Commit a6701f26 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http, net/url: permit Request-URI "*"

Also, implement a global OPTIONS * handler, like Apache.

Permit sending "*" requests to handlers, but not path-based
(ServeMux) handlers.  That means people can go out of their
way to support SSDP or SIP or whatever, but most users will be
unaffected.

See RFC 2616 Section 5.1.2 (Request-URI)
See RFC 2616 Section 9.2 (OPTIONS)

Fixes #3692

R=rsc
CC=golang-dev
https://golang.org/cl/6868095
parent fc393638
...@@ -247,6 +247,54 @@ var reqTests = []reqTest{ ...@@ -247,6 +247,54 @@ var reqTests = []reqTest{
noTrailer, noTrailer,
noError, noError,
}, },
// SSDP Notify request. golang.org/issue/3692
{
"NOTIFY * HTTP/1.1\r\nServer: foo\r\n\r\n",
&Request{
Method: "NOTIFY",
URL: &url.URL{
Path: "*",
},
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: Header{
"Server": []string{"foo"},
},
Close: false,
ContentLength: 0,
RequestURI: "*",
},
noBody,
noTrailer,
noError,
},
// OPTIONS request. Similar to golang.org/issue/3692
{
"OPTIONS * HTTP/1.1\r\nServer: foo\r\n\r\n",
&Request{
Method: "OPTIONS",
URL: &url.URL{
Path: "*",
},
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: Header{
"Server": []string{"foo"},
},
Close: false,
ContentLength: 0,
RequestURI: "*",
},
noBody,
noTrailer,
noError,
},
} }
func TestReadRequest(t *testing.T) { func TestReadRequest(t *testing.T) {
......
...@@ -1288,6 +1288,58 @@ For: ...@@ -1288,6 +1288,58 @@ For:
ts.Close() ts.Close()
} }
func TestOptions(t *testing.T) {
uric := make(chan string, 2) // only expect 1, but leave space for 2
mux := NewServeMux()
mux.HandleFunc("/", func(w ResponseWriter, r *Request) {
uric <- r.RequestURI
})
ts := httptest.NewServer(mux)
defer ts.Close()
conn, err := net.Dial("tcp", ts.Listener.Addr().String())
if err != nil {
t.Fatal(err)
}
defer conn.Close()
// An OPTIONS * request should succeed.
_, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
if err != nil {
t.Fatal(err)
}
br := bufio.NewReader(conn)
res, err := ReadResponse(br, &Request{Method: "OPTIONS"})
if err != nil {
t.Fatal(err)
}
if res.StatusCode != 200 {
t.Errorf("Got non-200 response to OPTIONS *: %#v", res)
}
// A GET * request on a ServeMux should fail.
_, err = conn.Write([]byte("GET * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
if err != nil {
t.Fatal(err)
}
res, err = ReadResponse(br, &Request{Method: "GET"})
if err != nil {
t.Fatal(err)
}
if res.StatusCode != 400 {
t.Errorf("Got non-400 response to GET *: %#v", res)
}
res, err = Get(ts.URL + "/second")
if err != nil {
t.Fatal(err)
}
res.Body.Close()
if got := <-uric; got != "/second" {
t.Errorf("Handler saw request for %q; want /second", got)
}
}
// goTimeout runs f, failing t if f takes more than ns to complete. // goTimeout runs f, failing t if f takes more than ns to complete.
func goTimeout(t *testing.T, d time.Duration, f func()) { func goTimeout(t *testing.T, d time.Duration, f func()) {
ch := make(chan bool, 2) ch := make(chan bool, 2)
......
...@@ -770,6 +770,9 @@ func (c *conn) serve() { ...@@ -770,6 +770,9 @@ func (c *conn) serve() {
if handler == nil { if handler == nil {
handler = DefaultServeMux handler = DefaultServeMux
} }
if req.RequestURI == "*" && req.Method == "OPTIONS" {
handler = globalOptionsHandler{}
}
// HTTP cannot have multiple simultaneous active requests.[*] // HTTP cannot have multiple simultaneous active requests.[*]
// Until the server replies to this request, it can't read another, // Until the server replies to this request, it can't read another,
...@@ -1085,6 +1088,11 @@ func (mux *ServeMux) handler(host, path string) (h Handler, pattern string) { ...@@ -1085,6 +1088,11 @@ func (mux *ServeMux) handler(host, path string) (h Handler, pattern string) {
// ServeHTTP dispatches the request to the handler whose // ServeHTTP dispatches the request to the handler whose
// pattern most closely matches the request URL. // pattern most closely matches the request URL.
func (mux *ServeMux) ServeHTTP(w ResponseWriter, r *Request) { func (mux *ServeMux) ServeHTTP(w ResponseWriter, r *Request) {
if r.RequestURI == "*" {
w.Header().Set("Connection", "close")
w.WriteHeader(StatusBadRequest)
return
}
h, _ := mux.Handler(r) h, _ := mux.Handler(r)
h.ServeHTTP(w, r) h.ServeHTTP(w, r)
} }
...@@ -1408,6 +1416,22 @@ func (tw *timeoutWriter) WriteHeader(code int) { ...@@ -1408,6 +1416,22 @@ func (tw *timeoutWriter) WriteHeader(code int) {
tw.w.WriteHeader(code) tw.w.WriteHeader(code)
} }
// globalOptionsHandler responds to "OPTIONS *" requests.
type globalOptionsHandler struct{}
func (globalOptionsHandler) ServeHTTP(w ResponseWriter, r *Request) {
w.Header().Set("Content-Length", "0")
if r.ContentLength != 0 {
// Read up to 4KB of OPTIONS body (as mentioned in the
// spec as being reserved for future use), but anything
// over that is considered a waste of server resources
// (or an attack) and we abort and close the connection,
// courtesy of MaxBytesReader's EOF behavior.
mb := MaxBytesReader(w, r.Body, 4<<10)
io.Copy(ioutil.Discard, mb)
}
}
// loggingConn is used for debugging. // loggingConn is used for debugging.
type loggingConn struct { type loggingConn struct {
name string name string
......
...@@ -361,6 +361,11 @@ func parse(rawurl string, viaRequest bool) (url *URL, err error) { ...@@ -361,6 +361,11 @@ func parse(rawurl string, viaRequest bool) (url *URL, err error) {
} }
url = new(URL) url = new(URL)
if rawurl == "*" {
url.Path = "*"
return
}
// Split off possible leading "http:", "mailto:", etc. // Split off possible leading "http:", "mailto:", etc.
// Cannot contain escaped characters. // Cannot contain escaped characters.
if url.Scheme, rest, err = getscheme(rawurl); err != nil { if url.Scheme, rest, err = getscheme(rawurl); err != nil {
......
...@@ -277,7 +277,7 @@ func TestParse(t *testing.T) { ...@@ -277,7 +277,7 @@ func TestParse(t *testing.T) {
const pathThatLooksSchemeRelative = "//not.a.user@not.a.host/just/a/path" const pathThatLooksSchemeRelative = "//not.a.user@not.a.host/just/a/path"
var parseRequestUrlTests = []struct { var parseRequestURLTests = []struct {
url string url string
expectedValid bool expectedValid bool
}{ }{
...@@ -289,10 +289,11 @@ var parseRequestUrlTests = []struct { ...@@ -289,10 +289,11 @@ var parseRequestUrlTests = []struct {
{"//not.a.user@%66%6f%6f.com/just/a/path/also", true}, {"//not.a.user@%66%6f%6f.com/just/a/path/also", true},
{"foo.html", false}, {"foo.html", false},
{"../dir/", false}, {"../dir/", false},
{"*", true},
} }
func TestParseRequestURI(t *testing.T) { func TestParseRequestURI(t *testing.T) {
for _, test := range parseRequestUrlTests { for _, test := range parseRequestURLTests {
_, err := ParseRequestURI(test.url) _, err := ParseRequestURI(test.url)
valid := err == nil valid := err == nil
if valid != test.expectedValid { if valid != test.expectedValid {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment