Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(api-graphql): pass authToken via subprotocol #13727

Merged
merged 9 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,8 @@ describe('AWSAppSyncRealTimeProvider', () => {

expect(newSocketSpy).toHaveBeenNthCalledWith(
1,
'ws://localhost:8080/realtime?header=&payload=e30=',
'graphql-ws',
'ws://localhost:8080/realtime',
['graphql-ws', 'header-'],
);
});

Expand All @@ -271,8 +271,8 @@ describe('AWSAppSyncRealTimeProvider', () => {

expect(newSocketSpy).toHaveBeenNthCalledWith(
1,
'wss://localhost:8080/realtime?header=&payload=e30=',
'graphql-ws',
'wss://localhost:8080/realtime',
['graphql-ws', 'header-'],
);
});

Expand All @@ -298,8 +298,84 @@ describe('AWSAppSyncRealTimeProvider', () => {

expect(newSocketSpy).toHaveBeenNthCalledWith(
1,
'wss://testaccounturl123456789123.appsync-realtime-api.us-east-1.amazonaws.com/graphql?header=&payload=e30=',
'graphql-ws',
'wss://testaccounturl123456789123.appsync-realtime-api.us-east-1.amazonaws.com/graphql',
['graphql-ws', 'header-'],
);
});

test('subscription generates expected auth token', async () => {
expect.assertions(1);

const newSocketSpy = jest
.spyOn(provider, 'getNewWebSocket')
.mockImplementation(() => {
fakeWebSocketInterface.newWebSocket();
return fakeWebSocketInterface.webSocket;
});

provider
.subscribe({
appSyncGraphqlEndpoint:
'https://testaccounturl123456789123.appsync-api.us-east-1.amazonaws.com/graphql',
// using custom auth instead of apiKey, because the latter inserts a timestamp header => expected value changes
authenticationType: 'lambda',
additionalHeaders: {
Authorization: 'my-custom-auth-token',
},
})
.subscribe({ error: () => {} });

// Wait for the socket to be initialize
await fakeWebSocketInterface.readyForUse;

/*
Regular base64 encoding of auth header {"Authorization":"my-custom-auth-token","host":"testaccounturl123456789123.appsync-api.us-east-1.amazonaws.com"}
Is: `eyJBdXRob3JpemF0aW9uIjoibXktY3VzdG9tLWF1dGgtdG9rZW4iLCJob3N0IjoidGVzdGFjY291bnR1cmwxMjM0NTY3ODkxMjMuYXBwc3luYy1hcGkudXMtZWFzdC0xLmFtYXpvbmF3cy5jb20ifQ==`
(note `==` at the end of the string)
base64url encoding is expected to drop padding chars `=`
*/

expect(newSocketSpy).toHaveBeenNthCalledWith(
1,
'wss://testaccounturl123456789123.appsync-realtime-api.us-east-1.amazonaws.com/graphql',
[
'graphql-ws',
'header-eyJBdXRob3JpemF0aW9uIjoibXktY3VzdG9tLWF1dGgtdG9rZW4iLCJob3N0IjoidGVzdGFjY291bnR1cmwxMjM0NTY3ODkxMjMuYXBwc3luYy1hcGkudXMtZWFzdC0xLmFtYXpvbmF3cy5jb20ifQ',
],
david-mcafee marked this conversation as resolved.
Show resolved Hide resolved
);
});

test('subscription generates expected auth token - custom domain', async () => {
expect.assertions(1);

const newSocketSpy = jest
.spyOn(provider, 'getNewWebSocket')
.mockImplementation(() => {
fakeWebSocketInterface.newWebSocket();
return fakeWebSocketInterface.webSocket;
});

provider
.subscribe({
appSyncGraphqlEndpoint: 'https://unit-test.testurl.com',
// using custom auth instead of apiKey, because the latter inserts a timestamp header => expected value changes
authenticationType: 'lambda',
additionalHeaders: {
Authorization: 'my-custom-auth-token',
},
})
.subscribe({ error: () => {} });

// Wait for the socket to be initialize
await fakeWebSocketInterface.readyForUse;

expect(newSocketSpy).toHaveBeenNthCalledWith(
1,
'wss://unit-test.testurl.com/realtime',
[
'graphql-ws',
'header-eyJBdXRob3JpemF0aW9uIjoibXktY3VzdG9tLWF1dGgtdG9rZW4iLCJob3N0IjoidW5pdC10ZXN0LnRlc3R1cmwuY29tIn0',
],
);
});

Expand Down
2 changes: 1 addition & 1 deletion packages/api-graphql/__tests__/GraphQLAPI.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1594,7 +1594,7 @@ describe('API test', () => {
`;

const resolvedUrl =
'wss://testaccounturl123456789123.appsync-realtime-api.us-east-1.amazonaws.com/graphql?header=eyJBdXRob3JpemF0aW9uIjoiYWJjMTIzNDUiLCJob3N0IjoidGVzdGFjY291bnR1cmwxMjM0NTY3ODkxMjMuYXBwc3luYy1hcGkudXMtZWFzdC0xLmFtYXpvbmF3cy5jb20ifQ==&payload=e30=&x-amz-user-agent=aws-amplify%2F6.4.0%20api%2F1%20framework%2F2&ex-machina=is%20a%20good%20movie';
'wss://testaccounturl123456789123.appsync-realtime-api.us-east-1.amazonaws.com/graphql?x-amz-user-agent=aws-amplify%2F6.4.0+api%2F1+framework%2F2&ex-machina=is+a+good+movie';

(
client.graphql(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
import { signRequest } from '@aws-amplify/core/internals/aws-client-utils';
import {
AmplifyUrl,
AmplifyUrlSearchParams,
CustomUserAgentDetails,
DocumentType,
GraphQLAuthMode,
Expand Down Expand Up @@ -181,7 +182,7 @@ export class AWSAppSyncRealTimeProvider {
this.reconnectionMonitor.close();
}

getNewWebSocket(url: string, protocol: string) {
getNewWebSocket(url: string, protocol: string[]) {
return new WebSocket(url, protocol);
}

Expand Down Expand Up @@ -734,20 +735,63 @@ export class AWSAppSyncRealTimeProvider {
/**
*
* @param headers - http headers
* @returns query string of uri-encoded parameters derived from custom headers
* @returns uri-encoded query parameters derived from custom headers
*/
private _queryStringFromCustomHeaders(
private _queryParamsFromCustomHeaders(
headers?: AWSAppSyncRealTimeProviderOptions['additionalCustomHeaders'],
): string {
): URLSearchParams {
const nonAuthHeaders = this._extractNonAuthHeaders(headers);

const queryParams: string[] = Object.entries(nonAuthHeaders).map(
([key, val]) => `${encodeURIComponent(key)}=${encodeURIComponent(val)}`,
const params = new AmplifyUrlSearchParams();

Object.entries(nonAuthHeaders).forEach(([k, v]) => {
params.append(k, v);
});

return params;
}

/**
* Normalizes AppSync realtime endpoint URL
*
* @param appSyncGraphqlEndpoint - AppSync endpointUri from config
* @param urlParams - URLSearchParams
* @returns fully resolved string realtime endpoint URL
*/
private _realtimeUrlWithQueryString(
appSyncGraphqlEndpoint: string | undefined,
urlParams: URLSearchParams,
): string {
const protocol = 'wss://';

let realtimeEndpoint = appSyncGraphqlEndpoint ?? '';

if (this.isCustomDomain(realtimeEndpoint)) {
realtimeEndpoint = realtimeEndpoint.concat(customDomainPath);
} else {
realtimeEndpoint = realtimeEndpoint
.replace('appsync-api', 'appsync-realtime-api')
.replace('gogi-beta', 'grt-beta');
}

realtimeEndpoint = realtimeEndpoint
.replace('https://', protocol)
.replace('http://', protocol);

const realtimeEndpointUrl = new AmplifyUrl(realtimeEndpoint);

// preserves any query params a customer might manually set in the configuration
const existingParams = new AmplifyUrlSearchParams(
realtimeEndpointUrl.search,
);

const queryString = queryParams.join('&');
for (const [k, v] of urlParams.entries()) {
existingParams.append(k, v);
}

return queryString;
realtimeEndpointUrl.search = existingParams.toString();

return realtimeEndpointUrl.toString();
}

private _initializeWebSocketConnection({
Expand Down Expand Up @@ -783,38 +827,27 @@ export class AWSAppSyncRealTimeProvider {
});

const headerString = authHeader ? JSON.stringify(authHeader) : '';
const headerQs = base64Encoder.convert(headerString);
// base64url-encoded string
const encodedHeader = base64Encoder.convert(headerString, {
urlSafe: true,
skipPadding: true,
});

const payloadQs = base64Encoder.convert(payloadString);
const authTokenSubprotocol = `header-${encodedHeader}`;

const queryString = this._queryStringFromCustomHeaders(
const queryParams = this._queryParamsFromCustomHeaders(
additionalCustomHeaders,
);

let discoverableEndpoint = appSyncGraphqlEndpoint ?? '';

if (this.isCustomDomain(discoverableEndpoint)) {
discoverableEndpoint =
discoverableEndpoint.concat(customDomainPath);
} else {
discoverableEndpoint = discoverableEndpoint
.replace('appsync-api', 'appsync-realtime-api')
.replace('gogi-beta', 'grt-beta');
}

// Creating websocket url with required query strings
const protocol = 'wss://';
discoverableEndpoint = discoverableEndpoint
.replace('https://', protocol)
.replace('http://', protocol);

let awsRealTimeUrl = `${discoverableEndpoint}?header=${headerQs}&payload=${payloadQs}`;

if (queryString !== '') {
awsRealTimeUrl += `&${queryString}`;
}
const awsRealTimeUrl = this._realtimeUrlWithQueryString(
appSyncGraphqlEndpoint,
queryParams,
);

await this._initializeRetryableHandshake(awsRealTimeUrl);
await this._initializeRetryableHandshake(
awsRealTimeUrl,
authTokenSubprotocol,
);

this.promiseArray.forEach(({ res }) => {
logger.debug('Notifying connection successful');
Expand All @@ -841,23 +874,37 @@ export class AWSAppSyncRealTimeProvider {
});
}

private async _initializeRetryableHandshake(awsRealTimeUrl: string) {
private async _initializeRetryableHandshake(
awsRealTimeUrl: string,
subprotocol: string,
) {
logger.debug(`Initializaling retryable Handshake`);
await jitteredExponentialRetry(
this._initializeHandshake.bind(this),
[awsRealTimeUrl],
[awsRealTimeUrl, subprotocol],
MAX_DELAY_MS,
);
}

private async _initializeHandshake(awsRealTimeUrl: string) {
/**
*
* @param subprotocol -
*/
private async _initializeHandshake(
awsRealTimeUrl: string,
subprotocol: string,
) {
logger.debug(`Initializing handshake ${awsRealTimeUrl}`);
// Because connecting the socket is async, is waiting until connection is open
// Step 1: connect websocket
try {
await (() => {
return new Promise<void>((resolve, reject) => {
const newSocket = this.getNewWebSocket(awsRealTimeUrl, 'graphql-ws');
const newSocket = this.getNewWebSocket(awsRealTimeUrl, [
'graphql-ws',
subprotocol,
]);

newSocket.onerror = () => {
logger.debug(`WebSocket connection error`);
};
Expand Down
2 changes: 1 addition & 1 deletion packages/aws-amplify/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@
"name": "[API] generateClient (AppSync)",
"path": "./dist/esm/api/index.mjs",
"import": "{ generateClient }",
"limit": "41 kB"
"limit": "41.5 kB"
},
{
"name": "[API] REST API handlers",
Expand Down
8 changes: 8 additions & 0 deletions packages/core/__tests__/utils/convert/base64Encoder.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,12 @@ describe('base64Encoder (non-native)', () => {
'test-test_test',
);
});

it('makes the result a base64url string with no padding chars', () => {
const mockResult = 'test+test/test=='; // = is the base64 padding char
mockBtoa.mockReturnValue(mockResult);
expect(
base64Encoder.convert('test', { urlSafe: true, skipPadding: true }),
).toBe('test-test_test');
});
});
33 changes: 26 additions & 7 deletions packages/core/src/utils/convert/base64/base64Encoder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,37 @@
// SPDX-License-Identifier: Apache-2.0

import { getBtoa } from '../../globalHelpers';
import { Base64Encoder } from '../types';
import type { Base64Encoder, Base64EncoderConvertOptions } from '../types';

import { bytesToString } from './bytesToString';

export const base64Encoder: Base64Encoder = {
convert(input, { urlSafe } = { urlSafe: false }) {
/**
* Convert input to base64-encoded string
* @param input - string to convert to base64
* @param options - encoding options that can optionally produce a base64url string
* @returns base64-encoded string
*/
convert(
input,
options: Base64EncoderConvertOptions = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: options should've been typed by the Base64Encoder

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is, I just wanted it to be more explicit, but I can follow up and remove in another PR if you prefer

urlSafe: false,
skipPadding: false,
},
) {
const inputStr = typeof input === 'string' ? input : bytesToString(input);
const encodedStr = getBtoa()(inputStr);
let encodedStr = getBtoa()(inputStr);

// see details about the char replacing at https://datatracker.ietf.org/doc/html/rfc4648#section-5
return urlSafe
? encodedStr.replace(/\+/g, '-').replace(/\//g, '_')
: encodedStr;
// urlSafe char replacement and skipPadding options conform to the base64url spec
// https://datatracker.ietf.org/doc/html/rfc4648#section-5
if (options.urlSafe) {
encodedStr = encodedStr.replace(/\+/g, '-').replace(/\//g, '_');
}

if (options.skipPadding) {
encodedStr = encodedStr.replace(/=/g, '');
}

return encodedStr;
},
};
1 change: 1 addition & 0 deletions packages/core/src/utils/convert/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

export interface Base64EncoderConvertOptions {
urlSafe: boolean;
skipPadding?: boolean;
Comment on lines 5 to +6
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, should the urlSafe be an optional property too, otherwise when you specify skipPadding you'd need to pass in a value for urlSafe as well. This is a internal tooling, I think it's safe to make a necessary change.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was intentional. skipPadding would only be valid if urlSafe is set to true.

}

export interface Base64Encoder {
Expand Down
Loading