diff --git a/aws_lambda_powertools/utilities/parameters/appconfig.py b/aws_lambda_powertools/utilities/parameters/appconfig.py index 8e10540b186..ad36395c452 100644 --- a/aws_lambda_powertools/utilities/parameters/appconfig.py +++ b/aws_lambda_powertools/utilities/parameters/appconfig.py @@ -107,7 +107,12 @@ def _get_multiple(self, path: str, **sdk_options) -> Dict[str, str]: def get_app_config( - name: str, environment: str, application: Optional[str] = None, transform: Optional[str] = None, **sdk_options + name: str, + environment: str, + application: Optional[str] = None, + transform: Optional[str] = None, + force_fetch: bool = False, + **sdk_options ) -> Union[str, list, dict, bytes]: """ Retrieve a configuration value from AWS App Config. @@ -122,6 +127,8 @@ def get_app_config( Application of the configuration transform: str, optional Transforms the content from a JSON object ('json') or base64 binary string ('binary') + force_fetch: bool, optional + Force update even before a cached item has expired, defaults to False sdk_options: dict, optional Dictionary of options that will be passed to the Parameter Store get_parameter API call @@ -160,4 +167,4 @@ def get_app_config( sdk_options["ClientId"] = CLIENT_ID - return DEFAULT_PROVIDERS["appconfig"].get(name, transform=transform, **sdk_options) + return DEFAULT_PROVIDERS["appconfig"].get(name, transform=transform, force_fetch=force_fetch, **sdk_options) diff --git a/aws_lambda_powertools/utilities/parameters/base.py b/aws_lambda_powertools/utilities/parameters/base.py index 7ce0c9e4d2e..b07312f19d3 100644 --- a/aws_lambda_powertools/utilities/parameters/base.py +++ b/aws_lambda_powertools/utilities/parameters/base.py @@ -38,7 +38,12 @@ def _has_not_expired(self, key: Tuple[str, Optional[str]]) -> bool: return key in self.store and self.store[key].ttl >= datetime.now() def get( - self, name: str, max_age: int = DEFAULT_MAX_AGE_SECS, transform: Optional[str] = None, **sdk_options + self, + name: str, + max_age: int = DEFAULT_MAX_AGE_SECS, + transform: Optional[str] = None, + force_fetch: bool = False, + **sdk_options, ) -> Union[str, list, dict, bytes]: """ Retrieve a parameter value or return the cached value @@ -53,6 +58,8 @@ def get( Optional transformation of the parameter value. Supported values are "json" for JSON strings and "binary" for base 64 encoded values. + force_fetch: bool, optional + Force update even before a cached item has expired, defaults to False sdk_options: dict, optional Arguments that will be passed directly to the underlying API call @@ -76,7 +83,7 @@ def get( # an acceptable tradeoff. key = (name, transform) - if self._has_not_expired(key): + if not force_fetch and self._has_not_expired(key): return self.store[key].value try: @@ -105,6 +112,7 @@ def get_multiple( max_age: int = DEFAULT_MAX_AGE_SECS, transform: Optional[str] = None, raise_on_transform_error: bool = False, + force_fetch: bool = False, **sdk_options, ) -> Union[Dict[str, str], Dict[str, dict], Dict[str, bytes]]: """ @@ -123,6 +131,8 @@ def get_multiple( raise_on_transform_error: bool, optional Raises an exception if any transform fails, otherwise this will return a None value for each transform that failed + force_fetch: bool, optional + Force update even before a cached item has expired, defaults to False sdk_options: dict, optional Arguments that will be passed directly to the underlying API call @@ -137,7 +147,7 @@ def get_multiple( key = (path, transform) - if self._has_not_expired(key): + if not force_fetch and self._has_not_expired(key): return self.store[key].value try: diff --git a/aws_lambda_powertools/utilities/parameters/secrets.py b/aws_lambda_powertools/utilities/parameters/secrets.py index e3981d22bcc..f14e4703ba8 100644 --- a/aws_lambda_powertools/utilities/parameters/secrets.py +++ b/aws_lambda_powertools/utilities/parameters/secrets.py @@ -93,7 +93,9 @@ def _get_multiple(self, path: str, **sdk_options) -> Dict[str, str]: raise NotImplementedError() -def get_secret(name: str, transform: Optional[str] = None, **sdk_options) -> Union[str, dict, bytes]: +def get_secret( + name: str, transform: Optional[str] = None, force_fetch: bool = False, **sdk_options +) -> Union[str, dict, bytes]: """ Retrieve a parameter value from AWS Secrets Manager @@ -103,6 +105,8 @@ def get_secret(name: str, transform: Optional[str] = None, **sdk_options) -> Uni Name of the parameter transform: str, optional Transforms the content from a JSON object ('json') or base64 binary string ('binary') + force_fetch: bool, optional + Force update even before a cached item has expired, defaults to False sdk_options: dict, optional Dictionary of options that will be passed to the get_secret_value call @@ -139,4 +143,4 @@ def get_secret(name: str, transform: Optional[str] = None, **sdk_options) -> Uni if "secrets" not in DEFAULT_PROVIDERS: DEFAULT_PROVIDERS["secrets"] = SecretsProvider() - return DEFAULT_PROVIDERS["secrets"].get(name, transform=transform, **sdk_options) + return DEFAULT_PROVIDERS["secrets"].get(name, transform=transform, force_fetch=force_fetch, **sdk_options) diff --git a/aws_lambda_powertools/utilities/parameters/ssm.py b/aws_lambda_powertools/utilities/parameters/ssm.py index 0f39bfac9c0..9c29436342a 100644 --- a/aws_lambda_powertools/utilities/parameters/ssm.py +++ b/aws_lambda_powertools/utilities/parameters/ssm.py @@ -92,6 +92,7 @@ def get( max_age: int = DEFAULT_MAX_AGE_SECS, transform: Optional[str] = None, decrypt: bool = False, + force_fetch: bool = False, **sdk_options ) -> Union[str, list, dict, bytes]: """ @@ -109,6 +110,8 @@ def get( values. decrypt: bool, optional If the parameter value should be decrypted + force_fetch: bool, optional + Force update even before a cached item has expired, defaults to False sdk_options: dict, optional Arguments that will be passed directly to the underlying API call @@ -124,7 +127,7 @@ def get( # Add to `decrypt` sdk_options to we can have an explicit option for this sdk_options["decrypt"] = decrypt - return super().get(name, max_age, transform, **sdk_options) + return super().get(name, max_age, transform, force_fetch, **sdk_options) def _get(self, name: str, decrypt: bool = False, **sdk_options) -> str: """ @@ -185,7 +188,7 @@ def _get_multiple(self, path: str, decrypt: bool = False, recursive: bool = Fals def get_parameter( - name: str, transform: Optional[str] = None, decrypt: bool = False, **sdk_options + name: str, transform: Optional[str] = None, decrypt: bool = False, force_fetch: bool = False, **sdk_options ) -> Union[str, list, dict, bytes]: """ Retrieve a parameter value from AWS Systems Manager (SSM) Parameter Store @@ -198,6 +201,8 @@ def get_parameter( Transforms the content from a JSON object ('json') or base64 binary string ('binary') decrypt: bool, optional If the parameter values should be decrypted + force_fetch: bool, optional + Force update even before a cached item has expired, defaults to False sdk_options: dict, optional Dictionary of options that will be passed to the Parameter Store get_parameter API call @@ -237,11 +242,16 @@ def get_parameter( # Add to `decrypt` sdk_options to we can have an explicit option for this sdk_options["decrypt"] = decrypt - return DEFAULT_PROVIDERS["ssm"].get(name, transform=transform, **sdk_options) + return DEFAULT_PROVIDERS["ssm"].get(name, transform=transform, force_fetch=force_fetch, **sdk_options) def get_parameters( - path: str, transform: Optional[str] = None, recursive: bool = True, decrypt: bool = False, **sdk_options + path: str, + transform: Optional[str] = None, + recursive: bool = True, + decrypt: bool = False, + force_fetch: bool = False, + **sdk_options ) -> Union[Dict[str, str], Dict[str, dict], Dict[str, bytes]]: """ Retrieve multiple parameter values from AWS Systems Manager (SSM) Parameter Store @@ -256,6 +266,8 @@ def get_parameters( If this should retrieve the parameter values recursively or not, defaults to True decrypt: bool, optional If the parameter values should be decrypted + force_fetch: bool, optional + Force update even before a cached item has expired, defaults to False sdk_options: dict, optional Dictionary of options that will be passed to the Parameter Store get_parameters_by_path API call @@ -295,4 +307,4 @@ def get_parameters( sdk_options["recursive"] = recursive sdk_options["decrypt"] = decrypt - return DEFAULT_PROVIDERS["ssm"].get_multiple(path, transform=transform, **sdk_options) + return DEFAULT_PROVIDERS["ssm"].get_multiple(path, transform=transform, force_fetch=force_fetch, **sdk_options) diff --git a/tests/functional/test_utilities_parameters.py b/tests/functional/test_utilities_parameters.py index 5a915f574ae..13c493ef5d6 100644 --- a/tests/functional/test_utilities_parameters.py +++ b/tests/functional/test_utilities_parameters.py @@ -1663,3 +1663,48 @@ def test_get_transform_method_preserve_auto_unhandled(key): transform = parameters.base.get_transform_method(key, "auto") assert transform is None + + +def test_base_provider_get_multiple_force_update(mock_name, mock_value): + """ + Test BaseProvider.get_multiple() with cached values and force_fetch is True + """ + + class TestProvider(BaseProvider): + def _get(self, name: str, **kwargs) -> str: + raise NotImplementedError() + + def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: + assert path == mock_name + return {"A": mock_value} + + provider = TestProvider() + + provider.store[(mock_name, None)] = ExpirableValue({"B": mock_value}, datetime.now() + timedelta(seconds=60)) + + value = provider.get_multiple(mock_name, force_fetch=True) + + assert isinstance(value, dict) + assert value["A"] == mock_value + + +def test_base_provider_get_force_update(mock_name, mock_value): + """ + Test BaseProvider.get() with cached values and force_fetch is True + """ + + class TestProvider(BaseProvider): + def _get(self, name: str, **kwargs) -> str: + return mock_value + + def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: + raise NotImplementedError() + + provider = TestProvider() + + provider.store[(mock_name, None)] = ExpirableValue("not-value", datetime.now() + timedelta(seconds=60)) + + value = provider.get(mock_name, force_fetch=True) + + assert isinstance(value, str) + assert value == mock_value