Commit c24de9d5 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

http2: add Server support for reading trailers from clients

Updates golang/go#13557

Change-Id: I95bbb15d9abbbbc4dc6c3a22cd965d8dcef53fb8
Reviewed-on: https://go-review.googlesource.com/17891Reviewed-by: 's avatarBlake Mizerany <blake.mizerany@gmail.com>
parent 548f7bf2
...@@ -57,6 +57,7 @@ func init() { ...@@ -57,6 +57,7 @@ func init() {
"server", "server",
"set-cookie", "set-cookie",
"strict-transport-security", "strict-transport-security",
"trailer",
"transfer-encoding", "transfer-encoding",
"user-agent", "user-agent",
"vary", "vary",
......
...@@ -102,6 +102,13 @@ func (d *Decoder) SetMaxStringLength(n int) { ...@@ -102,6 +102,13 @@ func (d *Decoder) SetMaxStringLength(n int) {
d.maxStrLen = n d.maxStrLen = n
} }
// SetEmitFunc changes the callback used when new header fields
// are decoded.
// It must be non-nil. It does not affect EmitEnabled.
func (d *Decoder) SetEmitFunc(emitFunc func(f HeaderField)) {
d.emit = emitFunc
}
// SetEmitEnabled controls whether the emitFunc provided to NewDecoder // SetEmitEnabled controls whether the emitFunc provided to NewDecoder
// should be called. The default is true. // should be called. The default is true.
// //
......
...@@ -14,11 +14,12 @@ import ( ...@@ -14,11 +14,12 @@ import (
// io.Pipe except there are no PipeReader/PipeWriter halves, and the // io.Pipe except there are no PipeReader/PipeWriter halves, and the
// underlying buffer is an interface. (io.Pipe is always unbuffered) // underlying buffer is an interface. (io.Pipe is always unbuffered)
type pipe struct { type pipe struct {
mu sync.Mutex mu sync.Mutex
c sync.Cond // c.L must point to c sync.Cond // c.L must point to
b pipeBuffer b pipeBuffer
err error // read error once empty. non-nil means closed. err error // read error once empty. non-nil means closed.
donec chan struct{} // closed on error donec chan struct{} // closed on error
readFn func() // optional code to run in Read before error
} }
type pipeBuffer interface { type pipeBuffer interface {
...@@ -40,6 +41,10 @@ func (p *pipe) Read(d []byte) (n int, err error) { ...@@ -40,6 +41,10 @@ func (p *pipe) Read(d []byte) (n int, err error) {
return p.b.Read(d) return p.b.Read(d)
} }
if p.err != nil { if p.err != nil {
if p.readFn != nil {
p.readFn() // e.g. copy trailers
p.readFn = nil // not sticky like p.err
}
return 0, p.err return 0, p.err
} }
p.c.Wait() p.c.Wait()
...@@ -63,13 +68,18 @@ func (p *pipe) Write(d []byte) (n int, err error) { ...@@ -63,13 +68,18 @@ func (p *pipe) Write(d []byte) (n int, err error) {
return p.b.Write(d) return p.b.Write(d)
} }
// CloseWithError causes Reads to wake up and return the // CloseWithError causes the next Read (waking up a current blocked
// provided err after all data has been read. // Read if needed) to return the provided err after all data has been
// read.
// //
// The error must be non-nil. // The error must be non-nil.
func (p *pipe) CloseWithError(err error) { func (p *pipe) CloseWithError(err error) { p.closeWithErrorAndCode(err, nil) }
// closeWithErrorAndCode is like CloseWithError but also sets some code to run
// in the caller's goroutine before returning the error.
func (p *pipe) closeWithErrorAndCode(err error, fn func()) {
if err == nil { if err == nil {
panic("CloseWithError must be non-nil") panic("CloseWithError err must be non-nil")
} }
p.mu.Lock() p.mu.Lock()
defer p.mu.Unlock() defer p.mu.Unlock()
...@@ -77,11 +87,14 @@ func (p *pipe) CloseWithError(err error) { ...@@ -77,11 +87,14 @@ func (p *pipe) CloseWithError(err error) {
p.c.L = &p.mu p.c.L = &p.mu
} }
defer p.c.Signal() defer p.c.Signal()
if p.err == nil { if p.err != nil {
p.err = err // Already been done.
if p.donec != nil { return
close(p.donec) }
} p.readFn = fn
p.err = err
if p.donec != nil {
close(p.donec)
} }
} }
......
...@@ -224,7 +224,7 @@ func (srv *Server) handleConn(hs *http.Server, c net.Conn, h http.Handler) { ...@@ -224,7 +224,7 @@ func (srv *Server) handleConn(hs *http.Server, c net.Conn, h http.Handler) {
sc.flow.add(initialWindowSize) sc.flow.add(initialWindowSize)
sc.inflow.add(initialWindowSize) sc.inflow.add(initialWindowSize)
sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf) sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
sc.hpackDecoder = hpack.NewDecoder(initialHeaderTableSize, sc.onNewHeaderField) sc.hpackDecoder = hpack.NewDecoder(initialHeaderTableSize, nil)
sc.hpackDecoder.SetMaxStringLength(sc.maxHeaderStringLen()) sc.hpackDecoder.SetMaxStringLength(sc.maxHeaderStringLen())
fr := NewFramer(sc.bw, c) fr := NewFramer(sc.bw, c)
...@@ -411,20 +411,26 @@ type requestParam struct { ...@@ -411,20 +411,26 @@ type requestParam struct {
// responseWriter's state field. // responseWriter's state field.
type stream struct { type stream struct {
// immutable: // immutable:
sc *serverConn
id uint32 id uint32
body *pipe // non-nil if expecting DATA frames body *pipe // non-nil if expecting DATA frames
cw closeWaiter // closed wait stream transitions to closed state cw closeWaiter // closed wait stream transitions to closed state
// owned by serverConn's serve loop: // owned by serverConn's serve loop:
bodyBytes int64 // body bytes seen so far bodyBytes int64 // body bytes seen so far
declBodyBytes int64 // or -1 if undeclared declBodyBytes int64 // or -1 if undeclared
flow flow // limits writing from Handler to client flow flow // limits writing from Handler to client
inflow flow // what the client is allowed to POST/etc to us inflow flow // what the client is allowed to POST/etc to us
parent *stream // or nil parent *stream // or nil
weight uint8 numTrailerValues int64
state streamState weight uint8
sentReset bool // only true once detached from streams map state streamState
gotReset bool // only true once detacted from streams map sentReset bool // only true once detached from streams map
gotReset bool // only true once detacted from streams map
gotTrailerHeader bool // HEADER frame for trailers was seen
trailer http.Header // accumulated trailers
reqTrailer http.Header // handler's Request.Trailer
} }
func (sc *serverConn) Framer() *Framer { return sc.framer } func (sc *serverConn) Framer() *Framer { return sc.framer }
...@@ -537,6 +543,37 @@ func (sc *serverConn) onNewHeaderField(f hpack.HeaderField) { ...@@ -537,6 +543,37 @@ func (sc *serverConn) onNewHeaderField(f hpack.HeaderField) {
} }
} }
func (st *stream) onNewTrailerField(f hpack.HeaderField) {
sc := st.sc
sc.serveG.check()
sc.vlogf("got trailer field %+v", f)
switch {
case !validHeader(f.Name):
// TODO: change hpack signature so this can return
// errors? Or stash an error somewhere on st or sc
// for processHeaderBlockFragment etc to pick up and
// return after the hpack Write/Close. For now just
// ignore.
return
case strings.HasPrefix(f.Name, ":"):
// TODO: same TODO as above.
return
default:
key := sc.canonicalHeader(f.Name)
if st.trailer != nil {
vv := append(st.trailer[key], f.Value)
st.trailer[key] = vv
// arbitrary; TODO: read spec about header list size limits wrt trailers
const tooBig = 1000
if len(vv) >= tooBig {
sc.hpackDecoder.SetEmitEnabled(false)
}
}
}
}
func (sc *serverConn) canonicalHeader(v string) string { func (sc *serverConn) canonicalHeader(v string) string {
sc.serveG.check() sc.serveG.check()
cv, ok := commonCanonHeader[v] cv, ok := commonCanonHeader[v]
...@@ -1249,7 +1286,7 @@ func (sc *serverConn) processData(f *DataFrame) error { ...@@ -1249,7 +1286,7 @@ func (sc *serverConn) processData(f *DataFrame) error {
// with a stream error (Section 5.4.2) of type STREAM_CLOSED." // with a stream error (Section 5.4.2) of type STREAM_CLOSED."
id := f.Header().StreamID id := f.Header().StreamID
st, ok := sc.streams[id] st, ok := sc.streams[id]
if !ok || st.state != stateOpen { if !ok || st.state != stateOpen || st.gotTrailerHeader {
// This includes sending a RST_STREAM if the stream is // This includes sending a RST_STREAM if the stream is
// in stateHalfClosedLocal (which currently means that // in stateHalfClosedLocal (which currently means that
// the http.Handler returned, so it's done reading & // the http.Handler returned, so it's done reading &
...@@ -1283,17 +1320,38 @@ func (sc *serverConn) processData(f *DataFrame) error { ...@@ -1283,17 +1320,38 @@ func (sc *serverConn) processData(f *DataFrame) error {
st.bodyBytes += int64(len(data)) st.bodyBytes += int64(len(data))
} }
if f.StreamEnded() { if f.StreamEnded() {
if st.declBodyBytes != -1 && st.declBodyBytes != st.bodyBytes { st.endStream()
st.body.CloseWithError(fmt.Errorf("request declared a Content-Length of %d but only wrote %d bytes",
st.declBodyBytes, st.bodyBytes))
} else {
st.body.CloseWithError(io.EOF)
}
st.state = stateHalfClosedRemote
} }
return nil return nil
} }
// endStream closes a Request.Body's pipe. It is called when a DATA
// frame says a request body is over (or after trailers).
func (st *stream) endStream() {
sc := st.sc
sc.serveG.check()
if st.declBodyBytes != -1 && st.declBodyBytes != st.bodyBytes {
st.body.CloseWithError(fmt.Errorf("request declared a Content-Length of %d but only wrote %d bytes",
st.declBodyBytes, st.bodyBytes))
} else {
st.body.closeWithErrorAndCode(io.EOF, st.copyTrailersToHandlerRequest)
st.body.CloseWithError(io.EOF)
}
st.state = stateHalfClosedRemote
}
// copyTrailersToHandlerRequest is run in the Handler's goroutine in
// its Request.Body.Read just before it gets io.EOF.
func (st *stream) copyTrailersToHandlerRequest() {
for k, vv := range st.trailer {
if _, ok := st.reqTrailer[k]; ok {
// Only copy it over it was pre-declared.
st.reqTrailer[k] = vv
}
}
}
func (sc *serverConn) processHeaders(f *HeadersFrame) error { func (sc *serverConn) processHeaders(f *HeadersFrame) error {
sc.serveG.check() sc.serveG.check()
id := f.Header().StreamID id := f.Header().StreamID
...@@ -1302,20 +1360,36 @@ func (sc *serverConn) processHeaders(f *HeadersFrame) error { ...@@ -1302,20 +1360,36 @@ func (sc *serverConn) processHeaders(f *HeadersFrame) error {
return nil return nil
} }
// http://http2.github.io/http2-spec/#rfc.section.5.1.1 // http://http2.github.io/http2-spec/#rfc.section.5.1.1
if id%2 != 1 || id <= sc.maxStreamID || sc.req.stream != nil { // Streams initiated by a client MUST use odd-numbered stream
// Streams initiated by a client MUST use odd-numbered // identifiers. [...] An endpoint that receives an unexpected
// stream identifiers. [...] The identifier of a newly // stream identifier MUST respond with a connection error
// established stream MUST be numerically greater than all // (Section 5.4.1) of type PROTOCOL_ERROR.
// streams that the initiating endpoint has opened or if id%2 != 1 {
// reserved. [...] An endpoint that receives an unexpected
// stream identifier MUST respond with a connection error
// (Section 5.4.1) of type PROTOCOL_ERROR.
return ConnectionError(ErrCodeProtocol) return ConnectionError(ErrCodeProtocol)
} }
// A HEADERS frame can be used to create a new stream or
// send a trailer for an open one. If we already have a stream
// open, let it process its own HEADERS frame (trailers at this
// point, if it's valid).
st := sc.streams[f.Header().StreamID]
if st != nil {
return st.processTrailerHeaders(f)
}
// [...] The identifier of a newly established stream MUST be
// numerically greater than all streams that the initiating
// endpoint has opened or reserved. [...] An endpoint that
// receives an unexpected stream identifier MUST respond with
// a connection error (Section 5.4.1) of type PROTOCOL_ERROR.
if id <= sc.maxStreamID || sc.req.stream != nil {
return ConnectionError(ErrCodeProtocol)
}
if id > sc.maxStreamID { if id > sc.maxStreamID {
sc.maxStreamID = id sc.maxStreamID = id
} }
st := &stream{ st = &stream{
sc: sc,
id: id, id: id,
state: stateOpen, state: stateOpen,
} }
...@@ -1341,16 +1415,30 @@ func (sc *serverConn) processHeaders(f *HeadersFrame) error { ...@@ -1341,16 +1415,30 @@ func (sc *serverConn) processHeaders(f *HeadersFrame) error {
stream: st, stream: st,
header: make(http.Header), header: make(http.Header),
} }
sc.hpackDecoder.SetEmitFunc(sc.onNewHeaderField)
sc.hpackDecoder.SetEmitEnabled(true) sc.hpackDecoder.SetEmitEnabled(true)
return sc.processHeaderBlockFragment(st, f.HeaderBlockFragment(), f.HeadersEnded()) return sc.processHeaderBlockFragment(st, f.HeaderBlockFragment(), f.HeadersEnded())
} }
func (st *stream) processTrailerHeaders(f *HeadersFrame) error {
sc := st.sc
sc.serveG.check()
if st.gotTrailerHeader {
return ConnectionError(ErrCodeProtocol)
}
st.gotTrailerHeader = true
return st.processTrailerHeaderBlockFragment(f.HeaderBlockFragment(), f.HeadersEnded())
}
func (sc *serverConn) processContinuation(f *ContinuationFrame) error { func (sc *serverConn) processContinuation(f *ContinuationFrame) error {
sc.serveG.check() sc.serveG.check()
st := sc.streams[f.Header().StreamID] st := sc.streams[f.Header().StreamID]
if st == nil || sc.curHeaderStreamID() != st.id { if st == nil || sc.curHeaderStreamID() != st.id {
return ConnectionError(ErrCodeProtocol) return ConnectionError(ErrCodeProtocol)
} }
if st.gotTrailerHeader {
return st.processTrailerHeaderBlockFragment(f.HeaderBlockFragment(), f.HeadersEnded())
}
return sc.processHeaderBlockFragment(st, f.HeaderBlockFragment(), f.HeadersEnded()) return sc.processHeaderBlockFragment(st, f.HeaderBlockFragment(), f.HeadersEnded())
} }
...@@ -1389,6 +1477,10 @@ func (sc *serverConn) processHeaderBlockFragment(st *stream, frag []byte, end bo ...@@ -1389,6 +1477,10 @@ func (sc *serverConn) processHeaderBlockFragment(st *stream, frag []byte, end bo
if err != nil { if err != nil {
return err return err
} }
st.reqTrailer = req.Trailer
if st.reqTrailer != nil {
st.trailer = make(http.Header)
}
st.body = req.Body.(*requestBody).pipe // may be nil st.body = req.Body.(*requestBody).pipe // may be nil
st.declBodyBytes = req.ContentLength st.declBodyBytes = req.ContentLength
...@@ -1402,6 +1494,24 @@ func (sc *serverConn) processHeaderBlockFragment(st *stream, frag []byte, end bo ...@@ -1402,6 +1494,24 @@ func (sc *serverConn) processHeaderBlockFragment(st *stream, frag []byte, end bo
return nil return nil
} }
func (st *stream) processTrailerHeaderBlockFragment(frag []byte, end bool) error {
sc := st.sc
sc.serveG.check()
sc.hpackDecoder.SetEmitFunc(st.onNewTrailerField)
if _, err := sc.hpackDecoder.Write(frag); err != nil {
return ConnectionError(ErrCodeCompression)
}
if !end {
return nil
}
err := sc.hpackDecoder.Close()
st.endStream()
if err != nil {
return ConnectionError(ErrCodeCompression)
}
return nil
}
func (sc *serverConn) processPriority(f *PriorityFrame) error { func (sc *serverConn) processPriority(f *PriorityFrame) error {
adjustStreamPriority(sc.streams, f.StreamID, f.PriorityParam) adjustStreamPriority(sc.streams, f.StreamID, f.PriorityParam)
return nil return nil
...@@ -1489,6 +1599,26 @@ func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, err ...@@ -1489,6 +1599,26 @@ func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, err
if cookies := rp.header["Cookie"]; len(cookies) > 1 { if cookies := rp.header["Cookie"]; len(cookies) > 1 {
rp.header.Set("Cookie", strings.Join(cookies, "; ")) rp.header.Set("Cookie", strings.Join(cookies, "; "))
} }
// Setup Trailers
var trailer http.Header
for _, v := range rp.header["Trailer"] {
for _, key := range strings.Split(v, ",") {
key = http.CanonicalHeaderKey(strings.TrimSpace(key))
switch key {
case "Transfer-Encoding", "Trailer", "Content-Length":
// Bogus. (copy of http1 rules)
// Ignore.
default:
if trailer == nil {
trailer = make(http.Header)
}
trailer[key] = nil
}
}
}
delete(rp.header, "Trailer")
body := &requestBody{ body := &requestBody{
conn: sc, conn: sc,
stream: rp.stream, stream: rp.stream,
...@@ -1512,10 +1642,11 @@ func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, err ...@@ -1512,10 +1642,11 @@ func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, err
TLS: tlsState, TLS: tlsState,
Host: authority, Host: authority,
Body: body, Body: body,
Trailer: trailer,
} }
if bodyOpen { if bodyOpen {
body.pipe = &pipe{ body.pipe = &pipe{
b: &fixedBuffer{buf: make([]byte, initialWindowSize)}, // TODO: share/remove XXX b: &fixedBuffer{buf: make([]byte, initialWindowSize)}, // TODO: garbage
} }
if vv, ok := rp.header["Content-Length"]; ok { if vv, ok := rp.header["Content-Length"]; ok {
......
...@@ -246,6 +246,21 @@ func (st *serverTester) encodeHeaderField(k, v string) { ...@@ -246,6 +246,21 @@ func (st *serverTester) encodeHeaderField(k, v string) {
} }
} }
// encodeHeaderRaw is the magic-free version of encodeHeader.
// It takes 0 or more (k, v) pairs and encodes them.
func (st *serverTester) encodeHeaderRaw(headers ...string) []byte {
if len(headers)%2 == 1 {
panic("odd number of kv args")
}
st.headerBuf.Reset()
for len(headers) > 0 {
k, v := headers[0], headers[1]
st.encodeHeaderField(k, v)
headers = headers[2:]
}
return st.headerBuf.Bytes()
}
// encodeHeader encodes headers and returns their HPACK bytes. headers // encodeHeader encodes headers and returns their HPACK bytes. headers
// must contain an even number of key/value pairs. There may be // must contain an even number of key/value pairs. There may be
// multiple pairs for keys (e.g. "cookie"). The :method, :path, and // multiple pairs for keys (e.g. "cookie"). The :method, :path, and
...@@ -299,7 +314,6 @@ func (st *serverTester) encodeHeader(headers ...string) []byte { ...@@ -299,7 +314,6 @@ func (st *serverTester) encodeHeader(headers ...string) []byte {
vals[k] = append(vals[k], v) vals[k] = append(vals[k], v)
} }
} }
st.headerBuf.Reset()
for _, k := range keys { for _, k := range keys {
for _, v := range vals[k] { for _, v := range vals[k] {
st.encodeHeaderField(k, v) st.encodeHeaderField(k, v)
...@@ -2451,8 +2465,53 @@ func TestCompressionErrorOnClose(t *testing.T) { ...@@ -2451,8 +2465,53 @@ func TestCompressionErrorOnClose(t *testing.T) {
// test that a server handler can read trailers from a client // test that a server handler can read trailers from a client
func TestServerReadsTrailers(t *testing.T) { func TestServerReadsTrailers(t *testing.T) {
// TODO: use testBodyContents or testServerRequest const testBody = "some test body"
t.Skip("unimplemented") writeReq := func(st *serverTester) {
st.writeHeaders(HeadersFrameParam{
StreamID: 1, // clients send odd numbers
BlockFragment: st.encodeHeader("trailer", "Foo, Bar", "trailer", "Baz"),
EndStream: false,
EndHeaders: true,
})
st.writeData(1, false, []byte(testBody))
st.writeHeaders(HeadersFrameParam{
StreamID: 1, // clients send odd numbers
BlockFragment: st.encodeHeaderRaw(
"foo", "foov",
"bar", "barv",
"baz", "bazv",
"surprise", "wasn't declared; shouldn't show up",
),
EndStream: true,
EndHeaders: true,
})
}
checkReq := func(r *http.Request) {
wantTrailer := http.Header{
"Foo": nil,
"Bar": nil,
"Baz": nil,
}
if !reflect.DeepEqual(r.Trailer, wantTrailer) {
t.Errorf("initial Trailer = %v; want %v", r.Trailer, wantTrailer)
}
slurp, err := ioutil.ReadAll(r.Body)
if string(slurp) != testBody {
t.Errorf("read body %q; want %q", slurp, testBody)
}
if err != nil {
t.Fatalf("Body slurp: %v", err)
}
wantTrailerAfter := http.Header{
"Foo": {"foov"},
"Bar": {"barv"},
"Baz": {"bazv"},
}
if !reflect.DeepEqual(r.Trailer, wantTrailerAfter) {
t.Errorf("final Trailer = %v; want %v", r.Trailer, wantTrailerAfter)
}
}
testServerRequest(t, writeReq, checkReq)
} }
// test that a server handler can send trailers // test that a server handler can send trailers
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment