Commit 0cb26f78 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

http2: move HEADERS/CONTINUATION order checking into Framer

Removes state machine complication and duplication out of Server &
Transport and puts it into the Framer instead (where it's nicely
tested).

Also, for testing, start tracking the reason for errors. Later we'll
use it in GOAWAY frames' debug data too.

Change-Id: Ic933654a33edb62b4432c28fe09f7bfdb6f9b334
Reviewed-on: https://go-review.googlesource.com/18101Reviewed-by: 's avatarBlake Mizerany <blake.mizerany@gmail.com>
parent 5d0a0f8c
......@@ -75,3 +75,16 @@ func (e StreamError) Error() string {
type goAwayFlowError struct{}
func (goAwayFlowError) Error() string { return "connection exceeded flow control window size" }
// connErrorReason wraps a ConnectionError with an informative error about why it occurs.
// Errors of this type are only returned by the frame parser functions
// and converted into ConnectionError(ErrCodeProtocol).
type connError struct {
Code ErrCode
Reason string
}
func (e connError) Error() string {
return fmt.Sprintf("http2: connection error: %v: %v", e.Code, e.Reason)
}
......@@ -255,6 +255,11 @@ type Frame interface {
type Framer struct {
r io.Reader
lastFrame Frame
errReason string
// lastHeaderStream is non-zero if the last frame was an
// unfinished HEADERS/CONTINUATION.
lastHeaderStream uint32
maxReadSize uint32
headerBuf [frameHeaderLen]byte
......@@ -271,13 +276,19 @@ type Framer struct {
wbuf []byte
// AllowIllegalWrites permits the Framer's Write methods to
// write frames that do not conform to the HTTP/2 spec. This
// write frames that do not conform to the HTTP/2 spec. This
// permits using the Framer to test other HTTP/2
// implementations' conformance to the spec.
// If false, the Write methods will prefer to return an error
// rather than comply.
AllowIllegalWrites bool
// AllowIllegalReads permits the Framer's ReadFrame method
// to return non-compliant frames or frame orders.
// This is for testing and permits using the Framer to test
// other HTTP/2 implementations' conformance to the spec.
AllowIllegalReads bool
// TODO: track which type of frame & with which flags was sent
// last. Then return an error (unless AllowIllegalWrites) if
// we're in the middle of a header block and a
......@@ -394,12 +405,65 @@ func (fr *Framer) ReadFrame() (Frame, error) {
}
f, err := typeFrameParser(fh.Type)(fh, payload)
if err != nil {
if ce, ok := err.(connError); ok {
return nil, fr.connError(ce.Code, ce.Reason)
}
return nil, err
}
if err := fr.checkFrameOrder(f); err != nil {
return nil, err
}
fr.lastFrame = f
return f, nil
}
// connError returns ConnectionError(code) but first
// stashes away a public reason to the caller can optionally relay it
// to the peer before hanging up on them. This might help others debug
// their implementations.
func (fr *Framer) connError(code ErrCode, reason string) error {
fr.errReason = reason
return ConnectionError(code)
}
// checkFrameOrder reports an error if f is an invalid frame to return
// next from ReadFrame. Mostly it checks whether HEADERS and
// CONTINUATION frames are contiguous.
func (fr *Framer) checkFrameOrder(f Frame) error {
last := fr.lastFrame
fr.lastFrame = f
if fr.AllowIllegalReads {
return nil
}
fh := f.Header()
if fr.lastHeaderStream != 0 {
if fh.Type != FrameContinuation {
return fr.connError(ErrCodeProtocol,
fmt.Sprintf("got %s for stream %d; expected CONTINUATION following %s for stream %d",
fh.Type, fh.StreamID,
last.Header().Type, fr.lastHeaderStream))
}
if fh.StreamID != fr.lastHeaderStream {
return fr.connError(ErrCodeProtocol,
fmt.Sprintf("got CONTINUATION for stream %d; expected stream %d",
fh.StreamID, fr.lastHeaderStream))
}
} else if fh.Type == FrameContinuation {
return fr.connError(ErrCodeProtocol, fmt.Sprintf("unexpected CONTINUATION for stream %d", fh.StreamID))
}
switch fh.Type {
case FrameHeaders, FrameContinuation:
if fh.Flags.Has(FlagHeadersEndHeaders) {
fr.lastHeaderStream = 0
} else {
fr.lastHeaderStream = fh.StreamID
}
}
return nil
}
// A DataFrame conveys arbitrary, variable-length sequences of octets
// associated with a stream.
// See http://http2.github.io/http2-spec/#rfc.section.6.1
......@@ -428,7 +492,7 @@ func parseDataFrame(fh FrameHeader, payload []byte) (Frame, error) {
// field is 0x0, the recipient MUST respond with a
// connection error (Section 5.4.1) of type
// PROTOCOL_ERROR.
return nil, ConnectionError(ErrCodeProtocol)
return nil, connError{ErrCodeProtocol, "DATA frame with stream ID 0"}
}
f := &DataFrame{
FrameHeader: fh,
......@@ -446,7 +510,7 @@ func parseDataFrame(fh FrameHeader, payload []byte) (Frame, error) {
// length of the frame payload, the recipient MUST
// treat this as a connection error.
// Filed: https://github.com/http2/http2-spec/issues/610
return nil, ConnectionError(ErrCodeProtocol)
return nil, connError{ErrCodeProtocol, "pad size larger than data payload"}
}
f.data = payload[:len(payload)-int(padSize)]
return f, nil
......@@ -753,7 +817,7 @@ func parseHeadersFrame(fh FrameHeader, p []byte) (_ Frame, err error) {
// is received whose stream identifier field is 0x0, the recipient MUST
// respond with a connection error (Section 5.4.1) of type
// PROTOCOL_ERROR.
return nil, ConnectionError(ErrCodeProtocol)
return nil, connError{ErrCodeProtocol, "HEADERS frame with stream ID 0"}
}
var padLength uint8
if fh.Flags.Has(FlagHeadersPadded) {
......@@ -883,10 +947,10 @@ func (p PriorityParam) IsZero() bool {
func parsePriorityFrame(fh FrameHeader, payload []byte) (Frame, error) {
if fh.StreamID == 0 {
return nil, ConnectionError(ErrCodeProtocol)
return nil, connError{ErrCodeProtocol, "PRIORITY frame with stream ID 0"}
}
if len(payload) != 5 {
return nil, ConnectionError(ErrCodeFrameSize)
return nil, connError{ErrCodeFrameSize, fmt.Sprintf("PRIORITY frame payload size was %d; want 5", len(payload))}
}
v := binary.BigEndian.Uint32(payload[:4])
streamID := v & 0x7fffffff // mask off high bit
......@@ -956,6 +1020,9 @@ type ContinuationFrame struct {
}
func parseContinuationFrame(fh FrameHeader, p []byte) (Frame, error) {
if fh.StreamID == 0 {
return nil, connError{ErrCodeProtocol, "CONTINUATION frame with stream ID 0"}
}
return &ContinuationFrame{fh, p}, nil
}
......
......@@ -6,6 +6,8 @@ package http2
import (
"bytes"
"fmt"
"io"
"reflect"
"strings"
"testing"
......@@ -264,6 +266,7 @@ func TestWriteContinuation(t *testing.T) {
t.Errorf("test %q: %v", tt.name, err)
continue
}
fr.AllowIllegalReads = true
f, err := fr.ReadFrame()
if err != nil {
t.Errorf("test %q: failed to read the frame back: %v", tt.name, err)
......@@ -595,3 +598,138 @@ func TestWritePushPromise(t *testing.T) {
t.Fatalf("parsed back:\n%#v\nwant:\n%#v", f, want)
}
}
// test checkFrameOrder and that HEADERS and CONTINUATION frames can't be intermingled.
func TestReadFrameOrder(t *testing.T) {
head := func(f *Framer, id uint32, end bool) {
f.WriteHeaders(HeadersFrameParam{
StreamID: id,
BlockFragment: []byte("foo"), // unused, but non-empty
EndHeaders: end,
})
}
cont := func(f *Framer, id uint32, end bool) {
f.WriteContinuation(id, end, []byte("foo"))
}
tests := [...]struct {
name string
w func(*Framer)
atLeast int
wantErr string
}{
0: {
w: func(f *Framer) {
head(f, 1, true)
},
},
1: {
w: func(f *Framer) {
head(f, 1, true)
head(f, 2, true)
},
},
2: {
wantErr: "got HEADERS for stream 2; expected CONTINUATION following HEADERS for stream 1",
w: func(f *Framer) {
head(f, 1, false)
head(f, 2, true)
},
},
3: {
wantErr: "got DATA for stream 1; expected CONTINUATION following HEADERS for stream 1",
w: func(f *Framer) {
head(f, 1, false)
},
},
4: {
w: func(f *Framer) {
head(f, 1, false)
cont(f, 1, true)
head(f, 2, true)
},
},
5: {
wantErr: "got CONTINUATION for stream 2; expected stream 1",
w: func(f *Framer) {
head(f, 1, false)
cont(f, 2, true)
head(f, 2, true)
},
},
6: {
wantErr: "unexpected CONTINUATION for stream 1",
w: func(f *Framer) {
cont(f, 1, true)
},
},
7: {
wantErr: "unexpected CONTINUATION for stream 1",
w: func(f *Framer) {
cont(f, 1, false)
},
},
8: {
wantErr: "HEADERS frame with stream ID 0",
w: func(f *Framer) {
head(f, 0, true)
},
},
9: {
wantErr: "CONTINUATION frame with stream ID 0",
w: func(f *Framer) {
cont(f, 0, true)
},
},
10: {
wantErr: "unexpected CONTINUATION for stream 1",
atLeast: 5,
w: func(f *Framer) {
head(f, 1, false)
cont(f, 1, false)
cont(f, 1, false)
cont(f, 1, false)
cont(f, 1, true)
cont(f, 1, false)
},
},
}
for i, tt := range tests {
buf := new(bytes.Buffer)
f := NewFramer(buf, buf)
f.AllowIllegalWrites = true
tt.w(f)
f.WriteData(1, true, nil) // to test transition away from last step
var err error
n := 0
var log bytes.Buffer
for {
var got Frame
got, err = f.ReadFrame()
fmt.Fprintf(&log, " read %v, %v\n", got, err)
if err != nil {
break
}
n++
}
if err == io.EOF {
err = nil
}
ok := tt.wantErr == ""
if ok && err != nil {
t.Errorf("%d. after %d good frames, ReadFrame = %v; want success\n%s", i, n, err, log.Bytes())
continue
}
if !ok && err != ConnectionError(ErrCodeProtocol) {
t.Errorf("%d. after %d good frames, ReadFrame = %v; want ConnectionError(ErrCodeProtocol)\n%s", i, n, err, log.Bytes())
continue
}
if f.errReason != tt.wantErr {
t.Errorf("%d. framer eror = %q; want %q\n%s", i, f.errReason, tt.wantErr, log.Bytes())
}
if n < tt.atLeast {
t.Errorf("%d. framer only read %d frames; want at least %d\n%s", i, n, tt.atLeast, log.Bytes())
}
}
}
......@@ -1015,18 +1015,6 @@ func (sc *serverConn) resetStream(se StreamError) {
}
}
// curHeaderStreamID returns the stream ID of the header block we're
// currently in the middle of reading. If this returns non-zero, the
// next frame must be a CONTINUATION with this stream id.
func (sc *serverConn) curHeaderStreamID() uint32 {
sc.serveG.check()
st := sc.req.stream
if st == nil {
return 0
}
return st.id
}
// processFrameFromReader processes the serve loop's read from readFrameCh from the
// frame-reading goroutine.
// processFrameFromReader returns whether the connection should be kept open.
......@@ -1091,14 +1079,6 @@ func (sc *serverConn) processFrame(f Frame) error {
sc.sawFirstSettings = true
}
if s := sc.curHeaderStreamID(); s != 0 {
if cf, ok := f.(*ContinuationFrame); !ok {
return ConnectionError(ErrCodeProtocol)
} else if cf.Header().StreamID != s {
return ConnectionError(ErrCodeProtocol)
}
}
switch f := f.(type) {
case *SettingsFrame:
return sc.processSettings(f)
......@@ -1437,9 +1417,6 @@ func (st *stream) processTrailerHeaders(f *HeadersFrame) error {
func (sc *serverConn) processContinuation(f *ContinuationFrame) error {
sc.serveG.check()
st := sc.streams[f.Header().StreamID]
if st == nil || sc.curHeaderStreamID() != st.id {
return ConnectionError(ErrCodeProtocol)
}
if st.gotTrailerHeader {
return st.processTrailerHeaderBlockFragment(f.HeaderBlockFragment(), f.HeadersEnded())
}
......
......@@ -851,10 +851,6 @@ type clientConnReadLoop struct {
cc *ClientConn
activeRes map[uint32]*clientStream // keyed by streamID
// continueStreamID is the stream ID we're waiting for
// continuation frames for.
continueStreamID uint32
hdec *hpack.Decoder
// Fields reset on each HEADERS:
......@@ -924,21 +920,6 @@ func (rl *clientConnReadLoop) run() error {
}
cc.vlogf("Transport received %v: %#v", f.Header(), f)
streamID := f.Header().StreamID
_, isContinue := f.(*ContinuationFrame)
if isContinue {
if streamID != rl.continueStreamID {
cc.logf("Protocol violation: got CONTINUATION with id %d; want %d", streamID, rl.continueStreamID)
return ConnectionError(ErrCodeProtocol)
}
} else if rl.continueStreamID != 0 {
// Continue frames need to be adjacent in the stream
// and we were in the middle of headers.
cc.logf("Protocol violation: got %T for stream %d, want CONTINUATION for %d", f, streamID, rl.continueStreamID)
return ConnectionError(ErrCodeProtocol)
}
switch f := f.(type) {
case *HeadersFrame:
err = rl.processHeaders(f)
......@@ -986,13 +967,12 @@ func (rl *clientConnReadLoop) processHeaderBlockFragment(frag []byte, streamID u
cc := rl.cc
cs := cc.streamByID(streamID, streamEnded)
if cs == nil {
// We could return a ConnectionError(ErrCodeProtocol)
// here, except that in the case of us canceling
// client requests, we may also delete from the
// streams map, in which case we forgot that we sent
// this request. So, just ignore any responses for
// now. They might've been in-flight before the
// server got our RST_STREAM.
// We'd get here if we canceled a request while the
// server was mid-way through replying with its
// headers. (The case of a CONTINUATION arriving
// without HEADERS would be rejected earlier by the
// Framer). So if this was just something we canceled,
// ignore it.
return nil
}
if cs.headersDone {
......@@ -1004,12 +984,12 @@ func (rl *clientConnReadLoop) processHeaderBlockFragment(frag []byte, streamID u
if err != nil {
return ConnectionError(ErrCodeCompression)
}
if err := rl.hdec.Close(); err != nil {
return ConnectionError(ErrCodeCompression)
}
if !headersEnded {
rl.continueStreamID = cs.ID
return nil
}
// HEADERS (or CONTINUATION) are now over.
rl.continueStreamID = 0
if !cs.headersDone {
cs.headersDone = true
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment