Commit 64cc5be4 authored by Fumitoshi Ukai's avatar Fumitoshi Ukai Committed by Russ Cox

web socket: fix short Read

Fixes #1145.

R=rsc
CC=golang-dev
https://golang.org/cl/2302042
parent f57f8b6f
...@@ -27,6 +27,13 @@ func (addr WebSocketAddr) Network() string { return "websocket" } ...@@ -27,6 +27,13 @@ func (addr WebSocketAddr) Network() string { return "websocket" }
// String returns the network address for a Web Socket. // String returns the network address for a Web Socket.
func (addr WebSocketAddr) String() string { return string(addr) } func (addr WebSocketAddr) String() string { return string(addr) }
const (
stateFrameByte = iota
stateFrameLength
stateFrameData
stateFrameTextData
)
// Conn is a channel to communicate to a Web Socket. // Conn is a channel to communicate to a Web Socket.
// It implements the net.Conn interface. // It implements the net.Conn interface.
type Conn struct { type Conn struct {
...@@ -39,6 +46,10 @@ type Conn struct { ...@@ -39,6 +46,10 @@ type Conn struct {
buf *bufio.ReadWriter buf *bufio.ReadWriter
rwc io.ReadWriteCloser rwc io.ReadWriteCloser
// It holds text data in previous Read() that failed with small buffer.
data []byte
reading bool
} }
// newConn creates a new Web Socket. // newConn creates a new Web Socket.
...@@ -48,60 +59,66 @@ func newConn(origin, location, protocol string, buf *bufio.ReadWriter, rwc io.Re ...@@ -48,60 +59,66 @@ func newConn(origin, location, protocol string, buf *bufio.ReadWriter, rwc io.Re
bw := bufio.NewWriter(rwc) bw := bufio.NewWriter(rwc)
buf = bufio.NewReadWriter(br, bw) buf = bufio.NewReadWriter(br, bw)
} }
ws := &Conn{origin, location, protocol, buf, rwc} ws := &Conn{Origin: origin, Location: location, Protocol: protocol, buf: buf, rwc: rwc}
return ws return ws
} }
// Read implements the io.Reader interface for a Conn. // Read implements the io.Reader interface for a Conn.
func (ws *Conn) Read(msg []byte) (n int, err os.Error) { func (ws *Conn) Read(msg []byte) (n int, err os.Error) {
for { Frame:
frameByte, err := ws.buf.ReadByte() for !ws.reading && len(ws.data) == 0 {
// Beginning of frame, possibly.
b, err := ws.buf.ReadByte()
if err != nil { if err != nil {
return n, err return 0, err
} }
if (frameByte & 0x80) == 0x80 { if b&0x80 == 0x80 {
// Skip length frame.
length := 0 length := 0
for { for {
c, err := ws.buf.ReadByte() c, err := ws.buf.ReadByte()
if err != nil { if err != nil {
return n, err return 0, err
} }
length = length*128 + int(c&0x7f) length = length*128 + int(c&0x7f)
if (c & 0x80) == 0 { if c&0x80 == 0 {
break break
} }
} }
for length > 0 { for length > 0 {
_, err := ws.buf.ReadByte() _, err := ws.buf.ReadByte()
if err != nil { if err != nil {
return n, err return 0, err
} }
length--
} }
} else { continue Frame
}
// In text mode
if b != 0 {
// Skip this frame
for { for {
c, err := ws.buf.ReadByte() c, err := ws.buf.ReadByte()
if err != nil { if err != nil {
return n, err return 0, err
} }
if c == '\xff' { if c == '\xff' {
return n, err break
}
if frameByte == 0 {
if n+1 <= cap(msg) {
msg = msg[0 : n+1]
} }
msg[n] = c
n++
} }
if n >= cap(msg) { continue Frame
return n, os.E2BIG
} }
ws.reading = true
} }
if len(ws.data) == 0 {
ws.data, err = ws.buf.ReadSlice('\xff')
if err == nil {
ws.reading = false
ws.data = ws.data[:len(ws.data)-1] // trim \xff
} }
} }
n = copy(msg, ws.data)
panic("unreachable") ws.data = ws.data[n:]
return n, err
} }
// Write implements the io.Writer interface for a Conn. // Write implements the io.Writer interface for a Conn.
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
package websocket package websocket
import ( import (
"bufio"
"bytes" "bytes"
"fmt" "fmt"
"http" "http"
...@@ -195,3 +196,77 @@ func TestTrailingSpaces(t *testing.T) { ...@@ -195,3 +196,77 @@ func TestTrailingSpaces(t *testing.T) {
} }
} }
} }
func TestSmallBuffer(t *testing.T) {
// http://code.google.com/p/go/issues/detail?id=1145
// Read should be able to handle reading a fragment of a frame.
once.Do(startServer)
// websocket.Dial()
client, err := net.Dial("tcp", "", serverAddr)
if err != nil {
t.Fatal("dialing", err)
}
ws, err := newClient("/echo", "localhost", "http://localhost",
"ws://localhost/echo", "", client, handshake)
if err != nil {
t.Errorf("WebSocket handshake error: %v", err)
return
}
msg := []byte("hello, world\n")
if _, err := ws.Write(msg); err != nil {
t.Errorf("Write: %v", err)
}
var small_msg = make([]byte, 8)
n, err := ws.Read(small_msg)
if err != nil {
t.Errorf("Read: %v", err)
}
if !bytes.Equal(msg[:len(small_msg)], small_msg) {
t.Errorf("Echo: expected %q got %q", msg[:len(small_msg)], small_msg)
}
var second_msg = make([]byte, len(msg))
n, err = ws.Read(second_msg)
if err != nil {
t.Errorf("Read: %v", err)
}
second_msg = second_msg[0:n]
if !bytes.Equal(msg[len(small_msg):], second_msg) {
t.Errorf("Echo: expected %q got %q", msg[len(small_msg):], second_msg)
}
ws.Close()
}
func testSkipLengthFrame(t *testing.T) {
b := []byte{'\x80', '\x01', 'x', 0, 'h', 'e', 'l', 'l', 'o', '\xff'}
buf := bytes.NewBuffer(b)
br := bufio.NewReader(buf)
bw := bufio.NewWriter(buf)
ws := newConn("http://127.0.0.1/", "ws://127.0.0.1/", "", bufio.NewReadWriter(br, bw), nil)
msg := make([]byte, 5)
n, err := ws.Read(msg)
if err != nil {
t.Errorf("Read: %v", err)
}
if !bytes.Equal(b[4:8], msg[0:n]) {
t.Errorf("Read: expected %q got %q", msg[4:8], msg[0:n])
}
}
func testSkipNoUTF8Frame(t *testing.T) {
b := []byte{'\x01', 'n', '\xff', 0, 'h', 'e', 'l', 'l', 'o', '\xff'}
buf := bytes.NewBuffer(b)
br := bufio.NewReader(buf)
bw := bufio.NewWriter(buf)
ws := newConn("http://127.0.0.1/", "ws://127.0.0.1/", "", bufio.NewReadWriter(br, bw), nil)
msg := make([]byte, 5)
n, err := ws.Read(msg)
if err != nil {
t.Errorf("Read: %v", err)
}
if !bytes.Equal(b[4:8], msg[0:n]) {
t.Errorf("Read: expected %q got %q", msg[4:8], msg[0:n])
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment