diff --git a/firebase_admin/_gapic_utils.py b/firebase_admin/_gapic_utils.py index 3c975808c..99ae6cfaa 100644 --- a/firebase_admin/_gapic_utils.py +++ b/firebase_admin/_gapic_utils.py @@ -17,8 +17,8 @@ import io import socket -import googleapiclient -import httplib2 +import googleapiclient # type: ignore +import httplib2 # type: ignore import requests from firebase_admin import exceptions @@ -92,15 +92,15 @@ def handle_googleapiclient_error(error, message=None, code=None, http_response=N if isinstance(error, socket.timeout) or ( isinstance(error, socket.error) and 'timed out' in str(error)): return exceptions.DeadlineExceededError( - message='Timed out while making an API call: {0}'.format(error), + message=f'Timed out while making an API call: {error}', cause=error) if isinstance(error, httplib2.ServerNotFoundError): return exceptions.UnavailableError( - message='Failed to establish a connection: {0}'.format(error), + message=f'Failed to establish a connection: {error}', cause=error) if not isinstance(error, googleapiclient.errors.HttpError): return exceptions.UnknownError( - message='Unknown error while making a remote service call: {0}'.format(error), + message=f'Unknown error while making a remote service call: {error}', cause=error) if not code: diff --git a/firebase_admin/_http_client.py b/firebase_admin/_http_client.py index d259faddf..e30b421cd 100644 --- a/firebase_admin/_http_client.py +++ b/firebase_admin/_http_client.py @@ -17,9 +17,9 @@ This module provides utilities for making HTTP calls using the requests library. """ -from google.auth import transport +from google.auth import transport # type: ignore import requests -from requests.packages.urllib3.util import retry # pylint: disable=import-error +from requests.packages.urllib3.util import retry # type: ignore # pylint: disable=import-error if hasattr(retry.Retry.DEFAULT, 'allowed_methods'): diff --git a/firebase_admin/_http_client_async.py b/firebase_admin/_http_client_async.py new file mode 100644 index 000000000..e2e8e4549 --- /dev/null +++ b/firebase_admin/_http_client_async.py @@ -0,0 +1,182 @@ +# Copyright 2022 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Internal async HTTP client module. + + This module provides utilities for making async HTTP calls using the aiohttp library. + """ + +import json + +import aiohttp +from aiohttp.client_exceptions import ClientResponseError +from google.auth.transport import _aiohttp_requests # type: ignore +from google.auth.transport._aiohttp_requests import _CombinedResponse # type: ignore + + +DEFAULT_RETRY_ATTEMPTS = 4 +DEFAULT_RETRY_CODES = (500, 503) +DEFAULT_TIMEOUT_SECONDS = 120 + + +class HttpClientAsync: + """Base HTTP client used to make aiohttp calls. + + HttpClientAsync maintains an aiohttp session, and handles request authentication and retries if + necessary. + """ + + def __init__( + self, + credential=None, + session=None, + base_url='', + headers=None, + retry_attempts=DEFAULT_RETRY_ATTEMPTS, + retry_codes=DEFAULT_RETRY_CODES, + timeout=DEFAULT_TIMEOUT_SECONDS + ): + """Creates a new HttpClientAsync instance from the provided arguments. + + If a credential is provided, initializes a new aiohttp client session authorized with it. + If neither a credential nor a session is provided, initializes a new unauthorized client + session. + + Args: + credential: A Google credential that can be used to authenticate requests (optional). + session: A custom aiohttp session (optional). + base_url: A URL prefix to be added to all outgoing requests (optional). + headers: A map of headers to be added to all outgoing requests (optional). + retry_attempts: The maximum number of retries that should be attempeted for a request + (optional). + retry_codes: A list of status codes for which the request retry should be attempted + (optional). + timeout: A request timeout in seconds. Defaults to 120 seconds when not specified. Set to + None to disable timeouts (optional). + """ + if credential: + self._session = _aiohttp_requests.AuthorizedSession( + credential, + max_refresh_attempts=retry_attempts, + refresh_status_codes=retry_codes, + refresh_timeout=timeout + ) + elif session: + self._session = session + else: + self._session = aiohttp.ClientSession() # pylint: disable=redefined-variable-type + + if headers: + self._session.headers.update(headers) + self._base_url = base_url + self._timeout = timeout + + @property + def session(self): + return self._session + + @property + def base_url(self): + return self._base_url + + @property + def timeout(self): + return self._timeout + + async def parse_body(self, resp): + raise NotImplementedError + + async def request(self, method, url, **kwargs): + """Makes an async HTTP call using the aiohttp library. + + This is the sole entry point to the aiohttp library. All other helper methods in this + class call this method to send async HTTP requests out. Refer to + http://docs.python-requests.org/en/master/api/ for more information on supported options + and features. + + Args: + method: HTTP method name as a string (e.g. get, post). + url: URL of the remote endpoint. + **kwargs: An additional set of keyword arguments to be passed into the aiohttp API + (e.g. json, params, timeout). + + Returns: + Response: A ``_CombinedResponse`` wrapped ``ClientResponse`` object. + + Raises: + ClientResponseWithBodyError: Any requests exceptions encountered while making the async + HTTP call. + """ + if 'timeout' not in kwargs: + kwargs['timeout'] = self.timeout + resp = await self._session.request(method, self.base_url + url, **kwargs) + wrapped_resp = _CombinedResponse(resp) + + # Get response content from StreamReader before it is closed by error throw. + resp_content = await wrapped_resp.content() + + # Catch response error and re-release it after appending response body needed to + # determine the underlying reason for the error. + try: + resp.raise_for_status() + except ClientResponseError as err: + raise ClientResponseWithBodyError( + err.request_info, + err.history, + wrapped_resp, + resp_content + ) from err + return wrapped_resp + + async def headers(self, method, url, **kwargs): + resp = await self.request(method, url, **kwargs) + return resp.headers + + async def body_and_response(self, method, url, **kwargs): + resp = await self.request(method, url, **kwargs) + return await self.parse_body(resp), resp + + async def body(self, method, url, **kwargs): + resp = await self.request(method, url, **kwargs) + return await self.parse_body(resp) + + async def headers_and_body(self, method, url, **kwargs): + resp = await self.request(method, url, **kwargs) + return await resp.headers, self.parse_body(resp) + + async def close(self): + if self._session is not None: + await self._session.close() + self._session = None + + +class JsonHttpClientAsync(HttpClientAsync): + """An async HTTP client that parses response messages as JSON.""" + + def __init__(self, **kwargs): + HttpClientAsync.__init__(self, **kwargs) + + async def parse_body(self, resp): + content = await resp.content() + return json.loads(content) + + +class ClientResponseWithBodyError(aiohttp.ClientResponseError): + """A ClientResponseError wrapper to hold the response body of the underlying failed + aiohttp request. + """ + def __init__(self, request_info, history, response, response_content): + super().__init__(request_info, history) + self.response = response + self.response_content = response_content diff --git a/firebase_admin/_utils.py b/firebase_admin/_utils.py index dcfb520d2..6b5523176 100644 --- a/firebase_admin/_utils.py +++ b/firebase_admin/_utils.py @@ -16,7 +16,7 @@ import json -import google.auth +import google.auth # type: ignore import requests import firebase_admin @@ -89,7 +89,7 @@ def _get_initialized_app(app): return app raise ValueError('Illegal app argument. Argument must be of type ' - ' firebase_admin.App, but given "{0}".'.format(type(app))) + f' firebase_admin.App, but given "{type(app)}".') @@ -125,6 +125,34 @@ def handle_platform_error_from_requests(error, handle_func=None): return exc if exc else _handle_func_requests(error, message, error_dict) +async def handle_platform_error_from_aiohttp(error, handle_func=None): + """Constructs a ``FirebaseError`` from the given requests error. + + This can be used to handle errors returned by Google Cloud Platform (GCP) APIs. + + Args: + error: An error raised by the aiohttp module while making an HTTP call to a GCP API. + handle_func: A function that can be used to handle platform errors in a custom way. When + specified, this function will be called with three arguments. It has the same + signature as ```_handle_func_requests``, but may return ``None``. + + Returns: + FirebaseError: A ``FirebaseError`` that can be raised to the user code. + """ + if error.response is None: + return handle_requests_error(error) + + response = error.response + content = error.response_content.decode() + status_code = response.status + error_dict, message = _parse_platform_error(content, status_code) + exc = None + if handle_func: + exc = handle_func(error, message, error_dict) + + # TODO: Implement aiohttp version of ``_handle_func_requests``. + return exc if exc else _handle_func_requests(error, message, error_dict) + def handle_operation_error(error): """Constructs a ``FirebaseError`` from the given operation error. @@ -137,7 +165,7 @@ def handle_operation_error(error): """ if not isinstance(error, dict): return exceptions.UnknownError( - message='Unknown error while making a remote service call: {0}'.format(error), + message=f'Unknown error while making a remote service call: {error}', cause=error) rpc_code = error.get('code') @@ -182,15 +210,15 @@ def handle_requests_error(error, message=None, code=None): """ if isinstance(error, requests.exceptions.Timeout): return exceptions.DeadlineExceededError( - message='Timed out while making an API call: {0}'.format(error), + message=f'Timed out while making an API call: {error}', cause=error) if isinstance(error, requests.exceptions.ConnectionError): return exceptions.UnavailableError( - message='Failed to establish a connection: {0}'.format(error), + message=f'Failed to establish a connection: {error}', cause=error) if error.response is None: return exceptions.UnknownError( - message='Unknown error while making a remote service call: {0}'.format(error), + message=f'Unknown error while making a remote service call: {error}', cause=error) if not code: @@ -237,7 +265,7 @@ def _parse_platform_error(content, status_code): error_dict = data.get('error', {}) msg = error_dict.get('message') if not msg: - msg = 'Unexpected HTTP response with status: {0}; body: {1}'.format(status_code, content) + msg = f'Unexpected HTTP response with status: {status_code}; body: {content}' return error_dict, msg diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index 46dd7d410..a8c2b89a8 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -16,7 +16,7 @@ import json -from googleapiclient import http +from googleapiclient import http # type: ignore from googleapiclient import _auth import requests @@ -228,7 +228,7 @@ class TopicManagementResponse: def __init__(self, resp): if not isinstance(resp, dict) or 'results' not in resp: - raise ValueError('Unexpected topic management response: {0}.'.format(resp)) + raise ValueError(f'Unexpected topic management response: {resp}.') self._success_count = 0 self._failure_count = 0 self._errors = [] @@ -328,7 +328,7 @@ def __init__(self, app): self._fcm_url = _MessagingService.FCM_URL.format(project_id) self._fcm_headers = { 'X-GOOG-API-FORMAT-VERSION': '2', - 'X-FIREBASE-CLIENT': 'fire-admin-python/{0}'.format(firebase_admin.__version__), + 'X-FIREBASE-CLIENT': f'fire-admin-python/{firebase_admin.__version__}', } timeout = app.options.get('httpTimeout', _http_client.DEFAULT_TIMEOUT_SECONDS) self._credential = app.credential.get_credential() @@ -407,12 +407,12 @@ def make_topic_management_request(self, tokens, topic, operation): if not isinstance(topic, str) or not topic: raise ValueError('Topic must be a non-empty string.') if not topic.startswith('/topics/'): - topic = '/topics/{0}'.format(topic) + topic = f'/topics/{topic}' data = { 'to': topic, 'registration_tokens': tokens, } - url = '{0}/{1}'.format(_MessagingService.IID_URL, operation) + url = f'{_MessagingService.IID_URL}/{operation}' try: resp = self._client.body( 'post', @@ -458,10 +458,10 @@ def _handle_iid_error(self, error): code = data.get('error') msg = None if code: - msg = 'Error while calling the IID service: {0}'.format(code) + msg = f'Error while calling the IID service: {code}' else: - msg = 'Unexpected HTTP response with status: {0}; body: {1}'.format( - error.response.status_code, error.response.content.decode()) + msg = (f'Unexpected HTTP response with status: {error.response.status_code}; ' + f'body: {error.response.content.decode()}') return _utils.handle_requests_error(error, msg) diff --git a/firebase_admin/messaging_async.py b/firebase_admin/messaging_async.py new file mode 100644 index 000000000..4e3306dbd --- /dev/null +++ b/firebase_admin/messaging_async.py @@ -0,0 +1,289 @@ +# Copyright 2022 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Firebase Cloud Messaging Async module.""" + +import asyncio + +from typing import ( + Optional, + Any, + Type, + List, + Dict +) + +import firebase_admin +from firebase_admin.exceptions import FirebaseError +from firebase_admin import ( + App +) +from firebase_admin.messaging import TopicManagementResponse +from firebase_admin._http_client_async import ( + JsonHttpClientAsync, + ClientResponseWithBodyError, + DEFAULT_TIMEOUT_SECONDS +) +from firebase_admin._messaging_encoder import ( + Message, + MessageEncoder +) +from firebase_admin._messaging_utils import ( + QuotaExceededError, + SenderIdMismatchError, + ThirdPartyAuthError, + UnregisteredError +) +from firebase_admin import _utils + + + +_MESSAGING_ATTRIBUTE = '_messaging_async' + + +__all__: List[str] = [ + 'send', + # 'send_all', + # 'send_multicast', + 'subscribe_to_topic', + 'unsubscribe_from_topic', +] + +# pylint: disable=unsubscriptable-object +# TODO: Remove false positive unsubscriptable-object lint warnings caused by type hints Optional +# type. This is fixed in pylint 2.7.0 but this version introduces new lint rules and requires +# multiple file changes. +def _get_messaging_service(app: Optional[App]) -> "_MessagingServiceAsync": + return _utils.get_app_service(app, _MESSAGING_ATTRIBUTE, _MessagingServiceAsync) + +async def send(message: Message, dry_run: bool = False, app: Optional[App] = None) -> str: + """Sends the given message via Firebase Cloud Messaging (FCM). + + If the ``dry_run`` mode is enabled, the message will not be actually delivered to the + recipients. Instead FCM performs all the usual validations, and emulates the send operation. + + Args: + message: An instance of ``messaging.Message``. + dry_run: A boolean indicating whether to run the operation in dry run mode (optional). + app: An App instance (optional). + + Returns: + string: A message ID string that uniquely identifies the sent message. + + Raises: + FirebaseError: If an error occurs while sending the message to the FCM service. + ValueError: If the input arguments are invalid. + """ + return await _get_messaging_service(app).send(message, dry_run) + +async def subscribe_to_topic( + tokens: List[str], + topic: str, app: Optional[App] = None + ) -> TopicManagementResponse: + """Subscribes a list of registration tokens to an FCM topic. + + Args: + tokens: A non-empty list of device registration tokens. List may not have more than 1000 + elements. + topic: Name of the topic to subscribe to. May contain the ``/topics/`` prefix. + app: An App instance (optional). + + Returns: + TopicManagementResponse: A ``TopicManagementResponse`` instance. + + Raises: + FirebaseError: If an error occurs while communicating with instance ID service. + ValueError: If the input arguments are invalid. + """ + return await _get_messaging_service(app).make_topic_management_request( + tokens, topic, 'iid/v1:batchAdd') + +async def unsubscribe_from_topic( + tokens: List[str], + topic: str, + app: Optional[App] = None + ) -> TopicManagementResponse: + """Unsubscribes a list of registration tokens from an FCM topic. + + Args: + tokens: A non-empty list of device registration tokens. List may not have more than 1000 + elements. + topic: Name of the topic to unsubscribe from. May contain the ``/topics/`` prefix. + app: An App instance (optional). + + Returns: + TopicManagementResponse: A ``TopicManagementResponse`` instance. + + Raises: + FirebaseError: If an error occurs while communicating with instance ID service. + ValueError: If the input arguments are invalid. + """ + return await _get_messaging_service(app).make_topic_management_request( + tokens, topic, 'iid/v1:batchRemove') + + +class _MessagingServiceAsync: + """Service class that implements Firebase Cloud Messaging (FCM) functionality asynchronously.""" + + FCM_URL: str = 'https://fcm.googleapis.com/v1/projects/{0}/messages:send' + FCM_BATCH_URL: str = 'https://fcm.googleapis.com/batch' + IID_URL: str = 'https://iid.googleapis.com' + IID_HEADERS: Dict[str, str] = {'access_token_auth': 'true'} + JSON_ENCODER: MessageEncoder = MessageEncoder() + + FCM_ERROR_TYPES: Dict[str, Type[FirebaseError]] = { + 'APNS_AUTH_ERROR': ThirdPartyAuthError, + 'QUOTA_EXCEEDED': QuotaExceededError, + 'SENDER_ID_MISMATCH': SenderIdMismatchError, + 'THIRD_PARTY_AUTH_ERROR': ThirdPartyAuthError, + 'UNREGISTERED': UnregisteredError, + } + + def __init__(self, app: App) -> None: + project_id = app.project_id + if not project_id: + raise ValueError( + 'Project ID is required to access Cloud Messaging service. Either set the ' + 'projectId option, or use service account credentials. Alternatively, set the ' + 'GOOGLE_CLOUD_PROJECT environment variable.') + self._fcm_url = _MessagingServiceAsync.FCM_URL.format(project_id) + self._fcm_headers = { + 'X-GOOG-API-FORMAT-VERSION': '2', + 'X-FIREBASE-CLIENT': f'fire-admin-python/{firebase_admin.__version__}' + } + timeout = app.options.get('httpTimeout', DEFAULT_TIMEOUT_SECONDS) + self._credential = app.credential.get_credential_async() + self._client = JsonHttpClientAsync(credential=self._credential, timeout=timeout) + self._loop = asyncio.get_event_loop() + + def close(self) -> None: + if self._client is not None: + if self._loop.is_closed(): + self._loop = asyncio.get_event_loop() + self._loop.run_until_complete(self._client.close()) + self._client = None # type: ignore[assignment] + + @classmethod + def encode_message(cls, message: Message) -> Dict[str, Any]: + if not isinstance(message, Message): + raise ValueError('Message must be an instance of messaging.Message class.') + return cls.JSON_ENCODER.default(message) + + async def send(self, message: Message, dry_run: bool = False) -> str: + """Sends the given message to FCM via the FCM v1 API.""" + data = self._message_data(message, dry_run) + try: + resp = await self._client.body( + 'post', + url=self._fcm_url, + headers=self._fcm_headers, + json=data + ) + except ClientResponseWithBodyError as error: + raise await self._handle_fcm_error(error) + else: + return resp['name'] + + async def make_topic_management_request(self, tokens, topic, operation): + """Invokes the IID service for topic management functionality.""" + if isinstance(tokens, str): + tokens = [tokens] + if not isinstance(tokens, list) or not tokens: + raise ValueError('Tokens must be a string or a non-empty list of strings.') + invalid_str = [t for t in tokens if not isinstance(t, str) or not t] + if invalid_str: + raise ValueError('Tokens must be non-empty strings.') + + if not isinstance(topic, str) or not topic: + raise ValueError('Topic must be a non-empty string.') + if not topic.startswith('/topics/'): + topic = f'/topics/{topic}' + data = { + 'to': topic, + 'registration_tokens': tokens, + } + url = f'{_MessagingServiceAsync.IID_URL}/{operation}' + try: + resp = await self._client.body( + 'post', + url=url, + json=data, + headers=_MessagingServiceAsync.IID_HEADERS + ) + except ClientResponseWithBodyError as error: + raise self._handle_iid_error(error) + else: + return TopicManagementResponse(resp) + + def _message_data(self, message: Message, dry_run: bool) -> Dict[str, Any]: + data = {'message': _MessagingServiceAsync.encode_message(message)} + if dry_run: + data['validate_only'] = True # type: ignore[assignment] + return data + + async def _handle_fcm_error(self, error: ClientResponseWithBodyError) -> FirebaseError: + """Handles errors received from the FCM API.""" + return await _utils.handle_platform_error_from_aiohttp( + error, _MessagingServiceAsync._build_fcm_error_aiohttp) + + def _handle_iid_error(self, error: ClientResponseWithBodyError) -> FirebaseError: + """Handles errors received from the Instance ID API.""" + if error.response is None: + raise _utils.handle_requests_error(error) + + data = {} + try: + parsed_body = error.response.json() + if isinstance(parsed_body, dict): + data = parsed_body + except ValueError: + pass + + # IID error response format: {"error": "ErrorCode"} + code = data.get('error') + msg = None + if code: + msg = f'Error while calling the IID service: {code}' + else: + msg = (f'Unexpected HTTP response with status: {error.response.status_code}; ' + f'body: {error.response.content.decode()}') + + return _utils.handle_requests_error(error, msg) + + @classmethod + def _build_fcm_error_aiohttp( + cls, + error: ClientResponseWithBodyError, + message: Message, + error_dict: Dict[Any, Any] + ) -> Optional[FirebaseError]: + """Parses an aiohttp error response from the FCM API and creates a FCM-specific exception if + appropriate.""" + exc_type: Optional[Type[FirebaseError]] = cls._build_fcm_error(error_dict) + return exc_type( # type: ignore[call-arg] + message, + cause=error, + http_response=error.request_info + ) if exc_type else None + + @classmethod + def _build_fcm_error(cls, error_dict: Dict[str, Any]) -> Optional[Type[FirebaseError]]: + if not error_dict: + return None + fcm_code: Optional[str] = None + for detail in error_dict.get('details', []): + if detail.get('@type') == 'type.googleapis.com/google.firebase.fcm.v1.FcmError': + fcm_code = detail.get('errorCode') + break + return _MessagingServiceAsync.FCM_ERROR_TYPES.get(fcm_code) # type: ignore[arg-type] diff --git a/integration/conftest.py b/integration/conftest.py index 169e02d5b..d844d4c6b 100644 --- a/integration/conftest.py +++ b/integration/conftest.py @@ -15,6 +15,7 @@ """pytest configuration and global fixtures for integration tests.""" import json +import asyncio import pytest import firebase_admin @@ -36,7 +37,7 @@ def _get_cert_path(request): def integration_conf(request): cert_path = _get_cert_path(request) - with open(cert_path) as cert: + with open(cert_path, encoding='utf-8') as cert: project_id = json.load(cert).get('project_id') if not project_id: raise ValueError('Failed to determine project ID from service account certificate.') @@ -57,10 +58,12 @@ def default_app(request): """ cred, project_id = integration_conf(request) ops = { - 'databaseURL' : 'https://{0}.firebaseio.com'.format(project_id), - 'storageBucket' : '{0}.appspot.com'.format(project_id) + 'databaseURL' : f'https://{project_id}.firebaseio.com', + 'storageBucket' : f'{project_id}.appspot.com' } - return firebase_admin.initialize_app(cred, ops) + app = firebase_admin.initialize_app(cred, ops) + yield app + firebase_admin.delete_app(app) @pytest.fixture(scope='session') def api_key(request): @@ -68,5 +71,14 @@ def api_key(request): if not path: raise ValueError('API key file not specified. Make sure to specify the "--apikey" ' 'command-line option.') - with open(path) as keyfile: + with open(path, encoding='utf-8') as keyfile: return keyfile.read().strip() + +@pytest.fixture(scope="session") +def event_loop(): + """Create an instance of the default event loop for test session. + This avoids early eventloop closure. + """ + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() diff --git a/integration/test_messaging_async.py b/integration/test_messaging_async.py new file mode 100644 index 000000000..bb58f82bd --- /dev/null +++ b/integration/test_messaging_async.py @@ -0,0 +1,102 @@ +# Copyright 2022 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for firebase_admin.messaging module.""" + +import re +from datetime import datetime + +import pytest + +from firebase_admin import ( + exceptions, + messaging, + messaging_async, +) + + +_REGISTRATION_TOKEN = ('fGw0qy4TGgk:APA91bGtWGjuhp4WRhHXgbabIYp1jxEKI08ofj_v1bKhWAGJQ4e3arRCWzeTf' + 'HaLz83mBnDh0aPWB1AykXAVUUGl2h1wT4XI6XazWpvY7RBUSYfoxtqSWGIm2nvWh2BOP1YG50' + '1SsRoE') + +@pytest.mark.asyncio +async def test_send(): + msg = messaging.Message( + topic='foo-bar', + notification=messaging.Notification('test-title', 'test-body', + 'https://images.unsplash.com/photo-1494438639946' + '-1ebd1d20bf85?fit=crop&w=900&q=60'), + android=messaging.AndroidConfig( + restricted_package_name='com.google.firebase.demos', + notification=messaging.AndroidNotification( + title='android-title', + body='android-body', + image='https://images.unsplash.com/' + 'photo-1494438639946-1ebd1d20bf85?fit=crop&w=900&q=60', + event_timestamp=datetime.now(), + priority='high', + vibrate_timings_millis=[100, 200, 300, 400], + visibility='public', + sticky=True, + local_only=False, + default_vibrate_timings=False, + default_sound=True, + default_light_settings=False, + light_settings=messaging.LightSettings( + color='#aabbcc', + light_off_duration_millis=200, + light_on_duration_millis=300 + ), + notification_count=1 + ) + ), + apns=messaging.APNSConfig(payload=messaging.APNSPayload( + aps=messaging.Aps( + alert=messaging.ApsAlert( + title='apns-title', + body='apns-body' + ) + ) + )) + ) + msg_id = await messaging_async.send(msg, dry_run=True) + assert re.match('^projects/.*/messages/.*$', msg_id) + +@pytest.mark.asyncio +async def test_send_invalid_token(): + msg = messaging.Message( + token=_REGISTRATION_TOKEN, + notification=messaging.Notification('test-title', 'test-body') + ) + with pytest.raises(messaging.UnregisteredError): + await messaging_async.send(msg, dry_run=True) + +@pytest.mark.asyncio +async def test_send_malformed_token(): + msg = messaging.Message( + token='not-a-token', + notification=messaging.Notification('test-title', 'test-body') + ) + with pytest.raises(exceptions.InvalidArgumentError): + await messaging_async.send(msg, dry_run=True) + +@pytest.mark.asyncio +async def test_subscribe(): + resp = await messaging_async.subscribe_to_topic(_REGISTRATION_TOKEN, 'mock-topic') + assert resp.success_count + resp.failure_count == 1 + +@pytest.mark.asyncio +async def test_unsubscribe(): + resp = await messaging_async.unsubscribe_from_topic(_REGISTRATION_TOKEN, 'mock-topic') + assert resp.success_count + resp.failure_count == 1 diff --git a/tests/test_http_client_async.py b/tests/test_http_client_async.py new file mode 100644 index 000000000..8719f4909 --- /dev/null +++ b/tests/test_http_client_async.py @@ -0,0 +1,180 @@ +# Copyright 2022 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for firebase_admin._http_client_async.""" +from __future__ import absolute_import +import asyncio + +import aiohttp +import pytest +from pytest_localserver import http + +from firebase_admin import _http_client_async +from tests import testutils + + +_TEST_URL = 'http://firebase.test.url/' + +def make_mock_client_session(payload='body', status=200): + recorder = [] + session = testutils.MockClientSession(payload, status, recorder) + return session, recorder + +def make_mock_authorized_session(credentials, payload='body', status=200): + recorder = [] + session = testutils.MockAuthorizedSession(payload, status, recorder, credentials) + return session, recorder + +class TestHttpClient: + def seutp_method(self): + self.client = None + + def teardown_method(self): + if self.client is not None: + asyncio.get_event_loop().run_until_complete(self.client.close()) + + @pytest.mark.asyncio + async def test_http_client_default_session(self): + self.client = _http_client_async.HttpClientAsync() + assert self.client.session is not None + assert isinstance(self.client.session, aiohttp.ClientSession) + assert self.client.base_url == '' + + @pytest.mark.asyncio + async def test_http_client_custom_session(self): + session, recorder = make_mock_client_session() + self.client = _http_client_async.HttpClientAsync(session=session) + assert self.client.session is session + assert self.client.base_url == '' + resp = await self.client.request('GET', _TEST_URL) + assert resp.status == 200 + content = await resp.content() + assert content.decode() == 'body' + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == _TEST_URL + + @pytest.mark.asyncio + async def test_base_url(self): + session, recorder = make_mock_client_session() + self.client = _http_client_async.HttpClientAsync(base_url=_TEST_URL, session=session) + assert self.client.session is not None + assert self.client.base_url == _TEST_URL + resp = await self.client.request('GET', 'foo') + assert resp.status == 200 + content = await resp.content() + assert content.decode() == 'body' + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == _TEST_URL + 'foo' + + @pytest.mark.asyncio + async def test_credential_async(self): + credential = testutils.MockGoogleCredentialAsync() + self.client = _http_client_async.HttpClientAsync( + credential=credential) + assert self.client.session is not None + session, recorder = make_mock_authorized_session(credential) + self.client._session = session + resp = await self.client.request('GET', _TEST_URL) + assert resp.status == 200 + content = await resp.content() + assert content.decode() == 'body' + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == _TEST_URL + assert recorder[0].extra_kwargs['headers']['authorization'] == 'Bearer mock-token' + + @pytest.mark.asyncio + @pytest.mark.parametrize('options, timeout', [ + ({}, _http_client_async.DEFAULT_TIMEOUT_SECONDS), + ({'timeout': 7}, 7), + ({'timeout': 0}, 0), + ({'timeout': None}, None), + ]) + async def test_timeout(self, options, timeout): + session, recorder = make_mock_client_session() + self.client = _http_client_async.HttpClientAsync(**options, session=session) + assert self.client.timeout == timeout + await self.client.request('get', _TEST_URL) + assert len(recorder) == 1 + if timeout is None: + assert recorder[0].extra_kwargs['timeout'] is None + else: + assert recorder[0].extra_kwargs['timeout'] == pytest.approx(timeout, 0.001) + + +class TestHttpRetry: + """Unit tests for the default HTTP retry configuration.""" + + ENTITY_ENCLOSING_METHODS = ['post', 'put', 'patch'] + ALL_METHODS = ENTITY_ENCLOSING_METHODS + ['get', 'delete', 'head', 'options'] + + @classmethod + def setup_class(cls): + # Start a test server instance scoped to the class. + server = http.ContentServer() + server.start() + cls.httpserver = server + + @classmethod + def teardown_class(cls): + cls.httpserver.stop() + + def setup_method(self): + self.client = None + # Clean up any state in the server before starting a new test case. + self.httpserver.requests = [] + + def teardown_method(self): + if self.client is not None: + asyncio.get_event_loop().run_until_complete(self.client.close()) + + @pytest.mark.asyncio + @pytest.mark.parametrize('method', ALL_METHODS) + async def test_retry_on_503(self, method): + self.httpserver.serve_content({}, 503) + self.client = _http_client_async.JsonHttpClientAsync( + credential=testutils.MockGoogleCredentialAsync(), base_url=self.httpserver.url) + body = None + if method in self.ENTITY_ENCLOSING_METHODS: + body = {'key': 'value'} + with pytest.raises(aiohttp.ClientError) as excinfo: + await self.client.request(method, '/', json=body) + assert excinfo.value.response.status == 503 + assert len(self.httpserver.requests) == 5 + + @pytest.mark.asyncio + @pytest.mark.parametrize('method', ALL_METHODS) + async def test_retry_on_500(self, method): + self.httpserver.serve_content({}, 500) + self.client = _http_client_async.JsonHttpClientAsync( + credential=testutils.MockGoogleCredentialAsync(), base_url=self.httpserver.url) + body = None + if method in self.ENTITY_ENCLOSING_METHODS: + body = {'key': 'value'} + with pytest.raises(aiohttp.ClientError) as excinfo: + await self.client.request(method, '/', json=body) + assert excinfo.value.response.status == 500 + assert len(self.httpserver.requests) == 5 + + @pytest.mark.asyncio + async def test_no_retry_on_404(self): + self.httpserver.serve_content({}, 404) + self.client = _http_client_async.JsonHttpClientAsync( + credential=testutils.MockGoogleCredentialAsync(), base_url=self.httpserver.url) + with pytest.raises(aiohttp.ClientError) as excinfo: + await self.client.request('get', '/') + assert excinfo.value.response.status == 404 + assert len(self.httpserver.requests) == 1 diff --git a/tests/test_messaging_async.py b/tests/test_messaging_async.py new file mode 100644 index 000000000..199a6584b --- /dev/null +++ b/tests/test_messaging_async.py @@ -0,0 +1,440 @@ +# Copyright 2022 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test cases for the firebase_admin.messaging module.""" +import json + +import pytest + +import firebase_admin +from firebase_admin import exceptions +from firebase_admin import messaging +from firebase_admin import messaging_async +from firebase_admin import _http_client_async +from tests import testutils + + +NON_STRING_ARGS = [[], tuple(), {}, True, False, 1, 0] +NON_DICT_ARGS = ['', [], tuple(), True, False, 1, 0, {1: 'foo'}, {'foo': 1}] +NON_OBJECT_ARGS = [[], tuple(), {}, 'foo', 0, 1, True, False] +NON_LIST_ARGS = ['', tuple(), {}, True, False, 1, 0, [1], ['foo', 1]] +NON_UINT_ARGS = ['1.23s', [], tuple(), {}, -1.23] +HTTP_ERROR_CODES = { + 400: exceptions.InvalidArgumentError, + 403: exceptions.PermissionDeniedError, + 404: exceptions.NotFoundError, + 500: exceptions.InternalError, + 503: exceptions.UnavailableError, +} +FCM_ERROR_CODES = { + 'APNS_AUTH_ERROR': messaging.ThirdPartyAuthError, + 'QUOTA_EXCEEDED': messaging.QuotaExceededError, + 'SENDER_ID_MISMATCH': messaging.SenderIdMismatchError, + 'THIRD_PARTY_AUTH_ERROR': messaging.ThirdPartyAuthError, + 'UNREGISTERED': messaging.UnregisteredError, +} + + +def check_exception(exception, message, status): + assert isinstance(exception, exceptions.FirebaseError) + assert str(exception) == message + assert exception.cause is not None + assert exception.http_response is not None + assert exception.http_response.status_code == status + + +class TestTimeoutAsync: + + def teardown(self): + testutils.cleanup_apps() + + def _instrument_service(self, response): + app = firebase_admin.get_app() + fcm_service_async = messaging_async._get_messaging_service(app) + recorder = [] + credentials = fcm_service_async._client.session.credentials + session = testutils.MockAuthorizedSession(json.dumps(response), 200, recorder, credentials) + fcm_service_async._client._session = session + return recorder + + def _check_timeout(self, recorder, timeout): + assert len(recorder) == 1 + if timeout is None: + assert recorder[0].extra_kwargs['timeout'] is None + else: + assert recorder[0].extra_kwargs['timeout'] == pytest.approx(timeout, 0.001) + + @pytest.mark.parametrize('options, timeout', [ + ({'httpTimeout': 4}, 4), + ({'httpTimeout': None}, None), + ({}, _http_client_async.DEFAULT_TIMEOUT_SECONDS), + ]) + @pytest.mark.asyncio + async def test_send_async(self, options, timeout): + cred = testutils.MockCredentialAsync() + all_options = {'projectId': 'explicit-project-id'} + all_options.update(options) + firebase_admin.initialize_app(cred, all_options) + recorder = self._instrument_service({'name': 'message-id'}) + msg = messaging.Message(topic='foo') + await messaging_async.send(msg) + self._check_timeout(recorder, timeout) + + @pytest.mark.parametrize('options, timeout', [ + ({'httpTimeout': 4}, 4), + ({'httpTimeout': None}, None), + ({}, _http_client_async.DEFAULT_TIMEOUT_SECONDS), + ]) + @pytest.mark.asyncio + async def test_topic_management_custom_timeout(self, options, timeout): + cred = testutils.MockCredentialAsync() + all_options = {'projectId': 'explicit-project-id'} + all_options.update(options) + firebase_admin.initialize_app(cred, all_options) + recorder = self._instrument_service({'results': [{}, {'error': 'error_reason'}]}) + await messaging_async.subscribe_to_topic(['1'], 'a') + self._check_timeout(recorder, timeout) + + +class TestSendAsync: + + _DEFAULT_RESPONSE = json.dumps({'name': 'message-id'}) + _CLIENT_VERSION = f'fire-admin-python/{firebase_admin.__version__}' + + def setup(self): + cred = testutils.MockCredentialAsync() + firebase_admin.initialize_app(cred, {'projectId': 'explicit-project-id'}) + + def teardown(self): + testutils.cleanup_apps() + + def _instrument_messaging_service(self, app=None, status=200, payload=_DEFAULT_RESPONSE): + if not app: + app = firebase_admin.get_app() + fcm_service_async = messaging_async._get_messaging_service(app) + recorder = [] + + credentials = fcm_service_async._client.session.credentials + session = testutils.MockAuthorizedSession(payload, status, recorder, credentials) + fcm_service_async._client._session = session + + return fcm_service_async, recorder + + def _get_url(self, project_id): + return messaging_async._MessagingServiceAsync.FCM_URL.format(project_id) + + @pytest.mark.asyncio + async def test_no_project_id(self): + async def evaluate(): + app = firebase_admin.initialize_app( + testutils.MockCredentialAsync(), + name='no_project_id' + ) + with pytest.raises(ValueError): + await messaging_async.send(messaging.Message(topic='foo'), app=app) + await testutils.run_without_project_id_async(evaluate) + + @pytest.mark.parametrize('msg', NON_OBJECT_ARGS + [None]) + @pytest.mark.asyncio + async def test_invalid_send(self, msg): + with pytest.raises(ValueError) as excinfo: + await messaging_async.send(msg) + assert str(excinfo.value) == 'Message must be an instance of messaging.Message class.' + + @pytest.mark.asyncio + async def test_send_dry_run(self): + _, recorder = self._instrument_messaging_service() + msg = messaging.Message(topic='foo') + msg_id = await messaging_async.send(msg, dry_run=True) + assert msg_id == 'message-id' + assert len(recorder) == 1 + assert recorder[0].method == 'post' + assert recorder[0].url == self._get_url('explicit-project-id') + assert recorder[0].extra_kwargs['headers']['X-GOOG-API-FORMAT-VERSION'] == '2' + assert recorder[0].extra_kwargs['headers']['X-FIREBASE-CLIENT'] == self._CLIENT_VERSION + body = { + 'message': messaging_async._MessagingServiceAsync.encode_message(msg), + 'validate_only': True, + } + assert recorder[0].extra_kwargs['json'] == body + + @pytest.mark.asyncio + async def test_send(self): + _, recorder = self._instrument_messaging_service() + msg = messaging.Message(topic='foo') + msg_id = await messaging_async.send(msg) + assert msg_id == 'message-id' + assert len(recorder) == 1 + assert recorder[0].method == 'post' + assert recorder[0].url == self._get_url('explicit-project-id') + assert recorder[0].extra_kwargs['headers']['X-GOOG-API-FORMAT-VERSION'] == '2' + assert recorder[0].extra_kwargs['headers']['X-FIREBASE-CLIENT'] == self._CLIENT_VERSION + body = {'message': messaging_async._MessagingServiceAsync.encode_message(msg)} + assert recorder[0].extra_kwargs['json'] == body + + # # TODO: Implement remainding FCM error handling for aiohttp requests + # @pytest.mark.parametrize('status,exc_type', HTTP_ERROR_CODES.items()) + # @pytest.mark.asyncio + # async def test_send_error(self, status, exc_type): + # _, recorder = self._instrument_messaging_service(status=status, payload='{}') + # msg = messaging.Message(topic='foo') + # with pytest.raises(exc_type) as excinfo: + # await messaging_async.send(msg) + # expected = f'Unexpected HTTP response with status: {status}; body: {{}}' + # check_exception(excinfo.value, expected, status) + # assert len(recorder) == 1 + # assert recorder[0].method == 'POST' + # assert recorder[0].url == self._get_url('explicit-project-id') + # assert recorder[0].extra_kwargs['headers']['X-GOOG-API-FORMAT-VERSION'] == '2' + # assert recorder[0].extra_kwargs['headers']['X-FIREBASE-CLIENT'] == self._CLIENT_VERSION + # body = {'message': messaging_async._MessagingServiceAsync.JSON_ENCODER.default(msg)} + # assert recorder[0].extra_kwargs['json'] == body + + # @pytest.mark.parametrize('status', HTTP_ERROR_CODES) + # @pytest.mark.asyncio + # async def test_send_detailed_error(self, status): + # payload = json.dumps({ + # 'error': { + # 'status': 'INVALID_ARGUMENT', + # 'message': 'test error' + # } + # }) + # _, recorder = self._instrument_messaging_service(status=status, payload=payload) + # msg = messaging.Message(topic='foo') + # with pytest.raises(exceptions.InvalidArgumentError) as excinfo: + # await messaging_async.send(msg) + # check_exception(excinfo.value, 'test error', status) + # assert len(recorder) == 1 + # assert recorder[0].method == 'post' + # assert recorder[0].url == self._get_url('explicit-project-id') + # body = {'message': messaging_async._MessagingServiceAsync.JSON_ENCODER.default(msg)} + # assert recorder[0].extra_kwargs['json'] == body + + # @pytest.mark.parametrize('status', HTTP_ERROR_CODES) + # @pytest.mark.asyncio + # async def test_send_canonical_error_code(self, status): + # payload = json.dumps({ + # 'error': { + # 'status': 'NOT_FOUND', + # 'message': 'test error' + # } + # }) + # _, recorder = self._instrument_messaging_service(status=status, payload=payload) + # msg = messaging.Message(topic='foo') + # with pytest.raises(exceptions.NotFoundError) as excinfo: + # await messaging_async.send(msg) + # check_exception(excinfo.value, 'test error', status) + # assert len(recorder) == 1 + # assert recorder[0].method == 'post' + # assert recorder[0].url == self._get_url('explicit-project-id') + # body = {'message': messaging_async._MessagingServiceAsync.JSON_ENCODER.default(msg)} + # assert recorder[0].extra_kwargs['json'] == body + + # @pytest.mark.parametrize('status', HTTP_ERROR_CODES) + # @pytest.mark.parametrize('fcm_error_code, exc_type', FCM_ERROR_CODES.items()) + # @pytest.mark.asyncio + # async def test_send_fcm_error_code(self, status, fcm_error_code, exc_type): + # payload = json.dumps({ + # 'error': { + # 'status': 'INVALID_ARGUMENT', + # 'message': 'test error', + # 'details': [ + # { + # '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', + # 'errorCode': fcm_error_code, + # }, + # ], + # } + # }) + # _, recorder = self._instrument_messaging_service(status=status, payload=payload) + # msg = messaging.Message(topic='foo') + # with pytest.raises(exc_type) as excinfo: + # await messaging_async.send(msg) + # check_exception(excinfo.value, 'test error', status) + # assert len(recorder) == 1 + # assert recorder[0].method == 'post' + # assert recorder[0].url == self._get_url('explicit-project-id') + # body = {'message': messaging_async._MessagingServiceAsync.JSON_ENCODER.default(msg)} + # assert recorder[0].extra_kwargs['json'] == body + + # @pytest.mark.parametrize('status', HTTP_ERROR_CODES) + # @pytest.mark.asyncio + # async def test_send_unknown_fcm_error_code(self, status): + # payload = json.dumps({ + # 'error': { + # 'status': 'INVALID_ARGUMENT', + # 'message': 'test error', + # 'details': [ + # { + # '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', + # 'errorCode': 'SOME_UNKNOWN_CODE', + # }, + # ], + # } + # }) + # _, recorder = self._instrument_messaging_service(status=status, payload=payload) + # msg = messaging.Message(topic='foo') + # with pytest.raises(exceptions.InvalidArgumentError) as excinfo: + # await messaging_async.send(msg) + # check_exception(excinfo.value, 'test error', status) + # assert len(recorder) == 1 + # assert recorder[0].method == 'post' + # assert recorder[0].url == self._get_url('explicit-project-id') + # body = {'message': messaging_async._MessagingServiceAsync.JSON_ENCODER.default(msg)} + # assert recorder[0].extra_kwargs['json'] == body + + +class TestTopicManagementAsync: + + _DEFAULT_RESPONSE = json.dumps({'results': [{}, {'error': 'error_reason'}]}) + _DEFAULT_ERROR_RESPONSE = json.dumps({'error': 'error_reason'}) + _VALID_ARGS = [ + # (tokens, topic, expected) + ( + ['foo', 'bar'], + 'test-topic', + {'to': '/topics/test-topic', 'registration_tokens': ['foo', 'bar']} + ), + ( + 'foo', + '/topics/test-topic', + {'to': '/topics/test-topic', 'registration_tokens': ['foo']} + ), + ] + + def setup(self): + cred = testutils.MockCredentialAsync() + firebase_admin.initialize_app(cred, {'projectId': 'explicit-project-id'}) + + def teardown(self): + testutils.cleanup_apps() + + def _instrument_iid_service(self, app=None, status=200, payload=_DEFAULT_RESPONSE): + if not app: + app = firebase_admin.get_app() + fcm_service_async = messaging_async._get_messaging_service(app) + recorder = [] + + credentials = fcm_service_async._client.session.credentials + session = testutils.MockAuthorizedSession(payload, status, recorder, credentials) + fcm_service_async._client._session = session + + return fcm_service_async, recorder + + def _get_url(self, path): + return f'{messaging_async._MessagingServiceAsync.IID_URL}/{path}' + + @pytest.mark.parametrize('tokens', [None, '', [], {}, tuple()]) + @pytest.mark.asyncio + async def test_invalid_tokens(self, tokens): + expected = 'Tokens must be a string or a non-empty list of strings.' + if isinstance(tokens, str): + expected = 'Tokens must be non-empty strings.' + + with pytest.raises(ValueError) as excinfo: + await messaging_async.subscribe_to_topic(tokens, 'test-topic') + assert str(excinfo.value) == expected + + with pytest.raises(ValueError) as excinfo: + await messaging_async.unsubscribe_from_topic(tokens, 'test-topic') + assert str(excinfo.value) == expected + + @pytest.mark.parametrize('topic', NON_STRING_ARGS + [None, '']) + @pytest.mark.asyncio + async def test_invalid_topic(self, topic): + expected = 'Topic must be a non-empty string.' + with pytest.raises(ValueError) as excinfo: + await messaging_async.subscribe_to_topic('test-token', topic) + assert str(excinfo.value) == expected + + with pytest.raises(ValueError) as excinfo: + await messaging_async.unsubscribe_from_topic('test-tokens', topic) + assert str(excinfo.value) == expected + + @pytest.mark.parametrize('args', _VALID_ARGS) + @pytest.mark.asyncio + async def test_subscribe_to_topic(self, args): + _, recorder = self._instrument_iid_service() + resp = await messaging_async.subscribe_to_topic(args[0], args[1]) + self._check_response(resp) + assert len(recorder) == 1 + assert recorder[0].method == 'post' + assert recorder[0].url == self._get_url('iid/v1:batchAdd') + assert recorder[0].extra_kwargs['json'] == args[2] + + # @pytest.mark.parametrize('status, exc_type', HTTP_ERROR_CODES.items()) + # @pytest.mark.asyncio + # async def test_subscribe_to_topic_error(self, status, exc_type): + # _, recorder = self._instrument_iid_service( + # status=status, payload=self._DEFAULT_ERROR_RESPONSE) + # with pytest.raises(exc_type) as excinfo: + # await messaging_async.subscribe_to_topic('foo', 'test-topic') + # assert str(excinfo.value) == 'Error while calling the IID service: error_reason' + # assert len(recorder) == 1 + # assert recorder[0].method == 'POST' + # assert recorder[0].url == self._get_url('iid/v1:batchAdd') + + # @pytest.mark.parametrize('status, exc_type', HTTP_ERROR_CODES.items()) + # @pytest.mark.asyncio + # async def test_subscribe_to_topic_non_json_error(self, status, exc_type): + # _, recorder = self._instrument_iid_service(status=status, payload='not json') + # with pytest.raises(exc_type) as excinfo: + # await messaging_async.subscribe_to_topic('foo', 'test-topic') + # reason = f'Unexpected HTTP response with status: {status}; body: not json' + # assert str(excinfo.value) == reason + # assert len(recorder) == 1 + # assert recorder[0].method == 'POST' + # assert recorder[0].url == self._get_url('iid/v1:batchAdd') + + @pytest.mark.parametrize('args', _VALID_ARGS) + @pytest.mark.asyncio + async def test_unsubscribe_from_topic(self, args): + _, recorder = self._instrument_iid_service() + resp = await messaging_async.unsubscribe_from_topic(args[0], args[1]) + self._check_response(resp) + assert len(recorder) == 1 + assert recorder[0].method == 'post' + assert recorder[0].url == self._get_url('iid/v1:batchRemove') + assert recorder[0].extra_kwargs['json'] == args[2] + + # @pytest.mark.parametrize('status, exc_type', HTTP_ERROR_CODES.items()) + # @pytest.mark.asyncio + # async def test_unsubscribe_from_topic_error(self, status, exc_type): + # _, recorder = self._instrument_iid_service( + # status=status, payload=self._DEFAULT_ERROR_RESPONSE) + # with pytest.raises(exc_type) as excinfo: + # await messaging_async.unsubscribe_from_topic('foo', 'test-topic') + # assert str(excinfo.value) == 'Error while calling the IID service: error_reason' + # assert len(recorder) == 1 + # assert recorder[0].method == 'POST' + # assert recorder[0].url == self._get_url('iid/v1:batchRemove') + + # @pytest.mark.parametrize('status, exc_type', HTTP_ERROR_CODES.items()) + # @pytest.mark.asyncio + # async def test_unsubscribe_from_topic_non_json_error(self, status, exc_type): + # _, recorder = self._instrument_iid_service(status=status, payload='not json') + # with pytest.raises(exc_type) as excinfo: + # await messaging_async.unsubscribe_from_topic('foo', 'test-topic') + # reason = f'Unexpected HTTP response with status: {status}; body: not json' + # assert str(excinfo.value) == reason + # assert len(recorder) == 1 + # assert recorder[0].method == 'POST' + # assert recorder[0].url == self._get_url('iid/v1:batchRemove') + + def _check_response(self, resp): + assert resp.success_count == 1 + assert resp.failure_count == 1 + assert len(resp.errors) == 1 + assert resp.errors[0].index == 1 + assert resp.errors[0].reason == 'error_reason' diff --git a/tests/testutils.py b/tests/testutils.py index 6ab69dda4..eebba307d 100644 --- a/tests/testutils.py +++ b/tests/testutils.py @@ -16,9 +16,13 @@ import io import os +from unittest.mock import MagicMock import pytest +import aiohttp +from aiohttp import streams from google.auth import credentials +from google.auth import _credentials_async from google.auth import transport from google.auth.transport import _aiohttp_requests as aiohttp_requests from requests import adapters @@ -34,7 +38,7 @@ def resource_filename(filename): def resource(filename): """Returns the contents of a test resource.""" - with open(resource_filename(filename), 'r') as file_obj: + with open(resource_filename(filename), 'r', encoding="utf-8") as file_obj: return file_obj.read() @@ -61,6 +65,22 @@ def run_without_project_id(func): os.environ[env_var] = gcloud_project +async def run_without_project_id_async(func): + env_vars = ['GCLOUD_PROJECT', 'GOOGLE_CLOUD_PROJECT'] + env_values = [] + for env_var in env_vars: + gcloud_project = os.environ.get(env_var) + if gcloud_project: + del os.environ[env_var] + env_values.append(gcloud_project) + try: + await func() + finally: + for idx, env_var in enumerate(env_vars): + gcloud_project = env_values[idx] + if gcloud_project: + os.environ[env_var] = gcloud_project + def new_monkeypatch(): return pytest.MonkeyPatch() @@ -114,7 +134,7 @@ def __call__(self, *args, **kwargs): # pylint: disable=arguments-differ class MockAsyncResponse(aiohttp_requests._CombinedResponse): def __init__(self, status, response): - super(MockAsyncResponse, self).__init__(response) + super().__init__(response) self._status = status self._response = response self._raw_content = response @@ -144,7 +164,7 @@ class MockAsyncRequest(aiohttp_requests.Request): """ def __init__(self, status, response): - super(MockAsyncRequest, self).__init__() + super().__init__() self.response = MockAsyncResponse(status, response) self.log = [] @@ -157,7 +177,7 @@ class MockFailedAsyncRequest(aiohttp_requests.Request): """A mock HTTP request that fails by raising an exception.""" def __init__(self, error): - super(MockFailedAsyncRequest, self).__init__() + super().__init__() self.error = error self.log = [] @@ -173,6 +193,11 @@ class MockGoogleCredential(credentials.Credentials): def refresh(self, request): self.token = 'mock-token' +class MockGoogleCredentialAsync(_credentials_async.Credentials): + """A mock Google authentication async credential.""" + async def refresh(self, request): # pylint: disable=invalid-overridden-method + self.token = 'mock-token' + class MockCredential(firebase_admin.credentials.Base): """A mock Firebase credential implementation.""" @@ -183,6 +208,14 @@ def __init__(self): def get_credential(self): return self._g_credential +class MockCredentialAsync(firebase_admin.credentials.Base): + """A mock Firebase async credential implementation.""" + + def __init__(self): + self._g_credential_async = MockGoogleCredentialAsync() + + def get_credential_async(self): + return self._g_credential_async class MockMultiRequestAdapter(adapters.HTTPAdapter): """A mock HTTP adapter that supports multiple responses for the Python requests module.""" @@ -216,7 +249,7 @@ def send(self, request, **kwargs): # pylint: disable=arguments-differ class MockAdapter(MockMultiRequestAdapter): """A mock HTTP adapter for the Python requests module.""" def __init__(self, data, status, recorder): - super(MockAdapter, self).__init__([data], [status], recorder) + super().__init__([data], [status], recorder) @property def status(self): @@ -225,3 +258,46 @@ def status(self): @property def data(self): return self._responses[0] + +class MockClientResponse(aiohttp.ClientResponse): + def __init__(self, method, url, payload, status, recorder): # pylint: disable=super-init-not-called + self._cache = {} + self._url = url + + mock_reader = AsyncMock(spec=streams.StreamReader) + mock_reader.read.return_value = str.encode(payload) + self.content = mock_reader + self.status = status + self.recorder = recorder + self._headers = [] + +class MockSession(aiohttp.ClientSession): + def __init__(self, payload, status, recorder, credentials=None): + super().__init__(credentials) + self.payload = payload + self.status = status + self.recorder = recorder + self.current_response = 0 + + async def _request(self, method, url, *args, **kwargs): # pylint: disable=arguments-differ + self.method = method + self.url = url + self.args = args + self.extra_kwargs = kwargs + self.recorder.append(self) + self.current_response += 1 + return MockClientResponse(method, url, self.payload, self.status, self.recorder) + +class MockClientSession(MockSession): + def __init__(self, payload, status, recorder): + super().__init__(payload, status, recorder) + +class MockAuthorizedSession(MockClientSession, aiohttp_requests.AuthorizedSession): + def __init__(self, payload, status, recorder, credentials): + super().__init__(payload, status, recorder) + self.credentials = credentials + +# Custom async mock class since unittest.mock.AsyncMock is only avaible in python 3.8+ +class AsyncMock(MagicMock): + async def __call__(self, *args, **kwargs): # pylint: disable=invalid-overridden-method + return super().__call__(*args, **kwargs)