Skip to content

Commit

Permalink
fix: update config to use jwk instead
Browse files Browse the repository at this point in the history
  • Loading branch information
kangmingtay committed Jul 23, 2024
1 parent 7539931 commit d771cc4
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 102 deletions.
6 changes: 2 additions & 4 deletions internal/api/jwks.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,13 @@ type JwksResponse struct {

func (a *API) Jwks(w http.ResponseWriter, r *http.Request) error {
config := a.config

keys := []jwk.Key{}
resp := JwksResponse{
Keys: keys,
Keys: []jwk.Key{},
}

for _, key := range config.JWT.Keys {
// don't expose hmac jwk in endpoint
if key.PrivateKey.Algorithm().String() == "HS256" {
if key.PublicKey == nil || key.PublicKey.KeyType() == "oct" {
continue
}
resp.Keys = append(resp.Keys, key.PublicKey)
Expand Down
21 changes: 7 additions & 14 deletions internal/conf/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
"github.com/gobwas/glob"
"github.com/joho/godotenv"
"github.com/kelseyhightower/envconfig"
"github.com/lestrrat-go/jwx/v2/jwk"
)

const defaultMinPasswordLength int = 6
Expand Down Expand Up @@ -710,19 +709,13 @@ func (config *GlobalConfiguration) ApplyDefaults() error {
}

if config.JWT.Keys == nil {
if config.JWT.KeyID != "" && config.JWT.Secret != "" {
config.JWT.Keys = make(JwtKeysDecoder)
derBytes, err := base64.StdEncoding.DecodeString(config.JWT.Secret)
if err != nil {
derBytes = []byte(config.JWT.Secret)
}
privKey, err := jwk.FromRaw(derBytes)
if err != nil {
return err
}
config.JWT.Keys[config.JWT.KeyID] = JwkInfo{
PrivateKey: privKey,
}
key, err := getSymmetricKey(&config.JWT)
if err != nil {
return err
}
config.JWT.Keys = make(JwtKeysDecoder)
config.JWT.Keys[config.JWT.KeyID] = JwkInfo{
PrivateKey: *key,
}
}

Expand Down
106 changes: 22 additions & 84 deletions internal/conf/jwk.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
package conf

import (
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"encoding/json"
"fmt"
Expand All @@ -19,80 +15,34 @@ type JwkInfo struct {
PrivateKey jwk.Key `json:"private_key"`
}

// KeyInfo is used to store the initial config of the keys
// The private key should be in DER format and base64 encoded
type KeyInfo struct {
Type string `json:"type"`
PrivateKey string `json:"private_key"`
InUse bool `json:"in_use"`
}

// Decode implements the Decoder interface
// which transforms the keys stored as der binary strings into jwks
func (j *JwtKeysDecoder) Decode(value string) error {
data := map[string]KeyInfo{}
data := make([]map[string]interface{}, 0)
if err := json.Unmarshal([]byte(value), &data); err != nil {
return err
}

config := JwtKeysDecoder{}
for kid, key := range data {
// all private keys should be stored as der binary strings in base64
derBytes, err := base64.StdEncoding.DecodeString(key.PrivateKey)
if err != nil {
derBytes = []byte(key.PrivateKey)
}

var privKey any
if key.Type == "hmac" {
privKey = derBytes
} else {
// assume key is asymmetric
privKey, err = x509.ParsePKCS8PrivateKey(derBytes)
if err != nil {
return err
}
}
alg := getAlg(privKey)
if alg == "" {
return fmt.Errorf("unsupported key alg: %v", kid)
}

privJwk, err := jwk.FromRaw(privKey)
for _, key := range data {
bytes, err := json.Marshal(key)
if err != nil {
return err
}
// Set kid, alg and use claims for private key
if err := privJwk.Set(jwk.KeyIDKey, kid); err != nil {
return err
}
if err := privJwk.Set(jwk.AlgorithmKey, alg); err != nil {
privJwk, err := jwk.ParseKey(bytes)
if err != nil {
return err
}

switch key.InUse {
case true:
// only the key that's in use should be used for encryption
if err := privJwk.Set(jwk.KeyUsageKey, "enc"); err != nil {
return err
}
default:
if err := privJwk.Set(jwk.KeyUsageKey, "sig"); err != nil {
return err
}
}

pubJwk, err := jwk.PublicKeyOf(privJwk)
if err != nil {
return err
}

// public keys are always used for signature verification only
if err := pubJwk.Set(jwk.KeyUsageKey, "sig"); err != nil {
return err
// all public keys will be used for signature verification
if pubJwk.KeyUsage() == "enc" {
pubJwk.Set(jwk.KeyUsageKey, "sig")
}

config[kid] = JwkInfo{
config[pubJwk.KeyID()] = JwkInfo{
PublicKey: pubJwk,
PrivateKey: privJwk,
}
Expand Down Expand Up @@ -126,32 +76,20 @@ func (j *JwtKeysDecoder) Validate() error {
return nil
}

func getAlg(key any) string {
var alg string
switch p := key.(type) {
case []byte:
alg = "HS256"
case *ecdsa.PrivateKey:
switch p.Curve.Params().Name {
case "P-256":
alg = "ES256"
case "P-384":
alg = "ES384"
case "P-521":
alg = "ES512"
func getSymmetricKey(config *JWTConfiguration) (*jwk.Key, error) {
if config.Secret != "" {
bytes, err := base64.StdEncoding.DecodeString(config.Secret)
if err != nil {
bytes = []byte(config.Secret)
}
privKey, err := jwk.FromRaw(bytes)
if err != nil {
return nil, err
}
case *rsa.PrivateKey:
switch p.N.BitLen() {
case 2048:
alg = "RS256"
case 4096:
alg = "RS512"
if config.KeyID != "" {
privKey.Set(jwk.KeyIDKey, config.KeyID)
}
case *ed25519.PrivateKey:
// Ed25519 is still experimental based on https://github.com/lestrrat-go/jwx/tree/develop/v2/jwk#supported-key-types
alg = "EdDSA"
default:
return ""
return &privKey, nil
}
return alg
return nil, fmt.Errorf("missing symmetric key")
}

0 comments on commit d771cc4

Please sign in to comment.