Commit a8d90ec3 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: close Body in client code always, even on errors, and document

Fixes #6981

LGTM=rsc
R=golang-codereviews, nightlyone
CC=adg, dsymonds, golang-codereviews, rsc
https://golang.org/cl/85560045
parent 0d441a08
...@@ -91,8 +91,9 @@ type RoundTripper interface { ...@@ -91,8 +91,9 @@ type RoundTripper interface {
// authentication, or cookies. // authentication, or cookies.
// //
// RoundTrip should not modify the request, except for // RoundTrip should not modify the request, except for
// consuming and closing the Body. The request's URL and // consuming and closing the Body, including on errors. The
// Header fields are guaranteed to be initialized. // request's URL and Header fields are guaranteed to be
// initialized.
RoundTrip(*Request) (*Response, error) RoundTrip(*Request) (*Response, error)
} }
...@@ -140,6 +141,9 @@ func (c *Client) send(req *Request) (*Response, error) { ...@@ -140,6 +141,9 @@ func (c *Client) send(req *Request) (*Response, error) {
// (typically Transport) may not be able to re-use a persistent TCP // (typically Transport) may not be able to re-use a persistent TCP
// connection to the server for a subsequent "keep-alive" request. // connection to the server for a subsequent "keep-alive" request.
// //
// The request Body, if non-nil, will be closed by the underlying
// Transport, even on errors.
//
// Generally Get, Post, or PostForm will be used instead of Do. // Generally Get, Post, or PostForm will be used instead of Do.
func (c *Client) Do(req *Request) (resp *Response, err error) { func (c *Client) Do(req *Request) (resp *Response, err error) {
if req.Method == "GET" || req.Method == "HEAD" { if req.Method == "GET" || req.Method == "HEAD" {
...@@ -162,14 +166,17 @@ func (c *Client) transport() RoundTripper { ...@@ -162,14 +166,17 @@ func (c *Client) transport() RoundTripper {
// Caller should close resp.Body when done reading from it. // Caller should close resp.Body when done reading from it.
func send(req *Request, t RoundTripper) (resp *Response, err error) { func send(req *Request, t RoundTripper) (resp *Response, err error) {
if t == nil { if t == nil {
req.closeBody()
return nil, errors.New("http: no Client.Transport or DefaultTransport") return nil, errors.New("http: no Client.Transport or DefaultTransport")
} }
if req.URL == nil { if req.URL == nil {
req.closeBody()
return nil, errors.New("http: nil Request.URL") return nil, errors.New("http: nil Request.URL")
} }
if req.RequestURI != "" { if req.RequestURI != "" {
req.closeBody()
return nil, errors.New("http: Request.RequestURI can't be set in client requests.") return nil, errors.New("http: Request.RequestURI can't be set in client requests.")
} }
...@@ -277,6 +284,7 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo ...@@ -277,6 +284,7 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo
var via []*Request var via []*Request
if ireq.URL == nil { if ireq.URL == nil {
ireq.closeBody()
return nil, errors.New("http: nil Request.URL") return nil, errors.New("http: nil Request.URL")
} }
...@@ -399,7 +407,7 @@ func Post(url string, bodyType string, body io.Reader) (resp *Response, err erro ...@@ -399,7 +407,7 @@ func Post(url string, bodyType string, body io.Reader) (resp *Response, err erro
// Caller should close resp.Body when done reading from it. // Caller should close resp.Body when done reading from it.
// //
// If the provided body is also an io.Closer, it is closed after the // If the provided body is also an io.Closer, it is closed after the
// body is successfully written to the server. // request.
func (c *Client) Post(url string, bodyType string, body io.Reader) (resp *Response, err error) { func (c *Client) Post(url string, bodyType string, body io.Reader) (resp *Response, err error) {
req, err := NewRequest("POST", url, body) req, err := NewRequest("POST", url, body)
if err != nil { if err != nil {
......
...@@ -867,3 +867,9 @@ func (r *Request) wantsHttp10KeepAlive() bool { ...@@ -867,3 +867,9 @@ func (r *Request) wantsHttp10KeepAlive() bool {
func (r *Request) wantsClose() bool { func (r *Request) wantsClose() bool {
return hasToken(r.Header.get("Connection"), "close") return hasToken(r.Header.get("Connection"), "close")
} }
func (r *Request) closeBody() {
if r.Body != nil {
r.Body.Close()
}
}
...@@ -160,9 +160,11 @@ func (tr *transportRequest) extraHeaders() Header { ...@@ -160,9 +160,11 @@ func (tr *transportRequest) extraHeaders() Header {
// and redirects), see Get, Post, and the Client type. // and redirects), see Get, Post, and the Client type.
func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) { func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) {
if req.URL == nil { if req.URL == nil {
req.closeBody()
return nil, errors.New("http: nil Request.URL") return nil, errors.New("http: nil Request.URL")
} }
if req.Header == nil { if req.Header == nil {
req.closeBody()
return nil, errors.New("http: nil Request.Header") return nil, errors.New("http: nil Request.Header")
} }
if req.URL.Scheme != "http" && req.URL.Scheme != "https" { if req.URL.Scheme != "http" && req.URL.Scheme != "https" {
...@@ -173,16 +175,19 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) { ...@@ -173,16 +175,19 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) {
} }
t.altMu.RUnlock() t.altMu.RUnlock()
if rt == nil { if rt == nil {
req.closeBody()
return nil, &badStringError{"unsupported protocol scheme", req.URL.Scheme} return nil, &badStringError{"unsupported protocol scheme", req.URL.Scheme}
} }
return rt.RoundTrip(req) return rt.RoundTrip(req)
} }
if req.URL.Host == "" { if req.URL.Host == "" {
req.closeBody()
return nil, errors.New("http: no Host in request URL") return nil, errors.New("http: no Host in request URL")
} }
treq := &transportRequest{Request: req} treq := &transportRequest{Request: req}
cm, err := t.connectMethodForRequest(treq) cm, err := t.connectMethodForRequest(treq)
if err != nil { if err != nil {
req.closeBody()
return nil, err return nil, err
} }
...@@ -193,6 +198,7 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) { ...@@ -193,6 +198,7 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) {
pconn, err := t.getConn(req, cm) pconn, err := t.getConn(req, cm)
if err != nil { if err != nil {
t.setReqCanceler(req, nil) t.setReqCanceler(req, nil)
req.closeBody()
return nil, err return nil, err
} }
...@@ -885,6 +891,7 @@ func (pc *persistConn) writeLoop() { ...@@ -885,6 +891,7 @@ func (pc *persistConn) writeLoop() {
} }
if err != nil { if err != nil {
pc.markBroken() pc.markBroken()
wr.req.Request.closeBody()
} }
pc.writeErrCh <- err // to the body reader, which might recycle us pc.writeErrCh <- err // to the body reader, which might recycle us
wr.ch <- err // to the roundTrip function wr.ch <- err // to the roundTrip function
......
...@@ -2028,6 +2028,52 @@ func TestTransportNoReuseAfterEarlyResponse(t *testing.T) { ...@@ -2028,6 +2028,52 @@ func TestTransportNoReuseAfterEarlyResponse(t *testing.T) {
} }
} }
type errorReader struct {
err error
}
func (e errorReader) Read(p []byte) (int, error) { return 0, e.err }
type closerFunc func() error
func (f closerFunc) Close() error { return f() }
// Issue 6981
func TestTransportClosesBodyOnError(t *testing.T) {
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
ioutil.ReadAll(r.Body)
}))
defer ts.Close()
fakeErr := errors.New("fake error")
didClose := make(chan bool, 1)
req, _ := NewRequest("POST", ts.URL, struct {
io.Reader
io.Closer
}{
io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), errorReader{fakeErr}),
closerFunc(func() error {
select {
case didClose <- true:
default:
}
return nil
}),
})
res, err := DefaultClient.Do(req)
if res != nil {
defer res.Body.Close()
}
if err == nil || !strings.Contains(err.Error(), fakeErr.Error()) {
t.Fatalf("Do error = %v; want something containing %q", fakeErr.Error())
}
select {
case <-didClose:
default:
t.Errorf("didn't see Body.Close")
}
}
func wantBody(res *http.Response, err error, want string) error { func wantBody(res *http.Response, err error, want string) error {
if err != nil { if err != nil {
return err return err
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment