Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(parameters): Add force_fetch option #341

Merged
merged 6 commits into from
Mar 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions aws_lambda_powertools/utilities/parameters/appconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,12 @@ def _get_multiple(self, path: str, **sdk_options) -> Dict[str, str]:


def get_app_config(
name: str, environment: str, application: Optional[str] = None, transform: Optional[str] = None, **sdk_options
name: str,
environment: str,
application: Optional[str] = None,
transform: Optional[str] = None,
force_fetch: bool = False,
**sdk_options
) -> Union[str, list, dict, bytes]:
"""
Retrieve a configuration value from AWS App Config.
Expand All @@ -122,6 +127,8 @@ def get_app_config(
Application of the configuration
transform: str, optional
Transforms the content from a JSON object ('json') or base64 binary string ('binary')
force_fetch: bool, optional
Force update even before a cached item has expired, defaults to False
sdk_options: dict, optional
Dictionary of options that will be passed to the Parameter Store get_parameter API call

Expand Down Expand Up @@ -160,4 +167,4 @@ def get_app_config(

sdk_options["ClientId"] = CLIENT_ID

return DEFAULT_PROVIDERS["appconfig"].get(name, transform=transform, **sdk_options)
return DEFAULT_PROVIDERS["appconfig"].get(name, transform=transform, force_fetch=force_fetch, **sdk_options)
16 changes: 13 additions & 3 deletions aws_lambda_powertools/utilities/parameters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@ def _has_not_expired(self, key: Tuple[str, Optional[str]]) -> bool:
return key in self.store and self.store[key].ttl >= datetime.now()

def get(
self, name: str, max_age: int = DEFAULT_MAX_AGE_SECS, transform: Optional[str] = None, **sdk_options
self,
name: str,
max_age: int = DEFAULT_MAX_AGE_SECS,
transform: Optional[str] = None,
force_fetch: bool = False,
**sdk_options,
) -> Union[str, list, dict, bytes]:
"""
Retrieve a parameter value or return the cached value
Expand All @@ -53,6 +58,8 @@ def get(
Optional transformation of the parameter value. Supported values
are "json" for JSON strings and "binary" for base 64 encoded
values.
force_fetch: bool, optional
Force update even before a cached item has expired, defaults to False
sdk_options: dict, optional
Arguments that will be passed directly to the underlying API call

Expand All @@ -76,7 +83,7 @@ def get(
# an acceptable tradeoff.
key = (name, transform)

if self._has_not_expired(key):
if not force_fetch and self._has_not_expired(key):
return self.store[key].value

try:
Expand Down Expand Up @@ -105,6 +112,7 @@ def get_multiple(
max_age: int = DEFAULT_MAX_AGE_SECS,
transform: Optional[str] = None,
raise_on_transform_error: bool = False,
force_fetch: bool = False,
**sdk_options,
) -> Union[Dict[str, str], Dict[str, dict], Dict[str, bytes]]:
"""
Expand All @@ -123,6 +131,8 @@ def get_multiple(
raise_on_transform_error: bool, optional
Raises an exception if any transform fails, otherwise this will
return a None value for each transform that failed
force_fetch: bool, optional
Force update even before a cached item has expired, defaults to False
sdk_options: dict, optional
Arguments that will be passed directly to the underlying API call

Expand All @@ -137,7 +147,7 @@ def get_multiple(

key = (path, transform)

if self._has_not_expired(key):
if not force_fetch and self._has_not_expired(key):
return self.store[key].value

try:
Expand Down
8 changes: 6 additions & 2 deletions aws_lambda_powertools/utilities/parameters/secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ def _get_multiple(self, path: str, **sdk_options) -> Dict[str, str]:
raise NotImplementedError()


def get_secret(name: str, transform: Optional[str] = None, **sdk_options) -> Union[str, dict, bytes]:
def get_secret(
name: str, transform: Optional[str] = None, force_fetch: bool = False, **sdk_options
) -> Union[str, dict, bytes]:
"""
Retrieve a parameter value from AWS Secrets Manager

Expand All @@ -103,6 +105,8 @@ def get_secret(name: str, transform: Optional[str] = None, **sdk_options) -> Uni
Name of the parameter
transform: str, optional
Transforms the content from a JSON object ('json') or base64 binary string ('binary')
force_fetch: bool, optional
Force update even before a cached item has expired, defaults to False
sdk_options: dict, optional
Dictionary of options that will be passed to the get_secret_value call

Expand Down Expand Up @@ -139,4 +143,4 @@ def get_secret(name: str, transform: Optional[str] = None, **sdk_options) -> Uni
if "secrets" not in DEFAULT_PROVIDERS:
DEFAULT_PROVIDERS["secrets"] = SecretsProvider()

return DEFAULT_PROVIDERS["secrets"].get(name, transform=transform, **sdk_options)
return DEFAULT_PROVIDERS["secrets"].get(name, transform=transform, force_fetch=force_fetch, **sdk_options)
22 changes: 17 additions & 5 deletions aws_lambda_powertools/utilities/parameters/ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def get(
max_age: int = DEFAULT_MAX_AGE_SECS,
transform: Optional[str] = None,
decrypt: bool = False,
force_fetch: bool = False,
**sdk_options
) -> Union[str, list, dict, bytes]:
"""
Expand All @@ -109,6 +110,8 @@ def get(
values.
decrypt: bool, optional
If the parameter value should be decrypted
force_fetch: bool, optional
Force update even before a cached item has expired, defaults to False
sdk_options: dict, optional
Arguments that will be passed directly to the underlying API call

Expand All @@ -124,7 +127,7 @@ def get(
# Add to `decrypt` sdk_options to we can have an explicit option for this
sdk_options["decrypt"] = decrypt

return super().get(name, max_age, transform, **sdk_options)
return super().get(name, max_age, transform, force_fetch, **sdk_options)

def _get(self, name: str, decrypt: bool = False, **sdk_options) -> str:
"""
Expand Down Expand Up @@ -185,7 +188,7 @@ def _get_multiple(self, path: str, decrypt: bool = False, recursive: bool = Fals


def get_parameter(
name: str, transform: Optional[str] = None, decrypt: bool = False, **sdk_options
name: str, transform: Optional[str] = None, decrypt: bool = False, force_fetch: bool = False, **sdk_options
) -> Union[str, list, dict, bytes]:
"""
Retrieve a parameter value from AWS Systems Manager (SSM) Parameter Store
Expand All @@ -198,6 +201,8 @@ def get_parameter(
Transforms the content from a JSON object ('json') or base64 binary string ('binary')
decrypt: bool, optional
If the parameter values should be decrypted
force_fetch: bool, optional
Force update even before a cached item has expired, defaults to False
sdk_options: dict, optional
Dictionary of options that will be passed to the Parameter Store get_parameter API call

Expand Down Expand Up @@ -237,11 +242,16 @@ def get_parameter(
# Add to `decrypt` sdk_options to we can have an explicit option for this
sdk_options["decrypt"] = decrypt

return DEFAULT_PROVIDERS["ssm"].get(name, transform=transform, **sdk_options)
return DEFAULT_PROVIDERS["ssm"].get(name, transform=transform, force_fetch=force_fetch, **sdk_options)


def get_parameters(
path: str, transform: Optional[str] = None, recursive: bool = True, decrypt: bool = False, **sdk_options
path: str,
transform: Optional[str] = None,
recursive: bool = True,
decrypt: bool = False,
force_fetch: bool = False,
**sdk_options
) -> Union[Dict[str, str], Dict[str, dict], Dict[str, bytes]]:
"""
Retrieve multiple parameter values from AWS Systems Manager (SSM) Parameter Store
Expand All @@ -256,6 +266,8 @@ def get_parameters(
If this should retrieve the parameter values recursively or not, defaults to True
decrypt: bool, optional
If the parameter values should be decrypted
force_fetch: bool, optional
Force update even before a cached item has expired, defaults to False
sdk_options: dict, optional
Dictionary of options that will be passed to the Parameter Store get_parameters_by_path API call

Expand Down Expand Up @@ -295,4 +307,4 @@ def get_parameters(
sdk_options["recursive"] = recursive
sdk_options["decrypt"] = decrypt

return DEFAULT_PROVIDERS["ssm"].get_multiple(path, transform=transform, **sdk_options)
return DEFAULT_PROVIDERS["ssm"].get_multiple(path, transform=transform, force_fetch=force_fetch, **sdk_options)
45 changes: 45 additions & 0 deletions tests/functional/test_utilities_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1663,3 +1663,48 @@ def test_get_transform_method_preserve_auto_unhandled(key):
transform = parameters.base.get_transform_method(key, "auto")

assert transform is None


def test_base_provider_get_multiple_force_update(mock_name, mock_value):
"""
Test BaseProvider.get_multiple() with cached values and force_fetch is True
"""

class TestProvider(BaseProvider):
def _get(self, name: str, **kwargs) -> str:
raise NotImplementedError()

def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]:
assert path == mock_name
return {"A": mock_value}

provider = TestProvider()

provider.store[(mock_name, None)] = ExpirableValue({"B": mock_value}, datetime.now() + timedelta(seconds=60))

value = provider.get_multiple(mock_name, force_fetch=True)

assert isinstance(value, dict)
assert value["A"] == mock_value


def test_base_provider_get_force_update(mock_name, mock_value):
"""
Test BaseProvider.get() with cached values and force_fetch is True
"""

class TestProvider(BaseProvider):
def _get(self, name: str, **kwargs) -> str:
return mock_value

def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]:
raise NotImplementedError()

provider = TestProvider()

provider.store[(mock_name, None)] = ExpirableValue("not-value", datetime.now() + timedelta(seconds=60))

value = provider.get(mock_name, force_fetch=True)

assert isinstance(value, str)
assert value == mock_value