diff --git a/docs/oidc-client-ts.api.md b/docs/oidc-client-ts.api.md index 8235fbec6..d673d283b 100644 --- a/docs/oidc-client-ts.api.md +++ b/docs/oidc-client-ts.api.md @@ -287,6 +287,10 @@ export interface OidcAddressClaim { // @public export class OidcClient { constructor(settings: OidcClientSettings); + // Warning: (ae-forgotten-export) The symbol "ClaimsService" needs to be exported by the entry point index.d.ts + // + // (undocumented) + protected readonly _claimsService: ClaimsService; // (undocumented) clearStaleState(): Promise; // (undocumented) @@ -294,6 +298,8 @@ export class OidcClient { // (undocumented) createSignoutRequest({ state, id_token_hint, request_type, post_logout_redirect_uri, extraQueryParams, }?: CreateSignoutRequestArgs): Promise; // (undocumented) + getUserInfo(token: string, profile: IdTokenClaims): Promise; + // (undocumented) protected readonly _logger: Logger; // (undocumented) readonly metadataService: MetadataService; @@ -323,6 +329,10 @@ export class OidcClient { protected readonly _tokenClient: TokenClient; // (undocumented) useRefreshToken({ state, timeoutInSeconds, }: UseRefreshTokenArgs): Promise; + // Warning: (ae-forgotten-export) The symbol "UserInfoService" needs to be exported by the entry point index.d.ts + // + // (undocumented) + protected readonly _userInfoService: UserInfoService; // Warning: (ae-forgotten-export) The symbol "ResponseValidator" needs to be exported by the entry point index.d.ts // // (undocumented) @@ -345,6 +355,7 @@ export interface OidcClientSettings { extraTokenParams?: Record; fetchRequestCredentials?: RequestCredentials; filterProtocolClaims?: boolean | string[]; + legacyMergeClaimsBehavior?: boolean; loadUserInfo?: boolean; max_age?: number; mergeClaims?: boolean; @@ -373,7 +384,7 @@ export interface OidcClientSettings { // @public export class OidcClientSettingsStore { - constructor({ authority, metadataUrl, metadata, signingKeys, metadataSeed, client_id, client_secret, response_type, scope, redirect_uri, post_logout_redirect_uri, client_authentication, prompt, display, max_age, ui_locales, acr_values, resource, response_mode, filterProtocolClaims, loadUserInfo, staleStateAgeInSeconds, clockSkewInSeconds, userInfoJwtIssuer, mergeClaims, stateStore, refreshTokenCredentials, revokeTokenAdditionalContentTypes, fetchRequestCredentials, refreshTokenAllowedScope, extraQueryParams, extraTokenParams, }: OidcClientSettings); + constructor({ authority, metadataUrl, metadata, signingKeys, metadataSeed, client_id, client_secret, response_type, scope, redirect_uri, post_logout_redirect_uri, client_authentication, prompt, display, max_age, ui_locales, acr_values, resource, response_mode, filterProtocolClaims, loadUserInfo, staleStateAgeInSeconds, clockSkewInSeconds, userInfoJwtIssuer, mergeClaims, legacyMergeClaimsBehavior, stateStore, refreshTokenCredentials, revokeTokenAdditionalContentTypes, fetchRequestCredentials, refreshTokenAllowedScope, extraQueryParams, extraTokenParams, }: OidcClientSettings); // (undocumented) readonly acr_values: string | undefined; // (undocumented) @@ -396,6 +407,7 @@ export class OidcClientSettingsStore { readonly fetchRequestCredentials: RequestCredentials; // (undocumented) readonly filterProtocolClaims: boolean | string[]; + readonly legacyMergeClaimsBehavior: boolean; // (undocumented) readonly loadUserInfo: boolean; // (undocumented) @@ -903,7 +915,7 @@ export class UserManager { get events(): UserManagerEvents; // (undocumented) protected readonly _events: UserManagerEvents; - getUser(): Promise; + getUser(refreshUserInfo?: boolean): Promise; // Warning: (ae-forgotten-export) The symbol "IFrameNavigator" needs to be exported by the entry point index.d.ts // // (undocumented) diff --git a/src/ClaimsService.test.ts b/src/ClaimsService.test.ts new file mode 100644 index 000000000..29a10278d --- /dev/null +++ b/src/ClaimsService.test.ts @@ -0,0 +1,337 @@ +// Copyright (c) Brock Allen & Dominick Baier. All rights reserved. +// Licensed under the Apache License, Version 2.0. See LICENSE in the project root for license information. + +import { ClaimsService } from "./ClaimsService"; +import type { OidcClientSettingsStore } from "./OidcClientSettings"; +import type { UserProfile } from "./User"; + +describe("ClaimsService", () => { + let settings: OidcClientSettingsStore; + let subject: ClaimsService; + + beforeEach(() => { + settings = { + authority: "op", + client_id: "client", + loadUserInfo: true, + } as OidcClientSettingsStore; + + subject = new ClaimsService(settings); + }); + + afterEach(() => { + jest.restoreAllMocks(); + }); + + describe("filterProtocolClaims", () => { + it("should filter protocol claims if enabled on settings", () => { + // arrange + Object.assign(settings, { filterProtocolClaims: true }); + const claims = { + foo: 1, bar: "test", + aud: "some_aud", iss: "issuer", + sub: "123", email: "foo@gmail.com", + role: ["admin", "dev"], + iat: 5, exp: 20, + nbf: 10, at_hash: "athash", + }; + + // act + const result = subject["filterProtocolClaims"](claims); + + // assert + expect(result).toEqual({ + foo: 1, bar: "test", + aud: "some_aud", iss: "issuer", + sub: "123", email: "foo@gmail.com", + role: ["admin", "dev"], + iat: 5, exp: 20, + }); + }); + + it("should not filter protocol claims if not enabled on settings", () => { + // arrange + Object.assign(settings, { filterProtocolClaims: false }); + const claims = { + foo: 1, bar: "test", + aud: "some_aud", iss: "issuer", + sub: "123", email: "foo@gmail.com", + role: ["admin", "dev"], + at_hash: "athash", + iat: 5, nbf: 10, exp: 20, + }; + + // act + const result = subject["filterProtocolClaims"](claims); + + // assert + expect(result).toEqual({ + foo: 1, bar: "test", + aud: "some_aud", iss: "issuer", + sub: "123", email: "foo@gmail.com", + role: ["admin", "dev"], + at_hash: "athash", + iat: 5, nbf: 10, exp: 20, + }); + }); + + it("should filter protocol claims if specified in settings", () => { + // arrange + Object.assign(settings, { filterProtocolClaims: ["foo", "bar", "role", "nbf", "email"] }); + const claims = { + foo: 1, bar: "test", + aud: "some_aud", iss: "issuer", + sub: "123", email: "foo@gmail.com", + role: ["admin", "dev"], + iat: 5, exp: 20, + nbf: 10, at_hash: "athash", + }; + + // act + const result = subject["filterProtocolClaims"](claims); + + // assert + expect(result).toEqual({ + aud: "some_aud", iss: "issuer", + sub: "123", + iat: 5, exp: 20, + at_hash: "athash", + }); + }); + + it("should filter only protocol claims defined by default by the library", () => { + // arrange + Object.assign(settings, { filterProtocolClaims: true }); + const defaultProtocolClaims = { + nbf: 3, jti: "jti", + auth_time: 123, + nonce: "nonce", + acr: "acr", + amr: "amr", + azp: "azp", + at_hash: "athash", + }; + const claims = { + foo: 1, bar: "test", + aud: "some_aud", iss: "issuer", + sub: "123", email: "foo@gmail.com", + role: ["admin", "dev"], + iat: 5, exp: 20, + }; + + // act + const result = subject["filterProtocolClaims"]({ ...defaultProtocolClaims, ...claims }); + + // assert + expect(result).toEqual(claims); + }); + + it("should not filter protocol claims that are required by the library", () => { + // arrange + Object.assign(settings, { filterProtocolClaims: true }); + const internalRequiredProtocolClaims = { + sub: "sub", + iss: "issuer", + aud: "some_aud", + exp: 20, + iat: 5, + }; + const claims = { + foo: 1, bar: "test", + email: "foo@gmail.com", + role: ["admin", "dev"], + nbf: 10, + }; + + // act + let items = { ...internalRequiredProtocolClaims, ...claims }; + let result = subject["filterProtocolClaims"](items); + + // assert + // nbf is part of the claims that should be filtered by the library by default, so we need to remove it + delete (items as Partial).nbf; + expect(result).toEqual(items); + + // ... even if specified in settings + + // arrange + Object.assign(settings, { filterProtocolClaims: ["sub", "iss", "aud", "exp", "iat"] }); + + // act + items = { ...internalRequiredProtocolClaims, ...claims }; + result = subject["filterProtocolClaims"](items); + + // assert + expect(result).toEqual(items); + }); + }); + + describe("mergeClaims", () => { + it("should merge claims", () => { + // arrange + const c1 = { a: "apple", b: "banana" } as unknown as UserProfile; + const c2 = { c: "carrot" }; + + // act + const result = subject["mergeClaims"](c1, c2); + + // assert + expect(result).toEqual({ a: "apple", c: "carrot", b: "banana" }); + }); + + it("should not merge (but append) claims when claim types are objects, if using legacy merge behavior", () => { + // arrange + Object.assign(settings, { legacyMergeClaimsBehavior: true }); + + const c1 = { custom: { "apple": "foo", "pear": "bar" } } as unknown as UserProfile; + const c2 = { custom: { "apple": "foo", "orange": "peel" }, b: "banana" }; + + // act + const result = subject["mergeClaims"](c1, c2); + + // assert + expect(result).toEqual({ custom: [{ "apple": "foo", "pear": "bar" }, { "apple": "foo", "orange": "peel" }], b: "banana" }); + }); + + it("should overwrite claims when claim types are objects", () => { + // arrange + const c1 = { custom: { "apple": "foo", "pear": "bar" } } as unknown as UserProfile; + const c2 = { custom: { "apple": "foo", "orange": "peel" }, b: "banana" }; + + // act + const result = subject["mergeClaims"](c1, c2); + + // assert + expect(result).toEqual({ custom: { "apple": "foo", "orange": "peel" }, b: "banana" }); + }); + + it("should merge claims when claim types are objects when mergeClaims settings is true, if using legacy merge behavior", () => { + // arrange + Object.assign(settings, { mergeClaims: true, legacyMergeClaimsBehavior: true }); + + const c1 = { custom: { "apple": "foo", "pear": "bar" } } as unknown as UserProfile; + const c2 = { custom: { "apple": "foo", "orange": "peel" }, b: "banana" }; + + // act + const result = subject["mergeClaims"](c1, c2); + + // assert + expect(result).toEqual({ custom: { "apple": "foo", "pear": "bar", "orange": "peel" }, b: "banana" }); + }); + + it("should merge claims when claim types are objects when mergeClaims settings is true", () => { + // arrange + Object.assign(settings, { mergeClaims: true }); + + const c1 = { custom: { "apple": "foo", "pear": "bar" } } as unknown as UserProfile; + const c2 = { custom: { "apple": "foo", "orange": "peel" }, b: "banana" }; + + // act + const result = subject["mergeClaims"](c1, c2); + + // assert + expect(result).toEqual({ custom: { "apple": "foo", "pear": "bar", "orange": "peel" }, b: "banana" }); + }); + + it("should merge same claim types into array, if using legacy merge behavior", () => { + // arrange + Object.assign(settings, { legacyMergeClaimsBehavior: true }); + + const c1 = { a: "apple", b: "banana" } as unknown as UserProfile; + const c2 = { a: "carrot" }; + + // act + const result = subject["mergeClaims"](c1, c2); + + // assert + expect(result).toEqual({ a: ["apple", "carrot"], b: "banana" }); + }); + + it("should overwrite same claim", () => { + // arrange + const c1 = { a: "apple", b: "banana" } as unknown as UserProfile; + const c2 = { a: "carrot" }; + + // act + const result = subject["mergeClaims"](c1, c2); + + // assert + expect(result).toEqual({ a: "carrot", b: "banana" }); + }); + + it("should merge arrays of same claim types into array, if using legacy merge behavior", () => { + // arrange + Object.assign(settings, { legacyMergeClaimsBehavior: true }); + + const c1 = { a: "apple", b: "banana" } as unknown as UserProfile; + const c2 = { a: ["carrot", "durian"] }; + + // act + let result = subject["mergeClaims"](c1, c2); + + // assert + expect(result).toEqual({ a: ["apple", "carrot", "durian"], b: "banana" }); + + // arrange + const d1 = { a: ["apple", "carrot"], b: "banana" } as unknown as UserProfile; + const d2 = { a: ["durian"] }; + + // act + result = subject["mergeClaims"](d1, d2); + + // assert + expect(result).toEqual({ a: ["apple", "carrot", "durian"], b: "banana" }); + + // arrange + const e1 = { a: ["apple", "carrot"], b: "banana" } as unknown as UserProfile; + const e2 = { a: "durian" }; + + // act + result = subject["mergeClaims"](e1, e2); + + // assert + expect(result).toEqual({ a: ["apple", "carrot", "durian"], b: "banana" }); + }); + + it("should remove duplicates when producing arrays, if using legacy merge behavior", () => { + // arrange + Object.assign(settings, { legacyMergeClaimsBehavior: true }); + + const c1 = { a: "apple", b: "banana" } as unknown as UserProfile; + const c2 = { a: ["apple", "durian"] }; + + // act + const result = subject["mergeClaims"](c1, c2); + + // assert + expect(result).toEqual({ a: ["apple", "durian"], b: "banana" }); + }); + + it("should not add if already present in array, if using legacy merge behavior", () => { + // arrange + Object.assign(settings, { legacyMergeClaimsBehavior: true }); + + const c1 = { a: ["apple", "durian"], b: "banana" } as unknown as UserProfile; + const c2 = { a: "apple" }; + + // act + const result = subject["mergeClaims"](c1, c2); + + // assert + expect(result).toEqual({ a: ["apple", "durian"], b: "banana" }); + }); + + it("should override array", () => { + // arrange + const c1 = { a: ["apple", "banana"], b: "banana" } as unknown as UserProfile; + const c2 = { a: ["orange", "durian"] }; + + // act + const result = subject["mergeClaims"](c1, c2); + + // assert + expect(result).toEqual({ a: ["orange", "durian"], b: "banana" }); + }); + + }); +}); diff --git a/src/ClaimsService.ts b/src/ClaimsService.ts new file mode 100644 index 000000000..998149684 --- /dev/null +++ b/src/ClaimsService.ts @@ -0,0 +1,119 @@ +// Copyright (c) Brock Allen & Dominick Baier. All rights reserved. +// Licensed under the Apache License, Version 2.0. See LICENSE in the project root for license information. + +import type { JwtClaims } from "./Claims"; +import type { OidcClientSettingsStore } from "./OidcClientSettings"; +import type { UserProfile } from "./User"; +import { Logger } from "./utils"; + +/** + * Protocol claims that could be removed by default from profile. + * Derived from the following sets of claims: + * - {@link https://datatracker.ietf.org/doc/html/rfc7519.html#section-4.1} + * - {@link https://openid.net/specs/openid-connect-core-1_0.html#IDToken} + * - {@link https://openid.net/specs/openid-connect-core-1_0.html#CodeIDToken} + * + * @internal + */ +const DefaultProtocolClaims = [ + "nbf", + "jti", + "auth_time", + "nonce", + "acr", + "amr", + "azp", + "at_hash", // https://openid.net/specs/openid-connect-core-1_0.html#CodeIDToken +] as const; + +/** + * Protocol claims that should never be removed from profile. + * "sub" is needed internally and others should remain required as per the OIDC specs. + * + * @internal + */ +const InternalRequiredProtocolClaims = ["sub", "iss", "aud", "exp", "iat"]; + +/** + * @internal + */ +export class ClaimsService { + protected readonly _logger = new Logger("ClaimsService"); + public constructor( + protected readonly _settings: OidcClientSettingsStore, + ) {} + + public filterProtocolClaims(claims: UserProfile): UserProfile { + const result = { ...claims }; + + if (this._settings.filterProtocolClaims) { + let protocolClaims; + if (Array.isArray(this._settings.filterProtocolClaims)) { + protocolClaims = this._settings.filterProtocolClaims; + } else { + protocolClaims = DefaultProtocolClaims; + } + + for (const claim of protocolClaims) { + if (!InternalRequiredProtocolClaims.includes(claim)) { + delete result[claim]; + } + } + } + + return result; + } + + private legacyMergeClaims(claims1: UserProfile, claims2: JwtClaims): UserProfile { + const result = { ...claims1 }; + for (const [claim, values] of Object.entries(claims2)) { + for (const value of Array.isArray(values) ? values : [values]) { + const previousValue = result[claim]; + if (!previousValue) { + result[claim] = value; + } + else if (Array.isArray(previousValue)) { + if (!previousValue.includes(value)) { + previousValue.push(value); + } + } + else if (result[claim] !== value) { + if (typeof value === "object" && this._settings.mergeClaims) { + result[claim] = this.mergeClaims(previousValue as UserProfile, value); + } + else { + result[claim] = [previousValue, value]; + } + } + } + } + + return result; + } + + public mergeClaims(claims1: UserProfile, claims2: JwtClaims): UserProfile; + public mergeClaims(claims1: JwtClaims, claims2: JwtClaims): JwtClaims; + public mergeClaims(claims1: JwtClaims | UserProfile, claims2: JwtClaims): JwtClaims | UserProfile { + // TODO: remove on next major version + if (this._settings.legacyMergeClaimsBehavior) { + return this.legacyMergeClaims(claims1 as UserProfile, claims2); + } + + const result = { ...claims1 }; + + for (const [claim, value] of Object.entries(claims2)) { + if (result[claim] !== value) { + if (typeof result[claim] === "object" + && typeof value === "object" + && this._settings.mergeClaims + ) { + result[claim] = this.mergeClaims(result[claim] as JwtClaims, value as JwtClaims); + } else { + result[claim] = value; + } + } + } + + return result; + } +} diff --git a/src/OidcClient.test.ts b/src/OidcClient.test.ts index dacbe39fc..af44e1f98 100644 --- a/src/OidcClient.test.ts +++ b/src/OidcClient.test.ts @@ -814,4 +814,36 @@ describe("OidcClient", () => { }); }); }); + + describe("getUserInfo", () => { + it("gets user info", async () => { + // arrange + const claims = { + aud: "aud", + exp: 0, + iat: 0, + iss: "iss", + sub: "sub", + }; + + const getClaimsSpy = jest.spyOn(subject["_userInfoService"], "getClaims").mockResolvedValue({ + aud: "aud", + exp: 0, + iat: 0, + iss: "iss", + sub: "sub", + a: "apple", + }); + + // act + await subject.getUserInfo("access_token", claims); + + // assert + expect(getClaimsSpy).toHaveBeenCalledWith( + "access_token", + claims, + true, + ); + }); + }); }); diff --git a/src/OidcClient.ts b/src/OidcClient.ts index 30508bd67..e7b2bbf02 100644 --- a/src/OidcClient.ts +++ b/src/OidcClient.ts @@ -14,6 +14,9 @@ import { SignoutResponse } from "./SignoutResponse"; import { SigninState } from "./SigninState"; import { State } from "./State"; import { TokenClient } from "./TokenClient"; +import { UserInfoService } from "./UserInfoService"; +import { ClaimsService } from "./ClaimsService"; +import type { IdTokenClaims } from "./Claims"; /** * @public @@ -83,12 +86,16 @@ export class OidcClient { public readonly metadataService: MetadataService; protected readonly _validator: ResponseValidator; protected readonly _tokenClient: TokenClient; + protected readonly _userInfoService: UserInfoService; + protected readonly _claimsService: ClaimsService; public constructor(settings: OidcClientSettings) { this.settings = new OidcClientSettingsStore(settings); this.metadataService = new MetadataService(this.settings); - this._validator = new ResponseValidator(this.settings, this.metadataService); + this._claimsService = new ClaimsService(this.settings); + this._userInfoService = new UserInfoService(this.settings, this.metadataService, this._claimsService); + this._validator = new ResponseValidator(this.settings, this.metadataService, this._claimsService, this._userInfoService); this._tokenClient = new TokenClient(this.settings, this.metadataService); } @@ -314,4 +321,8 @@ export class OidcClient { token_type_hint: type, }); } + + public getUserInfo(token: string, profile: IdTokenClaims): Promise { + return this._userInfoService.getClaims(token, profile, true); + } } diff --git a/src/OidcClientSettings.ts b/src/OidcClientSettings.ts index 53ea1abac..00a91ed41 100644 --- a/src/OidcClientSettings.ts +++ b/src/OidcClientSettings.ts @@ -93,6 +93,12 @@ export interface OidcClientSettings { */ mergeClaims?: boolean; + /** + * Indicates if the library should use the legacy merge behavior, which mutates string claims into arrays whenever the data is updated on the remote. + * This behavior is enabled by default on v2 and will be removed on v3, since it's not a deterministic way of handling claims. + */ + legacyMergeClaimsBehavior?: boolean; + /** * Storage object used to persist interaction state (default: window.localStorage, InMemoryWebStorage iff no window). * E.g. `stateStore: new WebStorageStateStore({ store: window.localStorage })` @@ -168,6 +174,10 @@ export class OidcClientSettingsStore { public readonly clockSkewInSeconds: number; public readonly userInfoJwtIssuer: "ANY" | "OP" | string; public readonly mergeClaims: boolean; + /** + * TODO: remove me on v3 + */ + public readonly legacyMergeClaimsBehavior: boolean; public readonly stateStore: StateStore; @@ -195,6 +205,7 @@ export class OidcClientSettingsStore { clockSkewInSeconds = DefaultClockSkewInSeconds, userInfoJwtIssuer = "OP", mergeClaims = false, + legacyMergeClaimsBehavior = true, // other behavior stateStore, refreshTokenCredentials, @@ -246,6 +257,7 @@ export class OidcClientSettingsStore { this.clockSkewInSeconds = clockSkewInSeconds; this.userInfoJwtIssuer = userInfoJwtIssuer; this.mergeClaims = !!mergeClaims; + this.legacyMergeClaimsBehavior = !!legacyMergeClaimsBehavior; this.revokeTokenAdditionalContentTypes = revokeTokenAdditionalContentTypes; diff --git a/src/ResponseValidator.test.ts b/src/ResponseValidator.test.ts index 7f86a3224..2bef77ca6 100644 --- a/src/ResponseValidator.test.ts +++ b/src/ResponseValidator.test.ts @@ -8,15 +8,20 @@ import { MetadataService } from "./MetadataService"; import type { SigninState } from "./SigninState"; import type { SigninResponse } from "./SigninResponse"; import type { SignoutResponse } from "./SignoutResponse"; -import type { UserProfile } from "./User"; import type { OidcClientSettingsStore } from "./OidcClientSettings"; import { mocked } from "jest-mock"; +import { ClaimsService } from "./ClaimsService"; +import { UserInfoService } from "./UserInfoService"; +import type { IdTokenClaims } from "./Claims"; describe("ResponseValidator", () => { let stubState: SigninState; let stubResponse: SigninResponse & SignoutResponse; + let stubClaimsResponse: IdTokenClaims; let settings: OidcClientSettingsStore; let metadataService: MetadataService; + let claimsService: ClaimsService; + let userInfoService: UserInfoService; let subject: ResponseValidator; beforeEach(() => { @@ -36,11 +41,15 @@ describe("ResponseValidator", () => { client_id: "client", loadUserInfo: true, } as OidcClientSettingsStore; + stubClaimsResponse = { nickname: "Nick", sub: "sub", iss: "iss", aud: "aud", exp: 0, iat: 0 }; metadataService = new MetadataService(settings); + claimsService = new ClaimsService(settings); + userInfoService = new UserInfoService(settings, metadataService, claimsService); + + subject = new ResponseValidator(settings, metadataService, claimsService, userInfoService); - subject = new ResponseValidator(settings, metadataService); jest.spyOn(subject["_tokenClient"], "exchangeCode").mockResolvedValue({}); - jest.spyOn(subject["_userInfoService"], "getClaims").mockResolvedValue({ nickname: "Nick" }); + jest.spyOn(subject["_userInfoService"], "getClaims").mockResolvedValue(stubClaimsResponse); }); afterEach(() => { @@ -111,14 +120,15 @@ describe("ResponseValidator", () => { access_token: "access_token", id_token: "id_token", }); + const claims = { sub: "sub", iss: "iss", aud: "aud", exp: 0, iat: 0 }; jest.spyOn(JwtUtils, "decode").mockReturnValue({ sub: "sub" }); - mocked(subject["_userInfoService"].getClaims).mockResolvedValue({ sub: "sub" }); + mocked(subject["_userInfoService"].getClaims).mockResolvedValue(claims); // act await subject.validateSigninResponse(stubResponse, stubState); // assert - expect(subject["_userInfoService"].getClaims).toHaveBeenCalledWith("access_token"); + expect(subject["_userInfoService"].getClaims).toHaveBeenCalledWith("access_token", { sub: "sub" }, true); }); it("should not process claims if state fails", async () => { @@ -262,55 +272,6 @@ describe("ResponseValidator", () => { expect(stubResponse.profile).toHaveProperty("iss", "foo"); }); - it("should fail if sub from user info endpoint does not match sub in id_token", async () => { - // arrange - Object.assign(settings, { loadUserInfo: true }); - Object.assign(stubResponse, { - isOpenId: true, - access_token: "access_token", - id_token: "id_token", - }); - jest.spyOn(JwtUtils, "decode").mockReturnValue({ - sub: "sub", - a: "apple", - b: "banana", - }); - mocked(subject["_userInfoService"].getClaims).mockResolvedValue({ sub: "sub different" }); - - // act - await expect(subject.validateSigninResponse(stubResponse, stubState)) - // assert - .rejects.toThrow("subject from UserInfo response does not match subject in ID Token"); - }); - - it("should load and merge user info claims when loadUserInfo configured", async () => { - // arrange - Object.assign(settings, { loadUserInfo: true }); - Object.assign(stubResponse, { - isOpenId: true, - access_token: "access_token", - id_token: "id_token", - }); - jest.spyOn(JwtUtils, "decode").mockReturnValue({ - sub: "sub", - a: "apple", - b: "banana", - }); - mocked(subject["_userInfoService"].getClaims).mockResolvedValue({ sub: "sub", c: "carrot" }); - - // act - await subject.validateSigninResponse(stubResponse, stubState); - - // assert - expect(subject["_userInfoService"].getClaims).toHaveBeenCalledWith("access_token"); - expect(stubResponse.profile).toEqual({ - sub: "sub", - a: "apple", - b: "banana", - c: "carrot", - }); - }); - it("should run if request was not openid", async () => { // arrange Object.assign(settings, { loadUserInfo: true }); @@ -566,290 +527,25 @@ describe("ResponseValidator", () => { // assert expect(JwtUtils.decode).not.toHaveBeenCalledWith("id_token"); - expect(subject["_userInfoService"].getClaims).toHaveBeenCalledWith("access_token"); - expect(stubResponse).toHaveProperty("profile", { nickname: "Nick" }); - }); - - it("should not process a valid openid signin response with wrong userInfo", async () => { - // arrange - Object.assign(stubResponse, { id_token: "id_token", isOpenId: true, access_token: "access_token" }); - jest.spyOn(JwtUtils, "decode").mockReturnValue({ sub: "subsub" }); - jest.spyOn(subject["_userInfoService"], "getClaims").mockResolvedValue({ sub: "anotherSub", nickname: "Nick" }); - - // act - await expect(subject.validateCredentialsResponse(stubResponse, false)) - // assert - .rejects.toThrow(Error); - expect(JwtUtils.decode).toHaveBeenCalledWith("id_token"); - expect(subject["_userInfoService"].getClaims).toHaveBeenCalledWith("access_token"); - expect(stubResponse).toHaveProperty("profile", { sub: "subsub" }); + expect(subject["_userInfoService"].getClaims).toHaveBeenCalledWith("access_token", {}, false); + expect(stubResponse).toHaveProperty("profile", stubClaimsResponse); }); it("should process a valid openid signin response with correct userInfo", async () => { // arrange Object.assign(stubResponse, { id_token: "id_token", isOpenId: true, access_token: "access_token" }); jest.spyOn(JwtUtils, "decode").mockReturnValue({ sub: "subsub" }); - jest.spyOn(subject["_userInfoService"], "getClaims").mockResolvedValue({ sub: "subsub", nickname: "Nick" }); + const claimResponse = { sub: "subsub", nickname: "Nick", iss: "iss", aud: "aud", exp: 0, iat: 0 }; + jest.spyOn(subject["_userInfoService"], "getClaims").mockResolvedValue(claimResponse); // act await subject.validateCredentialsResponse(stubResponse, false); // assert expect(JwtUtils.decode).toHaveBeenCalledWith("id_token"); - expect(subject["_userInfoService"].getClaims).toHaveBeenCalledWith("access_token"); - expect(stubResponse).toHaveProperty("profile", { sub: "subsub", nickname: "Nick" }); - }); - - }); - - describe("_mergeClaims", () => { - it("should merge claims", () => { - // arrange - const c1 = { a: "apple", b: "banana" } as unknown as UserProfile; - const c2 = { c: "carrot" }; - - // act - const result = subject["_mergeClaims"](c1, c2); - - // assert - expect(result).toEqual({ a: "apple", c: "carrot", b: "banana" }); - }); - - it("should not merge claims when claim types are objects", () => { - // arrange - const c1 = { custom: { "apple": "foo", "pear": "bar" } } as unknown as UserProfile; - const c2 = { custom: { "apple": "foo", "orange": "peel" }, b: "banana" }; - - // act - const result = subject["_mergeClaims"](c1, c2); - - // assert - expect(result).toEqual({ custom: [{ "apple": "foo", "pear": "bar" }, { "apple": "foo", "orange": "peel" }], b: "banana" }); - }); - - it("should merge claims when claim types are objects when mergeClaims settings is true", () => { - // arrange - Object.assign(settings, { mergeClaims: true }); - - const c1 = { custom: { "apple": "foo", "pear": "bar" } } as unknown as UserProfile; - const c2 = { custom: { "apple": "foo", "orange": "peel" }, b: "banana" }; - - // act - const result = subject["_mergeClaims"](c1, c2); - - // assert - expect(result).toEqual({ custom: { "apple": "foo", "pear": "bar", "orange": "peel" }, b: "banana" }); - }); - - it("should merge same claim types into array", () => { - // arrange - const c1 = { a: "apple", b: "banana" } as unknown as UserProfile; - const c2 = { a: "carrot" }; - - // act - const result = subject["_mergeClaims"](c1, c2); - - // assert - expect(result).toEqual({ a: ["apple", "carrot"], b: "banana" }); - }); - - it("should merge arrays of same claim types into array", () => { - // arrange - const c1 = { a: "apple", b: "banana" } as unknown as UserProfile; - const c2 = { a: ["carrot", "durian"] }; - - // act - let result = subject["_mergeClaims"](c1, c2); - - // assert - expect(result).toEqual({ a: ["apple", "carrot", "durian"], b: "banana" }); - - // arrange - const d1 = { a: ["apple", "carrot"], b: "banana" } as unknown as UserProfile; - const d2 = { a: ["durian"] }; - - // act - result = subject["_mergeClaims"](d1, d2); - - // assert - expect(result).toEqual({ a: ["apple", "carrot", "durian"], b: "banana" }); - - // arrange - const e1 = { a: ["apple", "carrot"], b: "banana" } as unknown as UserProfile; - const e2 = { a: "durian" }; - - // act - result = subject["_mergeClaims"](e1, e2); - - // assert - expect(result).toEqual({ a: ["apple", "carrot", "durian"], b: "banana" }); - }); - - it("should remove duplicates when producing arrays", () => { - // arrange - const c1 = { a: "apple", b: "banana" } as unknown as UserProfile; - const c2 = { a: ["apple", "durian"] }; - - // act - const result = subject["_mergeClaims"](c1, c2); - - // assert - expect(result).toEqual({ a: ["apple", "durian"], b: "banana" }); - }); - - it("should not add if already present in array", () => { - // arrange - const c1 = { a: ["apple", "durian"], b: "banana" } as unknown as UserProfile; - const c2 = { a: "apple" }; - - // act - const result = subject["_mergeClaims"](c1, c2); - - // assert - expect(result).toEqual({ a: ["apple", "durian"], b: "banana" }); - }); - }); - - describe("_filterProtocolClaims", () => { - it("should filter protocol claims if enabled on settings", () => { - // arrange - Object.assign(settings, { filterProtocolClaims: true }); - const claims = { - foo: 1, bar: "test", - aud: "some_aud", iss: "issuer", - sub: "123", email: "foo@gmail.com", - role: ["admin", "dev"], - iat: 5, exp: 20, - nbf: 10, at_hash: "athash", - }; - - // act - const result = subject["_filterProtocolClaims"](claims); - - // assert - expect(result).toEqual({ - foo: 1, bar: "test", - aud: "some_aud", iss: "issuer", - sub: "123", email: "foo@gmail.com", - role: ["admin", "dev"], - iat: 5, exp: 20, - }); + expect(subject["_userInfoService"].getClaims).toHaveBeenCalledWith("access_token", { sub: "subsub" }, true); + expect(stubResponse).toHaveProperty("profile", claimResponse); }); - it("should not filter protocol claims if not enabled on settings", () => { - // arrange - Object.assign(settings, { filterProtocolClaims: false }); - const claims = { - foo: 1, bar: "test", - aud: "some_aud", iss: "issuer", - sub: "123", email: "foo@gmail.com", - role: ["admin", "dev"], - at_hash: "athash", - iat: 5, nbf: 10, exp: 20, - }; - - // act - const result = subject["_filterProtocolClaims"](claims); - - // assert - expect(result).toEqual({ - foo: 1, bar: "test", - aud: "some_aud", iss: "issuer", - sub: "123", email: "foo@gmail.com", - role: ["admin", "dev"], - at_hash: "athash", - iat: 5, nbf: 10, exp: 20, - }); - }); - - it("should filter protocol claims if specified in settings", () => { - // arrange - Object.assign(settings, { filterProtocolClaims: ["foo", "bar", "role", "nbf", "email"] }); - const claims = { - foo: 1, bar: "test", - aud: "some_aud", iss: "issuer", - sub: "123", email: "foo@gmail.com", - role: ["admin", "dev"], - iat: 5, exp: 20, - nbf: 10, at_hash: "athash", - }; - - // act - const result = subject["_filterProtocolClaims"](claims); - - // assert - expect(result).toEqual({ - aud: "some_aud", iss: "issuer", - sub: "123", - iat: 5, exp: 20, - at_hash: "athash", - }); - }); - - it("should filter only protocol claims defined by default by the library", () => { - // arrange - Object.assign(settings, { filterProtocolClaims: true }); - const defaultProtocolClaims = { - nbf: 3, jti: "jti", - auth_time: 123, - nonce: "nonce", - acr: "acr", - amr: "amr", - azp: "azp", - at_hash: "athash", - }; - const claims = { - foo: 1, bar: "test", - aud: "some_aud", iss: "issuer", - sub: "123", email: "foo@gmail.com", - role: ["admin", "dev"], - iat: 5, exp: 20, - }; - - // act - const result = subject["_filterProtocolClaims"]({ ...defaultProtocolClaims, ...claims }); - - // assert - expect(result).toEqual(claims); - }); - - it("should not filter protocol claims that are required by the library", () => { - // arrange - Object.assign(settings, { filterProtocolClaims: true }); - const internalRequiredProtocolClaims = { - sub: "sub", - iss: "issuer", - aud: "some_aud", - exp: 20, - iat: 5, - }; - const claims = { - foo: 1, bar: "test", - email: "foo@gmail.com", - role: ["admin", "dev"], - nbf: 10, - }; - - // act - let items = { ...internalRequiredProtocolClaims, ...claims }; - let result = subject["_filterProtocolClaims"](items); - - // assert - // nbf is part of the claims that should be filtered by the library by default, so we need to remove it - delete (items as Partial).nbf; - expect(result).toEqual(items); - - // ... even if specified in settings - - // arrange - Object.assign(settings, { filterProtocolClaims: ["sub", "iss", "aud", "exp", "iat"] }); - - // act - items = { ...internalRequiredProtocolClaims, ...claims }; - result = subject["_filterProtocolClaims"](items); - - // assert - expect(result).toEqual(items); - }); }); }); diff --git a/src/ResponseValidator.ts b/src/ResponseValidator.ts index 637a3a156..fac2e82c4 100644 --- a/src/ResponseValidator.ts +++ b/src/ResponseValidator.ts @@ -4,7 +4,7 @@ import { Logger, JwtUtils } from "./utils"; import { ErrorResponse } from "./errors"; import type { MetadataService } from "./MetadataService"; -import { UserInfoService } from "./UserInfoService"; +import type { UserInfoService } from "./UserInfoService"; import { TokenClient } from "./TokenClient"; import type { OidcClientSettingsStore } from "./OidcClientSettings"; import type { SigninState } from "./SigninState"; @@ -13,47 +13,20 @@ import type { State } from "./State"; import type { SignoutResponse } from "./SignoutResponse"; import type { UserProfile } from "./User"; import type { RefreshState } from "./RefreshState"; -import type { JwtClaims, IdTokenClaims } from "./Claims"; - -/** - * Protocol claims that could be removed by default from profile. - * Derived from the following sets of claims: - * - {@link https://datatracker.ietf.org/doc/html/rfc7519.html#section-4.1} - * - {@link https://openid.net/specs/openid-connect-core-1_0.html#IDToken} - * - {@link https://openid.net/specs/openid-connect-core-1_0.html#CodeIDToken} - * - * @internal - */ -const DefaultProtocolClaims = [ - "nbf", - "jti", - "auth_time", - "nonce", - "acr", - "amr", - "azp", - "at_hash", // https://openid.net/specs/openid-connect-core-1_0.html#CodeIDToken -] as const; - -/** - * Protocol claims that should never be removed from profile. - * "sub" is needed internally and others should remain required as per the OIDC specs. - * - * @internal - */ -const InternalRequiredProtocolClaims = ["sub", "iss", "aud", "exp", "iat"]; +import type { ClaimsService } from "./ClaimsService"; /** * @internal */ export class ResponseValidator { protected readonly _logger = new Logger("ResponseValidator"); - protected readonly _userInfoService = new UserInfoService(this._settings, this._metadataService); protected readonly _tokenClient = new TokenClient(this._settings, this._metadataService); public constructor( protected readonly _settings: OidcClientSettingsStore, protected readonly _metadataService: MetadataService, + protected readonly _claimsService: ClaimsService, + protected readonly _userInfoService: UserInfoService, ) {} public async validateSigninResponse(response: SigninResponse, state: SigninState): Promise { @@ -176,9 +149,9 @@ export class ResponseValidator { } } - protected async _processClaims(response: SigninResponse, skipUserInfo = false, validateSub = true): Promise { + protected async _processClaims(response: SigninResponse, skipUserInfo = false, validateSub?: boolean): Promise { const logger = this._logger.create("_processClaims"); - response.profile = this._filterProtocolClaims(response.profile); + response.profile = this._claimsService.filterProtocolClaims(response.profile); if (skipUserInfo || !this._settings.loadUserInfo || !response.access_token) { logger.debug("not loading user info"); @@ -186,66 +159,10 @@ export class ResponseValidator { } logger.debug("loading user info"); - const claims = await this._userInfoService.getClaims(response.access_token); - logger.debug("user info claims received from user info endpoint"); - - if (validateSub && claims.sub !== response.profile.sub) { - logger.throw(new Error("subject from UserInfo response does not match subject in ID Token")); - } - - response.profile = this._mergeClaims(response.profile, this._filterProtocolClaims(claims as IdTokenClaims)); + response.profile = await this._userInfoService.getClaims(response.access_token, response.profile, validateSub); logger.debug("user info claims received, updated profile:", response.profile); } - protected _mergeClaims(claims1: UserProfile, claims2: JwtClaims): UserProfile { - const result = { ...claims1 }; - - for (const [claim, values] of Object.entries(claims2)) { - for (const value of Array.isArray(values) ? values : [values]) { - const previousValue = result[claim]; - if (!previousValue) { - result[claim] = value; - } - else if (Array.isArray(previousValue)) { - if (!previousValue.includes(value)) { - previousValue.push(value); - } - } - else if (result[claim] !== value) { - if (typeof value === "object" && this._settings.mergeClaims) { - result[claim] = this._mergeClaims(previousValue as UserProfile, value); - } - else { - result[claim] = [previousValue, value]; - } - } - } - } - - return result; - } - - protected _filterProtocolClaims(claims: UserProfile): UserProfile { - const result = { ...claims }; - - if (this._settings.filterProtocolClaims) { - let protocolClaims; - if (Array.isArray(this._settings.filterProtocolClaims)) { - protocolClaims = this._settings.filterProtocolClaims; - } else { - protocolClaims = DefaultProtocolClaims; - } - - for (const claim of protocolClaims) { - if (!InternalRequiredProtocolClaims.includes(claim)) { - delete result[claim]; - } - } - } - - return result; - } - protected async _processCode(response: SigninResponse, state: SigninState): Promise { const logger = this._logger.create("_processCode"); if (response.code) { diff --git a/src/UserInfoService.test.ts b/src/UserInfoService.test.ts index f57675031..d399e0eb3 100644 --- a/src/UserInfoService.test.ts +++ b/src/UserInfoService.test.ts @@ -5,22 +5,33 @@ import { UserInfoService } from "./UserInfoService"; import { MetadataService } from "./MetadataService"; import type { JsonService } from "./JsonService"; import { OidcClientSettingsStore } from "./OidcClientSettings"; +import { ClaimsService } from "./ClaimsService"; +import type { IdTokenClaims } from "./Claims"; describe("UserInfoService", () => { + let stubProfile: IdTokenClaims; + let stubToken: string; + + let settings: OidcClientSettingsStore; let subject: UserInfoService; let metadataService: MetadataService; + let claimsService: ClaimsService; let jsonService: JsonService; beforeEach(() => { - const settings = new OidcClientSettingsStore({ + settings = new OidcClientSettingsStore({ authority: "authority", client_id: "client", redirect_uri: "redirect", fetchRequestCredentials: "include", }); + stubProfile = { sub: "subsub", iss: "iss", aud: "aud", exp: 0, iat: 0 }; + stubToken = "access_token"; + metadataService = new MetadataService(settings); + claimsService = new ClaimsService(settings); - subject = new UserInfoService(settings, metadataService); + subject = new UserInfoService(settings, metadataService, claimsService); // access private members jsonService = subject["_jsonService"]; @@ -30,7 +41,7 @@ describe("UserInfoService", () => { it("should return a promise", async () => { // act - const p = subject.getClaims(""); + const p = subject.getClaims("", { sub: "sub", iss: "iss", aud: "aud", exp: 0, iat: 0 }); // assert expect(p).toBeInstanceOf(Promise); @@ -41,7 +52,7 @@ describe("UserInfoService", () => { it("should require a token", async () => { // act try { - await subject.getClaims(""); + await subject.getClaims("", { sub: "sub", iss: "iss", aud: "aud", exp: 0, iat: 0 }); fail("should not come here"); } catch (err) { @@ -57,7 +68,7 @@ describe("UserInfoService", () => { .mockResolvedValue({ foo: "bar" }); // act - await subject.getClaims("token"); + await subject.getClaims("token", { sub: "sub", iss: "iss", aud: "aud", exp: 0, iat: 0 }, false); // assert expect(getJsonMock).toBeCalledWith( @@ -74,7 +85,7 @@ describe("UserInfoService", () => { // act try { - await subject.getClaims("token"); + await subject.getClaims("token", { sub: "sub", iss: "iss", aud: "aud", exp: 0, iat: 0 }); fail("should not come here"); } catch (err) { @@ -83,21 +94,48 @@ describe("UserInfoService", () => { } }); - it("should return claims", async () => { + it("should return claims respecting claims filtering rules", async () => { // arrange jest.spyOn(metadataService, "getUserInfoEndpoint").mockImplementation(() => Promise.resolve("http://sts/userinfo")); + Object.assign(settings, { filterProtocolClaims: ["a", "b", "c"] }); + const expectedClaims = { foo: 1, bar: "test", aud:"some_aud", iss:"issuer", sub:"123", email:"foo@gmail.com", role:["admin", "dev"], - nonce:"nonce", at_hash:"athash", - iat:5, nbf:10, exp:20, + iat:5, exp:20, + nonce:"nonce", at_hash:"athash", nbf:10, }; - jest.spyOn(jsonService, "getJson").mockImplementation(() => Promise.resolve(expectedClaims)); + + jest.spyOn(jsonService, "getJson").mockImplementation(() => Promise.resolve( + { ...expectedClaims, a: "apple" }, + )); // act - const claims = await subject.getClaims("token"); + const claims = await subject.getClaims("token", { sub: "123", aud: "some_aud", iss: "issuer", exp: 0, iat: 0 }); + + // assert + expect(claims).toEqual(expectedClaims); + }); + + it("should return claims removing filtered claims", async () => { + // arrange + jest.spyOn(metadataService, "getUserInfoEndpoint").mockImplementation(() => Promise.resolve("http://sts/userinfo")); + const expectedClaims = { + foo: 1, bar: "test", + aud:"some_aud", iss:"issuer", + sub:"123", email:"foo@gmail.com", + role:["admin", "dev"], + + iat:5, exp:20, + }; + jest.spyOn(jsonService, "getJson").mockImplementation(() => Promise.resolve( + { ...expectedClaims, nonce:"nonce", at_hash:"athash", nbf:10 }, + )); + + // act + const claims = await subject.getClaims("token", { sub: "123", aud: "some_aud", iss: "issuer", exp: 0, iat: 0 }); // assert expect(claims).toEqual(expectedClaims); @@ -109,7 +147,7 @@ describe("UserInfoService", () => { const getJsonMock = jest.spyOn(jsonService, "getJson").mockImplementation(() => Promise.resolve({})); // act - await subject.getClaims("token"); + await subject.getClaims("token", { sub: "sub", iss: "iss", aud: "aud", exp: 0, iat: 0 }, false); // assert expect(getJsonMock).toBeCalledWith( @@ -119,5 +157,39 @@ describe("UserInfoService", () => { }), ); }); + + it("should fail if sub from user info endpoint does not match sub in id_token", async () => { + // arrange + jest.spyOn(metadataService, "getUserInfoEndpoint").mockImplementation(() => Promise.resolve("http://sts/userinfo")); + const getJsonMock = jest.spyOn(jsonService, "getJson").mockImplementation(() => Promise.resolve({})); + + // act + await expect(subject.getClaims(stubToken, stubProfile, true)) + // assert + .rejects.toThrow("subject from UserInfo response does not match subject in ID Token"); + expect(getJsonMock).toBeCalledWith( + "http://sts/userinfo", + expect.objectContaining({ + credentials: "include", + }), + ); + expect(stubProfile).toMatchObject({ sub: "subsub" }); + }); + + it("should load and merge user info claims", async () => { + // arrange + jest.spyOn(metadataService, "getUserInfoEndpoint").mockImplementation(() => Promise.resolve("http://sts/userinfo")); + jest.spyOn(jsonService, "getJson").mockImplementation(() => Promise.resolve({ sub: stubProfile.sub, c: "carrot" })); + Object.assign(stubProfile, { a: "apple", b: "banana" }); + + // act + const claims = await subject.getClaims(stubToken, stubProfile, true); + + // assert + expect(claims).toEqual({ + ...stubProfile, + c: "carrot", + }); + }); }); }); diff --git a/src/UserInfoService.ts b/src/UserInfoService.ts index 20a536c7b..44c6f8935 100644 --- a/src/UserInfoService.ts +++ b/src/UserInfoService.ts @@ -4,8 +4,9 @@ import { Logger, JwtUtils } from "./utils"; import { JsonService } from "./JsonService"; import type { MetadataService } from "./MetadataService"; -import type { JwtClaims } from "./Claims"; +import type { IdTokenClaims, JwtClaims } from "./Claims"; import type { OidcClientSettingsStore } from "./OidcClientSettings"; +import type { ClaimsService } from "./ClaimsService"; /** * @internal @@ -16,14 +17,15 @@ export class UserInfoService { public constructor(private readonly _settings: OidcClientSettingsStore, private readonly _metadataService: MetadataService, + private readonly _claimsService: ClaimsService, ) { this._jsonService = new JsonService(undefined, this._getClaimsFromJwt); } - public async getClaims(token: string): Promise { + public async getClaims(token: string, profile: IdTokenClaims, validateSub = true): Promise { const logger = this._logger.create("getClaims"); if (!token) { - this._logger.throw(new Error("No token passed")); + logger.throw(new Error("No token passed")); } const url = await this._metadataService.getUserInfoEndpoint(); @@ -33,9 +35,15 @@ export class UserInfoService { token, credentials: this._settings.fetchRequestCredentials, }); - logger.debug("got claims", claims); + logger.debug("user info claims received from user info endpoint"); - return claims; + if (validateSub && profile && claims.sub !== profile.sub) { + logger.throw(new Error("subject from UserInfo response does not match subject in ID Token")); + } + + const filteredClaims = this._claimsService.filterProtocolClaims(claims as IdTokenClaims); + + return this._claimsService.mergeClaims(profile, filteredClaims); } protected _getClaimsFromJwt = async (responseText: string): Promise => { diff --git a/src/UserManager.test.ts b/src/UserManager.test.ts index 111f09617..94cfad0c3 100644 --- a/src/UserManager.test.ts +++ b/src/UserManager.test.ts @@ -109,6 +109,131 @@ describe("UserManager", () => { expect(result).toBeNull(); expect(loadMock).not.toBeCalled(); }); + + it("should refresh the userinfo", async () => { + Object.assign(subject.settings, { + loadUserInfo: true, + }); + // arrange + const user = new User({ + access_token: "access_token", + token_type: "token_type", + profile: { + sub: "my sub", + iss: "issuer", + aud: "audience", + a: "apple", + exp: 123, + iat: 543, + } as UserProfile, + }); + subject["_loadUser"] = jest.fn().mockReturnValue(user); + const updateUserInfo = jest.spyOn(subject["_client"], "getUserInfo").mockResolvedValue({ + sub: "my sub", + iss: "issuer", + aud: "audience", + a: "orange", + b: "banana", + exp: 456, + iat: 789, + }); + const loadMock = jest.spyOn(subject["_events"], "load"); + + // act + const result = await subject.getUser(true); + + // assert + expect(result).toEqual(user); + expect(result?.profile).toHaveProperty("a", "orange"); + expect(result?.profile).toHaveProperty("b", "banana"); + expect(updateUserInfo).toHaveBeenCalled(); + expect(loadMock).toBeCalledWith(user, false); + }); + + it("while refreshing userinfo, should refresh token if expired", async () => { + Object.assign(subject.settings, { + loadUserInfo: true, + }); + // arrange + const user = new User({ + access_token: "access_token", + token_type: "token_type", + expires_at: 100, + refresh_token: "refresh_token", + profile: { + sub: "my sub", + iss: "issuer", + aud: "audience", + a: "apple", + exp: 123, + iat: 543, + } as UserProfile, + }); + + subject["_loadUser"] = jest.fn().mockReturnValue(user); + + const useRefreshMock = jest.spyOn(subject["_client"], "useRefreshToken").mockResolvedValue({ + access_token: "new_access_token", + profile: { + sub: "my sub", + iss: "issuer", + aud: "audience", + exp: 123, + iat: 543, + + a: "orange", + }, + } as unknown as SigninResponse); + + // act + const result = await subject.getUser(true); + + // assert + expect(result?.profile).toHaveProperty("a", "orange"); + expect(result?.profile).not.toHaveProperty("b"); + expect(useRefreshMock).toHaveBeenCalled(); + }); + it("should retrieve updated userinfo and persist it", async () => { + Object.assign(subject.settings, { + loadUserInfo: true, + }); + // arrange + const user = new User({ + access_token: "access_token", + token_type: "token_type", + profile: { + sub: "my sub", + iss: "issuer", + aud: "audience", + a: "apple", + exp: 123, + iat: 543, + } as UserProfile, + }); + subject["_loadUser"] = jest.fn().mockReturnValue(user); + const updateUserInfo = jest.spyOn(subject["_client"], "getUserInfo").mockResolvedValue({ + sub: "my sub", + iss: "issuer", + aud: "audience", + a: "orange", + exp: 456, + iat: 789, + }); + const loadMock = jest.spyOn(subject["_events"], "load"); + + // act + const result = await subject.getUser(true); + const result2 = await subject.getUser(); + + // assert + expect(result).toEqual(user); + expect(result?.profile).toHaveProperty("a", "orange"); + expect(updateUserInfo).toHaveBeenCalledTimes(1); + expect(loadMock).toBeCalledWith(user, false); + + expect(result2?.profile).toHaveProperty("a", "orange"); + expect(result2?.profile).toEqual(result?.profile); + }); }); describe("removeUser", () => { diff --git a/src/UserManager.ts b/src/UserManager.ts index 45d49b4cf..b40b45b5d 100644 --- a/src/UserManager.ts +++ b/src/UserManager.ts @@ -126,11 +126,30 @@ export class UserManager { /** * Returns promise to load the `User` object for the currently authenticated user. + * Can optionally refresh userInfo data */ - public async getUser(): Promise { + public async getUser(refreshUserInfo?: boolean): Promise { const logger = this._logger.create("getUser"); - const user = await this._loadUser(); + let user = await this._loadUser(); if (user) { + if (this.settings.loadUserInfo && refreshUserInfo) { + if (user.expired) { + logger.debug("refreshing token"); + user = await this.signinSilent(); + + if (!user) { + logger.info("after refresh, user is not logged anymore"); + return null; + } + } else { + logger.debug("refreshing user info"); + user.profile = await this._client.getUserInfo(user.access_token, user.profile); + await this.storeUser(user); + logger.debug("user updated in storage"); + } + logger.debug("user info refreshed"); + } + logger.info("user loaded"); this._events.load(user, false); return user;