diff --git a/firebase_admin/multi_factor_config_mgt.py b/firebase_admin/multi_factor_config_mgt.py new file mode 100644 index 000000000..6196568a5 --- /dev/null +++ b/firebase_admin/multi_factor_config_mgt.py @@ -0,0 +1,222 @@ +# Copyright 2024 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Firebase multifactor configuration management module. + +This module contains functions for managing multifactor auth configuration at +the project and tenant level. +""" +from enum import Enum +from typing import List + +__all__ = [ + 'validate_keys', + 'MultiFactorServerConfig', + 'TOTPProviderConfig', + 'ProviderConfig', + 'MultiFactorConfig', +] + + +def validate_keys(keys, valid_keys, config_name): + for key in keys: + if key not in valid_keys: + raise ValueError( + '"{0}" is not a valid "{1}" parameter.'.format( + key, config_name)) + + +class MultiFactorServerConfig: + """Represents the multi-factor configuration response received from the server. + """ + + def __init__(self, data): + if not isinstance(data, dict): + raise ValueError( + 'Invalid data argument in MultiFactorServerConfig constructor: {0}, must be a valid' + ' dict'.format(data)) + self._data = data + + @property + def provider_configs(self): + data = self._data.get('providerConfigs', None) + if data is not None: + return [self.ProviderServerConfig(d) for d in data] + return None + + class ProviderServerConfig: + """Represents the provider configuration response received from the server. + """ + + def __init__(self, data): + if not isinstance(data, dict): + raise ValueError( + 'Invalid data argument in ProviderServerConfig constructor: {0}'.format(data)) + self._data = data + + @property + def state(self): + return self._data.get('state', None) + + @property + def totp_provider_config(self): + data = self._data.get('totpProviderConfig', None) + if data is not None: + return self.TOTPProviderServerConfig(data) + return None + + class TOTPProviderServerConfig: + """Represents the TOTP provider configuration response received from the server. + """ + + def __init__(self, data): + if not isinstance(data, dict): + raise ValueError( + 'Invalid data argument in TOTPProviderServerConfig' + ' constructor: {0}'.format(data)) + self._data = data + + @property + def adjacent_intervals(self): + return self._data.get('adjacentIntervals', None) + + +class TOTPProviderConfig: + """A tenant or project's TOTP provider configuration.""" + + def __init__(self, adjacent_intervals: int = None): + self.adjacent_intervals: int = adjacent_intervals + + def to_dict(self) -> dict: + data = {} + if self.adjacent_intervals is not None: + data['adjacentIntervals'] = self.adjacent_intervals + return data + + def validate(self): + """Validates the configuration. + + Raises: + ValueError: In case of an unsuccessful validation. + """ + validate_keys( + keys=vars(self).keys(), + valid_keys={'adjacent_intervals'}, + config_name='TOTPProviderConfig') + if self.adjacent_intervals is not None: + # Because bool types get converted to int here + # pylint: disable=C0123 + if type(self.adjacent_intervals) is not int: + raise ValueError( + 'totp_provider_config.adjacent_intervals must be an integer between' + ' 1 and 10 (inclusive).') + if not 1 <= self.adjacent_intervals <= 10: + raise ValueError( + 'totp_provider_config.adjacent_intervals must be an integer between' + ' 1 and 10 (inclusive).') + + def build_server_request(self): + self.validate() + return self.to_dict() + + +class ProviderConfig: + """A tenant or project's multifactor provider configuration. + Currently, only TOTP can be configured.""" + + class State(Enum): + ENABLED = 'ENABLED' + DISABLED = 'DISABLED' + + def __init__(self, + state: State = None, + totp_provider_config: TOTPProviderConfig = None): + self.state: self.State = state + self.totp_provider_config: TOTPProviderConfig = totp_provider_config + + def to_dict(self) -> dict: + data = {} + if self.state: + data['state'] = self.state.value + if self.totp_provider_config: + data['totpProviderConfig'] = self.totp_provider_config.to_dict() + return data + + def validate(self): + """Validates the provider configuration. + + Raises: + ValueError: In case of an unsuccessful validation. + """ + validate_keys( + keys=vars(self).keys(), + valid_keys={ + 'state', + 'totp_provider_config'}, + config_name='ProviderConfig') + if self.state is None: + raise ValueError('ProviderConfig.state must be defined.') + if not isinstance(self.state, ProviderConfig.State): + raise ValueError( + 'ProviderConfig.state must be of type ProviderConfig.State.') + if self.totp_provider_config is None: + raise ValueError( + 'ProviderConfig.totp_provider_config must be defined.') + if not isinstance(self.totp_provider_config, TOTPProviderConfig): + raise ValueError( + 'ProviderConfig.totp_provider_config must be of type TOTPProviderConfig.') + + def build_server_request(self): + self.validate() + return self.to_dict() + + +class MultiFactorConfig: + """A tenant or project's multi factor configuration.""" + + def __init__(self, + provider_configs: List[ProviderConfig] = None): + self.provider_configs: List[ProviderConfig] = provider_configs + + def to_dict(self) -> dict: + data = {} + if self.provider_configs is not None: + data['providerConfigs'] = [d.to_dict() + for d in self.provider_configs] + return data + + def validate(self): + """Validates the configuration. + + Raises: + ValueError: In case of an unsuccessful validation. + """ + validate_keys( + keys=vars(self).keys(), + valid_keys={'provider_configs'}, + config_name='MultiFactorConfig') + if self.provider_configs is None: + raise ValueError( + 'multi_factor_config.provider_configs must be specified') + if not isinstance(self.provider_configs, list) or not self.provider_configs: + raise ValueError( + 'provider_configs must be an array of type ProviderConfig.') + for provider_config in self.provider_configs: + if not isinstance(provider_config, ProviderConfig): + raise ValueError( + 'provider_configs must be an array of type ProviderConfig.') + provider_config.validate() + + def build_server_request(self): + self.validate() + return self.to_dict() diff --git a/firebase_admin/project_config_mgt.py b/firebase_admin/project_config_mgt.py new file mode 100644 index 000000000..4ae65958a --- /dev/null +++ b/firebase_admin/project_config_mgt.py @@ -0,0 +1,135 @@ +# Copyright 2024 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Firebase project configuration management module. + +This module contains functions for managing projects. +""" + +import requests + +import firebase_admin +from firebase_admin import _auth_utils +from firebase_admin import _http_client +from firebase_admin import _utils +from firebase_admin.multi_factor_config_mgt import MultiFactorConfig +from firebase_admin.multi_factor_config_mgt import MultiFactorServerConfig + +_PROJECT_CONFIG_MGT_ATTRIBUTE = '_project_config_mgt' + +__all__ = [ + 'ProjectConfig', + 'get_project_config', + 'update_project_config', +] + + +def get_project_config(app=None): + """Gets the project config corresponding to the current project_id. + + Args: + app: An App instance (optional). + + Returns: + Project: A project object. + + Raises: + ValueError: If the project ID is None, empty or not a string. + ProjectNotFoundError: If no project exists by the given ID. + FirebaseError: If an error occurs while retrieving the project. + """ + project_config_mgt_service = _get_project_config_mgt_service(app) + return project_config_mgt_service.get_project_config() + +def update_project_config(multi_factor_config: MultiFactorConfig = None, app=None): + """Update the project config with the given options. + + Args: + multi_factor_config: Updated multi-factor authentication configuration + (optional) + app: An App instance (optional). + Returns: + Project: An updated ProjectConfig object. + Raises: + ValueError: If any of the given arguments are invalid. + FirebaseError: If an error occurs while updating the project. + """ + project_config_mgt_service = _get_project_config_mgt_service(app) + return project_config_mgt_service.update_project_config(multi_factor_config=multi_factor_config) + + +def _get_project_config_mgt_service(app): + return _utils.get_app_service(app, _PROJECT_CONFIG_MGT_ATTRIBUTE, + _ProjectConfigManagementService) + +class ProjectConfig: + """Represents a project config in an application. + """ + + def __init__(self, data): + if not isinstance(data, dict): + raise ValueError( + 'Invalid data argument in Project constructor: {0}'.format(data)) + self._data = data + + @property + def multi_factor_config(self): + data = self._data.get('mfa') + if data: + return MultiFactorServerConfig(data) + return None + +class _ProjectConfigManagementService: + """Firebase project management service.""" + + PROJECT_CONFIG_MGT_URL = 'https://identitytoolkit.googleapis.com/v2/projects' + + def __init__(self, app): + credential = app.credential.get_credential() + version_header = 'Python/Admin/{0}'.format(firebase_admin.__version__) + base_url = '{0}/{1}/config'.format( + self.PROJECT_CONFIG_MGT_URL, app.project_id) + self.app = app + self.client = _http_client.JsonHttpClient( + credential=credential, base_url=base_url, headers={'X-Client-Version': version_header}) + + def get_project_config(self) -> ProjectConfig: + """Gets the project config""" + try: + body = self.client.body('get', url='') + except requests.exceptions.RequestException as error: + raise _auth_utils.handle_auth_backend_error(error) + else: + return ProjectConfig(body) + + def update_project_config(self, multi_factor_config: MultiFactorConfig = None) -> ProjectConfig: + """Updates the specified project with the given parameters.""" + + payload = {} + if multi_factor_config is not None: + if not isinstance(multi_factor_config, MultiFactorConfig): + raise ValueError('multi_factor_config must be of type MultiFactorConfig.') + payload['mfa'] = multi_factor_config.build_server_request() + if not payload: + raise ValueError( + 'At least one parameter must be specified for update.') + + update_mask = ','.join(_auth_utils.build_update_mask(payload)) + params = 'updateMask={0}'.format(update_mask) + try: + body = self.client.body( + 'patch', url='', json=payload, params=params) + except requests.exceptions.RequestException as error: + raise _auth_utils.handle_auth_backend_error(error) + else: + return ProjectConfig(body) diff --git a/firebase_admin/tenant_mgt.py b/firebase_admin/tenant_mgt.py index 8c53e30a1..4f943b3e3 100644 --- a/firebase_admin/tenant_mgt.py +++ b/firebase_admin/tenant_mgt.py @@ -28,6 +28,8 @@ from firebase_admin import _auth_utils from firebase_admin import _http_client from firebase_admin import _utils +from firebase_admin.multi_factor_config_mgt import MultiFactorConfig +from firebase_admin.multi_factor_config_mgt import MultiFactorServerConfig _TENANT_MGT_ATTRIBUTE = '_tenant_mgt' @@ -91,7 +93,8 @@ def get_tenant(tenant_id, app=None): def create_tenant( - display_name, allow_password_sign_up=None, enable_email_link_sign_in=None, app=None): + display_name, allow_password_sign_up=None, enable_email_link_sign_in=None, + multi_factor_config: MultiFactorConfig = None, app=None): """Creates a new tenant from the given options. Args: @@ -101,6 +104,7 @@ def create_tenant( provider (optional). enable_email_link_sign_in: A boolean indicating whether to enable or disable email link sign-in (optional). Disabling this makes the password required for email sign-in. + multi_factor_config : A multi factor configuration to add to the tenant (optional). app: An App instance (optional). Returns: @@ -113,12 +117,13 @@ def create_tenant( tenant_mgt_service = _get_tenant_mgt_service(app) return tenant_mgt_service.create_tenant( display_name=display_name, allow_password_sign_up=allow_password_sign_up, - enable_email_link_sign_in=enable_email_link_sign_in) + enable_email_link_sign_in=enable_email_link_sign_in, + multi_factor_config=multi_factor_config,) def update_tenant( tenant_id, display_name=None, allow_password_sign_up=None, enable_email_link_sign_in=None, - app=None): + multi_factor_config: MultiFactorConfig = None, app=None): """Updates an existing tenant with the given options. Args: @@ -128,6 +133,7 @@ def update_tenant( provider. enable_email_link_sign_in: A boolean indicating whether to enable or disable email link sign-in. Disabling this makes the password required for email sign-in. + multi_factor_config : A multi factor configuration to update for the tenant (optional). app: An App instance (optional). Returns: @@ -141,7 +147,8 @@ def update_tenant( tenant_mgt_service = _get_tenant_mgt_service(app) return tenant_mgt_service.update_tenant( tenant_id, display_name=display_name, allow_password_sign_up=allow_password_sign_up, - enable_email_link_sign_in=enable_email_link_sign_in) + enable_email_link_sign_in=enable_email_link_sign_in, + multi_factor_config=multi_factor_config) def delete_tenant(tenant_id, app=None): @@ -183,6 +190,7 @@ def list_tenants(page_token=None, max_results=_MAX_LIST_TENANTS_RESULTS, app=Non FirebaseError: If an error occurs while retrieving the user accounts. """ tenant_mgt_service = _get_tenant_mgt_service(app) + def download(page_token, max_results): return tenant_mgt_service.list_tenants(page_token, max_results) return ListTenantsPage(download, page_token, max_results) @@ -205,7 +213,8 @@ class Tenant: def __init__(self, data): if not isinstance(data, dict): - raise ValueError('Invalid data argument in Tenant constructor: {0}'.format(data)) + raise ValueError( + 'Invalid data argument in Tenant constructor: {0}'.format(data)) if not 'name' in data: raise ValueError('Tenant response missing required keys.') @@ -228,6 +237,13 @@ def allow_password_sign_up(self): def enable_email_link_sign_in(self): return self._data.get('enableEmailLinkSignin', False) + @property + def multi_factor_config(self): + data = self._data.get('mfaConfig', None) + if data is not None: + return MultiFactorServerConfig(data) + return None + class _TenantManagementService: """Firebase tenant management service.""" @@ -237,7 +253,8 @@ class _TenantManagementService: def __init__(self, app): credential = app.credential.get_credential() version_header = 'Python/Admin/{0}'.format(firebase_admin.__version__) - base_url = '{0}/projects/{1}'.format(self.TENANT_MGT_URL, app.project_id) + base_url = '{0}/projects/{1}'.format( + self.TENANT_MGT_URL, app.project_id) self.app = app self.client = _http_client.JsonHttpClient( credential=credential, base_url=base_url, headers={'X-Client-Version': version_header}) @@ -256,7 +273,7 @@ def auth_for_tenant(self, tenant_id): client = auth.Client(self.app, tenant_id=tenant_id) self.tenant_clients[tenant_id] = client - return client + return client def get_tenant(self, tenant_id): """Gets the tenant corresponding to the given ``tenant_id``.""" @@ -272,7 +289,8 @@ def get_tenant(self, tenant_id): return Tenant(body) def create_tenant( - self, display_name, allow_password_sign_up=None, enable_email_link_sign_in=None): + self, display_name, allow_password_sign_up=None, enable_email_link_sign_in=None, + multi_factor_config: MultiFactorConfig = None): """Creates a new tenant from the given parameters.""" payload = {'displayName': _validate_display_name(display_name)} @@ -282,7 +300,11 @@ def create_tenant( if enable_email_link_sign_in is not None: payload['enableEmailLinkSignin'] = _auth_utils.validate_boolean( enable_email_link_sign_in, 'enableEmailLinkSignin') - + if multi_factor_config is not None: + if not isinstance(multi_factor_config, MultiFactorConfig): + raise ValueError( + 'multi_factor_config must be of type MultiFactorConfig.') + payload['mfaConfig'] = multi_factor_config.build_server_request() try: body = self.client.body('post', '/tenants', json=payload) except requests.exceptions.RequestException as error: @@ -292,7 +314,8 @@ def create_tenant( def update_tenant( self, tenant_id, display_name=None, allow_password_sign_up=None, - enable_email_link_sign_in=None): + enable_email_link_sign_in=None, + multi_factor_config: MultiFactorConfig = None): """Updates the specified tenant with the given parameters.""" if not isinstance(tenant_id, str) or not tenant_id: raise ValueError('Tenant ID must be a non-empty string.') @@ -306,9 +329,15 @@ def update_tenant( if enable_email_link_sign_in is not None: payload['enableEmailLinkSignin'] = _auth_utils.validate_boolean( enable_email_link_sign_in, 'enableEmailLinkSignin') + if multi_factor_config is not None: + if not isinstance(multi_factor_config, MultiFactorConfig): + raise ValueError( + 'multi_factor_config must be of type MultiFactorConfig.') + payload['mfaConfig'] = multi_factor_config.build_server_request() if not payload: - raise ValueError('At least one parameter must be specified for update.') + raise ValueError( + 'At least one parameter must be specified for update.') url = '/tenants/{0}'.format(tenant_id) update_mask = ','.join(_auth_utils.build_update_mask(payload)) diff --git a/integration/test_project_config_mgt.py b/integration/test_project_config_mgt.py new file mode 100644 index 000000000..bdfd4ea34 --- /dev/null +++ b/integration/test_project_config_mgt.py @@ -0,0 +1,75 @@ +# Copyright 2024 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for firebase_admin.project_config_mgt module.""" + +import pytest + +from firebase_admin.project_config_mgt import ProjectConfig +from firebase_admin.project_config_mgt import get_project_config +from firebase_admin.project_config_mgt import update_project_config +from firebase_admin.multi_factor_config_mgt import MultiFactorConfig +from firebase_admin.multi_factor_config_mgt import MultiFactorServerConfig +from firebase_admin.multi_factor_config_mgt import ProviderConfig +from firebase_admin.multi_factor_config_mgt import TOTPProviderConfig + +ADJACENT_INTERVALS = 5 + +@pytest.fixture(scope='module') +def sample_mfa_config(): + mfa_config = { + 'providerConfigs': [ + { + 'state': 'ENABLED', + 'totpProviderConfig': { + 'adjacentIntervals': ADJACENT_INTERVALS + } + } + ] + } + return mfa_config + + +def test_update_project_config(): + mfa_object = MultiFactorConfig( + provider_configs=[ + ProviderConfig( + state=ProviderConfig.State.ENABLED, + totp_provider_config=TOTPProviderConfig( + adjacent_intervals=5 + ) + ) + ] + ) + project_config = update_project_config(multi_factor_config=mfa_object) + _assert_multi_factor_config(project_config.multi_factor_config) + + +def test_get_project(): + project_config = get_project_config() + assert isinstance(project_config, ProjectConfig) + _assert_multi_factor_config(project_config.multi_factor_config) + +def _assert_multi_factor_config(multi_factor_config): + assert isinstance(multi_factor_config, MultiFactorServerConfig) + assert len(multi_factor_config.provider_configs) == 1 + assert isinstance(multi_factor_config.provider_configs, list) + for provider_config in multi_factor_config.provider_configs: + assert isinstance(provider_config, MultiFactorServerConfig + .ProviderServerConfig) + assert provider_config.state == 'ENABLED' + assert isinstance(provider_config.totp_provider_config, + MultiFactorServerConfig.ProviderServerConfig + .TOTPProviderServerConfig) + assert provider_config.totp_provider_config.adjacent_intervals == ADJACENT_INTERVALS diff --git a/integration/test_tenant_mgt.py b/integration/test_tenant_mgt.py index c9eefd96e..1766b2b1a 100644 --- a/integration/test_tenant_mgt.py +++ b/integration/test_tenant_mgt.py @@ -25,6 +25,7 @@ from firebase_admin import auth from firebase_admin import tenant_mgt +from firebase_admin import multi_factor_config_mgt from integration import test_auth @@ -35,13 +36,34 @@ @pytest.fixture(scope='module') def sample_tenant(): + mfa_object = multi_factor_config_mgt.MultiFactorConfig( + provider_configs=[multi_factor_config_mgt.ProviderConfig( + state=multi_factor_config_mgt.ProviderConfig.State.ENABLED, + totp_provider_config=multi_factor_config_mgt.TOTPProviderConfig( + adjacent_intervals=5 + ) + )] + ) tenant = tenant_mgt.create_tenant( display_name='admin-python-tenant', allow_password_sign_up=True, - enable_email_link_sign_in=True) + enable_email_link_sign_in=True, + multi_factor_config=mfa_object) yield tenant tenant_mgt.delete_tenant(tenant.tenant_id) +def _assert_multi_factor_config(mfa_config): + assert isinstance(mfa_config, multi_factor_config_mgt.MultiFactorServerConfig) + assert len(mfa_config.provider_configs) == 1 + assert isinstance(mfa_config.provider_configs, list) + for provider_config in mfa_config.provider_configs: + assert isinstance(provider_config, multi_factor_config_mgt.MultiFactorServerConfig.\ + ProviderServerConfig) + assert provider_config.state == 'ENABLED' + assert isinstance(provider_config.totp_provider_config, + multi_factor_config_mgt.MultiFactorServerConfig.ProviderServerConfig + .TOTPProviderServerConfig) + assert provider_config.totp_provider_config.adjacent_intervals == 5 @pytest.fixture(scope='module') def tenant_user(sample_tenant): @@ -59,6 +81,7 @@ def test_get_tenant(sample_tenant): assert tenant.display_name == 'admin-python-tenant' assert tenant.allow_password_sign_up is True assert tenant.enable_email_link_sign_in is True + _assert_multi_factor_config(tenant.multi_factor_config) def test_list_tenants(sample_tenant): @@ -76,8 +99,17 @@ def test_list_tenants(sample_tenant): def test_update_tenant(): + mfa_object = multi_factor_config_mgt.MultiFactorConfig( + provider_configs=[multi_factor_config_mgt.ProviderConfig( + state=multi_factor_config_mgt.ProviderConfig.State.ENABLED, + totp_provider_config=multi_factor_config_mgt.TOTPProviderConfig( + adjacent_intervals=5 + ) + )] + ) tenant = tenant_mgt.create_tenant( - display_name='py-update-test', allow_password_sign_up=True, enable_email_link_sign_in=True) + display_name='py-update-test', allow_password_sign_up=True, enable_email_link_sign_in=True, + multi_factor_config=mfa_object) try: tenant = tenant_mgt.update_tenant( tenant.tenant_id, display_name='updated-py-tenant', allow_password_sign_up=False, @@ -87,6 +119,7 @@ def test_update_tenant(): assert tenant.display_name == 'updated-py-tenant' assert tenant.allow_password_sign_up is False assert tenant.enable_email_link_sign_in is False + _assert_multi_factor_config(tenant.multi_factor_config) finally: tenant_mgt.delete_tenant(tenant.tenant_id) diff --git a/tests/test_multi_factor_config.py b/tests/test_multi_factor_config.py new file mode 100644 index 000000000..50eaeec9a --- /dev/null +++ b/tests/test_multi_factor_config.py @@ -0,0 +1,184 @@ +# Copyright 2024 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from copy import copy + +import pytest + +from firebase_admin.multi_factor_config_mgt import MultiFactorConfig +from firebase_admin.multi_factor_config_mgt import MultiFactorServerConfig +from firebase_admin.multi_factor_config_mgt import TOTPProviderConfig +from firebase_admin.multi_factor_config_mgt import ProviderConfig + +sample_mfa_config = MultiFactorConfig( + provider_configs=[ProviderConfig( + state=ProviderConfig.State.ENABLED, + totp_provider_config=TOTPProviderConfig( + adjacent_intervals=5 + ) + )] +) + + +class TestMultiFactorConfig: + def test_invalid_mfa_config_params(self): + test_config = copy(sample_mfa_config) + test_config.invalid_parameter = 'invalid' + with pytest.raises(ValueError) as excinfo: + test_config.build_server_request() + assert str(excinfo.value).startswith('"invalid_parameter" is not a valid' + ' "MultiFactorConfig" parameter.') + + @pytest.mark.parametrize('provider_configs', + [True, False, 1, 0, list(), tuple(), dict()]) + def test_invalid_provider_configs_type(self, provider_configs): + test_config = copy(sample_mfa_config) + test_config.provider_configs = provider_configs + with pytest.raises(ValueError) as excinfo: + test_config.build_server_request() + assert str(excinfo.value).startswith('provider_configs must be an array of type' + ' ProviderConfig.') + + @pytest.mark.parametrize('provider_configs', + [[True], [1, 2], + [{'state': 'DISABLED', 'totpProviderConfig': {}}, "foo"]]) + def test_invalid_mfa_config_provider_config(self, provider_configs): + test_config = copy(sample_mfa_config) + test_config.provider_configs = provider_configs + with pytest.raises(ValueError) as excinfo: + test_config.build_server_request() + assert str(excinfo.value).startswith('provider_configs must be an array of type' + ' ProviderConfig.') + + +class TestProviderConfig: + def test_invalid_provider_config_params(self): + test_config = copy(sample_mfa_config.provider_configs[0]) + test_config.invalid_parameter = 'invalid' + with pytest.raises(ValueError) as excinfo: + test_config.build_server_request() + assert str(excinfo.value).startswith('"invalid_parameter" is not a valid "ProviderConfig"' + ' parameter.') + + def test_undefined_provider_config_state(self): + test_config = copy(sample_mfa_config.provider_configs[0]) + test_config.state = None + with pytest.raises(ValueError) as excinfo: + test_config.build_server_request() + assert str(excinfo.value).startswith( + 'ProviderConfig.state must be defined.') + + @pytest.mark.parametrize('state', + ['', 1, True, False, [], (), {}, "foo", 'ENABLED']) + def test_invalid_provider_config_state(self, state): + test_config = ProviderConfig( + state=state + ) + with pytest.raises(ValueError) as excinfo: + test_config.build_server_request() + assert str(excinfo.value).startswith('ProviderConfig.state must be of type' + ' ProviderConfig.State.') + + @pytest.mark.parametrize('state', + [ProviderConfig.State.ENABLED, + ProviderConfig.State.DISABLED]) + def test_undefined_totp_provider_config(self, state): + test_config = ProviderConfig(state=state) + with pytest.raises(ValueError) as excinfo: + test_config.build_server_request() + assert str(excinfo.value).startswith('ProviderConfig.totp_provider_config must be' + ' defined.') + + @pytest.mark.parametrize('totp_provider_config', + [True, False, 1, 0, list(), tuple(), dict()]) + def test_invalid_totp_provider_config_type(self, totp_provider_config): + test_config = copy(sample_mfa_config.provider_configs[0]) + test_config.totp_provider_config = totp_provider_config + with pytest.raises(ValueError) as excinfo: + test_config.build_server_request() + assert str(excinfo.value).startswith('ProviderConfig.totp_provider_config must be of type' + ' TOTPProviderConfig.') + + +class TestTOTPProviderConfig: + + def test_invalid_totp_provider_config_params(self): + test_config = copy( + sample_mfa_config.provider_configs[0].totp_provider_config) + test_config.invalid_parameter = 'invalid' + with pytest.raises(ValueError) as excinfo: + test_config.build_server_request() + assert str(excinfo.value).startswith('"invalid_parameter" is not a valid' + ' "TOTPProviderConfig" parameter.') + + @pytest.mark.parametrize('adjacent_intervals', + ['', -1, True, False, [], (), {}, "foo", 11, 1.1]) + def test_invalid_adjacent_intervals_type(self, adjacent_intervals): + test_config = copy( + sample_mfa_config.provider_configs[0].totp_provider_config) + test_config.adjacent_intervals = adjacent_intervals + with pytest.raises(ValueError) as excinfo: + test_config.build_server_request() + assert str(excinfo.value).startswith('totp_provider_config.adjacent_intervals must be an' + ' integer between 1 and 10 (inclusive).') + + +class TestMultiFactorServerConfig: + def test_invalid_multi_factor_config_response(self): + test_config = 'invalid' + with pytest.raises(ValueError) as excinfo: + MultiFactorServerConfig(test_config) + assert str(excinfo.value).startswith('Invalid data argument in MultiFactorServerConfig' + ' constructor: {0}'.format(test_config)) + + def test_invalid_provider_config_response(self): + test_config = 'invalid' + with pytest.raises(ValueError) as excinfo: + MultiFactorServerConfig.ProviderServerConfig(test_config) + assert str(excinfo.value).startswith('Invalid data argument in ProviderServerConfig' + ' constructor: {0}'.format(test_config)) + + def test_invalid_totp_provider_config_response(self): + test_config = 'invalid' + with pytest.raises(ValueError) as excinfo: + MultiFactorServerConfig.ProviderServerConfig.\ + TOTPProviderServerConfig(test_config) + assert str(excinfo.value).startswith('Invalid data argument in TOTPProviderServerConfig' + ' constructor: {0}'.format(test_config)) + + def test_valid_server_response(self): + response = { + 'providerConfigs': [{ + 'state': 'ENABLED', + 'totpProviderConfig': { + 'adjacentIntervals': 5 + } + }] + } + mfa_config = MultiFactorServerConfig(response) + _assert_multi_factor_config(mfa_config) + + +def _assert_multi_factor_config(mfa_config): + assert isinstance(mfa_config, MultiFactorServerConfig) + assert len(mfa_config.provider_configs) == 1 + assert isinstance(mfa_config.provider_configs, list) + for provider_config in mfa_config.provider_configs: + assert isinstance( + provider_config, + MultiFactorServerConfig.ProviderServerConfig) + assert provider_config.state == 'ENABLED' + assert isinstance(provider_config.totp_provider_config, + MultiFactorServerConfig.ProviderServerConfig + .TOTPProviderServerConfig) + assert provider_config.totp_provider_config.adjacent_intervals == 5 diff --git a/tests/test_project_config_mgt.py b/tests/test_project_config_mgt.py new file mode 100644 index 000000000..7abfa552a --- /dev/null +++ b/tests/test_project_config_mgt.py @@ -0,0 +1,182 @@ +# Copyright 2024 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test cases for the firebase_admin.project_config_mgt module.""" + +import json + +import pytest + +from tests import testutils + +import firebase_admin +from firebase_admin import project_config_mgt +from firebase_admin import multi_factor_config_mgt + + +ADJACENT_INTERVALS = 5 + +GET_PROJECT_RESPONSE = """{ + "mfaConfig":{ + "providerConfigs":[ + { + "state":"ENABLED", + "totpProviderConfig": { + "adjacentIntervals": 5 + } + } + ] + } +}""" + +MOCK_GET_USER_RESPONSE = testutils.resource('get_user.json') +INVALID_BOOLEANS = ['', 1, 0, list(), tuple(), dict()] + +PROJECT_CONFIG_MGT_URL_PREFIX = 'https://identitytoolkit.googleapis.com/v2/projects' + + +@pytest.fixture(scope='module') +def project_config_mgt_app(): + app = firebase_admin.initialize_app( + testutils.MockCredential(), name='projectMgt', options={'projectId': 'project-id'}) + yield app + firebase_admin.delete_app(app) + + +def _instrument_project_config_mgt(app, status, payload): + service = project_config_mgt._get_project_config_mgt_service(app) + recorder = [] + service.client.session.mount( + project_config_mgt._ProjectConfigManagementService.PROJECT_CONFIG_MGT_URL, + testutils.MockAdapter(payload, status, recorder)) + return service, recorder + + +class TestProjectConfig: + + @pytest.mark.parametrize('data', [None, 'foo', 0, 1, True, False, list(), tuple()]) + def test_invalid_data(self, data): + with pytest.raises(ValueError): + project_config_mgt.ProjectConfig(data) + + def test_project_config(self): + data = { + 'mfa': { + 'providerConfigs': [ + { + 'state': 'ENABLED', + 'totpProviderConfig': { + 'adjacentIntervals': ADJACENT_INTERVALS, + } + } + ] + } + } + project_config = project_config_mgt.ProjectConfig(data) + _assert_project_config(project_config) + + def test_project_optional_params(self): + data = { + 'name': 'test-project', + } + project = project_config_mgt.ProjectConfig(data) + assert project.multi_factor_config is None + + +class TestGetProjectConfig: + + def test_get_project_config(self, project_config_mgt_app): + _, recorder = _instrument_project_config_mgt( + project_config_mgt_app, 200, GET_PROJECT_RESPONSE) + project_config = project_config_mgt.get_project_config(app=project_config_mgt_app) + + _assert_project_config(project_config) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}/project-id/config'.format(PROJECT_CONFIG_MGT_URL_PREFIX) + + +class TestUpdateProjectConfig: + + def test_update_project_no_args(self, project_config_mgt_app): + with pytest.raises(ValueError) as excinfo: + project_config_mgt.update_project_config(app=project_config_mgt_app) + assert str(excinfo.value).startswith('At least one parameter must be specified for update') + + @pytest.mark.parametrize('multi_factor_config', ['foo', 0, 1, True, False, list(), tuple()]) + def test_invalid_multi_factor_config_type(self, multi_factor_config, project_config_mgt_app): + with pytest.raises(ValueError) as excinfo: + project_config_mgt.update_project_config(multi_factor_config=multi_factor_config, + app=project_config_mgt_app) + assert str(excinfo.value).startswith( + 'multi_factor_config must be of type MultiFactorConfig.') + + def test_update_project_config(self, project_config_mgt_app): + _, recorder = _instrument_project_config_mgt( + project_config_mgt_app, 200, GET_PROJECT_RESPONSE) + mfa_object = multi_factor_config_mgt.MultiFactorConfig( + provider_configs=[ + multi_factor_config_mgt.ProviderConfig( + state=multi_factor_config_mgt.ProviderConfig.State.ENABLED, + totp_provider_config=multi_factor_config_mgt.TOTPProviderConfig( + adjacent_intervals=ADJACENT_INTERVALS + ) + ) + ] + ) + project_config = project_config_mgt.update_project_config( + multi_factor_config=mfa_object, app=project_config_mgt_app) + + mask = ['mfa.providerConfigs'] + + _assert_project_config(project_config) + self._assert_request(recorder, { + 'mfa': { + 'providerConfigs': [ + { + 'state': 'ENABLED', + 'totpProviderConfig': { + 'adjacentIntervals': ADJACENT_INTERVALS, + } + } + ] + } + }, mask) + + def _assert_request(self, recorder, body, mask): + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'PATCH' + assert req.url == '{0}/project-id/config?updateMask={1}'.format( + PROJECT_CONFIG_MGT_URL_PREFIX, ','.join(mask)) + got = json.loads(req.body.decode()) + assert got == body + +def _assert_multi_factor_config(multi_factor_config): + assert isinstance(multi_factor_config, multi_factor_config_mgt.MultiFactorServerConfig) + assert len(multi_factor_config.provider_configs) == 1 + assert isinstance(multi_factor_config.provider_configs, list) + for provider_config in multi_factor_config.provider_configs: + assert isinstance(provider_config, multi_factor_config_mgt.MultiFactorServerConfig + .ProviderServerConfig) + assert provider_config.state == 'ENABLED' + assert isinstance(provider_config.totp_provider_config, + multi_factor_config_mgt.MultiFactorServerConfig.ProviderServerConfig + .TOTPProviderServerConfig) + assert provider_config.totp_provider_config.adjacent_intervals == ADJACENT_INTERVALS + +def _assert_project_config(project_config): + if project_config.multi_factor_config is not None: + _assert_multi_factor_config(project_config.multi_factor_config) diff --git a/tests/test_tenant_mgt.py b/tests/test_tenant_mgt.py index 53b766239..5e9963fec 100644 --- a/tests/test_tenant_mgt.py +++ b/tests/test_tenant_mgt.py @@ -26,15 +26,29 @@ from firebase_admin import tenant_mgt from firebase_admin import _auth_providers from firebase_admin import _user_mgt +from firebase_admin.multi_factor_config_mgt import MultiFactorConfig +from firebase_admin.multi_factor_config_mgt import MultiFactorServerConfig +from firebase_admin.multi_factor_config_mgt import ProviderConfig +from firebase_admin.multi_factor_config_mgt import TOTPProviderConfig from tests import testutils from tests import test_token_gen +ADJACENT_INTERVALS = 5 + GET_TENANT_RESPONSE = """{ "name": "projects/mock-project-id/tenants/tenant-id", "displayName": "Test Tenant", "allowPasswordSignup": true, - "enableEmailLinkSignin": true + "enableEmailLinkSignin": true, + "mfaConfig": { + "providerConfigs": [{ + "state":"ENABLED", + "totpProviderConfig": { + "adjacentIntervals": 5 + } + }] + } }""" TENANT_NOT_FOUND_RESPONSE = """{ @@ -236,17 +250,45 @@ def test_invalid_enable_email_link_sign_in(self, enable, tenant_mgt_app): display_name='test', enable_email_link_sign_in=enable, app=tenant_mgt_app) assert str(excinfo.value).startswith('Invalid type for enableEmailLinkSignin') + @pytest.mark.parametrize('multi_factor_config', ['a', 1, True, {}, dict(), list(), tuple()]) + def test_invalid_multi_factor_configs(self, multi_factor_config, tenant_mgt_app): + with pytest.raises(ValueError) as excinfo: + tenant_mgt.create_tenant( + display_name='test', multi_factor_config=multi_factor_config, app=tenant_mgt_app) + assert str(excinfo.value).startswith('multi_factor_config must be of type' + ' MultiFactorConfig.') + def test_create_tenant(self, tenant_mgt_app): _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, GET_TENANT_RESPONSE) + mfa_object = MultiFactorConfig( + provider_configs=[ + ProviderConfig( + state=ProviderConfig.State.ENABLED, + totp_provider_config=TOTPProviderConfig( + adjacent_intervals=ADJACENT_INTERVALS + ) + ) + ] + ) tenant = tenant_mgt.create_tenant( display_name='My-Tenant', allow_password_sign_up=True, enable_email_link_sign_in=True, - app=tenant_mgt_app) + multi_factor_config=mfa_object, app=tenant_mgt_app) _assert_tenant(tenant) self._assert_request(recorder, { 'displayName': 'My-Tenant', 'allowPasswordSignup': True, 'enableEmailLinkSignin': True, + 'mfaConfig': { + 'providerConfigs': [ + { + 'state': 'ENABLED', + 'totpProviderConfig': { + 'adjacentIntervals': ADJACENT_INTERVALS + } + } + ] + }, }) def test_create_tenant_false_values(self, tenant_mgt_app): @@ -322,6 +364,14 @@ def test_invalid_enable_email_link_sign_in(self, enable, tenant_mgt_app): 'tenant-id', enable_email_link_sign_in=enable, app=tenant_mgt_app) assert str(excinfo.value).startswith('Invalid type for enableEmailLinkSignin') + @pytest.mark.parametrize('multi_factor_config', ['a', 1, True, {}, dict(), list(), tuple()]) + def test_invalid_multi_factor_configs(self, multi_factor_config, tenant_mgt_app): + with pytest.raises(ValueError) as excinfo: + tenant_mgt.update_tenant( + 'tenant-id', multi_factor_config=multi_factor_config, app=tenant_mgt_app) + assert str(excinfo.value).startswith('multi_factor_config must be of type' + ' MultiFactorConfig.') + def test_update_tenant_no_args(self, tenant_mgt_app): with pytest.raises(ValueError) as excinfo: tenant_mgt.update_tenant('tenant-id', app=tenant_mgt_app) @@ -329,17 +379,39 @@ def test_update_tenant_no_args(self, tenant_mgt_app): def test_update_tenant(self, tenant_mgt_app): _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, GET_TENANT_RESPONSE) + mfa_object = MultiFactorConfig( + provider_configs=[ + ProviderConfig( + state=ProviderConfig.State.ENABLED, + totp_provider_config=TOTPProviderConfig( + adjacent_intervals=ADJACENT_INTERVALS + ) + ) + ] + ) tenant = tenant_mgt.update_tenant( 'tenant-id', display_name='My-Tenant', allow_password_sign_up=True, - enable_email_link_sign_in=True, app=tenant_mgt_app) + enable_email_link_sign_in=True, + multi_factor_config=mfa_object, app=tenant_mgt_app) _assert_tenant(tenant) body = { 'displayName': 'My-Tenant', 'allowPasswordSignup': True, 'enableEmailLinkSignin': True, + 'mfaConfig': { + 'providerConfigs': [ + { + 'state': 'ENABLED', + 'totpProviderConfig': { + 'adjacentIntervals': ADJACENT_INTERVALS + } + } + ] + } } - mask = ['allowPasswordSignup', 'displayName', 'enableEmailLinkSignin'] + mask = ['allowPasswordSignup', 'displayName', 'enableEmailLinkSignin', + 'mfaConfig.providerConfigs'] self._assert_request(recorder, body, mask) def test_update_tenant_false_values(self, tenant_mgt_app): @@ -995,6 +1067,18 @@ def test_custom_token_with_claims(self, tenant_aware_custom_token_app): test_token_gen.verify_custom_token( custom_token, expected_claims=claims, tenant_id='test-tenant') +def _assert_multi_factor_config(mfa_config): + assert isinstance(mfa_config, MultiFactorServerConfig) + assert len(mfa_config.provider_configs) == 1 + assert isinstance(mfa_config.provider_configs, list) + for provider_config in mfa_config.provider_configs: + assert isinstance(provider_config, MultiFactorServerConfig.\ + ProviderServerConfig) + assert provider_config.state == 'ENABLED' + assert isinstance(provider_config.totp_provider_config, + MultiFactorServerConfig.ProviderServerConfig + .TOTPProviderServerConfig) + assert provider_config.totp_provider_config.adjacent_intervals == ADJACENT_INTERVALS def _assert_tenant(tenant, tenant_id='tenant-id'): assert isinstance(tenant, tenant_mgt.Tenant) @@ -1002,3 +1086,5 @@ def _assert_tenant(tenant, tenant_id='tenant-id'): assert tenant.display_name == 'Test Tenant' assert tenant.allow_password_sign_up is True assert tenant.enable_email_link_sign_in is True + if tenant.multi_factor_config is not None: + _assert_multi_factor_config(mfa_config=tenant.multi_factor_config)