diff --git a/packages/api/amplify_api/example/integration_test/graphql_tests.dart b/packages/api/amplify_api/example/integration_test/graphql_tests.dart index f1a9a42362..20404fdea3 100644 --- a/packages/api/amplify_api/example/integration_test/graphql_tests.dart +++ b/packages/api/amplify_api/example/integration_test/graphql_tests.dart @@ -529,136 +529,142 @@ void main() { }); }); - group('subscriptions', () { - // Some local helper methods to help with establishing subscriptions and such. - - // Wait for subscription established for given request. - Future>> - _getEstablishedSubscriptionOperation( - GraphQLRequest subscriptionRequest, - void Function(GraphQLResponse) onData) async { - Completer establishedCompleter = Completer(); - final stream = - Amplify.API.subscribe(subscriptionRequest, onEstablished: () { - establishedCompleter.complete(); - }); - final subscription = stream.listen( - onData, - onError: (Object e) => fail('Error in subscription stream: $e'), - ); - - await establishedCompleter.future - .timeout(const Duration(seconds: _subscriptionTimeoutInterval)); - return subscription; - } + group( + 'subscriptions', + () { + // Some local helper methods to help with establishing subscriptions and such. + + // Wait for subscription established for given request. + Future>> + _getEstablishedSubscriptionOperation( + GraphQLRequest subscriptionRequest, + void Function(GraphQLResponse) onData) async { + Completer establishedCompleter = Completer(); + final stream = + Amplify.API.subscribe(subscriptionRequest, onEstablished: () { + establishedCompleter.complete(); + }); + final subscription = stream.listen( + onData, + onError: (Object e) => fail('Error in subscription stream: $e'), + ); + + await establishedCompleter.future + .timeout(const Duration(seconds: _subscriptionTimeoutInterval)); + return subscription; + } - // Establish subscription for request, do the mutationFunction, then wait - // for the stream event, cancel the operation, return response from event. - Future> _establishSubscriptionAndMutate( - GraphQLRequest subscriptionRequest, - Future Function() mutationFunction) async { - Completer> dataCompleter = Completer(); - // With stream established, exec callback with stream events. - final subscription = await _getEstablishedSubscriptionOperation( - subscriptionRequest, (event) { - if (event.hasErrors) { - fail('subscription errors: ${event.errors}'); - } - dataCompleter.complete(event); - }); - await mutationFunction(); - final response = await dataCompleter.future - .timeout((const Duration(seconds: _subscriptionTimeoutInterval))); + // Establish subscription for request, do the mutationFunction, then wait + // for the stream event, cancel the operation, return response from event. + Future> _establishSubscriptionAndMutate( + GraphQLRequest subscriptionRequest, + Future Function() mutationFunction) async { + Completer> dataCompleter = Completer(); + // With stream established, exec callback with stream events. + final subscription = await _getEstablishedSubscriptionOperation( + subscriptionRequest, (event) { + if (event.hasErrors) { + fail('subscription errors: ${event.errors}'); + } + dataCompleter.complete(event); + }); + await mutationFunction(); + final response = await dataCompleter.future + .timeout((const Duration(seconds: _subscriptionTimeoutInterval))); + + await subscription.cancel(); + return response; + } - await subscription.cancel(); - return response; - } + testWidgets( + 'should emit event when onCreate subscription made with model helper', + (WidgetTester tester) async { + String name = + 'Integration Test Blog - subscription create ${UUID.getUUID()}'; + final subscriptionRequest = + ModelSubscriptions.onCreate(Blog.classType); - testWidgets( - 'should emit event when onCreate subscription made with model helper', - (WidgetTester tester) async { - String name = - 'Integration Test Blog - subscription create ${UUID.getUUID()}'; - final subscriptionRequest = ModelSubscriptions.onCreate(Blog.classType); + final eventResponse = await _establishSubscriptionAndMutate( + subscriptionRequest, () => addBlog(name)); + Blog? blogFromEvent = eventResponse.data; - final eventResponse = await _establishSubscriptionAndMutate( - subscriptionRequest, () => addBlog(name)); - Blog? blogFromEvent = eventResponse.data; + expect(blogFromEvent?.name, equals(name)); + }); - expect(blogFromEvent?.name, equals(name)); - }); + testWidgets( + 'should emit event when onUpdate subscription made with model helper', + (WidgetTester tester) async { + const originalName = 'Integration Test Blog - subscription update'; + final updatedName = + 'Integration Test Blog - subscription update, name now ${UUID.getUUID()}'; + Blog blogToUpdate = await addBlog(originalName); + + final subscriptionRequest = + ModelSubscriptions.onUpdate(Blog.classType); + final eventResponse = + await _establishSubscriptionAndMutate(subscriptionRequest, () { + blogToUpdate = blogToUpdate.copyWith(name: updatedName); + final updateReq = ModelMutations.update(blogToUpdate); + return Amplify.API.mutate(request: updateReq).response; + }); + Blog? blogFromEvent = eventResponse.data; + + expect(blogFromEvent?.name, equals(updatedName)); + }); - testWidgets( - 'should emit event when onUpdate subscription made with model helper', - (WidgetTester tester) async { - const originalName = 'Integration Test Blog - subscription update'; - final updatedName = - 'Integration Test Blog - subscription update, name now ${UUID.getUUID()}'; - Blog blogToUpdate = await addBlog(originalName); - - final subscriptionRequest = ModelSubscriptions.onUpdate(Blog.classType); - final eventResponse = - await _establishSubscriptionAndMutate(subscriptionRequest, () { - blogToUpdate = blogToUpdate.copyWith(name: updatedName); - final updateReq = ModelMutations.update(blogToUpdate); - return Amplify.API.mutate(request: updateReq).response; + testWidgets( + 'should emit event when onDelete subscription made with model helper', + (WidgetTester tester) async { + const name = 'Integration Test Blog - subscription delete'; + Blog blogToDelete = await addBlog(name); + + final subscriptionRequest = + ModelSubscriptions.onDelete(Blog.classType); + final eventResponse = + await _establishSubscriptionAndMutate(subscriptionRequest, () { + final deleteReq = + ModelMutations.deleteById(Blog.classType, blogToDelete.id); + return Amplify.API.mutate(request: deleteReq).response; + }); + Blog? blogFromEvent = eventResponse.data; + + expect(blogFromEvent?.name, equals(name)); }); - Blog? blogFromEvent = eventResponse.data; - expect(blogFromEvent?.name, equals(updatedName)); - }); + testWidgets('should cancel subscription', (WidgetTester tester) async { + const name = 'Integration Test Blog - subscription to cancel'; + Blog blogToDelete = await addBlog(name); - testWidgets( - 'should emit event when onDelete subscription made with model helper', - (WidgetTester tester) async { - const name = 'Integration Test Blog - subscription delete'; - Blog blogToDelete = await addBlog(name); + final subReq = ModelSubscriptions.onDelete(Blog.classType); + final subscription = + await _getEstablishedSubscriptionOperation(subReq, (_) { + fail('Subscription event triggered. Should be canceled.'); + }); + await subscription.cancel(); - final subscriptionRequest = ModelSubscriptions.onDelete(Blog.classType); - final eventResponse = - await _establishSubscriptionAndMutate(subscriptionRequest, () { + // delete the blog, wait for update final deleteReq = ModelMutations.deleteById(Blog.classType, blogToDelete.id); - return Amplify.API.mutate(request: deleteReq).response; + await Amplify.API.mutate(request: deleteReq).response; + await Future.delayed(const Duration(seconds: 5)); }); - Blog? blogFromEvent = eventResponse.data; - expect(blogFromEvent?.name, equals(name)); - }); + testWidgets( + 'should emit event when onCreate subscription made with model helper for post (model with parent).', + (WidgetTester tester) async { + String title = + 'Integration Test post - subscription create ${UUID.getUUID()}'; + final subscriptionRequest = + ModelSubscriptions.onCreate(Post.classType); - testWidgets('should cancel subscription', (WidgetTester tester) async { - const name = 'Integration Test Blog - subscription to cancel'; - Blog blogToDelete = await addBlog(name); + final eventResponse = await _establishSubscriptionAndMutate( + subscriptionRequest, + () => addPostAndBlogWithModelHelper(title, 0)); + Post? postFromEvent = eventResponse.data; - final subReq = ModelSubscriptions.onDelete(Blog.classType); - final subscription = - await _getEstablishedSubscriptionOperation(subReq, (_) { - fail('Subscription event triggered. Should be canceled.'); + expect(postFromEvent?.title, equals(title)); }); - await subscription.cancel(); - - // delete the blog, wait for update - final deleteReq = - ModelMutations.deleteById(Blog.classType, blogToDelete.id); - await Amplify.API.mutate(request: deleteReq).response; - await Future.delayed(const Duration(seconds: 5)); - }); - - testWidgets( - 'should emit event when onCreate subscription made with model helper for post (model with parent).', - (WidgetTester tester) async { - String title = - 'Integration Test post - subscription create ${UUID.getUUID()}'; - final subscriptionRequest = ModelSubscriptions.onCreate(Post.classType); - - final eventResponse = await _establishSubscriptionAndMutate( - subscriptionRequest, () => addPostAndBlogWithModelHelper(title, 0)); - Post? postFromEvent = eventResponse.data; - - expect(postFromEvent?.title, equals(title)); - }); - }, - skip: - 'TODO(ragingsquirrel3): re-enable tests once subscriptions are implemented.'); + }, + ); }); } diff --git a/packages/api/amplify_api/example/lib/graphql_api_view.dart b/packages/api/amplify_api/example/lib/graphql_api_view.dart index 6644dad380..fa0f2f345f 100644 --- a/packages/api/amplify_api/example/lib/graphql_api_view.dart +++ b/packages/api/amplify_api/example/lib/graphql_api_view.dart @@ -45,13 +45,19 @@ class _GraphQLApiViewState extends State { onEstablished: () => print('Subscription established'), ); - try { - await for (var event in operation) { - print('Subscription event data received: ${event.data}'); - } - } on Exception catch (e) { - print('Error in subscription stream: $e'); - } + final streamSubscription = operation.listen( + (event) { + final result = 'Subscription event data received: ${event.data}'; + print(result); + setState(() { + _result = result; + }); + }, + onError: (Object error) => print( + 'Error in GraphQL subscription: $error', + ), + ); + _unsubscribe = streamSubscription.cancel; } Future query() async { diff --git a/packages/api/amplify_api/lib/src/api_plugin_impl.dart b/packages/api/amplify_api/lib/src/api_plugin_impl.dart index e353c70a31..66e0d0ca91 100644 --- a/packages/api/amplify_api/lib/src/api_plugin_impl.dart +++ b/packages/api/amplify_api/lib/src/api_plugin_impl.dart @@ -17,6 +17,7 @@ library amplify_api; import 'dart:io'; import 'package:amplify_api/amplify_api.dart'; +import 'package:amplify_api/src/graphql/ws/web_socket_connection.dart'; import 'package:amplify_api/src/native_api_plugin.dart'; import 'package:amplify_core/amplify_core.dart'; import 'package:async/async.dart'; @@ -37,11 +38,16 @@ class AmplifyAPIDart extends AmplifyAPI { late final AWSApiPluginConfig _apiConfig; final http.Client? _baseHttpClient; late final AmplifyAuthProviderRepository _authProviderRepo; + final _logger = AmplifyLogger.category(Category.api); /// A map of the keys from the Amplify API config to HTTP clients to use for /// requests to that endpoint. final Map _clientPool = {}; + /// A map of the keys from the Amplify API config websocket connections to use + /// for that endpoint. + final Map _webSocketConnectionPool = {}; + /// The registered [APIAuthProvider] instances. final Map _authProviders = {}; @@ -123,6 +129,24 @@ class AmplifyAPIDart extends AmplifyAPI { )); } + /// Returns the websocket connection to use for a given endpoint. + /// + /// Use [apiName] if there are multiple endpoints. + @visibleForTesting + WebSocketConnection getWebSocketConnection({String? apiName}) { + final endpoint = _apiConfig.getEndpoint( + type: EndpointType.graphQL, + apiName: apiName, + ); + return _webSocketConnectionPool[endpoint.name] ??= WebSocketConnection( + endpoint.config, + _authProviderRepo, + logger: _logger.createChild( + 'webSocketConnection${endpoint.name}', + ), + ); + } + Uri _getGraphQLUri(String? apiName) { final endpoint = _apiConfig.getEndpoint( type: EndpointType.graphQL, @@ -187,6 +211,15 @@ class AmplifyAPIDart extends AmplifyAPI { return _makeCancelable>(responseFuture); } + @override + Stream> subscribe( + GraphQLRequest request, { + void Function()? onEstablished, + }) { + return getWebSocketConnection(apiName: request.apiName) + .subscribe(request, onEstablished); + } + // ====== REST ======= @override diff --git a/packages/api/amplify_api/lib/src/decorators/web_socket_auth_utils.dart b/packages/api/amplify_api/lib/src/decorators/web_socket_auth_utils.dart new file mode 100644 index 0000000000..d1520c731e --- /dev/null +++ b/packages/api/amplify_api/lib/src/decorators/web_socket_auth_utils.dart @@ -0,0 +1,125 @@ +// Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +@internal +library amplify_api.decorators.web_socket_auth_utils; + +import 'dart:convert'; + +import 'package:amplify_core/amplify_core.dart'; +import 'package:http/http.dart' as http; +import 'package:meta/meta.dart'; + +import '../graphql/ws/web_socket_types.dart'; +import 'authorize_http_request.dart'; + +// Constants for header values as noted in https://docs.aws.amazon.com/appsync/latest/devguide/real-time-websocket-client.html. +const _requiredHeaders = { + AWSHeaders.accept: 'application/json, text/javascript', + AWSHeaders.contentEncoding: 'amz-1.0', + AWSHeaders.contentType: 'application/json; charset=UTF-8', +}; + +// AppSync expects "{}" encoded in the URI as the payload during handshake. +const _emptyBody = '{}'; + +/// Generate a URI for the connection and all subscriptions. +/// +/// See https://docs.aws.amazon.com/appsync/latest/devguide/real-time-websocket-client.html#handshake-details-to-establish-the-websocket-connection= +Future generateConnectionUri( + AWSApiConfig config, AmplifyAuthProviderRepository authRepo) async { + final authorizationHeaders = await _generateAuthorizationHeaders( + config, + isConnectionInit: true, + authRepo: authRepo, + body: _emptyBody, + ); + final encodedAuthHeaders = + base64.encode(json.encode(authorizationHeaders).codeUnits); + final endpointUri = Uri.parse( + config.endpoint.replaceFirst('appsync-api', 'appsync-realtime-api'), + ); + return Uri(scheme: 'wss', host: endpointUri.host, path: 'graphql') + .replace(queryParameters: { + 'header': encodedAuthHeaders, + 'payload': base64.encode(utf8.encode(_emptyBody)), + }); +} + +/// Generate websocket message with authorized payload to register subscription. +/// +/// See https://docs.aws.amazon.com/appsync/latest/devguide/real-time-websocket-client.html#subscription-registration-message +Future + generateSubscriptionRegistrationMessage( + AWSApiConfig config, { + required String id, + required AmplifyAuthProviderRepository authRepo, + required GraphQLRequest request, +}) async { + final body = + jsonEncode({'variables': request.variables, 'query': request.document}); + final authorizationHeaders = await _generateAuthorizationHeaders( + config, + isConnectionInit: false, + authRepo: authRepo, + body: body, + ); + + return WebSocketSubscriptionRegistrationMessage( + id: id, + payload: SubscriptionRegistrationPayload( + request: request, + config: config, + authorizationHeaders: authorizationHeaders, + ), + ); +} + +/// For either connection URI or subscription registration, authorization headers +/// are formatted correctly to be either encoded into URI query params or subscription +/// registration payload headers. +/// +/// If `isConnectionInit` true then headers are formatted like connection URI. +/// Otherwise body will be formatted as subscription registration. This is done by creating +/// a canonical HTTP request that is authorized but never sent. The headers from +/// the HTTP request are reformatted and returned. This logic applies for all auth +/// modes as determined by [authRepo] parameter. +Future> _generateAuthorizationHeaders( + AWSApiConfig config, { + required bool isConnectionInit, + required AmplifyAuthProviderRepository authRepo, + required String body, +}) async { + final endpointHost = Uri.parse(config.endpoint).host; + // Create canonical HTTP request to authorize but never send. + // + // The canonical request URL is a little different depending on if authorizing + // connection URI or start message (subscription registration). + final maybeConnect = isConnectionInit ? '/connect' : ''; + final canonicalHttpRequest = + http.Request('POST', Uri.parse('${config.endpoint}$maybeConnect')); + canonicalHttpRequest.headers.addAll(_requiredHeaders); + canonicalHttpRequest.body = body; + final authorizedHttpRequest = await authorizeHttpRequest( + canonicalHttpRequest, + endpointConfig: config, + authProviderRepo: authRepo, + ); + + // Take authorized HTTP headers as map with "host" value added. + return { + ...authorizedHttpRequest.headers, + AWSHeaders.host: endpointHost, + }; +} diff --git a/packages/api/amplify_api/lib/src/graphql/ws/web_socket_connection.dart b/packages/api/amplify_api/lib/src/graphql/ws/web_socket_connection.dart new file mode 100644 index 0000000000..939aab96b8 --- /dev/null +++ b/packages/api/amplify_api/lib/src/graphql/ws/web_socket_connection.dart @@ -0,0 +1,258 @@ +// Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import 'dart:async'; +import 'dart:convert'; + +import 'package:amplify_api/src/decorators/web_socket_auth_utils.dart'; +import 'package:amplify_core/amplify_core.dart'; +import 'package:async/async.dart'; +import 'package:meta/meta.dart'; +import 'package:web_socket_channel/status.dart' as status; +import 'package:web_socket_channel/web_socket_channel.dart'; + +import 'web_socket_message_stream_transformer.dart'; +import 'web_socket_types.dart'; + +/// 1001, going away +const _defaultCloseStatus = status.goingAway; + +/// {@template amplify_api.ws.web_socket_connection} +/// Manages connection with an AppSync backend and subscription routing. +/// {@endtemplate} +@internal +class WebSocketConnection implements Closeable { + /// Allowed protocols for this connection. + static const webSocketProtocols = ['graphql-ws']; + final AmplifyLogger _logger; + + // Config and auth repo together determine how to authorize connection URLs + // and subscription registration messages. + final AmplifyAuthProviderRepository _authProviderRepo; + final AWSApiConfig _config; + + // Manages all incoming messages from server. Primarily handles messages related + // to the entire connection. E.g. connection_ack, connection_error, ka, error. + // Other events (for single subscriptions) rebroadcast to _rebroadcastController. + WebSocketChannel? _channel; + StreamSubscription? _subscription; + RestartableTimer? _timeoutTimer; + + // Re-broadcasts incoming messages for child streams (single GraphQL subscriptions). + // start_ack, data, error + final StreamController _rebroadcastController = + StreamController.broadcast(); + Stream get _messageStream => _rebroadcastController.stream; + + // Manage initial connection state. + var _initMemo = AsyncMemoizer(); + Completer _connectionReady = Completer(); + + /// Fires when the connection is ready to be listened to after the first + /// `connection_ack` message. + Future get ready => _connectionReady.future; + + /// {@macro amplify_api.ws.web_socket_connection} + WebSocketConnection(this._config, this._authProviderRepo, + {required AmplifyLogger logger}) + : _logger = logger; + + /// Connects _subscription stream to _onData handler. + @visibleForTesting + StreamSubscription getStreamSubscription( + Stream stream) { + return stream.transform(const WebSocketMessageStreamTransformer()).listen( + _onData, + cancelOnError: true, + onError: (Object e) { + _connectionError( + ApiException('Connection failed.', underlyingException: e.toString()), + ); + }, + ); + } + + /// Connects WebSocket channel to _subscription stream but does not send connection + /// init message. + @visibleForTesting + Future connect(Uri connectionUri) async { + _channel = WebSocketChannel.connect( + connectionUri, + protocols: webSocketProtocols, + ); + _subscription = getStreamSubscription(_channel!.stream); + } + + void _connectionError(ApiException exception) { + _connectionReady.completeError(exception); + _channel?.sink.close(); + _resetConnectionInit(); + } + + // Reset connection init variables so it can be re-attempted. + void _resetConnectionInit() { + _initMemo = AsyncMemoizer(); + _connectionReady = Completer(); + } + + /// Closes the WebSocket connection and cleans up local variables. + @override + void close([int closeStatus = _defaultCloseStatus]) { + _logger.verbose('Closing web socket connection.'); + final reason = + closeStatus == _defaultCloseStatus ? 'client closed' : 'unknown'; + _subscription?.cancel(); + _channel?.sink.done.whenComplete(() => _channel = null); + _channel?.sink.close(closeStatus, reason); + _rebroadcastController.close(); + _timeoutTimer?.cancel(); + _resetConnectionInit(); + } + + /// Initializes the connection. + /// + /// Connects to WebSocket, sends connection message and resolves future once + /// connection_ack message received from server. If the connection was previously + /// established then will return previously completed future. + Future init() => _initMemo.runOnce(_init); + + Future _init() async { + final connectionUri = + await generateConnectionUri(_config, _authProviderRepo); + await connect(connectionUri); + + send(WebSocketConnectionInitMessage()); + + return ready; + } + + Future _sendSubscriptionRegistrationMessage( + GraphQLRequest request) async { + await init(); // no-op if already connected + final subscriptionRegistrationMessage = + await generateSubscriptionRegistrationMessage( + _config, + id: request.id, + authRepo: _authProviderRepo, + request: request, + ); + send(subscriptionRegistrationMessage); + } + + /// Subscribes to the given GraphQL request. Returns the subscription object, + /// or throws an [Exception] if there's an error. + Stream> subscribe( + GraphQLRequest request, + void Function()? onEstablished, + ) { + // Create controller for this subscription so we can add errors. + late StreamController> controller; + controller = StreamController>.broadcast( + onCancel: () { + _cancel(request.id); + controller.close(); + }, + ); + + // Filter incoming messages that have the subscription ID and return as new + // stream with messages converted to GraphQLResponse. + _messageStream + .where((msg) => msg.id == request.id) + .transform(WebSocketSubscriptionStreamTransformer( + request, + onEstablished, + logger: _logger, + )) + .listen( + controller.add, + onError: controller.addError, + onDone: controller.close, + cancelOnError: true, + ); + + _sendSubscriptionRegistrationMessage(request) + .catchError(controller.addError); + + return controller.stream; + } + + /// Cancels a subscription. + void _cancel(String subscriptionId) { + _logger.info('Attempting to cancel Operation $subscriptionId'); + send(WebSocketStopMessage(id: subscriptionId)); + // TODO(equartey): if this is the only subscription, close the connection. + } + + /// Serializes a message as JSON string and sends over WebSocket channel. + @visibleForTesting + void send(WebSocketMessage message) { + final msgJson = json.encode(message.toJson()); + _channel!.sink.add(msgJson); + } + + /// Times out the connection (usually if a keep alive has not been received in time). + void _timeout(Duration timeoutDuration) { + _rebroadcastController.addError( + TimeoutException( + 'Connection timeout', + timeoutDuration, + ), + ); + } + + /// Handles incoming data on the WebSocket. + /// + /// Here, handle connection-wide messages and pass subscription events to + /// `_rebroadcastController`. + void _onData(WebSocketMessage message) { + _logger.verbose('websocket received message: ${prettyPrintJson(message)}'); + + switch (message.messageType) { + case MessageType.connectionAck: + final messageAck = message.payload as ConnectionAckMessagePayload; + final timeoutDuration = Duration( + milliseconds: messageAck.connectionTimeoutMs, + ); + _timeoutTimer = RestartableTimer( + timeoutDuration, + () => _timeout(timeoutDuration), + ); + _connectionReady.complete(); + _logger.verbose('Connection established. Registered timer'); + return; + case MessageType.connectionError: + _connectionError(const ApiException( + 'Error occurred while connecting to the websocket')); + return; + case MessageType.keepAlive: + _timeoutTimer?.reset(); + _logger.verbose('Reset timer'); + return; + case MessageType.error: + // Only handle general messages, not subscription-specific ones + if (message.id != null) { + break; + } + final wsError = message.payload as WebSocketError; + _rebroadcastController.addError(wsError); + return; + default: + break; + } + + // Re-broadcast other message types related to single subscriptions. + + _rebroadcastController.add(message); + } +} diff --git a/packages/api/amplify_api/lib/src/graphql/ws/web_socket_message_stream_transformer.dart b/packages/api/amplify_api/lib/src/graphql/ws/web_socket_message_stream_transformer.dart new file mode 100644 index 0000000000..e037bd1ba5 --- /dev/null +++ b/packages/api/amplify_api/lib/src/graphql/ws/web_socket_message_stream_transformer.dart @@ -0,0 +1,94 @@ +// Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +@internal +library amplify_api.graphql.ws.web_socket_message_stream_transformer; + +import 'dart:async'; +import 'dart:convert'; + +import 'package:amplify_api/src/util.dart'; +import 'package:amplify_core/amplify_core.dart'; +import 'package:meta/meta.dart'; + +import '../graphql_response_decoder.dart'; +import 'web_socket_types.dart'; + +/// Top-level transformer. +class WebSocketMessageStreamTransformer + extends StreamTransformerBase { + /// Transforms raw web socket response (String) to `WebSocketMessage` for all input + /// from channel. + const WebSocketMessageStreamTransformer(); + + @override + Stream bind(Stream stream) { + return stream.cast().map>((str) { + return json.decode(str) as Map; + }).map(WebSocketMessage.fromJson); + } +} + +/// Final level of transformation for converting `WebSocketMessage`s to stream +/// of `GraphQLResponse` that is eventually passed to public API `Amplify.API.subscribe`. +class WebSocketSubscriptionStreamTransformer + extends StreamTransformerBase> { + /// request for this stream, needed to properly decode response events + final GraphQLRequest request; + + /// logs complete messages to better provide visibility to cancels + final AmplifyLogger logger; + + /// executes when start_ack message received + final void Function()? onEstablished; + + /// [request] is used to properly decode response events + /// [onEstablished] is executed when start_ack message received + /// [logger] logs cancel messages when complete message received + const WebSocketSubscriptionStreamTransformer( + this.request, + this.onEstablished, { + required this.logger, + }); + + @override + Stream> bind(Stream stream) async* { + await for (var event in stream) { + switch (event.messageType) { + case MessageType.startAck: + onEstablished?.call(); + break; + case MessageType.data: + final payload = event.payload as SubscriptionDataPayload; + // TODO(ragingsquirrel3): refactor decoder + final errors = deserializeGraphQLResponseErrors(payload.toJson()); + yield GraphQLResponseDecoder.instance.decode( + request: request, + data: json.encode(payload.data), + errors: errors, + ); + + break; + case MessageType.error: + final error = event.payload as WebSocketError; + throw error; + case MessageType.complete: + logger.info('Cancel succeeded for Operation: ${event.id}'); + return; + default: + break; + } + } + } +} diff --git a/packages/api/amplify_api/lib/src/graphql/ws/web_socket_types.dart b/packages/api/amplify_api/lib/src/graphql/ws/web_socket_types.dart new file mode 100644 index 0000000000..c957b82641 --- /dev/null +++ b/packages/api/amplify_api/lib/src/graphql/ws/web_socket_types.dart @@ -0,0 +1,228 @@ +// Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ignore_for_file: public_member_api_docs + +@internal +library amplify_api.graphql.ws.web_socket_types; + +import 'dart:convert'; + +import 'package:amplify_core/amplify_core.dart'; +import 'package:json_annotation/json_annotation.dart'; +import 'package:meta/meta.dart'; + +enum MessageType { + @JsonValue('connection_init') + connectionInit('connection_init'), + + @JsonValue('connection_ack') + connectionAck('connection_ack'), + + @JsonValue('connection_error') + connectionError('connection_error'), + + @JsonValue('start') + start('start'), + + @JsonValue('start_ack') + startAck('start_ack'), + + @JsonValue('connection_error') + error('connection_error'), + + @JsonValue('data') + data('data'), + + @JsonValue('stop') + stop('stop'), + + @JsonValue('ka') + keepAlive('ka'), + + @JsonValue('complete') + complete('complete'); + + final String type; + + const MessageType(this.type); + + factory MessageType.fromJson(dynamic json) { + return MessageType.values.firstWhere((el) => json == el.type); + } +} + +@immutable +abstract class WebSocketMessagePayload { + const WebSocketMessagePayload(); + + static const Map + _factories = { + MessageType.connectionAck: ConnectionAckMessagePayload.fromJson, + MessageType.data: SubscriptionDataPayload.fromJson, + MessageType.error: WebSocketError.fromJson, + }; + + static WebSocketMessagePayload? fromJson(Map json, MessageType type) { + return _factories[type]?.call(json); + } + + Map toJson(); + + @override + String toString() => prettyPrintJson(toJson()); +} + +@internal +class ConnectionAckMessagePayload extends WebSocketMessagePayload { + final int connectionTimeoutMs; + + const ConnectionAckMessagePayload(this.connectionTimeoutMs); + + static ConnectionAckMessagePayload fromJson(Map json) { + final connectionTimeoutMs = json['connectionTimeoutMs'] as int; + return ConnectionAckMessagePayload(connectionTimeoutMs); + } + + @override + Map toJson() => { + 'connectionTimeoutMs': connectionTimeoutMs, + }; +} + +class SubscriptionRegistrationPayload extends WebSocketMessagePayload { + final GraphQLRequest request; + final AWSApiConfig config; + final Map authorizationHeaders; + + const SubscriptionRegistrationPayload({ + required this.request, + required this.config, + required this.authorizationHeaders, + }); + + @override + Map toJson() { + return { + 'data': jsonEncode( + {'variables': request.variables, 'query': request.document}), + 'extensions': >{ + 'authorization': authorizationHeaders + } + }; + } +} + +class SubscriptionDataPayload extends WebSocketMessagePayload { + final Map? data; + final Map? errors; + + const SubscriptionDataPayload(this.data, this.errors); + + static SubscriptionDataPayload fromJson(Map json) { + final data = json['data'] as Map?; + final errors = json['errors'] as Map?; + return SubscriptionDataPayload( + data?.cast(), + errors?.cast(), + ); + } + + @override + Map toJson() => { + 'data': data, + 'errors': errors, + }; +} + +class WebSocketError extends WebSocketMessagePayload implements Exception { + final List errors; + + const WebSocketError(this.errors); + + static WebSocketError fromJson(Map json) { + final errors = json['errors'] as List?; + return WebSocketError(errors?.cast() ?? []); + } + + @override + Map toJson() => { + 'errors': errors, + }; +} + +@immutable +class WebSocketMessage { + final String? id; + final MessageType messageType; + final WebSocketMessagePayload? payload; + + WebSocketMessage({ + String? id, + required this.messageType, + this.payload, + }) : id = id ?? uuid(); + + const WebSocketMessage._({ + this.id, + required this.messageType, + this.payload, + }); + + static WebSocketMessage fromJson(Map json) { + final id = json['id'] as String?; + final type = json['type'] as String; + final messageType = MessageType.fromJson(type); + final payloadMap = json['payload'] as Map?; + final payload = payloadMap == null + ? null + : WebSocketMessagePayload.fromJson( + payloadMap, + messageType, + ); + return WebSocketMessage._( + id: id, + messageType: messageType, + payload: payload, + ); + } + + Map toJson() => { + if (id != null) 'id': id, + 'type': messageType.type, + if (payload != null) 'payload': payload?.toJson(), + }; + + @override + String toString() { + return prettyPrintJson(this); + } +} + +class WebSocketConnectionInitMessage extends WebSocketMessage { + WebSocketConnectionInitMessage() + : super(messageType: MessageType.connectionInit); +} + +class WebSocketSubscriptionRegistrationMessage extends WebSocketMessage { + WebSocketSubscriptionRegistrationMessage({ + required String id, + required SubscriptionRegistrationPayload payload, + }) : super(messageType: MessageType.start, payload: payload, id: id); +} + +class WebSocketStopMessage extends WebSocketMessage { + WebSocketStopMessage({required String id}) + : super(messageType: MessageType.stop, id: id); +} diff --git a/packages/api/amplify_api/pubspec.yaml b/packages/api/amplify_api/pubspec.yaml index a4b2121efe..aa5240a437 100644 --- a/packages/api/amplify_api/pubspec.yaml +++ b/packages/api/amplify_api/pubspec.yaml @@ -21,8 +21,11 @@ dependencies: flutter: sdk: flutter http: ^0.13.4 + json_annotation: ^4.6.0 meta: ^1.7.0 plugin_platform_interface: ^2.0.0 + web_socket_channel: ^2.2.0 + dev_dependencies: amplify_lints: diff --git a/packages/api/amplify_api/test/dart_graphql_test.dart b/packages/api/amplify_api/test/dart_graphql_test.dart index 4d9d8ec47f..b37a2611f3 100644 --- a/packages/api/amplify_api/test/dart_graphql_test.dart +++ b/packages/api/amplify_api/test/dart_graphql_test.dart @@ -12,10 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +import 'dart:async'; import 'dart:convert'; import 'package:amplify_api/amplify_api.dart'; import 'package:amplify_api/src/api_plugin_impl.dart'; +import 'package:amplify_api/src/graphql/ws/web_socket_connection.dart'; import 'package:amplify_core/amplify_core.dart'; import 'package:amplify_test/test_models/ModelProvider.dart'; import 'package:collection/collection.dart'; @@ -24,6 +26,7 @@ import 'package:http/http.dart' as http; import 'package:http/testing.dart'; import 'test_data/fake_amplify_configuration.dart'; +import 'util.dart'; final _deepEquals = const DeepCollectionEquality().equals; @@ -107,6 +110,10 @@ class MockAmplifyAPI extends AmplifyAPIDart { return http.Response( json.encode(_expectedQuerySuccessResponseBody), 200); }); + + @override + WebSocketConnection getWebSocketConnection({String? apiName}) => + MockWebSocketConnection(testApiKeyConfig, getTestAuthProviderRepo()); } void main() { @@ -127,7 +134,34 @@ void main() { } } } '''; - final req = GraphQLRequest(document: graphQLDocument, variables: {}); + final req = GraphQLRequest( + document: graphQLDocument, + variables: {}, + ); + + final operation = Amplify.API.query(request: req); + final res = await operation.value; + + final expected = json.encode(_expectedQuerySuccessResponseBody['data']); + + expect(res.data, equals(expected)); + expect(res.errors, equals(null)); + }); + + test('Query returns proper response.data with dynamic type', () async { + String graphQLDocument = ''' query TestQuery { + listBlogs { + items { + id + name + createdAt + } + } + } '''; + final req = GraphQLRequest( + document: graphQLDocument, + variables: {}, + ); final operation = Amplify.API.query(request: req); final res = await operation.value; @@ -147,8 +181,10 @@ void main() { } } '''; final graphQLVariables = {'name': 'Test Blog 1'}; - final req = GraphQLRequest( - document: graphQLDocument, variables: graphQLVariables); + final req = GraphQLRequest( + document: graphQLDocument, + variables: graphQLVariables, + ); final operation = Amplify.API.mutate(request: req); final res = await operation.value; @@ -158,6 +194,33 @@ void main() { expect(res.data, equals(expected)); expect(res.errors, equals(null)); }); + + test('subscribe() should return a subscription stream', () async { + Completer establishedCompleter = Completer(); + Completer dataCompleter = Completer(); + const graphQLDocument = '''subscription MySubscription { + onCreateBlog { + id + name + createdAt + } + }'''; + final subscriptionRequest = + GraphQLRequest(document: graphQLDocument); + final subscription = Amplify.API.subscribe( + subscriptionRequest, + onEstablished: () => establishedCompleter.complete(), + ); + + final streamSub = subscription.listen( + (event) => dataCompleter.complete(event.data), + ); + await expectLater(establishedCompleter.future, completes); + + final subscriptionData = await dataCompleter.future; + expect(subscriptionData, json.encode(mockSubscriptionData)); + streamSub.cancel(); + }); }); group('Model Helpers', () { const blogSelectionSet = @@ -184,12 +247,33 @@ void main() { expect(res.data?.id, _modelQueryId); expect(res.errors, equals(null)); }); + + test('subscribe() should decode model data', () async { + Completer establishedCompleter = Completer(); + final subscriptionRequest = ModelSubscriptions.onCreate(Post.classType); + final subscription = Amplify.API.subscribe( + subscriptionRequest, + onEstablished: () => establishedCompleter.complete(), + ); + await establishedCompleter.future; + + late StreamSubscription> streamSub; + streamSub = subscription.listen( + expectAsync1((event) { + expect(event.data, isA()); + streamSub.cancel(); + }), + ); + }); }); group('Error Handling', () { test('response errors are decoded', () async { String graphQLDocument = ''' TestError '''; - final req = GraphQLRequest(document: graphQLDocument, variables: {}); + final req = GraphQLRequest( + document: graphQLDocument, + variables: {}, + ); final operation = Amplify.API.query(request: req); final res = await operation.value; @@ -209,7 +293,7 @@ void main() { }); test('canceled query request should never resolve', () async { - final req = GraphQLRequest(document: '', variables: {}); + final req = GraphQLRequest(document: '', variables: {}); final operation = Amplify.API.query(request: req); operation.cancel(); operation.then((p0) => fail('Request should have been cancelled.')); @@ -218,7 +302,7 @@ void main() { }); test('canceled mutation request should never resolve', () async { - final req = GraphQLRequest(document: '', variables: {}); + final req = GraphQLRequest(document: '', variables: {}); final operation = Amplify.API.mutate(request: req); operation.cancel(); operation.then((p0) => fail('Request should have been cancelled.')); diff --git a/packages/api/amplify_api/test/util.dart b/packages/api/amplify_api/test/util.dart index cd06f8c13c..7da7c56c1b 100644 --- a/packages/api/amplify_api/test/util.dart +++ b/packages/api/amplify_api/test/util.dart @@ -12,8 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. +import 'dart:async'; +import 'dart:convert'; + +import 'package:amplify_api/src/graphql/app_sync_api_key_auth_provider.dart'; +import 'package:amplify_api/src/graphql/ws/web_socket_connection.dart'; +import 'package:amplify_api/src/graphql/ws/web_socket_types.dart'; import 'package:amplify_core/amplify_core.dart'; import 'package:aws_signature_v4/aws_signature_v4.dart'; +import 'package:collection/collection.dart'; import 'package:flutter_test/flutter_test.dart'; import 'package:http/http.dart' as http; @@ -60,3 +67,111 @@ void validateSignedRequest(http.BaseRequest request) { contains('aws-sigv4'), ); } + +const testApiKeyConfig = AWSApiConfig( + endpointType: EndpointType.graphQL, + endpoint: 'https://abc123.appsync-api.us-east-1.amazonaws.com/graphql', + region: 'us-east-1', + authorizationType: APIAuthorizationType.apiKey, + apiKey: 'abc-123', +); + +const expectedApiKeyWebSocketConnectionUrl = + 'wss://abc123.appsync-realtime-api.us-east-1.amazonaws.com/graphql?header=eyJDb250ZW50LVR5cGUiOiJhcHBsaWNhdGlvbi9qc29uOyBjaGFyc2V0PVVURi04IiwiWC1BcGktS2V5IjoiYWJjLTEyMyIsIkFjY2VwdCI6ImFwcGxpY2F0aW9uL2pzb24sIHRleHQvamF2YXNjcmlwdCIsIkNvbnRlbnQtRW5jb2RpbmciOiJhbXotMS4wIiwiSG9zdCI6ImFiYzEyMy5hcHBzeW5jLWFwaS51cy1lYXN0LTEuYW1hem9uYXdzLmNvbSJ9&payload=e30%3D'; + +AmplifyAuthProviderRepository getTestAuthProviderRepo() { + final testAuthProviderRepo = AmplifyAuthProviderRepository(); + testAuthProviderRepo.registerAuthProvider( + APIAuthorizationType.apiKey.authProviderToken, + AppSyncApiKeyAuthProvider(), + ); + + return testAuthProviderRepo; +} + +const mockSubscriptionData = { + 'onCreatePost': { + 'id': '49d54440-cb80-4f20-964b-91c142761e82', + 'title': + 'Integration Test post - subscription create aa779f0d-0c92-4677-af32-e43f71b3eb55', + 'rating': 0, + 'created': null, + 'createdAt': '2022-08-15T18:22:15.410Z', + 'updatedAt': '2022-08-15T18:22:15.410Z', + 'blog': { + 'id': '164bd1f1-544c-40cb-a656-a7563b046e71', + 'name': 'Integration Test Blog with a post - create', + 'createdAt': '2022-08-15T18:22:15.164Z', + 'file': null, + 'files': null, + 'updatedAt': '2022-08-15T18:22:15.164Z' + } + } +}; + +/// Extension of [WebSocketConnection] that stores messages internally instead +/// of sending them. +class MockWebSocketConnection extends WebSocketConnection { + /// Instead of actually connecting, just set the URI here so it can be inspected + /// for testing. + Uri? connectedUri; + + /// Instead of sending messages, they are pushed to end of list so they can be + /// inspected for testing. + final List sentMessages = []; + + MockWebSocketConnection( + AWSApiConfig config, AmplifyAuthProviderRepository authProviderRepo) + : super(config, authProviderRepo, logger: AmplifyLogger()); + + WebSocketMessage? get lastSentMessage => sentMessages.lastOrNull; + + final messageStream = StreamController(); + + @override + Future connect(Uri connectionUri) async { + connectedUri = connectionUri; + + // mock some message responses (acks and mock data) from server + final broadcast = messageStream.stream.asBroadcastStream(); + broadcast.listen((event) { + final eventJson = json.decode(event as String); + final messageFromEvent = WebSocketMessage.fromJson(eventJson as Map); + + // connection_init, respond with connection_ack + final mockResponseMessages = []; + if (messageFromEvent.messageType == MessageType.connectionInit) { + mockResponseMessages.add(WebSocketMessage( + messageType: MessageType.connectionAck, + payload: const ConnectionAckMessagePayload(10000), + )); + // start, respond with start_ack and mock data + } else if (messageFromEvent.messageType == MessageType.start) { + mockResponseMessages.add(WebSocketMessage( + messageType: MessageType.startAck, + id: messageFromEvent.id, + )); + mockResponseMessages.add(WebSocketMessage( + messageType: MessageType.data, + id: messageFromEvent.id, + payload: const SubscriptionDataPayload(mockSubscriptionData, null), + )); + } + + for (var mockMessage in mockResponseMessages) { + messageStream.add(json.encode(mockMessage)); + } + }); + + // ensures connected to _onDone events in parent class + getStreamSubscription(broadcast); + } + + /// Pushes message in sentMessages and adds to stream (to support mocking result). + @override + void send(WebSocketMessage message) { + sentMessages.add(message); + final messageStr = json.encode(message.toJson()); + messageStream.add(messageStr); + } +} diff --git a/packages/api/amplify_api/test/ws/web_socket_auth_utils_test.dart b/packages/api/amplify_api/test/ws/web_socket_auth_utils_test.dart new file mode 100644 index 0000000000..19cb61a647 --- /dev/null +++ b/packages/api/amplify_api/test/ws/web_socket_auth_utils_test.dart @@ -0,0 +1,85 @@ +// Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import 'package:amplify_api/src/decorators/web_socket_auth_utils.dart'; +import 'package:amplify_api/src/graphql/app_sync_api_key_auth_provider.dart'; +import 'package:amplify_api/src/graphql/ws/web_socket_types.dart'; +import 'package:amplify_core/amplify_core.dart'; +import 'package:flutter_test/flutter_test.dart'; + +import '../util.dart'; + +void main() { + TestWidgetsFlutterBinding.ensureInitialized(); + + final authProviderRepo = AmplifyAuthProviderRepository(); + authProviderRepo.registerAuthProvider( + APIAuthorizationType.apiKey.authProviderToken, + AppSyncApiKeyAuthProvider()); + + const graphQLDocument = '''subscription MySubscription { + onCreateBlog { + id + name + createdAt + } + }'''; + final subscriptionRequest = GraphQLRequest(document: graphQLDocument); + + void _assertBasicSubscriptionPayloadHeaders( + SubscriptionRegistrationPayload payload) { + expect( + payload.authorizationHeaders[AWSHeaders.contentType], + 'application/json; charset=UTF-8', + ); + expect( + payload.authorizationHeaders[AWSHeaders.accept], + 'application/json, text/javascript', + ); + expect( + payload.authorizationHeaders[AWSHeaders.host], + 'abc123.appsync-api.us-east-1.amazonaws.com', + ); + } + + group('generateConnectionUri', () { + test('should generate authorized connection URI', () async { + final actualConnectionUri = + await generateConnectionUri(testApiKeyConfig, authProviderRepo); + expect( + actualConnectionUri.toString(), + expectedApiKeyWebSocketConnectionUrl, + ); + }); + }); + + group('generateSubscriptionRegistrationMessage', () { + test('should generate an authorized message', () async { + final authorizedMessage = await generateSubscriptionRegistrationMessage( + testApiKeyConfig, + id: 'abc123', + authRepo: authProviderRepo, + request: subscriptionRequest, + ); + final payload = + authorizedMessage.payload as SubscriptionRegistrationPayload; + + _assertBasicSubscriptionPayloadHeaders(payload); + expect( + payload.authorizationHeaders[xApiKey], + testApiKeyConfig.apiKey, + ); + }); + }); +} diff --git a/packages/api/amplify_api/test/ws/web_socket_connection_test.dart b/packages/api/amplify_api/test/ws/web_socket_connection_test.dart new file mode 100644 index 0000000000..9a1e3e6545 --- /dev/null +++ b/packages/api/amplify_api/test/ws/web_socket_connection_test.dart @@ -0,0 +1,111 @@ +// Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import 'dart:async'; +import 'dart:convert'; + +import 'package:amplify_api/src/graphql/app_sync_api_key_auth_provider.dart'; +import 'package:amplify_api/src/graphql/ws/web_socket_types.dart'; +import 'package:amplify_core/amplify_core.dart'; +import 'package:flutter_test/flutter_test.dart'; + +import '../util.dart'; + +void main() { + TestWidgetsFlutterBinding.ensureInitialized(); + + late MockWebSocketConnection connection; + + const graphQLDocument = '''subscription MySubscription { + onCreateBlog { + id + name + createdAt + } + }'''; + final subscriptionRequest = GraphQLRequest(document: graphQLDocument); + + setUp(() { + connection = MockWebSocketConnection( + testApiKeyConfig, + getTestAuthProviderRepo(), + ); + }); + + group('WebSocketConnection', () { + test( + 'init() should connect with authorized query params in URI and send connection init message', + () async { + await connection.init(); + expectLater(connection.ready, completes); + expect( + connection.connectedUri.toString(), + expectedApiKeyWebSocketConnectionUrl, + ); + expect( + connection.lastSentMessage?.messageType, MessageType.connectionInit); + }); + + test('subscribe() should initialize the connection and call onEstablished', + () async { + connection.subscribe(subscriptionRequest, expectAsync0(() {})); + expectLater(connection.ready, completes); + }); + + test( + 'subscribe() should send SubscriptionRegistrationMessage with authorized payload correctly serialized', + () async { + connection.init(); + await connection.ready; + Completer establishedCompleter = Completer(); + connection.subscribe(subscriptionRequest, () { + establishedCompleter.complete(); + }); + await establishedCompleter.future; + + final lastMessage = connection.lastSentMessage; + expect(lastMessage?.messageType, MessageType.start); + final payloadJson = lastMessage?.payload?.toJson(); + final apiKeyFromPayload = + payloadJson?['extensions']['authorization'][xApiKey]; + expect(apiKeyFromPayload, testApiKeyConfig.apiKey); + }); + + test('subscribe() should return a subscription stream', () async { + final subscription = connection.subscribe( + subscriptionRequest, + null, + ); + + late StreamSubscription> streamSub; + streamSub = subscription.listen( + expectAsync1((event) { + expect(event.data, json.encode(mockSubscriptionData)); + streamSub.cancel(); + }), + ); + }); + + test('cancel() should send a stop message', () async { + Completer dataCompleter = Completer(); + final subscription = connection.subscribe(subscriptionRequest, null); + final streamSub = subscription.listen( + (event) => dataCompleter.complete(event.data), + ); + await dataCompleter.future; + streamSub.cancel(); + expect(connection.lastSentMessage?.messageType, MessageType.stop); + }); + }); +}