-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
HMS-2693 feat: Implement JWK creation and load
Implement functions to create a new JWK with expected properties and load a serialized JWK. `GenericePrivateJWK()` creates a private JWK with EC / P256, kid, and expiration time stamp. `GetPublicJWK()` is a helper to create a public JWK from a private JWK. `ParseJWK()` deserializes a JWK string that has been previously serialized with `json.Marshal`. The function validates various properties including expiration. Signed-off-by: Christian Heimes <cheimes@redhat.com>
- Loading branch information
Showing
2 changed files
with
245 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
package token | ||
|
||
import ( | ||
"crypto" | ||
"crypto/ecdsa" | ||
"crypto/elliptic" | ||
"crypto/rand" | ||
"encoding/base64" | ||
"fmt" | ||
"time" | ||
|
||
"github.com/lestrrat-go/jwx/v2/jwa" | ||
"github.com/lestrrat-go/jwx/v2/jwk" | ||
) | ||
|
||
type KeyState int | ||
|
||
const ( | ||
ValidKey KeyState = iota | ||
ExpiredKey | ||
InvalidKey | ||
) | ||
|
||
const KeyCurve = jwa.P256 | ||
|
||
// Generate a private key with additional properties | ||
// alg: based on key type (ES256 for P256) | ||
// exp: expiration time (Unix timestamp) | ||
// kid: base64 SHA-256 thumbprint | ||
// use: "sig" | ||
func GeneratePrivateJWK(expiration time.Time) (jwk.Key, error) { | ||
var crv elliptic.Curve | ||
var alg jwa.SignatureAlgorithm | ||
|
||
switch KeyCurve { | ||
case jwa.P256: | ||
crv = elliptic.P256() | ||
alg = jwa.ES256 | ||
default: | ||
return nil, fmt.Errorf("Unsupported JWK curve %s", KeyCurve) | ||
} | ||
|
||
raw, err := ecdsa.GenerateKey(crv, rand.Reader) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
key, err := jwk.FromRaw(raw) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
// set kid to truncated SHA-256 thumbprint (RFC 7638) | ||
tp, err := key.Thumbprint(crypto.SHA256) | ||
if err != nil { | ||
return nil, err | ||
} | ||
if err = key.Set(jwk.KeyIDKey, base64.RawURLEncoding.EncodeToString(tp)[:8]); err != nil { | ||
return nil, err | ||
} | ||
|
||
// P-256 key is used for signing with ES256 | ||
if err = key.Set(jwk.KeyUsageKey, jwk.ForSignature); err != nil { | ||
return nil, err | ||
} | ||
if err = key.Set(jwk.AlgorithmKey, alg); err != nil { | ||
return nil, err | ||
} | ||
|
||
// non-standard but common expiration for key | ||
if err = key.Set("exp", expiration.Unix()); err != nil { | ||
return nil, err | ||
} | ||
|
||
return key, nil | ||
} | ||
|
||
// Get public key of a JWK | ||
func GetPublicJWK(key jwk.Key) (jwk.Key, error) { | ||
return key.PublicKey() | ||
} | ||
|
||
// Parse and validate a single JWK | ||
func ParseJWK(src []byte) (key jwk.Key, state KeyState, err error) { | ||
key, err = jwk.ParseKey(src) | ||
if err != nil { | ||
return nil, InvalidKey, err | ||
} | ||
state, err = checkJWK(key) | ||
if state == ValidKey { | ||
return key, state, nil | ||
} else { | ||
return nil, state, err | ||
} | ||
} | ||
|
||
// Parse and validate a JWKSet | ||
func ParseJWKSet(src []byte) (rs jwk.Set, err error) { | ||
s, err := jwk.Parse(src) | ||
if err != nil { | ||
return nil, err | ||
} | ||
rs = jwk.NewSet() | ||
for i := 0; i < s.Len(); i++ { | ||
key, _ := s.Key(i) | ||
state, err := checkJWK((key)) | ||
switch state { | ||
case ValidKey: | ||
if err = rs.AddKey(key); err != nil { | ||
return nil, err | ||
} | ||
case ExpiredKey: | ||
// skip expired key | ||
continue | ||
case InvalidKey: | ||
return nil, fmt.Errorf("Invalid key %d: %v", i, err) | ||
} | ||
} | ||
return rs, nil | ||
} | ||
|
||
// Verify a JWK and check that it matches our requirements | ||
func checkJWK(key jwk.Key) (KeyState, error) { | ||
if key.KeyType() != jwa.EC { | ||
return InvalidKey, fmt.Errorf("Invalid key type %s", key.KeyType()) | ||
} | ||
|
||
// Crv | ||
switch raw := key.(type) { | ||
case jwk.ECDSAPrivateKey: | ||
if raw.Crv() != jwa.P256 { | ||
return InvalidKey, fmt.Errorf("Invalid curve %s", raw.Crv().String()) | ||
} | ||
case jwk.ECDSAPublicKey: | ||
if raw.Crv() != jwa.P256 { | ||
return InvalidKey, fmt.Errorf("Invalid curve %s", raw.Crv().String()) | ||
} | ||
default: | ||
return InvalidKey, fmt.Errorf("Invalid key") | ||
} | ||
|
||
if key.KeyID() == "" { | ||
return InvalidKey, fmt.Errorf("KeyID is empty") | ||
} | ||
if key.KeyUsage() != jwk.ForSignature.String() { | ||
return InvalidKey, fmt.Errorf("Invalid key usage %s", key.KeyUsage()) | ||
} | ||
if key.Algorithm() != jwa.ES256 { | ||
return InvalidKey, fmt.Errorf("Invalid key alg %s", key.Algorithm().String()) | ||
} | ||
|
||
expif, ok := key.Get("exp") | ||
if !ok { | ||
return InvalidKey, fmt.Errorf("Missing or invalid 'exp'") | ||
} | ||
exp, ok := expif.(int64) | ||
if !ok { | ||
return InvalidKey, fmt.Errorf("Missing or invalid 'exp'") | ||
} | ||
if exp <= time.Now().Unix() { | ||
return ExpiredKey, fmt.Errorf("Key has expired") | ||
} | ||
|
||
return ValidKey, nil | ||
} | ||
|
||
func init() { | ||
var exp int64 | ||
jwk.RegisterCustomField("exp", exp) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
package token | ||
|
||
import ( | ||
"encoding/json" | ||
"testing" | ||
"time" | ||
|
||
"github.com/lestrrat-go/jwx/v2/jwa" | ||
"github.com/lestrrat-go/jwx/v2/jwk" | ||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestGeneratePrivateJWK(t *testing.T) { | ||
expiration := time.Now().Add(time.Hour) | ||
key, err := GeneratePrivateJWK(expiration) | ||
assert.NoError(t, err) | ||
|
||
raw, ok := key.(jwk.ECDSAPrivateKey) | ||
assert.True(t, ok) | ||
assert.Equal(t, raw.Crv(), jwa.P256) | ||
|
||
assert.Equal(t, key.KeyType(), jwa.EC) | ||
assert.Equal(t, len(key.KeyID()), 8) | ||
|
||
// exp is an int64 almost equal to current Unix time | ||
expif, ok := key.Get("exp") | ||
assert.True(t, ok) | ||
exp, ok := expif.(int64) | ||
assert.True(t, ok) | ||
assert.Equal(t, exp, expiration.Unix()) | ||
|
||
state, err := checkJWK(key) | ||
assert.Equal(t, state, ValidKey) | ||
assert.NoError(t, err) | ||
} | ||
|
||
func TestGetPublicK(t *testing.T) { | ||
expiration := time.Now().Add(time.Hour) | ||
key, err := GeneratePrivateJWK(expiration) | ||
assert.NoError(t, err) | ||
|
||
pub, err := GetPublicJWK(key) | ||
_, ok := pub.(jwk.ECDSAPublicKey) | ||
assert.True(t, ok) | ||
_, ok = pub.(jwk.ECDSAPrivateKey) | ||
assert.False(t, ok) | ||
|
||
pub2, err := GetPublicJWK(pub) | ||
assert.NoError(t, err) | ||
assert.Equal(t, pub, pub2) | ||
} | ||
|
||
func TestParseJWK(t *testing.T) { | ||
expiration := time.Now().Add(time.Hour) | ||
key, err := GeneratePrivateJWK(expiration) | ||
assert.NoError(t, err) | ||
|
||
s, err := json.Marshal(key) | ||
assert.NoError(t, err) | ||
parsed, state, err := ParseJWK(s) | ||
assert.NoError(t, err) | ||
assert.Equal(t, state, ValidKey) | ||
assert.Equal(t, parsed, key) | ||
|
||
pub, err := GetPublicJWK(key) | ||
assert.NoError(t, err) | ||
|
||
s, err = json.Marshal(pub) | ||
assert.NoError(t, err) | ||
parsed, state, err = ParseJWK(s) | ||
assert.Equal(t, state, ValidKey) | ||
assert.Equal(t, parsed, pub) | ||
|
||
// TODO add tests for invalid and expired keys | ||
} |