Commit fecab405 authored by William Chan's avatar William Chan Committed by Brad Fitzpatrick

http/spdy: fix data race in header decompression.

flate's reader greedily reads from the shared io.Reader in Framer. This leads to a data race on Framer.r. Fix this by providing a corkedReader to zlib.NewReaderDict(). We uncork the reader and allow it to read the number of bytes in the compressed payload.

Fixes #1884.

R=bradfitz, rsc, go.peter.90
CC=golang-dev
https://golang.org/cl/4530089
parent f4349f73
...@@ -44,6 +44,24 @@ func (e FramerError) String() string { ...@@ -44,6 +44,24 @@ func (e FramerError) String() string {
return "Error(" + strconv.Itoa(int(e)) + ")" return "Error(" + strconv.Itoa(int(e)) + ")"
} }
type corkedReader struct {
r io.Reader
ch chan int
n int
}
func (cr *corkedReader) Read(p []byte) (int, os.Error) {
if cr.n == 0 {
cr.n = <-cr.ch
}
if len(p) > cr.n {
p = p[:cr.n]
}
n, err := cr.r.Read(p)
cr.n -= n
return n, err
}
// Framer handles serializing/deserializing SPDY frames, including compressing/ // Framer handles serializing/deserializing SPDY frames, including compressing/
// decompressing payloads. // decompressing payloads.
type Framer struct { type Framer struct {
...@@ -52,6 +70,7 @@ type Framer struct { ...@@ -52,6 +70,7 @@ type Framer struct {
headerBuf *bytes.Buffer headerBuf *bytes.Buffer
headerCompressor *zlib.Writer headerCompressor *zlib.Writer
r io.Reader r io.Reader
headerReader corkedReader
headerDecompressor io.ReadCloser headerDecompressor io.ReadCloser
} }
...@@ -74,11 +93,13 @@ func NewFramer(w io.Writer, r io.Reader) (*Framer, os.Error) { ...@@ -74,11 +93,13 @@ func NewFramer(w io.Writer, r io.Reader) (*Framer, os.Error) {
return framer, nil return framer, nil
} }
func (f *Framer) initHeaderDecompression() os.Error { func (f *Framer) uncorkHeaderDecompressor(payloadSize int) os.Error {
if f.headerDecompressor != nil { if f.headerDecompressor != nil {
f.headerReader.ch <- payloadSize
return nil return nil
} }
decompressor, err := zlib.NewReaderDict(f.r, []byte(HeaderDictionary)) f.headerReader = corkedReader{r: f.r, ch: make(chan int, 1), n: payloadSize}
decompressor, err := zlib.NewReaderDict(&f.headerReader, []byte(HeaderDictionary))
if err != nil { if err != nil {
return err return err
} }
...@@ -171,7 +192,7 @@ func (f *Framer) readSynStreamFrame(h ControlFrameHeader, frame *SynStreamFrame) ...@@ -171,7 +192,7 @@ func (f *Framer) readSynStreamFrame(h ControlFrameHeader, frame *SynStreamFrame)
reader := f.r reader := f.r
if !f.headerCompressionDisabled { if !f.headerCompressionDisabled {
f.initHeaderDecompression() f.uncorkHeaderDecompressor(int(h.length - 10))
reader = f.headerDecompressor reader = f.headerDecompressor
} }
...@@ -194,7 +215,7 @@ func (f *Framer) readSynReplyFrame(h ControlFrameHeader, frame *SynReplyFrame) o ...@@ -194,7 +215,7 @@ func (f *Framer) readSynReplyFrame(h ControlFrameHeader, frame *SynReplyFrame) o
} }
reader := f.r reader := f.r
if !f.headerCompressionDisabled { if !f.headerCompressionDisabled {
f.initHeaderDecompression() f.uncorkHeaderDecompressor(int(h.length - 6))
reader = f.headerDecompressor reader = f.headerDecompressor
} }
frame.Headers, err = parseHeaderValueBlock(reader) frame.Headers, err = parseHeaderValueBlock(reader)
...@@ -216,7 +237,7 @@ func (f *Framer) readHeadersFrame(h ControlFrameHeader, frame *HeadersFrame) os. ...@@ -216,7 +237,7 @@ func (f *Framer) readHeadersFrame(h ControlFrameHeader, frame *HeadersFrame) os.
} }
reader := f.r reader := f.r
if !f.headerCompressionDisabled { if !f.headerCompressionDisabled {
f.initHeaderDecompression() f.uncorkHeaderDecompressor(int(h.length - 6))
reader = f.headerDecompressor reader = f.headerDecompressor
} }
frame.Headers, err = parseHeaderValueBlock(reader) frame.Headers, err = parseHeaderValueBlock(reader)
......
...@@ -371,12 +371,6 @@ func TestCreateParseDataFrame(t *testing.T) { ...@@ -371,12 +371,6 @@ func TestCreateParseDataFrame(t *testing.T) {
} }
func TestCompressionContextAcrossFrames(t *testing.T) { func TestCompressionContextAcrossFrames(t *testing.T) {
{
// TODO(willchan,bradfitz): test is temporarily disabled
t.Logf("test temporarily disabled; http://code.google.com/p/go/issues/detail?id=1884")
return
}
buffer := new(bytes.Buffer) buffer := new(bytes.Buffer)
framer, err := NewFramer(buffer, buffer) framer, err := NewFramer(buffer, buffer)
if err != nil { if err != nil {
...@@ -430,12 +424,6 @@ func TestCompressionContextAcrossFrames(t *testing.T) { ...@@ -430,12 +424,6 @@ func TestCompressionContextAcrossFrames(t *testing.T) {
} }
func TestMultipleSPDYFrames(t *testing.T) { func TestMultipleSPDYFrames(t *testing.T) {
{
// TODO(willchan,bradfitz): test is temporarily disabled
t.Logf("test temporarily disabled; http://code.google.com/p/go/issues/detail?id=1884")
return
}
// Initialize the framers. // Initialize the framers.
pr1, pw1 := io.Pipe() pr1, pw1 := io.Pipe()
pr2, pw2 := io.Pipe() pr2, pw2 := io.Pipe()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment