From 4dafd34fcccf32e8ad6fc026002907d8447138aa Mon Sep 17 00:00:00 2001 From: cqnguy23 <44353219+cqnguy23@users.noreply.github.com> Date: Wed, 20 Nov 2024 11:48:28 +0700 Subject: [PATCH] Fix client type issues and improve tests (#31617) - Add a test util to parse JWT token payload to validate audience - Add more tests for AAD - Fix mismatched DTO field when calling generateClientToken --------- Co-authored-by: tomnguyen --- .../web-pubsub-express/CHANGELOG.md | 2 + sdk/web-pubsub/web-pubsub/assets.json | 2 +- sdk/web-pubsub/web-pubsub/src/hubClient.ts | 8 +-- sdk/web-pubsub/web-pubsub/test.env | 1 + sdk/web-pubsub/web-pubsub/test/hubs.spec.ts | 65 ++++++++++++++++++- sdk/web-pubsub/web-pubsub/test/testEnv.ts | 1 + sdk/web-pubsub/web-pubsub/test/testUtils.ts | 8 +++ 7 files changed, 80 insertions(+), 7 deletions(-) create mode 100644 sdk/web-pubsub/web-pubsub/test/testUtils.ts diff --git a/sdk/web-pubsub/web-pubsub-express/CHANGELOG.md b/sdk/web-pubsub/web-pubsub-express/CHANGELOG.md index 8dcecf5d4e7f..b3844593c518 100644 --- a/sdk/web-pubsub/web-pubsub-express/CHANGELOG.md +++ b/sdk/web-pubsub/web-pubsub-express/CHANGELOG.md @@ -10,6 +10,8 @@ ### Bugs Fixed +- Fix issue with mismatched DTO for client protocol when calling generate client access URI API, which causes the response to be incorrect. + ### Other Changes ## 1.0.5 (2023-06-28) diff --git a/sdk/web-pubsub/web-pubsub/assets.json b/sdk/web-pubsub/web-pubsub/assets.json index ada23945f50a..7588af5e0fa2 100644 --- a/sdk/web-pubsub/web-pubsub/assets.json +++ b/sdk/web-pubsub/web-pubsub/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "js", "TagPrefix": "js/web-pubsub/web-pubsub", - "Tag": "js/web-pubsub/web-pubsub_0e1077ac6f" + "Tag": "js/web-pubsub/web-pubsub_d125ff0258" } diff --git a/sdk/web-pubsub/web-pubsub/src/hubClient.ts b/sdk/web-pubsub/web-pubsub/src/hubClient.ts index 7042d639872c..f5fa52d3f736 100644 --- a/sdk/web-pubsub/web-pubsub/src/hubClient.ts +++ b/sdk/web-pubsub/web-pubsub/src/hubClient.ts @@ -976,10 +976,10 @@ export class WebPubSubServiceClient { let token: string; if (isTokenCredential(this.credential)) { - const response = await this.client.webPubSub.generateClientToken( - this.hubName, - updatedOptions, - ); + const response = await this.client.webPubSub.generateClientToken(this.hubName, { + ...updatedOptions, + clientType: clientProtocol, + }); token = response.token!; } else { const key = this.credential.key; diff --git a/sdk/web-pubsub/web-pubsub/test.env b/sdk/web-pubsub/web-pubsub/test.env index 4aee7ecf4993..c60946580246 100644 --- a/sdk/web-pubsub/web-pubsub/test.env +++ b/sdk/web-pubsub/web-pubsub/test.env @@ -4,6 +4,7 @@ WPS_CONNECTION_STRING="" WPS_API_KEY="" WPS_ENDPOINT="" WPS_REVERSE_PROXY_ENDPOINT="" +WPS_SOCKETIO_ENDPOINT="" # Used to authenticate using Azure AD as a service principal for role-based # authentication. diff --git a/sdk/web-pubsub/web-pubsub/test/hubs.spec.ts b/sdk/web-pubsub/web-pubsub/test/hubs.spec.ts index f5e8cf291451..c9c3fd9efb90 100644 --- a/sdk/web-pubsub/web-pubsub/test/hubs.spec.ts +++ b/sdk/web-pubsub/web-pubsub/test/hubs.spec.ts @@ -7,6 +7,7 @@ import { assert } from "@azure-tools/test-utils"; import recorderOptions from "./testEnv"; import type { FullOperationResponse } from "@azure/core-client"; import { createTestCredential } from "@azure-tools/test-credential"; +import { parseJwt } from "./testUtils"; /* eslint-disable @typescript-eslint/no-invalid-this */ describe("HubClient", function () { @@ -301,9 +302,11 @@ describe("HubClient", function () { groups: ["group1"], }); const url = new URL(res.url); + const tokenPayload = parseJwt(res.token!); assert.ok(url.searchParams.has("access_token")); assert.equal(url.host, new URL(client.endpoint).host); assert.equal(url.pathname, `/client/hubs/${client.hubName}`); + assert.equal(tokenPayload.aud, client.endpoint + `client/hubs/${client.hubName}`); }); it("can generate default client tokens", async () => { @@ -313,9 +316,11 @@ describe("HubClient", function () { clientProtocol: "default", }); const url = new URL(res.url); + const tokenPayload = parseJwt(res.token!); assert.ok(url.searchParams.has("access_token")); assert.equal(url.host, new URL(client.endpoint).host); assert.equal(url.pathname, `/client/hubs/${client.hubName}`); + assert.equal(tokenPayload.aud, client.endpoint + `client/hubs/${client.hubName}`); }); it("can generate client MQTT tokens", async () => { @@ -325,21 +330,77 @@ describe("HubClient", function () { clientProtocol: "mqtt", }); const url = new URL(res.url); + const tokenPayload = parseJwt(res.token!); assert.ok(url.searchParams.has("access_token")); assert.equal(url.host, new URL(client.endpoint).host); assert.equal(url.pathname, `/clients/mqtt/hubs/${client.hubName}`); + assert.equal(tokenPayload.aud, client.endpoint + `clients/mqtt/hubs/${client.hubName}`); }); - it("can generate socketIO client tokens", async () => { - const res = await client.getClientAccessToken({ + it("can generate default client tokens with DAC", async function () { + // Recording not generated properly, so only run in live mode + if (!isLiveMode()) this.skip(); + const dacClient = new WebPubSubServiceClient( + assertEnvironmentVariable("WPS_ENDPOINT"), + credential, + "simplechat", + recorder.configureClientOptions({}), + ); + const res = await dacClient.getClientAccessToken({ + userId: "brian", + groups: ["group1"], + clientProtocol: "default", + }); + const url = new URL(res.url); + const tokenPayload = parseJwt(res.token!); + assert.ok(url.searchParams.has("access_token")); + assert.equal(url.host, new URL(client.endpoint).host); + assert.equal(url.pathname, `/client/hubs/${client.hubName}`); + assert.equal(tokenPayload.aud, client.endpoint + `client/hubs/${client.hubName}`); + }); + + it("can generate client MQTT tokens with DAC", async function () { + // Recording not generated properly, so only run in live mode + if (!isLiveMode()) this.skip(); + const dacClient = new WebPubSubServiceClient( + assertEnvironmentVariable("WPS_ENDPOINT"), + credential, + "simplechat", + recorder.configureClientOptions({}), + ); + const res = await dacClient.getClientAccessToken({ + userId: "brian", + groups: ["group1"], + clientProtocol: "mqtt", + }); + const url = new URL(res.url); + const tokenPayload = parseJwt(res.token!); + assert.ok(url.searchParams.has("access_token")); + assert.equal(url.host, new URL(client.endpoint).host); + assert.equal(url.pathname, `/clients/mqtt/hubs/${client.hubName}`); + assert.equal(tokenPayload.aud, client.endpoint + `clients/mqtt/hubs/${client.hubName}`); + }); + + it("can generate client socketIO tokens with DAC", async function () { + // Recording not generated properly, so only run in live mode + if (!isLiveMode()) this.skip(); + const dacClient = new WebPubSubServiceClient( + assertEnvironmentVariable("WPS_SOCKETIO_ENDPOINT"), + credential, + "simplechat", + recorder.configureClientOptions({}), + ); + const res = await dacClient.getClientAccessToken({ userId: "brian", groups: ["group1"], clientProtocol: "socketio", }); const url = new URL(res.url); + const tokenPayload = parseJwt(res.token!); assert.ok(url.searchParams.has("access_token")); assert.equal(url.host, new URL(client.endpoint).host); assert.equal(url.pathname, `/clients/socketio/hubs/${client.hubName}`); + assert.equal(tokenPayload.aud, client.endpoint + `clients/socketio/hubs/${client.hubName}`); }); }); }); diff --git a/sdk/web-pubsub/web-pubsub/test/testEnv.ts b/sdk/web-pubsub/web-pubsub/test/testEnv.ts index 82ae4414e3c2..b2d8599b160b 100644 --- a/sdk/web-pubsub/web-pubsub/test/testEnv.ts +++ b/sdk/web-pubsub/web-pubsub/test/testEnv.ts @@ -8,6 +8,7 @@ const envSetupForPlayback: Record = { WPS_API_KEY: "api_key", WPS_ENDPOINT: "https://endpoint", WPS_REVERSE_PROXY_ENDPOINT: "https://endpoint", + WPS_SOCKETIO_ENDPOINT: "https://socketio.endpoint", }; const recorderOptions: RecorderStartOptions = { diff --git a/sdk/web-pubsub/web-pubsub/test/testUtils.ts b/sdk/web-pubsub/web-pubsub/test/testUtils.ts new file mode 100644 index 000000000000..8f0fcb158f6f --- /dev/null +++ b/sdk/web-pubsub/web-pubsub/test/testUtils.ts @@ -0,0 +1,8 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +export function parseJwt(token: string): any { + const base64Payload = token.split(".")[1]; + const payload = Buffer.from(base64Payload, "base64"); + return JSON.parse(payload.toString()); +}