Skip to content

Commit

Permalink
feat: add ECDH-ES, ECDH-ES+A128KW, ECDH-ES+A192KW, ECDH-ES+A256KW enc…
Browse files Browse the repository at this point in the history
…ryption alg (#51)
  • Loading branch information
vdbulcke committed Oct 21, 2023
1 parent c8abffc commit 3444088
Show file tree
Hide file tree
Showing 9 changed files with 192 additions and 32 deletions.
13 changes: 13 additions & 0 deletions example/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,19 @@ scopes:
token_signing_alg:
- "RS256"

## Token Encryption Alg: (optional)
### List allowed encryption algorithm for token validation
token_encryption_alg:
## RSA
- RSA-OAEP
- RSA-OAEP-256
## Eliptic curve
- ECDH-ES
- ECDH-ES+A128KW
- ECDH-ES+A192KW
- ECDH-ES+A256KW


## IDP Config: (Mandatory)
### NOTE: this 'issuer' will be used to find the ./well-known/openid-configuration
### by adding ./well-known/openid-configuration after the issuer base url
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/vdbulcke/oidc-client-demo

go 1.20
go 1.21

require (
github.com/coreos/go-oidc/v3 v3.4.0
Expand Down
1 change: 1 addition & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE=
github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
Expand Down
3 changes: 2 additions & 1 deletion src/client/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ type OIDCClientConfig struct {
PARAdditionalParameter map[string]string `yaml:"par_additional_parameters"`
AuthorizeAdditionalParameter map[string]string `yaml:"authorize_additional_parameters"`

TokenSigningAlg []string `yaml:"token_signing_alg" validate:"required"`
TokenSigningAlg []string `yaml:"token_signing_alg" validate:"required"`
TokenEncryptionAlg []string `yaml:"token_encryption_alg" validate:"dive,oneof=ECDH-ES RSA-OAEP RSA-OAEP-256 ECDH-ES+A128KW ECDH-ES+A192KW ECDH-ES+A256KW"`

AMRWhitelist []string `yaml:"amr_list"`
ACRWhitelist []string `yaml:"acr_list"`
Expand Down
37 changes: 23 additions & 14 deletions src/client/id_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"encoding/json"

"slices"

"github.com/coreos/go-oidc/v3/oidc"
)

Expand All @@ -18,23 +20,30 @@ func (c *OIDCClient) processIdToken(idTokenRaw string) (*oidc.IDToken, error) {
// pretty print header
c.logger.Info("IDToken header", "header", header)

if alg, ok := headerClaims["alg"]; ok {
// ony supported encryption alg
if alg == "RSA-OAEP-256" || alg == "RSA-OAEP" {
jwtPayload, err := c.jwtsigner.DecryptJWT(idTokenRaw, alg)
if err != nil {
c.logger.Error("error decrypting jwt", "error", err)
} else {

// nested jwt payload
idTokenRaw = jwtPayload
c.logger.Info("Encryped JWT", "payload", jwtPayload)
nestedHeader, _, err := c.parseJWTHeader(idTokenRaw)
if err == nil {
c.logger.Info("IDToken nested token header", "header", nestedHeader)
if algI, ok := headerClaims["alg"]; ok {

// check string
if alg, ok := algI.(string); ok {
// ony supported encryption alg
if slices.Contains(c.config.TokenEncryptionAlg, alg) {
// if alg == "RSA-OAEP-256" || alg == "RSA-OAEP" || alg == "ECDH-ES" {
jwtPayload, err := c.jwtsigner.DecryptJWT(idTokenRaw, alg)
if err != nil {
c.logger.Error("error decrypting jwt", "error", err)
} else {

// nested jwt payload
idTokenRaw = jwtPayload
c.logger.Info("Encryped JWT", "payload", jwtPayload)
nestedHeader, _, err := c.parseJWTHeader(idTokenRaw)
if err == nil {
c.logger.Info("IDToken nested token header", "header", nestedHeader)
}
}
}

}

}
}

Expand Down
91 changes: 84 additions & 7 deletions src/client/jwt/signer/key_ec.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,18 @@ package signer

import (
"crypto/ecdsa"
"crypto/rand"
"crypto/sha1"
"crypto/x509"
"crypto/x509/pkix"
"encoding/json"
"fmt"
"math/big"
"time"

"github.com/golang-jwt/jwt/v5"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwe"
"gopkg.in/square/go-jose.v2"
)

Expand Down Expand Up @@ -54,15 +62,33 @@ func NewECJWTSigner(k *ecdsa.PrivateKey, alg, mockKid string) (*ECJWTSigner, err
// JWKS is the JSON JWKS representation of the rsa.PublicKey
func (k *ECJWTSigner) JWKS() ([]byte, error) {

cert, err := k.genX509Cert()
if err != nil {
return nil, err
}

fingerprint := sha1.Sum(cert.Raw)

// TODO: support mutli signing alg
jwk := jose.JSONWebKey{
Use: "sig",
Algorithm: k.alg,
Key: k.PublicKey,
KeyID: k.Kid,
sig := jose.JSONWebKey{
Use: "sig",
// Algorithm: k.alg,
Key: k.PublicKey,
KeyID: k.Kid,
Certificates: []*x509.Certificate{cert},
CertificateThumbprintSHA1: fingerprint[:],
}

enc := jose.JSONWebKey{
Use: "enc",

Key: k.PublicKey,
KeyID: k.Kid,
Certificates: []*x509.Certificate{cert},
CertificateThumbprintSHA1: fingerprint[:],
}
jwks := &jose.JSONWebKeySet{
Keys: []jose.JSONWebKey{jwk},
Keys: []jose.JSONWebKey{sig, enc},
}

return json.Marshal(jwks)
Expand All @@ -77,6 +103,57 @@ func (k *ECJWTSigner) SignJWT(claims jwt.Claims) (string, error) {
return token.SignedString(k.PrivateKey)
}
func (k *ECJWTSigner) DecryptJWT(encryptedJwt, alg string) (string, error) {
return "", fmt.Errorf("unsupported encryption alg %s", alg)
// return "", fmt.Errorf("unsupported encryption alg %s", alg)

var method jwa.KeyAlgorithm

switch alg {
case "ECDH-ES":
method = jwa.ECDH_ES
case "ECDH-ES+A128KW":
method = jwa.ECDH_ES_A128KW
case "ECDH-ES+A192KW":
method = jwa.ECDH_ES_A192KW
case "ECDH-ES+A256KW":
method = jwa.ECDH_ES_A256KW

default:
return "", fmt.Errorf("unsupported encryption alg %s", alg)
}

decrypted, err := jwe.Decrypt([]byte(encryptedJwt), jwe.WithKey(method, k.PrivateKey))
if err != nil {
return "", err
}

return string(decrypted), nil
}

func (k *ECJWTSigner) genX509Cert() (*x509.Certificate, error) {
serialNumber := big.NewInt(100000000000000000)
template := &x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
CommonName: "oidc-client-demo",
},
Issuer: pkix.Name{
CommonName: "oidc-client-demo",
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(5, 0, 0),
PublicKeyAlgorithm: x509.ECDSA,
SignatureAlgorithm: x509.ECDSAWithSHA512,
}

derBytes, err := x509.CreateCertificate(rand.Reader, template, template, k.PublicKey, k.PrivateKey)
if err != nil {
return nil, err
}

cert, err := x509.ParseCertificate(derBytes)
if err != nil {
return nil, err
}

return cert, nil
}
65 changes: 59 additions & 6 deletions src/client/jwt/signer/key_rsa.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
package signer

import (
"crypto/rand"
"crypto/rsa"
"crypto/sha1"
"crypto/x509"
"crypto/x509/pkix"
"encoding/json"
"fmt"
"math/big"
"time"

"github.com/golang-jwt/jwt/v5"
"github.com/lestrrat-go/jwx/v2/jwa"
Expand Down Expand Up @@ -54,15 +60,33 @@ func NewRSAJWTSigner(k *rsa.PrivateKey, alg, mockKid string) (*RSAJWTSigner, err
// JWKS is the JSON JWKS representation of the rsa.PublicKey
func (k *RSAJWTSigner) JWKS() ([]byte, error) {

cert, err := k.genX509Cert()
if err != nil {
return nil, err
}

fingerprint := sha1.Sum(cert.Raw)

// TODO: support mutli signing alg
jwk := jose.JSONWebKey{
Use: "sig",
Algorithm: k.alg,
Key: k.PublicKey,
KeyID: k.Kid,
sig := jose.JSONWebKey{
Use: "sig",
// Algorithm: k.alg,
Key: k.PublicKey,
KeyID: k.Kid,
Certificates: []*x509.Certificate{cert},
CertificateThumbprintSHA1: fingerprint[:],
}

enc := jose.JSONWebKey{
Use: "enc",

Key: k.PublicKey,
KeyID: k.Kid,
Certificates: []*x509.Certificate{cert},
CertificateThumbprintSHA1: fingerprint[:],
}
jwks := &jose.JSONWebKeySet{
Keys: []jose.JSONWebKey{jwk},
Keys: []jose.JSONWebKey{sig, enc},
}

return json.Marshal(jwks)
Expand Down Expand Up @@ -99,3 +123,32 @@ func (k *RSAJWTSigner) DecryptJWT(encryptedJwt, alg string) (string, error) {
return string(decrypted), nil

}

func (k *RSAJWTSigner) genX509Cert() (*x509.Certificate, error) {
serialNumber := big.NewInt(100000000000000000)
template := &x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
CommonName: "oidc-client-demo",
},
Issuer: pkix.Name{
CommonName: "oidc-client-demo",
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(5, 0, 0),
PublicKeyAlgorithm: x509.RSA,
SignatureAlgorithm: x509.SHA512WithRSA,
}

derBytes, err := x509.CreateCertificate(rand.Reader, template, template, k.PublicKey, k.PrivateKey)
if err != nil {
return nil, err
}

cert, err := x509.ParseCertificate(derBytes)
if err != nil {
return nil, err
}

return cert, nil
}
8 changes: 7 additions & 1 deletion src/client/jwt/signer/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@ func ParsePrivateKey(filename string) (crypto.PrivateKey, error) {

privKey, err := x509.ParsePKCS8PrivateKey(key.Bytes)
if err != nil {
return nil, err
privKey, err = x509.ParsePKCS1PrivateKey(key.Bytes)
if err != nil {
privKey, err = x509.ParseECPrivateKey(key.Bytes)
if err != nil {
return nil, err
}
}
}

return privKey, nil
Expand Down
4 changes: 2 additions & 2 deletions src/client/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func (c *OIDCClient) NewCodeChallenge(codeVerifier string) (string, error) {
return pkce.NewCodeChallenge(codeVerifier, c.config.PKCEChallengeMethod)

}
func (c *OIDCClient) parseJWTHeader(rawToken string) (string, map[string]string, error) {
func (c *OIDCClient) parseJWTHeader(rawToken string) (string, map[string]interface{}, error) {

parts := strings.Split(rawToken, ".")
// header must be the first part
Expand All @@ -66,7 +66,7 @@ func (c *OIDCClient) parseJWTHeader(rawToken string) (string, map[string]string,
return "", nil, fmt.Errorf(" malformed jwt header: %v", err)
}

var parsedHeader map[string]string
var parsedHeader map[string]interface{}
if err := json.Unmarshal(header, &parsedHeader); err != nil {
return "", nil, fmt.Errorf("failed to unmarshal jwt header: %v", err)
}
Expand Down

0 comments on commit 3444088

Please sign in to comment.