Commit 72d93220 authored by Russ Cox's avatar Russ Cox

crypto/tls: simpler implementation of record layer

Depends on CL 957045, 980043, 1004043.
Fixes #715.

R=agl1, agl
CC=golang-dev
https://golang.org/cl/943043
parent 47a05334
......@@ -7,15 +7,13 @@ include ../../../Make.$(GOARCH)
TARG=crypto/tls
GOFILES=\
alert.go\
ca_set.go\
common.go\
conn.go\
handshake_client.go\
handshake_messages.go\
handshake_server.go\
prf.go\
record_process.go\
record_read.go\
record_write.go\
ca_set.go\
tls.go\
include ../../../Make.pkg
......@@ -4,40 +4,70 @@
package tls
type alertLevel int
type alertType int
import "strconv"
type alert uint8
const (
alertLevelWarning alertLevel = 1
alertLevelError alertLevel = 2
// alert level
alertLevelWarning = 1
alertLevelError = 2
)
const (
alertCloseNotify alertType = 0
alertUnexpectedMessage alertType = 10
alertBadRecordMAC alertType = 20
alertDecryptionFailed alertType = 21
alertRecordOverflow alertType = 22
alertDecompressionFailure alertType = 30
alertHandshakeFailure alertType = 40
alertBadCertificate alertType = 42
alertUnsupportedCertificate alertType = 43
alertCertificateRevoked alertType = 44
alertCertificateExpired alertType = 45
alertCertificateUnknown alertType = 46
alertIllegalParameter alertType = 47
alertUnknownCA alertType = 48
alertAccessDenied alertType = 49
alertDecodeError alertType = 50
alertDecryptError alertType = 51
alertProtocolVersion alertType = 70
alertInsufficientSecurity alertType = 71
alertInternalError alertType = 80
alertUserCanceled alertType = 90
alertNoRenegotiation alertType = 100
alertCloseNotify alert = 0
alertUnexpectedMessage alert = 10
alertBadRecordMAC alert = 20
alertDecryptionFailed alert = 21
alertRecordOverflow alert = 22
alertDecompressionFailure alert = 30
alertHandshakeFailure alert = 40
alertBadCertificate alert = 42
alertUnsupportedCertificate alert = 43
alertCertificateRevoked alert = 44
alertCertificateExpired alert = 45
alertCertificateUnknown alert = 46
alertIllegalParameter alert = 47
alertUnknownCA alert = 48
alertAccessDenied alert = 49
alertDecodeError alert = 50
alertDecryptError alert = 51
alertProtocolVersion alert = 70
alertInsufficientSecurity alert = 71
alertInternalError alert = 80
alertUserCanceled alert = 90
alertNoRenegotiation alert = 100
)
type alert struct {
level alertLevel
error alertType
var alertText = map[alert]string{
alertCloseNotify: "close notify",
alertUnexpectedMessage: "unexpected message",
alertBadRecordMAC: "bad record MAC",
alertDecryptionFailed: "decryption failed",
alertRecordOverflow: "record overflow",
alertDecompressionFailure: "decompression failure",
alertHandshakeFailure: "handshake failure",
alertBadCertificate: "bad certificate",
alertUnsupportedCertificate: "unsupported certificate",
alertCertificateRevoked: "revoked certificate",
alertCertificateExpired: "expired certificate",
alertCertificateUnknown: "unknown certificate",
alertIllegalParameter: "illegal parameter",
alertUnknownCA: "unknown certificate authority",
alertAccessDenied: "access denied",
alertDecodeError: "error decoding message",
alertDecryptError: "error decrypting message",
alertProtocolVersion: "protocol version not supported",
alertInsufficientSecurity: "insufficient security level",
alertInternalError: "internal error",
alertUserCanceled: "user canceled",
alertNoRenegotiation: "no renegotiation",
}
func (e alert) String() string {
s, ok := alertText[e]
if ok {
return s
}
return "alert(" + strconv.Itoa(int(e)) + ")"
}
......@@ -10,22 +10,18 @@ import (
"io"
"io/ioutil"
"once"
"os"
"time"
)
const (
// maxTLSCiphertext is the maximum length of a plaintext payload.
maxTLSPlaintext = 16384
// maxTLSCiphertext is the maximum length payload after compression and encryption.
maxTLSCiphertext = 16384 + 2048
// maxHandshakeMsg is the largest single handshake message that we'll buffer.
maxHandshakeMsg = 65536
// defaultMajor and defaultMinor are the maximum TLS version that we support.
defaultMajor = 3
defaultMinor = 2
)
maxPlaintext = 16384 // maximum plaintext payload length
maxCiphertext = 16384 + 2048 // maximum ciphertext payload length
recordHeaderLen = 5 // record header length
maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB)
minVersion = 0x0301 // minimum supported version - TLS 1.0
maxVersion = 0x0302 // maximum supported version - TLS 1.1
)
// TLS record types.
type recordType uint8
......@@ -67,7 +63,7 @@ var (
type ConnectionState struct {
HandshakeComplete bool
CipherSuite string
Error alertType
Error alert
NegotiatedProtocol string
}
......@@ -99,6 +95,7 @@ type record struct {
type handshakeMessage interface {
marshal() []byte
unmarshal([]byte) bool
}
type encryptor interface {
......@@ -108,34 +105,16 @@ type encryptor interface {
// mutualVersion returns the protocol version to use given the advertised
// version of the peer.
func mutualVersion(theirMajor, theirMinor uint8) (major, minor uint8, ok bool) {
// We don't deal with peers < TLS 1.0 (aka version 3.1).
if theirMajor < 3 || theirMajor == 3 && theirMinor < 1 {
return 0, 0, false
func mutualVersion(vers uint16) (uint16, bool) {
if vers < minVersion {
return 0, false
}
major = 3
minor = 2
if theirMinor < minor {
minor = theirMinor
if vers > maxVersion {
vers = maxVersion
}
ok = true
return
return vers, true
}
// A nop implements the NULL encryption and MAC algorithms.
type nop struct{}
func (nop) XORKeyStream(buf []byte) {}
func (nop) Write(buf []byte) (int, os.Error) { return len(buf), nil }
func (nop) Sum() []byte { return nil }
func (nop) Reset() {}
func (nop) Size() int { return 0 }
// The defaultConfig is used in place of a nil *Config in the TLS server and client.
var varDefaultConfig *Config
......
This diff is collapsed.
......@@ -12,74 +12,63 @@ import (
"crypto/subtle"
"crypto/x509"
"io"
"os"
)
// A serverHandshake performs the server side of the TLS 1.1 handshake protocol.
type clientHandshake struct {
writeChan chan<- interface{}
controlChan chan<- interface{}
msgChan <-chan interface{}
config *Config
}
func (h *clientHandshake) loop(writeChan chan<- interface{}, controlChan chan<- interface{}, msgChan <-chan interface{}, config *Config) {
h.writeChan = writeChan
h.controlChan = controlChan
h.msgChan = msgChan
h.config = config
defer close(writeChan)
defer close(controlChan)
func (c *Conn) clientHandshake() os.Error {
finishedHash := newFinishedHash()
config := defaultConfig()
hello := &clientHelloMsg{
major: defaultMajor,
minor: defaultMinor,
vers: maxVersion,
cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA},
compressionMethods: []uint8{compressionNone},
random: make([]byte, 32),
}
currentTime := uint32(config.Time())
hello.random[0] = byte(currentTime >> 24)
hello.random[1] = byte(currentTime >> 16)
hello.random[2] = byte(currentTime >> 8)
hello.random[3] = byte(currentTime)
t := uint32(config.Time())
hello.random[0] = byte(t >> 24)
hello.random[1] = byte(t >> 16)
hello.random[2] = byte(t >> 8)
hello.random[3] = byte(t)
_, err := io.ReadFull(config.Rand, hello.random[4:])
if err != nil {
h.error(alertInternalError)
return
return c.sendAlert(alertInternalError)
}
finishedHash.Write(hello.marshal())
writeChan <- writerSetVersion{defaultMajor, defaultMinor}
writeChan <- hello
c.writeRecord(recordTypeHandshake, hello.marshal())
serverHello, ok := h.readHandshakeMsg().(*serverHelloMsg)
msg, err := c.readHandshake()
if err != nil {
return err
}
serverHello, ok := msg.(*serverHelloMsg)
if !ok {
h.error(alertUnexpectedMessage)
return
return c.sendAlert(alertUnexpectedMessage)
}
finishedHash.Write(serverHello.marshal())
major, minor, ok := mutualVersion(serverHello.major, serverHello.minor)
vers, ok := mutualVersion(serverHello.vers)
if !ok {
h.error(alertProtocolVersion)
return
c.sendAlert(alertProtocolVersion)
}
writeChan <- writerSetVersion{major, minor}
c.vers = vers
c.haveVers = true
if serverHello.cipherSuite != TLS_RSA_WITH_RC4_128_SHA ||
serverHello.compressionMethod != compressionNone {
h.error(alertUnexpectedMessage)
return
return c.sendAlert(alertUnexpectedMessage)
}
certMsg, ok := h.readHandshakeMsg().(*certificateMsg)
msg, err = c.readHandshake()
if err != nil {
return err
}
certMsg, ok := msg.(*certificateMsg)
if !ok || len(certMsg.certificates) == 0 {
h.error(alertUnexpectedMessage)
return
return c.sendAlert(alertUnexpectedMessage)
}
finishedHash.Write(certMsg.marshal())
......@@ -87,139 +76,98 @@ func (h *clientHandshake) loop(writeChan chan<- interface{}, controlChan chan<-
for i, asn1Data := range certMsg.certificates {
cert, err := x509.ParseCertificate(asn1Data)
if err != nil {
h.error(alertBadCertificate)
return
return c.sendAlert(alertBadCertificate)
}
certs[i] = cert
}
// TODO(agl): do better validation of certs: max path length, name restrictions etc.
for i := 1; i < len(certs); i++ {
if certs[i-1].CheckSignatureFrom(certs[i]) != nil {
h.error(alertBadCertificate)
return
if err := certs[i-1].CheckSignatureFrom(certs[i]); err != nil {
return c.sendAlert(alertBadCertificate)
}
}
if config.RootCAs != nil {
// TODO(rsc): Find certificates for OS X 10.6.
if false && config.RootCAs != nil {
root := config.RootCAs.FindParent(certs[len(certs)-1])
if root == nil {
h.error(alertBadCertificate)
return
return c.sendAlert(alertBadCertificate)
}
if certs[len(certs)-1].CheckSignatureFrom(root) != nil {
h.error(alertBadCertificate)
return
return c.sendAlert(alertBadCertificate)
}
}
pub, ok := certs[0].PublicKey.(*rsa.PublicKey)
if !ok {
h.error(alertUnsupportedCertificate)
return
return c.sendAlert(alertUnsupportedCertificate)
}
shd, ok := h.readHandshakeMsg().(*serverHelloDoneMsg)
msg, err = c.readHandshake()
if err != nil {
return err
}
shd, ok := msg.(*serverHelloDoneMsg)
if !ok {
h.error(alertUnexpectedMessage)
return
return c.sendAlert(alertUnexpectedMessage)
}
finishedHash.Write(shd.marshal())
ckx := new(clientKeyExchangeMsg)
preMasterSecret := make([]byte, 48)
// Note that the version number in the preMasterSecret must be the
// version offered in the ClientHello.
preMasterSecret[0] = defaultMajor
preMasterSecret[1] = defaultMinor
preMasterSecret[0] = byte(hello.vers >> 8)
preMasterSecret[1] = byte(hello.vers)
_, err = io.ReadFull(config.Rand, preMasterSecret[2:])
if err != nil {
h.error(alertInternalError)
return
return c.sendAlert(alertInternalError)
}
ckx.ciphertext, err = rsa.EncryptPKCS1v15(config.Rand, pub, preMasterSecret)
if err != nil {
h.error(alertInternalError)
return
return c.sendAlert(alertInternalError)
}
finishedHash.Write(ckx.marshal())
writeChan <- ckx
c.writeRecord(recordTypeHandshake, ckx.marshal())
suite := cipherSuites[0]
masterSecret, clientMAC, serverMAC, clientKey, serverKey :=
keysFromPreMasterSecret11(preMasterSecret, hello.random, serverHello.random, suite.hashLength, suite.cipherKeyLength)
cipher, _ := rc4.NewCipher(clientKey)
writeChan <- writerChangeCipherSpec{cipher, hmac.New(sha1.New(), clientMAC)}
c.out.prepareCipherSpec(cipher, hmac.New(sha1.New(), clientMAC))
c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
finished := new(finishedMsg)
finished.verifyData = finishedHash.clientSum(masterSecret)
finishedHash.Write(finished.marshal())
writeChan <- finished
// TODO(agl): this is cut-through mode which should probably be an option.
writeChan <- writerEnableApplicationData{}
_, ok = h.readHandshakeMsg().(changeCipherSpec)
if !ok {
h.error(alertUnexpectedMessage)
return
}
c.writeRecord(recordTypeHandshake, finished.marshal())
cipher2, _ := rc4.NewCipher(serverKey)
controlChan <- &newCipherSpec{cipher2, hmac.New(sha1.New(), serverMAC)}
c.in.prepareCipherSpec(cipher2, hmac.New(sha1.New(), serverMAC))
c.readRecord(recordTypeChangeCipherSpec)
if c.err != nil {
return c.err
}
serverFinished, ok := h.readHandshakeMsg().(*finishedMsg)
msg, err = c.readHandshake()
if err != nil {
return err
}
serverFinished, ok := msg.(*finishedMsg)
if !ok {
h.error(alertUnexpectedMessage)
return
return c.sendAlert(alertUnexpectedMessage)
}
verify := finishedHash.serverSum(masterSecret)
if len(verify) != len(serverFinished.verifyData) ||
subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 {
h.error(alertHandshakeFailure)
return
return c.sendAlert(alertHandshakeFailure)
}
controlChan <- ConnectionState{HandshakeComplete: true, CipherSuite: "TLS_RSA_WITH_RC4_128_SHA"}
// This should just block forever.
_ = h.readHandshakeMsg()
h.error(alertUnexpectedMessage)
return
}
func (h *clientHandshake) readHandshakeMsg() interface{} {
v := <-h.msgChan
if closed(h.msgChan) {
// If the channel closed then the processor received an error
// from the peer and we don't want to echo it back to them.
h.msgChan = nil
return 0
}
if _, ok := v.(alert); ok {
// We got an alert from the processor. We forward to the writer
// and shutdown.
h.writeChan <- v
h.msgChan = nil
return 0
}
return v
}
func (h *clientHandshake) error(e alertType) {
if h.msgChan != nil {
// If we didn't get an error from the processor, then we need
// to tell it about the error.
go func() {
for _ = range h.msgChan {
}
}()
h.controlChan <- ConnectionState{Error: e}
close(h.controlChan)
h.writeChan <- alert{alertLevelError, e}
}
c.handshakeComplete = true
c.cipherSuite = TLS_RSA_WITH_RC4_128_SHA
return nil
}
......@@ -6,7 +6,7 @@ package tls
type clientHelloMsg struct {
raw []byte
major, minor uint8
vers uint16
random []byte
sessionId []byte
cipherSuites []uint16
......@@ -40,8 +40,8 @@ func (m *clientHelloMsg) marshal() []byte {
x[1] = uint8(length >> 16)
x[2] = uint8(length >> 8)
x[3] = uint8(length)
x[4] = m.major
x[5] = m.minor
x[4] = uint8(m.vers >> 8)
x[5] = uint8(m.vers)
copy(x[6:38], m.random)
x[38] = uint8(len(m.sessionId))
copy(x[39:39+len(m.sessionId)], m.sessionId)
......@@ -108,12 +108,11 @@ func (m *clientHelloMsg) marshal() []byte {
}
func (m *clientHelloMsg) unmarshal(data []byte) bool {
if len(data) < 43 {
if len(data) < 42 {
return false
}
m.raw = data
m.major = data[4]
m.minor = data[5]
m.vers = uint16(data[4])<<8 | uint16(data[5])
m.random = data[6:38]
sessionIdLen := int(data[38])
if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
......@@ -136,7 +135,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
m.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i])
}
data = data[2+cipherSuiteLen:]
if len(data) < 2 {
if len(data) < 1 {
return false
}
compressionMethodsLen := int(data[0])
......@@ -212,7 +211,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
type serverHelloMsg struct {
raw []byte
major, minor uint8
vers uint16
random []byte
sessionId []byte
cipherSuite uint16
......@@ -249,8 +248,8 @@ func (m *serverHelloMsg) marshal() []byte {
x[1] = uint8(length >> 16)
x[2] = uint8(length >> 8)
x[3] = uint8(length)
x[4] = m.major
x[5] = m.minor
x[4] = uint8(m.vers >> 8)
x[5] = uint8(m.vers)
copy(x[6:38], m.random)
x[38] = uint8(len(m.sessionId))
copy(x[39:39+len(m.sessionId)], m.sessionId)
......@@ -306,8 +305,7 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool {
return false
}
m.raw = data
m.major = data[4]
m.minor = data[5]
m.vers = uint16(data[4])<<8 | uint16(data[5])
m.random = data[6:38]
sessionIdLen := int(data[38])
if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
......
......@@ -97,8 +97,7 @@ func randomString(n int, rand *rand.Rand) string {
func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &clientHelloMsg{}
m.major = uint8(rand.Intn(256))
m.minor = uint8(rand.Intn(256))
m.vers = uint16(rand.Intn(65536))
m.random = randomBytes(32, rand)
m.sessionId = randomBytes(rand.Intn(32), rand)
m.cipherSuites = make([]uint16, rand.Intn(63)+1)
......@@ -118,8 +117,7 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &serverHelloMsg{}
m.major = uint8(rand.Intn(256))
m.minor = uint8(rand.Intn(256))
m.vers = uint16(rand.Intn(65536))
m.random = randomBytes(32, rand)
m.sessionId = randomBytes(rand.Intn(32), rand)
m.cipherSuite = uint16(rand.Int31())
......
......@@ -19,6 +19,7 @@ import (
"crypto/sha1"
"crypto/subtle"
"io"
"os"
)
type cipherSuite struct {
......@@ -31,33 +32,22 @@ var cipherSuites = []cipherSuite{
cipherSuite{TLS_RSA_WITH_RC4_128_SHA, 20, 16},
}
// A serverHandshake performs the server side of the TLS 1.1 handshake protocol.
type serverHandshake struct {
writeChan chan<- interface{}
controlChan chan<- interface{}
msgChan <-chan interface{}
config *Config
}
func (h *serverHandshake) loop(writeChan chan<- interface{}, controlChan chan<- interface{}, msgChan <-chan interface{}, config *Config) {
h.writeChan = writeChan
h.controlChan = controlChan
h.msgChan = msgChan
h.config = config
defer close(writeChan)
defer close(controlChan)
clientHello, ok := h.readHandshakeMsg().(*clientHelloMsg)
func (c *Conn) serverHandshake() os.Error {
config := c.config
msg, err := c.readHandshake()
if err != nil {
return err
}
clientHello, ok := msg.(*clientHelloMsg)
if !ok {
h.error(alertUnexpectedMessage)
return
return c.sendAlert(alertUnexpectedMessage)
}
major, minor, ok := mutualVersion(clientHello.major, clientHello.minor)
vers, ok := mutualVersion(clientHello.vers)
if !ok {
h.error(alertProtocolVersion)
return
return c.sendAlert(alertProtocolVersion)
}
c.vers = vers
c.haveVers = true
finishedHash := newFinishedHash()
finishedHash.Write(clientHello.marshal())
......@@ -89,23 +79,20 @@ func (h *serverHandshake) loop(writeChan chan<- interface{}, controlChan chan<-
}
if suite == nil || !foundCompression {
h.error(alertHandshakeFailure)
return
return c.sendAlert(alertHandshakeFailure)
}
hello.major = major
hello.minor = minor
hello.vers = vers
hello.cipherSuite = suite.id
currentTime := uint32(config.Time())
t := uint32(config.Time())
hello.random = make([]byte, 32)
hello.random[0] = byte(currentTime >> 24)
hello.random[1] = byte(currentTime >> 16)
hello.random[2] = byte(currentTime >> 8)
hello.random[3] = byte(currentTime)
_, err := io.ReadFull(config.Rand, hello.random[4:])
hello.random[0] = byte(t >> 24)
hello.random[1] = byte(t >> 16)
hello.random[2] = byte(t >> 8)
hello.random[3] = byte(t)
_, err = io.ReadFull(config.Rand, hello.random[4:])
if err != nil {
h.error(alertInternalError)
return
return c.sendAlert(alertInternalError)
}
hello.compressionMethod = compressionNone
if clientHello.nextProtoNeg {
......@@ -114,41 +101,40 @@ func (h *serverHandshake) loop(writeChan chan<- interface{}, controlChan chan<-
}
finishedHash.Write(hello.marshal())
writeChan <- writerSetVersion{major, minor}
writeChan <- hello
c.writeRecord(recordTypeHandshake, hello.marshal())
if len(config.Certificates) == 0 {
h.error(alertInternalError)
return
return c.sendAlert(alertInternalError)
}
certMsg := new(certificateMsg)
certMsg.certificates = config.Certificates[0].Certificate
finishedHash.Write(certMsg.marshal())
writeChan <- certMsg
c.writeRecord(recordTypeHandshake, certMsg.marshal())
helloDone := new(serverHelloDoneMsg)
finishedHash.Write(helloDone.marshal())
writeChan <- helloDone
c.writeRecord(recordTypeHandshake, helloDone.marshal())
ckx, ok := h.readHandshakeMsg().(*clientKeyExchangeMsg)
msg, err = c.readHandshake()
if err != nil {
return err
}
ckx, ok := msg.(*clientKeyExchangeMsg)
if !ok {
h.error(alertUnexpectedMessage)
return
return c.sendAlert(alertUnexpectedMessage)
}
finishedHash.Write(ckx.marshal())
preMasterSecret := make([]byte, 48)
_, err = io.ReadFull(config.Rand, preMasterSecret[2:])
if err != nil {
h.error(alertInternalError)
return
return c.sendAlert(alertInternalError)
}
err = rsa.DecryptPKCS1v15SessionKey(config.Rand, config.Certificates[0].PrivateKey, ckx.ciphertext, preMasterSecret)
if err != nil {
h.error(alertHandshakeFailure)
return
return c.sendAlert(alertHandshakeFailure)
}
// We don't check the version number in the premaster secret. For one,
// by checking it, we would leak information about the validity of the
......@@ -160,91 +146,53 @@ func (h *serverHandshake) loop(writeChan chan<- interface{}, controlChan chan<-
masterSecret, clientMAC, serverMAC, clientKey, serverKey :=
keysFromPreMasterSecret11(preMasterSecret, clientHello.random, hello.random, suite.hashLength, suite.cipherKeyLength)
_, ok = h.readHandshakeMsg().(changeCipherSpec)
if !ok {
h.error(alertUnexpectedMessage)
return
}
cipher, _ := rc4.NewCipher(clientKey)
controlChan <- &newCipherSpec{cipher, hmac.New(sha1.New(), clientMAC)}
c.in.prepareCipherSpec(cipher, hmac.New(sha1.New(), clientMAC))
c.readRecord(recordTypeChangeCipherSpec)
if err := c.error(); err != nil {
return err
}
clientProtocol := ""
if hello.nextProtoNeg {
nextProto, ok := h.readHandshakeMsg().(*nextProtoMsg)
msg, err = c.readHandshake()
if err != nil {
return err
}
nextProto, ok := msg.(*nextProtoMsg)
if !ok {
h.error(alertUnexpectedMessage)
return
return c.sendAlert(alertUnexpectedMessage)
}
finishedHash.Write(nextProto.marshal())
clientProtocol = nextProto.proto
c.clientProtocol = nextProto.proto
}
clientFinished, ok := h.readHandshakeMsg().(*finishedMsg)
msg, err = c.readHandshake()
if err != nil {
return err
}
clientFinished, ok := msg.(*finishedMsg)
if !ok {
h.error(alertUnexpectedMessage)
return
return c.sendAlert(alertUnexpectedMessage)
}
verify := finishedHash.clientSum(masterSecret)
if len(verify) != len(clientFinished.verifyData) ||
subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 {
h.error(alertHandshakeFailure)
return
return c.sendAlert(alertHandshakeFailure)
}
controlChan <- ConnectionState{true, "TLS_RSA_WITH_RC4_128_SHA", 0, clientProtocol}
finishedHash.Write(clientFinished.marshal())
cipher2, _ := rc4.NewCipher(serverKey)
writeChan <- writerChangeCipherSpec{cipher2, hmac.New(sha1.New(), serverMAC)}
c.out.prepareCipherSpec(cipher2, hmac.New(sha1.New(), serverMAC))
c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
finished := new(finishedMsg)
finished.verifyData = finishedHash.serverSum(masterSecret)
writeChan <- finished
writeChan <- writerEnableApplicationData{}
for {
_, ok := h.readHandshakeMsg().(*clientHelloMsg)
if !ok {
h.error(alertUnexpectedMessage)
return
}
// We reject all renegotication requests.
writeChan <- alert{alertLevelWarning, alertNoRenegotiation}
}
}
c.writeRecord(recordTypeHandshake, finished.marshal())
func (h *serverHandshake) readHandshakeMsg() interface{} {
v := <-h.msgChan
if closed(h.msgChan) {
// If the channel closed then the processor received an error
// from the peer and we don't want to echo it back to them.
h.msgChan = nil
return 0
}
if _, ok := v.(alert); ok {
// We got an alert from the processor. We forward to the writer
// and shutdown.
h.writeChan <- v
h.msgChan = nil
return 0
}
return v
}
c.handshakeComplete = true
c.cipherSuite = TLS_RSA_WITH_RC4_128_SHA
func (h *serverHandshake) error(e alertType) {
if h.msgChan != nil {
// If we didn't get an error from the processor, then we need
// to tell it about the error.
go func() {
for _ = range h.msgChan {
}
}()
h.controlChan <- ConnectionState{false, "", e, ""}
close(h.controlChan)
h.writeChan <- alert{alertLevelError, e}
}
return nil
}
This diff is collapsed.
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
// A recordProcessor accepts reassembled records, decrypts and verifies them
// and routes them either to the handshake processor, to up to the application.
// It also accepts requests from the application for the current connection
// state, or for a notification when the state changes.
import (
"container/list"
"crypto/subtle"
"hash"
)
// getConnectionState is a request from the application to get the current
// ConnectionState.
type getConnectionState struct {
reply chan<- ConnectionState
}
// waitConnectionState is a request from the application to be notified when
// the connection state changes.
type waitConnectionState struct {
reply chan<- ConnectionState
}
// connectionStateChange is a message from the handshake processor that the
// connection state has changed.
type connectionStateChange struct {
connState ConnectionState
}
// changeCipherSpec is a message send to the handshake processor to signal that
// the peer is switching ciphers.
type changeCipherSpec struct{}
// newCipherSpec is a message from the handshake processor that future
// records should be processed with a new cipher and MAC function.
type newCipherSpec struct {
encrypt encryptor
mac hash.Hash
}
type recordProcessor struct {
decrypt encryptor
mac hash.Hash
seqNum uint64
handshakeBuf []byte
appDataChan chan<- []byte
requestChan <-chan interface{}
controlChan <-chan interface{}
recordChan <-chan *record
handshakeChan chan<- interface{}
// recordRead is nil when we don't wish to read any more.
recordRead <-chan *record
// appDataSend is nil when len(appData) == 0.
appDataSend chan<- []byte
// appData contains any application data queued for upstream.
appData []byte
// A list of channels waiting for connState to change.
waitQueue *list.List
connState ConnectionState
shutdown bool
header [13]byte
}
// drainRequestChannel processes messages from the request channel until it's closed.
func drainRequestChannel(requestChan <-chan interface{}, c ConnectionState) {
for v := range requestChan {
if closed(requestChan) {
return
}
switch r := v.(type) {
case getConnectionState:
r.reply <- c
case waitConnectionState:
r.reply <- c
}
}
}
func (p *recordProcessor) loop(appDataChan chan<- []byte, requestChan <-chan interface{}, controlChan <-chan interface{}, recordChan <-chan *record, handshakeChan chan<- interface{}) {
noop := nop{}
p.decrypt = noop
p.mac = noop
p.waitQueue = list.New()
p.appDataChan = appDataChan
p.requestChan = requestChan
p.controlChan = controlChan
p.recordChan = recordChan
p.handshakeChan = handshakeChan
p.recordRead = recordChan
for !p.shutdown {
select {
case p.appDataSend <- p.appData:
p.appData = nil
p.appDataSend = nil
p.recordRead = p.recordChan
case c := <-controlChan:
p.processControlMsg(c)
case r := <-requestChan:
p.processRequestMsg(r)
case r := <-p.recordRead:
p.processRecord(r)
}
}
p.wakeWaiters()
go drainRequestChannel(p.requestChan, p.connState)
go func() {
for _ = range controlChan {
}
}()
close(handshakeChan)
if len(p.appData) > 0 {
appDataChan <- p.appData
}
close(appDataChan)
}
func (p *recordProcessor) processRequestMsg(requestMsg interface{}) {
if closed(p.requestChan) {
p.shutdown = true
return
}
switch r := requestMsg.(type) {
case getConnectionState:
r.reply <- p.connState
case waitConnectionState:
if p.connState.HandshakeComplete {
r.reply <- p.connState
}
p.waitQueue.PushBack(r.reply)
}
}
func (p *recordProcessor) processControlMsg(msg interface{}) {
connState, ok := msg.(ConnectionState)
if !ok || closed(p.controlChan) {
p.shutdown = true
return
}
p.connState = connState
p.wakeWaiters()
}
func (p *recordProcessor) wakeWaiters() {
for i := p.waitQueue.Front(); i != nil; i = i.Next() {
i.Value.(chan<- ConnectionState) <- p.connState
}
p.waitQueue.Init()
}
func (p *recordProcessor) processRecord(r *record) {
if closed(p.recordChan) {
p.shutdown = true
return
}
p.decrypt.XORKeyStream(r.payload)
if len(r.payload) < p.mac.Size() {
p.error(alertBadRecordMAC)
return
}
fillMACHeader(&p.header, p.seqNum, len(r.payload)-p.mac.Size(), r)
p.seqNum++
p.mac.Reset()
p.mac.Write(p.header[0:13])
p.mac.Write(r.payload[0 : len(r.payload)-p.mac.Size()])
macBytes := p.mac.Sum()
if subtle.ConstantTimeCompare(macBytes, r.payload[len(r.payload)-p.mac.Size():]) != 1 {
p.error(alertBadRecordMAC)
return
}
switch r.contentType {
case recordTypeHandshake:
p.processHandshakeRecord(r.payload[0 : len(r.payload)-p.mac.Size()])
case recordTypeChangeCipherSpec:
if len(r.payload) != 1 || r.payload[0] != 1 {
p.error(alertUnexpectedMessage)
return
}
p.handshakeChan <- changeCipherSpec{}
newSpec, ok := (<-p.controlChan).(*newCipherSpec)
if !ok {
p.connState.Error = alertUnexpectedMessage
p.shutdown = true
return
}
p.decrypt = newSpec.encrypt
p.mac = newSpec.mac
p.seqNum = 0
case recordTypeApplicationData:
if p.connState.HandshakeComplete == false {
p.error(alertUnexpectedMessage)
return
}
p.recordRead = nil
p.appData = r.payload[0 : len(r.payload)-p.mac.Size()]
p.appDataSend = p.appDataChan
default:
p.error(alertUnexpectedMessage)
return
}
}
func (p *recordProcessor) processHandshakeRecord(data []byte) {
if p.handshakeBuf == nil {
p.handshakeBuf = data
} else {
if len(p.handshakeBuf) > maxHandshakeMsg {
p.error(alertInternalError)
return
}
newBuf := make([]byte, len(p.handshakeBuf)+len(data))
copy(newBuf, p.handshakeBuf)
copy(newBuf[len(p.handshakeBuf):], data)
p.handshakeBuf = newBuf
}
for len(p.handshakeBuf) >= 4 {
handshakeLen := int(p.handshakeBuf[1])<<16 |
int(p.handshakeBuf[2])<<8 |
int(p.handshakeBuf[3])
if handshakeLen+4 > len(p.handshakeBuf) {
break
}
bytes := p.handshakeBuf[0 : handshakeLen+4]
p.handshakeBuf = p.handshakeBuf[handshakeLen+4:]
if bytes[0] == typeFinished {
// Special case because Finished is synchronous: the
// handshake handler has to tell us if it's ok to start
// forwarding application data.
m := new(finishedMsg)
if !m.unmarshal(bytes) {
p.error(alertUnexpectedMessage)
}
p.handshakeChan <- m
var ok bool
p.connState, ok = (<-p.controlChan).(ConnectionState)
if !ok || p.connState.Error != 0 {
p.shutdown = true
return
}
} else {
msg, ok := parseHandshakeMsg(bytes)
if !ok {
p.error(alertUnexpectedMessage)
return
}
p.handshakeChan <- msg
}
}
}
func (p *recordProcessor) error(err alertType) {
close(p.handshakeChan)
p.connState.Error = err
p.wakeWaiters()
p.shutdown = true
}
func parseHandshakeMsg(data []byte) (interface{}, bool) {
var m interface {
unmarshal([]byte) bool
}
switch data[0] {
case typeClientHello:
m = new(clientHelloMsg)
case typeServerHello:
m = new(serverHelloMsg)
case typeCertificate:
m = new(certificateMsg)
case typeServerHelloDone:
m = new(serverHelloDoneMsg)
case typeClientKeyExchange:
m = new(clientKeyExchangeMsg)
case typeNextProtocol:
m = new(nextProtoMsg)
default:
return nil, false
}
ok := m.unmarshal(data)
return m, ok
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"encoding/hex"
"testing"
"testing/script"
)
func setup() (appDataChan chan []byte, requestChan chan interface{}, controlChan chan interface{}, recordChan chan *record, handshakeChan chan interface{}) {
rp := new(recordProcessor)
appDataChan = make(chan []byte)
requestChan = make(chan interface{})
controlChan = make(chan interface{})
recordChan = make(chan *record)
handshakeChan = make(chan interface{})
go rp.loop(appDataChan, requestChan, controlChan, recordChan, handshakeChan)
return
}
func fromHex(s string) []byte {
b, _ := hex.DecodeString(s)
return b
}
func TestNullConnectionState(t *testing.T) {
_, requestChan, controlChan, recordChan, _ := setup()
defer close(requestChan)
defer close(controlChan)
defer close(recordChan)
// Test a simple request for the connection state.
replyChan := make(chan ConnectionState)
sendReq := script.NewEvent("send request", nil, script.Send{requestChan, getConnectionState{replyChan}})
getReply := script.NewEvent("get reply", []*script.Event{sendReq}, script.Recv{replyChan, ConnectionState{false, "", 0, ""}})
err := script.Perform(0, []*script.Event{sendReq, getReply})
if err != nil {
t.Errorf("Got error: %s", err)
}
}
func TestWaitConnectionState(t *testing.T) {
_, requestChan, controlChan, recordChan, _ := setup()
defer close(requestChan)
defer close(controlChan)
defer close(recordChan)
// Test that waitConnectionState doesn't get a reply until the connection state changes.
replyChan := make(chan ConnectionState)
sendReq := script.NewEvent("send request", nil, script.Send{requestChan, waitConnectionState{replyChan}})
replyChan2 := make(chan ConnectionState)
sendReq2 := script.NewEvent("send request 2", []*script.Event{sendReq}, script.Send{requestChan, getConnectionState{replyChan2}})
getReply2 := script.NewEvent("get reply 2", []*script.Event{sendReq2}, script.Recv{replyChan2, ConnectionState{false, "", 0, ""}})
sendState := script.NewEvent("send state", []*script.Event{getReply2}, script.Send{controlChan, ConnectionState{true, "test", 1, ""}})
getReply := script.NewEvent("get reply", []*script.Event{sendState}, script.Recv{replyChan, ConnectionState{true, "test", 1, ""}})
err := script.Perform(0, []*script.Event{sendReq, sendReq2, getReply2, sendState, getReply})
if err != nil {
t.Errorf("Got error: %s", err)
}
}
func TestHandshakeAssembly(t *testing.T) {
_, requestChan, controlChan, recordChan, handshakeChan := setup()
defer close(requestChan)
defer close(controlChan)
defer close(recordChan)
// Test the reassembly of a fragmented handshake message.
send1 := script.NewEvent("send 1", nil, script.Send{recordChan, &record{recordTypeHandshake, 0, 0, fromHex("10000003")}})
send2 := script.NewEvent("send 2", []*script.Event{send1}, script.Send{recordChan, &record{recordTypeHandshake, 0, 0, fromHex("0001")}})
send3 := script.NewEvent("send 3", []*script.Event{send2}, script.Send{recordChan, &record{recordTypeHandshake, 0, 0, fromHex("42")}})
recvMsg := script.NewEvent("recv", []*script.Event{send3}, script.Recv{handshakeChan, &clientKeyExchangeMsg{fromHex("10000003000142"), fromHex("42")}})
err := script.Perform(0, []*script.Event{send1, send2, send3, recvMsg})
if err != nil {
t.Errorf("Got error: %s", err)
}
}
func TestEarlyApplicationData(t *testing.T) {
_, requestChan, controlChan, recordChan, handshakeChan := setup()
defer close(requestChan)
defer close(controlChan)
defer close(recordChan)
// Test that applicaton data received before the handshake has completed results in an error.
send := script.NewEvent("send", nil, script.Send{recordChan, &record{recordTypeApplicationData, 0, 0, fromHex("")}})
recv := script.NewEvent("recv", []*script.Event{send}, script.Closed{handshakeChan})
err := script.Perform(0, []*script.Event{send, recv})
if err != nil {
t.Errorf("Got error: %s", err)
}
}
func TestApplicationData(t *testing.T) {
appDataChan, requestChan, controlChan, recordChan, handshakeChan := setup()
defer close(requestChan)
defer close(controlChan)
defer close(recordChan)
// Test that the application data is forwarded after a successful Finished message.
send1 := script.NewEvent("send 1", nil, script.Send{recordChan, &record{recordTypeHandshake, 0, 0, fromHex("1400000c000000000000000000000000")}})
recv1 := script.NewEvent("recv finished", []*script.Event{send1}, script.Recv{handshakeChan, &finishedMsg{fromHex("1400000c000000000000000000000000"), fromHex("000000000000000000000000")}})
send2 := script.NewEvent("send connState", []*script.Event{recv1}, script.Send{controlChan, ConnectionState{true, "", 0, ""}})
send3 := script.NewEvent("send 2", []*script.Event{send2}, script.Send{recordChan, &record{recordTypeApplicationData, 0, 0, fromHex("0102")}})
recv2 := script.NewEvent("recv data", []*script.Event{send3}, script.Recv{appDataChan, []byte{0x01, 0x02}})
err := script.Perform(0, []*script.Event{send1, recv1, send2, send3, recv2})
if err != nil {
t.Errorf("Got error: %s", err)
}
}
func TestInvalidChangeCipherSpec(t *testing.T) {
appDataChan, requestChan, controlChan, recordChan, handshakeChan := setup()
defer close(requestChan)
defer close(controlChan)
defer close(recordChan)
send1 := script.NewEvent("send 1", nil, script.Send{recordChan, &record{recordTypeChangeCipherSpec, 0, 0, []byte{1}}})
recv1 := script.NewEvent("recv 1", []*script.Event{send1}, script.Recv{handshakeChan, changeCipherSpec{}})
send2 := script.NewEvent("send 2", []*script.Event{recv1}, script.Send{controlChan, ConnectionState{false, "", 42, ""}})
close := script.NewEvent("close 1", []*script.Event{send2}, script.Closed{appDataChan})
close2 := script.NewEvent("close 2", []*script.Event{send2}, script.Closed{handshakeChan})
err := script.Perform(0, []*script.Event{send1, recv1, send2, close, close2})
if err != nil {
t.Errorf("Got error: %s", err)
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
// The record reader handles reading from the connection and reassembling TLS
// record structures. It loops forever doing this and writes the TLS records to
// it's outbound channel. On error, it closes its outbound channel.
import (
"io"
"bufio"
)
// recordReader loops, reading TLS records from source and writing them to the
// given channel. The channel is closed on EOF or on error.
func recordReader(c chan<- *record, source io.Reader) {
defer close(c)
buf := bufio.NewReader(source)
for {
var header [5]byte
n, _ := buf.Read(&header)
if n != 5 {
return
}
recordLength := int(header[3])<<8 | int(header[4])
if recordLength > maxTLSCiphertext {
return
}
payload := make([]byte, recordLength)
n, _ = buf.Read(payload)
if n != recordLength {
return
}
c <- &record{recordType(header[0]), header[1], header[2], payload}
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"bytes"
"testing"
"testing/iotest"
)
func matchRecord(r1, r2 *record) bool {
if (r1 == nil) != (r2 == nil) {
return false
}
if r1 == nil {
return true
}
return r1.contentType == r2.contentType &&
r1.major == r2.major &&
r1.minor == r2.minor &&
bytes.Compare(r1.payload, r2.payload) == 0
}
type recordReaderTest struct {
in []byte
out []*record
}
var recordReaderTests = []recordReaderTest{
recordReaderTest{nil, nil},
recordReaderTest{fromHex("01"), nil},
recordReaderTest{fromHex("0102"), nil},
recordReaderTest{fromHex("010203"), nil},
recordReaderTest{fromHex("01020300"), nil},
recordReaderTest{fromHex("0102030000"), []*record{&record{1, 2, 3, nil}}},
recordReaderTest{fromHex("01020300000102030000"), []*record{&record{1, 2, 3, nil}, &record{1, 2, 3, nil}}},
recordReaderTest{fromHex("0102030001fe0102030002feff"), []*record{&record{1, 2, 3, []byte{0xfe}}, &record{1, 2, 3, []byte{0xfe, 0xff}}}},
recordReaderTest{fromHex("010203000001020300"), []*record{&record{1, 2, 3, nil}}},
}
func TestRecordReader(t *testing.T) {
for i, test := range recordReaderTests {
buf := bytes.NewBuffer(test.in)
c := make(chan *record)
go recordReader(c, buf)
matchRecordReaderOutput(t, i, test, c)
buf = bytes.NewBuffer(test.in)
buf2 := iotest.OneByteReader(buf)
c = make(chan *record)
go recordReader(c, buf2)
matchRecordReaderOutput(t, i*2, test, c)
}
}
func matchRecordReaderOutput(t *testing.T, i int, test recordReaderTest, c <-chan *record) {
for j, r1 := range test.out {
r2 := <-c
if r2 == nil {
t.Errorf("#%d truncated after %d values", i, j)
break
}
if !matchRecord(r1, r2) {
t.Errorf("#%d (%d) got:%#v want:%#v", i, j, r2, r1)
}
}
<-c
if !closed(c) {
t.Errorf("#%d: channel didn't close", i)
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"fmt"
"hash"
"io"
)
// writerEnableApplicationData is a message which instructs recordWriter to
// start reading and transmitting data from the application data channel.
type writerEnableApplicationData struct{}
// writerChangeCipherSpec updates the encryption and MAC functions and resets
// the sequence count.
type writerChangeCipherSpec struct {
encryptor encryptor
mac hash.Hash
}
// writerSetVersion sets the version number bytes that we included in the
// record header for future records.
type writerSetVersion struct {
major, minor uint8
}
// A recordWriter accepts messages from the handshake processor and
// application data. It writes them to the outgoing connection and blocks on
// writing. It doesn't read from the application data channel until the
// handshake processor has signaled that the handshake is complete.
type recordWriter struct {
writer io.Writer
encryptor encryptor
mac hash.Hash
seqNum uint64
major, minor uint8
shutdown bool
appChan <-chan []byte
controlChan <-chan interface{}
header [13]byte
}
func (w *recordWriter) loop(writer io.Writer, appChan <-chan []byte, controlChan <-chan interface{}) {
w.writer = writer
w.encryptor = nop{}
w.mac = nop{}
w.appChan = appChan
w.controlChan = controlChan
for !w.shutdown {
msg := <-controlChan
if _, ok := msg.(writerEnableApplicationData); ok {
break
}
w.processControlMessage(msg)
}
for !w.shutdown {
// Always process control messages first.
if controlMsg, ok := <-controlChan; ok {
w.processControlMessage(controlMsg)
continue
}
select {
case controlMsg := <-controlChan:
w.processControlMessage(controlMsg)
case appMsg := <-appChan:
w.processAppMessage(appMsg)
}
}
if !closed(appChan) {
go func() {
for _ = range appChan {
}
}()
}
if !closed(controlChan) {
go func() {
for _ = range controlChan {
}
}()
}
}
// fillMACHeader generates a MAC header. See RFC 4346, section 6.2.3.1.
func fillMACHeader(header *[13]byte, seqNum uint64, length int, r *record) {
header[0] = uint8(seqNum >> 56)
header[1] = uint8(seqNum >> 48)
header[2] = uint8(seqNum >> 40)
header[3] = uint8(seqNum >> 32)
header[4] = uint8(seqNum >> 24)
header[5] = uint8(seqNum >> 16)
header[6] = uint8(seqNum >> 8)
header[7] = uint8(seqNum)
header[8] = uint8(r.contentType)
header[9] = r.major
header[10] = r.minor
header[11] = uint8(length >> 8)
header[12] = uint8(length)
}
func (w *recordWriter) writeRecord(r *record) {
w.mac.Reset()
fillMACHeader(&w.header, w.seqNum, len(r.payload), r)
w.mac.Write(w.header[0:13])
w.mac.Write(r.payload)
macBytes := w.mac.Sum()
w.encryptor.XORKeyStream(r.payload)
w.encryptor.XORKeyStream(macBytes)
length := len(r.payload) + len(macBytes)
w.header[11] = uint8(length >> 8)
w.header[12] = uint8(length)
w.writer.Write(w.header[8:13])
w.writer.Write(r.payload)
w.writer.Write(macBytes)
w.seqNum++
}
func (w *recordWriter) processControlMessage(controlMsg interface{}) {
if controlMsg == nil {
w.shutdown = true
return
}
switch msg := controlMsg.(type) {
case writerChangeCipherSpec:
w.writeRecord(&record{recordTypeChangeCipherSpec, w.major, w.minor, []byte{0x01}})
w.encryptor = msg.encryptor
w.mac = msg.mac
w.seqNum = 0
case writerSetVersion:
w.major = msg.major
w.minor = msg.minor
case alert:
w.writeRecord(&record{recordTypeAlert, w.major, w.minor, []byte{byte(msg.level), byte(msg.error)}})
case handshakeMessage:
// TODO(agl): marshal may return a slice too large for a single record.
w.writeRecord(&record{recordTypeHandshake, w.major, w.minor, msg.marshal()})
default:
fmt.Printf("processControlMessage: unknown %#v\n", msg)
}
}
func (w *recordWriter) processAppMessage(appMsg []byte) {
if closed(w.appChan) {
w.writeRecord(&record{recordTypeApplicationData, w.major, w.minor, []byte{byte(alertCloseNotify)}})
w.shutdown = true
return
}
var done int
for done < len(appMsg) {
todo := len(appMsg)
if todo > maxTLSPlaintext {
todo = maxTLSPlaintext
}
w.writeRecord(&record{recordTypeApplicationData, w.major, w.minor, appMsg[done : done+todo]})
done += todo
}
}
......@@ -6,158 +6,16 @@
package tls
import (
"io"
"os"
"net"
"time"
)
// A Conn represents a secure connection.
type Conn struct {
net.Conn
writeChan chan<- []byte
readChan <-chan []byte
requestChan chan<- interface{}
readBuf []byte
eof bool
readTimeout, writeTimeout int64
}
func timeout(c chan<- bool, nsecs int64) {
time.Sleep(nsecs)
c <- true
}
func (tls *Conn) Read(p []byte) (int, os.Error) {
if len(tls.readBuf) == 0 {
if tls.eof {
return 0, os.EOF
}
var timeoutChan chan bool
if tls.readTimeout > 0 {
timeoutChan = make(chan bool)
go timeout(timeoutChan, tls.readTimeout)
}
select {
case b := <-tls.readChan:
tls.readBuf = b
case <-timeoutChan:
return 0, os.EAGAIN
}
// TLS distinguishes between orderly closes and truncations. An
// orderly close is represented by a zero length slice.
if closed(tls.readChan) {
return 0, io.ErrUnexpectedEOF
}
if len(tls.readBuf) == 0 {
tls.eof = true
return 0, os.EOF
}
}
n := copy(p, tls.readBuf)
tls.readBuf = tls.readBuf[n:]
return n, nil
}
func (tls *Conn) Write(p []byte) (int, os.Error) {
if tls.eof || closed(tls.readChan) {
return 0, os.EOF
}
var timeoutChan chan bool
if tls.writeTimeout > 0 {
timeoutChan = make(chan bool)
go timeout(timeoutChan, tls.writeTimeout)
}
select {
case tls.writeChan <- p:
case <-timeoutChan:
return 0, os.EAGAIN
}
return len(p), nil
}
func (tls *Conn) Close() os.Error {
close(tls.writeChan)
close(tls.requestChan)
tls.eof = true
return nil
}
func (tls *Conn) SetTimeout(nsec int64) os.Error {
tls.readTimeout = nsec
tls.writeTimeout = nsec
return nil
}
func (tls *Conn) SetReadTimeout(nsec int64) os.Error {
tls.readTimeout = nsec
return nil
}
func (tls *Conn) SetWriteTimeout(nsec int64) os.Error {
tls.writeTimeout = nsec
return nil
}
func (tls *Conn) GetConnectionState() ConnectionState {
replyChan := make(chan ConnectionState)
tls.requestChan <- getConnectionState{replyChan}
return <-replyChan
}
func (tls *Conn) WaitConnectionState() ConnectionState {
replyChan := make(chan ConnectionState)
tls.requestChan <- waitConnectionState{replyChan}
return <-replyChan
}
type handshaker interface {
loop(writeChan chan<- interface{}, controlChan chan<- interface{}, msgChan <-chan interface{}, config *Config)
}
// Server establishes a secure connection over the given connection and acts
// as a TLS server.
func startTLSGoroutines(conn net.Conn, h handshaker, config *Config) *Conn {
if config == nil {
config = defaultConfig()
}
tls := new(Conn)
tls.Conn = conn
writeChan := make(chan []byte)
readChan := make(chan []byte)
requestChan := make(chan interface{})
tls.writeChan = writeChan
tls.readChan = readChan
tls.requestChan = requestChan
handshakeWriterChan := make(chan interface{})
processorHandshakeChan := make(chan interface{})
handshakeProcessorChan := make(chan interface{})
readerProcessorChan := make(chan *record)
go new(recordWriter).loop(conn, writeChan, handshakeWriterChan)
go recordReader(readerProcessorChan, conn)
go new(recordProcessor).loop(readChan, requestChan, handshakeProcessorChan, readerProcessorChan, processorHandshakeChan)
go h.loop(handshakeWriterChan, handshakeProcessorChan, processorHandshakeChan, config)
return tls
}
func Server(conn net.Conn, config *Config) *Conn {
return startTLSGoroutines(conn, new(serverHandshake), config)
return &Conn{conn: conn, config: config}
}
func Client(conn net.Conn, config *Config) *Conn {
return startTLSGoroutines(conn, new(clientHandshake), config)
return &Conn{conn: conn, config: config, isClient: true}
}
type Listener struct {
......@@ -180,22 +38,24 @@ func (l *Listener) Addr() net.Addr { return l.listener.Addr() }
// NewListener creates a Listener which accepts connections from an inner
// Listener and wraps each connection with Server.
// The configuration config must be non-nil and must have
// at least one certificate.
func NewListener(listener net.Listener, config *Config) (l *Listener) {
if config == nil {
config = defaultConfig()
}
l = new(Listener)
l.listener = listener
l.config = config
return
}
func Listen(network, laddr string) (net.Listener, os.Error) {
func Listen(network, laddr string, config *Config) (net.Listener, os.Error) {
if config == nil || len(config.Certificates) == 0 {
return nil, os.NewError("tls.Listen: no certificates in configuration")
}
l, err := net.Listen(network, laddr)
if err != nil {
return nil, err
}
return NewListener(l, nil), nil
return NewListener(l, config), nil
}
func Dial(network, laddr, raddr string) (net.Conn, os.Error) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment