From 31801029d9e1fad1156452a6aa72fb407f04bc1d Mon Sep 17 00:00:00 2001 From: qmuntal Date: Wed, 25 Sep 2024 12:09:07 +0000 Subject: [PATCH] implement HKDF using the EVP_KDF API in OpenSSL 3 --- evp.go | 14 +++ hkdf.go | 283 +++++++++++++++++++++++++++++++++++------------------- params.go | 14 +++ shims.h | 4 + 4 files changed, 218 insertions(+), 97 deletions(-) diff --git a/evp.go b/evp.go index 1c39ade..aa85db6 100644 --- a/evp.go +++ b/evp.go @@ -69,6 +69,20 @@ func hashToMD(h hash.Hash) C.GO_EVP_MD_PTR { return nil } +// hashFuncToMD converts a hash.Hash function to a GO_EVP_MD_PTR. +// See [hashFuncHash] for details on error handling. +func hashFuncToMD(fn func() hash.Hash) (C.GO_EVP_MD_PTR, error) { + h, err := hashFuncHash(fn) + if err != nil { + return nil, err + } + md := hashToMD(h) + if md == nil { + return nil, errors.New("unsupported hash function") + } + return md, nil +} + // cryptoHashToMD converts a crypto.Hash to a GO_EVP_MD_PTR. func cryptoHashToMD(ch crypto.Hash) (md C.GO_EVP_MD_PTR) { if v, ok := cacheMD.Load(ch); ok { diff --git a/hkdf.go b/hkdf.go index 51a05a8..0665970 100644 --- a/hkdf.go +++ b/hkdf.go @@ -9,6 +9,7 @@ import ( "hash" "io" "runtime" + "sync" "unsafe" ) @@ -18,88 +19,72 @@ func SupportsHKDF() bool { case 1: return versionAtOrAbove(1, 1, 1) case 3: - // Some OpenSSL 3 providers don't support HKDF or don't support it via - // the EVP_PKEY API, which is the one we use. - // See https://github.com/golang-fips/openssl/issues/189. - ctx := C.go_openssl_EVP_PKEY_CTX_new_id(C.GO_EVP_PKEY_HKDF, nil) - if ctx == nil { - return false - } - C.go_openssl_EVP_PKEY_CTX_free(ctx) - return true + _, err := fetchHKDF3() + return err == nil default: panic(errUnsupportedVersion()) } } -func newHKDF(fh func() hash.Hash, mode C.int) (*hkdf, error) { - if !SupportsHKDF() { - return nil, errUnsupportedVersion() - } +func newHKDF1(md C.GO_EVP_MD_PTR, mode C.int, secret, salt, pseudorandomKey, info []byte) (ctx C.GO_EVP_PKEY_CTX_PTR, err error) { + checkMajorVersion(1) - h, err := hashFuncHash(fh) - if err != nil { - return nil, err - } - md := hashToMD(h) - if md == nil { - return nil, errors.New("unsupported hash function") - } - - ctx := C.go_openssl_EVP_PKEY_CTX_new_id(C.GO_EVP_PKEY_HKDF, nil) + ctx = C.go_openssl_EVP_PKEY_CTX_new_id(C.GO_EVP_PKEY_HKDF, nil) if ctx == nil { return nil, newOpenSSLError("EVP_PKEY_CTX_new_id") } defer func() { - C.go_openssl_EVP_PKEY_CTX_free(ctx) + if err != nil { + C.go_openssl_EVP_PKEY_CTX_free(ctx) + } }() if C.go_openssl_EVP_PKEY_derive_init(ctx) != 1 { - return nil, newOpenSSLError("EVP_PKEY_derive_init") + return ctx, newOpenSSLError("EVP_PKEY_derive_init") } - switch vMajor { - case 3: - if C.go_openssl_EVP_PKEY_CTX_set_hkdf_mode(ctx, mode) != 1 { - return nil, newOpenSSLError("EVP_PKEY_CTX_set_hkdf_mode") - } - if C.go_openssl_EVP_PKEY_CTX_set_hkdf_md(ctx, md) != 1 { - return nil, newOpenSSLError("EVP_PKEY_CTX_set_hkdf_md") - } - case 1: - if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1, C.GO1_EVP_PKEY_OP_DERIVE, - C.GO_EVP_PKEY_CTRL_HKDF_MODE, - C.int(mode), nil) != 1 { - return nil, newOpenSSLError("EVP_PKEY_CTX_set_hkdf_mode") - } - if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1, C.GO1_EVP_PKEY_OP_DERIVE, - C.GO_EVP_PKEY_CTRL_HKDF_MD, - 0, unsafe.Pointer(md)) != 1 { - return nil, newOpenSSLError("EVP_PKEY_CTX_set_hkdf_md") + + ctrlSlice := func(ctrl int, data []byte) C.int { + if len(data) == 0 { + return 1 // No data to set. } + return C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1, C.GO1_EVP_PKEY_OP_DERIVE, C.int(ctrl), C.int(len(salt)), unsafe.Pointer(base(data))) } - c := &hkdf{ctx: ctx, hashLen: h.Size()} - ctx = nil - - runtime.SetFinalizer(c, (*hkdf).finalize) - - return c, nil + if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1, C.GO1_EVP_PKEY_OP_DERIVE, C.GO_EVP_PKEY_CTRL_HKDF_MODE, mode, nil) != 1 { + return ctx, newOpenSSLError("EVP_PKEY_CTX_set_hkdf_mode") + } + if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1, C.GO1_EVP_PKEY_OP_DERIVE, C.GO_EVP_PKEY_CTRL_HKDF_MD, 0, unsafe.Pointer(md)) != 1 { + return ctx, newOpenSSLError("EVP_PKEY_CTX_set_hkdf_md") + } + if ctrlSlice(C.GO_EVP_PKEY_CTRL_HKDF_KEY, secret) != 1 { + return ctx, newOpenSSLError("EVP_PKEY_CTX_set1_hkdf_key") + } + if ctrlSlice(C.GO_EVP_PKEY_CTRL_HKDF_SALT, salt) != 1 { + return ctx, newOpenSSLError("EVP_PKEY_CTX_set1_hkdf_salt") + } + if ctrlSlice(C.GO_EVP_PKEY_CTRL_HKDF_KEY, pseudorandomKey) != 1 { + return ctx, newOpenSSLError("EVP_PKEY_CTX_set1_hkdf_key") + } + if ctrlSlice(C.GO_EVP_PKEY_CTRL_HKDF_INFO, info) != 1 { + return ctx, newOpenSSLError("EVP_PKEY_CTX_add1_hkdf_info") + } + return ctx, nil } -type hkdf struct { +type hkdf1 struct { ctx C.GO_EVP_PKEY_CTX_PTR hashLen int buf []byte } -func (c *hkdf) finalize() { +func (c *hkdf1) finalize() { if c.ctx != nil { C.go_openssl_EVP_PKEY_CTX_free(c.ctx) } } -func (c *hkdf) Read(p []byte) (int, error) { +func (c *hkdf1) Read(p []byte) (int, error) { defer runtime.KeepAlive(c) // EVP_PKEY_derive doesn't support incremental output, each call @@ -125,69 +110,173 @@ func (c *hkdf) Read(p []byte) (int, error) { } func ExtractHKDF(h func() hash.Hash, secret, salt []byte) ([]byte, error) { - c, err := newHKDF(h, C.GO_EVP_KDF_HKDF_MODE_EXTRACT_ONLY) + if !SupportsHKDF() { + return nil, errUnsupportedVersion() + } + + md, err := hashFuncToMD(h) if err != nil { return nil, err } + switch vMajor { - case 3: - if C.go_openssl_EVP_PKEY_CTX_set1_hkdf_key(c.ctx, - base(secret), C.int(len(secret))) != 1 { - return nil, newOpenSSLError("EVP_PKEY_CTX_set1_hkdf_key") + case 1: + ctx, err := newHKDF1(md, C.GO_EVP_KDF_HKDF_MODE_EXTRACT_ONLY, secret, salt, nil, nil) + if err != nil { + return nil, err } - if C.go_openssl_EVP_PKEY_CTX_set1_hkdf_salt(c.ctx, - base(salt), C.int(len(salt))) != 1 { - return nil, newOpenSSLError("EVP_PKEY_CTX_set1_hkdf_salt") + defer C.go_openssl_EVP_PKEY_CTX_free(ctx) + r := C.go_openssl_EVP_PKEY_derive_wrapper(ctx, nil, 0) + if r.result != 1 { + return nil, newOpenSSLError("EVP_PKEY_derive_init") } - case 1: - if C.go_openssl_EVP_PKEY_CTX_ctrl(c.ctx, -1, C.GO1_EVP_PKEY_OP_DERIVE, - C.GO_EVP_PKEY_CTRL_HKDF_KEY, - C.int(len(secret)), unsafe.Pointer(base(secret))) != 1 { - return nil, newOpenSSLError("EVP_PKEY_CTX_set1_hkdf_key") + out := make([]byte, r.keylen) + if C.go_openssl_EVP_PKEY_derive_wrapper(ctx, base(out), r.keylen).result != 1 { + return nil, newOpenSSLError("EVP_PKEY_derive") } - if C.go_openssl_EVP_PKEY_CTX_ctrl(c.ctx, -1, C.GO1_EVP_PKEY_OP_DERIVE, - C.GO_EVP_PKEY_CTRL_HKDF_SALT, - C.int(len(salt)), unsafe.Pointer(base(salt))) != 1 { - return nil, newOpenSSLError("EVP_PKEY_CTX_set1_hkdf_salt") + return out[:r.keylen], nil + case 3: + ctx, err := newHKDF3(md, C.GO_EVP_KDF_HKDF_MODE_EXTRACT_ONLY, secret, salt, nil, nil) + if err != nil { + return nil, err } + defer C.go_openssl_EVP_KDF_CTX_free(ctx) + out := make([]byte, C.go_openssl_EVP_KDF_CTX_get_kdf_size(ctx)) + if C.go_openssl_EVP_KDF_derive(ctx, base(out), C.size_t(len(out)), nil) != 1 { + return nil, newOpenSSLError("EVP_KDF_derive") + } + return out, nil + default: + panic(errUnsupportedVersion()) } - r := C.go_openssl_EVP_PKEY_derive_wrapper(c.ctx, nil, 0) - if r.result != 1 { - return nil, newOpenSSLError("EVP_PKEY_derive_init") - } - out := make([]byte, r.keylen) - if C.go_openssl_EVP_PKEY_derive_wrapper(c.ctx, base(out), r.keylen).result != 1 { - return nil, newOpenSSLError("EVP_PKEY_derive") - } - return out[:r.keylen], nil } func ExpandHKDF(h func() hash.Hash, pseudorandomKey, info []byte) (io.Reader, error) { - c, err := newHKDF(h, C.GO_EVP_KDF_HKDF_MODE_EXPAND_ONLY) + if !SupportsHKDF() { + return nil, errUnsupportedVersion() + } + + md, err := hashFuncToMD(h) if err != nil { return nil, err } + switch vMajor { - case 3: - if C.go_openssl_EVP_PKEY_CTX_set1_hkdf_key(c.ctx, - base(pseudorandomKey), C.int(len(pseudorandomKey))) != 1 { - return nil, newOpenSSLError("EVP_PKEY_CTX_set1_hkdf_key") - } - if C.go_openssl_EVP_PKEY_CTX_add1_hkdf_info(c.ctx, - base(info), C.int(len(info))) != 1 { - return nil, newOpenSSLError("EVP_PKEY_CTX_add1_hkdf_info") - } case 1: - if C.go_openssl_EVP_PKEY_CTX_ctrl(c.ctx, -1, C.GO1_EVP_PKEY_OP_DERIVE, - C.GO_EVP_PKEY_CTRL_HKDF_KEY, - C.int(len(pseudorandomKey)), unsafe.Pointer(base(pseudorandomKey))) != 1 { - return nil, newOpenSSLError("EVP_PKEY_CTX_set1_hkdf_key") + ctx, err := newHKDF1(md, C.GO_EVP_KDF_HKDF_MODE_EXPAND_ONLY, nil, nil, pseudorandomKey, info) + if err != nil { + return nil, err } - if C.go_openssl_EVP_PKEY_CTX_ctrl(c.ctx, -1, C.GO1_EVP_PKEY_OP_DERIVE, - C.GO_EVP_PKEY_CTRL_HKDF_INFO, - C.int(len(info)), unsafe.Pointer(base(info))) != 1 { - return nil, newOpenSSLError("EVP_PKEY_CTX_add1_hkdf_info") + c := &hkdf1{ctx: ctx, hashLen: int(C.go_openssl_EVP_MD_get_size(md))} + runtime.SetFinalizer(c, (*hkdf1).finalize) + return c, nil + case 3: + ctx, err := newHKDF3(md, C.GO_EVP_KDF_HKDF_MODE_EXPAND_ONLY, nil, nil, pseudorandomKey, info) + if err != nil { + return nil, err } + c := &hkdf3{ctx: ctx, hashLen: int(C.go_openssl_EVP_MD_get_size(md))} + runtime.SetFinalizer(c, (*hkdf3).finalize) + return c, nil + default: + panic(errUnsupportedVersion()) + } +} + +type hkdf3 struct { + ctx C.GO_EVP_KDF_CTX_PTR + + hashLen int + buf []byte +} + +func (c *hkdf3) finalize() { + if c.ctx != nil { + C.go_openssl_EVP_KDF_CTX_free(c.ctx) + } +} + +// fetchHKDF3 fetches the HKDF algorithm. +// It is safe to call this function concurrently. +// The returned EVP_KDF_PTR shouldn't be freed. +var fetchHKDF3 = sync.OnceValues(func() (C.GO_EVP_KDF_PTR, error) { + checkMajorVersion(3) + + name := C.CString("HKDF") + kdf := C.go_openssl_EVP_KDF_fetch(nil, name, nil) + C.free(unsafe.Pointer(name)) + if kdf == nil { + return nil, newOpenSSLError("EVP_KDF_fetch") + } + return kdf, nil +}) + +// newHKDF3 implements HKDF for OpenSSL 3 using the EVP_KDF API. +func newHKDF3(md C.GO_EVP_MD_PTR, mode C.int, secret, salt, pseudorandomKey, info []byte) (C.GO_EVP_KDF_CTX_PTR, error) { + checkMajorVersion(3) + + kdf, err := fetchHKDF3() + if err != nil { + return nil, err + } + ctx := C.go_openssl_EVP_KDF_CTX_new(kdf) + if ctx == nil { + return nil, newOpenSSLError("EVP_KDF_CTX_new") + } + + bld, err := newParamBuilder() + if err != nil { + return nil, err + } + bld.addUTF8String(_OSSL_KDF_PARAM_DIGEST, C.go_openssl_EVP_MD_get0_name(md), 0) + bld.addInt32(_OSSL_KDF_PARAM_MODE, int32(mode)) + if len(secret) > 0 { + bld.addOctetString(_OSSL_KDF_PARAM_KEY, secret) + } + if len(salt) > 0 { + bld.addOctetString(_OSSL_KDF_PARAM_SALT, salt) + } + if len(pseudorandomKey) > 0 { + bld.addOctetString(_OSSL_KDF_PARAM_KEY, pseudorandomKey) + } + if len(info) > 0 { + bld.addOctetString(_OSSL_KDF_PARAM_INFO, info) } - return c, nil + params, err := bld.build() + if err != nil { + C.go_openssl_EVP_KDF_CTX_free(ctx) + return nil, err + } + defer C.go_openssl_OSSL_PARAM_free(params) + + if C.go_openssl_EVP_KDF_CTX_set_params(ctx, params) != 1 { + C.go_openssl_EVP_KDF_CTX_free(ctx) + return nil, newOpenSSLError("EVP_KDF_CTX_set_params") + } + return ctx, nil +} + +func (c *hkdf3) Read(p []byte) (int, error) { + defer runtime.KeepAlive(c) + + // EVP_KDF_derive doesn't support incremental output, each call + // derives the key from scratch and returns the requested bytes. + // To implement io.Reader, we need to ask for len(c.buf) + len(p) + // bytes and copy the last derived len(p) bytes to p. + // We use c.buf to know how many bytes we've already derived and + // to avoid allocating the whole output buffer on each call. + prevLen := len(c.buf) + needLen := len(p) + remains := 255*c.hashLen - prevLen + // Check whether enough data can be generated. + if remains < needLen { + return 0, errors.New("hkdf: entropy limit reached") + } + c.buf = append(c.buf, make([]byte, needLen)...) + outLen := C.size_t(prevLen + needLen) + if C.go_openssl_EVP_KDF_derive(c.ctx, base(c.buf), outLen, nil) != 1 { + return 0, newOpenSSLError("EVP_KDF_derive") + } + n := copy(p, c.buf[prevLen:outLen]) + return n, nil } diff --git a/params.go b/params.go index 867f81b..2da09ac 100644 --- a/params.go +++ b/params.go @@ -14,6 +14,10 @@ var ( _OSSL_KDF_PARAM_DIGEST = C.CString("digest") _OSSL_KDF_PARAM_SECRET = C.CString("secret") _OSSL_KDF_PARAM_SEED = C.CString("seed") + _OSSL_KDF_PARAM_KEY = C.CString("key") + _OSSL_KDF_PARAM_INFO = C.CString("info") + _OSSL_KDF_PARAM_SALT = C.CString("salt") + _OSSL_KDF_PARAM_MODE = C.CString("mode") ) // paramBuilder is a helper for building OSSL_PARAMs. @@ -103,3 +107,13 @@ func (b *paramBuilder) addOctetString(name *C.char, value []byte) { b.err = newOpenSSLError("OSSL_PARAM_BLD_push_octet_string(" + C.GoString(name) + ")") } } + +// addIn32 adds a int32 to the builder. +func (b *paramBuilder) addInt32(name *C.char, value int32) { + if !b.check() { + return + } + if C.go_openssl_OSSL_PARAM_BLD_push_int32(b.bld, name, C.int32_t(value)) != 1 { + b.err = newOpenSSLError("OSSL_PARAM_BLD_push_int32(" + C.GoString(name) + ")") + } +} diff --git a/shims.h b/shims.h index a28ddde..93739a5 100644 --- a/shims.h +++ b/shims.h @@ -202,6 +202,7 @@ DEFINEFUNC_3_0(GO_EVP_MD_PTR, EVP_MD_fetch, (GO_OSSL_LIB_CTX_PTR ctx, const char DEFINEFUNC_3_0(void, EVP_MD_free, (GO_EVP_MD_PTR md), (md)) \ DEFINEFUNC_3_0(const char *, EVP_MD_get0_name, (const GO_EVP_MD_PTR md), (md)) \ DEFINEFUNC_3_0(const GO_OSSL_PROVIDER_PTR, EVP_MD_get0_provider, (const GO_EVP_MD_PTR md), (md)) \ +DEFINEFUNC_RENAMED_3_0(int, EVP_MD_get_size, EVP_MD_size, (const GO_EVP_MD_PTR md), (md)) \ DEFINEFUNC_RENAMED_3_0(int, EVP_MD_get_block_size, EVP_MD_block_size, (const GO_EVP_MD_PTR md), (md)) \ DEFINEFUNC(int, RAND_bytes, (unsigned char *arg0, int arg1), (arg0, arg1)) \ DEFINEFUNC_RENAMED_1_1(GO_EVP_MD_CTX_PTR, EVP_MD_CTX_new, EVP_MD_CTX_create, (void), ()) \ @@ -368,6 +369,7 @@ DEFINEFUNC_3_0(GO_OSSL_PARAM_PTR, OSSL_PARAM_BLD_to_param, (GO_OSSL_PARAM_BLD_PT DEFINEFUNC_3_0(int, OSSL_PARAM_BLD_push_utf8_string, (GO_OSSL_PARAM_BLD_PTR bld, const char *key, const char *buf, size_t bsize), (bld, key, buf, bsize)) \ DEFINEFUNC_3_0(int, OSSL_PARAM_BLD_push_octet_string, (GO_OSSL_PARAM_BLD_PTR bld, const char *key, const void *buf, size_t bsize), (bld, key, buf, bsize)) \ DEFINEFUNC_3_0(int, OSSL_PARAM_BLD_push_BN, (GO_OSSL_PARAM_BLD_PTR bld, const char *key, const GO_BIGNUM_PTR bn), (bld, key, bn)) \ +DEFINEFUNC_3_0(int, OSSL_PARAM_BLD_push_int32, (GO_OSSL_PARAM_BLD_PTR bld, const char *key, int32_t num), (bld, key, num)) \ DEFINEFUNC_3_0(int, EVP_PKEY_CTX_set_hkdf_mode, (GO_EVP_PKEY_CTX_PTR arg0, int arg1), (arg0, arg1)) \ DEFINEFUNC_3_0(int, EVP_PKEY_CTX_set_hkdf_md, (GO_EVP_PKEY_CTX_PTR arg0, const GO_EVP_MD_PTR arg1), (arg0, arg1)) \ DEFINEFUNC_3_0(int, EVP_PKEY_CTX_set1_hkdf_salt, (GO_EVP_PKEY_CTX_PTR arg0, const unsigned char *arg1, int arg2), (arg0, arg1, arg2)) \ @@ -391,6 +393,8 @@ DEFINEFUNC_LEGACY_1_1(int, DSA_set0_key, (GO_DSA_PTR d, GO_BIGNUM_PTR pub_key, G DEFINEFUNC_3_0(GO_EVP_KDF_PTR, EVP_KDF_fetch, (GO_OSSL_LIB_CTX_PTR libctx, const char *algorithm, const char *properties), (libctx, algorithm, properties)) \ DEFINEFUNC_3_0(void, EVP_KDF_free, (GO_EVP_KDF_PTR kdf), (kdf)) \ DEFINEFUNC_3_0(GO_EVP_KDF_CTX_PTR, EVP_KDF_CTX_new, (GO_EVP_KDF_PTR kdf), (kdf)) \ +DEFINEFUNC_3_0(int, EVP_KDF_CTX_set_params, (GO_EVP_KDF_CTX_PTR ctx, const GO_OSSL_PARAM_PTR params), (ctx, params)) \ DEFINEFUNC_3_0(void, EVP_KDF_CTX_free, (GO_EVP_KDF_CTX_PTR ctx), (ctx)) \ +DEFINEFUNC_3_0(size_t, EVP_KDF_CTX_get_kdf_size, (GO_EVP_KDF_CTX_PTR ctx), (ctx)) \ DEFINEFUNC_3_0(int, EVP_KDF_derive, (GO_EVP_KDF_CTX_PTR ctx, unsigned char *key, size_t keylen, const GO_OSSL_PARAM_PTR params), (ctx, key, keylen, params)) \