Commit feca99fd authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

http: add Transport.ProxySelector

R=mattn.jp, rsc
CC=golang-dev
https://golang.org/cl/4528077
parent cef64d8e
......@@ -40,10 +40,8 @@ func TestUseProxy(t *testing.T) {
no_proxy := "foobar.com, .barbaz.net"
os.Setenv("NO_PROXY", no_proxy)
tr := &Transport{}
for _, test := range UseProxyTests {
if tr.useProxy(test.host+":80") != test.match {
if useProxy(test.host+":80") != test.match {
t.Errorf("useProxy(%v) = %v, want %v", test.host, !test.match, test.match)
}
}
......
......@@ -24,7 +24,7 @@ import (
// each call to Do and uses HTTP proxies as directed by the
// $HTTP_PROXY and $NO_PROXY (or $http_proxy and $no_proxy)
// environment variables.
var DefaultTransport RoundTripper = &Transport{}
var DefaultTransport RoundTripper = &Transport{Proxy: ProxyFromEnvironment}
// DefaultMaxIdleConnsPerHost is the default value of Transport's
// MaxIdleConnsPerHost.
......@@ -41,7 +41,12 @@ type Transport struct {
// TODO: tunable on timeout on cached connections
// TODO: optional pipelining
IgnoreEnvironment bool // don't look at environment variables for proxy configuration
// Proxy optionally specifies a function to return a proxy for
// a given Request. If the function returns a non-nil error,
// the request is aborted with the provided error. If Proxy is
// nil or returns a nil *URL, no proxy is used.
Proxy func(*Request) (*URL, os.Error)
DisableKeepAlives bool
DisableCompression bool
......@@ -51,6 +56,39 @@ type Transport struct {
MaxIdleConnsPerHost int
}
// ProxyFromEnvironment returns the URL of the proxy to use for a
// given request, as indicated by the environment variables
// $HTTP_PROXY and $NO_PROXY (or $http_proxy and $no_proxy).
// Either URL or an error is returned.
func ProxyFromEnvironment(req *Request) (*URL, os.Error) {
proxy := getenvEitherCase("HTTP_PROXY")
if proxy == "" {
return nil, nil
}
if !useProxy(canonicalAddr(req.URL)) {
return nil, nil
}
proxyURL, err := ParseRequestURL(proxy)
if err != nil {
return nil, os.ErrorString("invalid proxy address")
}
if proxyURL.Host == "" {
proxyURL, err = ParseRequestURL("http://" + proxy)
if err != nil {
return nil, os.ErrorString("invalid proxy address")
}
}
return proxyURL, nil
}
// ProxyURL returns a proxy function (for use in a Transport)
// that always returns the same URL.
func ProxyURL(url *URL) func(*Request) (*URL, os.Error) {
return func(*Request) (*URL, os.Error) {
return url, nil
}
}
// RoundTrip implements the RoundTripper interface.
func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) {
if req.URL == nil {
......@@ -101,21 +139,11 @@ func (t *Transport) CloseIdleConnections() {
// Private implementation past this point.
//
func (t *Transport) getenvEitherCase(k string) string {
if t.IgnoreEnvironment {
return ""
}
if v := t.getenv(strings.ToUpper(k)); v != "" {
func getenvEitherCase(k string) string {
if v := os.Getenv(strings.ToUpper(k)); v != "" {
return v
}
return t.getenv(strings.ToLower(k))
}
func (t *Transport) getenv(k string) string {
if t.IgnoreEnvironment {
return ""
}
return os.Getenv(k)
return os.Getenv(strings.ToLower(k))
}
func (t *Transport) connectMethodForRequest(req *Request) (*connectMethod, os.Error) {
......@@ -123,20 +151,12 @@ func (t *Transport) connectMethodForRequest(req *Request) (*connectMethod, os.Er
targetScheme: req.URL.Scheme,
targetAddr: canonicalAddr(req.URL),
}
proxy := t.getenvEitherCase("HTTP_PROXY")
if proxy != "" && t.useProxy(cm.targetAddr) {
proxyURL, err := ParseRequestURL(proxy)
if t.Proxy != nil {
var err os.Error
cm.proxyURL, err = t.Proxy(req)
if err != nil {
return nil, os.ErrorString("invalid proxy address")
}
if proxyURL.Host == "" {
proxyURL, err = ParseRequestURL("http://" + proxy)
if err != nil {
return nil, os.ErrorString("invalid proxy address")
}
return nil, err
}
cm.proxyURL = proxyURL
}
return cm, nil
}
......@@ -296,7 +316,7 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) {
// useProxy returns true if requests to addr should use a proxy,
// according to the NO_PROXY or no_proxy environment variable.
// addr is always a canonicalAddr with a host and port.
func (t *Transport) useProxy(addr string) bool {
func useProxy(addr string) bool {
if len(addr) == 0 {
return true
}
......@@ -313,7 +333,7 @@ func (t *Transport) useProxy(addr string) bool {
}
}
no_proxy := t.getenvEitherCase("NO_PROXY")
no_proxy := getenvEitherCase("NO_PROXY")
if no_proxy == "*" {
return false
}
......
......@@ -478,6 +478,30 @@ func TestTransportGzip(t *testing.T) {
}
}
func TestTransportProxy(t *testing.T) {
ch := make(chan string, 1)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
ch <- "real server"
}))
defer ts.Close()
proxy := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
ch <- "proxy for " + r.URL.String()
}))
defer proxy.Close()
pu, err := ParseURL(proxy.URL)
if err != nil {
t.Fatal(err)
}
c := &Client{Transport: &Transport{Proxy: ProxyURL(pu)}}
c.Head(ts.URL)
got := <-ch
want := "proxy for " + ts.URL + "/"
if got != want {
t.Errorf("want %q, got %q", want, got)
}
}
// TestTransportGzipRecursive sends a gzip quine and checks that the
// client gets the same value back. This is more cute than anything,
// but checks that we don't recurse forever, and checks that
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment