From 3bb943a9ac2d4309b43d1cb9bf27bac7cabb86f9 Mon Sep 17 00:00:00 2001 From: aeneasr <3372410+aeneasr@users.noreply.github.com> Date: Thu, 25 Aug 2022 12:48:59 +0200 Subject: [PATCH] fix: session unmarshalling --- ...json => TestUnmarshalSession-v1.11.8.json} | 0 .../TestUnmarshalSession-v1.11.9.json | 49 +++++++++++++++++++ oauth2/session.go | 24 +++++---- oauth2/session_test.go | 19 ++++--- 4 files changed, 72 insertions(+), 20 deletions(-) rename oauth2/.snapshots/{TestUnmarshalSession.json => TestUnmarshalSession-v1.11.8.json} (100%) create mode 100644 oauth2/.snapshots/TestUnmarshalSession-v1.11.9.json diff --git a/oauth2/.snapshots/TestUnmarshalSession.json b/oauth2/.snapshots/TestUnmarshalSession-v1.11.8.json similarity index 100% rename from oauth2/.snapshots/TestUnmarshalSession.json rename to oauth2/.snapshots/TestUnmarshalSession-v1.11.8.json diff --git a/oauth2/.snapshots/TestUnmarshalSession-v1.11.9.json b/oauth2/.snapshots/TestUnmarshalSession-v1.11.9.json new file mode 100644 index 00000000000..723df624f4a --- /dev/null +++ b/oauth2/.snapshots/TestUnmarshalSession-v1.11.9.json @@ -0,0 +1,49 @@ +{ + "id_token": { + "id_token_claims": { + "jti": "", + "iss": "http://127.0.0.1:4444/", + "sub": "foo@bar.com", + "aud": [ + "auth-code-client" + ], + "nonce": "mbxojlzlkefzmlecvrzfkmpm", + "exp": "0001-01-01T00:00:00Z", + "iat": "2022-08-25T09:21:04Z", + "rat": "2022-08-25T09:20:54Z", + "auth_time": "2022-08-25T09:21:01Z", + "at_hash": "", + "acr": "0", + "amr": [], + "c_hash": "", + "ext": { + "sid": "177e1f44-a1e9-415c-bfa3-8b62280b182d" + } + }, + "headers": { + "extra": { + "kid": "public:hydra.openid.id-token" + } + }, + "expires_at": { + "access_token": "2022-08-25T09:26:05Z", + "authorize_code": "2022-08-25T09:23:04.432089764Z", + "refresh_token": "2022-08-26T09:21:05Z" + }, + "username": "", + "subject": "foo@bar.com" + }, + "extra": {}, + "kid": "public:hydra.jwt.access-token", + "client_id": "auth-code-client", + "consent_challenge": "2261efbd447044a1b2f76b05c6aca164", + "exclude_not_before_claim": false, + "allowed_top_level_claims": [ + "persona_id", + "persona_krn", + "grantType", + "market", + "zone", + "login_session_id" + ] +} diff --git a/oauth2/session.go b/oauth2/session.go index 9647a4af60e..f35e3085fad 100644 --- a/oauth2/session.go +++ b/oauth2/session.go @@ -157,40 +157,38 @@ var keyRewrites = map[string]string{ "idToken.Claims.Extra": "id_token.id_token_claims.ext", } -func (s *Session) UnmarshalJSON(in []byte) (err error) { - type t Session - interpret := in - parsed := gjson.ParseBytes(in) +func (s *Session) UnmarshalJSON(original []byte) (err error) { + transformed := original + originalParsed := gjson.ParseBytes(original) - for orig, update := range keyRewrites { - if !parsed.Get(orig).Exists() { + for oldKey, newKey := range keyRewrites { + if !originalParsed.Get(oldKey).Exists() { continue } - interpret, err = sjson.SetRawBytes(interpret, update, []byte(parsed.Get(orig).Raw)) + transformed, err = sjson.SetRawBytes(transformed, newKey, []byte(originalParsed.Get(oldKey).Raw)) if err != nil { return errors.WithStack(err) } } for orig := range keyRewrites { - interpret, err = sjson.DeleteBytes(interpret, orig) + transformed, err = sjson.DeleteBytes(transformed, orig) if err != nil { return errors.WithStack(err) } } - if parsed.Get("idToken").Exists() { - interpret, err = sjson.DeleteBytes(interpret, "idToken") + if originalParsed.Get("idToken").Exists() { + transformed, err = sjson.DeleteBytes(transformed, "idToken") if err != nil { return errors.WithStack(err) } } - var tt t - if err := json.Unmarshal(interpret, &tt); err != nil { + type t Session + if err := json.Unmarshal(transformed, (*t)(s)); err != nil { return errors.WithStack(err) } - *s = Session(tt) return nil } diff --git a/oauth2/session_test.go b/oauth2/session_test.go index 3422645700c..18177de2614 100644 --- a/oauth2/session_test.go +++ b/oauth2/session_test.go @@ -75,12 +75,17 @@ func TestUnmarshalSession(t *testing.T) { }, } - var actual Session - require.NoError(t, json.Unmarshal(v1118Session, &actual)) - assertx.EqualAsJSON(t, expect, &actual) - snapshotx.SnapshotTExcept(t, &actual, nil) + t.Run("v1.11.8", func(t *testing.T) { + var actual Session + require.NoError(t, json.Unmarshal(v1118Session, &actual)) + assertx.EqualAsJSON(t, expect, &actual) + snapshotx.SnapshotTExcept(t, &actual, nil) + }) - require.NoError(t, json.Unmarshal(v1119Session, &actual)) - assertx.EqualAsJSON(t, expect, &actual) - snapshotx.SnapshotTExcept(t, &actual, nil) + t.Run("v1.11.9", func(t *testing.T) { + var actual Session + require.NoError(t, json.Unmarshal(v1119Session, &actual)) + assertx.EqualAsJSON(t, expect, &actual) + snapshotx.SnapshotTExcept(t, &actual, nil) + }) }