Commit 2ad72ecf authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: add Client.Timeout for end-to-end timeouts

Fixes #3362

LGTM=josharian
R=golang-codereviews, josharian
CC=adg, dsymonds, golang-codereviews, n13m3y3r
https://golang.org/cl/70120045
parent 92d54833
......@@ -17,6 +17,8 @@ import (
"log"
"net/url"
"strings"
"sync"
"time"
)
// A Client is an HTTP client. Its zero value (DefaultClient) is a
......@@ -52,6 +54,21 @@ type Client struct {
// If Jar is nil, cookies are not sent in requests and ignored
// in responses.
Jar CookieJar
// Timeout specifies the end-to-end timeout for requests made
// via this Client. The timeout includes connection time, any
// redirects, and reading the response body. The timeout
// remains running once Get, Head, Post, or Do returns and
// will interrupt the read of the Response.Body if EOF hasn't
// been reached.
//
// A Timeout of zero means no timeout.
//
// The Client's Transport must support the CancelRequest
// method or Client will return errors when attempting to make
// a request with Get, Head, Post, or Do. Client's default
// Transport (DefaultTransport) supports CancelRequest.
Timeout time.Duration
}
// DefaultClient is the default Client and is used by Get, Head, and Post.
......@@ -97,7 +114,7 @@ func (c *Client) send(req *Request) (*Response, error) {
req.AddCookie(cookie)
}
}
resp, err := send(req, c.Transport)
resp, err := send(req, c.transport())
if err != nil {
return nil, err
}
......@@ -134,15 +151,18 @@ func (c *Client) Do(req *Request) (resp *Response, err error) {
return c.send(req)
}
func (c *Client) transport() RoundTripper {
if c.Transport != nil {
return c.Transport
}
return DefaultTransport
}
// send issues an HTTP request.
// Caller should close resp.Body when done reading from it.
func send(req *Request, t RoundTripper) (resp *Response, err error) {
if t == nil {
t = DefaultTransport
if t == nil {
err = errors.New("http: no Client.Transport or DefaultTransport")
return
}
return nil, errors.New("http: no Client.Transport or DefaultTransport")
}
if req.URL == nil {
......@@ -260,18 +280,36 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo
return nil, errors.New("http: nil Request.URL")
}
var reqmu sync.Mutex // guards req
req := ireq
var timer *time.Timer
if c.Timeout > 0 {
type canceler interface {
CancelRequest(*Request)
}
tr, ok := c.transport().(canceler)
if !ok {
return nil, fmt.Errorf("net/http: Client Transport of type %T doesn't support CancelRequest; Timeout not supported", c.transport())
}
timer = time.AfterFunc(c.Timeout, func() {
reqmu.Lock()
defer reqmu.Unlock()
tr.CancelRequest(req)
})
}
urlStr := "" // next relative or absolute URL to fetch (after first request)
redirectFailed := false
for redirect := 0; ; redirect++ {
if redirect != 0 {
req = new(Request)
req.Method = ireq.Method
nreq := new(Request)
nreq.Method = ireq.Method
if ireq.Method == "POST" || ireq.Method == "PUT" {
req.Method = "GET"
nreq.Method = "GET"
}
req.Header = make(Header)
req.URL, err = base.Parse(urlStr)
nreq.Header = make(Header)
nreq.URL, err = base.Parse(urlStr)
if err != nil {
break
}
......@@ -279,15 +317,18 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo
// Add the Referer header.
lastReq := via[len(via)-1]
if lastReq.URL.Scheme != "https" {
req.Header.Set("Referer", lastReq.URL.String())
nreq.Header.Set("Referer", lastReq.URL.String())
}
err = redirectChecker(req, via)
err = redirectChecker(nreq, via)
if err != nil {
redirectFailed = true
break
}
}
reqmu.Lock()
req = nreq
reqmu.Unlock()
}
urlStr = req.URL.String()
......@@ -305,7 +346,10 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo
via = append(via, req)
continue
}
return
if timer != nil {
resp.Body = &cancelTimerBody{timer, resp.Body}
}
return resp, nil
}
method := ireq.Method
......@@ -408,3 +452,22 @@ func (c *Client) Head(url string) (resp *Response, err error) {
}
return c.doFollowingRedirects(req, shouldRedirectGet)
}
type cancelTimerBody struct {
t *time.Timer
rc io.ReadCloser
}
func (b *cancelTimerBody) Read(p []byte) (n int, err error) {
n, err = b.rc.Read(p)
if err == io.EOF {
b.t.Stop()
}
return
}
func (b *cancelTimerBody) Close() error {
err := b.rc.Close()
b.t.Stop()
return err
}
......@@ -812,3 +812,70 @@ func TestBasicAuth(t *testing.T) {
t.Errorf("Invalid auth %q", auth)
}
}
func TestClientTimeout(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
defer afterTest(t)
sawRoot := make(chan bool, 1)
sawSlow := make(chan bool, 1)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
if r.URL.Path == "/" {
sawRoot <- true
Redirect(w, r, "/slow", StatusFound)
return
}
if r.URL.Path == "/slow" {
w.Write([]byte("Hello"))
w.(Flusher).Flush()
sawSlow <- true
time.Sleep(2 * time.Second)
return
}
}))
defer ts.Close()
const timeout = 500 * time.Millisecond
c := &Client{
Timeout: timeout,
}
res, err := c.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
select {
case <-sawRoot:
// good.
default:
t.Fatal("handler never got / request")
}
select {
case <-sawSlow:
// good.
default:
t.Fatal("handler never got /slow request")
}
var all []byte
errc := make(chan error, 1)
go func() {
var err error
all, err = ioutil.ReadAll(res.Body)
errc <- err
res.Body.Close()
}()
const failTime = timeout * 2
select {
case err := <-errc:
if err == nil {
t.Error("expected error from ReadAll")
}
t.Logf("Got expected ReadAll error of %v after reading body %q", err, all)
case <-time.After(failTime):
t.Errorf("timeout after %v waiting for timeout of %v", failTime, timeout)
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment