Commit 75944861 authored by Mikio Hara's avatar Mikio Hara

internal/socks: add DialWithConn method to Dialer

This change adds DialWithConn method for allowing package users to use
own net.Conn implementations optionally.

Also makes the deprecated Dialer.Dial return a raw transport connection
instead of a forward proxy connection for preserving the backward
compatibility on proxy.Dialer.Dial method.

Fixes golang/go#25104.

Change-Id: I4259cd10e299c1e36406545708e9f6888191705a
Reviewed-on: https://go-review.googlesource.com/110135
Run-TryBot: Mikio Hara <mikioh.mikioh@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: 's avatarBrad Fitzpatrick <bradfitz@golang.org>
parent 4ef37e81
...@@ -17,19 +17,11 @@ import ( ...@@ -17,19 +17,11 @@ import (
"golang.org/x/net/internal/sockstest" "golang.org/x/net/internal/sockstest"
) )
const (
targetNetwork = "tcp6"
targetHostname = "fqdn.doesnotexist"
targetHostIP = "2001:db8::1"
targetPort = "5963"
)
func TestDial(t *testing.T) { func TestDial(t *testing.T) {
t.Run("Connect", func(t *testing.T) { t.Run("Connect", func(t *testing.T) {
ss, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired) ss, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
if err != nil { if err != nil {
t.Error(err) t.Fatal(err)
return
} }
defer ss.Close() defer ss.Close()
d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String()) d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
...@@ -41,21 +33,45 @@ func TestDial(t *testing.T) { ...@@ -41,21 +33,45 @@ func TestDial(t *testing.T) {
Username: "username", Username: "username",
Password: "password", Password: "password",
}).Authenticate }).Authenticate
c, err := d.Dial(targetNetwork, net.JoinHostPort(targetHostIP, targetPort)) c, err := d.DialContext(context.Background(), ss.TargetAddr().Network(), ss.TargetAddr().String())
if err == nil { if err != nil {
c.(*socks.Conn).BoundAddr() t.Fatal(err)
c.Close() }
c.(*socks.Conn).BoundAddr()
c.Close()
})
t.Run("ConnectWithConn", func(t *testing.T) {
ss, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
if err != nil {
t.Fatal(err)
}
defer ss.Close()
c, err := net.Dial(ss.Addr().Network(), ss.Addr().String())
if err != nil {
t.Fatal(err)
}
defer c.Close()
d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
d.AuthMethods = []socks.AuthMethod{
socks.AuthMethodNotRequired,
socks.AuthMethodUsernamePassword,
} }
d.Authenticate = (&socks.UsernamePassword{
Username: "username",
Password: "password",
}).Authenticate
a, err := d.DialWithConn(context.Background(), c, ss.TargetAddr().Network(), ss.TargetAddr().String())
if err != nil { if err != nil {
t.Error(err) t.Fatal(err)
return }
if _, ok := a.(*socks.Addr); !ok {
t.Fatalf("got %+v; want socks.Addr", a)
} }
}) })
t.Run("Cancel", func(t *testing.T) { t.Run("Cancel", func(t *testing.T) {
ss, err := sockstest.NewServer(sockstest.NoAuthRequired, blackholeCmdFunc) ss, err := sockstest.NewServer(sockstest.NoAuthRequired, blackholeCmdFunc)
if err != nil { if err != nil {
t.Error(err) t.Fatal(err)
return
} }
defer ss.Close() defer ss.Close()
d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String()) d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
...@@ -63,7 +79,7 @@ func TestDial(t *testing.T) { ...@@ -63,7 +79,7 @@ func TestDial(t *testing.T) {
defer cancel() defer cancel()
dialErr := make(chan error) dialErr := make(chan error)
go func() { go func() {
c, err := d.DialContext(ctx, ss.TargetAddr().Network(), net.JoinHostPort(targetHostname, targetPort)) c, err := d.DialContext(ctx, ss.TargetAddr().Network(), ss.TargetAddr().String())
if err == nil { if err == nil {
c.Close() c.Close()
} }
...@@ -73,41 +89,37 @@ func TestDial(t *testing.T) { ...@@ -73,41 +89,37 @@ func TestDial(t *testing.T) {
cancel() cancel()
err = <-dialErr err = <-dialErr
if perr, nerr := parseDialError(err); perr != context.Canceled && nerr == nil { if perr, nerr := parseDialError(err); perr != context.Canceled && nerr == nil {
t.Errorf("got %v; want context.Canceled or equivalent", err) t.Fatalf("got %v; want context.Canceled or equivalent", err)
return
} }
}) })
t.Run("Deadline", func(t *testing.T) { t.Run("Deadline", func(t *testing.T) {
ss, err := sockstest.NewServer(sockstest.NoAuthRequired, blackholeCmdFunc) ss, err := sockstest.NewServer(sockstest.NoAuthRequired, blackholeCmdFunc)
if err != nil { if err != nil {
t.Error(err) t.Fatal(err)
return
} }
defer ss.Close() defer ss.Close()
d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String()) d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond)) ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond))
defer cancel() defer cancel()
c, err := d.DialContext(ctx, ss.TargetAddr().Network(), net.JoinHostPort(targetHostname, targetPort)) c, err := d.DialContext(ctx, ss.TargetAddr().Network(), ss.TargetAddr().String())
if err == nil { if err == nil {
c.Close() c.Close()
} }
if perr, nerr := parseDialError(err); perr != context.DeadlineExceeded && nerr == nil { if perr, nerr := parseDialError(err); perr != context.DeadlineExceeded && nerr == nil {
t.Errorf("got %v; want context.DeadlineExceeded or equivalent", err) t.Fatalf("got %v; want context.DeadlineExceeded or equivalent", err)
return
} }
}) })
t.Run("WithRogueServer", func(t *testing.T) { t.Run("WithRogueServer", func(t *testing.T) {
ss, err := sockstest.NewServer(sockstest.NoAuthRequired, rogueCmdFunc) ss, err := sockstest.NewServer(sockstest.NoAuthRequired, rogueCmdFunc)
if err != nil { if err != nil {
t.Error(err) t.Fatal(err)
return
} }
defer ss.Close() defer ss.Close()
d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String()) d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
for i := 0; i < 2*len(rogueCmdList); i++ { for i := 0; i < 2*len(rogueCmdList); i++ {
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond)) ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond))
defer cancel() defer cancel()
c, err := d.DialContext(ctx, targetNetwork, net.JoinHostPort(targetHostIP, targetPort)) c, err := d.DialContext(ctx, ss.TargetAddr().Network(), ss.TargetAddr().String())
if err == nil { if err == nil {
t.Log(c.(*socks.Conn).BoundAddr()) t.Log(c.(*socks.Conn).BoundAddr())
c.Close() c.Close()
......
...@@ -149,20 +149,13 @@ type Dialer struct { ...@@ -149,20 +149,13 @@ type Dialer struct {
// See func Dial of the net package of standard library for a // See func Dial of the net package of standard library for a
// description of the network and address parameters. // description of the network and address parameters.
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
switch network { if err := d.validateTarget(network, address); err != nil {
case "tcp", "tcp6", "tcp4":
default:
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("network not implemented")}
}
switch d.cmd {
case CmdConnect, cmdBind:
default:
proxy, dst, _ := d.pathAddrs(address) proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("command not implemented")} return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
} }
if ctx == nil { if ctx == nil {
ctx = context.Background() proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")}
} }
var err error var err error
var c net.Conn var c net.Conn
...@@ -185,11 +178,69 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net. ...@@ -185,11 +178,69 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
return &Conn{Conn: c, boundAddr: a}, nil return &Conn{Conn: c, boundAddr: a}, nil
} }
// DialWithConn initiates a connection from SOCKS server to the target
// network and address using the connection c that is already
// connected to the SOCKS server.
//
// It returns the connection's local address assigned by the SOCKS
// server.
func (d *Dialer) DialWithConn(ctx context.Context, c net.Conn, network, address string) (net.Addr, error) {
if err := d.validateTarget(network, address); err != nil {
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
}
if ctx == nil {
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")}
}
a, err := d.connect(ctx, c, address)
if err != nil {
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
}
return a, nil
}
// Dial connects to the provided address on the provided network. // Dial connects to the provided address on the provided network.
// //
// Deprecated: Use DialContext instead. // Unlike DialContext, it returns a raw transport connection instead
// of a forward proxy connection.
//
// Deprecated: Use DialContext or DialWithConn instead.
func (d *Dialer) Dial(network, address string) (net.Conn, error) { func (d *Dialer) Dial(network, address string) (net.Conn, error) {
return d.DialContext(context.Background(), network, address) if err := d.validateTarget(network, address); err != nil {
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
}
var err error
var c net.Conn
if d.ProxyDial != nil {
c, err = d.ProxyDial(context.Background(), d.proxyNetwork, d.proxyAddress)
} else {
c, err = net.Dial(d.proxyNetwork, d.proxyAddress)
}
if err != nil {
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
}
if _, err := d.DialWithConn(context.Background(), c, network, address); err != nil {
return nil, err
}
return c, nil
}
func (d *Dialer) validateTarget(network, address string) error {
switch network {
case "tcp", "tcp6", "tcp4":
default:
return errors.New("network not implemented")
}
switch d.cmd {
case CmdConnect, cmdBind:
default:
return errors.New("command not implemented")
}
return nil
} }
func (d *Dialer) pathAddrs(address string) (proxy, dst net.Addr, err error) { func (d *Dialer) pathAddrs(address string) (proxy, dst net.Addr, err error) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment