Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AwsSystemsManagerParameterStoreSettingsSource #385

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ refresh-lockfiles:
find requirements/ -name '*.txt' ! -name 'all.txt' -type f -delete
pip-compile -q --no-emit-index-url --resolver backtracking -o requirements/linting.txt requirements/linting.in
pip-compile -q --no-emit-index-url --resolver backtracking -o requirements/testing.txt requirements/testing.in
pip-compile -q --no-emit-index-url --resolver backtracking --extra toml --extra yaml --extra azure-key-vault -o requirements/pyproject.txt pyproject.toml
pip-compile -q --no-emit-index-url --resolver backtracking --extra toml --extra yaml --extra azure-key-vault --extra aws -o requirements/pyproject.txt pyproject.toml
pip install --dry-run -r requirements/all.txt

.PHONY: format
Expand Down
57 changes: 57 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -1306,6 +1306,63 @@ class AzureKeyVaultSettings(BaseSettings):
)
```

## AWS Systems Manager Parameter Store

You must set the following parameters:

- `ssm_client`: An initialized [`boto3` SSM Client](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ssm.html#client).

Optionally, you may specify the following parameters:

- `ssm_path`: The hierarchy for the parameter. Hierarchies start with a forward slash (/). The hierarchy is the parameter name except the last part of the parameter. Under the hood, we make use of the [`get_parameters_by_path` method](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ssm/client/get_parameters_by_path.html) to recursively retrieve all parameters within the a specified path hierarchy.

```py
import os
from typing import Tuple, Type

import boto3
from pydantic import BaseModel

from pydantic_settings import (
AwsSystemsManagerParameterStoreSettingsSource,
BaseSettings,
PydanticBaseSettingsSource,
)


class SubModel(BaseModel):
a: str


class AwsParamStoreSettings(BaseSettings):
foo: str
bar: int
sub: SubModel

@classmethod
def settings_customise_sources(
cls,
settings_cls: Type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> Tuple[PydanticBaseSettingsSource, ...]:
client = boto3.client('ssm')
ssm_param_store_settings = AwsSystemsManagerParameterStoreSettingsSource(
settings_cls,
ssm_client=client,
ssm_path=os.environ.get('SSM_PREFIX', '/api/dev/'),
)
return (
init_settings,
env_settings,
dotenv_settings,
file_secret_settings,
ssm_param_store_settings,
)
``` -->

## Other settings source

Other settings sources are available for common configuration files:
Expand Down
61 changes: 61 additions & 0 deletions pydantic_settings/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@
import tomli
import yaml

try:
from mypy_boto3_ssm import SSMClient
except ImportError:
SSMClient = None

from pydantic_settings.main import BaseSettings
else:
yaml = None
Expand Down Expand Up @@ -2014,6 +2019,62 @@ def __repr__(self) -> str:
return f'AzureKeyVaultSettingsSource(url={self._url!r}, ' f'env_nested_delimiter={self.env_nested_delimiter!r})'


class AwsSystemsManagerParameterStoreSettingsSource(EnvSettingsSource):
_ssm_client: SSMClient
_ssm_path: str

def __init__(
self,
settings_cls: type[BaseSettings],
ssm_client: SSMClient,
ssm_path: str = '/',
case_sensitive: bool | None = None,
env_prefix: str | None = None,
env_nested_delimiter: str = '/',
env_ignore_empty: bool | None = None,
env_parse_none_str: str | None = None,
env_parse_enums: bool | None = None,
) -> None:
self._ssm_client = ssm_client
self._ssm_path = ssm_path
super().__init__(
settings_cls,
case_sensitive,
env_prefix,
env_nested_delimiter,
env_ignore_empty,
env_parse_none_str,
env_parse_enums,
)

def _load_env_vars(self) -> Mapping[str, Optional[str]]:
paginator = self._ssm_client.get_paginator('get_parameters_by_path')
response_iterator = paginator.paginate(
Path=self._ssm_path, WithDecryption=True, Recursive=True
)

output = {}
try:
for page in response_iterator:
for parameter in page['Parameters']:
name = Path(parameter['Name'])
key = name.relative_to(self._ssm_path).as_posix()

if not self.case_sensitive:
first_key, *rest = key.split(self.env_nested_delimiter)
key = self.env_nested_delimiter.join([first_key.lower(), *rest])

output[key] = parameter['Value']

except self._ssm_client.exceptions.ClientError as e:
warnings.warn(f'Unable to get parameters from {self._ssm_path!r}: {e}')

return output

def __repr__(self) -> str:
return f'AwsSystemsManagerParameterStoreSettingsSource(ssm_path={self._ssm_path!r})'


def _get_env_var_key(key: str, case_sensitive: bool = False) -> str:
return key if case_sensitive else key.lower()

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ dynamic = ['version']
yaml = ["pyyaml>=6.0.1"]
toml = ["tomli>=2.0.1"]
azure-key-vault = ["azure-keyvault-secrets>=4.8.0", "azure-identity>=1.16.0"]
aws = ["boto3>=1.35.0", "boto3-stubs[ssm]>=1.35.0"]

[project.urls]
Homepage = 'https://github.com/pydantic/pydantic-settings'
Expand Down
187 changes: 187 additions & 0 deletions tests/test_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from pydantic_settings.main import BaseSettings, SettingsConfigDict
from pydantic_settings.sources import (
AwsSystemsManagerParameterStoreSettingsSource,
AzureKeyVaultSettingsSource,
PydanticBaseSettingsSource,
PyprojectTomlConfigSettingsSource,
Expand Down Expand Up @@ -210,3 +211,189 @@ def _raise_resource_not_found_when_getting_parent_secret_name(self, secret_name:
raise ResourceNotFoundError()

return key_vault_secret


class TestAwsSystemsManagerParameterStoreSettingsSource:
"""Test AwsSystemsManagerParameterStoreSettingsSource."""

def test___init__(self, mocker: MockerFixture) -> None:
"""Test __init__."""

class AwsSettings(BaseSettings):
"""AWS settings."""

mock_parameters = []
paginator_mock = mocker.Mock()
paginator_mock.paginate.return_value = [{'Parameters': mock_parameters}]

client_mock = mocker.Mock()
client_mock.get_paginator.return_value = paginator_mock
client_mock.exceptions.ClientError = Exception

AwsSystemsManagerParameterStoreSettingsSource(
settings_cls=AwsSettings, ssm_client=client_mock, ssm_path='/my/path'
)

def test___call__case_sensitive(self, mocker: MockerFixture) -> None:
"""Test __call__."""

class SqlServer(BaseModel):
password: str = Field(..., alias='Password')

class AwsSettings(BaseSettings):
"""AWS settings."""

SqlServerUser: str
sql_server_user: str = Field(..., alias='SqlServerUser')
sql_server: SqlServer = Field(..., alias='SqlServer')

mock_parameters = [
{'Name': '/my/path/SqlServerUser', 'Value': 'SecretValue'},
{'Name': '/my/path/SqlServer/Password', 'Value': 'SecretValue'},
]
paginator_mock = mocker.Mock()
paginator_mock.paginate.return_value = [{'Parameters': mock_parameters}]

client_mock = mocker.Mock()
client_mock.get_paginator.return_value = paginator_mock
client_mock.exceptions.ClientError = Exception

obj = AwsSystemsManagerParameterStoreSettingsSource(
settings_cls=AwsSettings,
ssm_client=client_mock,
ssm_path='/my/path',
case_sensitive=True,
)

settings = obj()

assert settings['SqlServerUser'] == 'SecretValue'
assert settings['SqlServer']['Password'] == 'SecretValue'

def test___call__case_insensitive(self, mocker: MockerFixture) -> None:
"""Test __call__."""

class SqlServer(BaseModel):
password: str = Field(..., alias='Password')

class AwsSettings(BaseSettings):
"""AWS settings."""

SqlServerUser: str
sql_server_user: str = Field(..., alias='SqlServerUser')
sql_server: SqlServer = Field(..., alias='SqlServer')

mock_parameters = [
{'Name': '/my/path/SqlServerUser', 'Value': 'SecretValue'},
{'Name': '/my/path/SqlServer/Password', 'Value': 'SecretValue'},
]
paginator_mock = mocker.Mock()
paginator_mock.paginate.return_value = [{'Parameters': mock_parameters}]

client_mock = mocker.Mock()
client_mock.get_paginator.return_value = paginator_mock
client_mock.exceptions.ClientError = Exception

obj = AwsSystemsManagerParameterStoreSettingsSource(
settings_cls=AwsSettings,
ssm_client=client_mock,
ssm_path='/my/path',
case_sensitive=False,
)
settings = obj()

assert settings['SqlServerUser'] == 'SecretValue'
assert settings['SqlServer']['Password'] == 'SecretValue'

def test_aws_ssm_settings_source(self, mocker: MockerFixture) -> None:
"""Test AwsSystemsManagerParameterStoreSettingsSource."""
mock_parameters = [
{'Name': '/my/path/SqlServerUser', 'Value': 'SecretValue'},
{'Name': '/my/path/SqlServer/Password', 'Value': 'SecretValue'},
]
paginator_mock = mocker.Mock()
paginator_mock.paginate.return_value = [{'Parameters': mock_parameters}]

client_mock = mocker.Mock()
client_mock.get_paginator.return_value = paginator_mock
client_mock.exceptions.ClientError = Exception

class SqlServer(BaseModel):
password: str = Field(..., alias='Password')

class AwsSettings(BaseSettings):
"""AWS settings."""

SqlServerUser: str
sql_server_user: str = Field(..., alias='SqlServerUser')
sql_server: SqlServer = Field(..., alias='SqlServer')

@classmethod
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> tuple[PydanticBaseSettingsSource, ...]:
return (
AwsSystemsManagerParameterStoreSettingsSource(
settings_cls=AwsSettings,
ssm_client=client_mock,
ssm_path='/my/path',
),
)

settings = AwsSettings() # type: ignore

assert settings.SqlServerUser == 'SecretValue'
assert settings.sql_server_user == 'SecretValue'
assert settings.sql_server.password == 'SecretValue'

def test_aws_ssm_settings_source__delimiter(self, mocker: MockerFixture) -> None:
"""Test AwsSystemsManagerParameterStoreSettingsSource."""
mock_parameters = [
{'Name': '/my/path/SqlServerUser', 'Value': 'SecretValue'},
{'Name': '/my/path/SqlServer__Password', 'Value': 'SecretValue'},
]
paginator_mock = mocker.Mock()
paginator_mock.paginate.return_value = [{'Parameters': mock_parameters}]

client_mock = mocker.Mock()
client_mock.get_paginator.return_value = paginator_mock
client_mock.exceptions.ClientError = Exception

class SqlServer(BaseModel):
password: str = Field(..., alias='Password')

class AwsSettings(BaseSettings):
"""AWS settings."""

SqlServerUser: str
sql_server_user: str = Field(..., alias='SqlServerUser')
sql_server: SqlServer = Field(..., alias='SqlServer')

@classmethod
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> tuple[PydanticBaseSettingsSource, ...]:
return (
AwsSystemsManagerParameterStoreSettingsSource(
settings_cls=AwsSettings,
ssm_client=client_mock,
ssm_path='/my/path',
env_nested_delimiter='__',
),
)

settings = AwsSettings() # type: ignore

assert settings.SqlServerUser == 'SecretValue'
assert settings.sql_server_user == 'SecretValue'
assert settings.sql_server.password == 'SecretValue'
Loading