Commit f5de73ef authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

http2: make Transport respect http1 Transport settings

The http2 Transport now respects the http1 Transport's
DisableCompression, DisableKeepAlives, and ResponseHeaderTimeout, if
the http2 and http1 Transports are wired up together, as they are in
the upcoming Go 1.6.

Updates golang/go#14008

Change-Id: I2f477f6fe5dbef9d0e5439dfc7f3ec2c0da7f296
Reviewed-on: https://go-review.googlesource.com/18721Reviewed-by: 's avatarAndrew Gerrand <adg@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent 5c0dae85
......@@ -14,7 +14,11 @@ import (
func configureTransport(t1 *http.Transport) (*Transport, error) {
connPool := new(clientConnPool)
t2 := &Transport{ConnPool: noDialClientConnPool{connPool}}
t2 := &Transport{
ConnPool: noDialClientConnPool{connPool},
t1: t1,
}
connPool.t = t2
if err := registerHTTPSProtocol(t1, noDialH2RoundTripper{t2}); err != nil {
return nil, err
}
......
......@@ -285,3 +285,14 @@ func bodyAllowedForStatus(status int) bool {
}
return true
}
type httpError struct {
msg string
timeout bool
}
func (e *httpError) Error() string { return e.msg }
func (e *httpError) Timeout() bool { return e.timeout }
func (e *httpError) Temporary() bool { return true }
var errTimeout error = &httpError{msg: "http2: timeout awaiting response headers", timeout: true}
......@@ -22,6 +22,7 @@ import (
"strconv"
"strings"
"sync"
"time"
"golang.org/x/net/http2/hpack"
)
......@@ -84,6 +85,11 @@ type Transport struct {
// to mean no limit.
MaxHeaderListSize uint32
// t1, if non-nil, is the standard library Transport using
// this transport. Its settings are used (but not its
// RoundTrip method, etc).
t1 *http.Transport
connPoolOnce sync.Once
connPoolOrDef ClientConnPool // non-nil version of ConnPool
}
......@@ -99,12 +105,7 @@ func (t *Transport) maxHeaderListSize() uint32 {
}
func (t *Transport) disableCompression() bool {
if t.DisableCompression {
return true
}
// TODO: also disable if this transport is somehow linked to an http1 Transport
// and it's configured there?
return false
return t.DisableCompression || (t.t1 != nil && t.t1.DisableCompression)
}
var errTransportVersion = errors.New("http2: ConfigureTransport is only supported starting at Go 1.6")
......@@ -160,7 +161,7 @@ type ClientConn struct {
henc *hpack.Encoder
freeBuf [][]byte
wmu sync.Mutex // held while writing; acquire AFTER wmu if holding both
wmu sync.Mutex // held while writing; acquire AFTER mu if holding both
werr error // first write error that has occurred
}
......@@ -178,7 +179,7 @@ type clientStream struct {
inflow flow // guarded by cc.mu
bytesRemain int64 // -1 means unknown; owned by transportResponseBody.Read
readErr error // sticky read error; owned by transportResponseBody.Read
stopReqBody bool // stop writing req body; guarded by cc.mu
stopReqBody error // if non-nil, stop writing req body; guarded by cc.mu
peerReset chan struct{} // closed on peer reset
resetErr error // populated before peerReset is closed
......@@ -221,10 +222,13 @@ func (cs *clientStream) checkReset() error {
}
}
func (cs *clientStream) abortRequestBodyWrite() {
func (cs *clientStream) abortRequestBodyWrite(err error) {
if err == nil {
panic("nil error")
}
cc := cs.cc
cc.mu.Lock()
cs.stopReqBody = true
cs.stopReqBody = err
cc.cond.Broadcast()
cc.mu.Unlock()
}
......@@ -364,6 +368,12 @@ func (t *Transport) dialTLSDefault(network, addr string, cfg *tls.Config) (net.C
return cn, nil
}
// disableKeepAlives reports whether connections should be closed as
// soon as possible.
func (t *Transport) disableKeepAlives() bool {
return t.t1 != nil && t.t1.DisableKeepAlives
}
func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) {
if VerboseLogs {
t.vlogf("http2: Transport creating client conn to %v", c.RemoteAddr())
......@@ -463,7 +473,7 @@ func (cc *ClientConn) CanTakeNewRequest() bool {
}
func (cc *ClientConn) canTakeNewRequestLocked() bool {
return cc.goAway == nil &&
return cc.goAway == nil && !cc.closed &&
int64(len(cc.streams)+1) < int64(cc.maxConcurrentStreams) &&
cc.nextStreamID < 2147483647
}
......@@ -544,6 +554,17 @@ func commaSeparatedTrailers(req *http.Request) (string, error) {
return "", nil
}
func (cc *ClientConn) responseHeaderTimeout() time.Duration {
if cc.t.t1 != nil {
return cc.t.t1.ResponseHeaderTimeout
}
// No way to do this (yet?) with just an http2.Transport. Probably
// no need. Request.Cancel this is the new way. We only need to support
// this for compatibility with the old http.Transport fields when
// we're doing transparent http2.
return 0
}
func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
trailers, err := commaSeparatedTrailers(req)
if err != nil {
......@@ -623,17 +644,25 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
return nil, werr
}
var respHeaderTimer <-chan time.Time
var bodyCopyErrc chan error // result of body copy
if hasBody {
bodyCopyErrc = make(chan error, 1)
go func() {
bodyCopyErrc <- cs.writeRequestBody(body, req.Body)
}()
} else {
if d := cc.responseHeaderTimeout(); d != 0 {
timer := time.NewTimer(d)
defer timer.Stop()
respHeaderTimer = timer.C
}
}
readLoopResCh := cs.resc
requestCanceledCh := requestCancel(req)
requestCanceled := false
bodyWritten := false
for {
select {
case re := <-readLoopResCh:
......@@ -648,7 +677,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
// doesn't, they'll RST_STREAM us soon enough. This is a
// heuristic to avoid adding knobs to Transport. Hopefully
// we can keep it.
cs.abortRequestBodyWrite()
cs.abortRequestBodyWrite(errStopReqBodyWrite)
}
if re.err != nil {
cc.forgetStreamID(cs.ID)
......@@ -657,37 +686,37 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
res.Request = req
res.TLS = cc.tlsState
return res, nil
case <-respHeaderTimer:
cc.forgetStreamID(cs.ID)
if !hasBody || bodyWritten {
cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
} else {
cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
}
return nil, errTimeout
case <-requestCanceledCh:
cc.forgetStreamID(cs.ID)
cs.abortRequestBodyWrite()
if !hasBody {
if !hasBody || bodyWritten {
cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
return nil, errRequestCanceled
} else {
cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
}
// If we have a body, wait for the body write to be
// finished before sending the RST_STREAM frame.
requestCanceled = true
requestCanceledCh = nil // to prevent spins
readLoopResCh = nil // ignore responses at this point
case <-cs.peerReset:
if requestCanceled {
// They hung up on us first. No need to write a RST_STREAM.
// But prioritize the request canceled error value, since
// it's likely related. (same spirit as http1 code)
return nil, errRequestCanceled
}
case <-cs.peerReset:
// processResetStream already removed the
// stream from the streams map; no need for
// forgetStreamID.
return nil, cs.resetErr
case err := <-bodyCopyErrc:
if requestCanceled {
cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
return nil, errRequestCanceled
}
if err != nil {
return nil, err
}
bodyWritten = true
if d := cc.responseHeaderTimeout(); d != 0 {
timer := time.NewTimer(d)
defer timer.Stop()
respHeaderTimer = timer.C
}
}
}
}
......@@ -723,9 +752,14 @@ func (cc *ClientConn) writeHeaders(streamID uint32, endStream bool, hdrs []byte)
return cc.werr
}
// errAbortReqBodyWrite is an internal error value.
// It doesn't escape to callers.
var errAbortReqBodyWrite = errors.New("http2: aborting request body write")
// internal error values; they don't escape to callers
var (
// abort request body write; don't send cancel
errStopReqBodyWrite = errors.New("http2: aborting request body write")
// abort request body write, but send stream reset of cancel.
errStopReqBodyWriteAndCancel = errors.New("http2: canceling request")
)
func (cs *clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) (err error) {
cc := cs.cc
......@@ -761,7 +795,13 @@ func (cs *clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) (
for len(remain) > 0 && err == nil {
var allowed int32
allowed, err = cs.awaitFlowControl(len(remain))
if err != nil {
switch {
case err == errStopReqBodyWrite:
return err
case err == errStopReqBodyWriteAndCancel:
cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
return err
case err != nil:
return err
}
cc.wmu.Lock()
......@@ -821,8 +861,8 @@ func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error)
if cc.closed {
return 0, errClientConnClosed
}
if cs.stopReqBody {
return 0, errAbortReqBodyWrite
if cs.stopReqBody != nil {
return 0, cs.stopReqBody
}
if err := cs.checkReset(); err != nil {
return 0, err
......@@ -898,7 +938,7 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail
cc.writeHeader(lowKey, v)
}
}
if contentLength >= 0 {
if shouldSendReqContentLength(req.Method, contentLength) {
cc.writeHeader("content-length", strconv.FormatInt(contentLength, 10))
}
if addGzipHeader {
......@@ -910,6 +950,28 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail
return cc.hbuf.Bytes()
}
// shouldSendReqContentLength reports whether the http2.Transport should send
// a "content-length" request header. This logic is basically a copy of the net/http
// transferWriter.shouldSendContentLength.
// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
// -1 means unknown.
func shouldSendReqContentLength(method string, contentLength int64) bool {
if contentLength > 0 {
return true
}
if contentLength < 0 {
return false
}
// For zero bodies, whether we send a content-length depends on the method.
// It also kinda doesn't matter for http2 either way, with END_STREAM.
switch method {
case "POST", "PUT", "PATCH":
return true
default:
return false
}
}
// requires cc.mu be held.
func (cc *ClientConn) encodeTrailers(req *http.Request) []byte {
cc.hbuf.Reset()
......@@ -1032,6 +1094,8 @@ func (rl *clientConnReadLoop) cleanup() {
func (rl *clientConnReadLoop) run() error {
cc := rl.cc
closeWhenIdle := cc.t.disableKeepAlives()
gotReply := false // ever saw a reply
for {
f, err := cc.fr.ReadFrame()
if err != nil {
......@@ -1046,18 +1110,25 @@ func (rl *clientConnReadLoop) run() error {
if VerboseLogs {
cc.vlogf("http2: Transport received %s", summarizeFrame(f))
}
maybeClose := false // whether frame might transition us to idle
switch f := f.(type) {
case *HeadersFrame:
err = rl.processHeaders(f)
maybeClose = true
gotReply = true
case *ContinuationFrame:
err = rl.processContinuation(f)
maybeClose = true
case *DataFrame:
err = rl.processData(f)
maybeClose = true
case *GoAwayFrame:
err = rl.processGoAway(f)
maybeClose = true
case *RSTStreamFrame:
err = rl.processResetStream(f)
maybeClose = true
case *SettingsFrame:
err = rl.processSettings(f)
case *PushPromiseFrame:
......@@ -1072,6 +1143,9 @@ func (rl *clientConnReadLoop) run() error {
if err != nil {
return err
}
if closeWhenIdle && gotReply && maybeClose && len(rl.activeRes) == 0 {
cc.closeIfIdle()
}
}
}
......
......@@ -99,7 +99,6 @@ func TestTransport(t *testing.T) {
} else if string(slurp) != body {
t.Errorf("Body = %q; want %q", slurp, body)
}
}
func TestTransportReusesConns(t *testing.T) {
......@@ -1318,3 +1317,225 @@ func TestTransportDoubleCloseOnWriteError(t *testing.T) {
c := &http.Client{Transport: tr}
c.Get(st.ts.URL)
}
// Test that the http1 Transport.DisableKeepAlives option is respected
// and connections are closed as soon as idle.
// See golang.org/issue/14008
func TestTransportDisableKeepAlives(t *testing.T) {
st := newServerTester(t,
func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, "hi")
},
optOnlyServer,
)
defer st.Close()
connClosed := make(chan struct{}) // closed on tls.Conn.Close
tr := &Transport{
t1: &http.Transport{
DisableKeepAlives: true,
},
TLSClientConfig: tlsConfigInsecure,
DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
tc, err := tls.Dial(network, addr, cfg)
if err != nil {
return nil, err
}
return &noteCloseConn{Conn: tc, closefn: func() { close(connClosed) }}, nil
},
}
c := &http.Client{Transport: tr}
res, err := c.Get(st.ts.URL)
if err != nil {
t.Fatal(err)
}
if _, err := ioutil.ReadAll(res.Body); err != nil {
t.Fatal(err)
}
defer res.Body.Close()
select {
case <-connClosed:
case <-time.After(1 * time.Second):
t.Errorf("timeout")
}
}
// Test concurrent requests with Transport.DisableKeepAlives. We can share connections,
// but when things are totally idle, it still needs to close.
func TestTransportDisableKeepAlives_Concurrency(t *testing.T) {
const D = 25 * time.Millisecond
st := newServerTester(t,
func(w http.ResponseWriter, r *http.Request) {
time.Sleep(D)
io.WriteString(w, "hi")
},
optOnlyServer,
)
defer st.Close()
var dials int32
var conns sync.WaitGroup
tr := &Transport{
t1: &http.Transport{
DisableKeepAlives: true,
},
TLSClientConfig: tlsConfigInsecure,
DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
tc, err := tls.Dial(network, addr, cfg)
if err != nil {
return nil, err
}
atomic.AddInt32(&dials, 1)
conns.Add(1)
return &noteCloseConn{Conn: tc, closefn: func() { conns.Done() }}, nil
},
}
c := &http.Client{Transport: tr}
var reqs sync.WaitGroup
const N = 20
for i := 0; i < N; i++ {
reqs.Add(1)
if i == N-1 {
// For the final request, try to make all the
// others close. This isn't verified in the
// count, other than the Log statement, since
// it's so timing dependent. This test is
// really to make sure we don't interrupt a
// valid request.
time.Sleep(D * 2)
}
go func() {
defer reqs.Done()
res, err := c.Get(st.ts.URL)
if err != nil {
t.Error(err)
return
}
if _, err := ioutil.ReadAll(res.Body); err != nil {
t.Error(err)
return
}
res.Body.Close()
}()
}
reqs.Wait()
conns.Wait()
t.Logf("did %d dials, %d requests", atomic.LoadInt32(&dials), N)
}
type noteCloseConn struct {
net.Conn
onceClose sync.Once
closefn func()
}
func (c *noteCloseConn) Close() error {
c.onceClose.Do(c.closefn)
return c.Conn.Close()
}
func isTimeout(err error) bool {
switch err := err.(type) {
case nil:
return false
case *url.Error:
return isTimeout(err.Err)
case net.Error:
return err.Timeout()
}
return false
}
// Test that the http1 Transport.ResponseHeaderTimeout option and cancel is sent.
func TestTransportResponseHeaderTimeout_NoBody(t *testing.T) {
testTransportResponseHeaderTimeout(t, false)
}
func TestTransportResponseHeaderTimeout_Body(t *testing.T) {
testTransportResponseHeaderTimeout(t, true)
}
func testTransportResponseHeaderTimeout(t *testing.T, body bool) {
ct := newClientTester(t)
ct.tr.t1 = &http.Transport{
ResponseHeaderTimeout: 5 * time.Millisecond,
}
ct.client = func() error {
c := &http.Client{Transport: ct.tr}
var err error
var n int64
const bodySize = 4 << 20
if body {
_, err = c.Post("https://dummy.tld/", "text/foo", io.LimitReader(countingReader{&n}, bodySize))
} else {
_, err = c.Get("https://dummy.tld/")
}
if !isTimeout(err) {
t.Errorf("client expected timeout error; got %#v", err)
}
if body && n != bodySize {
t.Errorf("only read %d bytes of body; want %d", n, bodySize)
}
return nil
}
ct.server = func() error {
ct.greet()
for {
f, err := ct.fr.ReadFrame()
if err != nil {
t.Logf("ReadFrame: %v", err)
return nil
}
switch f := f.(type) {
case *DataFrame:
dataLen := len(f.Data())
if dataLen > 0 {
if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil {
return err
}
if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil {
return err
}
}
case *RSTStreamFrame:
if f.StreamID == 1 && f.ErrCode == ErrCodeCancel {
return nil
}
}
}
return nil
}
ct.run()
}
func TestTransportDisableCompression(t *testing.T) {
const body = "sup"
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
want := http.Header{
"User-Agent": []string{"Go-http-client/2.0"},
}
if !reflect.DeepEqual(r.Header, want) {
t.Errorf("request headers = %v; want %v", r.Header, want)
}
}, optOnlyServer)
defer st.Close()
tr := &Transport{
TLSClientConfig: tlsConfigInsecure,
t1: &http.Transport{
DisableCompression: true,
},
}
defer tr.CloseIdleConnections()
req, err := http.NewRequest("GET", st.ts.URL, nil)
if err != nil {
t.Fatal(err)
}
res, err := tr.RoundTrip(req)
if err != nil {
t.Fatal(err)
}
defer res.Body.Close()
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment