Commit e4ed9494 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: fix response Connection: close, close client connections

Fixes #3663
Updates #3540 (fixes it more)
Updates #1967 (fixes it more, re-enables a test)

R=golang-dev, n13m3y3r
CC=golang-dev
https://golang.org/cl/6213064
parent 7482822b
...@@ -386,17 +386,18 @@ func testTCPConnectionCloses(t *testing.T, req string, h Handler) { ...@@ -386,17 +386,18 @@ func testTCPConnectionCloses(t *testing.T, req string, h Handler) {
} }
r := bufio.NewReader(conn) r := bufio.NewReader(conn)
_, err = ReadResponse(r, &Request{Method: "GET"}) res, err := ReadResponse(r, &Request{Method: "GET"})
if err != nil { if err != nil {
t.Fatal("ReadResponse error:", err) t.Fatal("ReadResponse error:", err)
} }
success := make(chan bool) didReadAll := make(chan bool, 1)
go func() { go func() {
select { select {
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
t.Fatal("body not closed after 5s") t.Error("body not closed after 5s")
case <-success: return
case <-didReadAll:
} }
}() }()
...@@ -404,8 +405,11 @@ func testTCPConnectionCloses(t *testing.T, req string, h Handler) { ...@@ -404,8 +405,11 @@ func testTCPConnectionCloses(t *testing.T, req string, h Handler) {
if err != nil { if err != nil {
t.Fatal("read error:", err) t.Fatal("read error:", err)
} }
didReadAll <- true
success <- true if !res.Close {
t.Errorf("Response.Close = false; want true")
}
} }
// TestServeHTTP10Close verifies that HTTP/1.0 requests won't be kept alive. // TestServeHTTP10Close verifies that HTTP/1.0 requests won't be kept alive.
......
...@@ -389,6 +389,11 @@ func (w *response) WriteHeader(code int) { ...@@ -389,6 +389,11 @@ func (w *response) WriteHeader(code int) {
if !w.req.ProtoAtLeast(1, 0) { if !w.req.ProtoAtLeast(1, 0) {
return return
} }
if w.closeAfterReply && !hasToken(w.header.Get("Connection"), "close") {
w.header.Set("Connection", "close")
}
proto := "HTTP/1.0" proto := "HTTP/1.0"
if w.req.ProtoAtLeast(1, 1) { if w.req.ProtoAtLeast(1, 1) {
proto = "HTTP/1.1" proto = "HTTP/1.1"
......
...@@ -480,6 +480,7 @@ type persistConn struct { ...@@ -480,6 +480,7 @@ type persistConn struct {
t *Transport t *Transport
cacheKey string // its connectMethod.String() cacheKey string // its connectMethod.String()
conn net.Conn conn net.Conn
closed bool // whether conn has been closed
br *bufio.Reader // from conn br *bufio.Reader // from conn
bw *bufio.Writer // to conn bw *bufio.Writer // to conn
reqch chan requestAndChan // written by roundTrip(); read by readLoop() reqch chan requestAndChan // written by roundTrip(); read by readLoop()
...@@ -574,6 +575,9 @@ func (pc *persistConn) readLoop() { ...@@ -574,6 +575,9 @@ func (pc *persistConn) readLoop() {
if alive && !pc.t.putIdleConn(pc) { if alive && !pc.t.putIdleConn(pc) {
alive = false alive = false
} }
if !alive {
pc.close()
}
waitForBodyRead <- true waitForBodyRead <- true
} }
} }
...@@ -669,7 +673,10 @@ func (pc *persistConn) close() { ...@@ -669,7 +673,10 @@ func (pc *persistConn) close() {
func (pc *persistConn) closeLocked() { func (pc *persistConn) closeLocked() {
pc.broken = true pc.broken = true
pc.conn.Close() if !pc.closed {
pc.conn.Close()
pc.closed = true
}
pc.mutateHeaderFunc = nil pc.mutateHeaderFunc = nil
} }
......
...@@ -37,17 +37,21 @@ var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) { ...@@ -37,17 +37,21 @@ var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
w.Write([]byte(r.RemoteAddr)) w.Write([]byte(r.RemoteAddr))
}) })
// testCloseConn is a net.Conn tracked by a testConnSet.
type testCloseConn struct { type testCloseConn struct {
net.Conn net.Conn
set *testConnSet set *testConnSet
} }
func (conn *testCloseConn) Close() error { func (c *testCloseConn) Close() error {
conn.set.remove(conn) c.set.remove(c)
return conn.Conn.Close() return c.Conn.Close()
} }
// testConnSet tracks a set of TCP connections and whether they've
// been closed.
type testConnSet struct { type testConnSet struct {
t *testing.T
closed map[net.Conn]bool closed map[net.Conn]bool
list []net.Conn // in order created list []net.Conn // in order created
mutex sync.Mutex mutex sync.Mutex
...@@ -67,8 +71,9 @@ func (tcs *testConnSet) remove(c net.Conn) { ...@@ -67,8 +71,9 @@ func (tcs *testConnSet) remove(c net.Conn) {
} }
// some tests use this to manage raw tcp connections for later inspection // some tests use this to manage raw tcp connections for later inspection
func makeTestDial() (*testConnSet, func(n, addr string) (net.Conn, error)) { func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, error)) {
connSet := &testConnSet{ connSet := &testConnSet{
t: t,
closed: make(map[net.Conn]bool), closed: make(map[net.Conn]bool),
} }
dial := func(n, addr string) (net.Conn, error) { dial := func(n, addr string) (net.Conn, error) {
...@@ -89,10 +94,7 @@ func (tcs *testConnSet) check(t *testing.T) { ...@@ -89,10 +94,7 @@ func (tcs *testConnSet) check(t *testing.T) {
for i, c := range tcs.list { for i, c := range tcs.list {
if !tcs.closed[c] { if !tcs.closed[c] {
// TODO(bradfitz,gustavo): make the following t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list))
// line an Errorf, not Logf, once issue 3540
// is fixed again.
t.Logf("TCP connection #%d (of %d total) was not closed", i+1, len(tcs.list))
} }
} }
} }
...@@ -134,7 +136,7 @@ func TestTransportConnectionCloseOnResponse(t *testing.T) { ...@@ -134,7 +136,7 @@ func TestTransportConnectionCloseOnResponse(t *testing.T) {
ts := httptest.NewServer(hostPortHandler) ts := httptest.NewServer(hostPortHandler)
defer ts.Close() defer ts.Close()
connSet, testDial := makeTestDial() connSet, testDial := makeTestDial(t)
for _, connectionClose := range []bool{false, true} { for _, connectionClose := range []bool{false, true} {
tr := &Transport{ tr := &Transport{
...@@ -184,7 +186,7 @@ func TestTransportConnectionCloseOnRequest(t *testing.T) { ...@@ -184,7 +186,7 @@ func TestTransportConnectionCloseOnRequest(t *testing.T) {
ts := httptest.NewServer(hostPortHandler) ts := httptest.NewServer(hostPortHandler)
defer ts.Close() defer ts.Close()
connSet, testDial := makeTestDial() connSet, testDial := makeTestDial(t)
for _, connectionClose := range []bool{false, true} { for _, connectionClose := range []bool{false, true} {
tr := &Transport{ tr := &Transport{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment