Commit 1c5f79cf authored by Ruslan Nigmatullin's avatar Ruslan Nigmatullin Committed by Brad Fitzpatrick

http2: don't leak streams on broken body

Updates golang/go#27208

Change-Id: I5d9a643f33d27d33b24f670c98f5a51aa6000967
GitHub-Last-Rev: 3ac4a573b62846ef4944599085218e119819383c
GitHub-Pull-Request: golang/net#18
Reviewed-on: https://go-review.googlesource.com/c/132715
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: 's avatarBrad Fitzpatrick <bradfitz@golang.org>
parent a544f70c
...@@ -1100,6 +1100,7 @@ func (cc *ClientConn) roundTrip(req *http.Request) (res *http.Response, gotErrAf ...@@ -1100,6 +1100,7 @@ func (cc *ClientConn) roundTrip(req *http.Request) (res *http.Response, gotErrAf
default: default:
} }
if err != nil { if err != nil {
cc.forgetStreamID(cs.ID)
return nil, cs.getStartedWrite(), err return nil, cs.getStartedWrite(), err
} }
bodyWritten = true bodyWritten = true
...@@ -1221,6 +1222,7 @@ func (cs *clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) ( ...@@ -1221,6 +1222,7 @@ func (cs *clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) (
sawEOF = true sawEOF = true
err = nil err = nil
} else if err != nil { } else if err != nil {
cc.writeStreamReset(cs.ID, ErrCodeCancel, err)
return err return err
} }
......
...@@ -4180,3 +4180,99 @@ func TestNoDialH2RoundTripperType(t *testing.T) { ...@@ -4180,3 +4180,99 @@ func TestNoDialH2RoundTripperType(t *testing.T) {
t.Fatalf("wrong kind %T; want *Transport", v.Interface()) t.Fatalf("wrong kind %T; want *Transport", v.Interface())
} }
} }
type errReader struct {
body []byte
err error
}
func (r *errReader) Read(p []byte) (int, error) {
if len(r.body) > 0 {
n := copy(p, r.body)
r.body = r.body[n:]
return n, nil
}
return 0, r.err
}
func testTransportBodyReadError(t *testing.T, body []byte) {
clientDone := make(chan struct{})
ct := newClientTester(t)
ct.client = func() error {
defer ct.cc.(*net.TCPConn).CloseWrite()
defer close(clientDone)
checkNoStreams := func() error {
cp, ok := ct.tr.connPool().(*clientConnPool)
if !ok {
return fmt.Errorf("conn pool is %T; want *clientConnPool", ct.tr.connPool())
}
cp.mu.Lock()
defer cp.mu.Unlock()
conns, ok := cp.conns["dummy.tld:443"]
if !ok {
return fmt.Errorf("missing connection")
}
if len(conns) != 1 {
return fmt.Errorf("conn pool size: %v; expect 1", len(conns))
}
if activeStreams(conns[0]) != 0 {
return fmt.Errorf("active streams count: %v; want 0", activeStreams(conns[0]))
}
return nil
}
bodyReadError := errors.New("body read error")
body := &errReader{body, bodyReadError}
req, err := http.NewRequest("PUT", "https://dummy.tld/", body)
if err != nil {
return err
}
_, err = ct.tr.RoundTrip(req)
if err != bodyReadError {
return fmt.Errorf("err = %v; want %v", err, bodyReadError)
}
if err = checkNoStreams(); err != nil {
return err
}
return nil
}
ct.server = func() error {
ct.greet()
var receivedBody []byte
var resetCount int
for {
f, err := ct.fr.ReadFrame()
if err != nil {
select {
case <-clientDone:
// If the client's done, it
// will have reported any
// errors on its side.
if bytes.Compare(receivedBody, body) != 0 {
return fmt.Errorf("body: %v; expected %v", receivedBody, body)
}
if resetCount != 1 {
return fmt.Errorf("stream reset count: %v; expected: 1", resetCount)
}
return nil
default:
return err
}
}
switch f := f.(type) {
case *WindowUpdateFrame, *SettingsFrame:
case *HeadersFrame:
case *DataFrame:
receivedBody = append(receivedBody, f.Data()...)
case *RSTStreamFrame:
resetCount++
default:
return fmt.Errorf("Unexpected client frame %v", f)
}
}
}
ct.run()
}
func TestTransportBodyReadError_Immediately(t *testing.T) { testTransportBodyReadError(t, nil) }
func TestTransportBodyReadError_Some(t *testing.T) { testTransportBodyReadError(t, []byte("123")) }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment