Commit e2ba55e4 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

http2: fix Transport.RoundTrip hang on stream error before headers

If the Transport got a stream error on the response headers, it was
never unblocking the client. Previously, Response.Body reads would be
aborted with the stream error, but RoundTrip itself would never
unblock.

The Transport now also sends a RST_STREAM to the server when we
encounter a stream error.

Also, add a "Cause" field to StreamError with additional detail. The
old code was just returning the detail, without the stream error
header.

Fixes golang/go#16572

Change-Id: Ibecedb5779f17bf98c32787b68eb8a9b850833b3
Reviewed-on: https://go-review.googlesource.com/25402
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: 's avatarAndrew Gerrand <adg@golang.org>
parent f6d21198
......@@ -64,9 +64,17 @@ func (e ConnectionError) Error() string { return fmt.Sprintf("connection error:
type StreamError struct {
StreamID uint32
Code ErrCode
Cause error // optional additional detail
}
func streamError(id uint32, code ErrCode) StreamError {
return StreamError{StreamID: id, Code: code}
}
func (e StreamError) Error() string {
if e.Cause != nil {
return fmt.Sprintf("stream error: stream ID %d; %v; %v", e.StreamID, e.Code, e.Cause)
}
return fmt.Sprintf("stream error: stream ID %d; %v", e.StreamID, e.Code)
}
......
......@@ -863,7 +863,7 @@ func parseWindowUpdateFrame(fh FrameHeader, p []byte) (Frame, error) {
if fh.StreamID == 0 {
return nil, ConnectionError(ErrCodeProtocol)
}
return nil, StreamError{fh.StreamID, ErrCodeProtocol}
return nil, streamError(fh.StreamID, ErrCodeProtocol)
}
return &WindowUpdateFrame{
FrameHeader: fh,
......@@ -944,7 +944,7 @@ func parseHeadersFrame(fh FrameHeader, p []byte) (_ Frame, err error) {
}
}
if len(p)-int(padLength) <= 0 {
return nil, StreamError{fh.StreamID, ErrCodeProtocol}
return nil, streamError(fh.StreamID, ErrCodeProtocol)
}
hf.headerFragBuf = p[:len(p)-int(padLength)]
return hf, nil
......@@ -1483,14 +1483,14 @@ func (fr *Framer) readMetaFrame(hf *HeadersFrame) (*MetaHeadersFrame, error) {
if VerboseLogs {
log.Printf("http2: invalid header: %v", invalid)
}
return nil, StreamError{mh.StreamID, ErrCodeProtocol}
return nil, StreamError{mh.StreamID, ErrCodeProtocol, invalid}
}
if err := mh.checkPseudos(); err != nil {
fr.errDetail = err
if VerboseLogs {
log.Printf("http2: invalid pseudo headers: %v", err)
}
return nil, StreamError{mh.StreamID, ErrCodeProtocol}
return nil, StreamError{mh.StreamID, ErrCodeProtocol, err}
}
return mh, nil
}
......
......@@ -992,7 +992,7 @@ func TestMetaFrameHeader(t *testing.T) {
":path", "/", // bogus
))
},
want: StreamError{1, ErrCodeProtocol},
want: streamError(1, ErrCodeProtocol),
wantErrReason: "pseudo header field after regular",
},
7: {
......@@ -1003,7 +1003,7 @@ func TestMetaFrameHeader(t *testing.T) {
"foo", "bar",
))
},
want: StreamError{1, ErrCodeProtocol},
want: streamError(1, ErrCodeProtocol),
wantErrReason: "invalid pseudo-header \":unknown\"",
},
8: {
......@@ -1014,7 +1014,7 @@ func TestMetaFrameHeader(t *testing.T) {
":status", "100",
))
},
want: StreamError{1, ErrCodeProtocol},
want: streamError(1, ErrCodeProtocol),
wantErrReason: "mix of request and response pseudo headers",
},
9: {
......@@ -1025,7 +1025,7 @@ func TestMetaFrameHeader(t *testing.T) {
":method", "POST",
))
},
want: StreamError{1, ErrCodeProtocol},
want: streamError(1, ErrCodeProtocol),
wantErrReason: "duplicate pseudo-header \":method\"",
},
10: {
......@@ -1036,13 +1036,13 @@ func TestMetaFrameHeader(t *testing.T) {
11: {
name: "invalid_field_name",
w: func(f *Framer) { write(f, encodeHeaderRaw(t, "CapitalBad", "x")) },
want: StreamError{1, ErrCodeProtocol},
want: streamError(1, ErrCodeProtocol),
wantErrReason: "invalid header field name \"CapitalBad\"",
},
12: {
name: "invalid_field_value",
w: func(f *Framer) { write(f, encodeHeaderRaw(t, "key", "bad_null\x00")) },
want: StreamError{1, ErrCodeProtocol},
want: streamError(1, ErrCodeProtocol),
wantErrReason: "invalid header field value \"bad_null\\x00\"",
},
}
......@@ -1063,6 +1063,13 @@ func TestMetaFrameHeader(t *testing.T) {
got, err = f.ReadFrame()
if err != nil {
got = err
// Ignore the StreamError.Cause field, if it matches the wantErrReason.
// The test table above predates the Cause field.
if se, ok := err.(StreamError); ok && se.Cause != nil && se.Cause.Error() == tt.wantErrReason {
se.Cause = nil
got = se
}
}
if !reflect.DeepEqual(got, tt.want) {
if mhg, ok := got.(*MetaHeadersFrame); ok {
......
......@@ -922,7 +922,7 @@ func (sc *serverConn) wroteFrame(res frameWriteResult) {
// state here anyway, after telling the peer
// we're hanging up on them.
st.state = stateHalfClosedLocal // won't last long, but necessary for closeStream via resetStream
errCancel := StreamError{st.id, ErrCodeCancel}
errCancel := streamError(st.id, ErrCodeCancel)
sc.resetStream(errCancel)
case stateHalfClosedRemote:
sc.closeStream(st, errHandlerComplete)
......@@ -1133,7 +1133,7 @@ func (sc *serverConn) processWindowUpdate(f *WindowUpdateFrame) error {
return nil
}
if !st.flow.add(int32(f.Increment)) {
return StreamError{f.StreamID, ErrCodeFlowControl}
return streamError(f.StreamID, ErrCodeFlowControl)
}
default: // connection-level flow control
if !sc.flow.add(int32(f.Increment)) {
......@@ -1159,7 +1159,7 @@ func (sc *serverConn) processResetStream(f *RSTStreamFrame) error {
if st != nil {
st.gotReset = true
st.cancelCtx()
sc.closeStream(st, StreamError{f.StreamID, f.ErrCode})
sc.closeStream(st, streamError(f.StreamID, f.ErrCode))
}
return nil
}
......@@ -1299,7 +1299,7 @@ func (sc *serverConn) processData(f *DataFrame) error {
// and return any flow control bytes since we're not going
// to consume them.
if sc.inflow.available() < int32(f.Length) {
return StreamError{id, ErrCodeFlowControl}
return streamError(id, ErrCodeFlowControl)
}
// Deduct the flow control from inflow, since we're
// going to immediately add it back in
......@@ -1308,7 +1308,7 @@ func (sc *serverConn) processData(f *DataFrame) error {
sc.inflow.take(int32(f.Length))
sc.sendWindowUpdate(nil, int(f.Length)) // conn-level
return StreamError{id, ErrCodeStreamClosed}
return streamError(id, ErrCodeStreamClosed)
}
if st.body == nil {
panic("internal error: should have a body in this state")
......@@ -1317,19 +1317,19 @@ func (sc *serverConn) processData(f *DataFrame) error {
// Sender sending more than they'd declared?
if st.declBodyBytes != -1 && st.bodyBytes+int64(len(data)) > st.declBodyBytes {
st.body.CloseWithError(fmt.Errorf("sender tried to send more than declared Content-Length of %d bytes", st.declBodyBytes))
return StreamError{id, ErrCodeStreamClosed}
return streamError(id, ErrCodeStreamClosed)
}
if f.Length > 0 {
// Check whether the client has flow control quota.
if st.inflow.available() < int32(f.Length) {
return StreamError{id, ErrCodeFlowControl}
return streamError(id, ErrCodeFlowControl)
}
st.inflow.take(int32(f.Length))
if len(data) > 0 {
wrote, err := st.body.Write(data)
if err != nil {
return StreamError{id, ErrCodeStreamClosed}
return streamError(id, ErrCodeStreamClosed)
}
if wrote != len(data) {
panic("internal error: bad Writer")
......@@ -1446,14 +1446,14 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error {
// REFUSED_STREAM."
if sc.unackedSettings == 0 {
// They should know better.
return StreamError{st.id, ErrCodeProtocol}
return streamError(st.id, ErrCodeProtocol)
}
// Assume it's a network race, where they just haven't
// received our last SETTINGS update. But actually
// this can't happen yet, because we don't yet provide
// a way for users to adjust server parameters at
// runtime.
return StreamError{st.id, ErrCodeRefusedStream}
return streamError(st.id, ErrCodeRefusedStream)
}
rw, req, err := sc.newWriterAndRequest(st, f)
......@@ -1487,11 +1487,11 @@ func (st *stream) processTrailerHeaders(f *MetaHeadersFrame) error {
}
st.gotTrailerHeader = true
if !f.StreamEnded() {
return StreamError{st.id, ErrCodeProtocol}
return streamError(st.id, ErrCodeProtocol)
}
if len(f.PseudoFields()) > 0 {
return StreamError{st.id, ErrCodeProtocol}
return streamError(st.id, ErrCodeProtocol)
}
if st.trailer != nil {
for _, hf := range f.RegularFields() {
......@@ -1500,7 +1500,7 @@ func (st *stream) processTrailerHeaders(f *MetaHeadersFrame) error {
// TODO: send more details to the peer somehow. But http2 has
// no way to send debug data at a stream level. Discuss with
// HTTP folk.
return StreamError{st.id, ErrCodeProtocol}
return streamError(st.id, ErrCodeProtocol)
}
st.trailer[key] = append(st.trailer[key], hf.Value)
}
......@@ -1561,7 +1561,7 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res
isConnect := method == "CONNECT"
if isConnect {
if path != "" || scheme != "" || authority == "" {
return nil, nil, StreamError{f.StreamID, ErrCodeProtocol}
return nil, nil, streamError(f.StreamID, ErrCodeProtocol)
}
} else if method == "" || path == "" ||
(scheme != "https" && scheme != "http") {
......@@ -1575,13 +1575,13 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res
// "All HTTP/2 requests MUST include exactly one valid
// value for the :method, :scheme, and :path
// pseudo-header fields"
return nil, nil, StreamError{f.StreamID, ErrCodeProtocol}
return nil, nil, streamError(f.StreamID, ErrCodeProtocol)
}
bodyOpen := !f.StreamEnded()
if method == "HEAD" && bodyOpen {
// HEAD requests can't have bodies
return nil, nil, StreamError{f.StreamID, ErrCodeProtocol}
return nil, nil, streamError(f.StreamID, ErrCodeProtocol)
}
var tlsState *tls.ConnectionState // nil if not scheme https
......@@ -1639,7 +1639,7 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res
var err error
url_, err = url.ParseRequestURI(path)
if err != nil {
return nil, nil, StreamError{f.StreamID, ErrCodeProtocol}
return nil, nil, streamError(f.StreamID, ErrCodeProtocol)
}
requestURI = path
}
......
......@@ -55,11 +55,6 @@ type serverTester struct {
// writing headers:
headerBuf bytes.Buffer
hpackEnc *hpack.Encoder
// reading frames:
frc chan Frame
frErrc chan error
readTimer *time.Timer
}
func init() {
......@@ -117,8 +112,6 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{}
t: t,
ts: ts,
logBuf: logBuf,
frc: make(chan Frame, 1),
frErrc: make(chan error, 1),
}
st.hpackEnc = hpack.NewEncoder(&st.headerBuf)
st.hpackDec = hpack.NewDecoder(initialHeaderTableSize, st.onHeaderField)
......@@ -365,32 +358,33 @@ func (st *serverTester) writeDataPadded(streamID uint32, endStream bool, data, p
}
}
func (st *serverTester) readFrame() (Frame, error) {
func readFrameTimeout(fr *Framer, wait time.Duration) (Frame, error) {
ch := make(chan interface{}, 1)
go func() {
fr, err := st.fr.ReadFrame()
fr, err := fr.ReadFrame()
if err != nil {
st.frErrc <- err
ch <- err
} else {
st.frc <- fr
ch <- fr
}
}()
t := st.readTimer
if t == nil {
t = time.NewTimer(2 * time.Second)
st.readTimer = t
}
t.Reset(2 * time.Second)
defer t.Stop()
t := time.NewTimer(wait)
select {
case f := <-st.frc:
return f, nil
case err := <-st.frErrc:
return nil, err
case v := <-ch:
t.Stop()
if fr, ok := v.(Frame); ok {
return fr, nil
}
return nil, v.(error)
case <-t.C:
return nil, errors.New("timeout waiting for frame")
}
}
func (st *serverTester) readFrame() (Frame, error) {
return readFrameTimeout(st.fr, 2*time.Second)
}
func (st *serverTester) wantHeaders() *HeadersFrame {
f, err := st.readFrame()
if err != nil {
......
......@@ -1229,7 +1229,11 @@ func (rl *clientConnReadLoop) run() error {
}
if se, ok := err.(StreamError); ok {
if cs := cc.streamByID(se.StreamID, true /*ended; remove it*/); cs != nil {
rl.endStreamError(cs, cc.fr.errDetail)
cs.cc.writeStreamReset(cs.ID, se.Code, err)
if se.Cause == nil {
se.Cause = cc.fr.errDetail
}
rl.endStreamError(cs, se)
}
continue
} else if err != nil {
......@@ -1639,6 +1643,11 @@ func (rl *clientConnReadLoop) endStreamError(cs *clientStream, err error) {
if isConnectionCloseRequest(cs.req) {
rl.closeWhenIdle = true
}
select {
case cs.resc <- resAndError{err: err}:
default:
}
}
func (cs *clientStream) copyTrailers() {
......@@ -1740,7 +1749,7 @@ func (rl *clientConnReadLoop) processResetStream(f *RSTStreamFrame) error {
// which closes this, so there
// isn't a race.
default:
err := StreamError{cs.ID, f.ErrCode}
err := streamError(cs.ID, f.ErrCode)
cs.resetErr = err
close(cs.peerReset)
cs.bufPipe.CloseWithError(err)
......
......@@ -699,6 +699,28 @@ func (ct *clientTester) start(which string, errc chan<- error, fn func() error)
}()
}
func (ct *clientTester) readFrame() (Frame, error) {
return readFrameTimeout(ct.fr, 2*time.Second)
}
func (ct *clientTester) firstHeaders() (*HeadersFrame, error) {
for {
f, err := ct.readFrame()
if err != nil {
return nil, fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
}
switch f.(type) {
case *WindowUpdateFrame, *SettingsFrame:
continue
}
hf, ok := f.(*HeadersFrame)
if !ok {
return nil, fmt.Errorf("Got %T; want HeadersFrame", f)
}
return hf, nil
}
}
type countingReader struct {
n *int64
}
......@@ -1224,8 +1246,9 @@ func testInvalidTrailer(t *testing.T, trailers headerType, wantErr error, writeT
return fmt.Errorf("status code = %v; want 200", res.StatusCode)
}
slurp, err := ioutil.ReadAll(res.Body)
if err != wantErr {
return fmt.Errorf("res.Body ReadAll error = %q, %#v; want %T of %#v", slurp, err, wantErr, wantErr)
se, ok := err.(StreamError)
if !ok || se.Cause != wantErr {
return fmt.Errorf("res.Body ReadAll error = %q, %#v; want StreamError with cause %T, %#v", slurp, err, wantErr, wantErr)
}
if len(slurp) > 0 {
return fmt.Errorf("body = %q; want nothing", slurp)
......@@ -2278,3 +2301,59 @@ func TestTransportReturnsDataPaddingFlowControl(t *testing.T) {
}
ct.run()
}
// golang.org/issue/16572 -- RoundTrip shouldn't hang when it gets a
// StreamError as a result of the response HEADERS
func TestTransportReturnsErrorOnBadResponseHeaders(t *testing.T) {
ct := newClientTester(t)
ct.client = func() error {
req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
res, err := ct.tr.RoundTrip(req)
if err == nil {
res.Body.Close()
return errors.New("unexpected successful GET")
}
want := StreamError{1, ErrCodeProtocol, headerFieldNameError(" content-type")}
if !reflect.DeepEqual(want, err) {
t.Errorf("RoundTrip error = %#v; want %#v", err, want)
}
return nil
}
ct.server = func() error {
ct.greet()
hf, err := ct.firstHeaders()
if err != nil {
return err
}
var buf bytes.Buffer
enc := hpack.NewEncoder(&buf)
enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
enc.WriteField(hpack.HeaderField{Name: " content-type", Value: "bogus"}) // bogus spaces
ct.fr.WriteHeaders(HeadersFrameParam{
StreamID: hf.StreamID,
EndHeaders: true,
EndStream: false,
BlockFragment: buf.Bytes(),
})
for {
fr, err := ct.readFrame()
if err != nil {
return fmt.Errorf("error waiting for RST_STREAM from client: %v", err)
}
if _, ok := fr.(*SettingsFrame); ok {
continue
}
if rst, ok := fr.(*RSTStreamFrame); !ok || rst.StreamID != 1 || rst.ErrCode != ErrCodeProtocol {
t.Errorf("Frame = %v; want RST_STREAM for stream 1 with ErrCodeProtocol", summarizeFrame(fr))
}
break
}
return nil
}
ct.run()
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment