Commit 6dba816f authored by Artyom Pervukhin's avatar Artyom Pervukhin Committed by Brad Fitzpatrick

websocket: limit incoming payload size

Codec's Receive method calls io.ReadAll of the whole frame payload,
which can be abused by user sending large payloads in order to exhaust
server memory.

Introduce limit on received payload size defined by
Conn.MaxPayloadBytes. If payload size of the message read with
Codec.Receive exceeds limit, ErrFrameTooLarge error is returned; the
connection can still be recovered if required: the next call to Receive
would at first discard leftovers of previous oversized message before
processing the next one.

Fixes golang/go#5082.

Change-Id: Ib04acd7038474fee39a1719324daaec1c0c496b1
Reviewed-on: https://go-review.googlesource.com/23590Reviewed-by: 's avatarBrad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent cf4effbb
...@@ -32,6 +32,8 @@ const ( ...@@ -32,6 +32,8 @@ const (
PingFrame = 9 PingFrame = 9
PongFrame = 10 PongFrame = 10
UnknownFrame = 255 UnknownFrame = 255
DefaultMaxPayloadBytes = 32 << 20 // 32MB
) )
// ProtocolError represents WebSocket protocol errors. // ProtocolError represents WebSocket protocol errors.
...@@ -58,6 +60,10 @@ var ( ...@@ -58,6 +60,10 @@ var (
ErrNotSupported = &ProtocolError{"not supported"} ErrNotSupported = &ProtocolError{"not supported"}
) )
// ErrFrameTooLarge is returned by Codec's Receive method if payload size
// exceeds limit set by Conn.MaxPayloadBytes
var ErrFrameTooLarge = errors.New("websocket: frame payload size exceeds limit")
// Addr is an implementation of net.Addr for WebSocket. // Addr is an implementation of net.Addr for WebSocket.
type Addr struct { type Addr struct {
*url.URL *url.URL
...@@ -166,6 +172,10 @@ type Conn struct { ...@@ -166,6 +172,10 @@ type Conn struct {
frameHandler frameHandler
PayloadType byte PayloadType byte
defaultCloseStatus int defaultCloseStatus int
// MaxPayloadBytes limits the size of frame payload received over Conn
// by Codec's Receive method. If zero, DefaultMaxPayloadBytes is used.
MaxPayloadBytes int
} }
// Read implements the io.Reader interface: // Read implements the io.Reader interface:
...@@ -302,7 +312,12 @@ func (cd Codec) Send(ws *Conn, v interface{}) (err error) { ...@@ -302,7 +312,12 @@ func (cd Codec) Send(ws *Conn, v interface{}) (err error) {
return err return err
} }
// Receive receives single frame from ws, unmarshaled by cd.Unmarshal and stores in v. // Receive receives single frame from ws, unmarshaled by cd.Unmarshal and stores
// in v. The whole frame payload is read to an in-memory buffer; max size of
// payload is defined by ws.MaxPayloadBytes. If frame payload size exceeds
// limit, ErrFrameTooLarge is returned; in this case frame is not read off wire
// completely. The next call to Receive would read and discard leftover data of
// previous oversized frame before processing next frame.
func (cd Codec) Receive(ws *Conn, v interface{}) (err error) { func (cd Codec) Receive(ws *Conn, v interface{}) (err error) {
ws.rio.Lock() ws.rio.Lock()
defer ws.rio.Unlock() defer ws.rio.Unlock()
...@@ -325,6 +340,19 @@ again: ...@@ -325,6 +340,19 @@ again:
if frame == nil { if frame == nil {
goto again goto again
} }
maxPayloadBytes := ws.MaxPayloadBytes
if maxPayloadBytes == 0 {
maxPayloadBytes = DefaultMaxPayloadBytes
}
if hf, ok := frame.(*hybiFrameReader); ok && hf.header.Length > int64(maxPayloadBytes) {
// payload size exceeds limit, no need to call Unmarshal
//
// set frameReader to current oversized frame so that
// the next call to this function can drain leftover
// data before processing the next frame
ws.frameReader = frame
return ErrFrameTooLarge
}
payloadType := frame.PayloadType() payloadType := frame.PayloadType()
data, err := ioutil.ReadAll(frame) data, err := ioutil.ReadAll(frame)
if err != nil { if err != nil {
......
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"fmt" "fmt"
"io" "io"
"log" "log"
"math/rand"
"net" "net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
...@@ -605,3 +606,60 @@ func TestCtrlAndData(t *testing.T) { ...@@ -605,3 +606,60 @@ func TestCtrlAndData(t *testing.T) {
} }
} }
} }
func TestCodec_ReceiveLimited(t *testing.T) {
const limit = 2048
var payloads [][]byte
for _, size := range []int{
1024,
2048,
4096, // receive of this message would be interrupted due to limit
2048, // this one is to make sure next receive recovers discarding leftovers
} {
b := make([]byte, size)
rand.Read(b)
payloads = append(payloads, b)
}
handlerDone := make(chan struct{})
limitedHandler := func(ws *Conn) {
defer close(handlerDone)
ws.MaxPayloadBytes = limit
defer ws.Close()
for i, p := range payloads {
t.Logf("payload #%d (size %d, exceeds limit: %v)", i, len(p), len(p) > limit)
var recv []byte
err := Message.Receive(ws, &recv)
switch err {
case nil:
case ErrFrameTooLarge:
if len(p) <= limit {
t.Fatalf("unexpected frame size limit: expected %d bytes of payload having limit at %d", len(p), limit)
}
continue
default:
t.Fatalf("unexpected error: %v (want either nil or ErrFrameTooLarge)", err)
}
if len(recv) > limit {
t.Fatalf("received %d bytes of payload having limit at %d", len(recv), limit)
}
if !bytes.Equal(p, recv) {
t.Fatalf("received payload differs:\ngot:\t%v\nwant:\t%v", recv, p)
}
}
}
server := httptest.NewServer(Handler(limitedHandler))
defer server.CloseClientConnections()
defer server.Close()
addr := server.Listener.Addr().String()
ws, err := Dial("ws://"+addr+"/", "", "http://localhost/")
if err != nil {
t.Fatal(err)
}
defer ws.Close()
for i, p := range payloads {
if err := Message.Send(ws, p); err != nil {
t.Fatalf("payload #%d (size %d): %v", i, len(p), err)
}
}
<-handlerDone
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment