Commit 7e767791 authored by Adam Langley's avatar Adam Langley

crypto/tls: implement TLS 1.2.

This does not include AES-GCM yet. Also, it assumes that the handshake and
certificate signature hash are always SHA-256, which is true of the ciphersuites
that we currently support.

R=golang-dev, rsc
CC=golang-dev
https://golang.org/cl/10762044
parent 1f954e5c
...@@ -42,7 +42,7 @@ type cipherSuite struct { ...@@ -42,7 +42,7 @@ type cipherSuite struct {
keyLen int keyLen int
macLen int macLen int
ivLen int ivLen int
ka func() keyAgreement ka func(version uint16) keyAgreement
// If elliptic is set, a server will only consider this ciphersuite if // If elliptic is set, a server will only consider this ciphersuite if
// the ClientHello indicated that the client supports an elliptic curve // the ClientHello indicated that the client supports an elliptic curve
// and point format that we can handle. // and point format that we can handle.
...@@ -157,12 +157,14 @@ func (s tls10MAC) MAC(digestBuf, seq, header, data []byte) []byte { ...@@ -157,12 +157,14 @@ func (s tls10MAC) MAC(digestBuf, seq, header, data []byte) []byte {
return s.h.Sum(digestBuf[:0]) return s.h.Sum(digestBuf[:0])
} }
func rsaKA() keyAgreement { func rsaKA(version uint16) keyAgreement {
return rsaKeyAgreement{} return rsaKeyAgreement{}
} }
func ecdheRSAKA() keyAgreement { func ecdheRSAKA(version uint16) keyAgreement {
return new(ecdheRSAKeyAgreement) return &ecdheRSAKeyAgreement{
version: version,
}
} }
// mutualCipherSuite returns a cipherSuite given a list of supported // mutualCipherSuite returns a cipherSuite given a list of supported
......
...@@ -18,6 +18,7 @@ const ( ...@@ -18,6 +18,7 @@ const (
VersionSSL30 = 0x0300 VersionSSL30 = 0x0300
VersionTLS10 = 0x0301 VersionTLS10 = 0x0301
VersionTLS11 = 0x0302 VersionTLS11 = 0x0302
VersionTLS12 = 0x0303
) )
const ( const (
...@@ -27,7 +28,7 @@ const ( ...@@ -27,7 +28,7 @@ const (
maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB) maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB)
minVersion = VersionSSL30 minVersion = VersionSSL30
maxVersion = VersionTLS11 maxVersion = VersionTLS12
) )
// TLS record types. // TLS record types.
...@@ -63,12 +64,13 @@ const ( ...@@ -63,12 +64,13 @@ const (
// TLS extension numbers // TLS extension numbers
var ( var (
extensionServerName uint16 = 0 extensionServerName uint16 = 0
extensionStatusRequest uint16 = 5 extensionStatusRequest uint16 = 5
extensionSupportedCurves uint16 = 10 extensionSupportedCurves uint16 = 10
extensionSupportedPoints uint16 = 11 extensionSupportedPoints uint16 = 11
extensionSessionTicket uint16 = 35 extensionSignatureAlgorithms uint16 = 13
extensionNextProtoNeg uint16 = 13172 // not IANA assigned extensionSessionTicket uint16 = 35
extensionNextProtoNeg uint16 = 13172 // not IANA assigned
) )
// TLS Elliptic Curves // TLS Elliptic Curves
...@@ -99,6 +101,31 @@ const ( ...@@ -99,6 +101,31 @@ const (
// Rest of these are reserved by the TLS spec // Rest of these are reserved by the TLS spec
) )
// Hash functions for TLS 1.2 (See RFC 5246, section A.4.1)
const (
hashSHA1 uint8 = 2
hashSHA256 uint8 = 4
)
// Signature algorithms for TLS 1.2 (See RFC 5246, section A.4.1)
const (
signatureRSA uint8 = 1
signatureECDSA uint8 = 3
)
// signatureAndHash mirrors the TLS 1.2, SignatureAndHashAlgorithm struct. See
// RFC 5246, section A.4.1.
type signatureAndHash struct {
hash, signature uint8
}
// supportedSignatureAlgorithms contains the signature and hash algorithms that
// the code will adverse as supported both in a TLS 1.2 ClientHello and
// CertificateRequest.
var supportedSignatureAlgorithms = []signatureAndHash{
{hashSHA256, signatureRSA},
}
// ConnectionState records basic TLS details about the connection. // ConnectionState records basic TLS details about the connection.
type ConnectionState struct { type ConnectionState struct {
HandshakeComplete bool HandshakeComplete bool
......
...@@ -748,7 +748,9 @@ func (c *Conn) readHandshake() (interface{}, error) { ...@@ -748,7 +748,9 @@ func (c *Conn) readHandshake() (interface{}, error) {
case typeCertificate: case typeCertificate:
m = new(certificateMsg) m = new(certificateMsg)
case typeCertificateRequest: case typeCertificateRequest:
m = new(certificateRequestMsg) m = &certificateRequestMsg{
hasSignatureAndHash: c.vers >= VersionTLS12,
}
case typeCertificateStatus: case typeCertificateStatus:
m = new(certificateStatusMsg) m = new(certificateStatusMsg)
case typeServerKeyExchange: case typeServerKeyExchange:
...@@ -758,7 +760,9 @@ func (c *Conn) readHandshake() (interface{}, error) { ...@@ -758,7 +760,9 @@ func (c *Conn) readHandshake() (interface{}, error) {
case typeClientKeyExchange: case typeClientKeyExchange:
m = new(clientKeyExchangeMsg) m = new(clientKeyExchangeMsg)
case typeCertificateVerify: case typeCertificateVerify:
m = new(certificateVerifyMsg) m = &certificateVerifyMsg{
hasSignatureAndHash: c.vers >= VersionTLS12,
}
case typeNextProtocol: case typeNextProtocol:
m = new(nextProtoMsg) m = new(nextProtoMsg)
case typeFinished: case typeFinished:
......
...@@ -6,7 +6,6 @@ package tls ...@@ -6,7 +6,6 @@ package tls
import ( import (
"bytes" "bytes"
"crypto"
"crypto/rsa" "crypto/rsa"
"crypto/subtle" "crypto/subtle"
"crypto/x509" "crypto/x509"
...@@ -16,8 +15,6 @@ import ( ...@@ -16,8 +15,6 @@ import (
) )
func (c *Conn) clientHandshake() error { func (c *Conn) clientHandshake() error {
finishedHash := newFinishedHash(VersionTLS10)
if c.config == nil { if c.config == nil {
c.config = defaultConfig() c.config = defaultConfig()
} }
...@@ -45,7 +42,10 @@ func (c *Conn) clientHandshake() error { ...@@ -45,7 +42,10 @@ func (c *Conn) clientHandshake() error {
return errors.New("short read from Rand") return errors.New("short read from Rand")
} }
finishedHash.Write(hello.marshal()) if hello.vers >= VersionTLS12 {
hello.signatureAndHashes = supportedSignatureAlgorithms
}
c.writeRecord(recordTypeHandshake, hello.marshal()) c.writeRecord(recordTypeHandshake, hello.marshal())
msg, err := c.readHandshake() msg, err := c.readHandshake()
...@@ -56,7 +56,6 @@ func (c *Conn) clientHandshake() error { ...@@ -56,7 +56,6 @@ func (c *Conn) clientHandshake() error {
if !ok { if !ok {
return c.sendAlert(alertUnexpectedMessage) return c.sendAlert(alertUnexpectedMessage)
} }
finishedHash.Write(serverHello.marshal())
vers, ok := c.config.mutualVersion(serverHello.vers) vers, ok := c.config.mutualVersion(serverHello.vers)
if !ok || vers < VersionTLS10 { if !ok || vers < VersionTLS10 {
...@@ -66,6 +65,10 @@ func (c *Conn) clientHandshake() error { ...@@ -66,6 +65,10 @@ func (c *Conn) clientHandshake() error {
c.vers = vers c.vers = vers
c.haveVers = true c.haveVers = true
finishedHash := newFinishedHash(c.vers)
finishedHash.Write(hello.marshal())
finishedHash.Write(serverHello.marshal())
if serverHello.compressionMethod != compressionNone { if serverHello.compressionMethod != compressionNone {
return c.sendAlert(alertUnexpectedMessage) return c.sendAlert(alertUnexpectedMessage)
} }
...@@ -148,7 +151,7 @@ func (c *Conn) clientHandshake() error { ...@@ -148,7 +151,7 @@ func (c *Conn) clientHandshake() error {
return err return err
} }
keyAgreement := suite.ka() keyAgreement := suite.ka(c.vers)
skx, ok := msg.(*serverKeyExchangeMsg) skx, ok := msg.(*serverKeyExchangeMsg)
if ok { if ok {
...@@ -269,10 +272,8 @@ func (c *Conn) clientHandshake() error { ...@@ -269,10 +272,8 @@ func (c *Conn) clientHandshake() error {
if chainToSend != nil { if chainToSend != nil {
certVerify := new(certificateVerifyMsg) certVerify := new(certificateVerifyMsg)
digest := make([]byte, 0, 36) digest, hashFunc := finishedHash.hashForClientCertificate()
digest = finishedHash.serverMD5.Sum(digest) signed, err := rsa.SignPKCS1v15(c.config.rand(), c.config.Certificates[0].PrivateKey.(*rsa.PrivateKey), hashFunc, digest)
digest = finishedHash.serverSHA1.Sum(digest)
signed, err := rsa.SignPKCS1v15(c.config.rand(), c.config.Certificates[0].PrivateKey.(*rsa.PrivateKey), crypto.MD5SHA1, digest)
if err != nil { if err != nil {
return c.sendAlert(alertInternalError) return c.sendAlert(alertInternalError)
} }
......
This diff is collapsed.
...@@ -20,6 +20,7 @@ type clientHelloMsg struct { ...@@ -20,6 +20,7 @@ type clientHelloMsg struct {
supportedPoints []uint8 supportedPoints []uint8
ticketSupported bool ticketSupported bool
sessionTicket []uint8 sessionTicket []uint8
signatureAndHashes []signatureAndHash
} }
func (m *clientHelloMsg) equal(i interface{}) bool { func (m *clientHelloMsg) equal(i interface{}) bool {
...@@ -40,7 +41,8 @@ func (m *clientHelloMsg) equal(i interface{}) bool { ...@@ -40,7 +41,8 @@ func (m *clientHelloMsg) equal(i interface{}) bool {
eqUint16s(m.supportedCurves, m1.supportedCurves) && eqUint16s(m.supportedCurves, m1.supportedCurves) &&
bytes.Equal(m.supportedPoints, m1.supportedPoints) && bytes.Equal(m.supportedPoints, m1.supportedPoints) &&
m.ticketSupported == m1.ticketSupported && m.ticketSupported == m1.ticketSupported &&
bytes.Equal(m.sessionTicket, m1.sessionTicket) bytes.Equal(m.sessionTicket, m1.sessionTicket) &&
eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes)
} }
func (m *clientHelloMsg) marshal() []byte { func (m *clientHelloMsg) marshal() []byte {
...@@ -74,6 +76,10 @@ func (m *clientHelloMsg) marshal() []byte { ...@@ -74,6 +76,10 @@ func (m *clientHelloMsg) marshal() []byte {
extensionsLength += len(m.sessionTicket) extensionsLength += len(m.sessionTicket)
numExtensions++ numExtensions++
} }
if len(m.signatureAndHashes) > 0 {
extensionsLength += 2 + 2*len(m.signatureAndHashes)
numExtensions++
}
if numExtensions > 0 { if numExtensions > 0 {
extensionsLength += 4 * numExtensions extensionsLength += 4 * numExtensions
length += 2 + extensionsLength length += 2 + extensionsLength
...@@ -199,6 +205,25 @@ func (m *clientHelloMsg) marshal() []byte { ...@@ -199,6 +205,25 @@ func (m *clientHelloMsg) marshal() []byte {
copy(z, m.sessionTicket) copy(z, m.sessionTicket)
z = z[len(m.sessionTicket):] z = z[len(m.sessionTicket):]
} }
if len(m.signatureAndHashes) > 0 {
// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
z[0] = byte(extensionSignatureAlgorithms >> 8)
z[1] = byte(extensionSignatureAlgorithms)
l := 2 + 2*len(m.signatureAndHashes)
z[2] = byte(l >> 8)
z[3] = byte(l)
z = z[4:]
l -= 2
z[0] = byte(l >> 8)
z[1] = byte(l)
z = z[2:]
for _, sigAndHash := range m.signatureAndHashes {
z[0] = sigAndHash.hash
z[1] = sigAndHash.signature
z = z[2:]
}
}
m.raw = x m.raw = x
...@@ -249,6 +274,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { ...@@ -249,6 +274,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
m.ocspStapling = false m.ocspStapling = false
m.ticketSupported = false m.ticketSupported = false
m.sessionTicket = nil m.sessionTicket = nil
m.signatureAndHashes = nil
if len(data) == 0 { if len(data) == 0 {
// ClientHello is optionally followed by extension data // ClientHello is optionally followed by extension data
...@@ -336,6 +362,23 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { ...@@ -336,6 +362,23 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
// http://tools.ietf.org/html/rfc5077#section-3.2 // http://tools.ietf.org/html/rfc5077#section-3.2
m.ticketSupported = true m.ticketSupported = true
m.sessionTicket = data[:length] m.sessionTicket = data[:length]
case extensionSignatureAlgorithms:
// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
if length < 2 || length&1 != 0 {
return false
}
l := int(data[0])<<8 | int(data[1])
if l != length-2 {
return false
}
n := l / 2
d := data[2:]
m.signatureAndHashes = make([]signatureAndHash, n)
for i := range m.signatureAndHashes {
m.signatureAndHashes[i].hash = d[0]
m.signatureAndHashes[i].signature = d[1]
d = d[2:]
}
} }
data = data[length:] data = data[length:]
} }
...@@ -899,8 +942,14 @@ func (m *nextProtoMsg) unmarshal(data []byte) bool { ...@@ -899,8 +942,14 @@ func (m *nextProtoMsg) unmarshal(data []byte) bool {
} }
type certificateRequestMsg struct { type certificateRequestMsg struct {
raw []byte raw []byte
// hasSignatureAndHash indicates whether this message includes a list
// of signature and hash functions. This change was introduced with TLS
// 1.2.
hasSignatureAndHash bool
certificateTypes []byte certificateTypes []byte
signatureAndHashes []signatureAndHash
certificateAuthorities [][]byte certificateAuthorities [][]byte
} }
...@@ -912,7 +961,8 @@ func (m *certificateRequestMsg) equal(i interface{}) bool { ...@@ -912,7 +961,8 @@ func (m *certificateRequestMsg) equal(i interface{}) bool {
return bytes.Equal(m.raw, m1.raw) && return bytes.Equal(m.raw, m1.raw) &&
bytes.Equal(m.certificateTypes, m1.certificateTypes) && bytes.Equal(m.certificateTypes, m1.certificateTypes) &&
eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities) eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities) &&
eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes)
} }
func (m *certificateRequestMsg) marshal() (x []byte) { func (m *certificateRequestMsg) marshal() (x []byte) {
...@@ -928,6 +978,10 @@ func (m *certificateRequestMsg) marshal() (x []byte) { ...@@ -928,6 +978,10 @@ func (m *certificateRequestMsg) marshal() (x []byte) {
} }
length += casLength length += casLength
if m.hasSignatureAndHash {
length += 2 + 2*len(m.signatureAndHashes)
}
x = make([]byte, 4+length) x = make([]byte, 4+length)
x[0] = typeCertificateRequest x[0] = typeCertificateRequest
x[1] = uint8(length >> 16) x[1] = uint8(length >> 16)
...@@ -938,6 +992,19 @@ func (m *certificateRequestMsg) marshal() (x []byte) { ...@@ -938,6 +992,19 @@ func (m *certificateRequestMsg) marshal() (x []byte) {
copy(x[5:], m.certificateTypes) copy(x[5:], m.certificateTypes)
y := x[5+len(m.certificateTypes):] y := x[5+len(m.certificateTypes):]
if m.hasSignatureAndHash {
n := len(m.signatureAndHashes) * 2
y[0] = uint8(n >> 8)
y[1] = uint8(n)
y = y[2:]
for _, sigAndHash := range m.signatureAndHashes {
y[0] = sigAndHash.hash
y[1] = sigAndHash.signature
y = y[2:]
}
}
y[0] = uint8(casLength >> 8) y[0] = uint8(casLength >> 8)
y[1] = uint8(casLength) y[1] = uint8(casLength)
y = y[2:] y = y[2:]
...@@ -978,6 +1045,27 @@ func (m *certificateRequestMsg) unmarshal(data []byte) bool { ...@@ -978,6 +1045,27 @@ func (m *certificateRequestMsg) unmarshal(data []byte) bool {
data = data[numCertTypes:] data = data[numCertTypes:]
if m.hasSignatureAndHash {
if len(data) < 2 {
return false
}
sigAndHashLen := uint16(data[0])<<8 | uint16(data[1])
data = data[2:]
if sigAndHashLen&1 != 0 {
return false
}
if len(data) < int(sigAndHashLen) {
return false
}
numSigAndHash := sigAndHashLen / 2
m.signatureAndHashes = make([]signatureAndHash, numSigAndHash)
for i := range m.signatureAndHashes {
m.signatureAndHashes[i].hash = data[0]
m.signatureAndHashes[i].signature = data[1]
data = data[2:]
}
}
if len(data) < 2 { if len(data) < 2 {
return false return false
} }
...@@ -1013,8 +1101,10 @@ func (m *certificateRequestMsg) unmarshal(data []byte) bool { ...@@ -1013,8 +1101,10 @@ func (m *certificateRequestMsg) unmarshal(data []byte) bool {
} }
type certificateVerifyMsg struct { type certificateVerifyMsg struct {
raw []byte raw []byte
signature []byte hasSignatureAndHash bool
signatureAndHash signatureAndHash
signature []byte
} }
func (m *certificateVerifyMsg) equal(i interface{}) bool { func (m *certificateVerifyMsg) equal(i interface{}) bool {
...@@ -1024,6 +1114,9 @@ func (m *certificateVerifyMsg) equal(i interface{}) bool { ...@@ -1024,6 +1114,9 @@ func (m *certificateVerifyMsg) equal(i interface{}) bool {
} }
return bytes.Equal(m.raw, m1.raw) && return bytes.Equal(m.raw, m1.raw) &&
m.hasSignatureAndHash == m1.hasSignatureAndHash &&
m.signatureAndHash.hash == m1.signatureAndHash.hash &&
m.signatureAndHash.signature == m1.signatureAndHash.signature &&
bytes.Equal(m.signature, m1.signature) bytes.Equal(m.signature, m1.signature)
} }
...@@ -1035,14 +1128,23 @@ func (m *certificateVerifyMsg) marshal() (x []byte) { ...@@ -1035,14 +1128,23 @@ func (m *certificateVerifyMsg) marshal() (x []byte) {
// See http://tools.ietf.org/html/rfc4346#section-7.4.8 // See http://tools.ietf.org/html/rfc4346#section-7.4.8
siglength := len(m.signature) siglength := len(m.signature)
length := 2 + siglength length := 2 + siglength
if m.hasSignatureAndHash {
length += 2
}
x = make([]byte, 4+length) x = make([]byte, 4+length)
x[0] = typeCertificateVerify x[0] = typeCertificateVerify
x[1] = uint8(length >> 16) x[1] = uint8(length >> 16)
x[2] = uint8(length >> 8) x[2] = uint8(length >> 8)
x[3] = uint8(length) x[3] = uint8(length)
x[4] = uint8(siglength >> 8) y := x[4:]
x[5] = uint8(siglength) if m.hasSignatureAndHash {
copy(x[6:], m.signature) y[0] = m.signatureAndHash.hash
y[1] = m.signatureAndHash.signature
y = y[2:]
}
y[0] = uint8(siglength >> 8)
y[1] = uint8(siglength)
copy(y[2:], m.signature)
m.raw = x m.raw = x
...@@ -1061,12 +1163,23 @@ func (m *certificateVerifyMsg) unmarshal(data []byte) bool { ...@@ -1061,12 +1163,23 @@ func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
return false return false
} }
siglength := int(data[4])<<8 + int(data[5]) data = data[4:]
if len(data)-6 != siglength { if m.hasSignatureAndHash {
m.signatureAndHash.hash = data[0]
m.signatureAndHash.signature = data[1]
data = data[2:]
}
if len(data) < 2 {
return false
}
siglength := int(data[0])<<8 + int(data[1])
data = data[2:]
if len(data) != siglength {
return false return false
} }
m.signature = data[6:] m.signature = data
return true return true
} }
...@@ -1165,3 +1278,16 @@ func eqByteSlices(x, y [][]byte) bool { ...@@ -1165,3 +1278,16 @@ func eqByteSlices(x, y [][]byte) bool {
} }
return true return true
} }
func eqSignatureAndHashes(x, y []signatureAndHash) bool {
if len(x) != len(y) {
return false
}
for i, v := range x {
v2 := y[i]
if v.hash != v2.hash || v.signature != v2.signature {
return false
}
}
return true
}
...@@ -135,6 +135,9 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { ...@@ -135,6 +135,9 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m.sessionTicket = randomBytes(rand.Intn(300), rand) m.sessionTicket = randomBytes(rand.Intn(300), rand)
} }
} }
if rand.Intn(10) > 5 {
m.signatureAndHashes = supportedSignatureAlgorithms
}
return reflect.ValueOf(m) return reflect.ValueOf(m)
} }
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
package tls package tls
import ( import (
"crypto"
"crypto/rsa" "crypto/rsa"
"crypto/subtle" "crypto/subtle"
"crypto/x509" "crypto/x509"
...@@ -292,7 +291,7 @@ func (hs *serverHandshakeState) doFullHandshake() error { ...@@ -292,7 +291,7 @@ func (hs *serverHandshakeState) doFullHandshake() error {
c.writeRecord(recordTypeHandshake, certStatus.marshal()) c.writeRecord(recordTypeHandshake, certStatus.marshal())
} }
keyAgreement := hs.suite.ka() keyAgreement := hs.suite.ka(c.vers)
skx, err := keyAgreement.generateServerKeyExchange(config, cert, hs.clientHello, hs.hello) skx, err := keyAgreement.generateServerKeyExchange(config, cert, hs.clientHello, hs.hello)
if err != nil { if err != nil {
c.sendAlert(alertHandshakeFailure) c.sendAlert(alertHandshakeFailure)
...@@ -307,6 +306,10 @@ func (hs *serverHandshakeState) doFullHandshake() error { ...@@ -307,6 +306,10 @@ func (hs *serverHandshakeState) doFullHandshake() error {
// Request a client certificate // Request a client certificate
certReq := new(certificateRequestMsg) certReq := new(certificateRequestMsg)
certReq.certificateTypes = []byte{certTypeRSASign} certReq.certificateTypes = []byte{certTypeRSASign}
if c.vers >= VersionTLS12 {
certReq.hasSignatureAndHash = true
certReq.signatureAndHashes = supportedSignatureAlgorithms
}
// An empty list of certificateAuthorities signals to // An empty list of certificateAuthorities signals to
// the client that it may send any certificate in response // the client that it may send any certificate in response
...@@ -383,10 +386,8 @@ func (hs *serverHandshakeState) doFullHandshake() error { ...@@ -383,10 +386,8 @@ func (hs *serverHandshakeState) doFullHandshake() error {
return c.sendAlert(alertUnexpectedMessage) return c.sendAlert(alertUnexpectedMessage)
} }
digest := make([]byte, 0, 36) digest, hashFunc := hs.finishedHash.hashForClientCertificate()
digest = hs.finishedHash.serverMD5.Sum(digest) err = rsa.VerifyPKCS1v15(pub, hashFunc, digest, certVerify.signature)
digest = hs.finishedHash.serverSHA1.Sum(digest)
err = rsa.VerifyPKCS1v15(pub, crypto.MD5SHA1, digest, certVerify.signature)
if err != nil { if err != nil {
c.sendAlert(alertBadCertificate) c.sendAlert(alertBadCertificate)
return errors.New("could not validate signature of connection nonces: " + err.Error()) return errors.New("could not validate signature of connection nonces: " + err.Error())
......
This diff is collapsed.
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"crypto/md5" "crypto/md5"
"crypto/rsa" "crypto/rsa"
"crypto/sha1" "crypto/sha1"
"crypto/sha256"
"crypto/x509" "crypto/x509"
"errors" "errors"
"io" "io"
...@@ -84,7 +85,7 @@ func (ka rsaKeyAgreement) generateClientKeyExchange(config *Config, clientHello ...@@ -84,7 +85,7 @@ func (ka rsaKeyAgreement) generateClientKeyExchange(config *Config, clientHello
// md5SHA1Hash implements TLS 1.0's hybrid hash function which consists of the // md5SHA1Hash implements TLS 1.0's hybrid hash function which consists of the
// concatenation of an MD5 and SHA1 hash. // concatenation of an MD5 and SHA1 hash.
func md5SHA1Hash(slices ...[]byte) []byte { func md5SHA1Hash(slices [][]byte) []byte {
md5sha1 := make([]byte, md5.Size+sha1.Size) md5sha1 := make([]byte, md5.Size+sha1.Size)
hmd5 := md5.New() hmd5 := md5.New()
for _, slice := range slices { for _, slice := range slices {
...@@ -100,10 +101,29 @@ func md5SHA1Hash(slices ...[]byte) []byte { ...@@ -100,10 +101,29 @@ func md5SHA1Hash(slices ...[]byte) []byte {
return md5sha1 return md5sha1
} }
// sha256Hash implements TLS 1.2's hash function.
func sha256Hash(slices [][]byte) []byte {
h := sha256.New()
for _, slice := range slices {
h.Write(slice)
}
return h.Sum(nil)
}
// hashForServerKeyExchange hashes the given slices and returns their digest
// and the identifier of the hash function used.
func hashForServerKeyExchange(version uint16, slices ...[]byte) ([]byte, crypto.Hash) {
if version >= VersionTLS12 {
return sha256Hash(slices), crypto.SHA256
}
return md5SHA1Hash(slices), crypto.MD5SHA1
}
// ecdheRSAKeyAgreement implements a TLS key agreement where the server // ecdheRSAKeyAgreement implements a TLS key agreement where the server
// generates a ephemeral EC public/private key pair and signs it. The // generates a ephemeral EC public/private key pair and signs it. The
// pre-master secret is then calculated using ECDH. // pre-master secret is then calculated using ECDH.
type ecdheRSAKeyAgreement struct { type ecdheRSAKeyAgreement struct {
version uint16
privateKey []byte privateKey []byte
curve elliptic.Curve curve elliptic.Curve
x, y *big.Int x, y *big.Int
...@@ -150,16 +170,25 @@ Curve: ...@@ -150,16 +170,25 @@ Curve:
serverECDHParams[3] = byte(len(ecdhePublic)) serverECDHParams[3] = byte(len(ecdhePublic))
copy(serverECDHParams[4:], ecdhePublic) copy(serverECDHParams[4:], ecdhePublic)
md5sha1 := md5SHA1Hash(clientHello.random, hello.random, serverECDHParams) digest, hashFunc := hashForServerKeyExchange(ka.version, clientHello.random, hello.random, serverECDHParams)
sig, err := rsa.SignPKCS1v15(config.rand(), cert.PrivateKey.(*rsa.PrivateKey), crypto.MD5SHA1, md5sha1) sig, err := rsa.SignPKCS1v15(config.rand(), cert.PrivateKey.(*rsa.PrivateKey), hashFunc, digest)
if err != nil { if err != nil {
return nil, errors.New("failed to sign ECDHE parameters: " + err.Error()) return nil, errors.New("failed to sign ECDHE parameters: " + err.Error())
} }
skx := new(serverKeyExchangeMsg) skx := new(serverKeyExchangeMsg)
skx.key = make([]byte, len(serverECDHParams)+2+len(sig)) sigAndHashLen := 0
if ka.version >= VersionTLS12 {
sigAndHashLen = 2
}
skx.key = make([]byte, len(serverECDHParams)+sigAndHashLen+2+len(sig))
copy(skx.key, serverECDHParams) copy(skx.key, serverECDHParams)
k := skx.key[len(serverECDHParams):] k := skx.key[len(serverECDHParams):]
if ka.version >= VersionTLS12 {
k[0] = hashSHA256
k[1] = signatureRSA
k = k[2:]
}
k[0] = byte(len(sig) >> 8) k[0] = byte(len(sig) >> 8)
k[1] = byte(len(sig)) k[1] = byte(len(sig))
copy(k[2:], sig) copy(k[2:], sig)
...@@ -219,14 +248,21 @@ func (ka *ecdheRSAKeyAgreement) processServerKeyExchange(config *Config, clientH ...@@ -219,14 +248,21 @@ func (ka *ecdheRSAKeyAgreement) processServerKeyExchange(config *Config, clientH
if len(sig) < 2 { if len(sig) < 2 {
return errServerKeyExchange return errServerKeyExchange
} }
if ka.version >= VersionTLS12 {
// ignore SignatureAndHashAlgorithm
sig = sig[2:]
if len(sig) < 2 {
return errServerKeyExchange
}
}
sigLen := int(sig[0])<<8 | int(sig[1]) sigLen := int(sig[0])<<8 | int(sig[1])
if sigLen+2 != len(sig) { if sigLen+2 != len(sig) {
return errServerKeyExchange return errServerKeyExchange
} }
sig = sig[2:] sig = sig[2:]
md5sha1 := md5SHA1Hash(clientHello.random, serverHello.random, serverECDHParams) digest, hashFunc := hashForServerKeyExchange(ka.version, clientHello.random, serverHello.random, serverECDHParams)
return rsa.VerifyPKCS1v15(cert.PublicKey.(*rsa.PublicKey), crypto.MD5SHA1, md5sha1, sig) return rsa.VerifyPKCS1v15(cert.PublicKey.(*rsa.PublicKey), hashFunc, digest, sig)
} }
func (ka *ecdheRSAKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) { func (ka *ecdheRSAKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
......
...@@ -5,9 +5,11 @@ ...@@ -5,9 +5,11 @@
package tls package tls
import ( import (
"crypto"
"crypto/hmac" "crypto/hmac"
"crypto/md5" "crypto/md5"
"crypto/sha1" "crypto/sha1"
"crypto/sha256"
"hash" "hash"
) )
...@@ -43,8 +45,8 @@ func pHash(result, secret, seed []byte, hash func() hash.Hash) { ...@@ -43,8 +45,8 @@ func pHash(result, secret, seed []byte, hash func() hash.Hash) {
} }
} }
// pRF10 implements the TLS 1.0 pseudo-random function, as defined in RFC 2246, section 5. // prf10 implements the TLS 1.0 pseudo-random function, as defined in RFC 2246, section 5.
func pRF10(result, secret, label, seed []byte) { func prf10(result, secret, label, seed []byte) {
hashSHA1 := sha1.New hashSHA1 := sha1.New
hashMD5 := md5.New hashMD5 := md5.New
...@@ -62,9 +64,18 @@ func pRF10(result, secret, label, seed []byte) { ...@@ -62,9 +64,18 @@ func pRF10(result, secret, label, seed []byte) {
} }
} }
// pRF30 implements the SSL 3.0 pseudo-random function, as defined in // prf12 implements the TLS 1.2 pseudo-random function, as defined in RFC 5246, section 5.
func prf12(result, secret, label, seed []byte) {
labelAndSeed := make([]byte, len(label)+len(seed))
copy(labelAndSeed, label)
copy(labelAndSeed[len(label):], seed)
pHash(result, secret, labelAndSeed, sha256.New)
}
// prf30 implements the SSL 3.0 pseudo-random function, as defined in
// www.mozilla.org/projects/security/pki/nss/ssl/draft302.txt section 6. // www.mozilla.org/projects/security/pki/nss/ssl/draft302.txt section 6.
func pRF30(result, secret, label, seed []byte) { func prf30(result, secret, label, seed []byte) {
hashSHA1 := sha1.New() hashSHA1 := sha1.New()
hashMD5 := md5.New() hashMD5 := md5.New()
...@@ -106,19 +117,27 @@ var keyExpansionLabel = []byte("key expansion") ...@@ -106,19 +117,27 @@ var keyExpansionLabel = []byte("key expansion")
var clientFinishedLabel = []byte("client finished") var clientFinishedLabel = []byte("client finished")
var serverFinishedLabel = []byte("server finished") var serverFinishedLabel = []byte("server finished")
func prfForVersion(version uint16) func(result, secret, label, seed []byte) {
switch version {
case VersionSSL30:
return prf30
case VersionTLS10, VersionTLS11:
return prf10
case VersionTLS12:
return prf12
default:
panic("unknown version")
}
}
// masterFromPreMasterSecret generates the master secret from the pre-master // masterFromPreMasterSecret generates the master secret from the pre-master
// secret. See http://tools.ietf.org/html/rfc5246#section-8.1 // secret. See http://tools.ietf.org/html/rfc5246#section-8.1
func masterFromPreMasterSecret(version uint16, preMasterSecret, clientRandom, serverRandom []byte) []byte { func masterFromPreMasterSecret(version uint16, preMasterSecret, clientRandom, serverRandom []byte) []byte {
prf := pRF10
if version == VersionSSL30 {
prf = pRF30
}
var seed [tlsRandomLength * 2]byte var seed [tlsRandomLength * 2]byte
copy(seed[0:len(clientRandom)], clientRandom) copy(seed[0:len(clientRandom)], clientRandom)
copy(seed[len(clientRandom):], serverRandom) copy(seed[len(clientRandom):], serverRandom)
masterSecret := make([]byte, masterSecretLength) masterSecret := make([]byte, masterSecretLength)
prf(masterSecret, preMasterSecret, masterSecretLabel, seed[0:]) prfForVersion(version)(masterSecret, preMasterSecret, masterSecretLabel, seed[0:])
return masterSecret return masterSecret
} }
...@@ -126,18 +145,13 @@ func masterFromPreMasterSecret(version uint16, preMasterSecret, clientRandom, se ...@@ -126,18 +145,13 @@ func masterFromPreMasterSecret(version uint16, preMasterSecret, clientRandom, se
// secret, given the lengths of the MAC key, cipher key and IV, as defined in // secret, given the lengths of the MAC key, cipher key and IV, as defined in
// RFC 2246, section 6.3. // RFC 2246, section 6.3.
func keysFromMasterSecret(version uint16, masterSecret, clientRandom, serverRandom []byte, macLen, keyLen, ivLen int) (clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV []byte) { func keysFromMasterSecret(version uint16, masterSecret, clientRandom, serverRandom []byte, macLen, keyLen, ivLen int) (clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV []byte) {
prf := pRF10
if version == VersionSSL30 {
prf = pRF30
}
var seed [tlsRandomLength * 2]byte var seed [tlsRandomLength * 2]byte
copy(seed[0:len(clientRandom)], serverRandom) copy(seed[0:len(clientRandom)], serverRandom)
copy(seed[len(serverRandom):], clientRandom) copy(seed[len(serverRandom):], clientRandom)
n := 2*macLen + 2*keyLen + 2*ivLen n := 2*macLen + 2*keyLen + 2*ivLen
keyMaterial := make([]byte, n) keyMaterial := make([]byte, n)
prf(keyMaterial, masterSecret, keyExpansionLabel, seed[0:]) prfForVersion(version)(keyMaterial, masterSecret, keyExpansionLabel, seed[0:])
clientMAC = keyMaterial[:macLen] clientMAC = keyMaterial[:macLen]
keyMaterial = keyMaterial[macLen:] keyMaterial = keyMaterial[macLen:]
serverMAC = keyMaterial[:macLen] serverMAC = keyMaterial[:macLen]
...@@ -153,37 +167,34 @@ func keysFromMasterSecret(version uint16, masterSecret, clientRandom, serverRand ...@@ -153,37 +167,34 @@ func keysFromMasterSecret(version uint16, masterSecret, clientRandom, serverRand
} }
func newFinishedHash(version uint16) finishedHash { func newFinishedHash(version uint16) finishedHash {
return finishedHash{md5.New(), sha1.New(), md5.New(), sha1.New(), version} if version >= VersionTLS12 {
return finishedHash{sha256.New(), sha256.New(), nil, nil, version}
}
return finishedHash{sha1.New(), sha1.New(), md5.New(), md5.New(), version}
} }
// A finishedHash calculates the hash of a set of handshake messages suitable // A finishedHash calculates the hash of a set of handshake messages suitable
// for including in a Finished message. // for including in a Finished message.
type finishedHash struct { type finishedHash struct {
clientMD5 hash.Hash client hash.Hash
clientSHA1 hash.Hash server hash.Hash
serverMD5 hash.Hash
serverSHA1 hash.Hash // Prior to TLS 1.2, an additional MD5 hash is required.
version uint16 clientMD5 hash.Hash
serverMD5 hash.Hash
version uint16
} }
func (h finishedHash) Write(msg []byte) (n int, err error) { func (h finishedHash) Write(msg []byte) (n int, err error) {
h.clientMD5.Write(msg) h.client.Write(msg)
h.clientSHA1.Write(msg) h.server.Write(msg)
h.serverMD5.Write(msg)
h.serverSHA1.Write(msg)
return len(msg), nil
}
// finishedSum10 calculates the contents of the verify_data member of a TLSv1 if h.version < VersionTLS12 {
// Finished message given the MD5 and SHA1 hashes of a set of handshake h.clientMD5.Write(msg)
// messages. h.serverMD5.Write(msg)
func finishedSum10(md5, sha1, label, masterSecret []byte) []byte { }
seed := make([]byte, len(md5)+len(sha1)) return len(msg), nil
copy(seed, md5)
copy(seed[len(md5):], sha1)
out := make([]byte, finishedVerifyLength)
pRF10(out, masterSecret, label, seed)
return out
} }
// finishedSum30 calculates the contents of the verify_data member of a SSLv3 // finishedSum30 calculates the contents of the verify_data member of a SSLv3
...@@ -225,22 +236,52 @@ var ssl3ServerFinishedMagic = [4]byte{0x53, 0x52, 0x56, 0x52} ...@@ -225,22 +236,52 @@ var ssl3ServerFinishedMagic = [4]byte{0x53, 0x52, 0x56, 0x52}
// Finished message. // Finished message.
func (h finishedHash) clientSum(masterSecret []byte) []byte { func (h finishedHash) clientSum(masterSecret []byte) []byte {
if h.version == VersionSSL30 { if h.version == VersionSSL30 {
return finishedSum30(h.clientMD5, h.clientSHA1, masterSecret, ssl3ClientFinishedMagic) return finishedSum30(h.clientMD5, h.client, masterSecret, ssl3ClientFinishedMagic)
} }
md5 := h.clientMD5.Sum(nil) out := make([]byte, finishedVerifyLength)
sha1 := h.clientSHA1.Sum(nil) if h.version >= VersionTLS12 {
return finishedSum10(md5, sha1, clientFinishedLabel, masterSecret) seed := h.client.Sum(nil)
prf12(out, masterSecret, clientFinishedLabel, seed)
} else {
seed := make([]byte, 0, md5.Size+sha1.Size)
seed = h.clientMD5.Sum(seed)
seed = h.client.Sum(seed)
prf10(out, masterSecret, clientFinishedLabel, seed)
}
return out
} }
// serverSum returns the contents of the verify_data member of a server's // serverSum returns the contents of the verify_data member of a server's
// Finished message. // Finished message.
func (h finishedHash) serverSum(masterSecret []byte) []byte { func (h finishedHash) serverSum(masterSecret []byte) []byte {
if h.version == VersionSSL30 { if h.version == VersionSSL30 {
return finishedSum30(h.serverMD5, h.serverSHA1, masterSecret, ssl3ServerFinishedMagic) return finishedSum30(h.serverMD5, h.server, masterSecret, ssl3ServerFinishedMagic)
}
out := make([]byte, finishedVerifyLength)
if h.version >= VersionTLS12 {
seed := h.server.Sum(nil)
prf12(out, masterSecret, serverFinishedLabel, seed)
} else {
seed := make([]byte, 0, md5.Size+sha1.Size)
seed = h.serverMD5.Sum(seed)
seed = h.server.Sum(seed)
prf10(out, masterSecret, serverFinishedLabel, seed)
}
return out
}
// hashForClientCertificate returns a digest and hash function identifier
// suitable for signing by a TLS client certificate.
func (h finishedHash) hashForClientCertificate() ([]byte, crypto.Hash) {
if h.version >= VersionTLS12 {
digest := h.server.Sum(nil)
return digest, crypto.SHA256
} }
md5 := h.serverMD5.Sum(nil) digest := make([]byte, 0, 36)
sha1 := h.serverSHA1.Sum(nil) digest = h.serverMD5.Sum(digest)
return finishedSum10(md5, sha1, serverFinishedLabel, masterSecret) digest = h.server.Sum(digest)
return digest, crypto.MD5SHA1
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment