Commit d1bbdbe7 authored by Peter Wu's avatar Peter Wu Committed by Adam Langley

crypto/tls: replace signatureAndHash by SignatureScheme.

Consolidate the signature and hash fields (SignatureAndHashAlgorithm in
TLS 1.2) into a single uint16 (SignatureScheme in TLS 1.3 draft 21).
This makes it easier to add RSASSA-PSS for TLS 1.2 in the future.

Fields were named like "signatureAlgorithm" rather than
"signatureScheme" since that name is also used throughout the 1.3 draft.

The only new public symbol is ECDSAWithSHA1, other than that this is an
internal change with no new functionality.

Change-Id: Iba63d262ab1af895420583ac9e302d9705a7e0f0
Reviewed-on: https://go-review.googlesource.com/62210Reviewed-by: 's avatarAdam Langley <agl@golang.org>
parent c996d07f
...@@ -126,35 +126,23 @@ const ( ...@@ -126,35 +126,23 @@ const (
// Rest of these are reserved by the TLS spec // Rest of these are reserved by the TLS spec
) )
// Hash functions for TLS 1.2 (See RFC 5246, section A.4.1)
const (
hashSHA1 uint8 = 2
hashSHA256 uint8 = 4
hashSHA384 uint8 = 5
)
// Signature algorithms for TLS 1.2 (See RFC 5246, section A.4.1) // Signature algorithms for TLS 1.2 (See RFC 5246, section A.4.1)
const ( const (
signatureRSA uint8 = 1 signatureRSA uint8 = 1
signatureECDSA uint8 = 3 signatureECDSA uint8 = 3
) )
// signatureAndHash mirrors the TLS 1.2, SignatureAndHashAlgorithm struct. See
// RFC 5246, section A.4.1.
type signatureAndHash struct {
hash, signature uint8
}
// supportedSignatureAlgorithms contains the signature and hash algorithms that // supportedSignatureAlgorithms contains the signature and hash algorithms that
// the code advertises as supported in a TLS 1.2 ClientHello and in a TLS 1.2 // the code advertises as supported in a TLS 1.2 ClientHello and in a TLS 1.2
// CertificateRequest. // CertificateRequest. The two fields are merged to match with TLS 1.3.
var supportedSignatureAlgorithms = []signatureAndHash{ // Note that in TLS 1.2, the ECDSA algorithms are not constrained to P-256, etc.
{hashSHA256, signatureRSA}, var supportedSignatureAlgorithms = []SignatureScheme{
{hashSHA256, signatureECDSA}, PKCS1WithSHA256,
{hashSHA384, signatureRSA}, ECDSAWithP256AndSHA256,
{hashSHA384, signatureECDSA}, PKCS1WithSHA384,
{hashSHA1, signatureRSA}, ECDSAWithP384AndSHA384,
{hashSHA1, signatureECDSA}, PKCS1WithSHA1,
ECDSAWithSHA1,
} }
// ConnectionState records basic TLS details about the connection. // ConnectionState records basic TLS details about the connection.
...@@ -234,6 +222,9 @@ const ( ...@@ -234,6 +222,9 @@ const (
ECDSAWithP256AndSHA256 SignatureScheme = 0x0403 ECDSAWithP256AndSHA256 SignatureScheme = 0x0403
ECDSAWithP384AndSHA384 SignatureScheme = 0x0503 ECDSAWithP384AndSHA384 SignatureScheme = 0x0503
ECDSAWithP521AndSHA512 SignatureScheme = 0x0603 ECDSAWithP521AndSHA512 SignatureScheme = 0x0603
// Legacy signature and hash algorithms for TLS 1.2.
ECDSAWithSHA1 SignatureScheme = 0x0203
) )
// ClientHelloInfo contains information from a ClientHello message in order to // ClientHelloInfo contains information from a ClientHello message in order to
...@@ -961,11 +952,24 @@ func unexpectedMessageError(wanted, got interface{}) error { ...@@ -961,11 +952,24 @@ func unexpectedMessageError(wanted, got interface{}) error {
return fmt.Errorf("tls: received unexpected handshake message of type %T when waiting for %T", got, wanted) return fmt.Errorf("tls: received unexpected handshake message of type %T when waiting for %T", got, wanted)
} }
func isSupportedSignatureAndHash(sigHash signatureAndHash, sigHashes []signatureAndHash) bool { func isSupportedSignatureAlgorithm(sigAlg SignatureScheme, supportedSignatureAlgorithms []SignatureScheme) bool {
for _, s := range sigHashes { for _, s := range supportedSignatureAlgorithms {
if s == sigHash { if s == sigAlg {
return true return true
} }
} }
return false return false
} }
// signatureFromSignatureScheme maps a signature algorithm to the underlying
// signature method (without hash function).
func signatureFromSignatureScheme(signatureAlgorithm SignatureScheme) uint8 {
switch signatureAlgorithm {
case PKCS1WithSHA1, PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512:
return signatureRSA
case ECDSAWithSHA1, ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512:
return signatureECDSA
default:
return 0
}
}
...@@ -85,7 +85,7 @@ NextCipherSuite: ...@@ -85,7 +85,7 @@ NextCipherSuite:
} }
if hello.vers >= VersionTLS12 { if hello.vers >= VersionTLS12 {
hello.signatureAndHashes = supportedSignatureAlgorithms hello.supportedSignatureAlgorithms = supportedSignatureAlgorithms
} }
return hello, nil return hello, nil
...@@ -482,12 +482,15 @@ func (hs *clientHandshakeState) doFullHandshake() error { ...@@ -482,12 +482,15 @@ func (hs *clientHandshakeState) doFullHandshake() error {
return fmt.Errorf("tls: failed to sign handshake with client certificate: unknown client certificate key type: %T", key) return fmt.Errorf("tls: failed to sign handshake with client certificate: unknown client certificate key type: %T", key)
} }
certVerify.signatureAndHash, err = hs.finishedHash.selectClientCertSignatureAlgorithm(certReq.signatureAndHashes, signatureType) // SignatureAndHashAlgorithm was introduced in TLS 1.2.
if err != nil { if certVerify.hasSignatureAndHash {
c.sendAlert(alertInternalError) certVerify.signatureAlgorithm, err = hs.finishedHash.selectClientCertSignatureAlgorithm(certReq.supportedSignatureAlgorithms, signatureType)
return err if err != nil {
c.sendAlert(alertInternalError)
return err
}
} }
digest, hashFunc, err := hs.finishedHash.hashForClientCertificate(certVerify.signatureAndHash, hs.masterSecret) digest, hashFunc, err := hs.finishedHash.hashForClientCertificate(signatureType, certVerify.signatureAlgorithm, hs.masterSecret)
if err != nil { if err != nil {
c.sendAlert(alertInternalError) c.sendAlert(alertInternalError)
return err return err
...@@ -746,10 +749,7 @@ func (hs *clientHandshakeState) getCertificate(certReq *certificateRequestMsg) ( ...@@ -746,10 +749,7 @@ func (hs *clientHandshakeState) getCertificate(certReq *certificateRequestMsg) (
signatureSchemes = signatureSchemes[:len(signatureSchemes)-tls11SignatureSchemesNumRSA] signatureSchemes = signatureSchemes[:len(signatureSchemes)-tls11SignatureSchemesNumRSA]
} }
} else { } else {
signatureSchemes = make([]SignatureScheme, 0, len(certReq.signatureAndHashes)) signatureSchemes = certReq.supportedSignatureAlgorithms
for _, sah := range certReq.signatureAndHashes {
signatureSchemes = append(signatureSchemes, SignatureScheme(sah.hash)<<8+SignatureScheme(sah.signature))
}
} }
return c.config.GetClientCertificate(&CertificateRequestInfo{ return c.config.GetClientCertificate(&CertificateRequestInfo{
......
...@@ -24,7 +24,7 @@ type clientHelloMsg struct { ...@@ -24,7 +24,7 @@ type clientHelloMsg struct {
supportedPoints []uint8 supportedPoints []uint8
ticketSupported bool ticketSupported bool
sessionTicket []uint8 sessionTicket []uint8
signatureAndHashes []signatureAndHash supportedSignatureAlgorithms []SignatureScheme
secureRenegotiation []byte secureRenegotiation []byte
secureRenegotiationSupported bool secureRenegotiationSupported bool
alpnProtocols []string alpnProtocols []string
...@@ -50,7 +50,7 @@ func (m *clientHelloMsg) equal(i interface{}) bool { ...@@ -50,7 +50,7 @@ func (m *clientHelloMsg) equal(i interface{}) bool {
bytes.Equal(m.supportedPoints, m1.supportedPoints) && bytes.Equal(m.supportedPoints, m1.supportedPoints) &&
m.ticketSupported == m1.ticketSupported && m.ticketSupported == m1.ticketSupported &&
bytes.Equal(m.sessionTicket, m1.sessionTicket) && bytes.Equal(m.sessionTicket, m1.sessionTicket) &&
eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes) && eqSignatureAlgorithms(m.supportedSignatureAlgorithms, m1.supportedSignatureAlgorithms) &&
m.secureRenegotiationSupported == m1.secureRenegotiationSupported && m.secureRenegotiationSupported == m1.secureRenegotiationSupported &&
bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) && bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) &&
eqStrings(m.alpnProtocols, m1.alpnProtocols) eqStrings(m.alpnProtocols, m1.alpnProtocols)
...@@ -87,8 +87,8 @@ func (m *clientHelloMsg) marshal() []byte { ...@@ -87,8 +87,8 @@ func (m *clientHelloMsg) marshal() []byte {
extensionsLength += len(m.sessionTicket) extensionsLength += len(m.sessionTicket)
numExtensions++ numExtensions++
} }
if len(m.signatureAndHashes) > 0 { if len(m.supportedSignatureAlgorithms) > 0 {
extensionsLength += 2 + 2*len(m.signatureAndHashes) extensionsLength += 2 + 2*len(m.supportedSignatureAlgorithms)
numExtensions++ numExtensions++
} }
if m.secureRenegotiationSupported { if m.secureRenegotiationSupported {
...@@ -234,11 +234,11 @@ func (m *clientHelloMsg) marshal() []byte { ...@@ -234,11 +234,11 @@ func (m *clientHelloMsg) marshal() []byte {
copy(z, m.sessionTicket) copy(z, m.sessionTicket)
z = z[len(m.sessionTicket):] z = z[len(m.sessionTicket):]
} }
if len(m.signatureAndHashes) > 0 { if len(m.supportedSignatureAlgorithms) > 0 {
// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1 // https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
z[0] = byte(extensionSignatureAlgorithms >> 8) z[0] = byte(extensionSignatureAlgorithms >> 8)
z[1] = byte(extensionSignatureAlgorithms) z[1] = byte(extensionSignatureAlgorithms)
l := 2 + 2*len(m.signatureAndHashes) l := 2 + 2*len(m.supportedSignatureAlgorithms)
z[2] = byte(l >> 8) z[2] = byte(l >> 8)
z[3] = byte(l) z[3] = byte(l)
z = z[4:] z = z[4:]
...@@ -247,9 +247,9 @@ func (m *clientHelloMsg) marshal() []byte { ...@@ -247,9 +247,9 @@ func (m *clientHelloMsg) marshal() []byte {
z[0] = byte(l >> 8) z[0] = byte(l >> 8)
z[1] = byte(l) z[1] = byte(l)
z = z[2:] z = z[2:]
for _, sigAndHash := range m.signatureAndHashes { for _, sigAlgo := range m.supportedSignatureAlgorithms {
z[0] = sigAndHash.hash z[0] = byte(sigAlgo >> 8)
z[1] = sigAndHash.signature z[1] = byte(sigAlgo)
z = z[2:] z = z[2:]
} }
} }
...@@ -344,7 +344,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { ...@@ -344,7 +344,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
m.ocspStapling = false m.ocspStapling = false
m.ticketSupported = false m.ticketSupported = false
m.sessionTicket = nil m.sessionTicket = nil
m.signatureAndHashes = nil m.supportedSignatureAlgorithms = nil
m.alpnProtocols = nil m.alpnProtocols = nil
m.scts = false m.scts = false
...@@ -455,10 +455,9 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { ...@@ -455,10 +455,9 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
} }
n := l / 2 n := l / 2
d := data[2:] d := data[2:]
m.signatureAndHashes = make([]signatureAndHash, n) m.supportedSignatureAlgorithms = make([]SignatureScheme, n)
for i := range m.signatureAndHashes { for i := range m.supportedSignatureAlgorithms {
m.signatureAndHashes[i].hash = d[0] m.supportedSignatureAlgorithms[i] = SignatureScheme(d[0])<<8 | SignatureScheme(d[1])
m.signatureAndHashes[i].signature = d[1]
d = d[2:] d = d[2:]
} }
case extensionRenegotiationInfo: case extensionRenegotiationInfo:
...@@ -1203,9 +1202,9 @@ type certificateRequestMsg struct { ...@@ -1203,9 +1202,9 @@ type certificateRequestMsg struct {
// 1.2. // 1.2.
hasSignatureAndHash bool hasSignatureAndHash bool
certificateTypes []byte certificateTypes []byte
signatureAndHashes []signatureAndHash supportedSignatureAlgorithms []SignatureScheme
certificateAuthorities [][]byte certificateAuthorities [][]byte
} }
func (m *certificateRequestMsg) equal(i interface{}) bool { func (m *certificateRequestMsg) equal(i interface{}) bool {
...@@ -1217,7 +1216,7 @@ func (m *certificateRequestMsg) equal(i interface{}) bool { ...@@ -1217,7 +1216,7 @@ func (m *certificateRequestMsg) equal(i interface{}) bool {
return bytes.Equal(m.raw, m1.raw) && return bytes.Equal(m.raw, m1.raw) &&
bytes.Equal(m.certificateTypes, m1.certificateTypes) && bytes.Equal(m.certificateTypes, m1.certificateTypes) &&
eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities) && eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities) &&
eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes) eqSignatureAlgorithms(m.supportedSignatureAlgorithms, m1.supportedSignatureAlgorithms)
} }
func (m *certificateRequestMsg) marshal() (x []byte) { func (m *certificateRequestMsg) marshal() (x []byte) {
...@@ -1234,7 +1233,7 @@ func (m *certificateRequestMsg) marshal() (x []byte) { ...@@ -1234,7 +1233,7 @@ func (m *certificateRequestMsg) marshal() (x []byte) {
length += casLength length += casLength
if m.hasSignatureAndHash { if m.hasSignatureAndHash {
length += 2 + 2*len(m.signatureAndHashes) length += 2 + 2*len(m.supportedSignatureAlgorithms)
} }
x = make([]byte, 4+length) x = make([]byte, 4+length)
...@@ -1249,13 +1248,13 @@ func (m *certificateRequestMsg) marshal() (x []byte) { ...@@ -1249,13 +1248,13 @@ func (m *certificateRequestMsg) marshal() (x []byte) {
y := x[5+len(m.certificateTypes):] y := x[5+len(m.certificateTypes):]
if m.hasSignatureAndHash { if m.hasSignatureAndHash {
n := len(m.signatureAndHashes) * 2 n := len(m.supportedSignatureAlgorithms) * 2
y[0] = uint8(n >> 8) y[0] = uint8(n >> 8)
y[1] = uint8(n) y[1] = uint8(n)
y = y[2:] y = y[2:]
for _, sigAndHash := range m.signatureAndHashes { for _, sigAlgo := range m.supportedSignatureAlgorithms {
y[0] = sigAndHash.hash y[0] = uint8(sigAlgo >> 8)
y[1] = sigAndHash.signature y[1] = uint8(sigAlgo)
y = y[2:] y = y[2:]
} }
} }
...@@ -1312,11 +1311,10 @@ func (m *certificateRequestMsg) unmarshal(data []byte) bool { ...@@ -1312,11 +1311,10 @@ func (m *certificateRequestMsg) unmarshal(data []byte) bool {
if len(data) < int(sigAndHashLen) { if len(data) < int(sigAndHashLen) {
return false return false
} }
numSigAndHash := sigAndHashLen / 2 numSigAlgos := sigAndHashLen / 2
m.signatureAndHashes = make([]signatureAndHash, numSigAndHash) m.supportedSignatureAlgorithms = make([]SignatureScheme, numSigAlgos)
for i := range m.signatureAndHashes { for i := range m.supportedSignatureAlgorithms {
m.signatureAndHashes[i].hash = data[0] m.supportedSignatureAlgorithms[i] = SignatureScheme(data[0])<<8 | SignatureScheme(data[1])
m.signatureAndHashes[i].signature = data[1]
data = data[2:] data = data[2:]
} }
} }
...@@ -1355,7 +1353,7 @@ func (m *certificateRequestMsg) unmarshal(data []byte) bool { ...@@ -1355,7 +1353,7 @@ func (m *certificateRequestMsg) unmarshal(data []byte) bool {
type certificateVerifyMsg struct { type certificateVerifyMsg struct {
raw []byte raw []byte
hasSignatureAndHash bool hasSignatureAndHash bool
signatureAndHash signatureAndHash signatureAlgorithm SignatureScheme
signature []byte signature []byte
} }
...@@ -1367,8 +1365,7 @@ func (m *certificateVerifyMsg) equal(i interface{}) bool { ...@@ -1367,8 +1365,7 @@ func (m *certificateVerifyMsg) equal(i interface{}) bool {
return bytes.Equal(m.raw, m1.raw) && return bytes.Equal(m.raw, m1.raw) &&
m.hasSignatureAndHash == m1.hasSignatureAndHash && m.hasSignatureAndHash == m1.hasSignatureAndHash &&
m.signatureAndHash.hash == m1.signatureAndHash.hash && m.signatureAlgorithm == m1.signatureAlgorithm &&
m.signatureAndHash.signature == m1.signatureAndHash.signature &&
bytes.Equal(m.signature, m1.signature) bytes.Equal(m.signature, m1.signature)
} }
...@@ -1390,8 +1387,8 @@ func (m *certificateVerifyMsg) marshal() (x []byte) { ...@@ -1390,8 +1387,8 @@ func (m *certificateVerifyMsg) marshal() (x []byte) {
x[3] = uint8(length) x[3] = uint8(length)
y := x[4:] y := x[4:]
if m.hasSignatureAndHash { if m.hasSignatureAndHash {
y[0] = m.signatureAndHash.hash y[0] = uint8(m.signatureAlgorithm >> 8)
y[1] = m.signatureAndHash.signature y[1] = uint8(m.signatureAlgorithm)
y = y[2:] y = y[2:]
} }
y[0] = uint8(siglength >> 8) y[0] = uint8(siglength >> 8)
...@@ -1417,8 +1414,7 @@ func (m *certificateVerifyMsg) unmarshal(data []byte) bool { ...@@ -1417,8 +1414,7 @@ func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
data = data[4:] data = data[4:]
if m.hasSignatureAndHash { if m.hasSignatureAndHash {
m.signatureAndHash.hash = data[0] m.signatureAlgorithm = SignatureScheme(data[0])<<8 | SignatureScheme(data[1])
m.signatureAndHash.signature = data[1]
data = data[2:] data = data[2:]
} }
...@@ -1554,13 +1550,12 @@ func eqByteSlices(x, y [][]byte) bool { ...@@ -1554,13 +1550,12 @@ func eqByteSlices(x, y [][]byte) bool {
return true return true
} }
func eqSignatureAndHashes(x, y []signatureAndHash) bool { func eqSignatureAlgorithms(x, y []SignatureScheme) bool {
if len(x) != len(y) { if len(x) != len(y) {
return false return false
} }
for i, v := range x { for i, v := range x {
v2 := y[i] if v != y[i] {
if v.hash != v2.hash || v.signature != v2.signature {
return false return false
} }
} }
......
...@@ -145,7 +145,7 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { ...@@ -145,7 +145,7 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
} }
} }
if rand.Intn(10) > 5 { if rand.Intn(10) > 5 {
m.signatureAndHashes = supportedSignatureAlgorithms m.supportedSignatureAlgorithms = supportedSignatureAlgorithms
} }
m.alpnProtocols = make([]string, rand.Intn(5)) m.alpnProtocols = make([]string, rand.Intn(5))
for i := range m.alpnProtocols { for i := range m.alpnProtocols {
......
...@@ -418,7 +418,7 @@ func (hs *serverHandshakeState) doFullHandshake() error { ...@@ -418,7 +418,7 @@ func (hs *serverHandshakeState) doFullHandshake() error {
} }
if c.vers >= VersionTLS12 { if c.vers >= VersionTLS12 {
certReq.hasSignatureAndHash = true certReq.hasSignatureAndHash = true
certReq.signatureAndHashes = supportedSignatureAlgorithms certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms
} }
// An empty list of certificateAuthorities signals to // An empty list of certificateAuthorities signals to
...@@ -519,27 +519,30 @@ func (hs *serverHandshakeState) doFullHandshake() error { ...@@ -519,27 +519,30 @@ func (hs *serverHandshakeState) doFullHandshake() error {
} }
// Determine the signature type. // Determine the signature type.
var signatureAndHash signatureAndHash var signatureAlgorithm SignatureScheme
var sigType uint8
if certVerify.hasSignatureAndHash { if certVerify.hasSignatureAndHash {
signatureAndHash = certVerify.signatureAndHash signatureAlgorithm = certVerify.signatureAlgorithm
if !isSupportedSignatureAndHash(signatureAndHash, supportedSignatureAlgorithms) { if !isSupportedSignatureAlgorithm(signatureAlgorithm, supportedSignatureAlgorithms) {
return errors.New("tls: unsupported hash function for client certificate") return errors.New("tls: unsupported hash function for client certificate")
} }
sigType = signatureFromSignatureScheme(signatureAlgorithm)
} else { } else {
// Before TLS 1.2 the signature algorithm was implicit // Before TLS 1.2 the signature algorithm was implicit
// from the key type, and only one hash per signature // from the key type, and only one hash per signature
// algorithm was possible. Leave the hash as zero. // algorithm was possible. Leave signatureAlgorithm
// unset.
switch pub.(type) { switch pub.(type) {
case *ecdsa.PublicKey: case *ecdsa.PublicKey:
signatureAndHash.signature = signatureECDSA sigType = signatureECDSA
case *rsa.PublicKey: case *rsa.PublicKey:
signatureAndHash.signature = signatureRSA sigType = signatureRSA
} }
} }
switch key := pub.(type) { switch key := pub.(type) {
case *ecdsa.PublicKey: case *ecdsa.PublicKey:
if signatureAndHash.signature != signatureECDSA { if sigType != signatureECDSA {
err = errors.New("tls: bad signature type for client's ECDSA certificate") err = errors.New("tls: bad signature type for client's ECDSA certificate")
break break
} }
...@@ -552,20 +555,20 @@ func (hs *serverHandshakeState) doFullHandshake() error { ...@@ -552,20 +555,20 @@ func (hs *serverHandshakeState) doFullHandshake() error {
break break
} }
var digest []byte var digest []byte
if digest, _, err = hs.finishedHash.hashForClientCertificate(signatureAndHash, hs.masterSecret); err != nil { if digest, _, err = hs.finishedHash.hashForClientCertificate(sigType, signatureAlgorithm, hs.masterSecret); err != nil {
break break
} }
if !ecdsa.Verify(key, digest, ecdsaSig.R, ecdsaSig.S) { if !ecdsa.Verify(key, digest, ecdsaSig.R, ecdsaSig.S) {
err = errors.New("tls: ECDSA verification failure") err = errors.New("tls: ECDSA verification failure")
} }
case *rsa.PublicKey: case *rsa.PublicKey:
if signatureAndHash.signature != signatureRSA { if sigType != signatureRSA {
err = errors.New("tls: bad signature type for client's RSA certificate") err = errors.New("tls: bad signature type for client's RSA certificate")
break break
} }
var digest []byte var digest []byte
var hashFunc crypto.Hash var hashFunc crypto.Hash
if digest, hashFunc, err = hs.finishedHash.hashForClientCertificate(signatureAndHash, hs.masterSecret); err != nil { if digest, hashFunc, err = hs.finishedHash.hashForClientCertificate(sigType, signatureAlgorithm, hs.masterSecret); err != nil {
break break
} }
err = rsa.VerifyPKCS1v15(key, hashFunc, digest, certVerify.signature) err = rsa.VerifyPKCS1v15(key, hashFunc, digest, certVerify.signature)
...@@ -818,17 +821,12 @@ func (hs *serverHandshakeState) clientHelloInfo() *ClientHelloInfo { ...@@ -818,17 +821,12 @@ func (hs *serverHandshakeState) clientHelloInfo() *ClientHelloInfo {
supportedVersions = suppVersArray[VersionTLS12-hs.clientHello.vers:] supportedVersions = suppVersArray[VersionTLS12-hs.clientHello.vers:]
} }
signatureSchemes := make([]SignatureScheme, 0, len(hs.clientHello.signatureAndHashes))
for _, sah := range hs.clientHello.signatureAndHashes {
signatureSchemes = append(signatureSchemes, SignatureScheme(sah.hash)<<8+SignatureScheme(sah.signature))
}
hs.cachedClientHelloInfo = &ClientHelloInfo{ hs.cachedClientHelloInfo = &ClientHelloInfo{
CipherSuites: hs.clientHello.cipherSuites, CipherSuites: hs.clientHello.cipherSuites,
ServerName: hs.clientHello.serverName, ServerName: hs.clientHello.serverName,
SupportedCurves: hs.clientHello.supportedCurves, SupportedCurves: hs.clientHello.supportedCurves,
SupportedPoints: hs.clientHello.supportedPoints, SupportedPoints: hs.clientHello.supportedPoints,
SignatureSchemes: signatureSchemes, SignatureSchemes: hs.clientHello.supportedSignatureAlgorithms,
SupportedProtos: hs.clientHello.alpnProtocols, SupportedProtos: hs.clientHello.alpnProtocols,
SupportedVersions: supportedVersions, SupportedVersions: supportedVersions,
Conn: hs.c.conn, Conn: hs.c.conn,
......
...@@ -110,14 +110,14 @@ func md5SHA1Hash(slices [][]byte) []byte { ...@@ -110,14 +110,14 @@ func md5SHA1Hash(slices [][]byte) []byte {
} }
// hashForServerKeyExchange hashes the given slices and returns their digest // hashForServerKeyExchange hashes the given slices and returns their digest
// and the identifier of the hash function used. The sigAndHash argument is // and the identifier of the hash function used. The signatureAlgorithm argument
// only used for >= TLS 1.2 and precisely identifies the hash function to use. // is only used for >= TLS 1.2 and identifies the hash function to use.
func hashForServerKeyExchange(sigAndHash signatureAndHash, version uint16, slices ...[]byte) ([]byte, crypto.Hash, error) { func hashForServerKeyExchange(sigType uint8, signatureAlgorithm SignatureScheme, version uint16, slices ...[]byte) ([]byte, crypto.Hash, error) {
if version >= VersionTLS12 { if version >= VersionTLS12 {
if !isSupportedSignatureAndHash(sigAndHash, supportedSignatureAlgorithms) { if !isSupportedSignatureAlgorithm(signatureAlgorithm, supportedSignatureAlgorithms) {
return nil, crypto.Hash(0), errors.New("tls: unsupported hash function used by peer") return nil, crypto.Hash(0), errors.New("tls: unsupported hash function used by peer")
} }
hashFunc, err := lookupTLSHash(sigAndHash.hash) hashFunc, err := lookupTLSHash(signatureAlgorithm)
if err != nil { if err != nil {
return nil, crypto.Hash(0), err return nil, crypto.Hash(0), err
} }
...@@ -128,7 +128,7 @@ func hashForServerKeyExchange(sigAndHash signatureAndHash, version uint16, slice ...@@ -128,7 +128,7 @@ func hashForServerKeyExchange(sigAndHash signatureAndHash, version uint16, slice
digest := h.Sum(nil) digest := h.Sum(nil)
return digest, hashFunc, nil return digest, hashFunc, nil
} }
if sigAndHash.signature == signatureECDSA { if sigType == signatureECDSA {
return sha1Hash(slices), crypto.SHA1, nil return sha1Hash(slices), crypto.SHA1, nil
} }
return md5SHA1Hash(slices), crypto.MD5SHA1, nil return md5SHA1Hash(slices), crypto.MD5SHA1, nil
...@@ -137,20 +137,27 @@ func hashForServerKeyExchange(sigAndHash signatureAndHash, version uint16, slice ...@@ -137,20 +137,27 @@ func hashForServerKeyExchange(sigAndHash signatureAndHash, version uint16, slice
// pickTLS12HashForSignature returns a TLS 1.2 hash identifier for signing a // pickTLS12HashForSignature returns a TLS 1.2 hash identifier for signing a
// ServerKeyExchange given the signature type being used and the client's // ServerKeyExchange given the signature type being used and the client's
// advertised list of supported signature and hash combinations. // advertised list of supported signature and hash combinations.
func pickTLS12HashForSignature(sigType uint8, clientList []signatureAndHash) (uint8, error) { func pickTLS12HashForSignature(sigType uint8, clientList []SignatureScheme) (SignatureScheme, error) {
if len(clientList) == 0 { if len(clientList) == 0 {
// If the client didn't specify any signature_algorithms // If the client didn't specify any signature_algorithms
// extension then we can assume that it supports SHA1. See // extension then we can assume that it supports SHA1. See
// http://tools.ietf.org/html/rfc5246#section-7.4.1.4.1 // http://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
return hashSHA1, nil switch sigType {
case signatureRSA:
return PKCS1WithSHA1, nil
case signatureECDSA:
return ECDSAWithSHA1, nil
default:
return 0, errors.New("tls: unknown signature algorithm")
}
} }
for _, sigAndHash := range clientList { for _, sigAlg := range clientList {
if sigAndHash.signature != sigType { if signatureFromSignatureScheme(sigAlg) != sigType {
continue continue
} }
if isSupportedSignatureAndHash(sigAndHash, supportedSignatureAlgorithms) { if isSupportedSignatureAlgorithm(sigAlg, supportedSignatureAlgorithms) {
return sigAndHash.hash, nil return sigAlg, nil
} }
} }
...@@ -240,16 +247,17 @@ NextCandidate: ...@@ -240,16 +247,17 @@ NextCandidate:
serverECDHParams[3] = byte(len(ecdhePublic)) serverECDHParams[3] = byte(len(ecdhePublic))
copy(serverECDHParams[4:], ecdhePublic) copy(serverECDHParams[4:], ecdhePublic)
sigAndHash := signatureAndHash{signature: ka.sigType} var signatureAlgorithm SignatureScheme
if ka.version >= VersionTLS12 { if ka.version >= VersionTLS12 {
var err error var err error
if sigAndHash.hash, err = pickTLS12HashForSignature(ka.sigType, clientHello.signatureAndHashes); err != nil { signatureAlgorithm, err = pickTLS12HashForSignature(ka.sigType, clientHello.supportedSignatureAlgorithms)
if err != nil {
return nil, err return nil, err
} }
} }
digest, hashFunc, err := hashForServerKeyExchange(sigAndHash, ka.version, clientHello.random, hello.random, serverECDHParams) digest, hashFunc, err := hashForServerKeyExchange(ka.sigType, signatureAlgorithm, ka.version, clientHello.random, hello.random, serverECDHParams)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -287,8 +295,8 @@ NextCandidate: ...@@ -287,8 +295,8 @@ NextCandidate:
copy(skx.key, serverECDHParams) copy(skx.key, serverECDHParams)
k := skx.key[len(serverECDHParams):] k := skx.key[len(serverECDHParams):]
if ka.version >= VersionTLS12 { if ka.version >= VersionTLS12 {
k[0] = sigAndHash.hash k[0] = byte(signatureAlgorithm >> 8)
k[1] = sigAndHash.signature k[1] = byte(signatureAlgorithm)
k = k[2:] k = k[2:]
} }
k[0] = byte(len(sig) >> 8) k[0] = byte(len(sig) >> 8)
...@@ -368,11 +376,11 @@ func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHell ...@@ -368,11 +376,11 @@ func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHell
} }
} }
sigAndHash := signatureAndHash{signature: ka.sigType} var signatureAlgorithm SignatureScheme
if ka.version >= VersionTLS12 { if ka.version >= VersionTLS12 {
// handle SignatureAndHashAlgorithm // handle SignatureAndHashAlgorithm
sigAndHash = signatureAndHash{hash: sig[0], signature: sig[1]} signatureAlgorithm = SignatureScheme(sig[0])<<8 | SignatureScheme(sig[1])
if sigAndHash.signature != ka.sigType { if signatureFromSignatureScheme(signatureAlgorithm) != ka.sigType {
return errServerKeyExchange return errServerKeyExchange
} }
sig = sig[2:] sig = sig[2:]
...@@ -386,7 +394,7 @@ func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHell ...@@ -386,7 +394,7 @@ func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHell
} }
sig = sig[2:] sig = sig[2:]
digest, hashFunc, err := hashForServerKeyExchange(sigAndHash, ka.version, clientHello.random, serverHello.random, serverECDHParams) digest, hashFunc, err := hashForServerKeyExchange(ka.sigType, signatureAlgorithm, ka.version, clientHello.random, serverHello.random, serverECDHParams)
if err != nil { if err != nil {
return err return err
} }
......
...@@ -12,6 +12,7 @@ import ( ...@@ -12,6 +12,7 @@ import (
"crypto/sha256" "crypto/sha256"
"crypto/sha512" "crypto/sha512"
"errors" "errors"
"fmt"
"hash" "hash"
) )
...@@ -180,17 +181,17 @@ func keysFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clie ...@@ -180,17 +181,17 @@ func keysFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clie
} }
// lookupTLSHash looks up the corresponding crypto.Hash for a given // lookupTLSHash looks up the corresponding crypto.Hash for a given
// TLS hash identifier. // hash from a TLS SignatureScheme.
func lookupTLSHash(hash uint8) (crypto.Hash, error) { func lookupTLSHash(signatureAlgorithm SignatureScheme) (crypto.Hash, error) {
switch hash { switch signatureAlgorithm {
case hashSHA1: case PKCS1WithSHA1, ECDSAWithSHA1:
return crypto.SHA1, nil return crypto.SHA1, nil
case hashSHA256: case PKCS1WithSHA256, PSSWithSHA256, ECDSAWithP256AndSHA256:
return crypto.SHA256, nil return crypto.SHA256, nil
case hashSHA384: case PKCS1WithSHA384, PSSWithSHA384, ECDSAWithP384AndSHA384:
return crypto.SHA384, nil return crypto.SHA384, nil
default: default:
return 0, errors.New("tls: unsupported hash algorithm") return 0, fmt.Errorf("tls: unsupported signature algorithm: %#04x", signatureAlgorithm)
} }
} }
...@@ -310,31 +311,26 @@ func (h finishedHash) serverSum(masterSecret []byte) []byte { ...@@ -310,31 +311,26 @@ func (h finishedHash) serverSum(masterSecret []byte) []byte {
return out return out
} }
// selectClientCertSignatureAlgorithm returns a signatureAndHash to sign a // selectClientCertSignatureAlgorithm returns a SignatureScheme to sign a
// client's CertificateVerify with, or an error if none can be found. // client's CertificateVerify with, or an error if none can be found.
func (h finishedHash) selectClientCertSignatureAlgorithm(serverList []signatureAndHash, sigType uint8) (signatureAndHash, error) { func (h finishedHash) selectClientCertSignatureAlgorithm(serverList []SignatureScheme, sigType uint8) (SignatureScheme, error) {
if h.version < VersionTLS12 {
// Nothing to negotiate before TLS 1.2.
return signatureAndHash{signature: sigType}, nil
}
for _, v := range serverList { for _, v := range serverList {
if v.signature == sigType && isSupportedSignatureAndHash(v, supportedSignatureAlgorithms) { if signatureFromSignatureScheme(v) == sigType && isSupportedSignatureAlgorithm(v, supportedSignatureAlgorithms) {
return v, nil return v, nil
} }
} }
return signatureAndHash{}, errors.New("tls: no supported signature algorithm found for signing client certificate") return 0, errors.New("tls: no supported signature algorithm found for signing client certificate")
} }
// hashForClientCertificate returns a digest, hash function, and TLS 1.2 hash // hashForClientCertificate returns a digest, hash function, and TLS 1.2 hash
// id suitable for signing by a TLS client certificate. // id suitable for signing by a TLS client certificate.
func (h finishedHash) hashForClientCertificate(signatureAndHash signatureAndHash, masterSecret []byte) ([]byte, crypto.Hash, error) { func (h finishedHash) hashForClientCertificate(sigType uint8, signatureAlgorithm SignatureScheme, masterSecret []byte) ([]byte, crypto.Hash, error) {
if (h.version == VersionSSL30 || h.version >= VersionTLS12) && h.buffer == nil { if (h.version == VersionSSL30 || h.version >= VersionTLS12) && h.buffer == nil {
panic("a handshake hash for a client-certificate was requested after discarding the handshake buffer") panic("a handshake hash for a client-certificate was requested after discarding the handshake buffer")
} }
if h.version == VersionSSL30 { if h.version == VersionSSL30 {
if signatureAndHash.signature != signatureRSA { if sigType != signatureRSA {
return nil, 0, errors.New("tls: unsupported signature type for client certificate") return nil, 0, errors.New("tls: unsupported signature type for client certificate")
} }
...@@ -345,7 +341,7 @@ func (h finishedHash) hashForClientCertificate(signatureAndHash signatureAndHash ...@@ -345,7 +341,7 @@ func (h finishedHash) hashForClientCertificate(signatureAndHash signatureAndHash
return finishedSum30(md5Hash, sha1Hash, masterSecret, nil), crypto.MD5SHA1, nil return finishedSum30(md5Hash, sha1Hash, masterSecret, nil), crypto.MD5SHA1, nil
} }
if h.version >= VersionTLS12 { if h.version >= VersionTLS12 {
hashAlg, err := lookupTLSHash(signatureAndHash.hash) hashAlg, err := lookupTLSHash(signatureAlgorithm)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
...@@ -354,7 +350,7 @@ func (h finishedHash) hashForClientCertificate(signatureAndHash signatureAndHash ...@@ -354,7 +350,7 @@ func (h finishedHash) hashForClientCertificate(signatureAndHash signatureAndHash
return hash.Sum(nil), hashAlg, nil return hash.Sum(nil), hashAlg, nil
} }
if signatureAndHash.signature == signatureECDSA { if sigType == signatureECDSA {
return h.server.Sum(nil), crypto.SHA1, nil return h.server.Sum(nil), crypto.SHA1, nil
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment