Commit d8bd7b24 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: update bundled x/net/http2 for Server context changes

Updates x/net/http2 to golang.org/cl/23220
(http2: with Go 1.7 set Request.Context in ServeHTTP handlers)

Fixes #15134

Change-Id: I73bac2601118614528f051e85dab51dc48e74f41
Reviewed-on: https://go-review.googlesource.com/23221
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: 's avatarAndrew Gerrand <adg@golang.org>
parent 1efec481
...@@ -1974,6 +1974,27 @@ func http2summarizeFrame(f http2Frame) string { ...@@ -1974,6 +1974,27 @@ func http2summarizeFrame(f http2Frame) string {
return buf.String() return buf.String()
} }
type http2contextContext interface {
context.Context
}
func http2serverConnBaseContext(c net.Conn, opts *http2ServeConnOpts) (ctx http2contextContext, cancel func()) {
ctx, cancel = context.WithCancel(context.Background())
ctx = context.WithValue(ctx, LocalAddrContextKey, c.LocalAddr())
if hs := opts.baseConfig(); hs != nil {
ctx = context.WithValue(ctx, ServerContextKey, hs)
}
return
}
func http2contextWithCancel(ctx http2contextContext) (_ http2contextContext, cancel func()) {
return context.WithCancel(ctx)
}
func http2requestWithContext(req *Request, ctx http2contextContext) *Request {
return req.WithContext(ctx)
}
type http2clientTrace httptrace.ClientTrace type http2clientTrace httptrace.ClientTrace
func http2reqContext(r *Request) context.Context { return r.Context() } func http2reqContext(r *Request) context.Context { return r.Context() }
...@@ -2994,10 +3015,14 @@ func (o *http2ServeConnOpts) handler() Handler { ...@@ -2994,10 +3015,14 @@ func (o *http2ServeConnOpts) handler() Handler {
// //
// The opts parameter is optional. If nil, default values are used. // The opts parameter is optional. If nil, default values are used.
func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) { func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) {
baseCtx, cancel := http2serverConnBaseContext(c, opts)
defer cancel()
sc := &http2serverConn{ sc := &http2serverConn{
srv: s, srv: s,
hs: opts.baseConfig(), hs: opts.baseConfig(),
conn: c, conn: c,
baseCtx: baseCtx,
remoteAddrStr: c.RemoteAddr().String(), remoteAddrStr: c.RemoteAddr().String(),
bw: http2newBufferedWriter(c), bw: http2newBufferedWriter(c),
handler: opts.handler(), handler: opts.handler(),
...@@ -3016,6 +3041,7 @@ func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) { ...@@ -3016,6 +3041,7 @@ func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) {
serveG: http2newGoroutineLock(), serveG: http2newGoroutineLock(),
pushEnabled: true, pushEnabled: true,
} }
sc.flow.add(http2initialWindowSize) sc.flow.add(http2initialWindowSize)
sc.inflow.add(http2initialWindowSize) sc.inflow.add(http2initialWindowSize)
sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf) sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
...@@ -3088,6 +3114,7 @@ type http2serverConn struct { ...@@ -3088,6 +3114,7 @@ type http2serverConn struct {
conn net.Conn conn net.Conn
bw *http2bufferedWriter // writing to conn bw *http2bufferedWriter // writing to conn
handler Handler handler Handler
baseCtx http2contextContext
framer *http2Framer framer *http2Framer
doneServing chan struct{} // closed when serverConn.serve ends doneServing chan struct{} // closed when serverConn.serve ends
readFrameCh chan http2readFrameResult // written by serverConn.readFrames readFrameCh chan http2readFrameResult // written by serverConn.readFrames
...@@ -3151,10 +3178,12 @@ func (sc *http2serverConn) maxHeaderListSize() uint32 { ...@@ -3151,10 +3178,12 @@ func (sc *http2serverConn) maxHeaderListSize() uint32 {
// responseWriter's state field. // responseWriter's state field.
type http2stream struct { type http2stream struct {
// immutable: // immutable:
sc *http2serverConn sc *http2serverConn
id uint32 id uint32
body *http2pipe // non-nil if expecting DATA frames body *http2pipe // non-nil if expecting DATA frames
cw http2closeWaiter // closed wait stream transitions to closed state cw http2closeWaiter // closed wait stream transitions to closed state
ctx http2contextContext
cancelCtx func()
// owned by serverConn's serve loop: // owned by serverConn's serve loop:
bodyBytes int64 // body bytes seen so far bodyBytes int64 // body bytes seen so far
...@@ -3818,6 +3847,7 @@ func (sc *http2serverConn) processResetStream(f *http2RSTStreamFrame) error { ...@@ -3818,6 +3847,7 @@ func (sc *http2serverConn) processResetStream(f *http2RSTStreamFrame) error {
} }
if st != nil { if st != nil {
st.gotReset = true st.gotReset = true
st.cancelCtx()
sc.closeStream(st, http2StreamError{f.StreamID, f.ErrCode}) sc.closeStream(st, http2StreamError{f.StreamID, f.ErrCode})
} }
return nil return nil
...@@ -3997,10 +4027,13 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { ...@@ -3997,10 +4027,13 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error {
} }
sc.maxStreamID = id sc.maxStreamID = id
ctx, cancelCtx := http2contextWithCancel(sc.baseCtx)
st = &http2stream{ st = &http2stream{
sc: sc, sc: sc,
id: id, id: id,
state: http2stateOpen, state: http2stateOpen,
ctx: ctx,
cancelCtx: cancelCtx,
} }
if f.StreamEnded() { if f.StreamEnded() {
st.state = http2stateHalfClosedRemote st.state = http2stateHalfClosedRemote
...@@ -4208,6 +4241,7 @@ func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHead ...@@ -4208,6 +4241,7 @@ func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHead
Body: body, Body: body,
Trailer: trailer, Trailer: trailer,
} }
req = http2requestWithContext(req, st.ctx)
if bodyOpen { if bodyOpen {
buf := make([]byte, http2initialWindowSize) buf := make([]byte, http2initialWindowSize)
...@@ -4250,6 +4284,7 @@ func (sc *http2serverConn) getRequestBodyBuf() []byte { ...@@ -4250,6 +4284,7 @@ func (sc *http2serverConn) getRequestBodyBuf() []byte {
func (sc *http2serverConn) runHandler(rw *http2responseWriter, req *Request, handler func(ResponseWriter, *Request)) { func (sc *http2serverConn) runHandler(rw *http2responseWriter, req *Request, handler func(ResponseWriter, *Request)) {
didPanic := true didPanic := true
defer func() { defer func() {
rw.rws.stream.cancelCtx()
if didPanic { if didPanic {
e := recover() e := recover()
// Same as net/http: // Same as net/http:
......
...@@ -4064,10 +4064,16 @@ func TestServerValidatesHeaders(t *testing.T) { ...@@ -4064,10 +4064,16 @@ func TestServerValidatesHeaders(t *testing.T) {
} }
} }
func TestServerRequestContextCancel_ServeHTTPDone(t *testing.T) { func TestServerRequestContextCancel_ServeHTTPDone_h1(t *testing.T) {
testServerRequestContextCancel_ServeHTTPDone(t, h1Mode)
}
func TestServerRequestContextCancel_ServeHTTPDone_h2(t *testing.T) {
testServerRequestContextCancel_ServeHTTPDone(t, h2Mode)
}
func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, h2 bool) {
defer afterTest(t) defer afterTest(t)
ctxc := make(chan context.Context, 1) ctxc := make(chan context.Context, 1)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
ctx := r.Context() ctx := r.Context()
select { select {
case <-ctx.Done(): case <-ctx.Done():
...@@ -4076,8 +4082,8 @@ func TestServerRequestContextCancel_ServeHTTPDone(t *testing.T) { ...@@ -4076,8 +4082,8 @@ func TestServerRequestContextCancel_ServeHTTPDone(t *testing.T) {
} }
ctxc <- ctx ctxc <- ctx
})) }))
defer ts.Close() defer cst.close()
res, err := Get(ts.URL) res, err := cst.c.Get(cst.ts.URL)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -4130,9 +4136,15 @@ func TestServerRequestContextCancel_ConnClose(t *testing.T) { ...@@ -4130,9 +4136,15 @@ func TestServerRequestContextCancel_ConnClose(t *testing.T) {
} }
} }
func TestServerContext_ServerContextKey(t *testing.T) { func TestServerContext_ServerContextKey_h1(t *testing.T) {
testServerContext_ServerContextKey(t, h1Mode)
}
func TestServerContext_ServerContextKey_h2(t *testing.T) {
testServerContext_ServerContextKey(t, h2Mode)
}
func testServerContext_ServerContextKey(t *testing.T, h2 bool) {
defer afterTest(t) defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
ctx := r.Context() ctx := r.Context()
got := ctx.Value(ServerContextKey) got := ctx.Value(ServerContextKey)
if _, ok := got.(*Server); !ok { if _, ok := got.(*Server); !ok {
...@@ -4140,12 +4152,14 @@ func TestServerContext_ServerContextKey(t *testing.T) { ...@@ -4140,12 +4152,14 @@ func TestServerContext_ServerContextKey(t *testing.T) {
} }
got = ctx.Value(LocalAddrContextKey) got = ctx.Value(LocalAddrContextKey)
if _, ok := got.(net.Addr); !ok { if addr, ok := got.(net.Addr); !ok {
t.Errorf("local addr value = %T; want net.Addr", got) t.Errorf("local addr value = %T; want net.Addr", got)
} else if fmt.Sprint(addr) != r.Host {
t.Errorf("local addr = %v; want %v", addr, r.Host)
} }
})) }))
defer ts.Close() defer cst.close()
res, err := Get(ts.URL) res, err := cst.c.Get(cst.ts.URL)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment