Commit c94bffa2 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

http2: don't leaving hanging server goroutines after RST_STREAM from client

In general, clean up and simplify the handling of frame writing from
handler goroutines.  Always select on streams closing, and don't try
to pass around and re-use channels. It was too confusing. Instead,
reuse channels in a very local manner that's easy to reason about.

Thanks to Github user @pabbott0 (who has signed the Google CLA) for
the initial bug report and test cases.

Fixes bradfitz/http2#45

Change-Id: Ib72a87cb6e33a4bb118ae23d765ba594e9182ade
Reviewed-on: https://go-review.googlesource.com/15820Reviewed-by: 's avatarAndrew Gerrand <adg@golang.org>
parent 271cfc1e
......@@ -322,7 +322,7 @@ type serverConn struct {
wantWriteFrameCh chan frameWriteMsg // from handlers -> serve
wroteFrameCh chan struct{} // from writeFrameAsync -> serve, tickles more frame writes
bodyReadCh chan bodyReadMsg // from handlers -> serve
testHookCh chan func() // code to run on the serve loop
testHookCh chan func(int) // code to run on the serve loop
flow flow // conn-wide (not stream-specific) outbound flow control
inflow flow // conn-wide inbound flow control
tlsState *tls.ConnectionState // shared by all handlers, like net/http
......@@ -636,7 +636,9 @@ func (sc *serverConn) serve() {
go sc.readFrames() // closed by defer sc.conn.Close above
settingsTimer := time.NewTimer(firstSettingsTimeout)
loopNum := 0
for {
loopNum++
select {
case wm := <-sc.wantWriteFrameCh:
sc.writeFrame(wm)
......@@ -664,7 +666,7 @@ func (sc *serverConn) serve() {
sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr())
return
case fn := <-sc.testHookCh:
fn()
fn(loopNum)
}
}
}
......@@ -697,19 +699,20 @@ func (sc *serverConn) readPreface() error {
}
}
var errChanPool = sync.Pool{
New: func() interface{} { return make(chan error, 1) },
}
// writeDataFromHandler writes the data described in req to stream.id.
//
// The provided ch is used to avoid allocating new channels for each
// write operation. It's expected that the caller reuses writeData and ch
// over time.
//
// The flow control currently happens in the Handler where it waits
// for 1 or more bytes to be available to then write here. So at this
// point we know that we have flow control. But this might have to
// change when priority is implemented, so the serve goroutine knows
// the total amount of bytes waiting to be sent and can can have more
// scheduling decisions available.
func (sc *serverConn) writeDataFromHandler(stream *stream, writeData *writeData, ch chan error) error {
func (sc *serverConn) writeDataFromHandler(stream *stream, writeData *writeData) error {
ch := errChanPool.Get().(chan error)
sc.writeFrameFromHandler(frameWriteMsg{
write: writeData,
stream: stream,
......@@ -717,6 +720,7 @@ func (sc *serverConn) writeDataFromHandler(stream *stream, writeData *writeData,
})
select {
case err := <-ch:
errChanPool.Put(ch)
return err
case <-sc.doneServing:
return errClientDisconnected
......@@ -734,10 +738,22 @@ func (sc *serverConn) writeDataFromHandler(stream *stream, writeData *writeData,
// goroutine, call writeFrame instead.
func (sc *serverConn) writeFrameFromHandler(wm frameWriteMsg) {
sc.serveG.checkNotOn() // NOT
var scheduled bool
select {
case sc.wantWriteFrameCh <- wm:
scheduled = true
case <-sc.doneServing:
// Client has closed their connection to the server.
case <-wm.stream.cw:
// Stream closed.
}
// Don't block writers expecting a reply.
if !scheduled && wm.done != nil {
select {
case wm.done <- errStreamBroken:
default:
panic("expected buffered channel")
}
}
}
......@@ -1435,7 +1451,6 @@ func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, err
rws.stream = rp.stream
rws.req = req
rws.body = body
rws.frameWriteCh = make(chan error, 1)
rw := &responseWriter{rws: rws}
return rw, req, nil
......@@ -1460,7 +1475,7 @@ func handleHeaderListTooLong(w http.ResponseWriter, r *http.Request) {
// called from handler goroutines.
// h may be nil.
func (sc *serverConn) writeHeaders(st *stream, headerData *writeResHeaders, tempCh chan error) {
func (sc *serverConn) writeHeaders(st *stream, headerData *writeResHeaders) {
sc.serveG.checkNotOn() // NOT on
var errc chan error
if headerData.h != nil {
......@@ -1468,7 +1483,7 @@ func (sc *serverConn) writeHeaders(st *stream, headerData *writeResHeaders, temp
// waiting for this frame to be written, so an http.Flush mid-handler
// writes out the correct value of keys, before a handler later potentially
// mutates it.
errc = tempCh
errc = errChanPool.Get().(chan error)
}
sc.writeFrameFromHandler(frameWriteMsg{
write: headerData,
......@@ -1480,8 +1495,11 @@ func (sc *serverConn) writeHeaders(st *stream, headerData *writeResHeaders, temp
case <-errc:
// Ignore. Just for synchronization.
// Any error will be handled in the writing goroutine.
errChanPool.Put(errc)
case <-sc.doneServing:
// Client has closed the connection.
case <-st.cw:
// Client did RST_STREAM, etc. (but conn still alive)
}
}
}
......@@ -1629,7 +1647,6 @@ type responseWriterState struct {
sentHeader bool // have we sent the header frame?
handlerDone bool // handler has finished
curWrite writeData
frameWriteCh chan error // re-used whenever we need to block on a frame being written
closeNotifierMu sync.Mutex // guards closeNotifierCh
closeNotifierCh chan bool // nil until first used
......@@ -1666,7 +1683,7 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
endStream: endStream,
contentType: ctype,
contentLength: clen,
}, rws.frameWriteCh)
})
if endStream {
return 0, nil
}
......@@ -1678,7 +1695,7 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
curWrite.streamID = rws.stream.id
curWrite.p = p
curWrite.endStream = rws.handlerDone
if err := rws.conn.writeDataFromHandler(rws.stream, curWrite, rws.frameWriteCh); err != nil {
if err := rws.conn.writeDataFromHandler(rws.stream, curWrite); err != nil {
return 0, err
}
return len(p), nil
......
......@@ -125,7 +125,7 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{}
st.scMu.Lock()
defer st.scMu.Unlock()
st.sc = v
st.sc.testHookCh = make(chan func())
st.sc.testHookCh = make(chan func(int))
}
log.SetOutput(io.MultiWriter(stderrv(), twriter{t: t, st: st}))
if !onlyServer {
......@@ -152,7 +152,7 @@ func (st *serverTester) addLogFilter(phrase string) {
func (st *serverTester) stream(id uint32) *stream {
ch := make(chan *stream, 1)
st.sc.testHookCh <- func() {
st.sc.testHookCh <- func(int) {
ch <- st.sc.streams[id]
}
return <-ch
......@@ -160,13 +160,39 @@ func (st *serverTester) stream(id uint32) *stream {
func (st *serverTester) streamState(id uint32) streamState {
ch := make(chan streamState, 1)
st.sc.testHookCh <- func() {
st.sc.testHookCh <- func(int) {
state, _ := st.sc.state(id)
ch <- state
}
return <-ch
}
// loopNum reports how many times this conn's select loop has gone around.
func (st *serverTester) loopNum() int {
lastc := make(chan int, 1)
st.sc.testHookCh <- func(loopNum int) {
lastc <- loopNum
}
return <-lastc
}
// awaitIdle heuristically awaits for the server conn's select loop to be idle.
// The heuristic is that the server connection's serve loop must schedule
// 50 times in a row without any channel sends or receives occuring.
func (st *serverTester) awaitIdle() {
remain := 50
last := st.loopNum()
for remain > 0 {
n := st.loopNum()
if n == last+1 {
remain--
} else {
remain = 50
}
last = n
}
}
func (st *serverTester) Close() {
st.ts.Close()
if st.cc != nil {
......@@ -1028,6 +1054,56 @@ func TestServer_RSTStream_Unblocks_Read(t *testing.T) {
)
}
func TestServer_RSTStream_Unblocks_Header_Write(t *testing.T) {
// Run this test a bunch, because it doesn't always
// deadlock. But with a bunch, it did.
n := 50
if testing.Short() {
n = 5
}
for i := 0; i < n; i++ {
testServer_RSTStream_Unblocks_Header_Write(t)
}
}
func testServer_RSTStream_Unblocks_Header_Write(t *testing.T) {
inHandler := make(chan bool, 1)
unblockHandler := make(chan bool, 1)
headerWritten := make(chan bool, 1)
wroteRST := make(chan bool, 1)
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
inHandler <- true
<-wroteRST
w.Header().Set("foo", "bar")
w.WriteHeader(200)
w.(http.Flusher).Flush()
headerWritten <- true
<-unblockHandler
})
defer st.Close()
st.greet()
st.writeHeaders(HeadersFrameParam{
StreamID: 1,
BlockFragment: st.encodeHeader(":method", "POST"),
EndStream: false, // keep it open
EndHeaders: true,
})
<-inHandler
if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil {
t.Fatal(err)
}
wroteRST <- true
st.awaitIdle()
select {
case <-headerWritten:
case <-time.After(2 * time.Second):
t.Error("timeout waiting for header write")
}
unblockHandler <- true
}
func TestServer_DeadConn_Unblocks_Read(t *testing.T) {
testServerPostUnblock(t,
func(w http.ResponseWriter, r *http.Request) (err error) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment