Skip to content

Commit

Permalink
Caches user auth in secure data store
Browse files Browse the repository at this point in the history
Introduces a persistent cache using the secure data store to store accounts and tokens, which avoids making users log in every time they start the app.

Fixes [AB#10633343](https://msazure.visualstudio.com/AzureBatch/_workitems/edit/10633343/)
  • Loading branch information
gingi committed Aug 30, 2021
1 parent d41e80d commit b32f4e2
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 31 deletions.
8 changes: 7 additions & 1 deletion src/client/core/aad/aad-constants.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
export const defaultTenant = "organizations";
export enum TenantPlaceholders {
common = "common",
organizations = "organizations",
consumers = "consumers"
}

export const defaultTenant = TenantPlaceholders.organizations;

export interface AuthorizeUrlParams {
response_type: string;
Expand Down
5 changes: 4 additions & 1 deletion src/client/core/aad/auth-provider.spec.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { instrumentAuthProvider, instrumentForAuth } from "test/utils/mocks/auth";
import AuthProvider from "./auth-provider";

describe("AuthProvider", () => {
Expand All @@ -13,6 +14,7 @@ describe("AuthProvider", () => {
}
}
};
instrumentForAuth(appSpy);
const config: any = {
tenant: "common",
redirectUri: "my-redirect-uri",
Expand All @@ -24,12 +26,13 @@ describe("AuthProvider", () => {
});

it("authenticates interactively first, then silently", async () => {
const call = () => authProvider.getToken({
const call = async () => await authProvider.getToken({
resourceURI: "resourceURI1", tenantId: "tenant1",
authCodeCallback: authCodeCallbackSpy
});
const clientSpy = makeClientApplicationSpy();
spyOn<any>(authProvider, "_getClient").and.returnValue(clientSpy);
instrumentAuthProvider(authProvider);

returnToken(clientSpy.acquireTokenByCode, "interactive-token-1");
returnToken(clientSpy.acquireTokenSilent, "silent-token-1");
Expand Down
92 changes: 74 additions & 18 deletions src/client/core/aad/auth-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@ import {
ClientApplication,
PublicClientApplication
} from "@azure/msal-node";
import { log } from "@batch-flask/utils";
import { BatchExplorerApplication } from "..";
import { SecureDataStore } from "../secure-data-store";
import { AADConfig } from "./aad-config";
import { defaultTenant } from "./aad-constants";
import { defaultTenant, TenantPlaceholders } from "./aad-constants";
import MSALCachePlugin from "./msal-cache-plugin";

const MSAL_SCOPES = ["user_impersonation"];

Expand All @@ -18,11 +21,17 @@ export type AuthorizationResult = AuthenticationResult;
export default class AuthProvider {
private _clients: StringMap<ClientApplication> = {};
private _accounts: StringMap<AccountInfo> = {};
private _cachePlugin: MSALCachePlugin;
private _logoutPromise?: Promise<void>;
private _primaryClient?: ClientApplication;

constructor(
protected app: BatchExplorerApplication,
protected config: AADConfig
) {}
) {
this._cachePlugin =
new MSALCachePlugin(app.injector.get(SecureDataStore));
}

/**
* Retrieves an access token
Expand All @@ -39,6 +48,10 @@ export default class AuthProvider {
}): Promise<AuthorizationResult> {
const { resourceURI, tenantId = defaultTenant, authCodeCallback } = options;

if (this._logoutPromise) {
await this._logoutPromise;
}

/**
* KLUDGE: msal.js does not handle well access tokens across multiple
* tenants within the same cache. It lets you specify a different
Expand All @@ -49,12 +62,15 @@ export default class AuthProvider {
const client = this._getClient(tenantId);

const authRequest = this._authRequest(resourceURI, tenantId);
if (this._accounts[tenantId]) {
try {
log.debug(`Trying to silently acquire token for '${tenantId}'`);
const account = await this._getAccount(tenantId);
const result = await client.acquireTokenSilent({
...authRequest, account: this._accounts[tenantId]
...authRequest, account
});
return result;
} else {
} catch (silentTokenException) {
log.debug(`Trying silent auth code flow (${silentTokenException})`);
let url, code;

try {
Expand All @@ -63,45 +79,85 @@ export default class AuthProvider {
{ ...authRequest, prompt: "none" }
);
code = await authCodeCallback(url, true);
} catch (e) {
} catch (silentAuthException) {
log.debug(`Trying interactive auth code flow (${silentAuthException})`);
url = await client.getAuthCodeUrl(authRequest);
code = await authCodeCallback(url);
}

const result: AuthorizationResult =
await client.acquireTokenByCode({ ...authRequest, code });
if (result) {
this._accounts[tenantId] = result.account;
}
return result;
}
}

public logout(): void {
this._removeAccount();
public async logout(): Promise<void> {
this._logoutPromise = this._removeAccounts();
return this._logoutPromise;
}

private async _removeAccounts(): Promise<void> {
const cache = this._primaryClient?.getTokenCache();
if (!cache) {
return;
}
const accounts = await cache.getAllAccounts();
for (const account of accounts) {
await cache.removeAccount(account);
}
this._accounts = {};
this._clients = {};
this._primaryClient = undefined;
}

protected _getClient(tenantId: string): ClientApplication {
if (tenantId in this._clients) {
return this._clients[tenantId];
}
return this._clients[tenantId] = new PublicClientApplication({
const client = new PublicClientApplication({
auth: {
clientId: this.config.clientId,
authority:
`${this.app.properties.azureEnvironment.aadUrl}${tenantId}/`
},
cache: {
cachePlugin: this._cachePlugin
}
});
if (!this._primaryClient) {
this._primaryClient = client;
}
this._clients[tenantId] = client;
return client;
}

private async _removeAccount(): Promise<void> {
for (const tenantId in this._clients) {
if (this._accounts[tenantId]) {
const cache = this._clients[tenantId].getTokenCache();
cache.removeAccount(this._accounts[tenantId]);
private async _getAccount(tenantId: string): Promise<AccountInfo> {
if (tenantId in this._accounts) {
return this._accounts[tenantId];
}
const cache = this._clients[tenantId].getTokenCache();
const accounts: AccountInfo[] = await cache.getAllAccounts();
let homeAccountId = null;
for (const account of accounts) {
if (account.tenantId === tenantId) {
return this._accounts[tenantId] = account;
} else if (!homeAccountId) {
homeAccountId = account.homeAccountId;
}
delete this._accounts[tenantId];
}

/* SPECIAL CASE: If the tenant is one of the tenant placeholders (e.g.,
* "common"), fallback to the home tenant account, since the tenant in
* the account dictionary are always resolved IDs.
*/
if (tenantId in TenantPlaceholders) {
return this._accounts[tenantId] =
await cache.getAccountByHomeId(homeAccountId);
}

throw new Error(
`Unable to find a valid AAD account for tenant ${tenantId}`
);
}

private _getScopes(resourceURI: string): string[] {
Expand Down
5 changes: 3 additions & 2 deletions src/client/core/aad/auth/aad.service.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { AzureChina, AzurePublic } from "client/azure-environment";
import { Constants } from "common";
import { DateTime } from "luxon";
import * as proxyquire from "proxyquire";
import { instrumentForAuth } from "test/utils/mocks/auth";
import { MockBrowserWindow, MockSplashScreen } from "test/utils/mocks/windows";
import { AADUser } from "./aad-user";
import { AADService } from "./aad.service";
Expand Down Expand Up @@ -41,9 +42,9 @@ describe("AADService", () => {
localStorage = new InMemoryDataStore();
appSpy = {
mainWindow: new MockBrowserWindow(),
splashScreen: new MockSplashScreen(),
properties: { azureEnvironment: "public" }
splashScreen: new MockSplashScreen()
};
instrumentForAuth(appSpy);

ipcMainMock = {
on: () => null,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { AzurePublic } from "client/azure-environment";
import { delay } from "test/utils/helpers/misc";
import { MockAuthProvider } from "test/utils/mocks/auth";
import { MockAuthenticationWindow, MockSplashScreen } from "test/utils/mocks/windows";
import {
Expand All @@ -21,10 +21,7 @@ describe("AuthenticationService", () => {
beforeEach(() => {
appSpy = {
splashScreen: new MockSplashScreen(),
authenticationWindow: new MockAuthenticationWindow(),
properties: {
azureEnvironment: AzurePublic,
},
authenticationWindow: new MockAuthenticationWindow()
};
fakeAuthProvider = new MockAuthProvider(appSpy, CONFIG);
userAuthorization = new AuthenticationService(appSpy, CONFIG,
Expand All @@ -42,6 +39,7 @@ describe("AuthenticationService", () => {
error = null;
const obs = userAuthorization.authorize("tenant-1");
promise = obs.then((out) => result = out).catch((e) => error = e);
await delay();
});

it("Should have called loadurl", async () => {
Expand Down Expand Up @@ -69,7 +67,7 @@ describe("AuthenticationService", () => {
expect(state).toBe(AuthenticationState.Authenticated);
});

it("Should error it fail to load", async () => {
it("Should error when the window fails to load", async () => {
fakeAuthWindow.notifyError({ code: 4, description: "Foo bar" });
await promise;

Expand Down Expand Up @@ -97,7 +95,7 @@ describe("AuthenticationService", () => {
expect(fakeAuthWindow.destroy).toHaveBeenCalledTimes(1);
});

it("should only authorize 1 tenant at the time and queue the others", async () => {
it("should only authorize 1 tenant at a time and queue the others", async () => {
const obs1 = userAuthorization.authorize("tenant-1");
const obs2 = userAuthorization.authorize("tenant-2");
const tenant1Spy = jasmine.createSpy("Tenant-1");
Expand Down
39 changes: 39 additions & 0 deletions src/client/core/aad/msal-cache-plugin.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import MSALCachePlugin from "./msal-cache-plugin";

describe("MSALCachePlugin", () => {
let plugin, cacheContextSpy;
const storeSpy = jasmine.createSpyObj("DataStore", {
setItem: jasmine.anything(),
getItem: "deserializedValue"
});
const cacheSpy = jasmine.createSpyObj("Cache", {
serialize: "serializedValue",
deserialize: jasmine.anything()
});
beforeEach(() => {
plugin = new MSALCachePlugin(storeSpy);
cacheContextSpy = {
tokenCache: cacheSpy,
cacheHasChanged: false
};
})
it("should get an item before the cache is called", async () => {
expect(storeSpy.getItem).not.toHaveBeenCalled();
await plugin.beforeCacheAccess(cacheContextSpy);
expect(storeSpy.getItem).toHaveBeenCalled();
expect(cacheSpy.deserialize).toHaveBeenCalledWith("deserializedValue");
});
it("should store an item after the cache is called", async () => {
expect(storeSpy.setItem).not.toHaveBeenCalled();

await plugin.afterCacheAccess(cacheContextSpy);
expect(storeSpy.setItem).not.toHaveBeenCalled();

cacheContextSpy.cacheHasChanged = true;
await plugin.afterCacheAccess(cacheContextSpy);
expect(storeSpy.setItem).toHaveBeenCalledWith(
jasmine.anything(),
"serializedValue"
);
});
});
18 changes: 18 additions & 0 deletions src/client/core/aad/msal-cache-plugin.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import { ICachePlugin, TokenCacheContext } from "@azure/msal-node";
import { SecureDataStore } from "../secure-data-store";

const CACHE_KEY = "msal_auth_cache";

export default class MSALCachePlugin implements ICachePlugin {
constructor(private store: SecureDataStore) {}

public async beforeCacheAccess(context: TokenCacheContext): Promise<void> {
context.tokenCache.deserialize(await this.store.getItem(CACHE_KEY));
}

public async afterCacheAccess(context: TokenCacheContext): Promise<void> {
if (context.cacheHasChanged) {
await this.store.setItem(CACHE_KEY, context.tokenCache.serialize());
}
}
}
26 changes: 24 additions & 2 deletions src/test/utils/mocks/auth/auth-provider.mock.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { AuthenticationResult, ClientApplication } from "@azure/msal-node";
import { AzurePublic } from "client/azure-environment";
import { AuthorizeError } from "client/core/aad";
import { AADConfig } from "client/core/aad/aad-config";
import AuthProvider from "client/core/aad/auth-provider";
Expand All @@ -8,11 +9,13 @@ export class MockAuthProvider extends AuthProvider {
public fakeConfig: AADConfig;
public fakeError: Partial<AuthorizeError>;
constructor(app: any, config: AADConfig) {
instrumentForAuth(app);
super(app, config);
this.fakeConfig = config;
this._getClient = jasmine.createSpy("_getClient").and.returnValue(
spyOn<any>(this, "_getClient").and.returnValue(
new MockClientApplication(this)
);
instrumentAuthProvider(this);
}
}
export class MockClientApplication extends ClientApplication {
Expand All @@ -26,7 +29,7 @@ export class MockClientApplication extends ClientApplication {

public getAuthCodeUrl(request) {
if (request?.prompt === "none") {
throw "No silent auth";
throw new Error("No silent auth");
}
return Promise.resolve(this.fakeAuthProvider.fakeConfig.redirectUri);
}
Expand All @@ -50,3 +53,22 @@ export const createMockClientApplication = () => {
});
return new MockClientApplication(fakeAuthProvider);
};

export const instrumentForAuth = app => {
app.injector = {
get: () => jasmine.createSpyObj("DataStore", ["getItem", "setItem"])
};
app.properties = { azureEnvironment: AzurePublic }
}

export const instrumentAuthProvider = (authProvider: AuthProvider) => {
const tenants = {};
spyOn<any>(authProvider, "_getAccount").and.callFake(tenantId => {
if (tenantId in tenants) {
return { tenantId };
} else {
tenants[tenantId] = true;
throw new Error("no account");
}
});
}

0 comments on commit b32f4e2

Please sign in to comment.