diff --git a/pkg/ca/fileca/fileca.go b/pkg/ca/fileca/fileca.go index 32e92d65a..63851101c 100644 --- a/pkg/ca/fileca/fileca.go +++ b/pkg/ca/fileca/fileca.go @@ -16,6 +16,7 @@ package fileca import ( + "bytes" "context" "crypto" "crypto/rand" @@ -32,8 +33,8 @@ import ( type fileCA struct { sync.RWMutex - cert *x509.Certificate - key crypto.Signer + certs []*x509.Certificate + key crypto.Signer } // NewFileCA returns a file backed certificate authority. Expects paths to a @@ -43,7 +44,7 @@ func NewFileCA(certPath, keyPath, keyPass string, watch bool) (ca.CertificateAut var fca fileCA var err error - fca.cert, fca.key, err = loadKeyPair(certPath, keyPath, keyPass) + fca.certs, fca.key, err = loadKeyPair(certPath, keyPath, keyPass) if err != nil { return nil, err } @@ -68,21 +69,21 @@ func NewFileCA(certPath, keyPath, keyPass string, watch bool) (ca.CertificateAut return &fca, err } -func (fca *fileCA) updateX509KeyPair(cert *x509.Certificate, key crypto.Signer) { +func (fca *fileCA) updateX509KeyPair(certs []*x509.Certificate, key crypto.Signer) { fca.Lock() defer fca.Unlock() // NB: We use the RWLock to unsure a reading thread can't get a mismatching // cert / key pair by reading the attributes halfway through the update // below. - fca.cert = cert + fca.certs = certs fca.key = key } func (fca *fileCA) getX509KeyPair() (*x509.Certificate, crypto.Signer) { fca.RLock() defer fca.RUnlock() - return fca.cert, fca.key + return fca.certs[0], fca.key } // CreateCertificate issues code signing certificates @@ -103,8 +104,19 @@ func (fca *fileCA) CreateCertificate(_ context.Context, subject *challenges.Chal } func (fca *fileCA) Root(ctx context.Context) ([]byte, error) { - return pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE", - Bytes: fca.cert.Raw, - }), nil + fca.RLock() + defer fca.RUnlock() + + buf := new(bytes.Buffer) + for _, cert := range fca.certs { + err := pem.Encode(buf, &pem.Block{ + Type: "CERTIFICATE", + Bytes: cert.Raw, + }) + if err != nil { + return nil, err + } + } + + return buf.Bytes(), nil } diff --git a/pkg/ca/fileca/load.go b/pkg/ca/fileca/load.go index 61b84c136..f8f372b30 100644 --- a/pkg/ca/fileca/load.go +++ b/pkg/ca/fileca/load.go @@ -27,21 +27,34 @@ import ( "go.step.sm/crypto/pemutil" ) -func loadKeyPair(certPath, keyPath, keyPass string) (*x509.Certificate, crypto.Signer, error) { +func loadKeyPair(certPath, keyPath, keyPass string) ([]*x509.Certificate, crypto.Signer, error) { var ( - cert *x509.Certificate - err error - key crypto.Signer + certs []*x509.Certificate + err error + key crypto.Signer ) - // TODO: Load chain of certs (intermediates and root) instead of just one - // certificate. - cert, err = pemutil.ReadCertificate(certPath) + certs, err = pemutil.ReadCertificateBundle(certPath) if err != nil { return nil, nil, err } + // Verify certificate chain + { + roots := x509.NewCertPool() + for _, cert := range certs { + roots.AddCert(cert) + } + + opts := x509.VerifyOptions{ + Roots: roots, + } + if _, err := certs[0].Verify(opts); err != nil { + return nil, nil, err + } + } + { opaqueKey, err := pemutil.Read(keyPath, pemutil.WithPassword([]byte(keyPass))) if err != nil { @@ -55,15 +68,15 @@ func loadKeyPair(certPath, keyPath, keyPass string) (*x509.Certificate, crypto.S } } - if !valid(cert, key) { + if !valid(certs[0], key) { return nil, nil, errors.New(`fileca: certificate public key and private key don't match`) } - if !cert.IsCA { + if !certs[0].IsCA { return nil, nil, errors.New(`fileca: certificate is not a CA`) } - return cert, key, nil + return certs, key, nil } func valid(cert *x509.Certificate, key crypto.Signer) bool { diff --git a/pkg/ca/fileca/load_test.go b/pkg/ca/fileca/load_test.go index 7da90e1e6..b731b3379 100644 --- a/pkg/ca/fileca/load_test.go +++ b/pkg/ca/fileca/load_test.go @@ -25,6 +25,7 @@ func TestValidLoadKeyPair(t *testing.T) { "ecdsa", "ed25519", "rsa4096", + "intermediate", } for _, keypair := range keypairs { diff --git a/pkg/ca/fileca/watch.go b/pkg/ca/fileca/watch.go index b509bf990..2a8e68404 100644 --- a/pkg/ca/fileca/watch.go +++ b/pkg/ca/fileca/watch.go @@ -22,10 +22,10 @@ import ( "github.com/fsnotify/fsnotify" ) -func ioWatch(certPath, keyPath, keyPass string, watcher *fsnotify.Watcher, callback func(*x509.Certificate, crypto.Signer)) { +func ioWatch(certPath, keyPath, keyPass string, watcher *fsnotify.Watcher, callback func([]*x509.Certificate, crypto.Signer)) { for event := range watcher.Events { if event.Op&fsnotify.Write == fsnotify.Write { - cert, key, err := loadKeyPair(certPath, keyPath, keyPass) + certs, key, err := loadKeyPair(certPath, keyPath, keyPass) if err != nil { // Don't sweat it if this errors out. One file might // have updated and the other isn't causing a key-pair @@ -33,7 +33,7 @@ func ioWatch(certPath, keyPath, keyPass string, watcher *fsnotify.Watcher, callb continue } - callback(cert, key) + callback(certs, key) } } } diff --git a/pkg/ca/fileca/watch_test.go b/pkg/ca/fileca/watch_test.go index 76d11442b..6ebd65927 100644 --- a/pkg/ca/fileca/watch_test.go +++ b/pkg/ca/fileca/watch_test.go @@ -57,14 +57,14 @@ func TestIOWatch(t *testing.T) { // Set up callback trap var received []struct { - cert *x509.Certificate - key crypto.Signer + certs []*x509.Certificate + key crypto.Signer } - callback := func(cert *x509.Certificate, key crypto.Signer) { + callback := func(certs []*x509.Certificate, key crypto.Signer) { received = append(received, struct { - cert *x509.Certificate - key crypto.Signer - }{cert, key}) + certs []*x509.Certificate + key crypto.Signer + }{certs, key}) } // Set up watcher