Commit 961116ae authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

http2: support CONNECT requests

Support CONNECT requests in both the server & transport.

See https://httpwg.github.io/specs/rfc7540.html#CONNECT

When I bundle this into the main Go repo I will also add h1-vs-h2
compatibility tests there, making sure they match behavior. (I now
expect that they do match)

Updates golang/go#13717

Change-Id: I0c65ad47b029419027efb616fed3d8e0e2a363f4
Reviewed-on: https://go-review.googlesource.com/18266Reviewed-by: 's avatarAndrew Gerrand <adg@golang.org>
parent 3b90a77d
......@@ -1545,7 +1545,17 @@ func (sc *serverConn) resetPendingRequest() {
func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, error) {
sc.serveG.check()
rp := &sc.req
if rp.invalidHeader || rp.method == "" || rp.path == "" ||
if rp.invalidHeader {
return nil, nil, StreamError{rp.stream.id, ErrCodeProtocol}
}
isConnect := rp.method == "CONNECT"
if isConnect {
if rp.path != "" || rp.scheme != "" || rp.authority == "" {
return nil, nil, StreamError{rp.stream.id, ErrCodeProtocol}
}
} else if rp.method == "" || rp.path == "" ||
(rp.scheme != "https" && rp.scheme != "http") {
// See 8.1.2.6 Malformed Requests and Responses:
//
......@@ -1559,12 +1569,14 @@ func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, err
// pseudo-header fields"
return nil, nil, StreamError{rp.stream.id, ErrCodeProtocol}
}
bodyOpen := rp.stream.state == stateOpen
if rp.method == "HEAD" && bodyOpen {
// HEAD requests can't have bodies
return nil, nil, StreamError{rp.stream.id, ErrCodeProtocol}
}
var tlsState *tls.ConnectionState // nil if not scheme https
if rp.scheme == "https" {
tlsState = sc.tlsState
}
......@@ -1605,18 +1617,26 @@ func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, err
stream: rp.stream,
needsContinue: needsContinue,
}
// TODO: handle asterisk '*' requests + test
url, err := url.ParseRequestURI(rp.path)
if err != nil {
// TODO: find the right error code?
return nil, nil, StreamError{rp.stream.id, ErrCodeProtocol}
var url_ *url.URL
var requestURI string
if isConnect {
url_ = &url.URL{Host: rp.authority}
requestURI = rp.authority // mimic HTTP/1 server behavior
} else {
var err error
// TODO: handle asterisk '*' requests + test
url_, err = url.ParseRequestURI(rp.path)
if err != nil {
return nil, nil, StreamError{rp.stream.id, ErrCodeProtocol}
}
requestURI = rp.path
}
req := &http.Request{
Method: rp.method,
URL: url,
URL: url_,
RemoteAddr: sc.remoteAddrStr,
Header: rp.header,
RequestURI: rp.path,
RequestURI: requestURI,
Proto: "HTTP/2.0",
ProtoMajor: 2,
ProtoMinor: 0,
......
......@@ -901,6 +901,60 @@ func testRejectRequest(t *testing.T, send func(*serverTester)) {
st.wantRSTStream(1, ErrCodeProtocol)
}
func TestServer_Request_Connect(t *testing.T) {
testServerRequest(t, func(st *serverTester) {
st.writeHeaders(HeadersFrameParam{
StreamID: 1,
BlockFragment: st.encodeHeaderRaw(
":method", "CONNECT",
":authority", "example.com:123",
),
EndStream: true,
EndHeaders: true,
})
}, func(r *http.Request) {
if g, w := r.Method, "CONNECT"; g != w {
t.Errorf("Method = %q; want %q", g, w)
}
if g, w := r.RequestURI, "example.com:123"; g != w {
t.Errorf("RequestURI = %q; want %q", g, w)
}
if g, w := r.URL.Host, "example.com:123"; g != w {
t.Errorf("URL.Host = %q; want %q", g, w)
}
})
}
func TestServer_Request_Connect_InvalidPath(t *testing.T) {
testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) {
st.writeHeaders(HeadersFrameParam{
StreamID: 1,
BlockFragment: st.encodeHeaderRaw(
":method", "CONNECT",
":authority", "example.com:123",
":path", "/bogus",
),
EndStream: true,
EndHeaders: true,
})
})
}
func TestServer_Request_Connect_InvalidScheme(t *testing.T) {
testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) {
st.writeHeaders(HeadersFrameParam{
StreamID: 1,
BlockFragment: st.encodeHeaderRaw(
":method", "CONNECT",
":authority", "example.com:123",
":scheme", "https",
),
EndStream: true,
EndHeaders: true,
})
})
}
func TestServer_Ping(t *testing.T) {
st := newServerTester(t, nil)
defer st.Close()
......@@ -1222,7 +1276,7 @@ func TestServer_StateTransitions(t *testing.T) {
// test HEADERS w/o EndHeaders + another HEADERS (should get rejected)
func TestServer_Rejects_HeadersNoEnd_Then_Headers(t *testing.T) {
testServerRejects(t, func(st *serverTester) {
testServerRejectsConn(t, func(st *serverTester) {
st.writeHeaders(HeadersFrameParam{
StreamID: 1,
BlockFragment: st.encodeHeader(),
......@@ -1240,7 +1294,7 @@ func TestServer_Rejects_HeadersNoEnd_Then_Headers(t *testing.T) {
// test HEADERS w/o EndHeaders + PING (should get rejected)
func TestServer_Rejects_HeadersNoEnd_Then_Ping(t *testing.T) {
testServerRejects(t, func(st *serverTester) {
testServerRejectsConn(t, func(st *serverTester) {
st.writeHeaders(HeadersFrameParam{
StreamID: 1,
BlockFragment: st.encodeHeader(),
......@@ -1255,7 +1309,7 @@ func TestServer_Rejects_HeadersNoEnd_Then_Ping(t *testing.T) {
// test HEADERS w/ EndHeaders + a continuation HEADERS (should get rejected)
func TestServer_Rejects_HeadersEnd_Then_Continuation(t *testing.T) {
testServerRejects(t, func(st *serverTester) {
testServerRejectsConn(t, func(st *serverTester) {
st.writeHeaders(HeadersFrameParam{
StreamID: 1,
BlockFragment: st.encodeHeader(),
......@@ -1271,7 +1325,7 @@ func TestServer_Rejects_HeadersEnd_Then_Continuation(t *testing.T) {
// test HEADERS w/o EndHeaders + a continuation HEADERS on wrong stream ID
func TestServer_Rejects_HeadersNoEnd_Then_ContinuationWrongStream(t *testing.T) {
testServerRejects(t, func(st *serverTester) {
testServerRejectsConn(t, func(st *serverTester) {
st.writeHeaders(HeadersFrameParam{
StreamID: 1,
BlockFragment: st.encodeHeader(),
......@@ -1286,7 +1340,7 @@ func TestServer_Rejects_HeadersNoEnd_Then_ContinuationWrongStream(t *testing.T)
// No HEADERS on stream 0.
func TestServer_Rejects_Headers0(t *testing.T) {
testServerRejects(t, func(st *serverTester) {
testServerRejectsConn(t, func(st *serverTester) {
st.fr.AllowIllegalWrites = true
st.writeHeaders(HeadersFrameParam{
StreamID: 0,
......@@ -1299,7 +1353,7 @@ func TestServer_Rejects_Headers0(t *testing.T) {
// No CONTINUATION on stream 0.
func TestServer_Rejects_Continuation0(t *testing.T) {
testServerRejects(t, func(st *serverTester) {
testServerRejectsConn(t, func(st *serverTester) {
st.fr.AllowIllegalWrites = true
if err := st.fr.WriteContinuation(0, true, st.encodeHeader()); err != nil {
t.Fatal(err)
......@@ -1308,7 +1362,7 @@ func TestServer_Rejects_Continuation0(t *testing.T) {
}
func TestServer_Rejects_PushPromise(t *testing.T) {
testServerRejects(t, func(st *serverTester) {
testServerRejectsConn(t, func(st *serverTester) {
pp := PushPromiseParam{
StreamID: 1,
PromiseID: 3,
......@@ -1319,10 +1373,10 @@ func TestServer_Rejects_PushPromise(t *testing.T) {
})
}
// testServerRejects tests that the server hangs up with a GOAWAY
// testServerRejectsConn tests that the server hangs up with a GOAWAY
// frame and a server close after the client does something
// deserving a CONNECTION_ERROR.
func testServerRejects(t *testing.T, writeReq func(*serverTester)) {
func testServerRejectsConn(t *testing.T, writeReq func(*serverTester)) {
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {})
st.addLogFilter("connection error: PROTOCOL_ERROR")
defer st.Close()
......@@ -1348,6 +1402,16 @@ func testServerRejects(t *testing.T, writeReq func(*serverTester)) {
}
}
// testServerRejectsStream tests that the server sends a RST_STREAM with the provided
// error code after a client sends a bogus request.
func testServerRejectsStream(t *testing.T, code ErrCode, writeReq func(*serverTester)) {
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {})
defer st.Close()
st.greet()
writeReq(st)
st.wantRSTStream(1, code)
}
// testServerRequest sets up an idle HTTP/2 connection and lets you
// write a single request with writeReq, and then verify that the
// *http.Request is built correctly in checkReq.
......
......@@ -775,8 +775,10 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail
// [RFC3986]).
cc.writeHeader(":authority", host) // probably not right for all sites
cc.writeHeader(":method", req.Method)
cc.writeHeader(":path", req.URL.RequestURI())
cc.writeHeader(":scheme", "https")
if req.Method != "CONNECT" {
cc.writeHeader(":path", req.URL.RequestURI())
cc.writeHeader(":scheme", "https")
}
if trailers != "" {
cc.writeHeader("trailer", trailers)
}
......
......@@ -760,3 +760,62 @@ func TestTransportFullDuplex(t *testing.T) {
t.Fatal(err)
}
}
func TestTransportConnectRequest(t *testing.T) {
gotc := make(chan *http.Request, 1)
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
gotc <- r
}, optOnlyServer)
defer st.Close()
u, err := url.Parse(st.ts.URL)
if err != nil {
t.Fatal(err)
}
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
c := &http.Client{Transport: tr}
tests := []struct {
req *http.Request
want string
}{
{
req: &http.Request{
Method: "CONNECT",
Header: http.Header{},
URL: u,
},
want: u.Host,
},
{
req: &http.Request{
Method: "CONNECT",
Header: http.Header{},
URL: u,
Host: "example.com:123",
},
want: "example.com:123",
},
}
for i, tt := range tests {
res, err := c.Do(tt.req)
if err != nil {
t.Errorf("%d. RoundTrip = %v", i, err)
continue
}
res.Body.Close()
req := <-gotc
if req.Method != "CONNECT" {
t.Errorf("method = %q; want CONNECT", req.Method)
}
if req.Host != tt.want {
t.Errorf("Host = %q; want %q", req.Host, tt.want)
}
if req.URL.Host != tt.want {
t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want)
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment