Skip to content

Commit b365fba

Browse files
author
Michael Brewer
authored
feat(parameters): Add force_fetch option (#341)
1 parent 33bb7c8 commit b365fba

File tree

5 files changed

+90
-12
lines changed

5 files changed

+90
-12
lines changed

Diff for: aws_lambda_powertools/utilities/parameters/appconfig.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,12 @@ def _get_multiple(self, path: str, **sdk_options) -> Dict[str, str]:
107107

108108

109109
def get_app_config(
110-
name: str, environment: str, application: Optional[str] = None, transform: Optional[str] = None, **sdk_options
110+
name: str,
111+
environment: str,
112+
application: Optional[str] = None,
113+
transform: Optional[str] = None,
114+
force_fetch: bool = False,
115+
**sdk_options
111116
) -> Union[str, list, dict, bytes]:
112117
"""
113118
Retrieve a configuration value from AWS App Config.
@@ -122,6 +127,8 @@ def get_app_config(
122127
Application of the configuration
123128
transform: str, optional
124129
Transforms the content from a JSON object ('json') or base64 binary string ('binary')
130+
force_fetch: bool, optional
131+
Force update even before a cached item has expired, defaults to False
125132
sdk_options: dict, optional
126133
Dictionary of options that will be passed to the Parameter Store get_parameter API call
127134
@@ -160,4 +167,4 @@ def get_app_config(
160167

161168
sdk_options["ClientId"] = CLIENT_ID
162169

163-
return DEFAULT_PROVIDERS["appconfig"].get(name, transform=transform, **sdk_options)
170+
return DEFAULT_PROVIDERS["appconfig"].get(name, transform=transform, force_fetch=force_fetch, **sdk_options)

Diff for: aws_lambda_powertools/utilities/parameters/base.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,12 @@ def _has_not_expired(self, key: Tuple[str, Optional[str]]) -> bool:
3838
return key in self.store and self.store[key].ttl >= datetime.now()
3939

4040
def get(
41-
self, name: str, max_age: int = DEFAULT_MAX_AGE_SECS, transform: Optional[str] = None, **sdk_options
41+
self,
42+
name: str,
43+
max_age: int = DEFAULT_MAX_AGE_SECS,
44+
transform: Optional[str] = None,
45+
force_fetch: bool = False,
46+
**sdk_options,
4247
) -> Union[str, list, dict, bytes]:
4348
"""
4449
Retrieve a parameter value or return the cached value
@@ -53,6 +58,8 @@ def get(
5358
Optional transformation of the parameter value. Supported values
5459
are "json" for JSON strings and "binary" for base 64 encoded
5560
values.
61+
force_fetch: bool, optional
62+
Force update even before a cached item has expired, defaults to False
5663
sdk_options: dict, optional
5764
Arguments that will be passed directly to the underlying API call
5865
@@ -76,7 +83,7 @@ def get(
7683
# an acceptable tradeoff.
7784
key = (name, transform)
7885

79-
if self._has_not_expired(key):
86+
if not force_fetch and self._has_not_expired(key):
8087
return self.store[key].value
8188

8289
try:
@@ -105,6 +112,7 @@ def get_multiple(
105112
max_age: int = DEFAULT_MAX_AGE_SECS,
106113
transform: Optional[str] = None,
107114
raise_on_transform_error: bool = False,
115+
force_fetch: bool = False,
108116
**sdk_options,
109117
) -> Union[Dict[str, str], Dict[str, dict], Dict[str, bytes]]:
110118
"""
@@ -123,6 +131,8 @@ def get_multiple(
123131
raise_on_transform_error: bool, optional
124132
Raises an exception if any transform fails, otherwise this will
125133
return a None value for each transform that failed
134+
force_fetch: bool, optional
135+
Force update even before a cached item has expired, defaults to False
126136
sdk_options: dict, optional
127137
Arguments that will be passed directly to the underlying API call
128138
@@ -137,7 +147,7 @@ def get_multiple(
137147

138148
key = (path, transform)
139149

140-
if self._has_not_expired(key):
150+
if not force_fetch and self._has_not_expired(key):
141151
return self.store[key].value
142152

143153
try:

Diff for: aws_lambda_powertools/utilities/parameters/secrets.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,9 @@ def _get_multiple(self, path: str, **sdk_options) -> Dict[str, str]:
9393
raise NotImplementedError()
9494

9595

96-
def get_secret(name: str, transform: Optional[str] = None, **sdk_options) -> Union[str, dict, bytes]:
96+
def get_secret(
97+
name: str, transform: Optional[str] = None, force_fetch: bool = False, **sdk_options
98+
) -> Union[str, dict, bytes]:
9799
"""
98100
Retrieve a parameter value from AWS Secrets Manager
99101
@@ -103,6 +105,8 @@ def get_secret(name: str, transform: Optional[str] = None, **sdk_options) -> Uni
103105
Name of the parameter
104106
transform: str, optional
105107
Transforms the content from a JSON object ('json') or base64 binary string ('binary')
108+
force_fetch: bool, optional
109+
Force update even before a cached item has expired, defaults to False
106110
sdk_options: dict, optional
107111
Dictionary of options that will be passed to the get_secret_value call
108112
@@ -139,4 +143,4 @@ def get_secret(name: str, transform: Optional[str] = None, **sdk_options) -> Uni
139143
if "secrets" not in DEFAULT_PROVIDERS:
140144
DEFAULT_PROVIDERS["secrets"] = SecretsProvider()
141145

142-
return DEFAULT_PROVIDERS["secrets"].get(name, transform=transform, **sdk_options)
146+
return DEFAULT_PROVIDERS["secrets"].get(name, transform=transform, force_fetch=force_fetch, **sdk_options)

Diff for: aws_lambda_powertools/utilities/parameters/ssm.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def get(
9292
max_age: int = DEFAULT_MAX_AGE_SECS,
9393
transform: Optional[str] = None,
9494
decrypt: bool = False,
95+
force_fetch: bool = False,
9596
**sdk_options
9697
) -> Union[str, list, dict, bytes]:
9798
"""
@@ -109,6 +110,8 @@ def get(
109110
values.
110111
decrypt: bool, optional
111112
If the parameter value should be decrypted
113+
force_fetch: bool, optional
114+
Force update even before a cached item has expired, defaults to False
112115
sdk_options: dict, optional
113116
Arguments that will be passed directly to the underlying API call
114117
@@ -124,7 +127,7 @@ def get(
124127
# Add to `decrypt` sdk_options to we can have an explicit option for this
125128
sdk_options["decrypt"] = decrypt
126129

127-
return super().get(name, max_age, transform, **sdk_options)
130+
return super().get(name, max_age, transform, force_fetch, **sdk_options)
128131

129132
def _get(self, name: str, decrypt: bool = False, **sdk_options) -> str:
130133
"""
@@ -185,7 +188,7 @@ def _get_multiple(self, path: str, decrypt: bool = False, recursive: bool = Fals
185188

186189

187190
def get_parameter(
188-
name: str, transform: Optional[str] = None, decrypt: bool = False, **sdk_options
191+
name: str, transform: Optional[str] = None, decrypt: bool = False, force_fetch: bool = False, **sdk_options
189192
) -> Union[str, list, dict, bytes]:
190193
"""
191194
Retrieve a parameter value from AWS Systems Manager (SSM) Parameter Store
@@ -198,6 +201,8 @@ def get_parameter(
198201
Transforms the content from a JSON object ('json') or base64 binary string ('binary')
199202
decrypt: bool, optional
200203
If the parameter values should be decrypted
204+
force_fetch: bool, optional
205+
Force update even before a cached item has expired, defaults to False
201206
sdk_options: dict, optional
202207
Dictionary of options that will be passed to the Parameter Store get_parameter API call
203208
@@ -237,11 +242,16 @@ def get_parameter(
237242
# Add to `decrypt` sdk_options to we can have an explicit option for this
238243
sdk_options["decrypt"] = decrypt
239244

240-
return DEFAULT_PROVIDERS["ssm"].get(name, transform=transform, **sdk_options)
245+
return DEFAULT_PROVIDERS["ssm"].get(name, transform=transform, force_fetch=force_fetch, **sdk_options)
241246

242247

243248
def get_parameters(
244-
path: str, transform: Optional[str] = None, recursive: bool = True, decrypt: bool = False, **sdk_options
249+
path: str,
250+
transform: Optional[str] = None,
251+
recursive: bool = True,
252+
decrypt: bool = False,
253+
force_fetch: bool = False,
254+
**sdk_options
245255
) -> Union[Dict[str, str], Dict[str, dict], Dict[str, bytes]]:
246256
"""
247257
Retrieve multiple parameter values from AWS Systems Manager (SSM) Parameter Store
@@ -256,6 +266,8 @@ def get_parameters(
256266
If this should retrieve the parameter values recursively or not, defaults to True
257267
decrypt: bool, optional
258268
If the parameter values should be decrypted
269+
force_fetch: bool, optional
270+
Force update even before a cached item has expired, defaults to False
259271
sdk_options: dict, optional
260272
Dictionary of options that will be passed to the Parameter Store get_parameters_by_path API call
261273
@@ -295,4 +307,4 @@ def get_parameters(
295307
sdk_options["recursive"] = recursive
296308
sdk_options["decrypt"] = decrypt
297309

298-
return DEFAULT_PROVIDERS["ssm"].get_multiple(path, transform=transform, **sdk_options)
310+
return DEFAULT_PROVIDERS["ssm"].get_multiple(path, transform=transform, force_fetch=force_fetch, **sdk_options)

Diff for: tests/functional/test_utilities_parameters.py

+45
Original file line numberDiff line numberDiff line change
@@ -1663,3 +1663,48 @@ def test_get_transform_method_preserve_auto_unhandled(key):
16631663
transform = parameters.base.get_transform_method(key, "auto")
16641664

16651665
assert transform is None
1666+
1667+
1668+
def test_base_provider_get_multiple_force_update(mock_name, mock_value):
1669+
"""
1670+
Test BaseProvider.get_multiple() with cached values and force_fetch is True
1671+
"""
1672+
1673+
class TestProvider(BaseProvider):
1674+
def _get(self, name: str, **kwargs) -> str:
1675+
raise NotImplementedError()
1676+
1677+
def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]:
1678+
assert path == mock_name
1679+
return {"A": mock_value}
1680+
1681+
provider = TestProvider()
1682+
1683+
provider.store[(mock_name, None)] = ExpirableValue({"B": mock_value}, datetime.now() + timedelta(seconds=60))
1684+
1685+
value = provider.get_multiple(mock_name, force_fetch=True)
1686+
1687+
assert isinstance(value, dict)
1688+
assert value["A"] == mock_value
1689+
1690+
1691+
def test_base_provider_get_force_update(mock_name, mock_value):
1692+
"""
1693+
Test BaseProvider.get() with cached values and force_fetch is True
1694+
"""
1695+
1696+
class TestProvider(BaseProvider):
1697+
def _get(self, name: str, **kwargs) -> str:
1698+
return mock_value
1699+
1700+
def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]:
1701+
raise NotImplementedError()
1702+
1703+
provider = TestProvider()
1704+
1705+
provider.store[(mock_name, None)] = ExpirableValue("not-value", datetime.now() + timedelta(seconds=60))
1706+
1707+
value = provider.get(mock_name, force_fetch=True)
1708+
1709+
assert isinstance(value, str)
1710+
assert value == mock_value

0 commit comments

Comments
 (0)