diff --git a/src/crypto/x509/cert_pool.go b/src/crypto/x509/cert_pool.go index 167390da9f..2cfaeb2d9e 100644 --- a/src/crypto/x509/cert_pool.go +++ b/src/crypto/x509/cert_pool.go @@ -6,35 +6,87 @@ package x509 import ( "bytes" + "crypto/sha256" "encoding/pem" "errors" "runtime" ) +type sum224 [sha256.Size224]byte + // CertPool is a set of certificates. type CertPool struct { - byName map[string][]int - certs []*Certificate + byName map[string][]int // cert.RawSubject => index into lazyCerts + + // lazyCerts contains funcs that return a certificate, + // lazily parsing/decompressing it as needed. + lazyCerts []lazyCert + + // haveSum maps from sum224(cert.Raw) to true. It's used only + // for AddCert duplicate detection, to avoid CertPool.contains + // calls in the AddCert path (because the contains method can + // call getCert and otherwise negate savings from lazy getCert + // funcs). + haveSum map[sum224]bool +} + +// lazyCert is minimal metadata about a Cert and a func to retrieve it +// in its normal expanded *Certificate form. +type lazyCert struct { + // rawSubject is the Certificate.RawSubject value. + // It's the same as the CertPool.byName key, but in []byte + // form to make CertPool.Subjects (as used by crypto/tls) do + // fewer allocations. + rawSubject []byte + + // getCert returns the certificate. + // + // It is not meant to do network operations or anything else + // where a failure is likely; the func is meant to lazily + // parse/decompress data that is already known to be good. The + // error in the signature primarily is meant for use in the + // case where a cert file existed on local disk when the program + // started up is deleted later before it's read. + getCert func() (*Certificate, error) } // NewCertPool returns a new, empty CertPool. func NewCertPool() *CertPool { return &CertPool{ - byName: make(map[string][]int), + byName: make(map[string][]int), + haveSum: make(map[sum224]bool), } } +// len returns the number of certs in the set. +// A nil set is a valid empty set. +func (s *CertPool) len() int { + if s == nil { + return 0 + } + return len(s.lazyCerts) +} + +// cert returns cert index n in s. +func (s *CertPool) cert(n int) (*Certificate, error) { + return s.lazyCerts[n].getCert() +} + func (s *CertPool) copy() *CertPool { p := &CertPool{ - byName: make(map[string][]int, len(s.byName)), - certs: make([]*Certificate, len(s.certs)), + byName: make(map[string][]int, len(s.byName)), + lazyCerts: make([]lazyCert, len(s.lazyCerts)), + haveSum: make(map[sum224]bool, len(s.haveSum)), } for k, v := range s.byName { indexes := make([]int, len(v)) copy(indexes, v) p.byName[k] = indexes } - copy(p.certs, s.certs) + for k := range s.haveSum { + p.haveSum[k] = true + } + copy(p.lazyCerts, s.lazyCerts) return p } @@ -64,7 +116,7 @@ func SystemCertPool() (*CertPool, error) { // findPotentialParents returns the indexes of certificates in s which might // have signed cert. -func (s *CertPool) findPotentialParents(cert *Certificate) []int { +func (s *CertPool) findPotentialParents(cert *Certificate) []*Certificate { if s == nil { return nil } @@ -75,18 +127,21 @@ func (s *CertPool) findPotentialParents(cert *Certificate) []int { // AKID and SKID match // AKID present, SKID missing / AKID missing, SKID present // AKID and SKID don't match - var matchingKeyID, oneKeyID, mismatchKeyID []int + var matchingKeyID, oneKeyID, mismatchKeyID []*Certificate for _, c := range s.byName[string(cert.RawIssuer)] { - candidate := s.certs[c] + candidate, err := s.cert(c) + if err != nil { + continue + } kidMatch := bytes.Equal(candidate.SubjectKeyId, cert.AuthorityKeyId) switch { case kidMatch: - matchingKeyID = append(matchingKeyID, c) + matchingKeyID = append(matchingKeyID, candidate) case (len(candidate.SubjectKeyId) == 0 && len(cert.AuthorityKeyId) > 0) || (len(candidate.SubjectKeyId) > 0 && len(cert.AuthorityKeyId) == 0): - oneKeyID = append(oneKeyID, c) + oneKeyID = append(oneKeyID, candidate) default: - mismatchKeyID = append(mismatchKeyID, c) + mismatchKeyID = append(mismatchKeyID, candidate) } } @@ -94,11 +149,10 @@ func (s *CertPool) findPotentialParents(cert *Certificate) []int { if found == 0 { return nil } - candidates := make([]int, 0, found) + candidates := make([]*Certificate, 0, found) candidates = append(candidates, matchingKeyID...) candidates = append(candidates, oneKeyID...) candidates = append(candidates, mismatchKeyID...) - return candidates } @@ -106,10 +160,13 @@ func (s *CertPool) contains(cert *Certificate) bool { if s == nil { return false } - candidates := s.byName[string(cert.RawSubject)] - for _, c := range candidates { - if s.certs[c].Equal(cert) { + for _, i := range candidates { + c, err := s.cert(i) + if err != nil { + return false + } + if c.Equal(cert) { return true } } @@ -122,17 +179,32 @@ func (s *CertPool) AddCert(cert *Certificate) { if cert == nil { panic("adding nil Certificate to CertPool") } + s.addCertFunc(sha256.Sum224(cert.Raw), string(cert.RawSubject), func() (*Certificate, error) { + return cert, nil + }) +} + +// addCertFunc adds metadata about a certificate to a pool, along with +// a func to fetch that certificate later when needed. +// +// The rawSubject is Certificate.RawSubject and must be non-empty. +// The getCert func may be called 0 or more times. +func (s *CertPool) addCertFunc(rawSum224 sum224, rawSubject string, getCert func() (*Certificate, error)) { + if getCert == nil { + panic("getCert can't be nil") + } // Check that the certificate isn't being added twice. - if s.contains(cert) { + if s.haveSum[rawSum224] { return } - n := len(s.certs) - s.certs = append(s.certs, cert) - - name := string(cert.RawSubject) - s.byName[name] = append(s.byName[name], n) + s.haveSum[rawSum224] = true + s.lazyCerts = append(s.lazyCerts, lazyCert{ + rawSubject: []byte(rawSubject), + getCert: getCert, + }) + s.byName[rawSubject] = append(s.byName[rawSubject], len(s.lazyCerts)-1) } // AppendCertsFromPEM attempts to parse a series of PEM encoded certificates. @@ -167,9 +239,9 @@ func (s *CertPool) AppendCertsFromPEM(pemCerts []byte) (ok bool) { // Subjects returns a list of the DER-encoded subjects of // all of the certificates in the pool. func (s *CertPool) Subjects() [][]byte { - res := make([][]byte, len(s.certs)) - for i, c := range s.certs { - res[i] = c.RawSubject + res := make([][]byte, s.len()) + for i, lc := range s.lazyCerts { + res[i] = lc.rawSubject } return res } diff --git a/src/crypto/x509/name_constraints_test.go b/src/crypto/x509/name_constraints_test.go index 5469e28de2..34055d07b5 100644 --- a/src/crypto/x509/name_constraints_test.go +++ b/src/crypto/x509/name_constraints_test.go @@ -1941,7 +1941,7 @@ func TestConstraintCases(t *testing.T) { // Skip tests with CommonName set because OpenSSL will try to match it // against name constraints, while we ignore it when it's not hostname-looking. if !test.noOpenSSL && testNameConstraintsAgainstOpenSSL && test.leaf.cn == "" { - output, err := testChainAgainstOpenSSL(leafCert, intermediatePool, rootPool) + output, err := testChainAgainstOpenSSL(t, leafCert, intermediatePool, rootPool) if err == nil && len(test.expectedError) > 0 { t.Errorf("#%d: unexpectedly succeeded against OpenSSL", i) if debugOpenSSLFailure { @@ -1993,7 +1993,7 @@ func TestConstraintCases(t *testing.T) { pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}) return buf.String() } - t.Errorf("#%d: root:\n%s", i, certAsPEM(rootPool.certs[0])) + t.Errorf("#%d: root:\n%s", i, certAsPEM(rootPool.mustCert(t, 0))) t.Errorf("#%d: leaf:\n%s", i, certAsPEM(leafCert)) } @@ -2019,10 +2019,10 @@ func writePEMsToTempFile(certs []*Certificate) *os.File { return file } -func testChainAgainstOpenSSL(leaf *Certificate, intermediates, roots *CertPool) (string, error) { +func testChainAgainstOpenSSL(t *testing.T, leaf *Certificate, intermediates, roots *CertPool) (string, error) { args := []string{"verify", "-no_check_time"} - rootsFile := writePEMsToTempFile(roots.certs) + rootsFile := writePEMsToTempFile(allCerts(t, roots)) if debugOpenSSLFailure { println("roots file:", rootsFile.Name()) } else { @@ -2030,8 +2030,8 @@ func testChainAgainstOpenSSL(leaf *Certificate, intermediates, roots *CertPool) } args = append(args, "-CAfile", rootsFile.Name()) - if len(intermediates.certs) > 0 { - intermediatesFile := writePEMsToTempFile(intermediates.certs) + if intermediates.len() > 0 { + intermediatesFile := writePEMsToTempFile(allCerts(t, intermediates)) if debugOpenSSLFailure { println("intermediates file:", intermediatesFile.Name()) } else { diff --git a/src/crypto/x509/root_cgo_darwin.go b/src/crypto/x509/root_cgo_darwin.go index 15c72cc0c8..825e8d4812 100644 --- a/src/crypto/x509/root_cgo_darwin.go +++ b/src/crypto/x509/root_cgo_darwin.go @@ -313,7 +313,11 @@ func _loadSystemRootsWithCgo() (*CertPool, error) { untrustedRoots.AppendCertsFromPEM(buf) trustedRoots := NewCertPool() - for _, c := range roots.certs { + for _, lc := range roots.lazyCerts { + c, err := lc.getCert() + if err != nil { + return nil, err + } if !untrustedRoots.contains(c) { trustedRoots.AddCert(c) } diff --git a/src/crypto/x509/root_darwin_test.go b/src/crypto/x509/root_darwin_test.go index 2c773b9120..69f181c2d4 100644 --- a/src/crypto/x509/root_darwin_test.go +++ b/src/crypto/x509/root_darwin_test.go @@ -24,7 +24,7 @@ func TestSystemRoots(t *testing.T) { // There are 174 system roots on Catalina, and 163 on iOS right now, require // at least 100 to make sure this is not completely broken. - if want, have := 100, len(sysRoots.certs); have < want { + if want, have := 100, sysRoots.len(); have < want { t.Errorf("want at least %d system roots, have %d", want, have) } @@ -43,11 +43,14 @@ func TestSystemRoots(t *testing.T) { t.Logf("loadSystemRootsWithCgo: %v", cgoSysRootsDuration) // Check that the two cert pools are the same. - sysPool := make(map[string]*Certificate, len(sysRoots.certs)) - for _, c := range sysRoots.certs { + sysPool := make(map[string]*Certificate, sysRoots.len()) + for i := 0; i < sysRoots.len(); i++ { + c := sysRoots.mustCert(t, i) sysPool[string(c.Raw)] = c } - for _, c := range cgoRoots.certs { + for i := 0; i < cgoRoots.len(); i++ { + c := cgoRoots.mustCert(t, i) + if _, ok := sysPool[string(c.Raw)]; ok { delete(sysPool, string(c.Raw)) } else { diff --git a/src/crypto/x509/root_unix.go b/src/crypto/x509/root_unix.go index ae72f025c3..1090b69f83 100644 --- a/src/crypto/x509/root_unix.go +++ b/src/crypto/x509/root_unix.go @@ -75,7 +75,7 @@ func loadSystemRoots() (*CertPool, error) { } } - if len(roots.certs) > 0 || firstErr == nil { + if roots.len() > 0 || firstErr == nil { return roots, nil } diff --git a/src/crypto/x509/root_unix_test.go b/src/crypto/x509/root_unix_test.go index 5a8015429c..b2e832ff36 100644 --- a/src/crypto/x509/root_unix_test.go +++ b/src/crypto/x509/root_unix_test.go @@ -113,15 +113,15 @@ func TestEnvVars(t *testing.T) { // Verify that the returned certs match, otherwise report where the mismatch is. for i, cn := range tc.cns { - if i >= len(r.certs) { + if i >= r.len() { t.Errorf("missing cert %v @ %v", cn, i) - } else if r.certs[i].Subject.CommonName != cn { - fmt.Printf("%#v\n", r.certs[0].Subject) - t.Errorf("unexpected cert common name %q, want %q", r.certs[i].Subject.CommonName, cn) + } else if r.mustCert(t, i).Subject.CommonName != cn { + fmt.Printf("%#v\n", r.mustCert(t, 0).Subject) + t.Errorf("unexpected cert common name %q, want %q", r.mustCert(t, i).Subject.CommonName, cn) } } - if len(r.certs) > len(tc.cns) { - t.Errorf("got %v certs, which is more than %v wanted", len(r.certs), len(tc.cns)) + if r.len() > len(tc.cns) { + t.Errorf("got %v certs, which is more than %v wanted", r.len(), len(tc.cns)) } }) } @@ -197,7 +197,8 @@ func TestLoadSystemCertsLoadColonSeparatedDirs(t *testing.T) { strCertPool := func(p *CertPool) string { return string(bytes.Join(p.Subjects(), []byte("\n"))) } - if !reflect.DeepEqual(gotPool, wantPool) { + + if !certPoolEqual(gotPool, wantPool) { g, w := strCertPool(gotPool), strCertPool(wantPool) t.Fatalf("Mismatched certPools\nGot:\n%s\n\nWant:\n%s", g, w) } diff --git a/src/crypto/x509/root_windows.go b/src/crypto/x509/root_windows.go index 1e0f3acb67..22e5a9382b 100644 --- a/src/crypto/x509/root_windows.go +++ b/src/crypto/x509/root_windows.go @@ -38,7 +38,11 @@ func createStoreContext(leaf *Certificate, opts *VerifyOptions) (*syscall.CertCo } if opts.Intermediates != nil { - for _, intermediate := range opts.Intermediates.certs { + for i := 0; i < opts.Intermediates.len(); i++ { + intermediate, err := opts.Intermediates.cert(i) + if err != nil { + return nil, err + } ctx, err := syscall.CertCreateCertificateContext(syscall.X509_ASN_ENCODING|syscall.PKCS_7_ASN_ENCODING, &intermediate.Raw[0], uint32(len(intermediate.Raw))) if err != nil { return nil, err diff --git a/src/crypto/x509/verify.go b/src/crypto/x509/verify.go index 5fdd4cb9fe..46afb2698a 100644 --- a/src/crypto/x509/verify.go +++ b/src/crypto/x509/verify.go @@ -761,11 +761,13 @@ func (c *Certificate) Verify(opts VerifyOptions) (chains [][]*Certificate, err e if len(c.Raw) == 0 { return nil, errNotParsed } - if opts.Intermediates != nil { - for _, intermediate := range opts.Intermediates.certs { - if len(intermediate.Raw) == 0 { - return nil, errNotParsed - } + for i := 0; i < opts.Intermediates.len(); i++ { + c, err := opts.Intermediates.cert(i) + if err != nil { + return nil, fmt.Errorf("crypto/x509: error fetching intermediate: %w", err) + } + if len(c.Raw) == 0 { + return nil, errNotParsed } } @@ -891,11 +893,11 @@ func (c *Certificate) buildChains(cache map[*Certificate][][]*Certificate, curre } } - for _, rootNum := range opts.Roots.findPotentialParents(c) { - considerCandidate(rootCertificate, opts.Roots.certs[rootNum]) + for _, root := range opts.Roots.findPotentialParents(c) { + considerCandidate(rootCertificate, root) } - for _, intermediateNum := range opts.Intermediates.findPotentialParents(c) { - considerCandidate(intermediateCertificate, opts.Intermediates.certs[intermediateNum]) + for _, intermediate := range opts.Intermediates.findPotentialParents(c) { + considerCandidate(intermediateCertificate, intermediate) } if len(chains) > 0 { diff --git a/src/crypto/x509/x509_test.go b/src/crypto/x509/x509_test.go index 47d78cf02a..1ba31aeff3 100644 --- a/src/crypto/x509/x509_test.go +++ b/src/crypto/x509/x509_test.go @@ -1960,7 +1960,7 @@ func TestSystemCertPool(t *testing.T) { if err != nil { t.Fatal(err) } - if !reflect.DeepEqual(a, b) { + if !certPoolEqual(a, b) { t.Fatal("two calls to SystemCertPool had different results") } if ok := b.AppendCertsFromPEM([]byte(` @@ -2912,3 +2912,53 @@ func TestCreateCertificateMD5(t *testing.T) { t.Fatalf("CreateCertificate failed when SignatureAlgorithm = MD5WithRSA: %s", err) } } + +func (s *CertPool) mustCert(t *testing.T, n int) *Certificate { + c, err := s.lazyCerts[n].getCert() + if err != nil { + t.Fatalf("failed to load cert %d: %v", n, err) + } + return c +} + +func allCerts(t *testing.T, p *CertPool) []*Certificate { + all := make([]*Certificate, p.len()) + for i := range all { + all[i] = p.mustCert(t, i) + } + return all +} + +// certPoolEqual reports whether a and b are equal, except for the +// function pointers. +func certPoolEqual(a, b *CertPool) bool { + if (a != nil) != (b != nil) { + return false + } + if a == nil { + return true + } + if !reflect.DeepEqual(a.byName, b.byName) || + len(a.lazyCerts) != len(b.lazyCerts) { + return false + } + for i := range a.lazyCerts { + la, lb := a.lazyCerts[i], b.lazyCerts[i] + if !bytes.Equal(la.rawSubject, lb.rawSubject) { + return false + } + ca, err := la.getCert() + if err != nil { + panic(err) + } + cb, err := la.getCert() + if err != nil { + panic(err) + } + if !ca.Equal(cb) { + return false + } + } + + return true +}