Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Windows certificate store improvements #6042

Merged
merged 3 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions server/certstore/certstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,13 @@ type MatchByType int
const (
matchByIssuer MatchByType = iota + 1
matchBySubject
matchByThumbprint
)

var MatchByMap = map[string]MatchByType{
"issuer": matchByIssuer,
"subject": matchBySubject,
"issuer": matchByIssuer,
"subject": matchBySubject,
"thumbprint": matchByThumbprint,
}

var Usage = `
Expand Down
3 changes: 1 addition & 2 deletions server/certstore/certstore_other.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ var _ = MATCHBYEMPTY
// otherKey implements crypto.Signer and crypto.Decrypter to satisfy linter on platforms that don't implement certstore
type otherKey struct{}

func TLSConfig(certStore StoreType, certMatchBy MatchByType, certMatch string, caCertsMatch []string, config *tls.Config) error {
_, _, _, _, _ = certStore, certMatchBy, certMatch, caCertsMatch, config
func TLSConfig(_ StoreType, _ MatchByType, _ string, _ []string, _ bool, _ *tls.Config) error {
return ErrOSNotCompatCertStore
}

Expand Down
132 changes: 85 additions & 47 deletions server/certstore/certstore_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"crypto/sha1"
"crypto/tls"
"crypto/x509"
"encoding/binary"
"encoding/hex"
"fmt"
"io"
"math/big"
Expand All @@ -41,27 +43,26 @@ import (

const (
// wincrypt.h constants
winAcquireCached = 0x1 // CRYPT_ACQUIRE_CACHE_FLAG
winAcquireSilent = 0x40 // CRYPT_ACQUIRE_SILENT_FLAG
winAcquireOnlyNCryptKey = 0x40000 // CRYPT_ACQUIRE_ONLY_NCRYPT_KEY_FLAG
winEncodingX509ASN = 1 // X509_ASN_ENCODING
winEncodingPKCS7 = 65536 // PKCS_7_ASN_ENCODING
winCertStoreProvSystem = 10 // CERT_STORE_PROV_SYSTEM
winCertStoreCurrentUser = windows.CERT_SYSTEM_STORE_CURRENT_USER // CERT_SYSTEM_STORE_CURRENT_USER
winCertStoreLocalMachine = windows.CERT_SYSTEM_STORE_LOCAL_MACHINE // CERT_SYSTEM_STORE_LOCAL_MACHINE
winCertStoreReadOnly = windows.CERT_STORE_READONLY_FLAG // CERT_STORE_MAXIMUM_ALLOWED_FLAG
winCertStoreCurrentUserID = 1 // CERT_SYSTEM_STORE_CURRENT_USER_ID
winCertStoreLocalMachineID = 2 // CERT_SYSTEM_STORE_LOCAL_MACHINE_ID
winInfoIssuerFlag = 4 // CERT_INFO_ISSUER_FLAG
winInfoSubjectFlag = 7 // CERT_INFO_SUBJECT_FLAG
winCompareNameStrW = 8 // CERT_COMPARE_NAME_STR_A
winCompareShift = 16 // CERT_COMPARE_SHIFT
winAcquireCached = windows.CRYPT_ACQUIRE_CACHE_FLAG
winAcquireSilent = windows.CRYPT_ACQUIRE_SILENT_FLAG
winAcquireOnlyNCryptKey = windows.CRYPT_ACQUIRE_ONLY_NCRYPT_KEY_FLAG
winEncodingX509ASN = windows.X509_ASN_ENCODING
winEncodingPKCS7 = windows.PKCS_7_ASN_ENCODING
winCertStoreProvSystem = windows.CERT_STORE_PROV_SYSTEM
winCertStoreCurrentUser = windows.CERT_SYSTEM_STORE_CURRENT_USER
winCertStoreLocalMachine = windows.CERT_SYSTEM_STORE_LOCAL_MACHINE
winCertStoreReadOnly = windows.CERT_STORE_READONLY_FLAG
winInfoIssuerFlag = windows.CERT_INFO_ISSUER_FLAG
winInfoSubjectFlag = windows.CERT_INFO_SUBJECT_FLAG
winCompareNameStrW = windows.CERT_COMPARE_NAME_STR_W
winCompareShift = windows.CERT_COMPARE_SHIFT

// Reference https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/nf-wincrypt-certfindcertificateinstore
winFindIssuerStr = winCompareNameStrW<<winCompareShift | winInfoIssuerFlag // CERT_FIND_ISSUER_STR_W
winFindSubjectStr = winCompareNameStrW<<winCompareShift | winInfoSubjectFlag // CERT_FIND_SUBJECT_STR_W
winFindIssuerStr = windows.CERT_FIND_ISSUER_STR_W
winFindSubjectStr = windows.CERT_FIND_SUBJECT_STR_W
winFindHashStr = windows.CERT_FIND_HASH_STR

winNcryptKeySpec = 0xFFFFFFFF // CERT_NCRYPT_KEY_SPEC
winNcryptKeySpec = windows.CERT_NCRYPT_KEY_SPEC

winBCryptPadPKCS1 uintptr = 0x2
winBCryptPadPSS uintptr = 0x8 // Modern TLS 1.2+
Expand All @@ -77,7 +78,7 @@ const (
winECK3Magic = 0x334B4345 // "ECK3" BCRYPT_ECDH_PUBLIC_P384_MAGIC
winECK5Magic = 0x354B4345 // "ECK5" BCRYPT_ECDH_PUBLIC_P521_MAGIC

winCryptENotFound = 0x80092004 // CRYPT_E_NOT_FOUND
winCryptENotFound = windows.CRYPT_E_NOT_FOUND

providerMSSoftware = "Microsoft Software Key Storage Provider"
)
Expand Down Expand Up @@ -129,6 +130,7 @@ var (
winNCrypt = windows.NewLazySystemDLL("ncrypt.dll")

winCertFindCertificateInStore = winCrypt32.NewProc("CertFindCertificateInStore")
winCertVerifyTimeValidity = winCrypt32.NewProc("CertVerifyTimeValidity")
winCryptAcquireCertificatePrivateKey = winCrypt32.NewProc("CryptAcquireCertificatePrivateKey")
winNCryptExportKey = winNCrypt.NewProc("NCryptExportKey")
winNCryptOpenStorageProvider = winNCrypt.NewProc("NCryptOpenStorageProvider")
Expand Down Expand Up @@ -170,11 +172,11 @@ type winPSSPaddingInfo struct {
// adding all matching certificates from the caCertsMatch array to the pool.
// All matching certificates (vs first) are added to the pool based on a user
// request. If no certificates are found an error is returned.
func createCACertsPool(cs *winCertStore, storeType uint32, caCertsMatch []string) (*x509.CertPool, error) {
func createCACertsPool(cs *winCertStore, storeType uint32, caCertsMatch []string, skipInvalid bool) (*x509.CertPool, error) {
var errs []error
caPool := x509.NewCertPool()
for _, s := range caCertsMatch {
lfs, err := cs.caCertsBySubjectMatch(s, storeType)
lfs, err := cs.caCertsBySubjectMatch(s, storeType, skipInvalid)
if err != nil {
errs = append(errs, err)
} else {
Expand All @@ -199,7 +201,7 @@ func createCACertsPool(cs *winCertStore, storeType uint32, caCertsMatch []string
// Subjects matching the provided strings. If a match is found, the
// certificate is added to the pool that is used to verify the certificate
// chain.
func TLSConfig(certStore StoreType, certMatchBy MatchByType, certMatch string, caCertsMatch []string, config *tls.Config) error {
func TLSConfig(certStore StoreType, certMatchBy MatchByType, certMatch string, caCertsMatch []string, skipInvalid bool, config *tls.Config) error {
var (
leaf *x509.Certificate
leafCtx *windows.CertContext
Expand All @@ -226,9 +228,11 @@ func TLSConfig(certStore StoreType, certMatchBy MatchByType, certMatch string, c

// certByIssuer or certBySubject
if certMatchBy == matchBySubject || certMatchBy == MATCHBYEMPTY {
leaf, leafCtx, err = cs.certBySubject(certMatch, scope)
leaf, leafCtx, err = cs.certBySubject(certMatch, scope, skipInvalid)
} else if certMatchBy == matchByIssuer {
leaf, leafCtx, err = cs.certByIssuer(certMatch, scope)
leaf, leafCtx, err = cs.certByIssuer(certMatch, scope, skipInvalid)
} else if certMatchBy == matchByThumbprint {
leaf, leafCtx, err = cs.certByThumbprint(certMatch, scope, skipInvalid)
} else {
return ErrBadMatchByType
}
Expand All @@ -248,7 +252,7 @@ func TLSConfig(certStore StoreType, certMatchBy MatchByType, certMatch string, c
}
// Look for CA Certificates
if len(caCertsMatch) != 0 {
caPool, err := createCACertsPool(cs, scope, caCertsMatch)
caPool, err := createCACertsPool(cs, scope, caCertsMatch, skipInvalid)
if err != nil {
return err
}
Expand Down Expand Up @@ -327,7 +331,7 @@ func winFindCert(store windows.Handle, enc, findFlags, findType uint32, para *ui
)
if h == 0 {
// Actual error, or simply not found?
if errno, ok := err.(syscall.Errno); ok && errno == winCryptENotFound {
if errno, ok := err.(syscall.Errno); ok && errno == syscall.Errno(winCryptENotFound) {
return nil, ErrFailedCertSearch
}
return nil, ErrFailedCertSearch
Expand All @@ -336,6 +340,16 @@ func winFindCert(store windows.Handle, enc, findFlags, findType uint32, para *ui
return (*windows.CertContext)(unsafe.Pointer(h)), nil
}

// winVerifyCertValid wraps the CertVerifyTimeValidity and simply returns true if the certificate is valid
func winVerifyCertValid(timeToVerify *windows.Filetime, certInfo *windows.CertInfo) bool {
// this function does not document returning errors / setting lasterror
r, _, _ := winCertVerifyTimeValidity.Call(
uintptr(unsafe.Pointer(timeToVerify)),
uintptr(unsafe.Pointer(certInfo)),
)
return r == 0
}

// winCertStore is a store implementation for the Windows Certificate Store
type winCertStore struct {
Prov uintptr
Expand Down Expand Up @@ -375,16 +389,31 @@ func winCertContextToX509(ctx *windows.CertContext) (*x509.Certificate, error) {
// CertContext pointer returned allows subsequent key operations like Sign. Caller specifies
// current user's personal certs or local machine's personal certs using storeType.
// See CERT_FIND_ISSUER_STR description at https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/nf-wincrypt-certfindcertificateinstore
func (w *winCertStore) certByIssuer(issuer string, storeType uint32) (*x509.Certificate, *windows.CertContext, error) {
return w.certSearch(winFindIssuerStr, issuer, winMyStore, storeType)
func (w *winCertStore) certByIssuer(issuer string, storeType uint32, skipInvalid bool) (*x509.Certificate, *windows.CertContext, error) {
return w.certSearch(winFindIssuerStr, issuer, winMyStore, storeType, skipInvalid)
}

// certBySubject matches and returns the first certificate found by passed subject field.
// CertContext pointer returned allows subsequent key operations like Sign. Caller specifies
// current user's personal certs or local machine's personal certs using storeType.
// See CERT_FIND_SUBJECT_STR description at https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/nf-wincrypt-certfindcertificateinstore
func (w *winCertStore) certBySubject(subject string, storeType uint32) (*x509.Certificate, *windows.CertContext, error) {
return w.certSearch(winFindSubjectStr, subject, winMyStore, storeType)
func (w *winCertStore) certBySubject(subject string, storeType uint32, skipInvalid bool) (*x509.Certificate, *windows.CertContext, error) {
return w.certSearch(winFindSubjectStr, subject, winMyStore, storeType, skipInvalid)
}

// certByThumbprint matches and returns the first certificate found by passed SHA1 thumbprint.
// CertContext pointer returned allows subsequent key operations like Sign. Caller specifies
// current user's personal certs or local machine's personal certs using storeType.
// See CERT_FIND_SUBJECT_STR description at https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/nf-wincrypt-certfindcertificateinstore
func (w *winCertStore) certByThumbprint(hash string, storeType uint32, skipInvalid bool) (*x509.Certificate, *windows.CertContext, error) {
hb, err := hex.DecodeString(hash)
if err != nil {
return nil, nil, err
}
if len(hb) != sha1.Size {
return nil, nil, fmt.Errorf("incorrect thumbprint length %d", len(hb))
}
return w.certSearch(winFindHashStr, string(hb), winMyStore, storeType, skipInvalid)
}

// caCertsBySubjectMatch matches and returns all matching certificates of the subject field.
Expand All @@ -396,7 +425,7 @@ func (w *winCertStore) certBySubject(subject string, storeType uint32) (*x509.Ce
//
// Caller specifies current user's personal certs or local machine's personal certs using storeType.
// See CERT_FIND_SUBJECT_STR description at https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/nf-wincrypt-certfindcertificateinstore
func (w *winCertStore) caCertsBySubjectMatch(subject string, storeType uint32) ([]*x509.Certificate, error) {
func (w *winCertStore) caCertsBySubjectMatch(subject string, storeType uint32, skipInvalid bool) ([]*x509.Certificate, error) {
var (
leaf *x509.Certificate
searchLocations = [3]*uint16{winRootStore, winAuthRootStore, winIntermediateCAStore}
Expand All @@ -408,7 +437,7 @@ func (w *winCertStore) caCertsBySubjectMatch(subject string, storeType uint32) (
}
for _, sr := range searchLocations {
var err error
if leaf, _, err = w.certSearch(winFindSubjectStr, subject, sr, storeType); err == nil {
if leaf, _, err = w.certSearch(winFindSubjectStr, subject, sr, storeType, skipInvalid); err == nil {
rv = append(rv, leaf)
} else {
// Ignore the failed search from a single location. Errors we catch include
Expand All @@ -430,7 +459,7 @@ func (w *winCertStore) caCertsBySubjectMatch(subject string, storeType uint32) (

// certSearch is a helper function to lookup certificates based on search type and match value.
// store is used to specify which store to perform the lookup in (system or user).
func (w *winCertStore) certSearch(searchType uint32, matchValue string, searchRoot *uint16, store uint32) (*x509.Certificate, *windows.CertContext, error) {
func (w *winCertStore) certSearch(searchType uint32, matchValue string, searchRoot *uint16, store uint32, skipInvalid bool) (*x509.Certificate, *windows.CertContext, error) {
// store handle to "MY" store
h, err := w.storeHandle(store, searchRoot)
if err != nil {
Expand All @@ -447,23 +476,32 @@ func (w *winCertStore) certSearch(searchType uint32, matchValue string, searchRo

// pass 0 as the third parameter because it is not used
// https://msdn.microsoft.com/en-us/library/windows/desktop/aa376064(v=vs.85).aspx
nc, err := winFindCert(h, winEncodingX509ASN|winEncodingPKCS7, 0, searchType, i, prev)
if err != nil {
return nil, nil, err
}
if nc != nil {
// certificate found
prev = nc

// Extract the DER-encoded certificate from the cert context
xc, err := winCertContextToX509(nc)
if err == nil {
cert = xc
for {
nc, err := winFindCert(h, winEncodingX509ASN|winEncodingPKCS7, 0, searchType, i, prev)
if err != nil {
return nil, nil, err
}
if nc != nil {
// certificate found
prev = nc

var now *windows.Filetime
if skipInvalid && !winVerifyCertValid(now, nc.CertInfo) {
continue
}

// Extract the DER-encoded certificate from the cert context
xc, err := winCertContextToX509(nc)
if err == nil {
cert = xc
break
} else {
return nil, nil, ErrFailedX509Extract
}
} else {
return nil, nil, ErrFailedX509Extract
return nil, nil, ErrFailedCertSearch
}
} else {
return nil, nil, ErrFailedCertSearch
}

if cert == nil {
Expand Down
3 changes: 3 additions & 0 deletions server/certstore/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ var (
// ErrBadCaCertMatchField represents malformed cert_match option
ErrBadCaCertMatchField = errors.New("expected 'ca_certs_match' to be a valid non-empty string array")

// ErrBadCertMatchSkipInvalidField represents malformed cert_match_skip_invalid option
ErrBadCertMatchSkipInvalidField = errors.New("expected 'cert_match_skip_invalid' to be a boolean")

// ErrOSNotCompatCertStore represents cert_store passed that exists but is not valid on current OS
ErrOSNotCompatCertStore = errors.New("cert_store not compatible with current operating system")
)
57 changes: 56 additions & 1 deletion server/certstore_windows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,12 @@ import (
)

func runPowershellScript(scriptFile string, args []string) error {
_ = args
psExec, _ := exec.LookPath("powershell.exe")

execArgs := []string{psExec, "-command", fmt.Sprintf("& '%s'", scriptFile)}
if len(args) > 0 {
execArgs = append(execArgs, args...)
}

cmdImport := &exec.Cmd{
Path: psExec,
Expand Down Expand Up @@ -251,3 +254,55 @@ func TestServerTLSWindowsCertStore(t *testing.T) {
})
}
}

// TestServerIgnoreExpiredCerts tests if the server skips expired certificates in configuration, and finds non-expired ones
func TestServerIgnoreExpiredCerts(t *testing.T) {

// Server Identities: expired.pem; not-expired.pem
// Issuer: OU = NATS.io, CN = localhost
// Subject: OU = NATS.io Operators, CN = localhost

testCases := []struct {
certFile string
expect bool
}{
{"expired.p12", false},
{"not-expired.p12", true},
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("Server certificate: %s", tc.certFile), func(t *testing.T) {
// Make sure windows cert store is reset to avoid conflict with other tests
err := runPowershellScript("../test/configs/certs/tlsauth/certstore/delete-cert-from-store.ps1", nil)
if err != nil {
t.Fatalf("expected powershell cert delete to succeed: %s", err.Error())
}

// Provision Windows cert store with server cert and secret
err = runPowershellScript("../test/configs/certs/tlsauth/certstore/import-p12-server.ps1", []string{tc.certFile})
if err != nil {
t.Fatalf("expected powershell provision to succeed: %s", err.Error())
}
// Fire up the server
srvConfig := createConfFile(t, []byte(`
listen: "localhost:-1"
tls {
cert_store: "WindowsCurrentUser"
cert_match_by: "Subject"
cert_match: "NATS.io Operators"
cert_match_skip_invalid: true
timeout: 5
}
`))
defer removeFile(t, srvConfig)
cfg, _ := ProcessConfigFile(srvConfig)
if (cfg != nil) == tc.expect {
return
}
if tc.expect == false {
t.Fatalf("expected server start to fail with expired certificate")
} else {
t.Fatalf("expected server to start with non expired certificate")
}
})
}
}
Loading