Commit be0f3c28 authored by Filippo Valsorda's avatar Filippo Valsorda

crypto/tls: replace net.Pipe in tests with real TCP connections

crypto/tls is meant to work over network connections with buffering, not
synchronous connections, as explained in #24198. Tests based on net.Pipe
are unrealistic as reads and writes are matched one to one. Such tests
worked just thanks to the implementation details of the tls.Conn
internal buffering, and would break if for example the flush of the
first flight of the server was not entirely assimilated by the client
rawInput buffer before the client attempted to reply to the ServerHello.

Note that this might run into the Darwin network issues at #25696.

Fixed a few test races that were either hidden or synchronized by the
use of the in-memory net.Pipe.

Also, this gets us slightly more realistic benchmarks, reflecting some
syscall cost of Read and Write operations.

Change-Id: I5a597b3d7a81b8ccc776030cc837133412bf50f8
Reviewed-on: https://go-review.googlesource.com/c/142817
Run-TryBot: Filippo Valsorda <filippo@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: 's avatarBrad Fitzpatrick <bradfitz@golang.org>
parent 628403fd
...@@ -134,12 +134,13 @@ func TestCertificateSelection(t *testing.T) { ...@@ -134,12 +134,13 @@ func TestCertificateSelection(t *testing.T) {
// Run with multiple crypto configs to test the logic for computing TLS record overheads. // Run with multiple crypto configs to test the logic for computing TLS record overheads.
func runDynamicRecordSizingTest(t *testing.T, config *Config) { func runDynamicRecordSizingTest(t *testing.T, config *Config) {
clientConn, serverConn := net.Pipe() clientConn, serverConn := localPipe(t)
serverConfig := config.Clone() serverConfig := config.Clone()
serverConfig.DynamicRecordSizingDisabled = false serverConfig.DynamicRecordSizingDisabled = false
tlsConn := Server(serverConn, serverConfig) tlsConn := Server(serverConn, serverConfig)
handshakeDone := make(chan struct{})
recordSizesChan := make(chan []int, 1) recordSizesChan := make(chan []int, 1)
go func() { go func() {
// This goroutine performs a TLS handshake over clientConn and // This goroutine performs a TLS handshake over clientConn and
...@@ -153,6 +154,7 @@ func runDynamicRecordSizingTest(t *testing.T, config *Config) { ...@@ -153,6 +154,7 @@ func runDynamicRecordSizingTest(t *testing.T, config *Config) {
t.Errorf("Error from client handshake: %v", err) t.Errorf("Error from client handshake: %v", err)
return return
} }
close(handshakeDone)
var recordHeader [recordHeaderLen]byte var recordHeader [recordHeaderLen]byte
var record []byte var record []byte
...@@ -192,6 +194,7 @@ func runDynamicRecordSizingTest(t *testing.T, config *Config) { ...@@ -192,6 +194,7 @@ func runDynamicRecordSizingTest(t *testing.T, config *Config) {
if err := tlsConn.Handshake(); err != nil { if err := tlsConn.Handshake(); err != nil {
t.Fatalf("Error from server handshake: %s", err) t.Fatalf("Error from server handshake: %s", err)
} }
<-handshakeDone
// The server writes these plaintexts in order. // The server writes these plaintexts in order.
plaintext := bytes.Join([][]byte{ plaintext := bytes.Join([][]byte{
...@@ -269,7 +272,7 @@ func (conn *hairpinConn) Close() error { ...@@ -269,7 +272,7 @@ func (conn *hairpinConn) Close() error {
func TestHairpinInClose(t *testing.T) { func TestHairpinInClose(t *testing.T) {
// This tests that the underlying net.Conn can call back into the // This tests that the underlying net.Conn can call back into the
// tls.Conn when being closed without deadlocking. // tls.Conn when being closed without deadlocking.
client, server := net.Pipe() client, server := localPipe(t)
defer server.Close() defer server.Close()
defer client.Close() defer client.Close()
......
...@@ -179,7 +179,7 @@ func (test *clientTest) connFromCommand() (conn *recordingConn, child *exec.Cmd, ...@@ -179,7 +179,7 @@ func (test *clientTest) connFromCommand() (conn *recordingConn, child *exec.Cmd,
var pemOut bytes.Buffer var pemOut bytes.Buffer
pem.Encode(&pemOut, &pem.Block{Type: pemType + " PRIVATE KEY", Bytes: derBytes}) pem.Encode(&pemOut, &pem.Block{Type: pemType + " PRIVATE KEY", Bytes: derBytes})
keyPath := tempFile(string(pemOut.Bytes())) keyPath := tempFile(pemOut.String())
defer os.Remove(keyPath) defer os.Remove(keyPath)
var command []string var command []string
...@@ -293,7 +293,7 @@ func (test *clientTest) run(t *testing.T, write bool) { ...@@ -293,7 +293,7 @@ func (test *clientTest) run(t *testing.T, write bool) {
} }
clientConn = recordingConn clientConn = recordingConn
} else { } else {
clientConn, serverConn = net.Pipe() clientConn, serverConn = localPipe(t)
} }
config := test.config config := test.config
...@@ -682,7 +682,7 @@ func TestClientResumption(t *testing.T) { ...@@ -682,7 +682,7 @@ func TestClientResumption(t *testing.T) {
} }
testResumeState := func(test string, didResume bool) { testResumeState := func(test string, didResume bool) {
_, hs, err := testHandshake(clientConfig, serverConfig) _, hs, err := testHandshake(t, clientConfig, serverConfig)
if err != nil { if err != nil {
t.Fatalf("%s: handshake failed: %s", test, err) t.Fatalf("%s: handshake failed: %s", test, err)
} }
...@@ -800,7 +800,7 @@ func TestKeyLog(t *testing.T) { ...@@ -800,7 +800,7 @@ func TestKeyLog(t *testing.T) {
serverConfig := testConfig.Clone() serverConfig := testConfig.Clone()
serverConfig.KeyLogWriter = &serverBuf serverConfig.KeyLogWriter = &serverBuf
c, s := net.Pipe() c, s := localPipe(t)
done := make(chan bool) done := make(chan bool)
go func() { go func() {
...@@ -838,8 +838,8 @@ func TestKeyLog(t *testing.T) { ...@@ -838,8 +838,8 @@ func TestKeyLog(t *testing.T) {
} }
} }
checkKeylogLine("client", string(clientBuf.Bytes())) checkKeylogLine("client", clientBuf.String())
checkKeylogLine("server", string(serverBuf.Bytes())) checkKeylogLine("server", serverBuf.String())
} }
func TestHandshakeClientALPNMatch(t *testing.T) { func TestHandshakeClientALPNMatch(t *testing.T) {
...@@ -1021,7 +1021,7 @@ var hostnameInSNITests = []struct { ...@@ -1021,7 +1021,7 @@ var hostnameInSNITests = []struct {
func TestHostnameInSNI(t *testing.T) { func TestHostnameInSNI(t *testing.T) {
for _, tt := range hostnameInSNITests { for _, tt := range hostnameInSNITests {
c, s := net.Pipe() c, s := localPipe(t)
go func(host string) { go func(host string) {
Client(c, &Config{ServerName: host, InsecureSkipVerify: true}).Handshake() Client(c, &Config{ServerName: host, InsecureSkipVerify: true}).Handshake()
...@@ -1059,7 +1059,7 @@ func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) { ...@@ -1059,7 +1059,7 @@ func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) {
// This checks that the server can't select a cipher suite that the // This checks that the server can't select a cipher suite that the
// client didn't offer. See #13174. // client didn't offer. See #13174.
c, s := net.Pipe() c, s := localPipe(t)
errChan := make(chan error, 1) errChan := make(chan error, 1)
go func() { go func() {
...@@ -1228,7 +1228,7 @@ func TestVerifyPeerCertificate(t *testing.T) { ...@@ -1228,7 +1228,7 @@ func TestVerifyPeerCertificate(t *testing.T) {
} }
for i, test := range tests { for i, test := range tests {
c, s := net.Pipe() c, s := localPipe(t)
done := make(chan error) done := make(chan error)
var clientCalled, serverCalled bool var clientCalled, serverCalled bool
...@@ -1287,7 +1287,7 @@ func (b *brokenConn) Write(data []byte) (int, error) { ...@@ -1287,7 +1287,7 @@ func (b *brokenConn) Write(data []byte) (int, error) {
func TestFailedWrite(t *testing.T) { func TestFailedWrite(t *testing.T) {
// Test that a write error during the handshake is returned. // Test that a write error during the handshake is returned.
for _, breakAfter := range []int{0, 1} { for _, breakAfter := range []int{0, 1} {
c, s := net.Pipe() c, s := localPipe(t)
done := make(chan bool) done := make(chan bool)
go func() { go func() {
...@@ -1321,7 +1321,7 @@ func (wcc *writeCountingConn) Write(data []byte) (int, error) { ...@@ -1321,7 +1321,7 @@ func (wcc *writeCountingConn) Write(data []byte) (int, error) {
} }
func TestBuffering(t *testing.T) { func TestBuffering(t *testing.T) {
c, s := net.Pipe() c, s := localPipe(t)
done := make(chan bool) done := make(chan bool)
clientWCC := &writeCountingConn{Conn: c} clientWCC := &writeCountingConn{Conn: c}
...@@ -1350,7 +1350,7 @@ func TestBuffering(t *testing.T) { ...@@ -1350,7 +1350,7 @@ func TestBuffering(t *testing.T) {
} }
func TestAlertFlushing(t *testing.T) { func TestAlertFlushing(t *testing.T) {
c, s := net.Pipe() c, s := localPipe(t)
done := make(chan bool) done := make(chan bool)
clientWCC := &writeCountingConn{Conn: c} clientWCC := &writeCountingConn{Conn: c}
...@@ -1399,7 +1399,7 @@ func TestHandshakeRace(t *testing.T) { ...@@ -1399,7 +1399,7 @@ func TestHandshakeRace(t *testing.T) {
// order to provide some evidence that there are no races or deadlocks // order to provide some evidence that there are no races or deadlocks
// in the handshake locking. // in the handshake locking.
for i := 0; i < 32; i++ { for i := 0; i < 32; i++ {
c, s := net.Pipe() c, s := localPipe(t)
go func() { go func() {
server := Server(s, testConfig) server := Server(s, testConfig)
...@@ -1430,7 +1430,7 @@ func TestHandshakeRace(t *testing.T) { ...@@ -1430,7 +1430,7 @@ func TestHandshakeRace(t *testing.T) {
go func() { go func() {
<-startRead <-startRead
var reply [1]byte var reply [1]byte
if n, err := client.Read(reply[:]); err != nil || n != 1 { if _, err := io.ReadFull(client, reply[:]); err != nil {
panic(err) panic(err)
} }
c.Close() c.Close()
...@@ -1559,7 +1559,7 @@ func TestGetClientCertificate(t *testing.T) { ...@@ -1559,7 +1559,7 @@ func TestGetClientCertificate(t *testing.T) {
err error err error
} }
c, s := net.Pipe() c, s := localPipe(t)
done := make(chan serverResult) done := make(chan serverResult)
go func() { go func() {
...@@ -1637,7 +1637,7 @@ RwBA9Xk1KBNF ...@@ -1637,7 +1637,7 @@ RwBA9Xk1KBNF
} }
func TestCloseClientConnectionOnIdleServer(t *testing.T) { func TestCloseClientConnectionOnIdleServer(t *testing.T) {
clientConn, serverConn := net.Pipe() clientConn, serverConn := localPipe(t)
client := Client(clientConn, testConfig.Clone()) client := Client(clientConn, testConfig.Clone())
go func() { go func() {
var b [1]byte var b [1]byte
...@@ -1647,8 +1647,8 @@ func TestCloseClientConnectionOnIdleServer(t *testing.T) { ...@@ -1647,8 +1647,8 @@ func TestCloseClientConnectionOnIdleServer(t *testing.T) {
client.SetWriteDeadline(time.Now().Add(time.Second)) client.SetWriteDeadline(time.Now().Add(time.Second))
err := client.Handshake() err := client.Handshake()
if err != nil { if err != nil {
if !strings.Contains(err.Error(), "read/write on closed pipe") { if err, ok := err.(net.Error); ok && err.Timeout() {
t.Errorf("Error expected containing 'read/write on closed pipe' but got '%s'", err.Error()) t.Errorf("Expected a closed network connection error but got '%s'", err.Error())
} }
} else { } else {
t.Errorf("Error expected, but no error returned") t.Errorf("Error expected, but no error returned")
......
...@@ -70,10 +70,7 @@ func testClientHello(t *testing.T, serverConfig *Config, m handshakeMessage) { ...@@ -70,10 +70,7 @@ func testClientHello(t *testing.T, serverConfig *Config, m handshakeMessage) {
} }
func testClientHelloFailure(t *testing.T, serverConfig *Config, m handshakeMessage, expectedSubStr string) { func testClientHelloFailure(t *testing.T, serverConfig *Config, m handshakeMessage, expectedSubStr string) {
// Create in-memory network connection, c, s := localPipe(t)
// send message to server. Should return
// expected error.
c, s := net.Pipe()
go func() { go func() {
cli := Client(c, testConfig) cli := Client(c, testConfig)
if ch, ok := m.(*clientHelloMsg); ok { if ch, ok := m.(*clientHelloMsg); ok {
...@@ -201,25 +198,26 @@ func TestRenegotiationExtension(t *testing.T) { ...@@ -201,25 +198,26 @@ func TestRenegotiationExtension(t *testing.T) {
cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA},
} }
var buf []byte bufChan := make(chan []byte)
c, s := net.Pipe() c, s := localPipe(t)
go func() { go func() {
cli := Client(c, testConfig) cli := Client(c, testConfig)
cli.vers = clientHello.vers cli.vers = clientHello.vers
cli.writeRecord(recordTypeHandshake, clientHello.marshal()) cli.writeRecord(recordTypeHandshake, clientHello.marshal())
buf = make([]byte, 1024) buf := make([]byte, 1024)
n, err := c.Read(buf) n, err := c.Read(buf)
if err != nil { if err != nil {
t.Errorf("Server read returned error: %s", err) t.Errorf("Server read returned error: %s", err)
return return
} }
buf = buf[:n]
c.Close() c.Close()
bufChan <- buf[:n]
}() }()
Server(s, testConfig).Handshake() Server(s, testConfig).Handshake()
buf := <-bufChan
if len(buf) < 5+4 { if len(buf) < 5+4 {
t.Fatalf("Server returned short message of length %d", len(buf)) t.Fatalf("Server returned short message of length %d", len(buf))
...@@ -262,22 +260,27 @@ func TestTLS12OnlyCipherSuites(t *testing.T) { ...@@ -262,22 +260,27 @@ func TestTLS12OnlyCipherSuites(t *testing.T) {
supportedPoints: []uint8{pointFormatUncompressed}, supportedPoints: []uint8{pointFormatUncompressed},
} }
c, s := net.Pipe() c, s := localPipe(t)
var reply interface{} replyChan := make(chan interface{})
var clientErr error
go func() { go func() {
cli := Client(c, testConfig) cli := Client(c, testConfig)
cli.vers = clientHello.vers cli.vers = clientHello.vers
cli.writeRecord(recordTypeHandshake, clientHello.marshal()) cli.writeRecord(recordTypeHandshake, clientHello.marshal())
reply, clientErr = cli.readHandshake() reply, err := cli.readHandshake()
c.Close() c.Close()
if err != nil {
replyChan <- err
} else {
replyChan <- reply
}
}() }()
config := testConfig.Clone() config := testConfig.Clone()
config.CipherSuites = clientHello.cipherSuites config.CipherSuites = clientHello.cipherSuites
Server(s, config).Handshake() Server(s, config).Handshake()
s.Close() s.Close()
if clientErr != nil { reply := <-replyChan
t.Fatal(clientErr) if err, ok := reply.(error); ok {
t.Fatal(err)
} }
serverHello, ok := reply.(*serverHelloMsg) serverHello, ok := reply.(*serverHelloMsg)
if !ok { if !ok {
...@@ -289,7 +292,7 @@ func TestTLS12OnlyCipherSuites(t *testing.T) { ...@@ -289,7 +292,7 @@ func TestTLS12OnlyCipherSuites(t *testing.T) {
} }
func TestAlertForwarding(t *testing.T) { func TestAlertForwarding(t *testing.T) {
c, s := net.Pipe() c, s := localPipe(t)
go func() { go func() {
Client(c, testConfig).sendAlert(alertUnknownCA) Client(c, testConfig).sendAlert(alertUnknownCA)
c.Close() c.Close()
...@@ -303,7 +306,7 @@ func TestAlertForwarding(t *testing.T) { ...@@ -303,7 +306,7 @@ func TestAlertForwarding(t *testing.T) {
} }
func TestClose(t *testing.T) { func TestClose(t *testing.T) {
c, s := net.Pipe() c, s := localPipe(t)
go c.Close() go c.Close()
err := Server(s, testConfig).Handshake() err := Server(s, testConfig).Handshake()
...@@ -313,8 +316,8 @@ func TestClose(t *testing.T) { ...@@ -313,8 +316,8 @@ func TestClose(t *testing.T) {
} }
} }
func testHandshake(clientConfig, serverConfig *Config) (serverState, clientState ConnectionState, err error) { func testHandshake(t *testing.T, clientConfig, serverConfig *Config) (serverState, clientState ConnectionState, err error) {
c, s := net.Pipe() c, s := localPipe(t)
done := make(chan bool) done := make(chan bool)
go func() { go func() {
cli := Client(c, clientConfig) cli := Client(c, clientConfig)
...@@ -341,7 +344,7 @@ func TestVersion(t *testing.T) { ...@@ -341,7 +344,7 @@ func TestVersion(t *testing.T) {
clientConfig := &Config{ clientConfig := &Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
} }
state, _, err := testHandshake(clientConfig, serverConfig) state, _, err := testHandshake(t, clientConfig, serverConfig)
if err != nil { if err != nil {
t.Fatalf("handshake failed: %s", err) t.Fatalf("handshake failed: %s", err)
} }
...@@ -360,7 +363,7 @@ func TestCipherSuitePreference(t *testing.T) { ...@@ -360,7 +363,7 @@ func TestCipherSuitePreference(t *testing.T) {
CipherSuites: []uint16{TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_RC4_128_SHA}, CipherSuites: []uint16{TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_RC4_128_SHA},
InsecureSkipVerify: true, InsecureSkipVerify: true,
} }
state, _, err := testHandshake(clientConfig, serverConfig) state, _, err := testHandshake(t, clientConfig, serverConfig)
if err != nil { if err != nil {
t.Fatalf("handshake failed: %s", err) t.Fatalf("handshake failed: %s", err)
} }
...@@ -370,7 +373,7 @@ func TestCipherSuitePreference(t *testing.T) { ...@@ -370,7 +373,7 @@ func TestCipherSuitePreference(t *testing.T) {
} }
serverConfig.PreferServerCipherSuites = true serverConfig.PreferServerCipherSuites = true
state, _, err = testHandshake(clientConfig, serverConfig) state, _, err = testHandshake(t, clientConfig, serverConfig)
if err != nil { if err != nil {
t.Fatalf("handshake failed: %s", err) t.Fatalf("handshake failed: %s", err)
} }
...@@ -391,7 +394,7 @@ func TestSCTHandshake(t *testing.T) { ...@@ -391,7 +394,7 @@ func TestSCTHandshake(t *testing.T) {
clientConfig := &Config{ clientConfig := &Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
} }
_, state, err := testHandshake(clientConfig, serverConfig) _, state, err := testHandshake(t, clientConfig, serverConfig)
if err != nil { if err != nil {
t.Fatalf("handshake failed: %s", err) t.Fatalf("handshake failed: %s", err)
} }
...@@ -420,13 +423,13 @@ func TestCrossVersionResume(t *testing.T) { ...@@ -420,13 +423,13 @@ func TestCrossVersionResume(t *testing.T) {
// Establish a session at TLS 1.1. // Establish a session at TLS 1.1.
clientConfig.MaxVersion = VersionTLS11 clientConfig.MaxVersion = VersionTLS11
_, _, err := testHandshake(clientConfig, serverConfig) _, _, err := testHandshake(t, clientConfig, serverConfig)
if err != nil { if err != nil {
t.Fatalf("handshake failed: %s", err) t.Fatalf("handshake failed: %s", err)
} }
// The client session cache now contains a TLS 1.1 session. // The client session cache now contains a TLS 1.1 session.
state, _, err := testHandshake(clientConfig, serverConfig) state, _, err := testHandshake(t, clientConfig, serverConfig)
if err != nil { if err != nil {
t.Fatalf("handshake failed: %s", err) t.Fatalf("handshake failed: %s", err)
} }
...@@ -436,7 +439,7 @@ func TestCrossVersionResume(t *testing.T) { ...@@ -436,7 +439,7 @@ func TestCrossVersionResume(t *testing.T) {
// Test that the server will decline to resume at a lower version. // Test that the server will decline to resume at a lower version.
clientConfig.MaxVersion = VersionTLS10 clientConfig.MaxVersion = VersionTLS10
state, _, err = testHandshake(clientConfig, serverConfig) state, _, err = testHandshake(t, clientConfig, serverConfig)
if err != nil { if err != nil {
t.Fatalf("handshake failed: %s", err) t.Fatalf("handshake failed: %s", err)
} }
...@@ -445,7 +448,7 @@ func TestCrossVersionResume(t *testing.T) { ...@@ -445,7 +448,7 @@ func TestCrossVersionResume(t *testing.T) {
} }
// The client session cache now contains a TLS 1.0 session. // The client session cache now contains a TLS 1.0 session.
state, _, err = testHandshake(clientConfig, serverConfig) state, _, err = testHandshake(t, clientConfig, serverConfig)
if err != nil { if err != nil {
t.Fatalf("handshake failed: %s", err) t.Fatalf("handshake failed: %s", err)
} }
...@@ -455,7 +458,7 @@ func TestCrossVersionResume(t *testing.T) { ...@@ -455,7 +458,7 @@ func TestCrossVersionResume(t *testing.T) {
// Test that the server will decline to resume at a higher version. // Test that the server will decline to resume at a higher version.
clientConfig.MaxVersion = VersionTLS11 clientConfig.MaxVersion = VersionTLS11
state, _, err = testHandshake(clientConfig, serverConfig) state, _, err = testHandshake(t, clientConfig, serverConfig)
if err != nil { if err != nil {
t.Fatalf("handshake failed: %s", err) t.Fatalf("handshake failed: %s", err)
} }
...@@ -579,7 +582,7 @@ func (test *serverTest) run(t *testing.T, write bool) { ...@@ -579,7 +582,7 @@ func (test *serverTest) run(t *testing.T, write bool) {
} }
serverConn = recordingConn serverConn = recordingConn
} else { } else {
clientConn, serverConn = net.Pipe() clientConn, serverConn = localPipe(t)
} }
config := test.config config := test.config
if config == nil { if config == nil {
...@@ -832,7 +835,7 @@ func TestHandshakeServerSNIGetCertificate(t *testing.T) { ...@@ -832,7 +835,7 @@ func TestHandshakeServerSNIGetCertificate(t *testing.T) {
nameToCert := config.NameToCertificate nameToCert := config.NameToCertificate
config.NameToCertificate = nil config.NameToCertificate = nil
config.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) { config.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) {
cert, _ := nameToCert[clientHello.ServerName] cert := nameToCert[clientHello.ServerName]
return cert, nil return cert, nil
} }
test := &serverTest{ test := &serverTest{
...@@ -1025,7 +1028,7 @@ func benchmarkHandshakeServer(b *testing.B, cipherSuite uint16, curve CurveID, c ...@@ -1025,7 +1028,7 @@ func benchmarkHandshakeServer(b *testing.B, cipherSuite uint16, curve CurveID, c
config.Certificates[0].PrivateKey = key config.Certificates[0].PrivateKey = key
config.BuildNameToCertificate() config.BuildNameToCertificate()
clientConn, serverConn := net.Pipe() clientConn, serverConn := localPipe(b)
serverConn = &recordingConn{Conn: serverConn} serverConn = &recordingConn{Conn: serverConn}
go func() { go func() {
client := Client(clientConn, testConfig) client := Client(clientConn, testConfig)
...@@ -1039,7 +1042,7 @@ func benchmarkHandshakeServer(b *testing.B, cipherSuite uint16, curve CurveID, c ...@@ -1039,7 +1042,7 @@ func benchmarkHandshakeServer(b *testing.B, cipherSuite uint16, curve CurveID, c
flows := serverConn.(*recordingConn).flows flows := serverConn.(*recordingConn).flows
feeder := make(chan struct{}) feeder := make(chan struct{})
clientConn, serverConn = net.Pipe() clientConn, serverConn = localPipe(b)
go func() { go func() {
for range feeder { for range feeder {
...@@ -1051,10 +1054,10 @@ func benchmarkHandshakeServer(b *testing.B, cipherSuite uint16, curve CurveID, c ...@@ -1051,10 +1054,10 @@ func benchmarkHandshakeServer(b *testing.B, cipherSuite uint16, curve CurveID, c
ff := make([]byte, len(f)) ff := make([]byte, len(f))
n, err := io.ReadFull(clientConn, ff) n, err := io.ReadFull(clientConn, ff)
if err != nil { if err != nil {
b.Fatalf("#%d: %s\nRead %d, wanted %d, got %x, wanted %x\n", i+1, err, n, len(ff), ff[:n], f) b.Errorf("#%d: %s\nRead %d, wanted %d, got %x, wanted %x\n", i+1, err, n, len(ff), ff[:n], f)
} }
if !bytes.Equal(f, ff) { if !bytes.Equal(f, ff) {
b.Fatalf("#%d: mismatch on read: got:%x want:%x", i+1, ff, f) b.Errorf("#%d: mismatch on read: got:%x want:%x", i+1, ff, f)
} }
} }
} }
...@@ -1216,7 +1219,7 @@ func TestSNIGivenOnFailure(t *testing.T) { ...@@ -1216,7 +1219,7 @@ func TestSNIGivenOnFailure(t *testing.T) {
// Erase the server's cipher suites to ensure the handshake fails. // Erase the server's cipher suites to ensure the handshake fails.
serverConfig.CipherSuites = nil serverConfig.CipherSuites = nil
c, s := net.Pipe() c, s := localPipe(t)
go func() { go func() {
cli := Client(c, testConfig) cli := Client(c, testConfig)
cli.vers = clientHello.vers cli.vers = clientHello.vers
...@@ -1346,7 +1349,7 @@ func TestGetConfigForClient(t *testing.T) { ...@@ -1346,7 +1349,7 @@ func TestGetConfigForClient(t *testing.T) {
configReturned = config configReturned = config
return config, err return config, err
} }
c, s := net.Pipe() c, s := localPipe(t)
done := make(chan error) done := make(chan error)
go func() { go func() {
...@@ -1423,7 +1426,7 @@ var testECDSAPrivateKey = &ecdsa.PrivateKey{ ...@@ -1423,7 +1426,7 @@ var testECDSAPrivateKey = &ecdsa.PrivateKey{
var testP256PrivateKey, _ = x509.ParseECPrivateKey(fromHex("30770201010420012f3b52bc54c36ba3577ad45034e2e8efe1e6999851284cb848725cfe029991a00a06082a8648ce3d030107a14403420004c02c61c9b16283bbcc14956d886d79b358aa614596975f78cece787146abf74c2d5dc578c0992b4f3c631373479ebf3892efe53d21c4f4f1cc9a11c3536b7f75")) var testP256PrivateKey, _ = x509.ParseECPrivateKey(fromHex("30770201010420012f3b52bc54c36ba3577ad45034e2e8efe1e6999851284cb848725cfe029991a00a06082a8648ce3d030107a14403420004c02c61c9b16283bbcc14956d886d79b358aa614596975f78cece787146abf74c2d5dc578c0992b4f3c631373479ebf3892efe53d21c4f4f1cc9a11c3536b7f75"))
func TestCloseServerConnectionOnIdleClient(t *testing.T) { func TestCloseServerConnectionOnIdleClient(t *testing.T) {
clientConn, serverConn := net.Pipe() clientConn, serverConn := localPipe(t)
server := Server(serverConn, testConfig.Clone()) server := Server(serverConn, testConfig.Clone())
go func() { go func() {
clientConn.Write([]byte{'0'}) clientConn.Write([]byte{'0'})
...@@ -1432,8 +1435,8 @@ func TestCloseServerConnectionOnIdleClient(t *testing.T) { ...@@ -1432,8 +1435,8 @@ func TestCloseServerConnectionOnIdleClient(t *testing.T) {
server.SetReadDeadline(time.Now().Add(time.Second)) server.SetReadDeadline(time.Now().Add(time.Second))
err := server.Handshake() err := server.Handshake()
if err != nil { if err != nil {
if !strings.Contains(err.Error(), "read/write on closed pipe") { if err, ok := err.(net.Error); ok && err.Timeout() {
t.Errorf("Error expected containing 'read/write on closed pipe' but got '%s'", err.Error()) t.Errorf("Expected a closed network connection error but got '%s'", err.Error())
} }
} else { } else {
t.Errorf("Error expected, but no error returned") t.Errorf("Error expected, but no error returned")
......
...@@ -13,6 +13,7 @@ import ( ...@@ -13,6 +13,7 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"net" "net"
"os"
"os/exec" "os/exec"
"strconv" "strconv"
"strings" "strings"
...@@ -224,3 +225,45 @@ func tempFile(contents string) string { ...@@ -224,3 +225,45 @@ func tempFile(contents string) string {
file.Close() file.Close()
return path return path
} }
// localListener is set up by TestMain and used by localPipe to create Conn
// pairs like net.Pipe, but connected by an actual buffered TCP connection.
var localListener struct {
sync.Mutex
net.Listener
}
func localPipe(t testing.TB) (net.Conn, net.Conn) {
localListener.Lock()
defer localListener.Unlock()
c := make(chan net.Conn)
go func() {
conn, err := localListener.Accept()
if err != nil {
t.Errorf("Failed to accept local connection: %v", err)
}
c <- conn
}()
addr := localListener.Addr()
c1, err := net.Dial(addr.Network(), addr.String())
if err != nil {
t.Fatalf("Failed to dial local connection: %v", err)
}
c2 := <-c
return c1, c2
}
func TestMain(m *testing.M) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
l, err = net.Listen("tcp6", "[::1]:0")
}
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to open local listener: %v", err)
os.Exit(1)
}
localListener.Listener = l
exitCode := m.Run()
localListener.Close()
os.Exit(exitCode)
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment