Skip to content

Commit

Permalink
HMS-2693 feat: Implement JWK creation and load
Browse files Browse the repository at this point in the history
Implement functions to create a new JWK with expected properties and
load a serialized JWK.

`GenericePrivateJWK()` creates a private JWK with EC / P256, kid, and
expiration time stamp.

`GetPublicJWK()` is a helper to create a public JWK from a private JWK.

`ParseJWK()` deserializes a JWK string that has been previously
serialized with `json.Marshal`. The function validates various
properties including expiration.

Signed-off-by: Christian Heimes <cheimes@redhat.com>
  • Loading branch information
tiran committed Oct 5, 2023
1 parent 2daaa64 commit 66a8bdd
Show file tree
Hide file tree
Showing 2 changed files with 245 additions and 0 deletions.
170 changes: 170 additions & 0 deletions internal/infrastructure/token/jwk.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
package token

import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"encoding/base64"
"fmt"
"time"

"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
)

type KeyState int

const (
ValidKey KeyState = iota
ExpiredKey
InvalidKey
)

const KeyCurve = jwa.P256

// Generate a private key with additional properties
// alg: based on key type (ES256 for P256)
// exp: expiration time (Unix timestamp)
// kid: base64 SHA-256 thumbprint
// use: "sig"
func GeneratePrivateJWK(expiration time.Time) (jwk.Key, error) {
var crv elliptic.Curve
var alg jwa.SignatureAlgorithm

switch KeyCurve {
case jwa.P256:
crv = elliptic.P256()
alg = jwa.ES256
default:
return nil, fmt.Errorf("Unsupported JWK curve %s", KeyCurve)
}

raw, err := ecdsa.GenerateKey(crv, rand.Reader)
if err != nil {
return nil, err
}

key, err := jwk.FromRaw(raw)
if err != nil {
return nil, err
}

// set kid to truncated SHA-256 thumbprint (RFC 7638)
tp, err := key.Thumbprint(crypto.SHA256)
if err != nil {
return nil, err
}
if err = key.Set(jwk.KeyIDKey, base64.RawURLEncoding.EncodeToString(tp)[:8]); err != nil {
return nil, err
}

// P-256 key is used for signing with ES256
if err = key.Set(jwk.KeyUsageKey, jwk.ForSignature); err != nil {
return nil, err
}
if err = key.Set(jwk.AlgorithmKey, alg); err != nil {
return nil, err
}

// non-standard but common expiration for key
if err = key.Set("exp", expiration.Unix()); err != nil {
return nil, err
}

return key, nil
}

// Get public key of a JWK
func GetPublicJWK(key jwk.Key) (jwk.Key, error) {
return key.PublicKey()
}

// Parse and validate a single JWK
func ParseJWK(src []byte) (key jwk.Key, state KeyState, err error) {
key, err = jwk.ParseKey(src)
if err != nil {
return nil, InvalidKey, err
}
state, err = checkJWK(key)
if state == ValidKey {
return key, state, nil
} else {
return nil, state, err
}
}

// Parse and validate a JWKSet
func ParseJWKSet(src []byte) (rs jwk.Set, err error) {
s, err := jwk.Parse(src)
if err != nil {
return nil, err
}
rs = jwk.NewSet()
for i := 0; i < s.Len(); i++ {
key, _ := s.Key(i)
state, err := checkJWK((key))
switch state {
case ValidKey:
if err = rs.AddKey(key); err != nil {
return nil, err
}
case ExpiredKey:
// skip expired key
continue
case InvalidKey:
return nil, fmt.Errorf("Invalid key %d: %v", i, err)
}
}
return rs, nil
}

// Verify a JWK and check that it matches our requirements
func checkJWK(key jwk.Key) (KeyState, error) {
if key.KeyType() != jwa.EC {
return InvalidKey, fmt.Errorf("Invalid key type %s", key.KeyType())
}

// Crv
switch raw := key.(type) {
case jwk.ECDSAPrivateKey:
if raw.Crv() != jwa.P256 {
return InvalidKey, fmt.Errorf("Invalid curve %s", raw.Crv().String())
}
case jwk.ECDSAPublicKey:
if raw.Crv() != jwa.P256 {
return InvalidKey, fmt.Errorf("Invalid curve %s", raw.Crv().String())
}
default:
return InvalidKey, fmt.Errorf("Invalid key")
}

if key.KeyID() == "" {
return InvalidKey, fmt.Errorf("KeyID is empty")
}
if key.KeyUsage() != jwk.ForSignature.String() {
return InvalidKey, fmt.Errorf("Invalid key usage %s", key.KeyUsage())
}
if key.Algorithm() != jwa.ES256 {
return InvalidKey, fmt.Errorf("Invalid key alg %s", key.Algorithm().String())
}

expif, ok := key.Get("exp")
if !ok {
return InvalidKey, fmt.Errorf("Missing or invalid 'exp'")
}
exp, ok := expif.(int64)
if !ok {
return InvalidKey, fmt.Errorf("Missing or invalid 'exp'")
}
if exp <= time.Now().Unix() {
return ExpiredKey, fmt.Errorf("Key has expired")
}

return ValidKey, nil
}

func init() {
var exp int64
jwk.RegisterCustomField("exp", exp)
}
75 changes: 75 additions & 0 deletions internal/infrastructure/token/jwk_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package token

import (
"encoding/json"
"testing"
"time"

"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/stretchr/testify/assert"
)

func TestGeneratePrivateJWK(t *testing.T) {
expiration := time.Now().Add(time.Hour)
key, err := GeneratePrivateJWK(expiration)
assert.NoError(t, err)

raw, ok := key.(jwk.ECDSAPrivateKey)
assert.True(t, ok)
assert.Equal(t, raw.Crv(), jwa.P256)

assert.Equal(t, key.KeyType(), jwa.EC)
assert.Equal(t, len(key.KeyID()), 8)

// exp is an int64 almost equal to current Unix time
expif, ok := key.Get("exp")
assert.True(t, ok)
exp, ok := expif.(int64)
assert.True(t, ok)
assert.Equal(t, exp, expiration.Unix())

state, err := checkJWK(key)
assert.Equal(t, state, ValidKey)
assert.NoError(t, err)
}

func TestGetPublicK(t *testing.T) {
expiration := time.Now().Add(time.Hour)
key, err := GeneratePrivateJWK(expiration)
assert.NoError(t, err)

pub, err := GetPublicJWK(key)
_, ok := pub.(jwk.ECDSAPublicKey)
assert.True(t, ok)
_, ok = pub.(jwk.ECDSAPrivateKey)
assert.False(t, ok)

pub2, err := GetPublicJWK(pub)
assert.NoError(t, err)
assert.Equal(t, pub, pub2)
}

func TestParseJWK(t *testing.T) {
expiration := time.Now().Add(time.Hour)
key, err := GeneratePrivateJWK(expiration)
assert.NoError(t, err)

s, err := json.Marshal(key)
assert.NoError(t, err)
parsed, state, err := ParseJWK(s)
assert.NoError(t, err)
assert.Equal(t, state, ValidKey)
assert.Equal(t, parsed, key)

pub, err := GetPublicJWK(key)
assert.NoError(t, err)

s, err = json.Marshal(pub)
assert.NoError(t, err)
parsed, state, err = ParseJWK(s)
assert.Equal(t, state, ValidKey)
assert.Equal(t, parsed, pub)

// TODO add tests for invalid and expired keys
}

0 comments on commit 66a8bdd

Please sign in to comment.