From cc6f2a766a42b689497fe6c577c3247a458fbafa Mon Sep 17 00:00:00 2001 From: CertoToStore Team Date: Wed, 6 Sep 2023 06:45:12 -0700 Subject: [PATCH] Support PSS padding. PiperOrigin-RevId: 563089484 --- certtostore_windows.go | 62 +++++++++++++++++++++++++++++++----------- 1 file changed, 46 insertions(+), 16 deletions(-) diff --git a/certtostore_windows.go b/certtostore_windows.go index 014ee0e..8b852b5 100644 --- a/certtostore_windows.go +++ b/certtostore_windows.go @@ -42,10 +42,10 @@ import ( "unsafe" "github.com/google/deck" - "github.com/hashicorp/go-multierror" - "golang.org/x/crypto/cryptobyte" "golang.org/x/crypto/cryptobyte/asn1" + "golang.org/x/crypto/cryptobyte" "golang.org/x/sys/windows" + "github.com/hashicorp/go-multierror" ) // WinCertStorage provides windows-specific additions to the CertStorage interface. @@ -106,6 +106,7 @@ const ( // Legacy CryptoAPI flags bCryptPadPKCS1 uintptr = 0x2 + bCryptPadPSS uintptr = 0x8 // Magic numbers for public key blobs. rsa1Magic = 0x31415352 // "RSA1" BCRYPT_RSAPUBLIC_MAGIC @@ -219,11 +220,17 @@ var ( fnGetProperty = getProperty ) -// paddingInfo is the BCRYPT_PKCS1_PADDING_INFO struct in bcrypt.h. -type paddingInfo struct { +// pkcs1PaddingInfo is the BCRYPT_PKCS1_PADDING_INFO struct in bcrypt.h. +type pkcs1PaddingInfo struct { pszAlgID *uint16 } +// pssPaddingInfo is the BCRYPT_PSS_PADDING_INFO struct in bcrypt.h. +type pssPaddingInfo struct { + pszAlgID *uint16 + cbSalt uint32 +} + // wide returns a pointer to a a uint16 representing the equivalent // to a Windows LPCWSTR. func wide(s string) *uint16 { @@ -777,12 +784,7 @@ func (k Key) Sign(_ io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, e case "ECDSA": return signECDSA(k.handle, digest) case "RSA": - hf := opts.HashFunc() - algID, ok := algIDs[hf] - if !ok { - return nil, fmt.Errorf("unsupported RSA hash algorithm %v", hf) - } - return signRSA(k.handle, digest, algID) + return signRSA(k.handle, digest, opts) default: return nil, fmt.Errorf("unsupported algorithm group %v", k.AlgorithmGroup) } @@ -849,19 +851,22 @@ func packECDSASigValue(r io.Reader, digestLength int) ([]byte, error) { return b.Bytes() } -func signRSA(kh uintptr, digest []byte, algID *uint16) ([]byte, error) { - padInfo := paddingInfo{pszAlgID: algID} +func signRSA(kh uintptr, digest []byte, opts crypto.SignerOpts) ([]byte, error) { + paddingInfo, flags, err := rsaPadding(opts) + if err != nil { + return nil, fmt.Errorf("failed to construct padding info: %v", err) + } var size uint32 // Obtain the size of the signature r, _, err := nCryptSignHash.Call( kh, - uintptr(unsafe.Pointer(&padInfo)), + uintptr(paddingInfo), uintptr(unsafe.Pointer(&digest[0])), uintptr(len(digest)), 0, 0, uintptr(unsafe.Pointer(&size)), - bCryptPadPKCS1) + flags) if r != 0 { return nil, fmt.Errorf("NCryptSignHash returned %X during size check: %v", r, err) } @@ -870,13 +875,13 @@ func signRSA(kh uintptr, digest []byte, algID *uint16) ([]byte, error) { sig := make([]byte, size) r, _, err = nCryptSignHash.Call( kh, - uintptr(unsafe.Pointer(&padInfo)), + uintptr(paddingInfo), uintptr(unsafe.Pointer(&digest[0])), uintptr(len(digest)), uintptr(unsafe.Pointer(&sig[0])), uintptr(size), uintptr(unsafe.Pointer(&size)), - bCryptPadPKCS1) + flags) if r != 0 { return nil, fmt.Errorf("NCryptSignHash returned %X during signing: %v", r, err) } @@ -884,6 +889,31 @@ func signRSA(kh uintptr, digest []byte, algID *uint16) ([]byte, error) { return sig[:size], nil } +// rsaPadding constructs the padding info structure and flags from the crypto.SignerOpts. +// https://learn.microsoft.com/en-us/windows/win32/api/bcrypt/nf-bcrypt-bcryptsignhash +func rsaPadding(opts crypto.SignerOpts) (unsafe.Pointer, uintptr, error) { + algID, ok := algIDs[opts.HashFunc()] + if !ok { + return nil, 0, fmt.Errorf("unsupported RSA hash algorithm %v", opts.HashFunc()) + } + if o, ok := opts.(*rsa.PSSOptions); ok { + saltLength := o.SaltLength + switch saltLength { + case rsa.PSSSaltLengthAuto: + return nil, 0, fmt.Errorf("rsa.PSSSaltLengthAuto is not supported") + case rsa.PSSSaltLengthEqualsHash: + saltLength = o.HashFunc().Size() + } + return unsafe.Pointer(&pssPaddingInfo{ + pszAlgID: algID, + cbSalt: uint32(saltLength), + }), bCryptPadPSS, nil + } + return unsafe.Pointer(&pkcs1PaddingInfo{ + pszAlgID: algID, + }), bCryptPadPKCS1, nil +} + // DecrypterOpts implements crypto.DecrypterOpts and contains the // flags required for the NCryptDecrypt system call. type DecrypterOpts struct {