Skip to content

Commit

Permalink
Support PSS padding.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 563089484
  • Loading branch information
CertoToStore Team authored and copybara-github committed Sep 7, 2023
1 parent c4d4bf9 commit cc6f2a7
Showing 1 changed file with 46 additions and 16 deletions.
62 changes: 46 additions & 16 deletions certtostore_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ import (
"unsafe"

"github.com/google/deck"
"github.com/hashicorp/go-multierror"
"golang.org/x/crypto/cryptobyte"
"golang.org/x/crypto/cryptobyte/asn1"
"golang.org/x/crypto/cryptobyte"
"golang.org/x/sys/windows"
"github.com/hashicorp/go-multierror"
)

// WinCertStorage provides windows-specific additions to the CertStorage interface.
Expand Down Expand Up @@ -106,6 +106,7 @@ const (

// Legacy CryptoAPI flags
bCryptPadPKCS1 uintptr = 0x2
bCryptPadPSS uintptr = 0x8

// Magic numbers for public key blobs.
rsa1Magic = 0x31415352 // "RSA1" BCRYPT_RSAPUBLIC_MAGIC
Expand Down Expand Up @@ -219,11 +220,17 @@ var (
fnGetProperty = getProperty
)

// paddingInfo is the BCRYPT_PKCS1_PADDING_INFO struct in bcrypt.h.
type paddingInfo struct {
// pkcs1PaddingInfo is the BCRYPT_PKCS1_PADDING_INFO struct in bcrypt.h.
type pkcs1PaddingInfo struct {
pszAlgID *uint16
}

// pssPaddingInfo is the BCRYPT_PSS_PADDING_INFO struct in bcrypt.h.
type pssPaddingInfo struct {
pszAlgID *uint16
cbSalt uint32
}

// wide returns a pointer to a a uint16 representing the equivalent
// to a Windows LPCWSTR.
func wide(s string) *uint16 {
Expand Down Expand Up @@ -777,12 +784,7 @@ func (k Key) Sign(_ io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, e
case "ECDSA":
return signECDSA(k.handle, digest)
case "RSA":
hf := opts.HashFunc()
algID, ok := algIDs[hf]
if !ok {
return nil, fmt.Errorf("unsupported RSA hash algorithm %v", hf)
}
return signRSA(k.handle, digest, algID)
return signRSA(k.handle, digest, opts)
default:
return nil, fmt.Errorf("unsupported algorithm group %v", k.AlgorithmGroup)
}
Expand Down Expand Up @@ -849,19 +851,22 @@ func packECDSASigValue(r io.Reader, digestLength int) ([]byte, error) {
return b.Bytes()
}

func signRSA(kh uintptr, digest []byte, algID *uint16) ([]byte, error) {
padInfo := paddingInfo{pszAlgID: algID}
func signRSA(kh uintptr, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
paddingInfo, flags, err := rsaPadding(opts)
if err != nil {
return nil, fmt.Errorf("failed to construct padding info: %v", err)
}
var size uint32
// Obtain the size of the signature
r, _, err := nCryptSignHash.Call(
kh,
uintptr(unsafe.Pointer(&padInfo)),
uintptr(paddingInfo),
uintptr(unsafe.Pointer(&digest[0])),
uintptr(len(digest)),
0,
0,
uintptr(unsafe.Pointer(&size)),
bCryptPadPKCS1)
flags)
if r != 0 {
return nil, fmt.Errorf("NCryptSignHash returned %X during size check: %v", r, err)
}
Expand All @@ -870,20 +875,45 @@ func signRSA(kh uintptr, digest []byte, algID *uint16) ([]byte, error) {
sig := make([]byte, size)
r, _, err = nCryptSignHash.Call(
kh,
uintptr(unsafe.Pointer(&padInfo)),
uintptr(paddingInfo),
uintptr(unsafe.Pointer(&digest[0])),
uintptr(len(digest)),
uintptr(unsafe.Pointer(&sig[0])),
uintptr(size),
uintptr(unsafe.Pointer(&size)),
bCryptPadPKCS1)
flags)
if r != 0 {
return nil, fmt.Errorf("NCryptSignHash returned %X during signing: %v", r, err)
}

return sig[:size], nil
}

// rsaPadding constructs the padding info structure and flags from the crypto.SignerOpts.
// https://learn.microsoft.com/en-us/windows/win32/api/bcrypt/nf-bcrypt-bcryptsignhash
func rsaPadding(opts crypto.SignerOpts) (unsafe.Pointer, uintptr, error) {
algID, ok := algIDs[opts.HashFunc()]
if !ok {
return nil, 0, fmt.Errorf("unsupported RSA hash algorithm %v", opts.HashFunc())
}
if o, ok := opts.(*rsa.PSSOptions); ok {
saltLength := o.SaltLength
switch saltLength {
case rsa.PSSSaltLengthAuto:
return nil, 0, fmt.Errorf("rsa.PSSSaltLengthAuto is not supported")
case rsa.PSSSaltLengthEqualsHash:
saltLength = o.HashFunc().Size()
}
return unsafe.Pointer(&pssPaddingInfo{
pszAlgID: algID,
cbSalt: uint32(saltLength),
}), bCryptPadPSS, nil
}
return unsafe.Pointer(&pkcs1PaddingInfo{
pszAlgID: algID,
}), bCryptPadPKCS1, nil
}

// DecrypterOpts implements crypto.DecrypterOpts and contains the
// flags required for the NCryptDecrypt system call.
type DecrypterOpts struct {
Expand Down

0 comments on commit cc6f2a7

Please sign in to comment.