Commit b016eba4 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: fix Transport data race, double cancel panic, cancel error message

Fixes #9496
Fixes #9946
Fixes #10474
Fixes #10405

Change-Id: I4e65f1706e46499811d9ebf4ad6d83a5dfb2ddaa
Reviewed-on: https://go-review.googlesource.com/8550Reviewed-by: 's avatarDaniel Morsing <daniel.morsing@gmail.com>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
parent 35b1dcc2
...@@ -334,6 +334,7 @@ var echoCookiesRedirectHandler = HandlerFunc(func(w ResponseWriter, r *Request) ...@@ -334,6 +334,7 @@ var echoCookiesRedirectHandler = HandlerFunc(func(w ResponseWriter, r *Request)
}) })
func TestClientSendsCookieFromJar(t *testing.T) { func TestClientSendsCookieFromJar(t *testing.T) {
defer afterTest(t)
tr := &recordingTransport{} tr := &recordingTransport{}
client := &Client{Transport: tr} client := &Client{Transport: tr}
client.Jar = &TestJar{perURL: make(map[string][]*Cookie)} client.Jar = &TestJar{perURL: make(map[string][]*Cookie)}
......
...@@ -110,3 +110,5 @@ func SetPendingDialHooks(before, after func()) { ...@@ -110,3 +110,5 @@ func SetPendingDialHooks(before, after func()) {
var ExportServerNewConn = (*Server).newConn var ExportServerNewConn = (*Server).newConn
var ExportCloseWriteAndWait = (*conn).closeWriteAndWait var ExportCloseWriteAndWait = (*conn).closeWriteAndWait
var ExportErrRequestCanceled = errRequestCanceled
...@@ -56,17 +56,21 @@ func goroutineLeaked() bool { ...@@ -56,17 +56,21 @@ func goroutineLeaked() bool {
// not counting goroutines for leakage in -short mode // not counting goroutines for leakage in -short mode
return false return false
} }
gs := interestingGoroutines()
n := 0 var stackCount map[string]int
stackCount := make(map[string]int) for i := 0; i < 5; i++ {
for _, g := range gs { n := 0
stackCount[g]++ stackCount = make(map[string]int)
n++ gs := interestingGoroutines()
} for _, g := range gs {
stackCount[g]++
if n == 0 { n++
return false }
if n == 0 {
return false
}
// Wait for goroutines to schedule and die off:
time.Sleep(100 * time.Millisecond)
} }
fmt.Fprintf(os.Stderr, "Too many goroutines running after net/http test(s).\n") fmt.Fprintf(os.Stderr, "Too many goroutines running after net/http test(s).\n")
for stack, count := range stackCount { for stack, count := range stackCount {
......
...@@ -6,6 +6,7 @@ package http_test ...@@ -6,6 +6,7 @@ package http_test
import ( import (
"bufio" "bufio"
"bytes"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"io" "io"
...@@ -17,6 +18,7 @@ import ( ...@@ -17,6 +18,7 @@ import (
) )
func TestNextProtoUpgrade(t *testing.T) { func TestNextProtoUpgrade(t *testing.T) {
defer afterTest(t)
ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
fmt.Fprintf(w, "path=%s,proto=", r.URL.Path) fmt.Fprintf(w, "path=%s,proto=", r.URL.Path)
if r.TLS != nil { if r.TLS != nil {
...@@ -38,12 +40,12 @@ func TestNextProtoUpgrade(t *testing.T) { ...@@ -38,12 +40,12 @@ func TestNextProtoUpgrade(t *testing.T) {
ts.StartTLS() ts.StartTLS()
defer ts.Close() defer ts.Close()
tr := newTLSTransport(t, ts)
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
// Normal request, without NPN. // Normal request, without NPN.
{ {
tr := newTLSTransport(t, ts)
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
res, err := c.Get(ts.URL) res, err := c.Get(ts.URL)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
...@@ -60,11 +62,17 @@ func TestNextProtoUpgrade(t *testing.T) { ...@@ -60,11 +62,17 @@ func TestNextProtoUpgrade(t *testing.T) {
// Request to an advertised but unhandled NPN protocol. // Request to an advertised but unhandled NPN protocol.
// Server will hang up. // Server will hang up.
{ {
tr.CloseIdleConnections() tr := newTLSTransport(t, ts)
tr.TLSClientConfig.NextProtos = []string{"unhandled-proto"} tr.TLSClientConfig.NextProtos = []string{"unhandled-proto"}
_, err := c.Get(ts.URL) defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
res, err := c.Get(ts.URL)
if err == nil { if err == nil {
t.Errorf("expected error on unhandled-proto request") defer res.Body.Close()
var buf bytes.Buffer
res.Write(&buf)
t.Errorf("expected error on unhandled-proto request; got: %s", buf.Bytes())
} }
} }
......
...@@ -178,6 +178,7 @@ func TestParseMultipartForm(t *testing.T) { ...@@ -178,6 +178,7 @@ func TestParseMultipartForm(t *testing.T) {
} }
func TestRedirect(t *testing.T) { func TestRedirect(t *testing.T) {
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
switch r.URL.Path { switch r.URL.Path {
case "/": case "/":
......
...@@ -146,6 +146,7 @@ func (ht handlerTest) rawResponse(req string) string { ...@@ -146,6 +146,7 @@ func (ht handlerTest) rawResponse(req string) string {
} }
func TestConsumingBodyOnNextConn(t *testing.T) { func TestConsumingBodyOnNextConn(t *testing.T) {
defer afterTest(t)
conn := new(testConn) conn := new(testConn)
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
conn.readBuf.Write([]byte( conn.readBuf.Write([]byte(
......
...@@ -279,6 +279,7 @@ func (t *Transport) CloseIdleConnections() { ...@@ -279,6 +279,7 @@ func (t *Transport) CloseIdleConnections() {
func (t *Transport) CancelRequest(req *Request) { func (t *Transport) CancelRequest(req *Request) {
t.reqMu.Lock() t.reqMu.Lock()
cancel := t.reqCanceler[req] cancel := t.reqCanceler[req]
delete(t.reqCanceler, req)
t.reqMu.Unlock() t.reqMu.Unlock()
if cancel != nil { if cancel != nil {
cancel() cancel()
...@@ -805,6 +806,7 @@ type persistConn struct { ...@@ -805,6 +806,7 @@ type persistConn struct {
numExpectedResponses int numExpectedResponses int
closed bool // whether conn has been closed closed bool // whether conn has been closed
broken bool // an error has happened on this connection; marked broken so it's not reused. broken bool // an error has happened on this connection; marked broken so it's not reused.
canceled bool // whether this conn was broken due a CancelRequest
// mutateHeaderFunc is an optional func to modify extra // mutateHeaderFunc is an optional func to modify extra
// headers on each outbound request before it's written. (the // headers on each outbound request before it's written. (the
// original Request given to RoundTrip is not modified) // original Request given to RoundTrip is not modified)
...@@ -819,8 +821,18 @@ func (pc *persistConn) isBroken() bool { ...@@ -819,8 +821,18 @@ func (pc *persistConn) isBroken() bool {
return b return b
} }
// isCanceled reports whether this connection was closed due to CancelRequest.
func (pc *persistConn) isCanceled() bool {
pc.lk.Lock()
defer pc.lk.Unlock()
return pc.canceled
}
func (pc *persistConn) cancelRequest() { func (pc *persistConn) cancelRequest() {
pc.conn.Close() pc.lk.Lock()
defer pc.lk.Unlock()
pc.canceled = true
pc.closeLocked()
} }
var remoteSideClosedFunc func(error) bool // or nil to use default var remoteSideClosedFunc func(error) bool // or nil to use default
...@@ -836,8 +848,13 @@ func remoteSideClosed(err error) bool { ...@@ -836,8 +848,13 @@ func remoteSideClosed(err error) bool {
} }
func (pc *persistConn) readLoop() { func (pc *persistConn) readLoop() {
alive := true // eofc is used to block http.Handler goroutines reading from Response.Body
// at EOF until this goroutines has (potentially) added the connection
// back to the idle pool.
eofc := make(chan struct{})
defer close(eofc) // unblock reader on errors
alive := true
for alive { for alive {
pb, err := pc.br.Peek(1) pb, err := pc.br.Peek(1)
...@@ -895,22 +912,22 @@ func (pc *persistConn) readLoop() { ...@@ -895,22 +912,22 @@ func (pc *persistConn) readLoop() {
alive = false alive = false
} }
var waitForBodyRead chan bool var waitForBodyRead chan bool // channel is nil when there's no body
if hasBody { if hasBody {
waitForBodyRead = make(chan bool, 2) waitForBodyRead = make(chan bool, 2)
resp.Body.(*bodyEOFSignal).earlyCloseFn = func() error { resp.Body.(*bodyEOFSignal).earlyCloseFn = func() error {
// Sending false here sets alive to
// false and closes the connection
// below.
waitForBodyRead <- false waitForBodyRead <- false
return nil return nil
} }
resp.Body.(*bodyEOFSignal).fn = func(err error) { resp.Body.(*bodyEOFSignal).fn = func(err error) error {
waitForBodyRead <- alive && isEOF := err == io.EOF
err == nil && waitForBodyRead <- isEOF
!pc.sawEOF && if isEOF {
pc.wroteRequest() && <-eofc // see comment at top
pc.t.putIdleConn(pc) } else if err != nil && pc.isCanceled() {
return errRequestCanceled
}
return err
} }
} }
...@@ -924,28 +941,33 @@ func (pc *persistConn) readLoop() { ...@@ -924,28 +941,33 @@ func (pc *persistConn) readLoop() {
// on the response channel before erroring out. // on the response channel before erroring out.
rc.ch <- responseAndError{resp, err} rc.ch <- responseAndError{resp, err}
if alive && !hasBody { if hasBody {
alive = !pc.sawEOF && // To avoid a race, wait for the just-returned
pc.wroteRequest() && // response body to be fully consumed before peek on
pc.t.putIdleConn(pc) // the underlying bufio reader.
}
// Wait for the just-returned response body to be fully consumed
// before we race and peek on the underlying bufio reader.
if waitForBodyRead != nil {
select { select {
case alive = <-waitForBodyRead: case bodyEOF := <-waitForBodyRead:
pc.t.setReqCanceler(rc.req, nil) // before pc might return to idle pool
alive = alive &&
bodyEOF &&
!pc.sawEOF &&
pc.wroteRequest() &&
pc.t.putIdleConn(pc)
if bodyEOF {
eofc <- struct{}{}
}
case <-pc.closech: case <-pc.closech:
alive = false alive = false
} }
} } else {
pc.t.setReqCanceler(rc.req, nil) // before pc might return to idle pool
pc.t.setReqCanceler(rc.req, nil) alive = alive &&
!pc.sawEOF &&
if !alive { pc.wroteRequest() &&
pc.close() pc.t.putIdleConn(pc)
} }
} }
pc.close()
} }
func (pc *persistConn) writeLoop() { func (pc *persistConn) writeLoop() {
...@@ -1035,6 +1057,7 @@ func (e *httpError) Temporary() bool { return true } ...@@ -1035,6 +1057,7 @@ func (e *httpError) Temporary() bool { return true }
var errTimeout error = &httpError{err: "net/http: timeout awaiting response headers", timeout: true} var errTimeout error = &httpError{err: "net/http: timeout awaiting response headers", timeout: true}
var errClosed error = &httpError{err: "net/http: transport closed before response was received"} var errClosed error = &httpError{err: "net/http: transport closed before response was received"}
var errRequestCanceled = errors.New("net/http: request canceled")
var testHookPersistConnClosedGotRes func() // nil except for tests var testHookPersistConnClosedGotRes func() // nil except for tests
...@@ -1183,16 +1206,18 @@ func canonicalAddr(url *url.URL) string { ...@@ -1183,16 +1206,18 @@ func canonicalAddr(url *url.URL) string {
// bodyEOFSignal wraps a ReadCloser but runs fn (if non-nil) at most // bodyEOFSignal wraps a ReadCloser but runs fn (if non-nil) at most
// once, right before its final (error-producing) Read or Close call // once, right before its final (error-producing) Read or Close call
// returns. If earlyCloseFn is non-nil and Close is called before // returns. fn should return the new error to return from Read or Close.
// io.EOF is seen, earlyCloseFn is called instead of fn, and its //
// return value is the return value from Close. // If earlyCloseFn is non-nil and Close is called before io.EOF is
// seen, earlyCloseFn is called instead of fn, and its return value is
// the return value from Close.
type bodyEOFSignal struct { type bodyEOFSignal struct {
body io.ReadCloser body io.ReadCloser
mu sync.Mutex // guards following 4 fields mu sync.Mutex // guards following 4 fields
closed bool // whether Close has been called closed bool // whether Close has been called
rerr error // sticky Read error rerr error // sticky Read error
fn func(error) // error will be nil on Read io.EOF fn func(error) error // err will be nil on Read io.EOF
earlyCloseFn func() error // optional alt Close func used if io.EOF not seen earlyCloseFn func() error // optional alt Close func used if io.EOF not seen
} }
func (es *bodyEOFSignal) Read(p []byte) (n int, err error) { func (es *bodyEOFSignal) Read(p []byte) (n int, err error) {
...@@ -1213,7 +1238,7 @@ func (es *bodyEOFSignal) Read(p []byte) (n int, err error) { ...@@ -1213,7 +1238,7 @@ func (es *bodyEOFSignal) Read(p []byte) (n int, err error) {
if es.rerr == nil { if es.rerr == nil {
es.rerr = err es.rerr = err
} }
es.condfn(err) err = es.condfn(err)
} }
return return
} }
...@@ -1229,20 +1254,17 @@ func (es *bodyEOFSignal) Close() error { ...@@ -1229,20 +1254,17 @@ func (es *bodyEOFSignal) Close() error {
return es.earlyCloseFn() return es.earlyCloseFn()
} }
err := es.body.Close() err := es.body.Close()
es.condfn(err) return es.condfn(err)
return err
} }
// caller must hold es.mu. // caller must hold es.mu.
func (es *bodyEOFSignal) condfn(err error) { func (es *bodyEOFSignal) condfn(err error) error {
if es.fn == nil { if es.fn == nil {
return return err
}
if err == io.EOF {
err = nil
} }
es.fn(err) err = es.fn(err)
es.fn = nil es.fn = nil
return err
} }
// gzipReader wraps a response body so it can lazily // gzipReader wraps a response body so it can lazily
......
...@@ -505,12 +505,17 @@ func TestStressSurpriseServerCloses(t *testing.T) { ...@@ -505,12 +505,17 @@ func TestStressSurpriseServerCloses(t *testing.T) {
tr := &Transport{DisableKeepAlives: false} tr := &Transport{DisableKeepAlives: false}
c := &Client{Transport: tr} c := &Client{Transport: tr}
defer tr.CloseIdleConnections()
// Do a bunch of traffic from different goroutines. Send to activityc // Do a bunch of traffic from different goroutines. Send to activityc
// after each request completes, regardless of whether it failed. // after each request completes, regardless of whether it failed.
// If these are too high, OS X exhausts its emphemeral ports
// and hangs waiting for them to transition TCP states. That's
// not what we want to test. TODO(bradfitz): use an io.Pipe
// dialer for this test instead?
const ( const (
numClients = 50 numClients = 20
reqsPerClient = 250 reqsPerClient = 25
) )
activityc := make(chan bool) activityc := make(chan bool)
for i := 0; i < numClients; i++ { for i := 0; i < numClients; i++ {
...@@ -1371,8 +1376,8 @@ func TestTransportCancelRequest(t *testing.T) { ...@@ -1371,8 +1376,8 @@ func TestTransportCancelRequest(t *testing.T) {
body, err := ioutil.ReadAll(res.Body) body, err := ioutil.ReadAll(res.Body)
d := time.Since(t0) d := time.Since(t0)
if err == nil { if err != ExportErrRequestCanceled {
t.Error("expected an error reading the body") t.Errorf("Body.Read error = %v; want errRequestCanceled", err)
} }
if string(body) != "Hello" { if string(body) != "Hello" {
t.Errorf("Body = %q; want Hello", body) t.Errorf("Body = %q; want Hello", body)
...@@ -1382,7 +1387,7 @@ func TestTransportCancelRequest(t *testing.T) { ...@@ -1382,7 +1387,7 @@ func TestTransportCancelRequest(t *testing.T) {
} }
// Verify no outstanding requests after readLoop/writeLoop // Verify no outstanding requests after readLoop/writeLoop
// goroutines shut down. // goroutines shut down.
for tries := 3; tries > 0; tries-- { for tries := 5; tries > 0; tries-- {
n := tr.NumPendingRequestsForTesting() n := tr.NumPendingRequestsForTesting()
if n == 0 { if n == 0 {
break break
...@@ -1431,6 +1436,7 @@ func TestTransportCancelRequestInDial(t *testing.T) { ...@@ -1431,6 +1436,7 @@ func TestTransportCancelRequestInDial(t *testing.T) {
eventLog.Printf("canceling") eventLog.Printf("canceling")
tr.CancelRequest(req) tr.CancelRequest(req)
tr.CancelRequest(req) // used to panic on second call
select { select {
case <-gotres: case <-gotres:
...@@ -2321,6 +2327,47 @@ func TestTransportResponseCloseRace(t *testing.T) { ...@@ -2321,6 +2327,47 @@ func TestTransportResponseCloseRace(t *testing.T) {
} }
} }
// Test for issue 10474
func TestTransportResponseCancelRace(t *testing.T) {
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
// important that this response has a body.
var b [1024]byte
w.Write(b[:])
}))
defer ts.Close()
tr := &Transport{}
defer tr.CloseIdleConnections()
req, err := NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatal(err)
}
res, err := tr.RoundTrip(req)
if err != nil {
t.Fatal(err)
}
// If we do an early close, Transport just throws the connection away and
// doesn't reuse it. In order to trigger the bug, it has to reuse the connection
// so read the body
if _, err := io.Copy(ioutil.Discard, res.Body); err != nil {
t.Fatal(err)
}
req2, err := NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatal(err)
}
tr.CancelRequest(req)
res, err = tr.RoundTrip(req2)
if err != nil {
t.Fatal(err)
}
res.Body.Close()
}
func wantBody(res *http.Response, err error, want string) error { func wantBody(res *http.Response, err error, want string) error {
if err != nil { if err != nil {
return err return err
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment