Skip to content

Commit

Permalink
chore: parse jwks in config initialisation
Browse files Browse the repository at this point in the history
  • Loading branch information
kangmingtay committed Jul 23, 2024
1 parent 76096e8 commit e2751e1
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 51 deletions.
33 changes: 4 additions & 29 deletions internal/api/jwks.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package api

import (
"crypto/x509"
"encoding/base64"
"net/http"

jwk "github.com/lestrrat-go/jwx/v2/jwk"
Expand All @@ -20,35 +18,12 @@ func (a *API) Jwks(w http.ResponseWriter, r *http.Request) error {
Keys: keys,
}

for kid, key := range config.JWT.Keys {
if key.Type == "hmac" {
// don't display hmac key in jwks
for _, key := range config.JWT.Keys {
// don't expose hmac jwk in endpoint
if key.PrivateKey.Algorithm().String() == "HS256" {
continue
}

// public keys are stored as base64 encoded DER
derBytes, err := base64.StdEncoding.DecodeString(key.PublicKey)
if err != nil {
return internalServerError("Error decoding public key for kid: %v", kid).WithInternalError(err)
}

// public keys are assumed to be stored in spki format
// x509 only supports the P256 curve for EC and ED25519
pubKey, err := x509.ParsePKIXPublicKey(derBytes)
if err != nil {
return internalServerError("Error parsing public key for kid: %v", kid).WithInternalError(err)
}
k, err := jwk.FromRaw(pubKey)
if err != nil {
return internalServerError("Error parsing jwk for kid: %v", kid).WithInternalError(err)
}
k.Set(jwk.KeyIDKey, kid)

k.Set(jwk.KeyUsageKey, "enc")
if key.InUse {
k.Set(jwk.KeyUsageKey, "sig")
}
resp.Keys = append(resp.Keys, k)
resp.Keys = append(resp.Keys, key.PublicKey)
}

return sendJSON(w, http.StatusOK, resp)
Expand Down
34 changes: 12 additions & 22 deletions internal/conf/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package conf
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/url"
Expand All @@ -16,6 +15,7 @@ import (
"github.com/gobwas/glob"
"github.com/joho/godotenv"
"github.com/kelseyhightower/envconfig"
"github.com/lestrrat-go/jwx/v2/jwk"
)

const defaultMinPasswordLength int = 6
Expand Down Expand Up @@ -104,22 +104,6 @@ type JWTConfiguration struct {
Keys JwtKeysDecoder `json:"keys"`
}

type JwtKeysDecoder map[string]KeyInfo

func (j *JwtKeysDecoder) Decode(value string) error {
if err := json.Unmarshal([]byte(value), j); err != nil {
return err
}
return nil
}

type KeyInfo struct {
Type string `json:"type"`
PublicKey string `json:"public_key"`
PrivateKey string `json:"private_key"`
InUse bool `json:"in_use"`
}

// MFAConfiguration holds all the MFA related Configuration
type MFAConfiguration struct {
Enabled bool `default:"false"`
Expand Down Expand Up @@ -727,11 +711,16 @@ func (config *GlobalConfiguration) ApplyDefaults() error {

if config.JWT.Keys == nil {
config.JWT.Keys = make(JwtKeysDecoder)
config.JWT.Keys[config.JWT.KeyID] = KeyInfo{
Type: "hmac",
PublicKey: "",
PrivateKey: config.JWT.Secret,
InUse: true,
derBytes, err := base64.StdEncoding.DecodeString(config.JWT.Secret)
if err != nil {
return err
}
privKey, err := jwk.FromRaw(derBytes)
if err != nil {
return err
}
config.JWT.Keys[config.JWT.KeyID] = JwkInfo{
PrivateKey: privKey,
}
}

Expand Down Expand Up @@ -852,6 +841,7 @@ func (c *GlobalConfiguration) Validate() error {
&c.Security,
&c.Sessions,
&c.Hook,
&c.JWT.Keys,
}

for _, validatable := range validatables {
Expand Down
152 changes: 152 additions & 0 deletions internal/conf/jwk.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
package conf

import (
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"encoding/json"
"fmt"

"github.com/lestrrat-go/jwx/v2/jwk"
)

type JwtKeysDecoder map[string]JwkInfo

type JwkInfo struct {
PublicKey jwk.Key `json:"public_key"`
PrivateKey jwk.Key `json:"private_key"`
}

type KeyInfo struct {
Type string `json:"type"`
PrivateKey string `json:"private_key"`
InUse bool `json:"in_use"`
}

// Decode implements the Decoder interface
// which transforms the keys stored as der binary strings into jwks
func (j *JwtKeysDecoder) Decode(value string) error {
data := map[string]KeyInfo{}
if err := json.Unmarshal([]byte(value), &data); err != nil {
return err
}

config := JwtKeysDecoder{}
for kid, key := range data {
// all private keys should be stored as der binary strings in base64
derBytes, err := base64.StdEncoding.DecodeString(key.PrivateKey)
if err != nil {
return err
}

var privKey any
if key.Type == "hmac" {
privKey = derBytes
} else {
// assume key is asymmetric
privKey, err = x509.ParsePKCS8PrivateKey(derBytes)
if err != nil {
return err
}
}
alg := getAlg(privKey)
if alg == "" {
return fmt.Errorf("unsupported key alg: %v", kid)
}

privJwk, err := jwk.FromRaw(privKey)
if err != nil {
return err
}
// Set kid, alg and use claims for private key
if err := privJwk.Set(jwk.KeyIDKey, kid); err != nil {
return err
}
if err := privJwk.Set(jwk.AlgorithmKey, alg); err != nil {
return err
}

switch key.InUse {
case true:
// only the key that's in use should be used for encryption
if err := privJwk.Set(jwk.KeyUsageKey, "enc"); err != nil {
return err
}
default:
if err := privJwk.Set(jwk.KeyUsageKey, "sig"); err != nil {
return err
}
}

pubJwk, err := jwk.PublicKeyOf(privJwk)
if err != nil {
return err
}

// public keys are always used for signature verification only
if err := pubJwk.Set(jwk.KeyUsageKey, "sig"); err != nil {
return err
}

config[kid] = JwkInfo{
PublicKey: pubJwk,
PrivateKey: privJwk,
}
}
*j = config
return nil
}

func (j *JwtKeysDecoder) Validate() error {
// Validate performs _minimal_ checks if the data stored in the key are valid.
// By minimal, we mean that it does not check if the key is valid for use in
// cryptographic operations. For example, it does not check if an RSA key's
// `e` field is a valid exponent, or if the `n` field is a valid modulus.
// Instead, it checks for things such as the _presence_ of some required fields,
// or if certain keys' values are of particular length.
//
// Note that depending on th underlying key type, use of this method requires
// that multiple fields in the key are properly populated. For example, an EC
// key's "x", "y" fields cannot be validated unless the "crv" field is populated first.
for _, key := range *j {
if err := key.PrivateKey.Validate(); err != nil {
return err
}
if err := key.PublicKey.Validate(); err != nil {
return err
}
}
return nil
}

func getAlg(key any) string {
var alg string
switch p := key.(type) {
case []byte:
alg = "HS256"
case *ecdsa.PrivateKey:
switch p.Curve.Params().Name {
case "P-256":
alg = "ES256"
case "P-384":
alg = "ES384"
case "P-521":
alg = "ES512"
}
case *rsa.PrivateKey:
switch p.N.BitLen() {
case 2048:
alg = "RS256"
case 4096:
alg = "RS512"
}
case *ed25519.PrivateKey:
// Ed25519 is still experimental based on https://github.com/lestrrat-go/jwx/tree/develop/v2/jwk#supported-key-types
alg = "EdDSA"
default:
return ""
}
return alg
}

0 comments on commit e2751e1

Please sign in to comment.