Skip to content

Commit

Permalink
Windows certificate store improvements (#6042)
Browse files Browse the repository at this point in the history
This PR:

* Removes some magic numbers from the certstore code in favour of
constants already defined in `x/sys/windows`
* Adds the `thumbprint` option to `cert_match_by` by allowing matching a
specific certificate my SHA1 thumbprint rather than possibly matching
multiple certificates by name
* Adds the `cert_match_skip_invalid` option by integrating & rebasing a
community PR along with some fix-ups

Fixes #6024 
Fixes #4383
Closes #4384

Signed-off-by: Neil Twigg <neil@nats.io>
  • Loading branch information
derekcollison authored Oct 25, 2024
2 parents 1370816 + 8fec9e3 commit 899d4cf
Show file tree
Hide file tree
Showing 10 changed files with 183 additions and 76 deletions.
6 changes: 4 additions & 2 deletions server/certstore/certstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,13 @@ type MatchByType int
const (
matchByIssuer MatchByType = iota + 1
matchBySubject
matchByThumbprint
)

var MatchByMap = map[string]MatchByType{
"issuer": matchByIssuer,
"subject": matchBySubject,
"issuer": matchByIssuer,
"subject": matchBySubject,
"thumbprint": matchByThumbprint,
}

var Usage = `
Expand Down
3 changes: 1 addition & 2 deletions server/certstore/certstore_other.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ var _ = MATCHBYEMPTY
// otherKey implements crypto.Signer and crypto.Decrypter to satisfy linter on platforms that don't implement certstore
type otherKey struct{}

func TLSConfig(certStore StoreType, certMatchBy MatchByType, certMatch string, caCertsMatch []string, config *tls.Config) error {
_, _, _, _, _ = certStore, certMatchBy, certMatch, caCertsMatch, config
func TLSConfig(_ StoreType, _ MatchByType, _ string, _ []string, _ bool, _ *tls.Config) error {
return ErrOSNotCompatCertStore
}

Expand Down
132 changes: 85 additions & 47 deletions server/certstore/certstore_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"crypto/sha1"
"crypto/tls"
"crypto/x509"
"encoding/binary"
"encoding/hex"
"fmt"
"io"
"math/big"
Expand All @@ -41,27 +43,26 @@ import (

const (
// wincrypt.h constants
winAcquireCached = 0x1 // CRYPT_ACQUIRE_CACHE_FLAG
winAcquireSilent = 0x40 // CRYPT_ACQUIRE_SILENT_FLAG
winAcquireOnlyNCryptKey = 0x40000 // CRYPT_ACQUIRE_ONLY_NCRYPT_KEY_FLAG
winEncodingX509ASN = 1 // X509_ASN_ENCODING
winEncodingPKCS7 = 65536 // PKCS_7_ASN_ENCODING
winCertStoreProvSystem = 10 // CERT_STORE_PROV_SYSTEM
winCertStoreCurrentUser = windows.CERT_SYSTEM_STORE_CURRENT_USER // CERT_SYSTEM_STORE_CURRENT_USER
winCertStoreLocalMachine = windows.CERT_SYSTEM_STORE_LOCAL_MACHINE // CERT_SYSTEM_STORE_LOCAL_MACHINE
winCertStoreReadOnly = windows.CERT_STORE_READONLY_FLAG // CERT_STORE_MAXIMUM_ALLOWED_FLAG
winCertStoreCurrentUserID = 1 // CERT_SYSTEM_STORE_CURRENT_USER_ID
winCertStoreLocalMachineID = 2 // CERT_SYSTEM_STORE_LOCAL_MACHINE_ID
winInfoIssuerFlag = 4 // CERT_INFO_ISSUER_FLAG
winInfoSubjectFlag = 7 // CERT_INFO_SUBJECT_FLAG
winCompareNameStrW = 8 // CERT_COMPARE_NAME_STR_A
winCompareShift = 16 // CERT_COMPARE_SHIFT
winAcquireCached = windows.CRYPT_ACQUIRE_CACHE_FLAG
winAcquireSilent = windows.CRYPT_ACQUIRE_SILENT_FLAG
winAcquireOnlyNCryptKey = windows.CRYPT_ACQUIRE_ONLY_NCRYPT_KEY_FLAG
winEncodingX509ASN = windows.X509_ASN_ENCODING
winEncodingPKCS7 = windows.PKCS_7_ASN_ENCODING
winCertStoreProvSystem = windows.CERT_STORE_PROV_SYSTEM
winCertStoreCurrentUser = windows.CERT_SYSTEM_STORE_CURRENT_USER
winCertStoreLocalMachine = windows.CERT_SYSTEM_STORE_LOCAL_MACHINE
winCertStoreReadOnly = windows.CERT_STORE_READONLY_FLAG
winInfoIssuerFlag = windows.CERT_INFO_ISSUER_FLAG
winInfoSubjectFlag = windows.CERT_INFO_SUBJECT_FLAG
winCompareNameStrW = windows.CERT_COMPARE_NAME_STR_W
winCompareShift = windows.CERT_COMPARE_SHIFT

// Reference https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/nf-wincrypt-certfindcertificateinstore
winFindIssuerStr = winCompareNameStrW<<winCompareShift | winInfoIssuerFlag // CERT_FIND_ISSUER_STR_W
winFindSubjectStr = winCompareNameStrW<<winCompareShift | winInfoSubjectFlag // CERT_FIND_SUBJECT_STR_W
winFindIssuerStr = windows.CERT_FIND_ISSUER_STR_W
winFindSubjectStr = windows.CERT_FIND_SUBJECT_STR_W
winFindHashStr = windows.CERT_FIND_HASH_STR

winNcryptKeySpec = 0xFFFFFFFF // CERT_NCRYPT_KEY_SPEC
winNcryptKeySpec = windows.CERT_NCRYPT_KEY_SPEC

winBCryptPadPKCS1 uintptr = 0x2
winBCryptPadPSS uintptr = 0x8 // Modern TLS 1.2+
Expand All @@ -77,7 +78,7 @@ const (
winECK3Magic = 0x334B4345 // "ECK3" BCRYPT_ECDH_PUBLIC_P384_MAGIC
winECK5Magic = 0x354B4345 // "ECK5" BCRYPT_ECDH_PUBLIC_P521_MAGIC

winCryptENotFound = 0x80092004 // CRYPT_E_NOT_FOUND
winCryptENotFound = windows.CRYPT_E_NOT_FOUND

providerMSSoftware = "Microsoft Software Key Storage Provider"
)
Expand Down Expand Up @@ -129,6 +130,7 @@ var (
winNCrypt = windows.NewLazySystemDLL("ncrypt.dll")

winCertFindCertificateInStore = winCrypt32.NewProc("CertFindCertificateInStore")
winCertVerifyTimeValidity = winCrypt32.NewProc("CertVerifyTimeValidity")
winCryptAcquireCertificatePrivateKey = winCrypt32.NewProc("CryptAcquireCertificatePrivateKey")
winNCryptExportKey = winNCrypt.NewProc("NCryptExportKey")
winNCryptOpenStorageProvider = winNCrypt.NewProc("NCryptOpenStorageProvider")
Expand Down Expand Up @@ -170,11 +172,11 @@ type winPSSPaddingInfo struct {
// adding all matching certificates from the caCertsMatch array to the pool.
// All matching certificates (vs first) are added to the pool based on a user
// request. If no certificates are found an error is returned.
func createCACertsPool(cs *winCertStore, storeType uint32, caCertsMatch []string) (*x509.CertPool, error) {
func createCACertsPool(cs *winCertStore, storeType uint32, caCertsMatch []string, skipInvalid bool) (*x509.CertPool, error) {
var errs []error
caPool := x509.NewCertPool()
for _, s := range caCertsMatch {
lfs, err := cs.caCertsBySubjectMatch(s, storeType)
lfs, err := cs.caCertsBySubjectMatch(s, storeType, skipInvalid)
if err != nil {
errs = append(errs, err)
} else {
Expand All @@ -199,7 +201,7 @@ func createCACertsPool(cs *winCertStore, storeType uint32, caCertsMatch []string
// Subjects matching the provided strings. If a match is found, the
// certificate is added to the pool that is used to verify the certificate
// chain.
func TLSConfig(certStore StoreType, certMatchBy MatchByType, certMatch string, caCertsMatch []string, config *tls.Config) error {
func TLSConfig(certStore StoreType, certMatchBy MatchByType, certMatch string, caCertsMatch []string, skipInvalid bool, config *tls.Config) error {
var (
leaf *x509.Certificate
leafCtx *windows.CertContext
Expand All @@ -226,9 +228,11 @@ func TLSConfig(certStore StoreType, certMatchBy MatchByType, certMatch string, c

// certByIssuer or certBySubject
if certMatchBy == matchBySubject || certMatchBy == MATCHBYEMPTY {
leaf, leafCtx, err = cs.certBySubject(certMatch, scope)
leaf, leafCtx, err = cs.certBySubject(certMatch, scope, skipInvalid)
} else if certMatchBy == matchByIssuer {
leaf, leafCtx, err = cs.certByIssuer(certMatch, scope)
leaf, leafCtx, err = cs.certByIssuer(certMatch, scope, skipInvalid)
} else if certMatchBy == matchByThumbprint {
leaf, leafCtx, err = cs.certByThumbprint(certMatch, scope, skipInvalid)
} else {
return ErrBadMatchByType
}
Expand All @@ -248,7 +252,7 @@ func TLSConfig(certStore StoreType, certMatchBy MatchByType, certMatch string, c
}
// Look for CA Certificates
if len(caCertsMatch) != 0 {
caPool, err := createCACertsPool(cs, scope, caCertsMatch)
caPool, err := createCACertsPool(cs, scope, caCertsMatch, skipInvalid)
if err != nil {
return err
}
Expand Down Expand Up @@ -327,7 +331,7 @@ func winFindCert(store windows.Handle, enc, findFlags, findType uint32, para *ui
)
if h == 0 {
// Actual error, or simply not found?
if errno, ok := err.(syscall.Errno); ok && errno == winCryptENotFound {
if errno, ok := err.(syscall.Errno); ok && errno == syscall.Errno(winCryptENotFound) {
return nil, ErrFailedCertSearch
}
return nil, ErrFailedCertSearch
Expand All @@ -336,6 +340,16 @@ func winFindCert(store windows.Handle, enc, findFlags, findType uint32, para *ui
return (*windows.CertContext)(unsafe.Pointer(h)), nil
}

// winVerifyCertValid wraps the CertVerifyTimeValidity and simply returns true if the certificate is valid
func winVerifyCertValid(timeToVerify *windows.Filetime, certInfo *windows.CertInfo) bool {
// this function does not document returning errors / setting lasterror
r, _, _ := winCertVerifyTimeValidity.Call(
uintptr(unsafe.Pointer(timeToVerify)),
uintptr(unsafe.Pointer(certInfo)),
)
return r == 0
}

// winCertStore is a store implementation for the Windows Certificate Store
type winCertStore struct {
Prov uintptr
Expand Down Expand Up @@ -375,16 +389,31 @@ func winCertContextToX509(ctx *windows.CertContext) (*x509.Certificate, error) {
// CertContext pointer returned allows subsequent key operations like Sign. Caller specifies
// current user's personal certs or local machine's personal certs using storeType.
// See CERT_FIND_ISSUER_STR description at https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/nf-wincrypt-certfindcertificateinstore
func (w *winCertStore) certByIssuer(issuer string, storeType uint32) (*x509.Certificate, *windows.CertContext, error) {
return w.certSearch(winFindIssuerStr, issuer, winMyStore, storeType)
func (w *winCertStore) certByIssuer(issuer string, storeType uint32, skipInvalid bool) (*x509.Certificate, *windows.CertContext, error) {
return w.certSearch(winFindIssuerStr, issuer, winMyStore, storeType, skipInvalid)
}

// certBySubject matches and returns the first certificate found by passed subject field.
// CertContext pointer returned allows subsequent key operations like Sign. Caller specifies
// current user's personal certs or local machine's personal certs using storeType.
// See CERT_FIND_SUBJECT_STR description at https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/nf-wincrypt-certfindcertificateinstore
func (w *winCertStore) certBySubject(subject string, storeType uint32) (*x509.Certificate, *windows.CertContext, error) {
return w.certSearch(winFindSubjectStr, subject, winMyStore, storeType)
func (w *winCertStore) certBySubject(subject string, storeType uint32, skipInvalid bool) (*x509.Certificate, *windows.CertContext, error) {
return w.certSearch(winFindSubjectStr, subject, winMyStore, storeType, skipInvalid)
}

// certByThumbprint matches and returns the first certificate found by passed SHA1 thumbprint.
// CertContext pointer returned allows subsequent key operations like Sign. Caller specifies
// current user's personal certs or local machine's personal certs using storeType.
// See CERT_FIND_SUBJECT_STR description at https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/nf-wincrypt-certfindcertificateinstore
func (w *winCertStore) certByThumbprint(hash string, storeType uint32, skipInvalid bool) (*x509.Certificate, *windows.CertContext, error) {
hb, err := hex.DecodeString(hash)
if err != nil {
return nil, nil, err
}
if len(hb) != sha1.Size {
return nil, nil, fmt.Errorf("incorrect thumbprint length %d", len(hb))
}
return w.certSearch(winFindHashStr, string(hb), winMyStore, storeType, skipInvalid)
}

// caCertsBySubjectMatch matches and returns all matching certificates of the subject field.
Expand All @@ -396,7 +425,7 @@ func (w *winCertStore) certBySubject(subject string, storeType uint32) (*x509.Ce
//
// Caller specifies current user's personal certs or local machine's personal certs using storeType.
// See CERT_FIND_SUBJECT_STR description at https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/nf-wincrypt-certfindcertificateinstore
func (w *winCertStore) caCertsBySubjectMatch(subject string, storeType uint32) ([]*x509.Certificate, error) {
func (w *winCertStore) caCertsBySubjectMatch(subject string, storeType uint32, skipInvalid bool) ([]*x509.Certificate, error) {
var (
leaf *x509.Certificate
searchLocations = [3]*uint16{winRootStore, winAuthRootStore, winIntermediateCAStore}
Expand All @@ -408,7 +437,7 @@ func (w *winCertStore) caCertsBySubjectMatch(subject string, storeType uint32) (
}
for _, sr := range searchLocations {
var err error
if leaf, _, err = w.certSearch(winFindSubjectStr, subject, sr, storeType); err == nil {
if leaf, _, err = w.certSearch(winFindSubjectStr, subject, sr, storeType, skipInvalid); err == nil {
rv = append(rv, leaf)
} else {
// Ignore the failed search from a single location. Errors we catch include
Expand All @@ -430,7 +459,7 @@ func (w *winCertStore) caCertsBySubjectMatch(subject string, storeType uint32) (

// certSearch is a helper function to lookup certificates based on search type and match value.
// store is used to specify which store to perform the lookup in (system or user).
func (w *winCertStore) certSearch(searchType uint32, matchValue string, searchRoot *uint16, store uint32) (*x509.Certificate, *windows.CertContext, error) {
func (w *winCertStore) certSearch(searchType uint32, matchValue string, searchRoot *uint16, store uint32, skipInvalid bool) (*x509.Certificate, *windows.CertContext, error) {
// store handle to "MY" store
h, err := w.storeHandle(store, searchRoot)
if err != nil {
Expand All @@ -447,23 +476,32 @@ func (w *winCertStore) certSearch(searchType uint32, matchValue string, searchRo

// pass 0 as the third parameter because it is not used
// https://msdn.microsoft.com/en-us/library/windows/desktop/aa376064(v=vs.85).aspx
nc, err := winFindCert(h, winEncodingX509ASN|winEncodingPKCS7, 0, searchType, i, prev)
if err != nil {
return nil, nil, err
}
if nc != nil {
// certificate found
prev = nc

// Extract the DER-encoded certificate from the cert context
xc, err := winCertContextToX509(nc)
if err == nil {
cert = xc
for {
nc, err := winFindCert(h, winEncodingX509ASN|winEncodingPKCS7, 0, searchType, i, prev)
if err != nil {
return nil, nil, err
}
if nc != nil {
// certificate found
prev = nc

var now *windows.Filetime
if skipInvalid && !winVerifyCertValid(now, nc.CertInfo) {
continue
}

// Extract the DER-encoded certificate from the cert context
xc, err := winCertContextToX509(nc)
if err == nil {
cert = xc
break
} else {
return nil, nil, ErrFailedX509Extract
}
} else {
return nil, nil, ErrFailedX509Extract
return nil, nil, ErrFailedCertSearch
}
} else {
return nil, nil, ErrFailedCertSearch
}

if cert == nil {
Expand Down
3 changes: 3 additions & 0 deletions server/certstore/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ var (
// ErrBadCaCertMatchField represents malformed cert_match option
ErrBadCaCertMatchField = errors.New("expected 'ca_certs_match' to be a valid non-empty string array")

// ErrBadCertMatchSkipInvalidField represents malformed cert_match_skip_invalid option
ErrBadCertMatchSkipInvalidField = errors.New("expected 'cert_match_skip_invalid' to be a boolean")

// ErrOSNotCompatCertStore represents cert_store passed that exists but is not valid on current OS
ErrOSNotCompatCertStore = errors.New("cert_store not compatible with current operating system")
)
57 changes: 56 additions & 1 deletion server/certstore_windows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,12 @@ import (
)

func runPowershellScript(scriptFile string, args []string) error {
_ = args
psExec, _ := exec.LookPath("powershell.exe")

execArgs := []string{psExec, "-command", fmt.Sprintf("& '%s'", scriptFile)}
if len(args) > 0 {
execArgs = append(execArgs, args...)
}

cmdImport := &exec.Cmd{
Path: psExec,
Expand Down Expand Up @@ -251,3 +254,55 @@ func TestServerTLSWindowsCertStore(t *testing.T) {
})
}
}

// TestServerIgnoreExpiredCerts tests if the server skips expired certificates in configuration, and finds non-expired ones
func TestServerIgnoreExpiredCerts(t *testing.T) {

// Server Identities: expired.pem; not-expired.pem
// Issuer: OU = NATS.io, CN = localhost
// Subject: OU = NATS.io Operators, CN = localhost

testCases := []struct {
certFile string
expect bool
}{
{"expired.p12", false},
{"not-expired.p12", true},
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("Server certificate: %s", tc.certFile), func(t *testing.T) {
// Make sure windows cert store is reset to avoid conflict with other tests
err := runPowershellScript("../test/configs/certs/tlsauth/certstore/delete-cert-from-store.ps1", nil)
if err != nil {
t.Fatalf("expected powershell cert delete to succeed: %s", err.Error())
}

// Provision Windows cert store with server cert and secret
err = runPowershellScript("../test/configs/certs/tlsauth/certstore/import-p12-server.ps1", []string{tc.certFile})
if err != nil {
t.Fatalf("expected powershell provision to succeed: %s", err.Error())
}
// Fire up the server
srvConfig := createConfFile(t, []byte(`
listen: "localhost:-1"
tls {
cert_store: "WindowsCurrentUser"
cert_match_by: "Subject"
cert_match: "NATS.io Operators"
cert_match_skip_invalid: true
timeout: 5
}
`))
defer removeFile(t, srvConfig)
cfg, _ := ProcessConfigFile(srvConfig)
if (cfg != nil) == tc.expect {
return
}
if tc.expect == false {
t.Fatalf("expected server start to fail with expired certificate")
} else {
t.Fatalf("expected server to start with non expired certificate")
}
})
}
}
Loading

0 comments on commit 899d4cf

Please sign in to comment.