Commit 0663fe98 authored by Filippo Valsorda's avatar Filippo Valsorda

crypto/tls: implement TLS 1.3 version-specific messages

Note that there is significant code duplication due to extensions with
the same format appearing in different messages in TLS 1.3. This will be
cleaned up in a future refactor once CL 145317 is merged.

Enforcing the presence/absence of each extension in each message is left
to the upper layer, based on both protocol version and extensions
advertised in CH and CR. Duplicated extensions and unknown extensions in
SH, EE, HRR, and CT will be tightened up in a future CL.

The TLS 1.2 CertificateStatus message was restricted to accepting only
type OCSP as any other type (none of which are specified so far) would
have to be negotiated.

Updates #9671

Change-Id: I7c42394c5cc0af01faa84b9b9f25fdc6e7cfbb9e
Reviewed-on: https://go-review.googlesource.com/c/145477Reviewed-by: 's avatarAdam Langley <agl@golang.org>
parent 84d6a7ab
...@@ -60,6 +60,8 @@ const ( ...@@ -60,6 +60,8 @@ const (
typeClientHello uint8 = 1 typeClientHello uint8 = 1
typeServerHello uint8 = 2 typeServerHello uint8 = 2
typeNewSessionTicket uint8 = 4 typeNewSessionTicket uint8 = 4
typeEndOfEarlyData uint8 = 5
typeEncryptedExtensions uint8 = 8
typeCertificate uint8 = 11 typeCertificate uint8 = 11
typeServerKeyExchange uint8 = 12 typeServerKeyExchange uint8 = 12
typeCertificateRequest uint8 = 13 typeCertificateRequest uint8 = 13
...@@ -68,7 +70,9 @@ const ( ...@@ -68,7 +70,9 @@ const (
typeClientKeyExchange uint8 = 16 typeClientKeyExchange uint8 = 16
typeFinished uint8 = 20 typeFinished uint8 = 20
typeCertificateStatus uint8 = 22 typeCertificateStatus uint8 = 22
typeKeyUpdate uint8 = 24
typeNextProtocol uint8 = 67 // Not IANA assigned typeNextProtocol uint8 = 67 // Not IANA assigned
typeMessageHash uint8 = 254 // synthetic message
) )
// TLS compression types. // TLS compression types.
...@@ -87,6 +91,7 @@ const ( ...@@ -87,6 +91,7 @@ const (
extensionSCT uint16 = 18 extensionSCT uint16 = 18
extensionSessionTicket uint16 = 35 extensionSessionTicket uint16 = 35
extensionPreSharedKey uint16 = 41 extensionPreSharedKey uint16 = 41
extensionEarlyData uint16 = 42
extensionSupportedVersions uint16 = 43 extensionSupportedVersions uint16 = 43
extensionCookie uint16 = 44 extensionCookie uint16 = 44
extensionPSKModes uint16 = 45 extensionPSKModes uint16 = 45
......
...@@ -990,13 +990,25 @@ func (c *Conn) readHandshake() (interface{}, error) { ...@@ -990,13 +990,25 @@ func (c *Conn) readHandshake() (interface{}, error) {
case typeServerHello: case typeServerHello:
m = new(serverHelloMsg) m = new(serverHelloMsg)
case typeNewSessionTicket: case typeNewSessionTicket:
if c.vers == VersionTLS13 {
m = new(newSessionTicketMsgTLS13)
} else {
m = new(newSessionTicketMsg) m = new(newSessionTicketMsg)
}
case typeCertificate: case typeCertificate:
if c.vers == VersionTLS13 {
m = new(certificateMsgTLS13)
} else {
m = new(certificateMsg) m = new(certificateMsg)
}
case typeCertificateRequest: case typeCertificateRequest:
if c.vers == VersionTLS13 {
m = new(certificateRequestMsgTLS13)
} else {
m = &certificateRequestMsg{ m = &certificateRequestMsg{
hasSignatureAlgorithm: c.vers >= VersionTLS12, hasSignatureAlgorithm: c.vers >= VersionTLS12,
} }
}
case typeCertificateStatus: case typeCertificateStatus:
m = new(certificateStatusMsg) m = new(certificateStatusMsg)
case typeServerKeyExchange: case typeServerKeyExchange:
...@@ -1013,6 +1025,12 @@ func (c *Conn) readHandshake() (interface{}, error) { ...@@ -1013,6 +1025,12 @@ func (c *Conn) readHandshake() (interface{}, error) {
m = new(nextProtoMsg) m = new(nextProtoMsg)
case typeFinished: case typeFinished:
m = new(finishedMsg) m = new(finishedMsg)
case typeEncryptedExtensions:
m = new(encryptedExtensionsMsg)
case typeEndOfEarlyData:
m = new(endOfEarlyDataMsg)
case typeKeyUpdate:
m = new(keyUpdateMsg)
default: default:
return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
} }
......
...@@ -393,9 +393,7 @@ func (hs *clientHandshakeState) doFullHandshake() error { ...@@ -393,9 +393,7 @@ func (hs *clientHandshakeState) doFullHandshake() error {
} }
hs.finishedHash.Write(cs.marshal()) hs.finishedHash.Write(cs.marshal())
if cs.statusType == statusTypeOCSP {
c.ocspResponse = cs.response c.ocspResponse = cs.response
}
msg, err = c.readHandshake() msg, err = c.readHandshake()
if err != nil { if err != nil {
......
This diff is collapsed.
...@@ -29,6 +29,12 @@ var tests = []interface{}{ ...@@ -29,6 +29,12 @@ var tests = []interface{}{
&nextProtoMsg{}, &nextProtoMsg{},
&newSessionTicketMsg{}, &newSessionTicketMsg{},
&sessionState{}, &sessionState{},
&encryptedExtensionsMsg{},
&endOfEarlyDataMsg{},
&keyUpdateMsg{},
&newSessionTicketMsgTLS13{},
&certificateRequestMsgTLS13{},
&certificateMsgTLS13{},
} }
func TestMarshalUnmarshal(t *testing.T) { func TestMarshalUnmarshal(t *testing.T) {
...@@ -184,6 +190,9 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { ...@@ -184,6 +190,9 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m.pskIdentities = append(m.pskIdentities, psk) m.pskIdentities = append(m.pskIdentities, psk)
m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand)) m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand))
} }
if rand.Intn(10) > 5 {
m.earlyData = true
}
return reflect.ValueOf(m) return reflect.ValueOf(m)
} }
...@@ -209,7 +218,9 @@ func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { ...@@ -209,7 +218,9 @@ func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
if rand.Intn(10) > 5 { if rand.Intn(10) > 5 {
m.ticketSupported = true m.ticketSupported = true
} }
if rand.Intn(10) > 5 {
m.alpnProtocol = randomString(rand.Intn(32)+1, rand) m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
}
for i := 0; i < rand.Intn(4); i++ { for i := 0; i < rand.Intn(4); i++ {
m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand)) m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand))
...@@ -241,6 +252,16 @@ func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { ...@@ -241,6 +252,16 @@ func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
return reflect.ValueOf(m) return reflect.ValueOf(m)
} }
func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &encryptedExtensionsMsg{}
if rand.Intn(10) > 5 {
m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
}
return reflect.ValueOf(m)
}
func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value { func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &certificateMsg{} m := &certificateMsg{}
numCerts := rand.Intn(20) numCerts := rand.Intn(20)
...@@ -270,12 +291,7 @@ func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value { ...@@ -270,12 +291,7 @@ func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value { func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &certificateStatusMsg{} m := &certificateStatusMsg{}
if rand.Intn(10) > 5 {
m.statusType = statusTypeOCSP
m.response = randomBytes(rand.Intn(10)+1, rand) m.response = randomBytes(rand.Intn(10)+1, rand)
} else {
m.statusType = 42
}
return reflect.ValueOf(m) return reflect.ValueOf(m)
} }
...@@ -316,6 +332,66 @@ func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value { ...@@ -316,6 +332,66 @@ func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value {
return reflect.ValueOf(s) return reflect.ValueOf(s)
} }
func (*endOfEarlyDataMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &endOfEarlyDataMsg{}
return reflect.ValueOf(m)
}
func (*keyUpdateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &keyUpdateMsg{}
m.updateRequested = rand.Intn(10) > 5
return reflect.ValueOf(m)
}
func (*newSessionTicketMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
m := &newSessionTicketMsgTLS13{}
m.lifetime = uint32(rand.Intn(500000))
m.ageAdd = uint32(rand.Intn(500000))
m.nonce = randomBytes(rand.Intn(100), rand)
m.label = randomBytes(rand.Intn(1000), rand)
if rand.Intn(10) > 5 {
m.maxEarlyData = uint32(rand.Intn(500000))
}
return reflect.ValueOf(m)
}
func (*certificateRequestMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
m := &certificateRequestMsgTLS13{}
if rand.Intn(10) > 5 {
m.ocspStapling = true
}
if rand.Intn(10) > 5 {
m.scts = true
}
if rand.Intn(10) > 5 {
m.supportedSignatureAlgorithms = supportedSignatureAlgorithms
}
if rand.Intn(10) > 5 {
m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms
}
return reflect.ValueOf(m)
}
func (*certificateMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
m := &certificateMsgTLS13{}
for i := 0; i < rand.Intn(2)+1; i++ {
m.certificate.Certificate = append(
m.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand))
}
if rand.Intn(10) > 5 {
m.ocspStapling = true
m.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand)
}
if rand.Intn(10) > 5 {
m.scts = true
for i := 0; i < rand.Intn(2)+1; i++ {
m.certificate.SignedCertificateTimestamps = append(
m.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand))
}
}
return reflect.ValueOf(m)
}
func TestRejectEmptySCTList(t *testing.T) { func TestRejectEmptySCTList(t *testing.T) {
// RFC 6962, Section 3.3.1 specifies that empty SCT lists are invalid. // RFC 6962, Section 3.3.1 specifies that empty SCT lists are invalid.
......
...@@ -389,7 +389,6 @@ func (hs *serverHandshakeState) doFullHandshake() error { ...@@ -389,7 +389,6 @@ func (hs *serverHandshakeState) doFullHandshake() error {
if hs.hello.ocspStapling { if hs.hello.ocspStapling {
certStatus := new(certificateStatusMsg) certStatus := new(certificateStatusMsg)
certStatus.statusType = statusTypeOCSP
certStatus.response = hs.cert.OCSPStaple certStatus.response = hs.cert.OCSPStaple
hs.finishedHash.Write(certStatus.marshal()) hs.finishedHash.Write(certStatus.marshal())
if _, err := c.writeRecord(recordTypeHandshake, certStatus.marshal()); err != nil { if _, err := c.writeRecord(recordTypeHandshake, certStatus.marshal()); err != nil {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment