Skip to content

Commit

Permalink
Normalize identity issuer (#22)
Browse files Browse the repository at this point in the history
* rpc: normalize issuer for Identity

* rpc(tests): add test for normalized issuer in Identity
  • Loading branch information
patrislav authored Feb 19, 2024
1 parent 8e1b338 commit 3d551be
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 16 deletions.
31 changes: 25 additions & 6 deletions rpc/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func getTestingCtxValue(ctx context.Context, k string) string {
func initRPC(cfg *config.Config, enc *enclave.Enclave, dbClient *dbMock) *rpc.RPC {
svc := &rpc.RPC{
Config: cfg,
HTTPClient: http.DefaultClient,
HTTPClient: httpClient{},
Enclave: enc,
Wallets: newWalletServiceMock(nil),
Tenants: data.NewTenantTable(dbClient, "Tenants"),
Expand Down Expand Up @@ -166,7 +166,7 @@ QwIDAQAB
}
}

func issueAccessTokenAndRunJwksServer(t *testing.T, optTokenBuilderFn ...func(*jwt.Builder)) (iss string, tok string, close func()) {
func issueAccessTokenAndRunJwksServer(t *testing.T, optTokenBuilderFn ...func(*jwt.Builder, string)) (iss string, tok string, close func()) {
jwtKeyRaw, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
jwtKey, err := jwk.FromRaw(jwtKeyRaw)
Expand Down Expand Up @@ -204,7 +204,7 @@ func issueAccessTokenAndRunJwksServer(t *testing.T, optTokenBuilderFn ...func(*j
Subject("subject")

if len(optTokenBuilderFn) > 0 && optTokenBuilderFn[0] != nil {
optTokenBuilderFn[0](tokBuilder)
optTokenBuilderFn[0](tokBuilder, jwksServer.URL)
}

tokRaw, err := tokBuilder.Build()
Expand Down Expand Up @@ -459,9 +459,12 @@ func newTenant(t *testing.T, enc *enclave.Enclave, issuer string) (*data.Tenant,
},
UpgradeCode: "CHANGEME",
WaasAccessToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJwYXJ0bmVyX2lkIjozfQ.g2fWwLrKPhTUpLFc7ZM9pMm4kEHGu8haCMzMOOGiqSM",
OIDCProviders: []*proto.OpenIdProvider{{Issuer: issuer, Audience: []string{"audience"}}},
AllowedOrigins: []string{"http://localhost"},
KMSKeys: []string{"SessionKey"},
OIDCProviders: []*proto.OpenIdProvider{
{Issuer: issuer, Audience: []string{"audience"}},
{Issuer: "https://" + strings.TrimPrefix(issuer, "http://"), Audience: []string{"audience"}},
},
AllowedOrigins: []string{"http://localhost"},
KMSKeys: []string{"SessionKey"},
}

encryptedKey, algorithm, ciphertext, err := crypto.EncryptData(context.Background(), att, "TenantKey", payload)
Expand Down Expand Up @@ -748,3 +751,19 @@ func (w walletServiceMock) FinishValidateSession(ctx context.Context, sessionId
}

var _ proto_wallet.WaaS = (*walletServiceMock)(nil)

type httpClient struct{}

func (httpClient) Do(req *http.Request) (*http.Response, error) {
req.URL.Scheme = "http"
return http.DefaultClient.Do(req)
}

func (httpClient) Get(s string) (*http.Response, error) {
if strings.HasPrefix(s, "https://") {
s = "http://" + strings.TrimPrefix(s, "https://")
}
return http.DefaultClient.Get(s)
}

var _ rpc.HTTPClient = (*httpClient)(nil)
7 changes: 4 additions & 3 deletions rpc/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,10 @@ func verifyIdentity(ctx context.Context, client HTTPClient, idToken string, sess
return proto.Identity{}, fmt.Errorf("parse JWT: %w", err)
}

idp := getOIDCProvider(ctx, normalizeIssuer(tok.Issuer()))
issuer := normalizeIssuer(tok.Issuer())
idp := getOIDCProvider(ctx, issuer)
if idp == nil {
return proto.Identity{}, fmt.Errorf("issuer %q not valid for this tenant", tok.Issuer())
return proto.Identity{}, fmt.Errorf("issuer %q not valid for this tenant", issuer)
}

keySet, err := getProviderKeySet(ctx, client, normalizeIssuer(idp.Issuer))
Expand All @@ -94,7 +95,7 @@ func verifyIdentity(ctx context.Context, client HTTPClient, idToken string, sess

identity := proto.Identity{
Type: proto.IdentityType_OIDC,
Issuer: tok.Issuer(),
Issuer: issuer,
Subject: tok.Subject(),
Email: getEmailFromToken(tok),
}
Expand Down
27 changes: 20 additions & 7 deletions rpc/sessions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
mathrand "math/rand"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -41,7 +42,7 @@ func TestRPC_RegisterSession(t *testing.T) {
}
testCases := map[string]struct {
assertFn func(t *testing.T, sess *proto.Session, err error, p assertionParams)
tokBuilderFn func(b *jwt.Builder)
tokBuilderFn func(b *jwt.Builder, url string)
intentBuilderFn func(t *testing.T, data intents.IntentDataOpenSession) *proto.Intent
}{
"Basic": {
Expand All @@ -60,14 +61,14 @@ func TestRPC_RegisterSession(t *testing.T) {
},
},
"WithInvalidIssuer": {
tokBuilderFn: func(b *jwt.Builder) { b.Issuer("https://id.example.com") },
tokBuilderFn: func(b *jwt.Builder, url string) { b.Issuer("https://id.example.com") },
assertFn: func(t *testing.T, sess *proto.Session, err error, p assertionParams) {
require.Nil(t, sess)
require.ErrorContains(t, err, `issuer "https://id.example.com" not valid for this tenant`)
},
},
"WithValidNonce": {
tokBuilderFn: func(b *jwt.Builder) { b.Claim("nonce", sessHash) },
tokBuilderFn: func(b *jwt.Builder, url string) { b.Claim("nonce", sessHash) },
assertFn: func(t *testing.T, sess *proto.Session, err error, p assertionParams) {
require.NoError(t, err)
require.NotNil(t, sess)
Expand All @@ -76,14 +77,14 @@ func TestRPC_RegisterSession(t *testing.T) {
},
},
"WithInvalidNonce": {
tokBuilderFn: func(b *jwt.Builder) { b.Claim("nonce", "0x1234567890abcdef") },
tokBuilderFn: func(b *jwt.Builder, url string) { b.Claim("nonce", "0x1234567890abcdef") },
assertFn: func(t *testing.T, sess *proto.Session, err error, p assertionParams) {
require.Nil(t, sess)
require.ErrorContains(t, err, "JWT validation: nonce not satisfied")
},
},
"WithInvalidNonceButValidSessionAddressClaim": {
tokBuilderFn: func(b *jwt.Builder) {
tokBuilderFn: func(b *jwt.Builder, url string) {
b.Claim("nonce", "0x1234567890abcdef").
Claim("sequence:session_hash", sessHash)
},
Expand All @@ -95,7 +96,7 @@ func TestRPC_RegisterSession(t *testing.T) {
},
},
"WithVerifiedEmail": {
tokBuilderFn: func(b *jwt.Builder) {
tokBuilderFn: func(b *jwt.Builder, url string) {
b.Claim("email", "user@example.com").Claim("email_verified", "true")
},
assertFn: func(t *testing.T, sess *proto.Session, err error, p assertionParams) {
Expand All @@ -106,7 +107,7 @@ func TestRPC_RegisterSession(t *testing.T) {
},
},
"WithUnverifiedEmail": {
tokBuilderFn: func(b *jwt.Builder) {
tokBuilderFn: func(b *jwt.Builder, url string) {
b.Claim("email", "user@example.com").Claim("email_verified", "false")
},
assertFn: func(t *testing.T, sess *proto.Session, err error, p assertionParams) {
Expand All @@ -131,6 +132,18 @@ func TestRPC_RegisterSession(t *testing.T) {
assert.ErrorContains(t, err, "intent is invalid: no signatures")
},
},
"IssuerMissingScheme": {
tokBuilderFn: func(b *jwt.Builder, url string) {
b.Issuer(strings.TrimPrefix(url, "http://"))
},
assertFn: func(t *testing.T, sess *proto.Session, err error, p assertionParams) {
require.NoError(t, err)
require.NotNil(t, sess)

httpsIssuer := "https://" + strings.TrimPrefix(p.issuer, "http://")
assert.Equal(t, httpsIssuer, sess.Identity.Issuer)
},
},
}

for label, testCase := range testCases {
Expand Down

0 comments on commit 3d551be

Please sign in to comment.