Commit 2ad72ecf authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: add Client.Timeout for end-to-end timeouts

Fixes #3362

LGTM=josharian
R=golang-codereviews, josharian
CC=adg, dsymonds, golang-codereviews, n13m3y3r
https://golang.org/cl/70120045
parent 92d54833
...@@ -17,6 +17,8 @@ import ( ...@@ -17,6 +17,8 @@ import (
"log" "log"
"net/url" "net/url"
"strings" "strings"
"sync"
"time"
) )
// A Client is an HTTP client. Its zero value (DefaultClient) is a // A Client is an HTTP client. Its zero value (DefaultClient) is a
...@@ -52,6 +54,21 @@ type Client struct { ...@@ -52,6 +54,21 @@ type Client struct {
// If Jar is nil, cookies are not sent in requests and ignored // If Jar is nil, cookies are not sent in requests and ignored
// in responses. // in responses.
Jar CookieJar Jar CookieJar
// Timeout specifies the end-to-end timeout for requests made
// via this Client. The timeout includes connection time, any
// redirects, and reading the response body. The timeout
// remains running once Get, Head, Post, or Do returns and
// will interrupt the read of the Response.Body if EOF hasn't
// been reached.
//
// A Timeout of zero means no timeout.
//
// The Client's Transport must support the CancelRequest
// method or Client will return errors when attempting to make
// a request with Get, Head, Post, or Do. Client's default
// Transport (DefaultTransport) supports CancelRequest.
Timeout time.Duration
} }
// DefaultClient is the default Client and is used by Get, Head, and Post. // DefaultClient is the default Client and is used by Get, Head, and Post.
...@@ -97,7 +114,7 @@ func (c *Client) send(req *Request) (*Response, error) { ...@@ -97,7 +114,7 @@ func (c *Client) send(req *Request) (*Response, error) {
req.AddCookie(cookie) req.AddCookie(cookie)
} }
} }
resp, err := send(req, c.Transport) resp, err := send(req, c.transport())
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -134,15 +151,18 @@ func (c *Client) Do(req *Request) (resp *Response, err error) { ...@@ -134,15 +151,18 @@ func (c *Client) Do(req *Request) (resp *Response, err error) {
return c.send(req) return c.send(req)
} }
func (c *Client) transport() RoundTripper {
if c.Transport != nil {
return c.Transport
}
return DefaultTransport
}
// send issues an HTTP request. // send issues an HTTP request.
// Caller should close resp.Body when done reading from it. // Caller should close resp.Body when done reading from it.
func send(req *Request, t RoundTripper) (resp *Response, err error) { func send(req *Request, t RoundTripper) (resp *Response, err error) {
if t == nil { if t == nil {
t = DefaultTransport return nil, errors.New("http: no Client.Transport or DefaultTransport")
if t == nil {
err = errors.New("http: no Client.Transport or DefaultTransport")
return
}
} }
if req.URL == nil { if req.URL == nil {
...@@ -260,18 +280,36 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo ...@@ -260,18 +280,36 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo
return nil, errors.New("http: nil Request.URL") return nil, errors.New("http: nil Request.URL")
} }
var reqmu sync.Mutex // guards req
req := ireq req := ireq
var timer *time.Timer
if c.Timeout > 0 {
type canceler interface {
CancelRequest(*Request)
}
tr, ok := c.transport().(canceler)
if !ok {
return nil, fmt.Errorf("net/http: Client Transport of type %T doesn't support CancelRequest; Timeout not supported", c.transport())
}
timer = time.AfterFunc(c.Timeout, func() {
reqmu.Lock()
defer reqmu.Unlock()
tr.CancelRequest(req)
})
}
urlStr := "" // next relative or absolute URL to fetch (after first request) urlStr := "" // next relative or absolute URL to fetch (after first request)
redirectFailed := false redirectFailed := false
for redirect := 0; ; redirect++ { for redirect := 0; ; redirect++ {
if redirect != 0 { if redirect != 0 {
req = new(Request) nreq := new(Request)
req.Method = ireq.Method nreq.Method = ireq.Method
if ireq.Method == "POST" || ireq.Method == "PUT" { if ireq.Method == "POST" || ireq.Method == "PUT" {
req.Method = "GET" nreq.Method = "GET"
} }
req.Header = make(Header) nreq.Header = make(Header)
req.URL, err = base.Parse(urlStr) nreq.URL, err = base.Parse(urlStr)
if err != nil { if err != nil {
break break
} }
...@@ -279,15 +317,18 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo ...@@ -279,15 +317,18 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo
// Add the Referer header. // Add the Referer header.
lastReq := via[len(via)-1] lastReq := via[len(via)-1]
if lastReq.URL.Scheme != "https" { if lastReq.URL.Scheme != "https" {
req.Header.Set("Referer", lastReq.URL.String()) nreq.Header.Set("Referer", lastReq.URL.String())
} }
err = redirectChecker(req, via) err = redirectChecker(nreq, via)
if err != nil { if err != nil {
redirectFailed = true redirectFailed = true
break break
} }
} }
reqmu.Lock()
req = nreq
reqmu.Unlock()
} }
urlStr = req.URL.String() urlStr = req.URL.String()
...@@ -305,7 +346,10 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo ...@@ -305,7 +346,10 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo
via = append(via, req) via = append(via, req)
continue continue
} }
return if timer != nil {
resp.Body = &cancelTimerBody{timer, resp.Body}
}
return resp, nil
} }
method := ireq.Method method := ireq.Method
...@@ -408,3 +452,22 @@ func (c *Client) Head(url string) (resp *Response, err error) { ...@@ -408,3 +452,22 @@ func (c *Client) Head(url string) (resp *Response, err error) {
} }
return c.doFollowingRedirects(req, shouldRedirectGet) return c.doFollowingRedirects(req, shouldRedirectGet)
} }
type cancelTimerBody struct {
t *time.Timer
rc io.ReadCloser
}
func (b *cancelTimerBody) Read(p []byte) (n int, err error) {
n, err = b.rc.Read(p)
if err == io.EOF {
b.t.Stop()
}
return
}
func (b *cancelTimerBody) Close() error {
err := b.rc.Close()
b.t.Stop()
return err
}
...@@ -812,3 +812,70 @@ func TestBasicAuth(t *testing.T) { ...@@ -812,3 +812,70 @@ func TestBasicAuth(t *testing.T) {
t.Errorf("Invalid auth %q", auth) t.Errorf("Invalid auth %q", auth)
} }
} }
func TestClientTimeout(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
defer afterTest(t)
sawRoot := make(chan bool, 1)
sawSlow := make(chan bool, 1)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
if r.URL.Path == "/" {
sawRoot <- true
Redirect(w, r, "/slow", StatusFound)
return
}
if r.URL.Path == "/slow" {
w.Write([]byte("Hello"))
w.(Flusher).Flush()
sawSlow <- true
time.Sleep(2 * time.Second)
return
}
}))
defer ts.Close()
const timeout = 500 * time.Millisecond
c := &Client{
Timeout: timeout,
}
res, err := c.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
select {
case <-sawRoot:
// good.
default:
t.Fatal("handler never got / request")
}
select {
case <-sawSlow:
// good.
default:
t.Fatal("handler never got /slow request")
}
var all []byte
errc := make(chan error, 1)
go func() {
var err error
all, err = ioutil.ReadAll(res.Body)
errc <- err
res.Body.Close()
}()
const failTime = timeout * 2
select {
case err := <-errc:
if err == nil {
t.Error("expected error from ReadAll")
}
t.Logf("Got expected ReadAll error of %v after reading body %q", err, all)
case <-time.After(failTime):
t.Errorf("timeout after %v waiting for timeout of %v", failTime, timeout)
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment