Commit 3b988eb6 authored by Johan Brandhorst's avatar Johan Brandhorst Committed by Brad Fitzpatrick

net/http: use httptest.Server Client in tests

After merging https://go-review.googlesource.com/c/34639/,
it was pointed out to me that a lot of tests under net/http
could use the new functionality to simplify and unify testing.

Using the httptest.Server provided Client removes the need to
call CloseIdleConnections() on all Transports created, as it
is automatically called on the Transport associated with the
client when Server.Close() is called.

Change the transport used by the non-TLS
httptest.Server to a new *http.Transport rather than using
http.DefaultTransport implicitly. The TLS version already
used its own *http.Transport. This change is to prevent
concurrency problems with using DefaultTransport implicitly
across several httptest.Server's.

Add tests to ensure the httptest.Server.Client().Transport
RoundTripper interface is implemented by a *http.Transport,
as is now assumed across large parts of net/http tests.

Change-Id: I9f9d15f59d72893deead5678d314388718c91821
Reviewed-on: https://go-review.googlesource.com/37771
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-by: 's avatarBrad Fitzpatrick <bradfitz@golang.org>
parent 2bd6360e
This diff is collapsed.
...@@ -74,6 +74,7 @@ func TestServeFile(t *testing.T) { ...@@ -74,6 +74,7 @@ func TestServeFile(t *testing.T) {
ServeFile(w, r, "testdata/file") ServeFile(w, r, "testdata/file")
})) }))
defer ts.Close() defer ts.Close()
c := ts.Client()
var err error var err error
...@@ -91,7 +92,7 @@ func TestServeFile(t *testing.T) { ...@@ -91,7 +92,7 @@ func TestServeFile(t *testing.T) {
req.Method = "GET" req.Method = "GET"
// straight GET // straight GET
_, body := getBody(t, "straight get", req) _, body := getBody(t, "straight get", req, c)
if !bytes.Equal(body, file) { if !bytes.Equal(body, file) {
t.Fatalf("body mismatch: got %q, want %q", body, file) t.Fatalf("body mismatch: got %q, want %q", body, file)
} }
...@@ -102,7 +103,7 @@ Cases: ...@@ -102,7 +103,7 @@ Cases:
if rt.r != "" { if rt.r != "" {
req.Header.Set("Range", rt.r) req.Header.Set("Range", rt.r)
} }
resp, body := getBody(t, fmt.Sprintf("range test %q", rt.r), req) resp, body := getBody(t, fmt.Sprintf("range test %q", rt.r), req, c)
if resp.StatusCode != rt.code { if resp.StatusCode != rt.code {
t.Errorf("range=%q: StatusCode=%d, want %d", rt.r, resp.StatusCode, rt.code) t.Errorf("range=%q: StatusCode=%d, want %d", rt.r, resp.StatusCode, rt.code)
} }
...@@ -704,7 +705,8 @@ func TestDirectoryIfNotModified(t *testing.T) { ...@@ -704,7 +705,8 @@ func TestDirectoryIfNotModified(t *testing.T) {
req, _ := NewRequest("GET", ts.URL, nil) req, _ := NewRequest("GET", ts.URL, nil)
req.Header.Set("If-Modified-Since", lastMod) req.Header.Set("If-Modified-Since", lastMod)
res, err = DefaultClient.Do(req) c := ts.Client()
res, err = c.Do(req)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -716,7 +718,7 @@ func TestDirectoryIfNotModified(t *testing.T) { ...@@ -716,7 +718,7 @@ func TestDirectoryIfNotModified(t *testing.T) {
// Advance the index.html file's modtime, but not the directory's. // Advance the index.html file's modtime, but not the directory's.
indexFile.modtime = indexFile.modtime.Add(1 * time.Hour) indexFile.modtime = indexFile.modtime.Add(1 * time.Hour)
res, err = DefaultClient.Do(req) res, err = c.Do(req)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -995,7 +997,9 @@ func TestServeContent(t *testing.T) { ...@@ -995,7 +997,9 @@ func TestServeContent(t *testing.T) {
for k, v := range tt.reqHeader { for k, v := range tt.reqHeader {
req.Header.Set(k, v) req.Header.Set(k, v)
} }
res, err := DefaultClient.Do(req)
c := ts.Client()
res, err := c.Do(req)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -1050,8 +1054,9 @@ func TestServeContentErrorMessages(t *testing.T) { ...@@ -1050,8 +1054,9 @@ func TestServeContentErrorMessages(t *testing.T) {
} }
ts := httptest.NewServer(FileServer(fs)) ts := httptest.NewServer(FileServer(fs))
defer ts.Close() defer ts.Close()
c := ts.Client()
for _, code := range []int{403, 404, 500} { for _, code := range []int{403, 404, 500} {
res, err := DefaultClient.Get(fmt.Sprintf("%s/%d", ts.URL, code)) res, err := c.Get(fmt.Sprintf("%s/%d", ts.URL, code))
if err != nil { if err != nil {
t.Errorf("Error fetching /%d: %v", code, err) t.Errorf("Error fetching /%d: %v", code, err)
continue continue
...@@ -1125,8 +1130,8 @@ func TestLinuxSendfile(t *testing.T) { ...@@ -1125,8 +1130,8 @@ func TestLinuxSendfile(t *testing.T) {
} }
} }
func getBody(t *testing.T, testName string, req Request) (*Response, []byte) { func getBody(t *testing.T, testName string, req Request, client *Client) (*Response, []byte) {
r, err := DefaultClient.Do(&req) r, err := client.Do(&req)
if err != nil { if err != nil {
t.Fatalf("%s: for URL %q, send error: %v", testName, req.URL.String(), err) t.Fatalf("%s: for URL %q, send error: %v", testName, req.URL.String(), err)
} }
......
...@@ -93,7 +93,9 @@ func NewUnstartedServer(handler http.Handler) *Server { ...@@ -93,7 +93,9 @@ func NewUnstartedServer(handler http.Handler) *Server {
return &Server{ return &Server{
Listener: newLocalListener(), Listener: newLocalListener(),
Config: &http.Server{Handler: handler}, Config: &http.Server{Handler: handler},
client: &http.Client{}, client: &http.Client{
Transport: &http.Transport{},
},
} }
} }
......
...@@ -121,3 +121,27 @@ func TestServerClient(t *testing.T) { ...@@ -121,3 +121,27 @@ func TestServerClient(t *testing.T) {
t.Errorf("got %q, want hello", string(got)) t.Errorf("got %q, want hello", string(got))
} }
} }
// Tests that the Server.Client.Transport interface is implemented
// by a *http.Transport.
func TestServerClientTransportType(t *testing.T) {
ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
}))
defer ts.Close()
client := ts.Client()
if _, ok := client.Transport.(*http.Transport); !ok {
t.Errorf("got %T, want *http.Transport", client.Transport)
}
}
// Tests that the TLS Server.Client.Transport interface is implemented
// by a *http.Transport.
func TestTLSServerClientTransportType(t *testing.T) {
ts := NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
}))
defer ts.Close()
client := ts.Client()
if _, ok := client.Transport.(*http.Transport); !ok {
t.Errorf("got %T, want *http.Transport", client.Transport)
}
}
...@@ -79,6 +79,7 @@ func TestReverseProxy(t *testing.T) { ...@@ -79,6 +79,7 @@ func TestReverseProxy(t *testing.T) {
proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
frontend := httptest.NewServer(proxyHandler) frontend := httptest.NewServer(proxyHandler)
defer frontend.Close() defer frontend.Close()
frontendClient := frontend.Client()
getReq, _ := http.NewRequest("GET", frontend.URL, nil) getReq, _ := http.NewRequest("GET", frontend.URL, nil)
getReq.Host = "some-name" getReq.Host = "some-name"
...@@ -86,7 +87,7 @@ func TestReverseProxy(t *testing.T) { ...@@ -86,7 +87,7 @@ func TestReverseProxy(t *testing.T) {
getReq.Header.Set("Proxy-Connection", "should be deleted") getReq.Header.Set("Proxy-Connection", "should be deleted")
getReq.Header.Set("Upgrade", "foo") getReq.Header.Set("Upgrade", "foo")
getReq.Close = true getReq.Close = true
res, err := http.DefaultClient.Do(getReq) res, err := frontendClient.Do(getReq)
if err != nil { if err != nil {
t.Fatalf("Get: %v", err) t.Fatalf("Get: %v", err)
} }
...@@ -126,7 +127,7 @@ func TestReverseProxy(t *testing.T) { ...@@ -126,7 +127,7 @@ func TestReverseProxy(t *testing.T) {
// a response results in a StatusBadGateway. // a response results in a StatusBadGateway.
getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil) getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil)
getReq.Close = true getReq.Close = true
res, err = http.DefaultClient.Do(getReq) res, err = frontendClient.Do(getReq)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -172,7 +173,7 @@ func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) { ...@@ -172,7 +173,7 @@ func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
getReq.Header.Set("Connection", "Upgrade, "+fakeConnectionToken) getReq.Header.Set("Connection", "Upgrade, "+fakeConnectionToken)
getReq.Header.Set("Upgrade", "original value") getReq.Header.Set("Upgrade", "original value")
getReq.Header.Set(fakeConnectionToken, "should be deleted") getReq.Header.Set(fakeConnectionToken, "should be deleted")
res, err := http.DefaultClient.Do(getReq) res, err := frontend.Client().Do(getReq)
if err != nil { if err != nil {
t.Fatalf("Get: %v", err) t.Fatalf("Get: %v", err)
} }
...@@ -220,7 +221,7 @@ func TestXForwardedFor(t *testing.T) { ...@@ -220,7 +221,7 @@ func TestXForwardedFor(t *testing.T) {
getReq.Header.Set("Connection", "close") getReq.Header.Set("Connection", "close")
getReq.Header.Set("X-Forwarded-For", prevForwardedFor) getReq.Header.Set("X-Forwarded-For", prevForwardedFor)
getReq.Close = true getReq.Close = true
res, err := http.DefaultClient.Do(getReq) res, err := frontend.Client().Do(getReq)
if err != nil { if err != nil {
t.Fatalf("Get: %v", err) t.Fatalf("Get: %v", err)
} }
...@@ -259,7 +260,7 @@ func TestReverseProxyQuery(t *testing.T) { ...@@ -259,7 +260,7 @@ func TestReverseProxyQuery(t *testing.T) {
frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL)) frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL))
req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil) req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil)
req.Close = true req.Close = true
res, err := http.DefaultClient.Do(req) res, err := frontend.Client().Do(req)
if err != nil { if err != nil {
t.Fatalf("%d. Get: %v", i, err) t.Fatalf("%d. Get: %v", i, err)
} }
...@@ -295,7 +296,7 @@ func TestReverseProxyFlushInterval(t *testing.T) { ...@@ -295,7 +296,7 @@ func TestReverseProxyFlushInterval(t *testing.T) {
req, _ := http.NewRequest("GET", frontend.URL, nil) req, _ := http.NewRequest("GET", frontend.URL, nil)
req.Close = true req.Close = true
res, err := http.DefaultClient.Do(req) res, err := frontend.Client().Do(req)
if err != nil { if err != nil {
t.Fatalf("Get: %v", err) t.Fatalf("Get: %v", err)
} }
...@@ -349,13 +350,14 @@ func TestReverseProxyCancelation(t *testing.T) { ...@@ -349,13 +350,14 @@ func TestReverseProxyCancelation(t *testing.T) {
frontend := httptest.NewServer(proxyHandler) frontend := httptest.NewServer(proxyHandler)
defer frontend.Close() defer frontend.Close()
frontendClient := frontend.Client()
getReq, _ := http.NewRequest("GET", frontend.URL, nil) getReq, _ := http.NewRequest("GET", frontend.URL, nil)
go func() { go func() {
<-reqInFlight <-reqInFlight
http.DefaultTransport.(*http.Transport).CancelRequest(getReq) frontendClient.Transport.(*http.Transport).CancelRequest(getReq)
}() }()
res, err := http.DefaultClient.Do(getReq) res, err := frontendClient.Do(getReq)
if res != nil { if res != nil {
t.Errorf("got response %v; want nil", res.Status) t.Errorf("got response %v; want nil", res.Status)
} }
...@@ -363,7 +365,7 @@ func TestReverseProxyCancelation(t *testing.T) { ...@@ -363,7 +365,7 @@ func TestReverseProxyCancelation(t *testing.T) {
// This should be an error like: // This should be an error like:
// Get http://127.0.0.1:58079: read tcp 127.0.0.1:58079: // Get http://127.0.0.1:58079: read tcp 127.0.0.1:58079:
// use of closed network connection // use of closed network connection
t.Error("DefaultClient.Do() returned nil error; want non-nil error") t.Error("Server.Client().Do() returned nil error; want non-nil error")
} }
} }
...@@ -428,11 +430,12 @@ func TestUserAgentHeader(t *testing.T) { ...@@ -428,11 +430,12 @@ func TestUserAgentHeader(t *testing.T) {
proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
frontend := httptest.NewServer(proxyHandler) frontend := httptest.NewServer(proxyHandler)
defer frontend.Close() defer frontend.Close()
frontendClient := frontend.Client()
getReq, _ := http.NewRequest("GET", frontend.URL, nil) getReq, _ := http.NewRequest("GET", frontend.URL, nil)
getReq.Header.Set("User-Agent", explicitUA) getReq.Header.Set("User-Agent", explicitUA)
getReq.Close = true getReq.Close = true
res, err := http.DefaultClient.Do(getReq) res, err := frontendClient.Do(getReq)
if err != nil { if err != nil {
t.Fatalf("Get: %v", err) t.Fatalf("Get: %v", err)
} }
...@@ -441,7 +444,7 @@ func TestUserAgentHeader(t *testing.T) { ...@@ -441,7 +444,7 @@ func TestUserAgentHeader(t *testing.T) {
getReq, _ = http.NewRequest("GET", frontend.URL+"/noua", nil) getReq, _ = http.NewRequest("GET", frontend.URL+"/noua", nil)
getReq.Header.Set("User-Agent", "") getReq.Header.Set("User-Agent", "")
getReq.Close = true getReq.Close = true
res, err = http.DefaultClient.Do(getReq) res, err = frontendClient.Do(getReq)
if err != nil { if err != nil {
t.Fatalf("Get: %v", err) t.Fatalf("Get: %v", err)
} }
...@@ -493,7 +496,7 @@ func TestReverseProxyGetPutBuffer(t *testing.T) { ...@@ -493,7 +496,7 @@ func TestReverseProxyGetPutBuffer(t *testing.T) {
req, _ := http.NewRequest("GET", frontend.URL, nil) req, _ := http.NewRequest("GET", frontend.URL, nil)
req.Close = true req.Close = true
res, err := http.DefaultClient.Do(req) res, err := frontend.Client().Do(req)
if err != nil { if err != nil {
t.Fatalf("Get: %v", err) t.Fatalf("Get: %v", err)
} }
...@@ -540,7 +543,7 @@ func TestReverseProxy_Post(t *testing.T) { ...@@ -540,7 +543,7 @@ func TestReverseProxy_Post(t *testing.T) {
defer frontend.Close() defer frontend.Close()
postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody)) postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody))
res, err := http.DefaultClient.Do(postReq) res, err := frontend.Client().Do(postReq)
if err != nil { if err != nil {
t.Fatalf("Do: %v", err) t.Fatalf("Do: %v", err)
} }
...@@ -573,7 +576,7 @@ func TestReverseProxy_NilBody(t *testing.T) { ...@@ -573,7 +576,7 @@ func TestReverseProxy_NilBody(t *testing.T) {
frontend := httptest.NewServer(proxyHandler) frontend := httptest.NewServer(proxyHandler)
defer frontend.Close() defer frontend.Close()
res, err := http.DefaultClient.Get(frontend.URL) res, err := frontend.Client().Get(frontend.URL)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
......
...@@ -151,7 +151,3 @@ func waitErrCondition(waitFor, checkEvery time.Duration, fn func() error) error ...@@ -151,7 +151,3 @@ func waitErrCondition(waitFor, checkEvery time.Duration, fn func() error) error
} }
return err return err
} }
func closeClient(c *http.Client) {
c.Transport.(*http.Transport).CloseIdleConnections()
}
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"crypto/tls" "crypto/tls"
"crypto/x509"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
...@@ -43,10 +44,7 @@ func TestNextProtoUpgrade(t *testing.T) { ...@@ -43,10 +44,7 @@ func TestNextProtoUpgrade(t *testing.T) {
// Normal request, without NPN. // Normal request, without NPN.
{ {
tr := newTLSTransport(t, ts) c := ts.Client()
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
res, err := c.Get(ts.URL) res, err := c.Get(ts.URL)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
...@@ -63,11 +61,18 @@ func TestNextProtoUpgrade(t *testing.T) { ...@@ -63,11 +61,18 @@ func TestNextProtoUpgrade(t *testing.T) {
// Request to an advertised but unhandled NPN protocol. // Request to an advertised but unhandled NPN protocol.
// Server will hang up. // Server will hang up.
{ {
tr := newTLSTransport(t, ts) certPool := x509.NewCertPool()
tr.TLSClientConfig.NextProtos = []string{"unhandled-proto"} certPool.AddCert(ts.Certificate())
tr := &Transport{
TLSClientConfig: &tls.Config{
RootCAs: certPool,
NextProtos: []string{"unhandled-proto"},
},
}
defer tr.CloseIdleConnections() defer tr.CloseIdleConnections()
c := &Client{Transport: tr} c := &Client{
Transport: tr,
}
res, err := c.Get(ts.URL) res, err := c.Get(ts.URL)
if err == nil { if err == nil {
defer res.Body.Close() defer res.Body.Close()
...@@ -80,7 +85,8 @@ func TestNextProtoUpgrade(t *testing.T) { ...@@ -80,7 +85,8 @@ func TestNextProtoUpgrade(t *testing.T) {
// Request using the "tls-0.9" protocol, which we register here. // Request using the "tls-0.9" protocol, which we register here.
// It is HTTP/0.9 over TLS. // It is HTTP/0.9 over TLS.
{ {
tlsConfig := newTLSTransport(t, ts).TLSClientConfig c := ts.Client()
tlsConfig := c.Transport.(*Transport).TLSClientConfig
tlsConfig.NextProtos = []string{"tls-0.9"} tlsConfig.NextProtos = []string{"tls-0.9"}
conn, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig) conn, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig)
if err != nil { if err != nil {
......
...@@ -474,9 +474,7 @@ func TestServerTimeouts(t *testing.T) { ...@@ -474,9 +474,7 @@ func TestServerTimeouts(t *testing.T) {
defer ts.Close() defer ts.Close()
// Hit the HTTP server successfully. // Hit the HTTP server successfully.
tr := &Transport{DisableKeepAlives: true} // they interfere with this test c := ts.Client()
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
r, err := c.Get(ts.URL) r, err := c.Get(ts.URL)
if err != nil { if err != nil {
t.Fatalf("http Get #1: %v", err) t.Fatalf("http Get #1: %v", err)
...@@ -548,12 +546,10 @@ func TestHTTP2WriteDeadlineExtendedOnNewRequest(t *testing.T) { ...@@ -548,12 +546,10 @@ func TestHTTP2WriteDeadlineExtendedOnNewRequest(t *testing.T) {
ts.StartTLS() ts.StartTLS()
defer ts.Close() defer ts.Close()
tr := newTLSTransport(t, ts) c := ts.Client()
defer tr.CloseIdleConnections() if err := ExportHttp2ConfigureTransport(c.Transport.(*Transport)); err != nil {
if err := ExportHttp2ConfigureTransport(tr); err != nil {
t.Fatal(err) t.Fatal(err)
} }
c := &Client{Transport: tr}
for i := 1; i <= 3; i++ { for i := 1; i <= 3; i++ {
req, err := NewRequest("GET", ts.URL, nil) req, err := NewRequest("GET", ts.URL, nil)
...@@ -608,9 +604,7 @@ func TestOnlyWriteTimeout(t *testing.T) { ...@@ -608,9 +604,7 @@ func TestOnlyWriteTimeout(t *testing.T) {
ts.Start() ts.Start()
defer ts.Close() defer ts.Close()
tr := &Transport{DisableKeepAlives: false} c := ts.Client()
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
errc := make(chan error) errc := make(chan error)
go func() { go func() {
...@@ -671,8 +665,7 @@ func TestIdentityResponse(t *testing.T) { ...@@ -671,8 +665,7 @@ func TestIdentityResponse(t *testing.T) {
ts := httptest.NewServer(handler) ts := httptest.NewServer(handler)
defer ts.Close() defer ts.Close()
c := &Client{Transport: new(Transport)} c := ts.Client()
defer closeClient(c)
// Note: this relies on the assumption (which is true) that // Note: this relies on the assumption (which is true) that
// Get sends HTTP/1.1 or greater requests. Otherwise the // Get sends HTTP/1.1 or greater requests. Otherwise the
...@@ -949,9 +942,8 @@ func TestServerAllowsBlockingRemoteAddr(t *testing.T) { ...@@ -949,9 +942,8 @@ func TestServerAllowsBlockingRemoteAddr(t *testing.T) {
ts.Start() ts.Start()
defer ts.Close() defer ts.Close()
tr := &Transport{DisableKeepAlives: true} c := ts.Client()
defer tr.CloseIdleConnections() c.Timeout = time.Second
c := &Client{Transport: tr, Timeout: time.Second}
fetch := func(num int, response chan<- string) { fetch := func(num int, response chan<- string) {
resp, err := c.Get(ts.URL) resp, err := c.Get(ts.URL)
...@@ -1022,9 +1014,7 @@ func TestIdentityResponseHeaders(t *testing.T) { ...@@ -1022,9 +1014,7 @@ func TestIdentityResponseHeaders(t *testing.T) {
})) }))
defer ts.Close() defer ts.Close()
c := &Client{Transport: new(Transport)} c := ts.Client()
defer closeClient(c)
res, err := c.Get(ts.URL) res, err := c.Get(ts.URL)
if err != nil { if err != nil {
t.Fatalf("Get error: %v", err) t.Fatalf("Get error: %v", err)
...@@ -1145,12 +1135,7 @@ func TestTLSServer(t *testing.T) { ...@@ -1145,12 +1135,7 @@ func TestTLSServer(t *testing.T) {
t.Errorf("expected test TLS server to start with https://, got %q", ts.URL) t.Errorf("expected test TLS server to start with https://, got %q", ts.URL)
return return
} }
noVerifyTransport := &Transport{ client := ts.Client()
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
client := &Client{Transport: noVerifyTransport}
res, err := client.Get(ts.URL) res, err := client.Get(ts.URL)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
...@@ -1967,8 +1952,7 @@ func TestTimeoutHandlerRace(t *testing.T) { ...@@ -1967,8 +1952,7 @@ func TestTimeoutHandlerRace(t *testing.T) {
ts := httptest.NewServer(TimeoutHandler(delayHi, 20*time.Millisecond, "")) ts := httptest.NewServer(TimeoutHandler(delayHi, 20*time.Millisecond, ""))
defer ts.Close() defer ts.Close()
c := &Client{Transport: new(Transport)} c := ts.Client()
defer closeClient(c)
var wg sync.WaitGroup var wg sync.WaitGroup
gate := make(chan bool, 10) gate := make(chan bool, 10)
...@@ -2011,8 +1995,8 @@ func TestTimeoutHandlerRaceHeader(t *testing.T) { ...@@ -2011,8 +1995,8 @@ func TestTimeoutHandlerRaceHeader(t *testing.T) {
if testing.Short() { if testing.Short() {
n = 10 n = 10
} }
c := &Client{Transport: new(Transport)}
defer closeClient(c) c := ts.Client()
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
gate <- true gate <- true
wg.Add(1) wg.Add(1)
...@@ -2099,8 +2083,7 @@ func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) { ...@@ -2099,8 +2083,7 @@ func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) {
ts := httptest.NewServer(TimeoutHandler(handler, timeout, "")) ts := httptest.NewServer(TimeoutHandler(handler, timeout, ""))
defer ts.Close() defer ts.Close()
c := &Client{Transport: new(Transport)} c := ts.Client()
defer closeClient(c)
// Issue was caused by the timeout handler starting the timer when // Issue was caused by the timeout handler starting the timer when
// was created, not when the request. So wait for more than the timeout // was created, not when the request. So wait for more than the timeout
...@@ -2127,8 +2110,7 @@ func TestTimeoutHandlerEmptyResponse(t *testing.T) { ...@@ -2127,8 +2110,7 @@ func TestTimeoutHandlerEmptyResponse(t *testing.T) {
ts := httptest.NewServer(TimeoutHandler(handler, timeout, "")) ts := httptest.NewServer(TimeoutHandler(handler, timeout, ""))
defer ts.Close() defer ts.Close()
c := &Client{Transport: new(Transport)} c := ts.Client()
defer closeClient(c)
res, err := c.Get(ts.URL) res, err := c.Get(ts.URL)
if err != nil { if err != nil {
...@@ -2364,9 +2346,7 @@ func TestServerWriteHijackZeroBytes(t *testing.T) { ...@@ -2364,9 +2346,7 @@ func TestServerWriteHijackZeroBytes(t *testing.T) {
ts.Start() ts.Start()
defer ts.Close() defer ts.Close()
tr := &Transport{} c := ts.Client()
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
res, err := c.Get(ts.URL) res, err := c.Get(ts.URL)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
...@@ -2411,8 +2391,7 @@ func TestStripPrefix(t *testing.T) { ...@@ -2411,8 +2391,7 @@ func TestStripPrefix(t *testing.T) {
ts := httptest.NewServer(StripPrefix("/foo", h)) ts := httptest.NewServer(StripPrefix("/foo", h))
defer ts.Close() defer ts.Close()
c := &Client{Transport: new(Transport)} c := ts.Client()
defer closeClient(c)
res, err := c.Get(ts.URL + "/foo/bar") res, err := c.Get(ts.URL + "/foo/bar")
if err != nil { if err != nil {
...@@ -3654,9 +3633,7 @@ func TestServerConnState(t *testing.T) { ...@@ -3654,9 +3633,7 @@ func TestServerConnState(t *testing.T) {
} }
ts.Start() ts.Start()
tr := &Transport{} c := ts.Client()
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
mustGet := func(url string, headers ...string) { mustGet := func(url string, headers ...string) {
req, err := NewRequest("GET", url, nil) req, err := NewRequest("GET", url, nil)
...@@ -4491,15 +4468,9 @@ func benchmarkClientServerParallel(b *testing.B, parallelism int, useTLS bool) { ...@@ -4491,15 +4468,9 @@ func benchmarkClientServerParallel(b *testing.B, parallelism int, useTLS bool) {
b.ResetTimer() b.ResetTimer()
b.SetParallelism(parallelism) b.SetParallelism(parallelism)
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
noVerifyTransport := &Transport{ c := ts.Client()
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
defer noVerifyTransport.CloseIdleConnections()
client := &Client{Transport: noVerifyTransport}
for pb.Next() { for pb.Next() {
res, err := client.Get(ts.URL) res, err := c.Get(ts.URL)
if err != nil { if err != nil {
b.Logf("Get: %v", err) b.Logf("Get: %v", err)
continue continue
...@@ -4934,10 +4905,7 @@ func TestServerIdleTimeout(t *testing.T) { ...@@ -4934,10 +4905,7 @@ func TestServerIdleTimeout(t *testing.T) {
ts.Config.IdleTimeout = 2 * time.Second ts.Config.IdleTimeout = 2 * time.Second
ts.Start() ts.Start()
defer ts.Close() defer ts.Close()
c := ts.Client()
tr := &Transport{}
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
get := func() string { get := func() string {
res, err := c.Get(ts.URL) res, err := c.Get(ts.URL)
...@@ -4998,9 +4966,8 @@ func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) { ...@@ -4998,9 +4966,8 @@ func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) {
})) }))
defer ts.Close() defer ts.Close()
tr := &Transport{} c := ts.Client()
defer tr.CloseIdleConnections() tr := c.Transport.(*Transport)
c := &Client{Transport: tr}
get := func() string { return get(t, c, ts.URL) } get := func() string { return get(t, c, ts.URL) }
...@@ -5119,9 +5086,7 @@ func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { ...@@ -5119,9 +5086,7 @@ func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) {
ts.Start() ts.Start()
defer ts.Close() defer ts.Close()
tr := &Transport{} c := ts.Client()
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
res, err := c.Get(ts.URL) res, err := c.Get(ts.URL)
if err != nil { if err != nil {
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment