From 3bb943a9ac2d4309b43d1cb9bf27bac7cabb86f9 Mon Sep 17 00:00:00 2001
From: aeneasr <3372410+aeneasr@users.noreply.github.com>
Date: Thu, 25 Aug 2022 12:48:59 +0200
Subject: [PATCH] fix: session unmarshalling

---
 ...json => TestUnmarshalSession-v1.11.8.json} |  0
 .../TestUnmarshalSession-v1.11.9.json         | 49 +++++++++++++++++++
 oauth2/session.go                             | 24 +++++----
 oauth2/session_test.go                        | 19 ++++---
 4 files changed, 72 insertions(+), 20 deletions(-)
 rename oauth2/.snapshots/{TestUnmarshalSession.json => TestUnmarshalSession-v1.11.8.json} (100%)
 create mode 100644 oauth2/.snapshots/TestUnmarshalSession-v1.11.9.json

diff --git a/oauth2/.snapshots/TestUnmarshalSession.json b/oauth2/.snapshots/TestUnmarshalSession-v1.11.8.json
similarity index 100%
rename from oauth2/.snapshots/TestUnmarshalSession.json
rename to oauth2/.snapshots/TestUnmarshalSession-v1.11.8.json
diff --git a/oauth2/.snapshots/TestUnmarshalSession-v1.11.9.json b/oauth2/.snapshots/TestUnmarshalSession-v1.11.9.json
new file mode 100644
index 00000000000..723df624f4a
--- /dev/null
+++ b/oauth2/.snapshots/TestUnmarshalSession-v1.11.9.json
@@ -0,0 +1,49 @@
+{
+  "id_token": {
+    "id_token_claims": {
+      "jti": "",
+      "iss": "http://127.0.0.1:4444/",
+      "sub": "foo@bar.com",
+      "aud": [
+        "auth-code-client"
+      ],
+      "nonce": "mbxojlzlkefzmlecvrzfkmpm",
+      "exp": "0001-01-01T00:00:00Z",
+      "iat": "2022-08-25T09:21:04Z",
+      "rat": "2022-08-25T09:20:54Z",
+      "auth_time": "2022-08-25T09:21:01Z",
+      "at_hash": "",
+      "acr": "0",
+      "amr": [],
+      "c_hash": "",
+      "ext": {
+        "sid": "177e1f44-a1e9-415c-bfa3-8b62280b182d"
+      }
+    },
+    "headers": {
+      "extra": {
+        "kid": "public:hydra.openid.id-token"
+      }
+    },
+    "expires_at": {
+      "access_token": "2022-08-25T09:26:05Z",
+      "authorize_code": "2022-08-25T09:23:04.432089764Z",
+      "refresh_token": "2022-08-26T09:21:05Z"
+    },
+    "username": "",
+    "subject": "foo@bar.com"
+  },
+  "extra": {},
+  "kid": "public:hydra.jwt.access-token",
+  "client_id": "auth-code-client",
+  "consent_challenge": "2261efbd447044a1b2f76b05c6aca164",
+  "exclude_not_before_claim": false,
+  "allowed_top_level_claims": [
+    "persona_id",
+    "persona_krn",
+    "grantType",
+    "market",
+    "zone",
+    "login_session_id"
+  ]
+}
diff --git a/oauth2/session.go b/oauth2/session.go
index 9647a4af60e..f35e3085fad 100644
--- a/oauth2/session.go
+++ b/oauth2/session.go
@@ -157,40 +157,38 @@ var keyRewrites = map[string]string{
 	"idToken.Claims.Extra":                               "id_token.id_token_claims.ext",
 }
 
-func (s *Session) UnmarshalJSON(in []byte) (err error) {
-	type t Session
-	interpret := in
-	parsed := gjson.ParseBytes(in)
+func (s *Session) UnmarshalJSON(original []byte) (err error) {
+	transformed := original
+	originalParsed := gjson.ParseBytes(original)
 
-	for orig, update := range keyRewrites {
-		if !parsed.Get(orig).Exists() {
+	for oldKey, newKey := range keyRewrites {
+		if !originalParsed.Get(oldKey).Exists() {
 			continue
 		}
-		interpret, err = sjson.SetRawBytes(interpret, update, []byte(parsed.Get(orig).Raw))
+		transformed, err = sjson.SetRawBytes(transformed, newKey, []byte(originalParsed.Get(oldKey).Raw))
 		if err != nil {
 			return errors.WithStack(err)
 		}
 	}
 
 	for orig := range keyRewrites {
-		interpret, err = sjson.DeleteBytes(interpret, orig)
+		transformed, err = sjson.DeleteBytes(transformed, orig)
 		if err != nil {
 			return errors.WithStack(err)
 		}
 	}
 
-	if parsed.Get("idToken").Exists() {
-		interpret, err = sjson.DeleteBytes(interpret, "idToken")
+	if originalParsed.Get("idToken").Exists() {
+		transformed, err = sjson.DeleteBytes(transformed, "idToken")
 		if err != nil {
 			return errors.WithStack(err)
 		}
 	}
 
-	var tt t
-	if err := json.Unmarshal(interpret, &tt); err != nil {
+	type t Session
+	if err := json.Unmarshal(transformed, (*t)(s)); err != nil {
 		return errors.WithStack(err)
 	}
 
-	*s = Session(tt)
 	return nil
 }
diff --git a/oauth2/session_test.go b/oauth2/session_test.go
index 3422645700c..18177de2614 100644
--- a/oauth2/session_test.go
+++ b/oauth2/session_test.go
@@ -75,12 +75,17 @@ func TestUnmarshalSession(t *testing.T) {
 		},
 	}
 
-	var actual Session
-	require.NoError(t, json.Unmarshal(v1118Session, &actual))
-	assertx.EqualAsJSON(t, expect, &actual)
-	snapshotx.SnapshotTExcept(t, &actual, nil)
+	t.Run("v1.11.8", func(t *testing.T) {
+		var actual Session
+		require.NoError(t, json.Unmarshal(v1118Session, &actual))
+		assertx.EqualAsJSON(t, expect, &actual)
+		snapshotx.SnapshotTExcept(t, &actual, nil)
+	})
 
-	require.NoError(t, json.Unmarshal(v1119Session, &actual))
-	assertx.EqualAsJSON(t, expect, &actual)
-	snapshotx.SnapshotTExcept(t, &actual, nil)
+	t.Run("v1.11.9", func(t *testing.T) {
+		var actual Session
+		require.NoError(t, json.Unmarshal(v1119Session, &actual))
+		assertx.EqualAsJSON(t, expect, &actual)
+		snapshotx.SnapshotTExcept(t, &actual, nil)
+	})
 }