Commit 01b25600 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: make Client.Timeout return net.Error errors indicating timeout

Co-hacking with Dave Cheney.

Fixes #9405

Change-Id: I14fc3b6a47dcdb5e514e93d062b804bb24e89f47
Reviewed-on: https://go-review.googlesource.com/1875Reviewed-by: 's avatarDave Cheney <dave@cheney.net>
parent 421c0170
......@@ -19,6 +19,7 @@ import (
"net/url"
"strings"
"sync"
"sync/atomic"
"time"
)
......@@ -299,6 +300,8 @@ func (c *Client) Get(url string) (resp *Response, err error) {
return c.doFollowingRedirects(req, shouldRedirectGet)
}
func alwaysFalse() bool { return false }
func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bool) (resp *Response, err error) {
var base *url.URL
redirectChecker := c.CheckRedirect
......@@ -316,7 +319,10 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo
req := ireq
var timer *time.Timer
var atomicWasCanceled int32 // atomic bool (1 or 0)
var wasCanceled = alwaysFalse
if c.Timeout > 0 {
wasCanceled = func() bool { return atomic.LoadInt32(&atomicWasCanceled) != 0 }
type canceler interface {
CancelRequest(*Request)
}
......@@ -325,6 +331,7 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo
return nil, fmt.Errorf("net/http: Client Transport of type %T doesn't support CancelRequest; Timeout not supported", c.transport())
}
timer = time.AfterFunc(c.Timeout, func() {
atomic.StoreInt32(&atomicWasCanceled, 1)
reqmu.Lock()
defer reqmu.Unlock()
tr.CancelRequest(req)
......@@ -365,6 +372,12 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo
urlStr = req.URL.String()
if resp, err = c.send(req); err != nil {
if wasCanceled() {
err = &httpError{
err: err.Error() + " (Client.Timeout exceeded while awaiting headers)",
timeout: true,
}
}
break
}
......@@ -385,7 +398,11 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo
continue
}
if timer != nil {
resp.Body = &cancelTimerBody{timer, resp.Body}
resp.Body = &cancelTimerBody{
t: timer,
rc: resp.Body,
reqWasCanceled: wasCanceled,
}
}
return resp, nil
}
......@@ -491,15 +508,25 @@ func (c *Client) Head(url string) (resp *Response, err error) {
return c.doFollowingRedirects(req, shouldRedirectGet)
}
// cancelTimerBody is an io.ReadCloser that wraps rc with two features:
// 1) on Read EOF or Close, the timer t is Stopped,
// 2) On Read failure, if reqWasCanceled is true, the error is wrapped and
// marked as net.Error that hit its timeout.
type cancelTimerBody struct {
t *time.Timer
rc io.ReadCloser
t *time.Timer
rc io.ReadCloser
reqWasCanceled func() bool
}
func (b *cancelTimerBody) Read(p []byte) (n int, err error) {
n, err = b.rc.Read(p)
if err == io.EOF {
b.t.Stop()
} else if err != nil && b.reqWasCanceled() {
return n, &httpError{
err: err.Error() + " (Client.Timeout exceeded while reading body)",
timeout: true,
}
}
return
}
......
......@@ -899,14 +899,57 @@ func TestClientTimeout(t *testing.T) {
select {
case err := <-errc:
if err == nil {
t.Error("expected error from ReadAll")
t.Fatal("expected error from ReadAll")
}
ne, ok := err.(net.Error)
if !ok {
t.Errorf("error value from ReadAll was %T; expected some net.Error", err)
} else if !ne.Timeout() {
t.Errorf("net.Error.Timeout = false; want true")
}
if got := ne.Error(); !strings.Contains(got, "Client.Timeout exceeded") {
t.Errorf("error string = %q; missing timeout substring", got)
}
// Expected error.
case <-time.After(failTime):
t.Errorf("timeout after %v waiting for timeout of %v", failTime, timeout)
}
}
// Client.Timeout firing before getting to the body
func TestClientTimeout_Headers(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
defer afterTest(t)
donec := make(chan bool)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
<-donec
}))
defer ts.Close()
defer close(donec)
c := &Client{Timeout: 500 * time.Millisecond}
_, err := c.Get(ts.URL)
if err == nil {
t.Fatal("got response from Get; expected error")
}
ue, ok := err.(*url.Error)
if !ok {
t.Fatalf("Got error of type %T; want *url.Error", err)
}
ne, ok := ue.Err.(net.Error)
if !ok {
t.Fatalf("Got url.Error.Err of type %T; want some net.Error", err)
}
if !ne.Timeout() {
t.Error("net.Error.Timeout = false; want true")
}
if got := ne.Error(); !strings.Contains(got, "Client.Timeout exceeded") {
t.Errorf("error string = %q; missing timeout substring", got)
}
}
func TestClientRedirectEatsBody(t *testing.T) {
defer afterTest(t)
saw := make(chan string, 2)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment