Commit 2161e3e2 authored by Fumitoshi Ukai's avatar Fumitoshi Ukai Committed by Russ Cox

http: avoid server crash on malformed client request

R=r, rsc
CC=golang-dev
https://golang.org/cl/206079
parent c2dea219
...@@ -38,20 +38,34 @@ type Handler func(*Conn) ...@@ -38,20 +38,34 @@ type Handler func(*Conn)
// ServeHTTP implements the http.Handler interface for a Web Socket. // ServeHTTP implements the http.Handler interface for a Web Socket.
func (f Handler) ServeHTTP(c *http.Conn, req *http.Request) { func (f Handler) ServeHTTP(c *http.Conn, req *http.Request) {
if req.Method != "GET" || req.Proto != "HTTP/1.1" || if req.Method != "GET" || req.Proto != "HTTP/1.1" {
req.Header["Upgrade"] != "WebSocket" || c.WriteHeader(http.StatusBadRequest)
req.Header["Connection"] != "Upgrade" { io.WriteString(c, "Unexpected request")
c.WriteHeader(http.StatusNotFound)
io.WriteString(c, "must use websocket to connect here")
return return
} }
if v, present := req.Header["Upgrade"]; !present || v != "WebSocket" {
c.WriteHeader(http.StatusBadRequest)
io.WriteString(c, "missing Upgrade: WebSocket header")
return
}
if v, present := req.Header["Connection"]; !present || v != "Upgrade" {
c.WriteHeader(http.StatusBadRequest)
io.WriteString(c, "missing Connection: Upgrade header")
return
}
origin, present := req.Header["Origin"]
if !present {
c.WriteHeader(http.StatusBadRequest)
io.WriteString(c, "missing Origin header")
return
}
rwc, buf, err := c.Hijack() rwc, buf, err := c.Hijack()
if err != nil { if err != nil {
panic("Hijack failed: ", err.String()) panic("Hijack failed: ", err.String())
return return
} }
defer rwc.Close() defer rwc.Close()
origin := req.Header["Origin"]
location := "ws://" + req.Host + req.URL.Path location := "ws://" + req.Host + req.URL.Path
// TODO(ukai): verify origin,location,protocol. // TODO(ukai): verify origin,location,protocol.
...@@ -61,9 +75,9 @@ func (f Handler) ServeHTTP(c *http.Conn, req *http.Request) { ...@@ -61,9 +75,9 @@ func (f Handler) ServeHTTP(c *http.Conn, req *http.Request) {
buf.WriteString("Connection: Upgrade\r\n") buf.WriteString("Connection: Upgrade\r\n")
buf.WriteString("WebSocket-Origin: " + origin + "\r\n") buf.WriteString("WebSocket-Origin: " + origin + "\r\n")
buf.WriteString("WebSocket-Location: " + location + "\r\n") buf.WriteString("WebSocket-Location: " + location + "\r\n")
protocol := "" protocol, present := req.Header["Websocket-Protocol"]
// canonical header key of WebSocket-Protocol. // canonical header key of WebSocket-Protocol.
if protocol, found := req.Header["Websocket-Protocol"]; found { if present {
buf.WriteString("WebSocket-Protocol: " + protocol + "\r\n") buf.WriteString("WebSocket-Protocol: " + protocol + "\r\n")
} }
buf.WriteString("\r\n") buf.WriteString("\r\n")
......
...@@ -6,6 +6,7 @@ package websocket ...@@ -6,6 +6,7 @@ package websocket
import ( import (
"bytes" "bytes"
"fmt"
"http" "http"
"io" "io"
"log" "log"
...@@ -59,3 +60,17 @@ func TestEcho(t *testing.T) { ...@@ -59,3 +60,17 @@ func TestEcho(t *testing.T) {
} }
ws.Close() ws.Close()
} }
func TestHTTP(t *testing.T) {
once.Do(startServer)
r, _, err := http.Get(fmt.Sprintf("http://%s/echo", serverAddr))
if err != nil {
t.Errorf("Get: error %v", err)
return
}
if r.StatusCode != http.StatusBadRequest {
t.Errorf("Get: got status %d", r.StatusCode)
return
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment