Skip to content

Commit

Permalink
implement HKDF using the EVP_KDF API in OpenSSL 3
Browse files Browse the repository at this point in the history
  • Loading branch information
qmuntal committed Sep 25, 2024
1 parent 4fb8ffc commit 3180102
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 97 deletions.
14 changes: 14 additions & 0 deletions evp.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,20 @@ func hashToMD(h hash.Hash) C.GO_EVP_MD_PTR {
return nil
}

// hashFuncToMD converts a hash.Hash function to a GO_EVP_MD_PTR.
// See [hashFuncHash] for details on error handling.
func hashFuncToMD(fn func() hash.Hash) (C.GO_EVP_MD_PTR, error) {
h, err := hashFuncHash(fn)
if err != nil {
return nil, err
}
md := hashToMD(h)
if md == nil {
return nil, errors.New("unsupported hash function")
}
return md, nil
}

// cryptoHashToMD converts a crypto.Hash to a GO_EVP_MD_PTR.
func cryptoHashToMD(ch crypto.Hash) (md C.GO_EVP_MD_PTR) {
if v, ok := cacheMD.Load(ch); ok {
Expand Down
283 changes: 186 additions & 97 deletions hkdf.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"hash"
"io"
"runtime"
"sync"
"unsafe"
)

Expand All @@ -18,88 +19,72 @@ func SupportsHKDF() bool {
case 1:
return versionAtOrAbove(1, 1, 1)
case 3:
// Some OpenSSL 3 providers don't support HKDF or don't support it via
// the EVP_PKEY API, which is the one we use.
// See https://github.com/golang-fips/openssl/issues/189.
ctx := C.go_openssl_EVP_PKEY_CTX_new_id(C.GO_EVP_PKEY_HKDF, nil)
if ctx == nil {
return false
}
C.go_openssl_EVP_PKEY_CTX_free(ctx)
return true
_, err := fetchHKDF3()
return err == nil
default:
panic(errUnsupportedVersion())
}
}

func newHKDF(fh func() hash.Hash, mode C.int) (*hkdf, error) {
if !SupportsHKDF() {
return nil, errUnsupportedVersion()
}
func newHKDF1(md C.GO_EVP_MD_PTR, mode C.int, secret, salt, pseudorandomKey, info []byte) (ctx C.GO_EVP_PKEY_CTX_PTR, err error) {
checkMajorVersion(1)

h, err := hashFuncHash(fh)
if err != nil {
return nil, err
}
md := hashToMD(h)
if md == nil {
return nil, errors.New("unsupported hash function")
}

ctx := C.go_openssl_EVP_PKEY_CTX_new_id(C.GO_EVP_PKEY_HKDF, nil)
ctx = C.go_openssl_EVP_PKEY_CTX_new_id(C.GO_EVP_PKEY_HKDF, nil)
if ctx == nil {
return nil, newOpenSSLError("EVP_PKEY_CTX_new_id")
}
defer func() {
C.go_openssl_EVP_PKEY_CTX_free(ctx)
if err != nil {
C.go_openssl_EVP_PKEY_CTX_free(ctx)
}
}()

if C.go_openssl_EVP_PKEY_derive_init(ctx) != 1 {
return nil, newOpenSSLError("EVP_PKEY_derive_init")
return ctx, newOpenSSLError("EVP_PKEY_derive_init")
}
switch vMajor {
case 3:
if C.go_openssl_EVP_PKEY_CTX_set_hkdf_mode(ctx, mode) != 1 {
return nil, newOpenSSLError("EVP_PKEY_CTX_set_hkdf_mode")
}
if C.go_openssl_EVP_PKEY_CTX_set_hkdf_md(ctx, md) != 1 {
return nil, newOpenSSLError("EVP_PKEY_CTX_set_hkdf_md")
}
case 1:
if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1, C.GO1_EVP_PKEY_OP_DERIVE,
C.GO_EVP_PKEY_CTRL_HKDF_MODE,
C.int(mode), nil) != 1 {
return nil, newOpenSSLError("EVP_PKEY_CTX_set_hkdf_mode")
}
if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1, C.GO1_EVP_PKEY_OP_DERIVE,
C.GO_EVP_PKEY_CTRL_HKDF_MD,
0, unsafe.Pointer(md)) != 1 {
return nil, newOpenSSLError("EVP_PKEY_CTX_set_hkdf_md")

ctrlSlice := func(ctrl int, data []byte) C.int {
if len(data) == 0 {
return 1 // No data to set.
}
return C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1, C.GO1_EVP_PKEY_OP_DERIVE, C.int(ctrl), C.int(len(salt)), unsafe.Pointer(base(data)))
}

c := &hkdf{ctx: ctx, hashLen: h.Size()}
ctx = nil

runtime.SetFinalizer(c, (*hkdf).finalize)

return c, nil
if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1, C.GO1_EVP_PKEY_OP_DERIVE, C.GO_EVP_PKEY_CTRL_HKDF_MODE, mode, nil) != 1 {
return ctx, newOpenSSLError("EVP_PKEY_CTX_set_hkdf_mode")
}
if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1, C.GO1_EVP_PKEY_OP_DERIVE, C.GO_EVP_PKEY_CTRL_HKDF_MD, 0, unsafe.Pointer(md)) != 1 {
return ctx, newOpenSSLError("EVP_PKEY_CTX_set_hkdf_md")
}
if ctrlSlice(C.GO_EVP_PKEY_CTRL_HKDF_KEY, secret) != 1 {
return ctx, newOpenSSLError("EVP_PKEY_CTX_set1_hkdf_key")
}
if ctrlSlice(C.GO_EVP_PKEY_CTRL_HKDF_SALT, salt) != 1 {
return ctx, newOpenSSLError("EVP_PKEY_CTX_set1_hkdf_salt")
}
if ctrlSlice(C.GO_EVP_PKEY_CTRL_HKDF_KEY, pseudorandomKey) != 1 {
return ctx, newOpenSSLError("EVP_PKEY_CTX_set1_hkdf_key")
}
if ctrlSlice(C.GO_EVP_PKEY_CTRL_HKDF_INFO, info) != 1 {
return ctx, newOpenSSLError("EVP_PKEY_CTX_add1_hkdf_info")
}
return ctx, nil
}

type hkdf struct {
type hkdf1 struct {
ctx C.GO_EVP_PKEY_CTX_PTR

hashLen int
buf []byte
}

func (c *hkdf) finalize() {
func (c *hkdf1) finalize() {
if c.ctx != nil {
C.go_openssl_EVP_PKEY_CTX_free(c.ctx)
}
}

func (c *hkdf) Read(p []byte) (int, error) {
func (c *hkdf1) Read(p []byte) (int, error) {
defer runtime.KeepAlive(c)

// EVP_PKEY_derive doesn't support incremental output, each call
Expand All @@ -125,69 +110,173 @@ func (c *hkdf) Read(p []byte) (int, error) {
}

func ExtractHKDF(h func() hash.Hash, secret, salt []byte) ([]byte, error) {
c, err := newHKDF(h, C.GO_EVP_KDF_HKDF_MODE_EXTRACT_ONLY)
if !SupportsHKDF() {
return nil, errUnsupportedVersion()
}

md, err := hashFuncToMD(h)
if err != nil {
return nil, err
}

switch vMajor {
case 3:
if C.go_openssl_EVP_PKEY_CTX_set1_hkdf_key(c.ctx,
base(secret), C.int(len(secret))) != 1 {
return nil, newOpenSSLError("EVP_PKEY_CTX_set1_hkdf_key")
case 1:
ctx, err := newHKDF1(md, C.GO_EVP_KDF_HKDF_MODE_EXTRACT_ONLY, secret, salt, nil, nil)
if err != nil {
return nil, err
}
if C.go_openssl_EVP_PKEY_CTX_set1_hkdf_salt(c.ctx,
base(salt), C.int(len(salt))) != 1 {
return nil, newOpenSSLError("EVP_PKEY_CTX_set1_hkdf_salt")
defer C.go_openssl_EVP_PKEY_CTX_free(ctx)
r := C.go_openssl_EVP_PKEY_derive_wrapper(ctx, nil, 0)
if r.result != 1 {
return nil, newOpenSSLError("EVP_PKEY_derive_init")
}
case 1:
if C.go_openssl_EVP_PKEY_CTX_ctrl(c.ctx, -1, C.GO1_EVP_PKEY_OP_DERIVE,
C.GO_EVP_PKEY_CTRL_HKDF_KEY,
C.int(len(secret)), unsafe.Pointer(base(secret))) != 1 {
return nil, newOpenSSLError("EVP_PKEY_CTX_set1_hkdf_key")
out := make([]byte, r.keylen)
if C.go_openssl_EVP_PKEY_derive_wrapper(ctx, base(out), r.keylen).result != 1 {
return nil, newOpenSSLError("EVP_PKEY_derive")
}
if C.go_openssl_EVP_PKEY_CTX_ctrl(c.ctx, -1, C.GO1_EVP_PKEY_OP_DERIVE,
C.GO_EVP_PKEY_CTRL_HKDF_SALT,
C.int(len(salt)), unsafe.Pointer(base(salt))) != 1 {
return nil, newOpenSSLError("EVP_PKEY_CTX_set1_hkdf_salt")
return out[:r.keylen], nil
case 3:
ctx, err := newHKDF3(md, C.GO_EVP_KDF_HKDF_MODE_EXTRACT_ONLY, secret, salt, nil, nil)
if err != nil {
return nil, err
}
defer C.go_openssl_EVP_KDF_CTX_free(ctx)
out := make([]byte, C.go_openssl_EVP_KDF_CTX_get_kdf_size(ctx))
if C.go_openssl_EVP_KDF_derive(ctx, base(out), C.size_t(len(out)), nil) != 1 {
return nil, newOpenSSLError("EVP_KDF_derive")
}
return out, nil
default:
panic(errUnsupportedVersion())
}
r := C.go_openssl_EVP_PKEY_derive_wrapper(c.ctx, nil, 0)
if r.result != 1 {
return nil, newOpenSSLError("EVP_PKEY_derive_init")
}
out := make([]byte, r.keylen)
if C.go_openssl_EVP_PKEY_derive_wrapper(c.ctx, base(out), r.keylen).result != 1 {
return nil, newOpenSSLError("EVP_PKEY_derive")
}
return out[:r.keylen], nil
}

func ExpandHKDF(h func() hash.Hash, pseudorandomKey, info []byte) (io.Reader, error) {
c, err := newHKDF(h, C.GO_EVP_KDF_HKDF_MODE_EXPAND_ONLY)
if !SupportsHKDF() {
return nil, errUnsupportedVersion()
}

md, err := hashFuncToMD(h)
if err != nil {
return nil, err
}

switch vMajor {
case 3:
if C.go_openssl_EVP_PKEY_CTX_set1_hkdf_key(c.ctx,
base(pseudorandomKey), C.int(len(pseudorandomKey))) != 1 {
return nil, newOpenSSLError("EVP_PKEY_CTX_set1_hkdf_key")
}
if C.go_openssl_EVP_PKEY_CTX_add1_hkdf_info(c.ctx,
base(info), C.int(len(info))) != 1 {
return nil, newOpenSSLError("EVP_PKEY_CTX_add1_hkdf_info")
}
case 1:
if C.go_openssl_EVP_PKEY_CTX_ctrl(c.ctx, -1, C.GO1_EVP_PKEY_OP_DERIVE,
C.GO_EVP_PKEY_CTRL_HKDF_KEY,
C.int(len(pseudorandomKey)), unsafe.Pointer(base(pseudorandomKey))) != 1 {
return nil, newOpenSSLError("EVP_PKEY_CTX_set1_hkdf_key")
ctx, err := newHKDF1(md, C.GO_EVP_KDF_HKDF_MODE_EXPAND_ONLY, nil, nil, pseudorandomKey, info)
if err != nil {
return nil, err
}
if C.go_openssl_EVP_PKEY_CTX_ctrl(c.ctx, -1, C.GO1_EVP_PKEY_OP_DERIVE,
C.GO_EVP_PKEY_CTRL_HKDF_INFO,
C.int(len(info)), unsafe.Pointer(base(info))) != 1 {
return nil, newOpenSSLError("EVP_PKEY_CTX_add1_hkdf_info")
c := &hkdf1{ctx: ctx, hashLen: int(C.go_openssl_EVP_MD_get_size(md))}
runtime.SetFinalizer(c, (*hkdf1).finalize)
return c, nil
case 3:
ctx, err := newHKDF3(md, C.GO_EVP_KDF_HKDF_MODE_EXPAND_ONLY, nil, nil, pseudorandomKey, info)
if err != nil {
return nil, err
}
c := &hkdf3{ctx: ctx, hashLen: int(C.go_openssl_EVP_MD_get_size(md))}
runtime.SetFinalizer(c, (*hkdf3).finalize)
return c, nil
default:
panic(errUnsupportedVersion())
}
}

type hkdf3 struct {
ctx C.GO_EVP_KDF_CTX_PTR

hashLen int
buf []byte
}

func (c *hkdf3) finalize() {
if c.ctx != nil {
C.go_openssl_EVP_KDF_CTX_free(c.ctx)
}
}

// fetchHKDF3 fetches the HKDF algorithm.
// It is safe to call this function concurrently.
// The returned EVP_KDF_PTR shouldn't be freed.
var fetchHKDF3 = sync.OnceValues(func() (C.GO_EVP_KDF_PTR, error) {
checkMajorVersion(3)

name := C.CString("HKDF")
kdf := C.go_openssl_EVP_KDF_fetch(nil, name, nil)
C.free(unsafe.Pointer(name))
if kdf == nil {
return nil, newOpenSSLError("EVP_KDF_fetch")
}
return kdf, nil
})

// newHKDF3 implements HKDF for OpenSSL 3 using the EVP_KDF API.
func newHKDF3(md C.GO_EVP_MD_PTR, mode C.int, secret, salt, pseudorandomKey, info []byte) (C.GO_EVP_KDF_CTX_PTR, error) {
checkMajorVersion(3)

kdf, err := fetchHKDF3()
if err != nil {
return nil, err
}
ctx := C.go_openssl_EVP_KDF_CTX_new(kdf)
if ctx == nil {
return nil, newOpenSSLError("EVP_KDF_CTX_new")
}

bld, err := newParamBuilder()
if err != nil {
return nil, err
}
bld.addUTF8String(_OSSL_KDF_PARAM_DIGEST, C.go_openssl_EVP_MD_get0_name(md), 0)
bld.addInt32(_OSSL_KDF_PARAM_MODE, int32(mode))
if len(secret) > 0 {
bld.addOctetString(_OSSL_KDF_PARAM_KEY, secret)
}
if len(salt) > 0 {
bld.addOctetString(_OSSL_KDF_PARAM_SALT, salt)
}
if len(pseudorandomKey) > 0 {
bld.addOctetString(_OSSL_KDF_PARAM_KEY, pseudorandomKey)
}
if len(info) > 0 {
bld.addOctetString(_OSSL_KDF_PARAM_INFO, info)
}
return c, nil
params, err := bld.build()
if err != nil {
C.go_openssl_EVP_KDF_CTX_free(ctx)
return nil, err
}
defer C.go_openssl_OSSL_PARAM_free(params)

if C.go_openssl_EVP_KDF_CTX_set_params(ctx, params) != 1 {
C.go_openssl_EVP_KDF_CTX_free(ctx)
return nil, newOpenSSLError("EVP_KDF_CTX_set_params")
}
return ctx, nil
}

func (c *hkdf3) Read(p []byte) (int, error) {
defer runtime.KeepAlive(c)

// EVP_KDF_derive doesn't support incremental output, each call
// derives the key from scratch and returns the requested bytes.
// To implement io.Reader, we need to ask for len(c.buf) + len(p)
// bytes and copy the last derived len(p) bytes to p.
// We use c.buf to know how many bytes we've already derived and
// to avoid allocating the whole output buffer on each call.
prevLen := len(c.buf)
needLen := len(p)
remains := 255*c.hashLen - prevLen
// Check whether enough data can be generated.
if remains < needLen {
return 0, errors.New("hkdf: entropy limit reached")
}
c.buf = append(c.buf, make([]byte, needLen)...)
outLen := C.size_t(prevLen + needLen)
if C.go_openssl_EVP_KDF_derive(c.ctx, base(c.buf), outLen, nil) != 1 {
return 0, newOpenSSLError("EVP_KDF_derive")
}
n := copy(p, c.buf[prevLen:outLen])
return n, nil
}
14 changes: 14 additions & 0 deletions params.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ var (
_OSSL_KDF_PARAM_DIGEST = C.CString("digest")
_OSSL_KDF_PARAM_SECRET = C.CString("secret")
_OSSL_KDF_PARAM_SEED = C.CString("seed")
_OSSL_KDF_PARAM_KEY = C.CString("key")
_OSSL_KDF_PARAM_INFO = C.CString("info")
_OSSL_KDF_PARAM_SALT = C.CString("salt")
_OSSL_KDF_PARAM_MODE = C.CString("mode")
)

// paramBuilder is a helper for building OSSL_PARAMs.
Expand Down Expand Up @@ -103,3 +107,13 @@ func (b *paramBuilder) addOctetString(name *C.char, value []byte) {
b.err = newOpenSSLError("OSSL_PARAM_BLD_push_octet_string(" + C.GoString(name) + ")")
}
}

// addIn32 adds a int32 to the builder.
func (b *paramBuilder) addInt32(name *C.char, value int32) {
if !b.check() {
return
}
if C.go_openssl_OSSL_PARAM_BLD_push_int32(b.bld, name, C.int32_t(value)) != 1 {
b.err = newOpenSSLError("OSSL_PARAM_BLD_push_int32(" + C.GoString(name) + ")")
}
}
Loading

0 comments on commit 3180102

Please sign in to comment.