Commit 0005f0a0 authored by Fumitoshi Ukai's avatar Fumitoshi Ukai Committed by Mikio Hara

go.net/websocket: allow server configurable

Add websocket.Server to configure WebSocket server handler.

- Config.Header is additional headers to send, so you can use it
  to send cookies or so.
  To read cookies, you can use Conn.Request().Header.
- factor out Handshake.
  You can set func to check origin, subprotocol etc.
  Handler checks origin by default.

Fixes golang/go#4198.
Fixes golang/go#5178.

R=golang-dev, mikioh.mikioh, crobin
CC=golang-dev
https://golang.org/cl/8731044
parent 94458b3b
......@@ -9,6 +9,7 @@ import (
"crypto/tls"
"io"
"net"
"net/http"
"net/url"
)
......@@ -34,6 +35,7 @@ func NewConfig(server, origin string) (config *Config, err error) {
if err != nil {
return
}
config.Header = http.Header(make(map[string][]string))
return
}
......
......@@ -46,6 +46,17 @@ var (
ErrBadClosingStatus = &ProtocolError{"bad closing status"}
ErrUnsupportedExtensions = &ProtocolError{"unsupported extensions"}
ErrNotImplemented = &ProtocolError{"not implemented"}
handshakeHeader = map[string]bool{
"Host": true,
"Upgrade": true,
"Connection": true,
"Sec-Websocket-Key": true,
"Sec-Websocket-Origin": true,
"Sec-Websocket-Version": true,
"Sec-Websocket-Protocol": true,
"Sec-Websocket-Accept": true,
}
)
// A hybiFrameHeader is a frame header as defined in hybi draft.
......@@ -408,8 +419,11 @@ func hybiClientHandshake(config *Config, br *bufio.Reader, bw *bufio.Writer) (er
if len(config.Protocol) > 0 {
bw.WriteString("Sec-WebSocket-Protocol: " + strings.Join(config.Protocol, ", ") + "\r\n")
}
// TODO(ukai): send extensions.
// TODO(ukai): send cookie if any.
// TODO(ukai): send Sec-WebSocket-Extensions.
err = config.Header.WriteSubset(bw, handshakeHeader)
if err != nil {
return err
}
bw.WriteString("\r\n")
if err = bw.Flush(); err != nil {
......@@ -483,21 +497,14 @@ func (c *hybiServerHandshaker) ReadHandshake(buf *bufio.Reader, req *http.Reques
return http.StatusBadRequest, ErrChallengeResponse
}
version := req.Header.Get("Sec-Websocket-Version")
var origin string
switch version {
case "13":
c.Version = ProtocolVersionHybi13
origin = req.Header.Get("Origin")
case "8":
c.Version = ProtocolVersionHybi08
origin = req.Header.Get("Sec-Websocket-Origin")
default:
return http.StatusBadRequest, ErrBadWebSocketVersion
}
c.Origin, err = url.ParseRequestURI(origin)
if err != nil {
return http.StatusForbidden, err
}
var scheme string
if req.TLS != nil {
scheme = "wss"
......@@ -520,6 +527,22 @@ func (c *hybiServerHandshaker) ReadHandshake(buf *bufio.Reader, req *http.Reques
return http.StatusSwitchingProtocols, nil
}
// Origin parses Origin header in "req".
// If origin is "null", returns (nil, nil).
func Origin(config *Config, req *http.Request) (*url.URL, error) {
var origin string
switch config.Version {
case ProtocolVersionHybi13:
origin = req.Header.Get("Origin")
case ProtocolVersionHybi08:
origin = req.Header.Get("Sec-Websocket-Origin")
}
if origin == "null" {
return nil, nil
}
return url.ParseRequestURI(origin)
}
func (c *hybiServerHandshaker) AcceptHandshake(buf *bufio.Writer) (err error) {
if len(c.Protocol) > 0 {
if len(c.Protocol) != 1 {
......@@ -533,7 +556,13 @@ func (c *hybiServerHandshaker) AcceptHandshake(buf *bufio.Writer) (err error) {
if len(c.Protocol) > 0 {
buf.WriteString("Sec-WebSocket-Protocol: " + c.Protocol[0] + "\r\n")
}
// TODO(ukai): support extensions
// TODO(ukai): send Sec-WebSocket-Extensions.
if c.Header != nil {
err := c.Header.WriteSubset(buf, handshakeHeader)
if err != nil {
return err
}
}
buf.WriteString("\r\n")
return buf.Flush()
}
......
......@@ -92,6 +92,71 @@ Sec-WebSocket-Protocol: chat
}
}
func TestHybiClientHandshakeWithHeader(t *testing.T) {
b := bytes.NewBuffer([]byte{})
bw := bufio.NewWriter(b)
br := bufio.NewReader(strings.NewReader(`HTTP/1.1 101 Switching Protocols
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
Sec-WebSocket-Protocol: chat
`))
var err error
config := new(Config)
config.Location, err = url.ParseRequestURI("ws://server.example.com/chat")
if err != nil {
t.Fatal("location url", err)
}
config.Origin, err = url.ParseRequestURI("http://example.com")
if err != nil {
t.Fatal("origin url", err)
}
config.Protocol = append(config.Protocol, "chat")
config.Protocol = append(config.Protocol, "superchat")
config.Version = ProtocolVersionHybi13
config.Header = http.Header(make(map[string][]string))
config.Header.Add("User-Agent", "test")
config.handshakeData = map[string]string{
"key": "dGhlIHNhbXBsZSBub25jZQ==",
}
err = hybiClientHandshake(config, br, bw)
if err != nil {
t.Errorf("handshake failed: %v", err)
}
req, err := http.ReadRequest(bufio.NewReader(b))
if err != nil {
t.Fatalf("read request: %v", err)
}
if req.Method != "GET" {
t.Errorf("request method expected GET, but got %q", req.Method)
}
if req.URL.Path != "/chat" {
t.Errorf("request path expected /chat, but got %q", req.URL.Path)
}
if req.Proto != "HTTP/1.1" {
t.Errorf("request proto expected HTTP/1.1, but got %q", req.Proto)
}
if req.Host != "server.example.com" {
t.Errorf("request Host expected server.example.com, but got %v", req.Host)
}
var expectedHeader = map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-Websocket-Key": config.handshakeData["key"],
"Origin": config.Origin.String(),
"Sec-Websocket-Protocol": "chat, superchat",
"Sec-Websocket-Version": fmt.Sprintf("%d", ProtocolVersionHybi13),
"User-Agent": "test",
}
for k, v := range expectedHeader {
if req.Header.Get(k) != v {
t.Errorf(fmt.Sprintf("%s expected %q but got %q", k, v, req.Header.Get(k)))
}
}
}
func TestHybiClientHandshakeHybi08(t *testing.T) {
b := bytes.NewBuffer([]byte{})
bw := bufio.NewWriter(b)
......
......@@ -11,8 +11,7 @@ import (
"net/http"
)
func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Request) (conn *Conn, err error) {
config := new(Config)
func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Request, config *Config, handshake func(*Config, *http.Request) error) (conn *Conn, err error) {
var hs serverHandshaker = &hybiServerHandshaker{Config: config}
code, err := hs.ReadHandshake(buf.Reader, req)
if err == ErrBadWebSocketVersion {
......@@ -38,8 +37,16 @@ func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Requ
buf.Flush()
return
}
config.Protocol = nil
if handshake != nil {
err = handshake(config, req)
if err != nil {
code = http.StatusForbidden
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
buf.WriteString("\r\n")
buf.Flush()
return
}
}
err = hs.AcceptHandshake(buf.Writer)
if err != nil {
code = http.StatusBadRequest
......@@ -52,11 +59,26 @@ func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Requ
return
}
// Handler is an interface to a WebSocket.
type Handler func(*Conn)
// Server represents a server of a WebSocket.
type Server struct {
// Config is a WebSocket configuration for new WebSocket connection.
Config
// ServeHTTP implements the http.Handler interface for a Web Socket
func (h Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// Handshake is an optional function in WebSocket handshake.
// For example, you can check, or don't check Origin header.
// Another example, you can select config.Protocol.
Handshake func(*Config, *http.Request) error
// Handler handles a WebSocket connection.
Handler
}
// ServeHTTP implements the http.Handler interface for a WebSocket
func (s Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
s.serveWebSocket(w, req)
}
func (s Server) serveWebSocket(w http.ResponseWriter, req *http.Request) {
rwc, buf, err := w.(http.Hijacker).Hijack()
if err != nil {
panic("Hijack failed: " + err.Error())
......@@ -66,12 +88,35 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// the client did not send a handshake that matches with protocol
// specification.
defer rwc.Close()
conn, err := newServerConn(rwc, buf, req)
conn, err := newServerConn(rwc, buf, req, &s.Config, s.Handshake)
if err != nil {
return
}
if conn == nil {
panic("unexpected nil conn")
}
h(conn)
s.Handler(conn)
}
// Handler is a simple interface to a WebSocket browser client.
// It checks if Origin header is valid URL by default.
// You might want to verify websocket.Conn.Config().Origin in the func.
// If you use Server instead of Handler, you could call websocket.Origin and
// check the origin in your Handshake func. So, if you want to accept
// non-browser client, which doesn't send Origin header, you could use Server
//. that doesn't check origin in its Handshake.
type Handler func(*Conn)
func checkOrigin(config *Config, req *http.Request) (err error) {
config.Origin, err = Origin(config, req)
if err == nil && config.Origin == nil {
return fmt.Errorf("null origin")
}
return err
}
// ServeHTTP implements the http.Handler interface for a WebSocket
func (h Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
s := Server{Handler: h, Handshake: checkOrigin}
s.serveWebSocket(w, req)
}
......@@ -87,6 +87,9 @@ type Config struct {
// TLS config for secure WebSocket (wss).
TlsConfig *tls.Config
// Additional header fields to be sent in WebSocket opening handshake.
Header http.Header
handshakeData map[string]string
}
......
......@@ -44,9 +44,30 @@ func countServer(ws *Conn) {
}
}
func subProtocolHandshake(config *Config, req *http.Request) error {
for _, proto := range config.Protocol {
if proto == "chat" {
config.Protocol = []string{proto}
return nil
}
}
return ErrBadWebSocketProtocol
}
func subProtoServer(ws *Conn) {
for _, proto := range ws.Config().Protocol {
io.WriteString(ws, proto)
}
}
func startServer() {
http.Handle("/echo", Handler(echoServer))
http.Handle("/count", Handler(countServer))
subproto := Server{
Handshake: subProtocolHandshake,
Handler: Handler(subProtoServer),
}
http.Handle("/subproto", subproto)
server := httptest.NewServer(nil)
serverAddr = server.Listener.Addr().String()
log.Print("Test WebSocket server listening on ", serverAddr)
......@@ -177,7 +198,7 @@ func TestWithQuery(t *testing.T) {
ws.Close()
}
func TestWithProtocol(t *testing.T) {
func testWithProtocol(t *testing.T, subproto []string) (string, error) {
once.Do(startServer)
client, err := net.Dial("tcp", serverAddr)
......@@ -185,15 +206,47 @@ func TestWithProtocol(t *testing.T) {
t.Fatal("dialing", err)
}
config := newConfig(t, "/echo")
config.Protocol = append(config.Protocol, "test")
config := newConfig(t, "/subproto")
config.Protocol = subproto
ws, err := NewClient(config, client)
if err != nil {
t.Errorf("WebSocket handshake: %v", err)
return
return "", err
}
msg := make([]byte, 16)
n, err := ws.Read(msg)
if err != nil {
return "", err
}
ws.Close()
return string(msg[:n]), nil
}
func TestWithProtocol(t *testing.T) {
proto, err := testWithProtocol(t, []string{"chat"})
if err != nil {
t.Errorf("SubProto: unexpected error: %v", err)
}
if proto != "chat" {
t.Errorf("SubProto: expected %q, got %q", "chat", proto)
}
}
func TestWithTwoProtocol(t *testing.T) {
proto, err := testWithProtocol(t, []string{"test", "chat"})
if err != nil {
t.Errorf("SubProto: unexpected error: %v", err)
}
if proto != "chat" {
t.Errorf("SubProto: expected %q, got %q", "chat", proto)
}
}
func TestWithBadProtocol(t *testing.T) {
_, err := testWithProtocol(t, []string{"test"})
if err != ErrBadStatus {
t.Errorf("SubProto: expected %q, got %q", ErrBadStatus)
}
}
func TestHTTP(t *testing.T) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment