Commit 8b29f158 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: fix a panic in Redirect

R=golang-dev, dsymonds, r
CC=golang-dev
https://golang.org/cl/8721045
parent 782a5781
...@@ -180,7 +180,7 @@ func TestPostFormRequestFormat(t *testing.T) { ...@@ -180,7 +180,7 @@ func TestPostFormRequestFormat(t *testing.T) {
} }
} }
func TestRedirects(t *testing.T) { func TestClientRedirects(t *testing.T) {
defer afterTest(t) defer afterTest(t)
var ts *httptest.Server var ts *httptest.Server
ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
......
...@@ -1222,9 +1222,9 @@ func Redirect(w ResponseWriter, r *Request, urlStr string, code int) { ...@@ -1222,9 +1222,9 @@ func Redirect(w ResponseWriter, r *Request, urlStr string, code int) {
} }
// clean up but preserve trailing slash // clean up but preserve trailing slash
trailing := urlStr[len(urlStr)-1] == '/' trailing := strings.HasSuffix(urlStr, "/")
urlStr = path.Clean(urlStr) urlStr = path.Clean(urlStr)
if trailing && urlStr[len(urlStr)-1] != '/' { if trailing && !strings.HasSuffix(urlStr, "/") {
urlStr += "/" urlStr += "/"
} }
urlStr += query urlStr += query
......
...@@ -2,9 +2,11 @@ ...@@ -2,9 +2,11 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package http package http_test
import ( import (
. "net/http"
"net/http/httptest"
"net/url" "net/url"
"testing" "testing"
) )
...@@ -76,20 +78,27 @@ func TestServeMuxHandler(t *testing.T) { ...@@ -76,20 +78,27 @@ func TestServeMuxHandler(t *testing.T) {
}, },
} }
h, pattern := mux.Handler(r) h, pattern := mux.Handler(r)
cs := &codeSaver{h: Header{}} rr := httptest.NewRecorder()
h.ServeHTTP(cs, r) h.ServeHTTP(rr, r)
if pattern != tt.pattern || cs.code != tt.code { if pattern != tt.pattern || rr.Code != tt.code {
t.Errorf("%s %s %s = %d, %q, want %d, %q", tt.method, tt.host, tt.path, cs.code, pattern, tt.code, tt.pattern) t.Errorf("%s %s %s = %d, %q, want %d, %q", tt.method, tt.host, tt.path, rr.Code, pattern, tt.code, tt.pattern)
} }
} }
} }
// A codeSaver is a ResponseWriter that saves the code passed to WriteHeader. func TestServerRedirect(t *testing.T) {
type codeSaver struct { // This used to crash. It's not valid input (bad path), but it
h Header // shouldn't crash.
code int rr := httptest.NewRecorder()
req := &Request{
Method: "GET",
URL: &url.URL{
Scheme: "http",
Path: "not-empty-but-no-leading-slash", // bogus
},
}
Redirect(rr, req, "", 304)
if rr.Code != 304 {
t.Errorf("Code = %d; want 304", rr.Code)
}
} }
func (cs *codeSaver) Header() Header { return cs.h }
func (cs *codeSaver) Write(p []byte) (int, error) { return len(p), nil }
func (cs *codeSaver) WriteHeader(code int) { cs.code = code }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment