Commit b7f5d985 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

http2: change the pipe and buffer code

Make the pipe code take an interface as the backing store.  Now a pipe
is something that's goroutine-safe and does the Cond waits but its underlying data
is now an interface: anything that's a ReaderWriter with a Len method (such as a
*bytes.Buffer), or a fixedBuffer (renamed in this CL from 'buffer').

This opens the ground to having a non-fixed buffer used with pipe.

This also moves the CloseWithError code up into the pipe code, out of
fixedBuffer.

Change-Id: Ia3b853e8aa8920807b705ff4e41bed934a8c67b7
Reviewed-on: https://go-review.googlesource.com/16312Reviewed-by: 's avatarBlake Mizerany <blake.mizerany@gmail.com>
parent 2cba614e
...@@ -8,46 +8,41 @@ import ( ...@@ -8,46 +8,41 @@ import (
"errors" "errors"
) )
// buffer is an io.ReadWriteCloser backed by a fixed size buffer. // fixedBuffer is an io.ReadWriter backed by a fixed size buffer.
// It never allocates, but moves old data as new data is written. // It never allocates, but moves old data as new data is written.
type buffer struct { type fixedBuffer struct {
buf []byte buf []byte
r, w int r, w int
closed bool
err error // err to return to reader
} }
var ( var (
errReadEmpty = errors.New("read from empty buffer") errReadEmpty = errors.New("read from empty fixedBuffer")
errWriteClosed = errors.New("write on closed buffer") errWriteFull = errors.New("write on full fixedBuffer")
errWriteFull = errors.New("write on full buffer")
) )
// Read copies bytes from the buffer into p. // Read copies bytes from the buffer into p.
// It is an error to read when no data is available. // It is an error to read when no data is available.
func (b *buffer) Read(p []byte) (n int, err error) { func (b *fixedBuffer) Read(p []byte) (n int, err error) {
if b.r == b.w {
return 0, errReadEmpty
}
n = copy(p, b.buf[b.r:b.w]) n = copy(p, b.buf[b.r:b.w])
b.r += n b.r += n
if b.closed && b.r == b.w { if b.r == b.w {
err = b.err b.r = 0
} else if b.r == b.w && n == 0 { b.w = 0
err = errReadEmpty
} }
return n, err return n, nil
} }
// Len returns the number of bytes of the unread portion of the buffer. // Len returns the number of bytes of the unread portion of the buffer.
func (b *buffer) Len() int { func (b *fixedBuffer) Len() int {
return b.w - b.r return b.w - b.r
} }
// Write copies bytes from p into the buffer. // Write copies bytes from p into the buffer.
// It is an error to write more data than the buffer can hold. // It is an error to write more data than the buffer can hold.
func (b *buffer) Write(p []byte) (n int, err error) { func (b *fixedBuffer) Write(p []byte) (n int, err error) {
if b.closed {
return 0, errWriteClosed
}
// Slide existing data to beginning. // Slide existing data to beginning.
if b.r > 0 && len(p) > len(b.buf)-b.w { if b.r > 0 && len(p) > len(b.buf)-b.w {
copy(b.buf, b.buf[b.r:b.w]) copy(b.buf, b.buf[b.r:b.w])
...@@ -63,13 +58,3 @@ func (b *buffer) Write(p []byte) (n int, err error) { ...@@ -63,13 +58,3 @@ func (b *buffer) Write(p []byte) (n int, err error) {
} }
return n, err return n, err
} }
// Close marks the buffer as closed. Future calls to Write will
// return an error. Future calls to Read, once the buffer is
// empty, will return err.
func (b *buffer) Close(err error) {
if !b.closed {
b.closed = true
b.err = err
}
}
...@@ -5,47 +5,36 @@ ...@@ -5,47 +5,36 @@
package http2 package http2
import ( import (
"io"
"reflect" "reflect"
"testing" "testing"
) )
var bufferReadTests = []struct { var bufferReadTests = []struct {
buf buffer buf fixedBuffer
read, wn int read, wn int
werr error werr error
wp []byte wp []byte
wbuf buffer wbuf fixedBuffer
}{ }{
{ {
buffer{[]byte{'a', 0}, 0, 1, false, nil}, fixedBuffer{[]byte{'a', 0}, 0, 1},
5, 1, nil, []byte{'a'}, 5, 1, nil, []byte{'a'},
buffer{[]byte{'a', 0}, 1, 1, false, nil}, fixedBuffer{[]byte{'a', 0}, 0, 0},
}, },
{ {
buffer{[]byte{'a', 0}, 0, 1, true, io.EOF}, fixedBuffer{[]byte{0, 'a'}, 1, 2},
5, 1, io.EOF, []byte{'a'},
buffer{[]byte{'a', 0}, 1, 1, true, io.EOF},
},
{
buffer{[]byte{0, 'a'}, 1, 2, false, nil},
5, 1, nil, []byte{'a'}, 5, 1, nil, []byte{'a'},
buffer{[]byte{0, 'a'}, 2, 2, false, nil}, fixedBuffer{[]byte{0, 'a'}, 0, 0},
}, },
{ {
buffer{[]byte{0, 'a'}, 1, 2, true, io.EOF}, fixedBuffer{[]byte{'a', 'b'}, 0, 2},
5, 1, io.EOF, []byte{'a'}, 1, 1, nil, []byte{'a'},
buffer{[]byte{0, 'a'}, 2, 2, true, io.EOF}, fixedBuffer{[]byte{'a', 'b'}, 1, 2},
}, },
{ {
buffer{[]byte{}, 0, 0, false, nil}, fixedBuffer{[]byte{}, 0, 0},
5, 0, errReadEmpty, []byte{}, 5, 0, errReadEmpty, []byte{},
buffer{[]byte{}, 0, 0, false, nil}, fixedBuffer{[]byte{}, 0, 0},
},
{
buffer{[]byte{}, 0, 0, true, io.EOF},
5, 0, io.EOF, []byte{},
buffer{[]byte{}, 0, 0, true, io.EOF},
}, },
} }
...@@ -72,64 +61,50 @@ func TestBufferRead(t *testing.T) { ...@@ -72,64 +61,50 @@ func TestBufferRead(t *testing.T) {
} }
var bufferWriteTests = []struct { var bufferWriteTests = []struct {
buf buffer buf fixedBuffer
write, wn int write, wn int
werr error werr error
wbuf buffer wbuf fixedBuffer
}{ }{
{ {
buf: buffer{ buf: fixedBuffer{
buf: []byte{}, buf: []byte{},
}, },
wbuf: buffer{ wbuf: fixedBuffer{
buf: []byte{}, buf: []byte{},
}, },
}, },
{ {
buf: buffer{ buf: fixedBuffer{
buf: []byte{1, 'a'}, buf: []byte{1, 'a'},
}, },
write: 1, write: 1,
wn: 1, wn: 1,
wbuf: buffer{ wbuf: fixedBuffer{
buf: []byte{0, 'a'}, buf: []byte{0, 'a'},
w: 1, w: 1,
}, },
}, },
{ {
buf: buffer{ buf: fixedBuffer{
buf: []byte{'a', 1}, buf: []byte{'a', 1},
r: 1, r: 1,
w: 1, w: 1,
}, },
write: 2, write: 2,
wn: 2, wn: 2,
wbuf: buffer{ wbuf: fixedBuffer{
buf: []byte{0, 0}, buf: []byte{0, 0},
w: 2, w: 2,
}, },
}, },
{ {
buf: buffer{ buf: fixedBuffer{
buf: []byte{},
r: 1,
closed: true,
},
write: 5,
werr: errWriteClosed,
wbuf: buffer{
buf: []byte{},
r: 1,
closed: true,
},
},
{
buf: buffer{
buf: []byte{}, buf: []byte{},
}, },
write: 5, write: 5,
werr: errWriteFull, werr: errWriteFull,
wbuf: buffer{ wbuf: fixedBuffer{
buf: []byte{}, buf: []byte{},
}, },
}, },
......
...@@ -5,38 +5,78 @@ ...@@ -5,38 +5,78 @@
package http2 package http2
import ( import (
"errors"
"io"
"sync" "sync"
) )
// pipe is a goroutine-safe io.Reader/io.Writer pair. It's like
// io.Pipe except there are no PipeReader/PipeWriter halves, and the
// underlying buffer is an interface. (io.Pipe is always unbuffered)
type pipe struct { type pipe struct {
b buffer mu sync.Mutex
c sync.Cond c sync.Cond // c.L must point to
m sync.Mutex b pipeBuffer
err error // read error once empty. non-nil means closed.
}
type pipeBuffer interface {
Len() int
io.Writer
io.Reader
} }
// Read waits until data is available and copies bytes // Read waits until data is available and copies bytes
// from the buffer into p. // from the buffer into p.
func (r *pipe) Read(p []byte) (n int, err error) { func (p *pipe) Read(d []byte) (n int, err error) {
r.c.L.Lock() p.mu.Lock()
defer r.c.L.Unlock() defer p.mu.Unlock()
for r.b.Len() == 0 && !r.b.closed { if p.c.L == nil {
r.c.Wait() p.c.L = &p.mu
}
for {
if p.b.Len() > 0 {
return p.b.Read(d)
}
if p.err != nil {
return 0, p.err
}
p.c.Wait()
} }
return r.b.Read(p)
} }
var errClosedPipeWrite = errors.New("write on closed buffer")
// Write copies bytes from p into the buffer and wakes a reader. // Write copies bytes from p into the buffer and wakes a reader.
// It is an error to write more data than the buffer can hold. // It is an error to write more data than the buffer can hold.
func (w *pipe) Write(p []byte) (n int, err error) { func (p *pipe) Write(d []byte) (n int, err error) {
w.c.L.Lock() p.mu.Lock()
defer w.c.L.Unlock() defer p.mu.Unlock()
defer w.c.Signal() if p.c.L == nil {
return w.b.Write(p) p.c.L = &p.mu
}
defer p.c.Signal()
if p.err != nil {
return 0, errClosedPipeWrite
}
return p.b.Write(d)
} }
func (c *pipe) Close(err error) { // CloseWithError causes Reads to wake up and return the
c.c.L.Lock() // provided err after all data has been read.
defer c.c.L.Unlock() //
defer c.c.Signal() // The error must be non-nil.
c.b.Close(err) func (p *pipe) CloseWithError(err error) {
if err == nil {
panic("CloseWithError must be non-nil")
}
p.mu.Lock()
defer p.mu.Unlock()
if p.c.L == nil {
p.c.L = &p.mu
}
defer p.c.Signal()
if p.err == nil {
p.err = err
}
} }
...@@ -5,17 +5,18 @@ ...@@ -5,17 +5,18 @@
package http2 package http2
import ( import (
"bytes"
"errors" "errors"
"testing" "testing"
) )
func TestPipeClose(t *testing.T) { func TestPipeClose(t *testing.T) {
var p pipe var p pipe
p.c.L = &p.m p.b = new(bytes.Buffer)
a := errors.New("a") a := errors.New("a")
b := errors.New("b") b := errors.New("b")
p.Close(a) p.CloseWithError(a)
p.Close(b) p.CloseWithError(b)
_, err := p.Read(make([]byte, 1)) _, err := p.Read(make([]byte, 1))
if err != a { if err != a {
t.Errorf("err = %v want %v", err, a) t.Errorf("err = %v want %v", err, a)
......
...@@ -65,6 +65,7 @@ const ( ...@@ -65,6 +65,7 @@ const (
var ( var (
errClientDisconnected = errors.New("client disconnected") errClientDisconnected = errors.New("client disconnected")
errClosedBody = errors.New("body closed by handler") errClosedBody = errors.New("body closed by handler")
errHandlerComplete = errors.New("http2: request body closed due to handler exiting")
errStreamClosed = errors.New("http2: stream closed") errStreamClosed = errors.New("http2: stream closed")
) )
...@@ -872,7 +873,7 @@ func (sc *serverConn) wroteFrame(res frameWriteResult) { ...@@ -872,7 +873,7 @@ func (sc *serverConn) wroteFrame(res frameWriteResult) {
errCancel := StreamError{st.id, ErrCodeCancel} errCancel := StreamError{st.id, ErrCodeCancel}
sc.resetStream(errCancel) sc.resetStream(errCancel)
case stateHalfClosedRemote: case stateHalfClosedRemote:
sc.closeStream(st, nil) sc.closeStream(st, errHandlerComplete)
} }
} }
...@@ -1142,7 +1143,7 @@ func (sc *serverConn) closeStream(st *stream, err error) { ...@@ -1142,7 +1143,7 @@ func (sc *serverConn) closeStream(st *stream, err error) {
} }
delete(sc.streams, st.id) delete(sc.streams, st.id)
if p := st.body; p != nil { if p := st.body; p != nil {
p.Close(err) p.CloseWithError(err)
} }
st.cw.Close() // signals Handler's CloseNotifier, unblocks writes, etc st.cw.Close() // signals Handler's CloseNotifier, unblocks writes, etc
sc.writeSched.forgetStream(st.id) sc.writeSched.forgetStream(st.id)
...@@ -1246,7 +1247,7 @@ func (sc *serverConn) processData(f *DataFrame) error { ...@@ -1246,7 +1247,7 @@ func (sc *serverConn) processData(f *DataFrame) error {
// Sender sending more than they'd declared? // Sender sending more than they'd declared?
if st.declBodyBytes != -1 && st.bodyBytes+int64(len(data)) > st.declBodyBytes { if st.declBodyBytes != -1 && st.bodyBytes+int64(len(data)) > st.declBodyBytes {
st.body.Close(fmt.Errorf("sender tried to send more than declared Content-Length of %d bytes", st.declBodyBytes)) st.body.CloseWithError(fmt.Errorf("sender tried to send more than declared Content-Length of %d bytes", st.declBodyBytes))
return StreamError{id, ErrCodeStreamClosed} return StreamError{id, ErrCodeStreamClosed}
} }
if len(data) > 0 { if len(data) > 0 {
...@@ -1266,10 +1267,10 @@ func (sc *serverConn) processData(f *DataFrame) error { ...@@ -1266,10 +1267,10 @@ func (sc *serverConn) processData(f *DataFrame) error {
} }
if f.StreamEnded() { if f.StreamEnded() {
if st.declBodyBytes != -1 && st.declBodyBytes != st.bodyBytes { if st.declBodyBytes != -1 && st.declBodyBytes != st.bodyBytes {
st.body.Close(fmt.Errorf("request declared a Content-Length of %d but only wrote %d bytes", st.body.CloseWithError(fmt.Errorf("request declared a Content-Length of %d but only wrote %d bytes",
st.declBodyBytes, st.bodyBytes)) st.declBodyBytes, st.bodyBytes))
} else { } else {
st.body.Close(io.EOF) st.body.CloseWithError(io.EOF)
} }
st.state = stateHalfClosedRemote st.state = stateHalfClosedRemote
} }
...@@ -1493,9 +1494,8 @@ func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, err ...@@ -1493,9 +1494,8 @@ func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, err
} }
if bodyOpen { if bodyOpen {
body.pipe = &pipe{ body.pipe = &pipe{
b: buffer{buf: make([]byte, initialWindowSize)}, // TODO: share/remove XXX b: &fixedBuffer{buf: make([]byte, initialWindowSize)}, // TODO: share/remove XXX
} }
body.pipe.c.L = &body.pipe.m
if vv, ok := rp.header["Content-Length"]; ok { if vv, ok := rp.header["Content-Length"]; ok {
req.ContentLength, _ = strconv.ParseInt(vv[0], 10, 64) req.ContentLength, _ = strconv.ParseInt(vv[0], 10, 64)
...@@ -1655,7 +1655,7 @@ type requestBody struct { ...@@ -1655,7 +1655,7 @@ type requestBody struct {
func (b *requestBody) Close() error { func (b *requestBody) Close() error {
if b.pipe != nil { if b.pipe != nil {
b.pipe.Close(errClosedBody) b.pipe.CloseWithError(errClosedBody)
} }
b.closed = true b.closed = true
return nil return nil
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment