Commit 819f4c53 authored by Cezar Sa Espinola's avatar Cezar Sa Espinola Committed by Russ Cox

websocket: use net.Dialer to open tcp connection

This change adds a Dialer field to websocket.Config struct. If this
value is set the Dialer will be used. If it's nil, DialConfig will
create an empty Dialer that will maintain the original behavior.

Because before Go 1.3 there was no crypto/tls.DialWithDialer function,
the Dialer will be ignored when opening TLS connections in these
versions.

Fixes golang/go#9198.

Change-Id: If8b5c3c47019a3d367c436e3e60eb54bf0276184
Reviewed-on: https://go-review.googlesource.com/12191Reviewed-by: 's avatarRuss Cox <rsc@golang.org>
parent cd95c68b
......@@ -6,7 +6,6 @@ package websocket
import (
"bufio"
"crypto/tls"
"io"
"net"
"net/http"
......@@ -87,20 +86,14 @@ func DialConfig(config *Config) (ws *Conn, err error) {
if config.Origin == nil {
return nil, &DialError{config, ErrBadWebSocketOrigin}
}
switch config.Location.Scheme {
case "ws":
client, err = net.Dial("tcp", parseAuthority(config.Location))
case "wss":
client, err = tls.Dial("tcp", parseAuthority(config.Location), config.TlsConfig)
default:
err = ErrBadScheme
dialer := config.Dialer
if dialer == nil {
dialer = &net.Dialer{}
}
client, err = dialWithDialer(dialer, config)
if err != nil {
goto Error
}
ws, err = NewClient(config, client)
if err != nil {
client.Close()
......
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !go1.3
package websocket
import (
"crypto/tls"
"net"
)
func dialWithDialer(dialer *net.Dialer, config *Config) (conn net.Conn, err error) {
switch config.Location.Scheme {
case "ws":
conn, err = dialer.Dial("tcp", parseAuthority(config.Location))
case "wss":
conn, err = tls.Dial("tcp", parseAuthority(config.Location), config.TlsConfig)
default:
err = ErrBadScheme
}
return
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.3
// We only compile this with Go 1.3+ because previously tls.DialWithDialer
// wasn't available. The dial.go file is used for previous Go versions.
package websocket
import (
"crypto/tls"
"net"
)
func dialWithDialer(dialer *net.Dialer, config *Config) (conn net.Conn, err error) {
switch config.Location.Scheme {
case "ws":
conn, err = dialer.Dial("tcp", parseAuthority(config.Location))
case "wss":
conn, err = tls.DialWithDialer(dialer, "tcp", parseAuthority(config.Location), config.TlsConfig)
default:
err = ErrBadScheme
}
return
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.3
package websocket
import (
"crypto/tls"
"fmt"
"log"
"net"
"net/http/httptest"
"testing"
"time"
)
// This test depend on Go 1.3+ because in earlier versions the Dialer won't be
// used in TLS connections and a timeout won't be triggered.
func TestDialConfigTLSWithDialer(t *testing.T) {
tlsServer := httptest.NewTLSServer(nil)
tlsServerAddr := tlsServer.Listener.Addr().String()
log.Print("Test TLS WebSocket server listening on ", tlsServerAddr)
defer tlsServer.Close()
config, _ := NewConfig(fmt.Sprintf("wss://%s/echo", tlsServerAddr), "http://localhost")
config.Dialer = &net.Dialer{
Deadline: time.Now().Add(-time.Minute),
}
config.TlsConfig = &tls.Config{
InsecureSkipVerify: true,
}
_, err := DialConfig(config)
dialerr, ok := err.(*DialError)
if !ok {
t.Fatalf("DialError expected, got %#v", err)
}
neterr, ok := dialerr.Err.(*net.OpError)
if !ok {
t.Fatalf("net.OpError error expected, got %#v", dialerr.Err)
}
if !neterr.Timeout() {
t.Fatalf("expected timeout error, got %#v", neterr)
}
}
......@@ -86,6 +86,9 @@ type Config struct {
// Additional header fields to be sent in WebSocket opening handshake.
Header http.Header
// Dialer used when opening websocket connections.
Dialer *net.Dialer
handshakeData map[string]string
}
......
......@@ -357,6 +357,26 @@ func TestDialConfigBadVersion(t *testing.T) {
}
}
func TestDialConfigWithDialer(t *testing.T) {
once.Do(startServer)
config := newConfig(t, "/echo")
config.Dialer = &net.Dialer{
Deadline: time.Now().Add(-time.Minute),
}
_, err := DialConfig(config)
dialerr, ok := err.(*DialError)
if !ok {
t.Fatalf("DialError expected, got %#v", err)
}
neterr, ok := dialerr.Err.(*net.OpError)
if !ok {
t.Fatalf("net.OpError error expected, got %#v", dialerr.Err)
}
if !neterr.Timeout() {
t.Fatalf("expected timeout error, got %#v", neterr)
}
}
func TestSmallBuffer(t *testing.T) {
// http://code.google.com/p/go/issues/detail?id=1145
// Read should be able to handle reading a fragment of a frame.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment