Commit d8f3c68d authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

http2: fix enforcement of max header list size

In the first attempt to enforce the SETTINGS_MAX_HEADER_LIST_SIZE
(https://go-review.googlesource.com/15751), the enforcement happened
in the hpack decoder and the hpack decoder returned errors on Write
and Close if the limit was violated. This was incorrect because the
decoder is used over the life of the connection and all subsequent
requests and could therefore get out of sync.

Instead, this moves the counting of the limit up to the http2 package
in the serverConn type, and replaces the hpack counting mechanism with
a simple on/off switch. When SetEmitEnabled is set false, the header
field emit callbacks will be suppressed and the hpack Decoder will do
less work (less CPU and garbage) if possible, but will still return
nil from Write and Close on valid input, and will still stay in sync
it the stream.

The http2 Server then returns a 431 error if emits were disabled while
processing the HEADER or any CONTINUATION frames.

Fixes golang/go#12843

Change-Id: I3b41aaefc6c6ee6218225f8dc62bba6ae5fe8f2d
Reviewed-on: https://go-review.googlesource.com/15733Reviewed-by: 's avatarAndrew Gerrand <adg@golang.org>
parent 2a045c20
...@@ -64,9 +64,7 @@ type Decoder struct { ...@@ -64,9 +64,7 @@ type Decoder struct {
dynTab dynamicTable dynTab dynamicTable
emit func(f HeaderField) emit func(f HeaderField)
headerListSize int64 emitEnabled bool // whether calls to emit are enabled
maxHeaderListSize uint32 // 0 means unlimited
hitLimit bool
// buf is the unparsed buffer. It's only written to // buf is the unparsed buffer. It's only written to
// saveBuf if it was truncated in the middle of a header // saveBuf if it was truncated in the middle of a header
...@@ -78,23 +76,29 @@ type Decoder struct { ...@@ -78,23 +76,29 @@ type Decoder struct {
// NewDecoder returns a new decoder with the provided maximum dynamic // NewDecoder returns a new decoder with the provided maximum dynamic
// table size. The emitFunc will be called for each valid field // table size. The emitFunc will be called for each valid field
// parsed. // parsed, in the same goroutine as calls to Write, before Write returns.
func NewDecoder(maxDynamicTableSize uint32, emitFunc func(f HeaderField)) *Decoder { func NewDecoder(maxDynamicTableSize uint32, emitFunc func(f HeaderField)) *Decoder {
d := &Decoder{ d := &Decoder{
emit: emitFunc, emit: emitFunc,
emitEnabled: true,
} }
d.dynTab.allowedMaxSize = maxDynamicTableSize d.dynTab.allowedMaxSize = maxDynamicTableSize
d.dynTab.setMaxSize(maxDynamicTableSize) d.dynTab.setMaxSize(maxDynamicTableSize)
return d return d
} }
// SetMaxHeaderListSize sets the decoder's SETTINGS_MAX_HEADER_LIST_SIZE. // SetEmitEnabled controls whether the emitFunc provided to NewDecoder
// It should be set before any call to Write. // should be called. The default is true.
// The default, 0, means unlimited. //
// If the limit is passed, calls to Write and Close will return ErrMaxHeaderListSize. // This facility exists to let servers enforce MAX_HEADER_LIST_SIZE
func (d *Decoder) SetMaxHeaderListSize(v uint32) { // while still decoding and keeping in-sync with decoder state, but
d.maxHeaderListSize = v // without doing unnecessary decompression or generating unnecessary
} // garbage for header fields past the limit.
func (d *Decoder) SetEmitEnabled(v bool) { d.emitEnabled = v }
// EmitEnabled reports whether calls to the emitFunc provided to NewDecoder
// are currently enabled. The default is true.
func (d *Decoder) EmitEnabled() bool { return d.emitEnabled }
// TODO: add method *Decoder.Reset(maxSize, emitFunc) to let callers re-use Decoders and their // TODO: add method *Decoder.Reset(maxSize, emitFunc) to let callers re-use Decoders and their
// underlying buffers for garbage reasons. // underlying buffers for garbage reasons.
...@@ -235,16 +239,11 @@ func (d *Decoder) DecodeFull(p []byte) ([]HeaderField, error) { ...@@ -235,16 +239,11 @@ func (d *Decoder) DecodeFull(p []byte) ([]HeaderField, error) {
return hf, nil return hf, nil
} }
var ErrMaxHeaderListSize = errors.New("hpack: max header list size exceeded")
func (d *Decoder) Close() error { func (d *Decoder) Close() error {
if d.saveBuf.Len() > 0 { if d.saveBuf.Len() > 0 {
d.saveBuf.Reset() d.saveBuf.Reset()
return DecodingError{errors.New("truncated headers")} return DecodingError{errors.New("truncated headers")}
} }
if d.hitLimit {
return ErrMaxHeaderListSize
}
return nil return nil
} }
...@@ -265,7 +264,7 @@ func (d *Decoder) Write(p []byte) (n int, err error) { ...@@ -265,7 +264,7 @@ func (d *Decoder) Write(p []byte) (n int, err error) {
d.saveBuf.Reset() d.saveBuf.Reset()
} }
for len(d.buf) > 0 && !d.hitLimit { for len(d.buf) > 0 {
err = d.parseHeaderFieldRepr() err = d.parseHeaderFieldRepr()
if err != nil { if err != nil {
if err == errNeedMore { if err == errNeedMore {
...@@ -275,9 +274,6 @@ func (d *Decoder) Write(p []byte) (n int, err error) { ...@@ -275,9 +274,6 @@ func (d *Decoder) Write(p []byte) (n int, err error) {
break break
} }
} }
if err == nil && d.hitLimit {
err = ErrMaxHeaderListSize
}
return len(p), err return len(p), err
} }
...@@ -359,6 +355,7 @@ func (d *Decoder) parseFieldLiteral(n uint8, it indexType) error { ...@@ -359,6 +355,7 @@ func (d *Decoder) parseFieldLiteral(n uint8, it indexType) error {
} }
var hf HeaderField var hf HeaderField
wantStr := d.emitEnabled || it.indexed()
if nameIdx > 0 { if nameIdx > 0 {
ihf, ok := d.at(nameIdx) ihf, ok := d.at(nameIdx)
if !ok { if !ok {
...@@ -366,12 +363,12 @@ func (d *Decoder) parseFieldLiteral(n uint8, it indexType) error { ...@@ -366,12 +363,12 @@ func (d *Decoder) parseFieldLiteral(n uint8, it indexType) error {
} }
hf.Name = ihf.Name hf.Name = ihf.Name
} else { } else {
hf.Name, buf, err = readString(buf) hf.Name, buf, err = readString(buf, wantStr)
if err != nil { if err != nil {
return err return err
} }
} }
hf.Value, buf, err = readString(buf) hf.Value, buf, err = readString(buf, wantStr)
if err != nil { if err != nil {
return err return err
} }
...@@ -385,13 +382,9 @@ func (d *Decoder) parseFieldLiteral(n uint8, it indexType) error { ...@@ -385,13 +382,9 @@ func (d *Decoder) parseFieldLiteral(n uint8, it indexType) error {
} }
func (d *Decoder) callEmit(hf HeaderField) { func (d *Decoder) callEmit(hf HeaderField) {
const overheadPerField = 32 // per http2 section 6.5.2, etc if d.emitEnabled {
d.headerListSize += int64(len(hf.Name)+len(hf.Value)) + overheadPerField d.emit(hf)
if d.maxHeaderListSize != 0 && d.headerListSize > int64(d.maxHeaderListSize) {
d.hitLimit = true
return
} }
d.emit(hf)
} }
// (same invariants and behavior as parseHeaderFieldRepr) // (same invariants and behavior as parseHeaderFieldRepr)
...@@ -452,7 +445,15 @@ func readVarInt(n byte, p []byte) (i uint64, remain []byte, err error) { ...@@ -452,7 +445,15 @@ func readVarInt(n byte, p []byte) (i uint64, remain []byte, err error) {
return 0, origP, errNeedMore return 0, origP, errNeedMore
} }
func readString(p []byte) (s string, remain []byte, err error) { // readString decodes an hpack string from p.
//
// wantStr is whether s will be used. If false, decompression and
// []byte->string garbage are skipped if s will be ignored
// anyway. This does mean that huffman decoding errors for non-indexed
// strings past the MAX_HEADER_LIST_SIZE are ignored, but the server
// is returning an error anyway, and because they're not indexed, the error
// won't affect the decoding state.
func readString(p []byte, wantStr bool) (s string, remain []byte, err error) {
if len(p) == 0 { if len(p) == 0 {
return "", p, errNeedMore return "", p, errNeedMore
} }
...@@ -465,13 +466,19 @@ func readString(p []byte) (s string, remain []byte, err error) { ...@@ -465,13 +466,19 @@ func readString(p []byte) (s string, remain []byte, err error) {
return "", p, errNeedMore return "", p, errNeedMore
} }
if !isHuff { if !isHuff {
return string(p[:strLen]), p[strLen:], nil if wantStr {
s = string(p[:strLen])
}
return s, p[strLen:], nil
} }
// TODO: optimize this garbage: if wantStr {
var buf bytes.Buffer // TODO: optimize this garbage:
if _, err := HuffmanDecode(&buf, p[:strLen]); err != nil { var buf bytes.Buffer
return "", nil, err if _, err := HuffmanDecode(&buf, p[:strLen]); err != nil {
return "", nil, err
}
s = buf.String()
} }
return buf.String(), p[strLen:], nil return s, p[strLen:], nil
} }
...@@ -647,40 +647,28 @@ func dehex(s string) []byte { ...@@ -647,40 +647,28 @@ func dehex(s string) []byte {
return b return b
} }
func TestMaxHeaderListSize(t *testing.T) { func TestEmitEnabled(t *testing.T) {
tests := []struct { var buf bytes.Buffer
fields []HeaderField enc := NewEncoder(&buf)
max int enc.WriteField(HeaderField{Name: "foo", Value: "bar"})
wantErr bool enc.WriteField(HeaderField{Name: "foo", Value: "bar"})
}{
// Plenty of space. numCallback := 0
{ var dec *Decoder
fields: []HeaderField{{Name: "foo", Value: "bar"}}, dec = NewDecoder(8<<20, func(HeaderField) {
max: 500, numCallback++
}, dec.SetEmitEnabled(false)
// Exactly right limit. })
{ if !dec.EmitEnabled() {
fields: []HeaderField{{Name: "foo", Value: "bar"}}, t.Errorf("initial emit enabled = false; want true")
max: len("foo") + len("bar") + 32,
},
// One byte too short.
{
fields: []HeaderField{{Name: "foo", Value: "bar"}},
max: len("foo") + len("bar") + 32 - 1,
wantErr: true,
},
} }
for i, tt := range tests { if _, err := dec.Write(buf.Bytes()); err != nil {
var buf bytes.Buffer t.Error(err)
enc := NewEncoder(&buf) }
for _, hf := range tt.fields { if numCallback != 1 {
enc.WriteField(hf) t.Errorf("num callbacks = %d; want 1", numCallback)
} }
dec := NewDecoder(8<<20, func(HeaderField) {}) if dec.EmitEnabled() {
dec.SetMaxHeaderListSize(uint32(tt.max)) t.Errorf("emit enabled = true; want false")
_, err := dec.Write(buf.Bytes())
if (err != nil) != tt.wantErr {
t.Errorf("%d. err = %v; want err = %v", i, err, tt.wantErr)
}
} }
} }
...@@ -220,7 +220,6 @@ func (srv *Server) handleConn(hs *http.Server, c net.Conn, h http.Handler) { ...@@ -220,7 +220,6 @@ func (srv *Server) handleConn(hs *http.Server, c net.Conn, h http.Handler) {
sc.inflow.add(initialWindowSize) sc.inflow.add(initialWindowSize)
sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf) sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
sc.hpackDecoder = hpack.NewDecoder(initialHeaderTableSize, sc.onNewHeaderField) sc.hpackDecoder = hpack.NewDecoder(initialHeaderTableSize, sc.onNewHeaderField)
sc.hpackDecoder.SetMaxHeaderListSize(sc.maxHeaderListSize())
fr := NewFramer(sc.bw, c) fr := NewFramer(sc.bw, c)
fr.SetMaxReadFrameSize(srv.maxReadFrameSize()) fr.SetMaxReadFrameSize(srv.maxReadFrameSize())
...@@ -373,7 +372,7 @@ type serverConn struct { ...@@ -373,7 +372,7 @@ type serverConn struct {
func (sc *serverConn) maxHeaderListSize() uint32 { func (sc *serverConn) maxHeaderListSize() uint32 {
n := sc.hs.MaxHeaderBytes n := sc.hs.MaxHeaderBytes
if n == 0 { if n <= 0 {
n = http.DefaultMaxHeaderBytes n = http.DefaultMaxHeaderBytes
} }
// http2's count is in a slightly different unit and includes 32 bytes per pair. // http2's count is in a slightly different unit and includes 32 bytes per pair.
...@@ -393,8 +392,9 @@ type requestParam struct { ...@@ -393,8 +392,9 @@ type requestParam struct {
header http.Header header http.Header
method, path string method, path string
scheme, authority string scheme, authority string
sawRegularHeader bool // saw a non-pseudo header already sawRegularHeader bool // saw a non-pseudo header already
invalidHeader bool // an invalid header was seen invalidHeader bool // an invalid header was seen
headerListSize int64 // actually uint32, but easier math this way
} }
// stream represents a stream. This is the minimal metadata needed by // stream represents a stream. This is the minimal metadata needed by
...@@ -515,6 +515,11 @@ func (sc *serverConn) onNewHeaderField(f hpack.HeaderField) { ...@@ -515,6 +515,11 @@ func (sc *serverConn) onNewHeaderField(f hpack.HeaderField) {
default: default:
sc.req.sawRegularHeader = true sc.req.sawRegularHeader = true
sc.req.header.Add(sc.canonicalHeader(f.Name), f.Value) sc.req.header.Add(sc.canonicalHeader(f.Name), f.Value)
const headerFieldOverhead = 32 // per spec
sc.req.headerListSize += int64(len(f.Name)) + int64(len(f.Value)) + headerFieldOverhead
if sc.req.headerListSize > int64(sc.maxHeaderListSize()) {
sc.hpackDecoder.SetEmitEnabled(false)
}
} }
} }
...@@ -1247,6 +1252,7 @@ func (sc *serverConn) processHeaders(f *HeadersFrame) error { ...@@ -1247,6 +1252,7 @@ func (sc *serverConn) processHeaders(f *HeadersFrame) error {
stream: st, stream: st,
header: make(http.Header), header: make(http.Header),
} }
sc.hpackDecoder.SetEmitEnabled(true)
return sc.processHeaderBlockFragment(st, f.HeaderBlockFragment(), f.HeadersEnded()) return sc.processHeaderBlockFragment(st, f.HeaderBlockFragment(), f.HeadersEnded())
} }
...@@ -1298,7 +1304,14 @@ func (sc *serverConn) processHeaderBlockFragment(st *stream, frag []byte, end bo ...@@ -1298,7 +1304,14 @@ func (sc *serverConn) processHeaderBlockFragment(st *stream, frag []byte, end bo
} }
st.body = req.Body.(*requestBody).pipe // may be nil st.body = req.Body.(*requestBody).pipe // may be nil
st.declBodyBytes = req.ContentLength st.declBodyBytes = req.ContentLength
go sc.runHandler(rw, req)
handler := sc.handler.ServeHTTP
if !sc.hpackDecoder.EmitEnabled() {
// Their header list was too long. Send a 431 error.
handler = handleHeaderListTooLong
}
go sc.runHandler(rw, req, handler)
return nil return nil
} }
...@@ -1438,10 +1451,20 @@ func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, err ...@@ -1438,10 +1451,20 @@ func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, err
} }
// Run on its own goroutine. // Run on its own goroutine.
func (sc *serverConn) runHandler(rw *responseWriter, req *http.Request) { func (sc *serverConn) runHandler(rw *responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) {
defer rw.handlerDone() defer rw.handlerDone()
// TODO: catch panics like net/http.Server // TODO: catch panics like net/http.Server
sc.handler.ServeHTTP(rw, req) handler(rw, req)
}
func handleHeaderListTooLong(w http.ResponseWriter, r *http.Request) {
// 10.5.1 Limits on Header Block Size:
// .. "A server that receives a larger header block than it is
// willing to handle can send an HTTP 431 (Request Header Fields Too
// Large) status code"
const statusRequestHeaderFieldsTooLarge = 431 // only in Go 1.6+
w.WriteHeader(statusRequestHeaderFieldsTooLarge)
io.WriteString(w, "<h1>HTTP Error 431</h1><p>Request Header Field(s) Too Large</p>")
} }
// called from handler goroutines. // called from handler goroutines.
......
...@@ -2251,9 +2251,18 @@ func TestServerDoS_MaxHeaderListSize(t *testing.T) { ...@@ -2251,9 +2251,18 @@ func TestServerDoS_MaxHeaderListSize(t *testing.T) {
st.fr.WriteContinuation(1, len(b) == 0, chunk) st.fr.WriteContinuation(1, len(b) == 0, chunk)
} }
fr, err := st.fr.ReadFrame() h := st.wantHeaders()
if err == nil { if !h.HeadersEnded() {
t.Fatalf("want error; got unexpected frame: %#v", fr) t.Fatalf("Got HEADERS without END_HEADERS set: %v", h)
}
headers := decodeHeader(t, h.HeaderBlockFragment())
want := [][2]string{
{":status", "431"},
{"content-type", "text/html; charset=utf-8"},
{"content-length", "63"},
}
if !reflect.DeepEqual(headers, want) {
t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want)
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment