Skip to content

Commit

Permalink
JWT: Let the configuration take a JSON Web Key Set
Browse files Browse the repository at this point in the history
Right now key material needs to be provided in the form of PEM/DER
format. Even though this is native in the sense that Go's crypto
libraries use it, it's not a common way of representing them in the
JWT ecosystem.

This change replaces the existing options with a new one named
'jwks_inline', which may hold a JWKS as specified in RFC 7517, chapter
five. The easiest way to convert existing public keys to JWKS is to
install the 'step' CLI and run the command below:

step crypto key format --jwk < mykey

https://smallstep.com/docs/step-cli/reference/crypto/key/format/

This work is based on a contribution by Morten Mjelva and Robert
Collins. Thanks a lot!

Fixes: #165
Fixes: #179
  • Loading branch information
EdSchouten committed Oct 6, 2023
1 parent 74764b5 commit 0bb5e73
Show file tree
Hide file tree
Showing 21 changed files with 231 additions and 160 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/sts v1.22.0
github.com/bazelbuild/remote-apis v0.0.0-20230822133051-6c32c3b917cc
github.com/fxtlabs/primes v0.0.0-20150821004651-dad82d10a449
github.com/go-jose/go-jose/v3 v3.0.0
github.com/go-redis/redis/extra/redisotel v0.3.0
github.com/go-redis/redis/v8 v8.11.5
github.com/golang/mock v1.6.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMo
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
github.com/fxtlabs/primes v0.0.0-20150821004651-dad82d10a449 h1:HOYnhuVrhAVGKdg3rZapII640so7QfXQmkLkefUN/uM=
github.com/fxtlabs/primes v0.0.0-20150821004651-dad82d10a449/go.mod h1:i+vbdOOivRRh2j+WwBkjZXloGN/+KAqfKDwNfUJeugc=
github.com/go-jose/go-jose/v3 v3.0.0 h1:s6rrhirfEP/CGIoc6p+PZAeogN2SxKav6Wp7+dyMWVo=
github.com/go-jose/go-jose/v3 v3.0.0/go.mod h1:RNkWWRld676jZEYoV3+XK8L2ZnNSvIsxFMht0mSX+u8=
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
Expand Down
6 changes: 6 additions & 0 deletions go_dependencies.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,12 @@ def go_dependencies():
sum = "h1:HOYnhuVrhAVGKdg3rZapII640so7QfXQmkLkefUN/uM=",
version = "v0.0.0-20150821004651-dad82d10a449",
)
go_repository(
name = "com_github_go_jose_go_jose_v3",
importpath = "github.com/go-jose/go-jose/v3",
sum = "h1:s6rrhirfEP/CGIoc6p+PZAeogN2SxKav6Wp7+dyMWVo=",
version = "v3.0.0",
)
go_repository(
name = "com_github_go_kit_log",
importpath = "github.com/go-kit/log",
Expand Down
3 changes: 3 additions & 0 deletions pkg/jwt/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go_library(
srcs = [
"authorization_header_parser.go",
"configuration.go",
"demultiplexing_signature_validator.go",
"ecdsa_sha_signature_generator.go",
"ecdsa_sha_signature_validator.go",
"ed25519_signature_validator.go",
Expand All @@ -23,9 +24,11 @@ go_library(
"//pkg/proto/configuration/jwt",
"//pkg/random",
"//pkg/util",
"@com_github_go_jose_go_jose_v3//:go-jose",
"@com_github_jmespath_go_jmespath//:go-jmespath",
"@org_golang_google_grpc//codes",
"@org_golang_google_grpc//status",
"@org_golang_google_protobuf//encoding/protojson",
],
)

Expand Down
5 changes: 3 additions & 2 deletions pkg/jwt/authorization_header_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,13 @@ func (a *AuthorizationHeaderParser) parseSingleAuthorizationHeader(header string

// Perform signature validation.
headerMessage := struct {
Alg string `json:"alg"`
Alg string `json:"alg"`
KID *string `json:"kid"`
}{}
if json.Unmarshal(decodedFields[0], &headerMessage) != nil {
return unauthenticated
}
if !a.signatureValidator.ValidateSignature(headerMessage.Alg, match[1], decodedFields[2]) {
if !a.signatureValidator.ValidateSignature(headerMessage.Alg, headerMessage.KID, match[1], decodedFields[2]) {
return unauthenticated
}

Expand Down
32 changes: 19 additions & 13 deletions pkg/jwt/authorization_header_parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ func TestAuthorizationHeaderParser(t *testing.T) {
1000,
eviction.NewLRUSet[string]())

exampleKeyID := "MyKeyID"

t.Run("NoAuthorizationHeadersProvided", func(t *testing.T) {
clock.EXPECT().Now().Return(time.Unix(1635747849, 0))

Expand All @@ -36,6 +38,7 @@ func TestAuthorizationHeaderParser(t *testing.T) {
clock.EXPECT().Now().Return(time.Unix(1635747849, 0))
signatureValidator.EXPECT().ValidateSignature(
"HS256",
/* keyID = */ nil,
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ",
[]byte{
0x49, 0xf9, 0x4a, 0xc7, 0x04, 0x49, 0x48, 0xc7,
Expand All @@ -59,7 +62,8 @@ func TestAuthorizationHeaderParser(t *testing.T) {
clock.EXPECT().Now().Return(time.Unix(1635781700, 0))
signatureValidator.EXPECT().ValidateSignature(
"HS256",
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ",
&exampleKeyID,
"eyJhbGciOiJIUzI1NiIsImtpZCI6Ik15S2V5SUQiLCJ0eXAiOiJKV1QifQ.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ",
[]byte{
0x69, 0xf2, 0xcf, 0x62, 0xca, 0x9a, 0xa4, 0x3c,
0x6f, 0xc1, 0xe7, 0x61, 0x35, 0x39, 0xd8, 0xaa,
Expand All @@ -69,7 +73,7 @@ func TestAuthorizationHeaderParser(t *testing.T) {
).Return(true)

metadata, ok := authenticator.ParseAuthorizationHeaders([]string{
"Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.afLPYsqapDxvwedhNTnYqpk3YmXo9rSO24UDyCokl9M",
"Bearer eyJhbGciOiJIUzI1NiIsImtpZCI6Ik15S2V5SUQiLCJ0eXAiOiJKV1QifQ.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.afLPYsqapDxvwedhNTnYqpk3YmXo9rSO24UDyCokl9M",
})
require.True(t, ok)
require.Equal(t, map[string]any{
Expand All @@ -84,7 +88,7 @@ func TestAuthorizationHeaderParser(t *testing.T) {
clock.EXPECT().Now().Return(time.Unix(1635781701, 0))

metadata, ok = authenticator.ParseAuthorizationHeaders([]string{
"Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.afLPYsqapDxvwedhNTnYqpk3YmXo9rSO24UDyCokl9M",
"Bearer eyJhbGciOiJIUzI1NiIsImtpZCI6Ik15S2V5SUQiLCJ0eXAiOiJKV1QifQ.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.afLPYsqapDxvwedhNTnYqpk3YmXo9rSO24UDyCokl9M",
})
require.True(t, ok)
require.Equal(t, map[string]any{
Expand All @@ -102,7 +106,8 @@ func TestAuthorizationHeaderParser(t *testing.T) {
clock.EXPECT().Now().Return(time.Unix(1635781778, 0))
signatureValidator.EXPECT().ValidateSignature(
"HS256",
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwibmJmIjoxNjM1NzgxNzgwLCJleHAiOjE2MzU3ODE3OTJ9",
&exampleKeyID,
"eyJhbGciOiJIUzI1NiIsImtpZCI6Ik15S2V5SUQiLCJ0eXAiOiJKV1QifQ.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwibmJmIjoxNjM1NzgxNzgwLCJleHAiOjE2MzU3ODE3OTJ9",
[]byte{
0x9a, 0xf0, 0xa6, 0x11, 0xb2, 0x62, 0xcb, 0xec,
0x48, 0x43, 0x7c, 0xec, 0x21, 0x3a, 0x6a, 0x6e,
Expand All @@ -112,15 +117,15 @@ func TestAuthorizationHeaderParser(t *testing.T) {
).Return(true)

_, ok := authenticator.ParseAuthorizationHeaders([]string{
"Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwibmJmIjoxNjM1NzgxNzgwLCJleHAiOjE2MzU3ODE3OTJ9.mvCmEbJiy-xIQ3zsITpqbthXrSTjtuph1Sd2KGvMXhY",
"Bearer eyJhbGciOiJIUzI1NiIsImtpZCI6Ik15S2V5SUQiLCJ0eXAiOiJKV1QifQ.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwibmJmIjoxNjM1NzgxNzgwLCJleHAiOjE2MzU3ODE3OTJ9.mvCmEbJiy-xIQ3zsITpqbthXrSTjtuph1Sd2KGvMXhY",
})
require.False(t, ok)

// Successive calls for the same token should be cached.
clock.EXPECT().Now().Return(time.Unix(1635781779, 0))

_, ok = authenticator.ParseAuthorizationHeaders([]string{
"Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwibmJmIjoxNjM1NzgxNzgwLCJleHAiOjE2MzU3ODE3OTJ9.mvCmEbJiy-xIQ3zsITpqbthXrSTjtuph1Sd2KGvMXhY",
"Bearer eyJhbGciOiJIUzI1NiIsImtpZCI6Ik15S2V5SUQiLCJ0eXAiOiJKV1QifQ.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwibmJmIjoxNjM1NzgxNzgwLCJleHAiOjE2MzU3ODE3OTJ9.mvCmEbJiy-xIQ3zsITpqbthXrSTjtuph1Sd2KGvMXhY",
})
require.False(t, ok)

Expand All @@ -129,7 +134,7 @@ func TestAuthorizationHeaderParser(t *testing.T) {
clock.EXPECT().Now().Return(time.Unix(1635781780, 0))

metadata, ok := authenticator.ParseAuthorizationHeaders([]string{
"Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwibmJmIjoxNjM1NzgxNzgwLCJleHAiOjE2MzU3ODE3OTJ9.mvCmEbJiy-xIQ3zsITpqbthXrSTjtuph1Sd2KGvMXhY",
"Bearer eyJhbGciOiJIUzI1NiIsImtpZCI6Ik15S2V5SUQiLCJ0eXAiOiJKV1QifQ.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwibmJmIjoxNjM1NzgxNzgwLCJleHAiOjE2MzU3ODE3OTJ9.mvCmEbJiy-xIQ3zsITpqbthXrSTjtuph1Sd2KGvMXhY",
})
require.True(t, ok)
require.Equal(t, map[string]any{
Expand All @@ -146,7 +151,7 @@ func TestAuthorizationHeaderParser(t *testing.T) {
clock.EXPECT().Now().Return(time.Unix(1635781786, 0))

metadata, ok = authenticator.ParseAuthorizationHeaders([]string{
"Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwibmJmIjoxNjM1NzgxNzgwLCJleHAiOjE2MzU3ODE3OTJ9.mvCmEbJiy-xIQ3zsITpqbthXrSTjtuph1Sd2KGvMXhY",
"Bearer eyJhbGciOiJIUzI1NiIsImtpZCI6Ik15S2V5SUQiLCJ0eXAiOiJKV1QifQ.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwibmJmIjoxNjM1NzgxNzgwLCJleHAiOjE2MzU3ODE3OTJ9.mvCmEbJiy-xIQ3zsITpqbthXrSTjtuph1Sd2KGvMXhY",
})
require.True(t, ok)
require.Equal(t, map[string]any{
Expand All @@ -161,7 +166,7 @@ func TestAuthorizationHeaderParser(t *testing.T) {
clock.EXPECT().Now().Return(time.Unix(1635781791, 0))

metadata, ok = authenticator.ParseAuthorizationHeaders([]string{
"Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwibmJmIjoxNjM1NzgxNzgwLCJleHAiOjE2MzU3ODE3OTJ9.mvCmEbJiy-xIQ3zsITpqbthXrSTjtuph1Sd2KGvMXhY",
"Bearer eyJhbGciOiJIUzI1NiIsImtpZCI6Ik15S2V5SUQiLCJ0eXAiOiJKV1QifQ.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwibmJmIjoxNjM1NzgxNzgwLCJleHAiOjE2MzU3ODE3OTJ9.mvCmEbJiy-xIQ3zsITpqbthXrSTjtuph1Sd2KGvMXhY",
})
require.True(t, ok)
require.Equal(t, map[string]any{
Expand All @@ -178,7 +183,7 @@ func TestAuthorizationHeaderParser(t *testing.T) {
clock.EXPECT().Now().Return(time.Unix(1635781792, 0))

_, ok = authenticator.ParseAuthorizationHeaders([]string{
"Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwibmJmIjoxNjM1NzgxNzgwLCJleHAiOjE2MzU3ODE3OTJ9.mvCmEbJiy-xIQ3zsITpqbthXrSTjtuph1Sd2KGvMXhY",
"Bearer eyJhbGciOiJIUzI1NiIsImtpZCI6Ik15S2V5SUQiLCJ0eXAiOiJKV1QifQ.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwibmJmIjoxNjM1NzgxNzgwLCJleHAiOjE2MzU3ODE3OTJ9.mvCmEbJiy-xIQ3zsITpqbthXrSTjtuph1Sd2KGvMXhY",
})
require.False(t, ok)

Expand All @@ -187,7 +192,7 @@ func TestAuthorizationHeaderParser(t *testing.T) {
clock.EXPECT().Now().Return(time.Unix(1635781793, 0))

_, ok = authenticator.ParseAuthorizationHeaders([]string{
"Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwibmJmIjoxNjM1NzgxNzgwLCJleHAiOjE2MzU3ODE3OTJ9.mvCmEbJiy-xIQ3zsITpqbthXrSTjtuph1Sd2KGvMXhY",
"Bearer eyJhbGciOiJIUzI1NiIsImtpZCI6Ik15S2V5SUQiLCJ0eXAiOiJKV1QifQ.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwibmJmIjoxNjM1NzgxNzgwLCJleHAiOjE2MzU3ODE3OTJ9.mvCmEbJiy-xIQ3zsITpqbthXrSTjtuph1Sd2KGvMXhY",
})
require.False(t, ok)
})
Expand All @@ -199,7 +204,8 @@ func TestAuthorizationHeaderParser(t *testing.T) {
clock.EXPECT().Now().Return(time.Unix(1636144433, 0))
signatureValidator.EXPECT().ValidateSignature(
"HS256",
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb3JiaWRkZW5GaWVsZCI6Im9vcHMifQ",
&exampleKeyID,
"eyJhbGciOiJIUzI1NiIsImtpZCI6Ik15S2V5SUQiLCJ0eXAiOiJKV1QifQ.eyJmb3JiaWRkZW5GaWVsZCI6Im9vcHMifQ",
[]byte{
0xf1, 0x5c, 0xbc, 0x0c, 0x47, 0x71, 0x2d, 0x88,
0x42, 0x8a, 0xe3, 0x52, 0x32, 0x77, 0xee, 0xb7,
Expand All @@ -209,7 +215,7 @@ func TestAuthorizationHeaderParser(t *testing.T) {
).Return(true)

_, ok := authenticator.ParseAuthorizationHeaders([]string{
"Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb3JiaWRkZW5GaWVsZCI6Im9vcHMifQ.8Vy8DEdxLYhCiuNSMnfut4c7UJmHjHQWencNhePnKH4",
"Bearer eyJhbGciOiJIUzI1NiIsImtpZCI6Ik15S2V5SUQiLCJ0eXAiOiJKV1QifQ.eyJmb3JiaWRkZW5GaWVsZCI6Im9vcHMifQ.8Vy8DEdxLYhCiuNSMnfut4c7UJmHjHQWencNhePnKH4",
})
require.False(t, ok)
})
Expand Down
90 changes: 58 additions & 32 deletions pkg/jwt/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,54 +4,36 @@ import (
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"encoding/json"
"reflect"

"github.com/buildbarn/bb-storage/pkg/clock"
"github.com/buildbarn/bb-storage/pkg/eviction"
configuration "github.com/buildbarn/bb-storage/pkg/proto/configuration/jwt"
"github.com/buildbarn/bb-storage/pkg/util"
jose "github.com/go-jose/go-jose/v3"
"github.com/jmespath/go-jmespath"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/protojson"
)

// NewAuthorizationHeaderParserFromConfiguration creates a new HTTP
// "Authorization" header parser based on options stored in a
// configuration file.
func NewAuthorizationHeaderParserFromConfiguration(config *configuration.AuthorizationHeaderParserConfiguration) (*AuthorizationHeaderParser, error) {
var signatureValidator SignatureValidator
switch key := config.Key.(type) {
case *configuration.AuthorizationHeaderParserConfiguration_HmacKey:
signatureValidator = NewHMACSHASignatureValidator(key.HmacKey)
case *configuration.AuthorizationHeaderParserConfiguration_PublicKey:
block, _ := pem.Decode([]byte(key.PublicKey))
if block == nil {
return nil, status.Error(codes.InvalidArgument, "Public key does not use the PEM format")
}
parsedKey, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, util.StatusWrapWithCode(err, codes.InvalidArgument, "Failed to parse public key")
}
switch convertedKey := parsedKey.(type) {
case *ecdsa.PublicKey:
var err error
signatureValidator, err = NewECDSASHASignatureValidator(convertedKey)
if err != nil {
return nil, err
}
case ed25519.PublicKey:
signatureValidator = NewEd25519SignatureValidator(convertedKey)
case *rsa.PublicKey:
signatureValidator = NewRSASHASignatureValidator(convertedKey)
default:
keyType := reflect.TypeOf(parsedKey)
return nil, status.Errorf(codes.InvalidArgument, "Unsupported public key type: %s/%s", keyType.PkgPath(), keyType.Name())
}
default:
return nil, status.Error(codes.InvalidArgument, "No key type provided")
jwksJSON, err := protojson.Marshal(config.JwksInline)
if err != nil {
return nil, util.StatusWrapWithCode(err, codes.InvalidArgument, "Failed to marshal JSON Web Key Set")
}
var jwks jose.JSONWebKeySet
if err := json.Unmarshal(jwksJSON, &jwks); err != nil {
return nil, util.StatusWrapWithCode(err, codes.InvalidArgument, "Failed to unmarshal JSON Web Key Set")
}
signatureValidator, err := NewSignatureValidatorFromJSONWebKeySet(&jwks)
if err != nil {
return nil, err
}

evictionSet, err := eviction.NewSetFromConfiguration[string](config.CacheReplacementPolicy)
Expand All @@ -76,3 +58,47 @@ func NewAuthorizationHeaderParserFromConfiguration(config *configuration.Authori
int(config.MaximumCacheSize),
eviction.NewMetricsSet(evictionSet, "AuthorizationHeaderParser")), nil
}

// NewSignatureValidatorFromJSONWebKeySet creates a new
// SignatureValidator capable of validating JWTs matching keys contained
// in a JSON Web Key Set, as described in RFC 7517, chapter 5.
func NewSignatureValidatorFromJSONWebKeySet(jwks *jose.JSONWebKeySet) (SignatureValidator, error) {
namedSignatureValidators := make(map[string]SignatureValidator, len(jwks.Keys))
allSignatureValidators := make([]SignatureValidator, 0, len(jwks.Keys))
for i, jwk := range jwks.Keys {
if !jwk.Valid() {
return nil, status.Errorf(codes.InvalidArgument, "Invalid JSON Web Key at index %d", i)
}

var signatureValidator SignatureValidator
switch convertedKey := jwk.Key.(type) {
case *ecdsa.PublicKey:
var err error
signatureValidator, err = NewECDSASHASignatureValidator(convertedKey)
if err != nil {
return nil, util.StatusWrapf(err, "Invalid ECDSA key at index %d", i)
}
case ed25519.PublicKey:
signatureValidator = NewEd25519SignatureValidator(convertedKey)
case *rsa.PublicKey:
signatureValidator = NewRSASHASignatureValidator(convertedKey)
default:
keyType := reflect.TypeOf(jwk.Key)
return nil, status.Errorf(codes.InvalidArgument, "Unsupported public key type at index %d: %s/%s", i, keyType.PkgPath(), keyType.Name())
}

if jwk.KeyID != "" {
// JSON Web Key contains a key ID. Ensure that
// JWTs that contain an explicit key ID only get
// matched to this validator if the key ID
// matches.
if _, ok := namedSignatureValidators[jwk.KeyID]; ok {
return nil, status.Errorf(codes.InvalidArgument, "JSON Web Key Set contains multiple keys with ID %#v", jwk.KeyID)
}
namedSignatureValidators[jwk.KeyID] = signatureValidator
}
allSignatureValidators = append(allSignatureValidators, signatureValidator)
}

return NewDemultiplexingSignatureValidator(namedSignatureValidators, allSignatureValidators), nil
}
31 changes: 31 additions & 0 deletions pkg/jwt/demultiplexing_signature_validator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package jwt

type demultiplexingSignatureValidator struct {
namedSignatureValidators map[string]SignatureValidator
allSignatureValidators []SignatureValidator
}

// NewDemultiplexingSignatureValidator creates a SignatureValidator that
// routes signature validation requests based on the key ID ("kid")
// field that's part of a JWT's header.
func NewDemultiplexingSignatureValidator(namedSignatureValidators map[string]SignatureValidator, allSignatureValidators []SignatureValidator) SignatureValidator {
return &demultiplexingSignatureValidator{
namedSignatureValidators: namedSignatureValidators,
allSignatureValidators: allSignatureValidators,
}
}

func (sv *demultiplexingSignatureValidator) ValidateSignature(algorithm string, keyID *string, headerAndPayload string, signature []byte) bool {
if keyID == nil {
// No key ID provided. Simply try all signature validators.
for _, signatureValidator := range sv.allSignatureValidators {
if signatureValidator.ValidateSignature(algorithm, keyID, headerAndPayload, signature) {
return true
}
}
} else if signatureValidator, ok := sv.namedSignatureValidators[*keyID]; ok {
// Exact match on the key ID.
return signatureValidator.ValidateSignature(algorithm, keyID, headerAndPayload, signature)
}
return false
}
2 changes: 1 addition & 1 deletion pkg/jwt/ecdsa_sha_signature_generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ f2EJfEoVNO/YidkVY+J35v8vQoAMS4rRGA==
// Ensure that the generated signature is valid.
signatureValidator, err := jwt.NewECDSASHASignatureValidator(&key.PublicKey)
require.NoError(t, err)
require.True(t, signatureValidator.ValidateSignature("ES256", headerAndPayload, signature))
require.True(t, signatureValidator.ValidateSignature("ES256", nil, headerAndPayload, signature))
})
}
2 changes: 1 addition & 1 deletion pkg/jwt/ecdsa_sha_signature_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func NewECDSASHASignatureValidator(publicKey *ecdsa.PublicKey) (SignatureValidat
}, nil
}

func (sv *ecdsaSHASignatureValidator) ValidateSignature(algorithm, headerAndPayload string, signature []byte) bool {
func (sv *ecdsaSHASignatureValidator) ValidateSignature(algorithm string, keyID *string, headerAndPayload string, signature []byte) bool {
p := sv.parameters
if algorithm != p.algorithm || len(signature) != 2*p.keySizeBytes {
return false
Expand Down
Loading

0 comments on commit 0bb5e73

Please sign in to comment.