diff --git a/Makefile b/Makefile index f0d4395..5a85edc 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,7 @@ refresh-lockfiles: find requirements/ -name '*.txt' ! -name 'all.txt' -type f -delete pip-compile -q --no-emit-index-url --resolver backtracking -o requirements/linting.txt requirements/linting.in pip-compile -q --no-emit-index-url --resolver backtracking -o requirements/testing.txt requirements/testing.in - pip-compile -q --no-emit-index-url --resolver backtracking --extra toml --extra yaml --extra azure-key-vault -o requirements/pyproject.txt pyproject.toml + pip-compile -q --no-emit-index-url --resolver backtracking --extra toml --extra yaml --extra azure-key-vault --extra aws -o requirements/pyproject.txt pyproject.toml pip install --dry-run -r requirements/all.txt .PHONY: format diff --git a/docs/index.md b/docs/index.md index 5e3cf28..63225f5 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1306,6 +1306,63 @@ class AzureKeyVaultSettings(BaseSettings): ) ``` +## AWS Systems Manager Parameter Store + +You must set the following parameters: + +- `ssm_client`: An initialized [`boto3` SSM Client](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ssm.html#client). + +Optionally, you may specify the following parameters: + +- `ssm_path`: The hierarchy for the parameter. Hierarchies start with a forward slash (/). The hierarchy is the parameter name except the last part of the parameter. Under the hood, we make use of the [`get_parameters_by_path` method](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ssm/client/get_parameters_by_path.html) to recursively retrieve all parameters within the a specified path hierarchy. + +```py +import os +from typing import Tuple, Type + +import boto3 +from pydantic import BaseModel + +from pydantic_settings import ( + AwsSystemsManagerParameterStoreSettingsSource, + BaseSettings, + PydanticBaseSettingsSource, +) + + +class SubModel(BaseModel): + a: str + + +class AwsParamStoreSettings(BaseSettings): + foo: str + bar: int + sub: SubModel + + @classmethod + def settings_customise_sources( + cls, + settings_cls: Type[BaseSettings], + init_settings: PydanticBaseSettingsSource, + env_settings: PydanticBaseSettingsSource, + dotenv_settings: PydanticBaseSettingsSource, + file_secret_settings: PydanticBaseSettingsSource, + ) -> Tuple[PydanticBaseSettingsSource, ...]: + client = boto3.client('ssm') + ssm_param_store_settings = AwsSystemsManagerParameterStoreSettingsSource( + settings_cls, + ssm_client=client, + ssm_path=os.environ.get('SSM_PREFIX', '/api/dev/'), + ) + return ( + init_settings, + env_settings, + dotenv_settings, + file_secret_settings, + ssm_param_store_settings, + ) +``` --> + ## Other settings source Other settings sources are available for common configuration files: diff --git a/pydantic_settings/sources.py b/pydantic_settings/sources.py index 63520af..44560cf 100644 --- a/pydantic_settings/sources.py +++ b/pydantic_settings/sources.py @@ -59,6 +59,11 @@ import tomli import yaml + try: + from mypy_boto3_ssm import SSMClient + except ImportError: + SSMClient = None + from pydantic_settings.main import BaseSettings else: yaml = None @@ -2014,6 +2019,62 @@ def __repr__(self) -> str: return f'AzureKeyVaultSettingsSource(url={self._url!r}, ' f'env_nested_delimiter={self.env_nested_delimiter!r})' +class AwsSystemsManagerParameterStoreSettingsSource(EnvSettingsSource): + _ssm_client: SSMClient + _ssm_path: str + + def __init__( + self, + settings_cls: type[BaseSettings], + ssm_client: SSMClient, + ssm_path: str = '/', + case_sensitive: bool | None = None, + env_prefix: str | None = None, + env_nested_delimiter: str = '/', + env_ignore_empty: bool | None = None, + env_parse_none_str: str | None = None, + env_parse_enums: bool | None = None, + ) -> None: + self._ssm_client = ssm_client + self._ssm_path = ssm_path + super().__init__( + settings_cls, + case_sensitive, + env_prefix, + env_nested_delimiter, + env_ignore_empty, + env_parse_none_str, + env_parse_enums, + ) + + def _load_env_vars(self) -> Mapping[str, Optional[str]]: + paginator = self._ssm_client.get_paginator('get_parameters_by_path') + response_iterator = paginator.paginate( + Path=self._ssm_path, WithDecryption=True, Recursive=True + ) + + output = {} + try: + for page in response_iterator: + for parameter in page['Parameters']: + name = Path(parameter['Name']) + key = name.relative_to(self._ssm_path).as_posix() + + if not self.case_sensitive: + first_key, *rest = key.split(self.env_nested_delimiter) + key = self.env_nested_delimiter.join([first_key.lower(), *rest]) + + output[key] = parameter['Value'] + + except self._ssm_client.exceptions.ClientError as e: + warnings.warn(f'Unable to get parameters from {self._ssm_path!r}: {e}') + + return output + + def __repr__(self) -> str: + return f'AwsSystemsManagerParameterStoreSettingsSource(ssm_path={self._ssm_path!r})' + + def _get_env_var_key(key: str, case_sensitive: bool = False) -> str: return key if case_sensitive else key.lower() diff --git a/pyproject.toml b/pyproject.toml index 5df0e89..ffaa21a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ dynamic = ['version'] yaml = ["pyyaml>=6.0.1"] toml = ["tomli>=2.0.1"] azure-key-vault = ["azure-keyvault-secrets>=4.8.0", "azure-identity>=1.16.0"] +aws = ["boto3>=1.35.0", "boto3-stubs[ssm]>=1.35.0"] [project.urls] Homepage = 'https://github.com/pydantic/pydantic-settings' diff --git a/tests/test_sources.py b/tests/test_sources.py index f467c05..282ded7 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -10,6 +10,7 @@ from pydantic_settings.main import BaseSettings, SettingsConfigDict from pydantic_settings.sources import ( + AwsSystemsManagerParameterStoreSettingsSource, AzureKeyVaultSettingsSource, PydanticBaseSettingsSource, PyprojectTomlConfigSettingsSource, @@ -210,3 +211,189 @@ def _raise_resource_not_found_when_getting_parent_secret_name(self, secret_name: raise ResourceNotFoundError() return key_vault_secret + + +class TestAwsSystemsManagerParameterStoreSettingsSource: + """Test AwsSystemsManagerParameterStoreSettingsSource.""" + + def test___init__(self, mocker: MockerFixture) -> None: + """Test __init__.""" + + class AwsSettings(BaseSettings): + """AWS settings.""" + + mock_parameters = [] + paginator_mock = mocker.Mock() + paginator_mock.paginate.return_value = [{'Parameters': mock_parameters}] + + client_mock = mocker.Mock() + client_mock.get_paginator.return_value = paginator_mock + client_mock.exceptions.ClientError = Exception + + AwsSystemsManagerParameterStoreSettingsSource( + settings_cls=AwsSettings, ssm_client=client_mock, ssm_path='/my/path' + ) + + def test___call__case_sensitive(self, mocker: MockerFixture) -> None: + """Test __call__.""" + + class SqlServer(BaseModel): + password: str = Field(..., alias='Password') + + class AwsSettings(BaseSettings): + """AWS settings.""" + + SqlServerUser: str + sql_server_user: str = Field(..., alias='SqlServerUser') + sql_server: SqlServer = Field(..., alias='SqlServer') + + mock_parameters = [ + {'Name': '/my/path/SqlServerUser', 'Value': 'SecretValue'}, + {'Name': '/my/path/SqlServer/Password', 'Value': 'SecretValue'}, + ] + paginator_mock = mocker.Mock() + paginator_mock.paginate.return_value = [{'Parameters': mock_parameters}] + + client_mock = mocker.Mock() + client_mock.get_paginator.return_value = paginator_mock + client_mock.exceptions.ClientError = Exception + + obj = AwsSystemsManagerParameterStoreSettingsSource( + settings_cls=AwsSettings, + ssm_client=client_mock, + ssm_path='/my/path', + case_sensitive=True, + ) + + settings = obj() + + assert settings['SqlServerUser'] == 'SecretValue' + assert settings['SqlServer']['Password'] == 'SecretValue' + + def test___call__case_insensitive(self, mocker: MockerFixture) -> None: + """Test __call__.""" + + class SqlServer(BaseModel): + password: str = Field(..., alias='Password') + + class AwsSettings(BaseSettings): + """AWS settings.""" + + SqlServerUser: str + sql_server_user: str = Field(..., alias='SqlServerUser') + sql_server: SqlServer = Field(..., alias='SqlServer') + + mock_parameters = [ + {'Name': '/my/path/SqlServerUser', 'Value': 'SecretValue'}, + {'Name': '/my/path/SqlServer/Password', 'Value': 'SecretValue'}, + ] + paginator_mock = mocker.Mock() + paginator_mock.paginate.return_value = [{'Parameters': mock_parameters}] + + client_mock = mocker.Mock() + client_mock.get_paginator.return_value = paginator_mock + client_mock.exceptions.ClientError = Exception + + obj = AwsSystemsManagerParameterStoreSettingsSource( + settings_cls=AwsSettings, + ssm_client=client_mock, + ssm_path='/my/path', + case_sensitive=False, + ) + settings = obj() + + assert settings['SqlServerUser'] == 'SecretValue' + assert settings['SqlServer']['Password'] == 'SecretValue' + + def test_aws_ssm_settings_source(self, mocker: MockerFixture) -> None: + """Test AwsSystemsManagerParameterStoreSettingsSource.""" + mock_parameters = [ + {'Name': '/my/path/SqlServerUser', 'Value': 'SecretValue'}, + {'Name': '/my/path/SqlServer/Password', 'Value': 'SecretValue'}, + ] + paginator_mock = mocker.Mock() + paginator_mock.paginate.return_value = [{'Parameters': mock_parameters}] + + client_mock = mocker.Mock() + client_mock.get_paginator.return_value = paginator_mock + client_mock.exceptions.ClientError = Exception + + class SqlServer(BaseModel): + password: str = Field(..., alias='Password') + + class AwsSettings(BaseSettings): + """AWS settings.""" + + SqlServerUser: str + sql_server_user: str = Field(..., alias='SqlServerUser') + sql_server: SqlServer = Field(..., alias='SqlServer') + + @classmethod + def settings_customise_sources( + cls, + settings_cls: type[BaseSettings], + init_settings: PydanticBaseSettingsSource, + env_settings: PydanticBaseSettingsSource, + dotenv_settings: PydanticBaseSettingsSource, + file_secret_settings: PydanticBaseSettingsSource, + ) -> tuple[PydanticBaseSettingsSource, ...]: + return ( + AwsSystemsManagerParameterStoreSettingsSource( + settings_cls=AwsSettings, + ssm_client=client_mock, + ssm_path='/my/path', + ), + ) + + settings = AwsSettings() # type: ignore + + assert settings.SqlServerUser == 'SecretValue' + assert settings.sql_server_user == 'SecretValue' + assert settings.sql_server.password == 'SecretValue' + + def test_aws_ssm_settings_source__delimiter(self, mocker: MockerFixture) -> None: + """Test AwsSystemsManagerParameterStoreSettingsSource.""" + mock_parameters = [ + {'Name': '/my/path/SqlServerUser', 'Value': 'SecretValue'}, + {'Name': '/my/path/SqlServer__Password', 'Value': 'SecretValue'}, + ] + paginator_mock = mocker.Mock() + paginator_mock.paginate.return_value = [{'Parameters': mock_parameters}] + + client_mock = mocker.Mock() + client_mock.get_paginator.return_value = paginator_mock + client_mock.exceptions.ClientError = Exception + + class SqlServer(BaseModel): + password: str = Field(..., alias='Password') + + class AwsSettings(BaseSettings): + """AWS settings.""" + + SqlServerUser: str + sql_server_user: str = Field(..., alias='SqlServerUser') + sql_server: SqlServer = Field(..., alias='SqlServer') + + @classmethod + def settings_customise_sources( + cls, + settings_cls: type[BaseSettings], + init_settings: PydanticBaseSettingsSource, + env_settings: PydanticBaseSettingsSource, + dotenv_settings: PydanticBaseSettingsSource, + file_secret_settings: PydanticBaseSettingsSource, + ) -> tuple[PydanticBaseSettingsSource, ...]: + return ( + AwsSystemsManagerParameterStoreSettingsSource( + settings_cls=AwsSettings, + ssm_client=client_mock, + ssm_path='/my/path', + env_nested_delimiter='__', + ), + ) + + settings = AwsSettings() # type: ignore + + assert settings.SqlServerUser == 'SecretValue' + assert settings.sql_server_user == 'SecretValue' + assert settings.sql_server.password == 'SecretValue'