Commit 518ee115 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http/httputil: preserve query params in reverse proxy

Fixes #2853

R=golang-dev, r
CC=golang-dev
https://golang.org/cl/5642056
parent 92f55949
...@@ -55,11 +55,16 @@ func singleJoiningSlash(a, b string) string { ...@@ -55,11 +55,16 @@ func singleJoiningSlash(a, b string) string {
// target's path is "/base" and the incoming request was for "/dir", // target's path is "/base" and the incoming request was for "/dir",
// the target request will be for /base/dir. // the target request will be for /base/dir.
func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy { func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
targetQuery := target.RawQuery
director := func(req *http.Request) { director := func(req *http.Request) {
req.URL.Scheme = target.Scheme req.URL.Scheme = target.Scheme
req.URL.Host = target.Host req.URL.Host = target.Host
req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
req.URL.RawQuery = target.RawQuery if targetQuery == "" || req.URL.RawQuery == "" {
req.URL.RawQuery = targetQuery + req.URL.RawQuery
} else {
req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
}
} }
return &ReverseProxy{Director: director} return &ReverseProxy{Director: director}
} }
......
...@@ -69,3 +69,41 @@ func TestReverseProxy(t *testing.T) { ...@@ -69,3 +69,41 @@ func TestReverseProxy(t *testing.T) {
t.Errorf("got body %q; expected %q", g, e) t.Errorf("got body %q; expected %q", g, e)
} }
} }
var proxyQueryTests = []struct {
baseSuffix string // suffix to add to backend URL
reqSuffix string // suffix to add to frontend's request URL
want string // what backend should see for final request URL (without ?)
}{
{"", "", ""},
{"?sta=tic", "?us=er", "sta=tic&us=er"},
{"", "?us=er", "us=er"},
{"?sta=tic", "", "sta=tic"},
}
func TestReverseProxyQuery(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Got-Query", r.URL.RawQuery)
w.Write([]byte("hi"))
}))
defer backend.Close()
for i, tt := range proxyQueryTests {
backendURL, err := url.Parse(backend.URL + tt.baseSuffix)
if err != nil {
t.Fatal(err)
}
frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL))
req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil)
req.Close = true
res, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("%d. Get: %v", i, err)
}
if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e {
t.Errorf("%d. got query %q; expected %q", i, g, e)
}
res.Body.Close()
frontend.Close()
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment