Commit 08ce7f1d authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: follow certain redirects after POST requests

Fixes #4145

R=golang-dev, rsc
CC=golang-dev
https://golang.org/cl/6923055
parent 39067062
...@@ -120,7 +120,10 @@ func (c *Client) send(req *Request) (*Response, error) { ...@@ -120,7 +120,10 @@ func (c *Client) send(req *Request) (*Response, error) {
// Generally Get, Post, or PostForm will be used instead of Do. // Generally Get, Post, or PostForm will be used instead of Do.
func (c *Client) Do(req *Request) (resp *Response, err error) { func (c *Client) Do(req *Request) (resp *Response, err error) {
if req.Method == "GET" || req.Method == "HEAD" { if req.Method == "GET" || req.Method == "HEAD" {
return c.doFollowingRedirects(req) return c.doFollowingRedirects(req, shouldRedirectGet)
}
if req.Method == "POST" || req.Method == "PUT" {
return c.doFollowingRedirects(req, shouldRedirectPost)
} }
return c.send(req) return c.send(req)
} }
...@@ -166,7 +169,7 @@ func send(req *Request, t RoundTripper) (resp *Response, err error) { ...@@ -166,7 +169,7 @@ func send(req *Request, t RoundTripper) (resp *Response, err error) {
// True if the specified HTTP status code is one for which the Get utility should // True if the specified HTTP status code is one for which the Get utility should
// automatically redirect. // automatically redirect.
func shouldRedirect(statusCode int) bool { func shouldRedirectGet(statusCode int) bool {
switch statusCode { switch statusCode {
case StatusMovedPermanently, StatusFound, StatusSeeOther, StatusTemporaryRedirect: case StatusMovedPermanently, StatusFound, StatusSeeOther, StatusTemporaryRedirect:
return true return true
...@@ -174,6 +177,16 @@ func shouldRedirect(statusCode int) bool { ...@@ -174,6 +177,16 @@ func shouldRedirect(statusCode int) bool {
return false return false
} }
// True if the specified HTTP status code is one for which the Post utility should
// automatically redirect.
func shouldRedirectPost(statusCode int) bool {
switch statusCode {
case StatusFound, StatusSeeOther:
return true
}
return false
}
// Get issues a GET to the specified URL. If the response is one of the following // Get issues a GET to the specified URL. If the response is one of the following
// redirect codes, Get follows the redirect, up to a maximum of 10 redirects: // redirect codes, Get follows the redirect, up to a maximum of 10 redirects:
// //
...@@ -214,10 +227,10 @@ func (c *Client) Get(url string) (resp *Response, err error) { ...@@ -214,10 +227,10 @@ func (c *Client) Get(url string) (resp *Response, err error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return c.doFollowingRedirects(req) return c.doFollowingRedirects(req, shouldRedirectGet)
} }
func (c *Client) doFollowingRedirects(ireq *Request) (resp *Response, err error) { func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bool) (resp *Response, err error) {
// TODO: if/when we add cookie support, the redirected request shouldn't // TODO: if/when we add cookie support, the redirected request shouldn't
// necessarily supply the same cookies as the original. // necessarily supply the same cookies as the original.
var base *url.URL var base *url.URL
...@@ -238,6 +251,9 @@ func (c *Client) doFollowingRedirects(ireq *Request) (resp *Response, err error) ...@@ -238,6 +251,9 @@ func (c *Client) doFollowingRedirects(ireq *Request) (resp *Response, err error)
if redirect != 0 { if redirect != 0 {
req = new(Request) req = new(Request)
req.Method = ireq.Method req.Method = ireq.Method
if ireq.Method == "POST" || ireq.Method == "PUT" {
req.Method = "GET"
}
req.Header = make(Header) req.Header = make(Header)
req.URL, err = base.Parse(urlStr) req.URL, err = base.Parse(urlStr)
if err != nil { if err != nil {
...@@ -321,7 +337,7 @@ func (c *Client) Post(url string, bodyType string, body io.Reader) (resp *Respon ...@@ -321,7 +337,7 @@ func (c *Client) Post(url string, bodyType string, body io.Reader) (resp *Respon
return nil, err return nil, err
} }
req.Header.Set("Content-Type", bodyType) req.Header.Set("Content-Type", bodyType)
return c.send(req) return c.doFollowingRedirects(req, shouldRedirectPost)
} }
// PostForm issues a POST to the specified URL, with data's keys and // PostForm issues a POST to the specified URL, with data's keys and
...@@ -371,5 +387,5 @@ func (c *Client) Head(url string) (resp *Response, err error) { ...@@ -371,5 +387,5 @@ func (c *Client) Head(url string) (resp *Response, err error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return c.doFollowingRedirects(req) return c.doFollowingRedirects(req, shouldRedirectGet)
} }
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
package http_test package http_test
import ( import (
"bytes"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"errors" "errors"
...@@ -246,6 +247,52 @@ func TestRedirects(t *testing.T) { ...@@ -246,6 +247,52 @@ func TestRedirects(t *testing.T) {
} }
} }
func TestPostRedirects(t *testing.T) {
var log struct {
sync.Mutex
bytes.Buffer
}
var ts *httptest.Server
ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
log.Lock()
fmt.Fprintf(&log.Buffer, "%s %s ", r.Method, r.RequestURI)
log.Unlock()
if v := r.URL.Query().Get("code"); v != "" {
code, _ := strconv.Atoi(v)
if code/100 == 3 {
w.Header().Set("Location", ts.URL)
}
w.WriteHeader(code)
}
}))
tests := []struct {
suffix string
want int // response code
}{
{"/", 200},
{"/?code=301", 301},
{"/?code=302", 200},
{"/?code=303", 200},
{"/?code=404", 404},
}
for _, tt := range tests {
res, err := Post(ts.URL+tt.suffix, "text/plain", strings.NewReader("Some content"))
if err != nil {
t.Fatal(err)
}
if res.StatusCode != tt.want {
t.Errorf("POST %s: status code = %d; want %d", tt.suffix, res.StatusCode, tt.want)
}
}
log.Lock()
got := log.String()
log.Unlock()
want := "POST / POST /?code=301 POST /?code=302 GET / POST /?code=303 GET / POST /?code=404 "
if got != want {
t.Errorf("Log differs.\n Got: %q\nWant: %q", got, want)
}
}
var expectedCookies = []*Cookie{ var expectedCookies = []*Cookie{
{Name: "ChocolateChip", Value: "tasty"}, {Name: "ChocolateChip", Value: "tasty"},
{Name: "First", Value: "Hit"}, {Name: "First", Value: "Hit"},
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment