Commit d0a7d01f authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: fix client goroutine leak with persistent connections

Thanks to Sascha Matzke & Florian Weimer for diagnosing.

R=golang-dev, adg, bradfitz, kevlar
CC=golang-dev
https://golang.org/cl/5656046
parent 9df6fdcc
...@@ -235,15 +235,19 @@ func (cm *connectMethod) proxyAuth() string { ...@@ -235,15 +235,19 @@ func (cm *connectMethod) proxyAuth() string {
return "" return ""
} }
func (t *Transport) putIdleConn(pconn *persistConn) { // putIdleConn adds pconn to the list of idle persistent connections awaiting
// a new request.
// If pconn is no longer needed or not in a good state, putIdleConn
// returns false.
func (t *Transport) putIdleConn(pconn *persistConn) bool {
t.lk.Lock() t.lk.Lock()
defer t.lk.Unlock() defer t.lk.Unlock()
if t.DisableKeepAlives || t.MaxIdleConnsPerHost < 0 { if t.DisableKeepAlives || t.MaxIdleConnsPerHost < 0 {
pconn.close() pconn.close()
return return false
} }
if pconn.isBroken() { if pconn.isBroken() {
return return false
} }
key := pconn.cacheKey key := pconn.cacheKey
max := t.MaxIdleConnsPerHost max := t.MaxIdleConnsPerHost
...@@ -252,9 +256,10 @@ func (t *Transport) putIdleConn(pconn *persistConn) { ...@@ -252,9 +256,10 @@ func (t *Transport) putIdleConn(pconn *persistConn) {
} }
if len(t.idleConn[key]) >= max { if len(t.idleConn[key]) >= max {
pconn.close() pconn.close()
return return false
} }
t.idleConn[key] = append(t.idleConn[key], pconn) t.idleConn[key] = append(t.idleConn[key], pconn)
return true
} }
func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) { func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) {
...@@ -565,7 +570,9 @@ func (pc *persistConn) readLoop() { ...@@ -565,7 +570,9 @@ func (pc *persistConn) readLoop() {
lastbody = resp.Body lastbody = resp.Body
waitForBodyRead = make(chan bool) waitForBodyRead = make(chan bool)
resp.Body.(*bodyEOFSignal).fn = func() { resp.Body.(*bodyEOFSignal).fn = func() {
pc.t.putIdleConn(pc) if !pc.t.putIdleConn(pc) {
alive = false
}
waitForBodyRead <- true waitForBodyRead <- true
} }
} else { } else {
...@@ -578,7 +585,9 @@ func (pc *persistConn) readLoop() { ...@@ -578,7 +585,9 @@ func (pc *persistConn) readLoop() {
// read it (even though it'll just be 0, EOF). // read it (even though it'll just be 0, EOF).
lastbody = nil lastbody = nil
pc.t.putIdleConn(pc) if !pc.t.putIdleConn(pc) {
alive = false
}
} }
} }
......
...@@ -16,6 +16,7 @@ import ( ...@@ -16,6 +16,7 @@ import (
. "net/http" . "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"runtime"
"strconv" "strconv"
"strings" "strings"
"testing" "testing"
...@@ -632,6 +633,66 @@ func TestTransportGzipRecursive(t *testing.T) { ...@@ -632,6 +633,66 @@ func TestTransportGzipRecursive(t *testing.T) {
} }
} }
// tests that persistent goroutine connections shut down when no longer desired.
func TestTransportPersistConnLeak(t *testing.T) {
gotReqCh := make(chan bool)
unblockCh := make(chan bool)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
gotReqCh <- true
<-unblockCh
w.Header().Set("Content-Length", "0")
w.WriteHeader(204)
}))
defer ts.Close()
tr := &Transport{}
c := &Client{Transport: tr}
n0 := runtime.Goroutines()
const numReq = 100
didReqCh := make(chan bool)
for i := 0; i < numReq; i++ {
go func() {
c.Get(ts.URL)
didReqCh <- true
}()
}
// Wait for all goroutines to be stuck in the Handler.
for i := 0; i < numReq; i++ {
<-gotReqCh
}
nhigh := runtime.Goroutines()
// Tell all handlers to unblock and reply.
for i := 0; i < numReq; i++ {
unblockCh <- true
}
// Wait for all HTTP clients to be done.
for i := 0; i < numReq; i++ {
<-didReqCh
}
time.Sleep(100 * time.Millisecond)
runtime.GC()
runtime.GC() // even more.
nfinal := runtime.Goroutines()
growth := nfinal - n0
// We expect 5 extra goroutines, empirically. That number is at least
// DefaultMaxIdleConnsPerHost * 2 (one reader goroutine, one writer),
// and something else.
expectedGoroutineGrowth := DefaultMaxIdleConnsPerHost*2 + 1
if int(growth) > expectedGoroutineGrowth*2 {
t.Errorf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
}
}
type fooProto struct{} type fooProto struct{}
func (fooProto) RoundTrip(req *Request) (*Response, error) { func (fooProto) RoundTrip(req *Request) (*Response, error) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment