Commit c1c7547f authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: add Request.Context and Request.WithContext

Currently only used by the client. The server is not yet wired up.  A
TODO remains to document how it works server-side, once implemented.

Updates #14660

Change-Id: I27c2e74198872b2720995fa8271d91de200e23d5
Reviewed-on: https://go-review.googlesource.com/21496Reviewed-by: 's avatarAndrew Gerrand <adg@golang.org>
parent fcac8809
...@@ -357,7 +357,8 @@ var pkgDeps = map[string][]string{ ...@@ -357,7 +357,8 @@ var pkgDeps = map[string][]string{
// HTTP, kingpin of dependencies. // HTTP, kingpin of dependencies.
"net/http": { "net/http": {
"L4", "NET", "OS", "L4", "NET", "OS",
"compress/gzip", "crypto/tls", "mime/multipart", "runtime/debug", "context", "compress/gzip", "crypto/tls",
"mime/multipart", "runtime/debug",
"net/http/internal", "net/http/internal",
"golang.org/x/net/http2/hpack", "golang.org/x/net/http2/hpack",
}, },
......
...@@ -9,6 +9,7 @@ package http ...@@ -9,6 +9,7 @@ package http
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"context"
"crypto/tls" "crypto/tls"
"encoding/base64" "encoding/base64"
"errors" "errors"
...@@ -247,7 +248,43 @@ type Request struct { ...@@ -247,7 +248,43 @@ type Request struct {
// RoundTripper may support Cancel. // RoundTripper may support Cancel.
// //
// For server requests, this field is not applicable. // For server requests, this field is not applicable.
//
// Deprecated: use the Context and WithContext methods
// instead. If a Request's Cancel field and context are both
// set, it is undefined whether Cancel is respected.
Cancel <-chan struct{} Cancel <-chan struct{}
// ctx is either the client or server context. It should only
// be modified via copying the whole Request using WithContext.
// It is unexported to prevent people from using Context wrong
// and mutating the contexts held by callers of the same request.
ctx context.Context
}
// Context returns the request's context. To change the context, use
// WithContext.
//
// The returned context is always non-nil; it defaults to the
// background context.
func (r *Request) Context() context.Context {
// TODO(bradfitz): document above what Context means for server and client
// requests, once implemented.
if r.ctx != nil {
return r.ctx
}
return context.Background()
}
// WithContext returns a shallow copy of r with its context changed
// to ctx. The provided ctx must be non-nil.
func (r *Request) WithContext(ctx context.Context) *Request {
if ctx == nil {
panic("nil context")
}
r2 := new(Request)
*r2 = *r
r2.ctx = ctx
return r2
} }
// ProtoAtLeast reports whether the HTTP protocol used // ProtoAtLeast reports whether the HTTP protocol used
......
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"compress/gzip" "compress/gzip"
"crypto/rand" "crypto/rand"
"fmt" "fmt"
"go/ast"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http/internal" "net/http/internal"
...@@ -656,10 +657,14 @@ func diff(t *testing.T, prefix string, have, want interface{}) { ...@@ -656,10 +657,14 @@ func diff(t *testing.T, prefix string, have, want interface{}) {
t.Errorf("%s: type mismatch %v want %v", prefix, hv.Type(), wv.Type()) t.Errorf("%s: type mismatch %v want %v", prefix, hv.Type(), wv.Type())
} }
for i := 0; i < hv.NumField(); i++ { for i := 0; i < hv.NumField(); i++ {
name := hv.Type().Field(i).Name
if !ast.IsExported(name) {
continue
}
hf := hv.Field(i).Interface() hf := hv.Field(i).Interface()
wf := wv.Field(i).Interface() wf := wv.Field(i).Interface()
if !reflect.DeepEqual(hf, wf) { if !reflect.DeepEqual(hf, wf) {
t.Errorf("%s: %s = %v want %v", prefix, hv.Type().Field(i).Name, hf, wf) t.Errorf("%s: %s = %v want %v", prefix, name, hf, wf)
} }
} }
} }
......
...@@ -758,6 +758,9 @@ func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error ...@@ -758,6 +758,9 @@ func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error
case <-req.Cancel: case <-req.Cancel:
handlePendingDial() handlePendingDial()
return nil, errRequestCanceledConn return nil, errRequestCanceledConn
case <-req.Context().Done():
handlePendingDial()
return nil, errRequestCanceledConn
case <-cancelc: case <-cancelc:
handlePendingDial() handlePendingDial()
return nil, errRequestCanceledConn return nil, errRequestCanceledConn
...@@ -1263,6 +1266,9 @@ func (pc *persistConn) readLoop() { ...@@ -1263,6 +1266,9 @@ func (pc *persistConn) readLoop() {
case <-rc.req.Cancel: case <-rc.req.Cancel:
alive = false alive = false
pc.t.CancelRequest(rc.req) pc.t.CancelRequest(rc.req)
case <-rc.req.Context().Done():
alive = false
pc.t.CancelRequest(rc.req)
case <-pc.closech: case <-pc.closech:
alive = false alive = false
} }
...@@ -1567,6 +1573,9 @@ WaitResponse: ...@@ -1567,6 +1573,9 @@ WaitResponse:
case <-cancelChan: case <-cancelChan:
pc.t.CancelRequest(req.Request) pc.t.CancelRequest(req.Request)
cancelChan = nil cancelChan = nil
case <-req.Context().Done():
pc.t.CancelRequest(req.Request)
cancelChan = nil
} }
} }
......
...@@ -13,6 +13,7 @@ import ( ...@@ -13,6 +13,7 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"compress/gzip" "compress/gzip"
"context"
"crypto/rand" "crypto/rand"
"crypto/tls" "crypto/tls"
"errors" "errors"
...@@ -1625,7 +1626,13 @@ func TestCancelRequestWithChannel(t *testing.T) { ...@@ -1625,7 +1626,13 @@ func TestCancelRequestWithChannel(t *testing.T) {
} }
} }
func TestCancelRequestWithChannelBeforeDo(t *testing.T) { func TestCancelRequestWithChannelBeforeDo_Cancel(t *testing.T) {
testCancelRequestWithChannelBeforeDo(t, false)
}
func TestCancelRequestWithChannelBeforeDo_Context(t *testing.T) {
testCancelRequestWithChannelBeforeDo(t, true)
}
func testCancelRequestWithChannelBeforeDo(t *testing.T, withCtx bool) {
setParallel(t) setParallel(t)
defer afterTest(t) defer afterTest(t)
unblockc := make(chan bool) unblockc := make(chan bool)
...@@ -1646,9 +1653,15 @@ func TestCancelRequestWithChannelBeforeDo(t *testing.T) { ...@@ -1646,9 +1653,15 @@ func TestCancelRequestWithChannelBeforeDo(t *testing.T) {
c := &Client{Transport: tr} c := &Client{Transport: tr}
req, _ := NewRequest("GET", ts.URL, nil) req, _ := NewRequest("GET", ts.URL, nil)
if withCtx {
ctx, cancel := context.WithCancel(context.Background())
cancel()
req = req.WithContext(ctx)
} else {
ch := make(chan struct{}) ch := make(chan struct{})
req.Cancel = ch req.Cancel = ch
close(ch) close(ch)
}
_, err := c.Do(req) _, err := c.Do(req)
if err == nil || !strings.Contains(err.Error(), "canceled") { if err == nil || !strings.Contains(err.Error(), "canceled") {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment