Commit 2161e3e2 authored by Fumitoshi Ukai's avatar Fumitoshi Ukai Committed by Russ Cox

http: avoid server crash on malformed client request

R=r, rsc
CC=golang-dev
https://golang.org/cl/206079
parent c2dea219
......@@ -38,20 +38,34 @@ type Handler func(*Conn)
// ServeHTTP implements the http.Handler interface for a Web Socket.
func (f Handler) ServeHTTP(c *http.Conn, req *http.Request) {
if req.Method != "GET" || req.Proto != "HTTP/1.1" ||
req.Header["Upgrade"] != "WebSocket" ||
req.Header["Connection"] != "Upgrade" {
c.WriteHeader(http.StatusNotFound)
io.WriteString(c, "must use websocket to connect here")
if req.Method != "GET" || req.Proto != "HTTP/1.1" {
c.WriteHeader(http.StatusBadRequest)
io.WriteString(c, "Unexpected request")
return
}
if v, present := req.Header["Upgrade"]; !present || v != "WebSocket" {
c.WriteHeader(http.StatusBadRequest)
io.WriteString(c, "missing Upgrade: WebSocket header")
return
}
if v, present := req.Header["Connection"]; !present || v != "Upgrade" {
c.WriteHeader(http.StatusBadRequest)
io.WriteString(c, "missing Connection: Upgrade header")
return
}
origin, present := req.Header["Origin"]
if !present {
c.WriteHeader(http.StatusBadRequest)
io.WriteString(c, "missing Origin header")
return
}
rwc, buf, err := c.Hijack()
if err != nil {
panic("Hijack failed: ", err.String())
return
}
defer rwc.Close()
origin := req.Header["Origin"]
location := "ws://" + req.Host + req.URL.Path
// TODO(ukai): verify origin,location,protocol.
......@@ -61,9 +75,9 @@ func (f Handler) ServeHTTP(c *http.Conn, req *http.Request) {
buf.WriteString("Connection: Upgrade\r\n")
buf.WriteString("WebSocket-Origin: " + origin + "\r\n")
buf.WriteString("WebSocket-Location: " + location + "\r\n")
protocol := ""
protocol, present := req.Header["Websocket-Protocol"]
// canonical header key of WebSocket-Protocol.
if protocol, found := req.Header["Websocket-Protocol"]; found {
if present {
buf.WriteString("WebSocket-Protocol: " + protocol + "\r\n")
}
buf.WriteString("\r\n")
......
......@@ -6,6 +6,7 @@ package websocket
import (
"bytes"
"fmt"
"http"
"io"
"log"
......@@ -59,3 +60,17 @@ func TestEcho(t *testing.T) {
}
ws.Close()
}
func TestHTTP(t *testing.T) {
once.Do(startServer)
r, _, err := http.Get(fmt.Sprintf("http://%s/echo", serverAddr))
if err != nil {
t.Errorf("Get: error %v", err)
return
}
if r.StatusCode != http.StatusBadRequest {
t.Errorf("Get: got status %d", r.StatusCode)
return
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment