Commit 77013065 authored by Filippo Valsorda's avatar Filippo Valsorda Committed by Dmitri Shuralyov

crypto/x509: limit number of signature checks for each verification

That number grows quadratically with the number of intermediate
certificates in certain pathological cases (for example if they all have
the same Subject) leading to a CPU DoS. Set a fixed budget that should
fit all real world chains, given we only look at intermediates provided
by the peer.

The algorithm can be improved, but that's left for follow-up CLs:

    * the cache logic should be reviewed for correctness, as it seems to
      override the entire chain with the cached one
    * the equality check should compare Subject and public key, not the
      whole certificate
    * certificates with the right SKID but the wrong Subject should not
      be considered, and in particular should not take priority over
      certificates with the right Subject

Fixes #29233

Change-Id: Ib257c12cd5563df7723f9c81231d82b882854213
Reviewed-on: https://team-review.git.corp.google.com/c/370475Reviewed-by: 's avatarAndrew Bonventre <andybons@google.com>
Reviewed-on: https://go-review.googlesource.com/c/154105Reviewed-by: 's avatarFilippo Valsorda <filippo@golang.org>
Run-TryBot: Filippo Valsorda <filippo@golang.org>
parent 9c075b7c
...@@ -65,32 +65,16 @@ func SystemCertPool() (*CertPool, error) { ...@@ -65,32 +65,16 @@ func SystemCertPool() (*CertPool, error) {
return loadSystemRoots() return loadSystemRoots()
} }
// findVerifiedParents attempts to find certificates in s which have signed the // findPotentialParents returns the indexes of certificates in s which might
// given certificate. If any candidates were rejected then errCert will be set // have signed cert. The caller must not modify the returned slice.
// to one of them, arbitrarily, and err will contain the reason that it was func (s *CertPool) findPotentialParents(cert *Certificate) []int {
// rejected.
func (s *CertPool) findVerifiedParents(cert *Certificate) (parents []int, errCert *Certificate, err error) {
if s == nil { if s == nil {
return return nil
} }
var candidates []int
if len(cert.AuthorityKeyId) > 0 { if len(cert.AuthorityKeyId) > 0 {
candidates = s.bySubjectKeyId[string(cert.AuthorityKeyId)] return s.bySubjectKeyId[string(cert.AuthorityKeyId)]
}
if len(candidates) == 0 {
candidates = s.byName[string(cert.RawIssuer)]
} }
return s.byName[string(cert.RawIssuer)]
for _, c := range candidates {
if err = cert.CheckSignatureFrom(s.certs[c]); err == nil {
parents = append(parents, c)
} else {
errCert = s.certs[c]
}
}
return
} }
func (s *CertPool) contains(cert *Certificate) bool { func (s *CertPool) contains(cert *Certificate) bool {
......
...@@ -763,7 +763,7 @@ func (c *Certificate) Verify(opts VerifyOptions) (chains [][]*Certificate, err e ...@@ -763,7 +763,7 @@ func (c *Certificate) Verify(opts VerifyOptions) (chains [][]*Certificate, err e
if opts.Roots.contains(c) { if opts.Roots.contains(c) {
candidateChains = append(candidateChains, []*Certificate{c}) candidateChains = append(candidateChains, []*Certificate{c})
} else { } else {
if candidateChains, err = c.buildChains(make(map[int][][]*Certificate), []*Certificate{c}, &opts); err != nil { if candidateChains, err = c.buildChains(nil, []*Certificate{c}, nil, &opts); err != nil {
return nil, err return nil, err
} }
} }
...@@ -800,58 +800,74 @@ func appendToFreshChain(chain []*Certificate, cert *Certificate) []*Certificate ...@@ -800,58 +800,74 @@ func appendToFreshChain(chain []*Certificate, cert *Certificate) []*Certificate
return n return n
} }
func (c *Certificate) buildChains(cache map[int][][]*Certificate, currentChain []*Certificate, opts *VerifyOptions) (chains [][]*Certificate, err error) { // maxChainSignatureChecks is the maximum number of CheckSignatureFrom calls
possibleRoots, failedRoot, rootErr := opts.Roots.findVerifiedParents(c) // that an invocation of buildChains will (tranistively) make. Most chains are
nextRoot: // less than 15 certificates long, so this leaves space for multiple chains and
for _, rootNum := range possibleRoots { // for failed checks due to different intermediates having the same Subject.
root := opts.Roots.certs[rootNum] const maxChainSignatureChecks = 100
func (c *Certificate) buildChains(cache map[*Certificate][][]*Certificate, currentChain []*Certificate, sigChecks *int, opts *VerifyOptions) (chains [][]*Certificate, err error) {
var (
hintErr error
hintCert *Certificate
)
considerCandidate := func(certType int, candidate *Certificate) {
for _, cert := range currentChain { for _, cert := range currentChain {
if cert.Equal(root) { if cert.Equal(candidate) {
continue nextRoot return
} }
} }
err = root.isValid(rootCertificate, currentChain, opts) if sigChecks == nil {
if err != nil { sigChecks = new(int)
continue }
*sigChecks++
if *sigChecks > maxChainSignatureChecks {
err = errors.New("x509: signature check attempts limit reached while verifying certificate chain")
return
} }
chains = append(chains, appendToFreshChain(currentChain, root))
}
possibleIntermediates, failedIntermediate, intermediateErr := opts.Intermediates.findVerifiedParents(c) if err := c.CheckSignatureFrom(candidate); err != nil {
nextIntermediate: if hintErr == nil {
for _, intermediateNum := range possibleIntermediates { hintErr = err
intermediate := opts.Intermediates.certs[intermediateNum] hintCert = candidate
for _, cert := range currentChain {
if cert.Equal(intermediate) {
continue nextIntermediate
} }
return
} }
err = intermediate.isValid(intermediateCertificate, currentChain, opts)
err = candidate.isValid(certType, currentChain, opts)
if err != nil { if err != nil {
continue return
} }
var childChains [][]*Certificate
childChains, ok := cache[intermediateNum] switch certType {
if !ok { case rootCertificate:
childChains, err = intermediate.buildChains(cache, appendToFreshChain(currentChain, intermediate), opts) chains = append(chains, appendToFreshChain(currentChain, candidate))
cache[intermediateNum] = childChains case intermediateCertificate:
if cache == nil {
cache = make(map[*Certificate][][]*Certificate)
}
childChains, ok := cache[candidate]
if !ok {
childChains, err = candidate.buildChains(cache, appendToFreshChain(currentChain, candidate), sigChecks, opts)
cache[candidate] = childChains
}
chains = append(chains, childChains...)
} }
chains = append(chains, childChains...) }
for _, rootNum := range opts.Roots.findPotentialParents(c) {
considerCandidate(rootCertificate, opts.Roots.certs[rootNum])
}
for _, intermediateNum := range opts.Intermediates.findPotentialParents(c) {
considerCandidate(intermediateCertificate, opts.Intermediates.certs[intermediateNum])
} }
if len(chains) > 0 { if len(chains) > 0 {
err = nil err = nil
} }
if len(chains) == 0 && err == nil { if len(chains) == 0 && err == nil {
hintErr := rootErr
hintCert := failedRoot
if hintErr == nil {
hintErr = intermediateErr
hintCert = failedIntermediate
}
err = UnknownAuthorityError{c, hintErr, hintCert} err = UnknownAuthorityError{c, hintErr, hintCert}
} }
......
...@@ -5,10 +5,15 @@ ...@@ -5,10 +5,15 @@
package x509 package x509
import ( import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509/pkix" "crypto/x509/pkix"
"encoding/pem" "encoding/pem"
"errors" "errors"
"fmt" "fmt"
"math/big"
"runtime" "runtime"
"strings" "strings"
"testing" "testing"
...@@ -1889,3 +1894,117 @@ func TestValidHostname(t *testing.T) { ...@@ -1889,3 +1894,117 @@ func TestValidHostname(t *testing.T) {
} }
} }
} }
func generateCert(cn string, isCA bool, issuer *Certificate, issuerKey crypto.PrivateKey) (*Certificate, crypto.PrivateKey, error) {
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, nil, err
}
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, _ := rand.Int(rand.Reader, serialNumberLimit)
template := &Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{CommonName: cn},
NotBefore: time.Now().Add(-1 * time.Hour),
NotAfter: time.Now().Add(24 * time.Hour),
KeyUsage: KeyUsageKeyEncipherment | KeyUsageDigitalSignature | KeyUsageCertSign,
ExtKeyUsage: []ExtKeyUsage{ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IsCA: isCA,
}
if issuer == nil {
issuer = template
issuerKey = priv
}
derBytes, err := CreateCertificate(rand.Reader, template, issuer, priv.Public(), issuerKey)
if err != nil {
return nil, nil, err
}
cert, err := ParseCertificate(derBytes)
if err != nil {
return nil, nil, err
}
return cert, priv, nil
}
func TestPathologicalChain(t *testing.T) {
if testing.Short() {
t.Skip("skipping generation of a long chain of certificates in short mode")
}
// Build a chain where all intermediates share the same subject, to hit the
// path building worst behavior.
roots, intermediates := NewCertPool(), NewCertPool()
parent, parentKey, err := generateCert("Root CA", true, nil, nil)
if err != nil {
t.Fatal(err)
}
roots.AddCert(parent)
for i := 1; i < 100; i++ {
parent, parentKey, err = generateCert("Intermediate CA", true, parent, parentKey)
if err != nil {
t.Fatal(err)
}
intermediates.AddCert(parent)
}
leaf, _, err := generateCert("Leaf", false, parent, parentKey)
if err != nil {
t.Fatal(err)
}
start := time.Now()
_, err = leaf.Verify(VerifyOptions{
Roots: roots,
Intermediates: intermediates,
})
t.Logf("verification took %v", time.Since(start))
if err == nil || !strings.Contains(err.Error(), "signature check attempts limit") {
t.Errorf("expected verification to fail with a signature checks limit error; got %v", err)
}
}
func TestLongChain(t *testing.T) {
if testing.Short() {
t.Skip("skipping generation of a long chain of certificates in short mode")
}
roots, intermediates := NewCertPool(), NewCertPool()
parent, parentKey, err := generateCert("Root CA", true, nil, nil)
if err != nil {
t.Fatal(err)
}
roots.AddCert(parent)
for i := 1; i < 15; i++ {
name := fmt.Sprintf("Intermediate CA #%d", i)
parent, parentKey, err = generateCert(name, true, parent, parentKey)
if err != nil {
t.Fatal(err)
}
intermediates.AddCert(parent)
}
leaf, _, err := generateCert("Leaf", false, parent, parentKey)
if err != nil {
t.Fatal(err)
}
start := time.Now()
if _, err := leaf.Verify(VerifyOptions{
Roots: roots,
Intermediates: intermediates,
}); err != nil {
t.Error(err)
}
t.Logf("verification took %v", time.Since(start))
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment