Skip to content

Commit

Permalink
fix: ignore decrypt errors in WithDeclassifiedCredentials (#3731)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonas-jonas authored Feb 22, 2024
1 parent b7e5144 commit 8f5192f
Show file tree
Hide file tree
Showing 8 changed files with 131 additions and 41 deletions.
2 changes: 1 addition & 1 deletion cipher/chacha20.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func (c *XChaCha20Poly1305) Decrypt(ctx context.Context, ciphertext string) ([]b
for i := range secrets {
aead, err := chacha20poly1305.NewX(secrets[i][:])
if err != nil {
return nil, errors.WithStack(herodot.ErrInternalServerError.WithWrap(err).WithReason("Unable to instanciate chacha20"))
return nil, errors.WithStack(herodot.ErrInternalServerError.WithWrap(err).WithReason("Unable to instantiate chacha20"))
}

if len(ciphertext) < aead.NonceSize() {
Expand Down
2 changes: 1 addition & 1 deletion driver/registry_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ func (m *RegistryDefault) Cipher(ctx context.Context) cipher.Cipher {
m.crypter = cipher.NewCryptAES(m)
default:
m.crypter = cipher.NewNoop(m)
m.l.Logger.Warning("No encryption configuration found. Default algorithm (noop) will be use that mean sensitive data will be recorded in plaintext")
m.l.Logger.Warning("No encryption configuration found. The default algorithm (noop) will be used, resulting in sensitive data being stored in plaintext")
}
}
return m.crypter
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"type": "oidc",
"identifiers": [
"bar",
"baz"
],
"config": {
"providers": [
{
"initial_id_token": "foo",
"initial_access_token": "",
"initial_refresh_token": "",
"subject": "",
"provider": "",
"organization": ""
}
]
},
"version": 0,
"created_at": "0001-01-01T00:00:00Z",
"updated_at": "0001-01-01T00:00:00Z"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"type": "password",
"identifiers": [
"zab",
"bar"
],
"version": 0,
"created_at": "0001-01-01T00:00:00Z",
"updated_at": "0001-01-01T00:00:00Z"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"type": "webauthn",
"identifiers": [
"foo",
"bar"
],
"version": 0,
"created_at": "0001-01-01T00:00:00Z",
"updated_at": "0001-01-01T00:00:00Z"
}
48 changes: 39 additions & 9 deletions identity/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -364,17 +364,19 @@ func TestHandler(t *testing.T) {
identities := res.Array()
require.Equal(t, len(identities), listAmount)
})

})

t.Run("suite=create and update", func(t *testing.T) {
var i identity.Identity
createOidcIdentity := func(t *testing.T, identifier, accessToken, refreshToken, idToken string, encrypt bool) string {
transform := func(token string) string {
transform := func(token, suffix string) string {
if !encrypt {
return token
}
c, err := reg.Cipher(ctx).Encrypt(context.Background(), []byte(token))
if token == "" {
return ""
}
c, err := reg.Cipher(ctx).Encrypt(context.Background(), []byte(token+suffix))
require.NoError(t, err)
return c
}
Expand All @@ -396,16 +398,16 @@ func TestHandler(t *testing.T) {
{
Subject: "foo",
Provider: "bar",
InitialAccessToken: transform(accessToken + "0"),
InitialRefreshToken: transform(refreshToken + "0"),
InitialIDToken: transform(idToken + "0"),
InitialAccessToken: transform(accessToken, "0"),
InitialRefreshToken: transform(refreshToken, "0"),
InitialIDToken: transform(idToken, "0"),
},
{
Subject: "baz",
Provider: "zab",
InitialAccessToken: transform(accessToken + "1"),
InitialRefreshToken: transform(refreshToken + "1"),
InitialIDToken: transform(idToken + "1"),
InitialAccessToken: transform(accessToken, "1"),
InitialRefreshToken: transform(refreshToken, "1"),
InitialIDToken: transform(idToken, "1"),
},
}}),
},
Expand Down Expand Up @@ -537,6 +539,34 @@ func TestHandler(t *testing.T) {
}
})

t.Run("case=should not fail on empty tokens", func(t *testing.T) {
id := createOidcIdentity(t, "foo.oidc.empty-tokens@bar.com", "", "", "", true)
for name, ts := range map[string]*httptest.Server{"public": publicTS, "admin": adminTS} {
t.Run("endpoint="+name, func(t *testing.T) {
res := get(t, ts, "/identities/"+id, http.StatusOK)
assert.False(t, res.Get("credentials.oidc.config").Exists(), "credentials config should be omitted: %s", res.Raw)
assert.False(t, res.Get("credentials.password.config").Exists(), "credentials config should be omitted: %s", res.Raw)

res = get(t, ts, "/identities/"+id+"?include_credential=oidc", http.StatusOK)
assert.True(t, res.Get("credentials").Exists(), "credentials should be included: %s", res.Raw)
assert.True(t, res.Get("credentials.password").Exists(), "password meta should be included: %s", res.Raw)
assert.False(t, res.Get("credentials.password.false").Exists(), "password credentials should not be included: %s", res.Raw)
assert.True(t, res.Get("credentials.oidc.config").Exists(), "oidc credentials should be included: %s", res.Raw)

assert.EqualValues(t, "foo", res.Get("credentials.oidc.config.providers.0.subject").String(), "credentials should be included: %s", res.Raw)
assert.EqualValues(t, "bar", res.Get("credentials.oidc.config.providers.0.provider").String(), "credentials should be included: %s", res.Raw)
assert.EqualValues(t, "access_token0", res.Get("credentials.oidc.config.providers.0.initial_access_token").String(), "credentials should be included: %s", res.Raw)
assert.EqualValues(t, "refresh_token0", res.Get("credentials.oidc.config.providers.0.initial_refresh_token").String(), "credentials should be included: %s", res.Raw)
assert.EqualValues(t, "id_token0", res.Get("credentials.oidc.config.providers.0.initial_id_token").String(), "credentials should be included: %s", res.Raw)
assert.EqualValues(t, "baz", res.Get("credentials.oidc.config.providers.1.subject").String(), "credentials should be included: %s", res.Raw)
assert.EqualValues(t, "zab", res.Get("credentials.oidc.config.providers.1.provider").String(), "credentials should be included: %s", res.Raw)
assert.EqualValues(t, "access_token1", res.Get("credentials.oidc.config.providers.1.initial_access_token").String(), "credentials should be included: %s", res.Raw)
assert.EqualValues(t, "refresh_token1", res.Get("credentials.oidc.config.providers.1.initial_refresh_token").String(), "credentials should be included: %s", res.Raw)
assert.EqualValues(t, "id_token1", res.Get("credentials.oidc.config.providers.1.initial_id_token").String(), "credentials should be included: %s", res.Raw)
})
}
})

t.Run("case=should get identity with credentials", func(t *testing.T) {
i := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID)
credentials := map[identity.CredentialsType]identity.Credentials{
Expand Down
51 changes: 25 additions & 26 deletions identity/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -441,49 +441,48 @@ func (i *Identity) WithDeclassifiedCredentials(ctx context.Context, c cipher.Pro
toPublish := original
toPublish.Config = []byte{}

for _, token := range []string{"initial_id_token", "initial_access_token", "initial_refresh_token"} {
var i int
var err error
gjson.GetBytes(original.Config, "providers").ForEach(func(_, v gjson.Result) bool {
var i int
var err error
gjson.GetBytes(original.Config, "providers").ForEach(func(_, v gjson.Result) bool {
for _, token := range []string{"initial_id_token", "initial_access_token", "initial_refresh_token"} {
key := fmt.Sprintf("%d.%s", i, token)
ciphertext := v.Get(token).String()

var plaintext []byte
plaintext, err = c.Cipher(ctx).Decrypt(ctx, ciphertext)
plaintext, err := c.Cipher(ctx).Decrypt(ctx, ciphertext)
if err != nil {
return false
plaintext = []byte("")
}

toPublish.Config, err = sjson.SetBytes(toPublish.Config, "providers."+key, string(plaintext))
if err != nil {
return false
}
}

toPublish.Config, err = sjson.SetBytes(toPublish.Config, fmt.Sprintf("providers.%d.subject", i), v.Get("subject").String())
if err != nil {
return false
}

toPublish.Config, err = sjson.SetBytes(toPublish.Config, fmt.Sprintf("providers.%d.provider", i), v.Get("provider").String())
if err != nil {
return false
}

toPublish.Config, err = sjson.SetBytes(toPublish.Config, fmt.Sprintf("providers.%d.organization", i), v.Get("organization").String())
if err != nil {
return false
}
toPublish.Config, err = sjson.SetBytes(toPublish.Config, fmt.Sprintf("providers.%d.subject", i), v.Get("subject").String())
if err != nil {
return false
}

i++
return true
})
toPublish.Config, err = sjson.SetBytes(toPublish.Config, fmt.Sprintf("providers.%d.provider", i), v.Get("provider").String())
if err != nil {
return false
}

toPublish.Config, err = sjson.SetBytes(toPublish.Config, fmt.Sprintf("providers.%d.organization", i), v.Get("organization").String())
if err != nil {
return nil, err
return false
}

credsToPublish[ct] = toPublish
i++
return true
})

if err != nil {
return nil, err
}

credsToPublish[ct] = toPublish
default:
credsToPublish[ct] = original
}
Expand Down
27 changes: 23 additions & 4 deletions identity/identity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ package identity

import (
"bytes"
"context"
"encoding/json"
"fmt"
"testing"

"github.com/ory/x/snapshotx"

"github.com/ory/kratos/cipher"
"github.com/ory/kratos/x"

"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -314,6 +316,12 @@ func TestVerifiableAddresses(t *testing.T) {
assert.Equal(t, addresses, CollectVerifiableAddresses([]*Identity{id1, id2, id3}))
}

type cipherProvider struct{}

func (c *cipherProvider) Cipher(ctx context.Context) cipher.Cipher {
return cipher.NewNoop(nil)
}

func TestWithDeclassifiedCredentials(t *testing.T) {
i := NewIdentity(config.DefaultIdentityTraitsSchemaID)
credentials := map[CredentialsType]Credentials{
Expand All @@ -325,7 +333,7 @@ func TestWithDeclassifiedCredentials(t *testing.T) {
CredentialsTypeOIDC: {
Type: CredentialsTypeOIDC,
Identifiers: []string{"bar", "baz"},
Config: sqlxx.JSONRawMessage("{\"some\" : \"secret\"}"),
Config: sqlxx.JSONRawMessage(`{"providers": [{"initial_id_token": "666f6f"}]}`),
},
CredentialsTypeWebAuthn: {
Type: CredentialsTypeWebAuthn,
Expand All @@ -336,7 +344,7 @@ func TestWithDeclassifiedCredentials(t *testing.T) {
i.Credentials = credentials

t.Run("case=no-include", func(t *testing.T) {
actualIdentity, err := i.WithDeclassifiedCredentials(ctx, nil, nil)
actualIdentity, err := i.WithDeclassifiedCredentials(ctx, &cipherProvider{}, nil)
require.NoError(t, err)

for ct, actual := range actualIdentity.Credentials {
Expand All @@ -347,7 +355,7 @@ func TestWithDeclassifiedCredentials(t *testing.T) {
})

t.Run("case=include-webauthn", func(t *testing.T) {
actualIdentity, err := i.WithDeclassifiedCredentials(ctx, nil, []CredentialsType{CredentialsTypeWebAuthn})
actualIdentity, err := i.WithDeclassifiedCredentials(ctx, &cipherProvider{}, []CredentialsType{CredentialsTypeWebAuthn})
require.NoError(t, err)

for ct, actual := range actualIdentity.Credentials {
Expand All @@ -358,7 +366,18 @@ func TestWithDeclassifiedCredentials(t *testing.T) {
})

t.Run("case=include-multi", func(t *testing.T) {
actualIdentity, err := i.WithDeclassifiedCredentials(ctx, nil, []CredentialsType{CredentialsTypeWebAuthn, CredentialsTypePassword})
actualIdentity, err := i.WithDeclassifiedCredentials(ctx, &cipherProvider{}, []CredentialsType{CredentialsTypeWebAuthn, CredentialsTypePassword})
require.NoError(t, err)

for ct, actual := range actualIdentity.Credentials {
t.Run("credential="+string(ct), func(t *testing.T) {
snapshotx.SnapshotT(t, actual)
})
}
})

t.Run("case=oidc", func(t *testing.T) {
actualIdentity, err := i.WithDeclassifiedCredentials(ctx, &cipherProvider{}, []CredentialsType{CredentialsTypeOIDC})
require.NoError(t, err)

for ct, actual := range actualIdentity.Credentials {
Expand Down

0 comments on commit 8f5192f

Please sign in to comment.