Commit ea0da6f3 authored by Tom Bergan's avatar Tom Bergan

http2: remove afterReqBodyWriteError wrapper

There was a case where we forgot to undo this wrapper. Instead of fixing
that case, I moved the implementation of ClientConn.RoundTrip into an
unexported method that returns the same info as a bool.

Fixes golang/go#22136

Change-Id: I7e5fc467f9c26fb74b9b83f2b3b7f8882645e34c
Reviewed-on: https://go-review.googlesource.com/75252Reviewed-by: 's avatarBrad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent c73622c7
...@@ -274,6 +274,13 @@ func (cs *clientStream) checkResetOrDone() error { ...@@ -274,6 +274,13 @@ func (cs *clientStream) checkResetOrDone() error {
} }
} }
func (cs *clientStream) getStartedWrite() bool {
cc := cs.cc
cc.mu.Lock()
defer cc.mu.Unlock()
return cs.startedWrite
}
func (cs *clientStream) abortRequestBodyWrite(err error) { func (cs *clientStream) abortRequestBodyWrite(err error) {
if err == nil { if err == nil {
panic("nil error") panic("nil error")
...@@ -349,14 +356,9 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res ...@@ -349,14 +356,9 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res
return nil, err return nil, err
} }
traceGotConn(req, cc) traceGotConn(req, cc)
res, err := cc.RoundTrip(req) res, gotErrAfterReqBodyWrite, err := cc.roundTrip(req)
if err != nil && retry <= 6 { if err != nil && retry <= 6 {
afterBodyWrite := false if req, err = shouldRetryRequest(req, err, gotErrAfterReqBodyWrite); err == nil {
if e, ok := err.(afterReqBodyWriteError); ok {
err = e
afterBodyWrite = true
}
if req, err = shouldRetryRequest(req, err, afterBodyWrite); err == nil {
// After the first retry, do exponential backoff with 10% jitter. // After the first retry, do exponential backoff with 10% jitter.
if retry == 0 { if retry == 0 {
continue continue
...@@ -394,16 +396,6 @@ var ( ...@@ -394,16 +396,6 @@ var (
errClientConnGotGoAway = errors.New("http2: Transport received Server's graceful shutdown GOAWAY") errClientConnGotGoAway = errors.New("http2: Transport received Server's graceful shutdown GOAWAY")
) )
// afterReqBodyWriteError is a wrapper around errors returned by ClientConn.RoundTrip.
// It is used to signal that err happened after part of Request.Body was sent to the server.
type afterReqBodyWriteError struct {
err error
}
func (e afterReqBodyWriteError) Error() string {
return e.err.Error() + "; some request body already written"
}
// shouldRetryRequest is called by RoundTrip when a request fails to get // shouldRetryRequest is called by RoundTrip when a request fails to get
// response headers. It is always called with a non-nil error. // response headers. It is always called with a non-nil error.
// It returns either a request to retry (either the same request, or a // It returns either a request to retry (either the same request, or a
...@@ -752,8 +744,13 @@ func actualContentLength(req *http.Request) int64 { ...@@ -752,8 +744,13 @@ func actualContentLength(req *http.Request) int64 {
} }
func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
resp, _, err := cc.roundTrip(req)
return resp, err
}
func (cc *ClientConn) roundTrip(req *http.Request) (res *http.Response, gotErrAfterReqBodyWrite bool, err error) {
if err := checkConnHeaders(req); err != nil { if err := checkConnHeaders(req); err != nil {
return nil, err return nil, false, err
} }
if cc.idleTimer != nil { if cc.idleTimer != nil {
cc.idleTimer.Stop() cc.idleTimer.Stop()
...@@ -761,14 +758,14 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { ...@@ -761,14 +758,14 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
trailers, err := commaSeparatedTrailers(req) trailers, err := commaSeparatedTrailers(req)
if err != nil { if err != nil {
return nil, err return nil, false, err
} }
hasTrailers := trailers != "" hasTrailers := trailers != ""
cc.mu.Lock() cc.mu.Lock()
if err := cc.awaitOpenSlotForRequest(req); err != nil { if err := cc.awaitOpenSlotForRequest(req); err != nil {
cc.mu.Unlock() cc.mu.Unlock()
return nil, err return nil, false, err
} }
body := req.Body body := req.Body
...@@ -802,7 +799,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { ...@@ -802,7 +799,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
hdrs, err := cc.encodeHeaders(req, requestedGzip, trailers, contentLen) hdrs, err := cc.encodeHeaders(req, requestedGzip, trailers, contentLen)
if err != nil { if err != nil {
cc.mu.Unlock() cc.mu.Unlock()
return nil, err return nil, false, err
} }
cs := cc.newStream() cs := cc.newStream()
...@@ -828,7 +825,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { ...@@ -828,7 +825,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
// Don't bother sending a RST_STREAM (our write already failed; // Don't bother sending a RST_STREAM (our write already failed;
// no need to keep writing) // no need to keep writing)
traceWroteRequest(cs.trace, werr) traceWroteRequest(cs.trace, werr)
return nil, werr return nil, false, werr
} }
var respHeaderTimer <-chan time.Time var respHeaderTimer <-chan time.Time
...@@ -847,7 +844,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { ...@@ -847,7 +844,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
bodyWritten := false bodyWritten := false
ctx := reqContext(req) ctx := reqContext(req)
handleReadLoopResponse := func(re resAndError) (*http.Response, error) { handleReadLoopResponse := func(re resAndError) (*http.Response, bool, error) {
res := re.res res := re.res
if re.err != nil || res.StatusCode > 299 { if re.err != nil || res.StatusCode > 299 {
// On error or status code 3xx, 4xx, 5xx, etc abort any // On error or status code 3xx, 4xx, 5xx, etc abort any
...@@ -863,18 +860,12 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { ...@@ -863,18 +860,12 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
cs.abortRequestBodyWrite(errStopReqBodyWrite) cs.abortRequestBodyWrite(errStopReqBodyWrite)
} }
if re.err != nil { if re.err != nil {
cc.mu.Lock()
afterBodyWrite := cs.startedWrite
cc.mu.Unlock()
cc.forgetStreamID(cs.ID) cc.forgetStreamID(cs.ID)
if afterBodyWrite { return nil, cs.getStartedWrite(), re.err
return nil, afterReqBodyWriteError{re.err}
}
return nil, re.err
} }
res.Request = req res.Request = req
res.TLS = cc.tlsState res.TLS = cc.tlsState
return res, nil return res, false, nil
} }
for { for {
...@@ -889,7 +880,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { ...@@ -889,7 +880,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel) cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
} }
cc.forgetStreamID(cs.ID) cc.forgetStreamID(cs.ID)
return nil, errTimeout return nil, cs.getStartedWrite(), errTimeout
case <-ctx.Done(): case <-ctx.Done():
if !hasBody || bodyWritten { if !hasBody || bodyWritten {
cc.writeStreamReset(cs.ID, ErrCodeCancel, nil) cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
...@@ -898,7 +889,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { ...@@ -898,7 +889,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel) cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
} }
cc.forgetStreamID(cs.ID) cc.forgetStreamID(cs.ID)
return nil, ctx.Err() return nil, cs.getStartedWrite(), ctx.Err()
case <-req.Cancel: case <-req.Cancel:
if !hasBody || bodyWritten { if !hasBody || bodyWritten {
cc.writeStreamReset(cs.ID, ErrCodeCancel, nil) cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
...@@ -907,12 +898,12 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { ...@@ -907,12 +898,12 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel) cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
} }
cc.forgetStreamID(cs.ID) cc.forgetStreamID(cs.ID)
return nil, errRequestCanceled return nil, cs.getStartedWrite(), errRequestCanceled
case <-cs.peerReset: case <-cs.peerReset:
// processResetStream already removed the // processResetStream already removed the
// stream from the streams map; no need for // stream from the streams map; no need for
// forgetStreamID. // forgetStreamID.
return nil, cs.resetErr return nil, cs.getStartedWrite(), cs.resetErr
case err := <-bodyWriter.resc: case err := <-bodyWriter.resc:
// Prefer the read loop's response, if available. Issue 16102. // Prefer the read loop's response, if available. Issue 16102.
select { select {
...@@ -921,7 +912,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { ...@@ -921,7 +912,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
default: default:
} }
if err != nil { if err != nil {
return nil, err return nil, cs.getStartedWrite(), err
} }
bodyWritten = true bodyWritten = true
if d := cc.responseHeaderTimeout(); d != 0 { if d := cc.responseHeaderTimeout(); d != 0 {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment