Skip to content

Commit bf31ec5

Browse files
committed
feat: adding expiration time for secret cache in secret manager plugin
1 parent 383b605 commit bf31ec5

File tree

7 files changed

+67
-16
lines changed

7 files changed

+67
-16
lines changed

aws_advanced_python_wrapper/aws_secrets_manager_plugin.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717
from json import JSONDecodeError, loads
1818
from re import search
1919
from types import SimpleNamespace
20-
from typing import TYPE_CHECKING, Callable, Dict, Optional, Set, Tuple
20+
from typing import TYPE_CHECKING, Callable, Optional, Set, Tuple
2121

2222
import boto3
23+
from aws_advanced_python_wrapper.utils.cache_map import CacheMap
2324
from botocore.exceptions import ClientError, EndpointConnectionError
2425

2526
if TYPE_CHECKING:
@@ -46,8 +47,10 @@ class AwsSecretsManagerPlugin(Plugin):
4647
_SUBSCRIBED_METHODS: Set[str] = {"connect", "force_connect"}
4748

4849
_SECRETS_ARN_PATTERN = r"^arn:aws:secretsmanager:(?P<region>[^:\n]*):[^:\n]*:([^:/\n]*[:/])?(.*)$"
50+
_ONE_YEAR_IN_SECONDS = 60 * 60 * 24 * 365
4951

50-
_secrets_cache: Dict[Tuple, SimpleNamespace] = {}
52+
_secret: Optional[SimpleNamespace] = None
53+
_secrets_cache: CacheMap[Tuple, SimpleNamespace] = CacheMap()
5154
_secret_key: Tuple = ()
5255

5356
@property
@@ -94,7 +97,13 @@ def force_connect(
9497
return self._connect(props, force_connect_func)
9598

9699
def _connect(self, props: Properties, connect_func: Callable) -> Connection:
97-
secret_fetched: bool = self._update_secret()
100+
token_expiration_sec: int = WrapperProperties.SECRETS_MANAGER_EXPIRATION.get_int(props)
101+
# if value is -1, default to one year
102+
if token_expiration_sec == -1:
103+
token_expiration_sec = AwsSecretsManagerPlugin._ONE_YEAR_IN_SECONDS
104+
token_expiration_ns = token_expiration_sec * 1000
105+
106+
secret_fetched: bool = self._update_secret(token_expiration_ns=token_expiration_ns)
98107

99108
try:
100109
self._apply_secret_to_properties(props)
@@ -105,7 +114,7 @@ def _connect(self, props: Properties, connect_func: Callable) -> Connection:
105114
raise AwsWrapperError(
106115
Messages.get_formatted("AwsSecretsManagerPlugin.ConnectException", e)) from e
107116

108-
secret_fetched = self._update_secret(True)
117+
secret_fetched = self._update_secret(token_expiration_ns=token_expiration_ns, force_refetch=True)
109118

110119
if secret_fetched:
111120
try:
@@ -117,9 +126,10 @@ def _connect(self, props: Properties, connect_func: Callable) -> Connection:
117126
unhandled_error)) from unhandled_error
118127
raise AwsWrapperError(Messages.get_formatted("AwsSecretsManagerPlugin.FailedLogin", e)) from e
119128

120-
def _update_secret(self, force_refetch: bool = False) -> bool:
129+
def _update_secret(self, token_expiration_ns: int, force_refetch: bool = False) -> bool:
121130
"""
122131
Called to update credentials from the cache, or from the AWS Secrets Manager service.
132+
:param token_expiration_ns: Expiration time in nanoseconds for secret stored in cache.
123133
:param force_refetch: Allows ignoring cached credentials and force fetches the latest credentials from the service.
124134
:return: `True`, if credentials were fetched from the service.
125135
"""
@@ -135,7 +145,7 @@ def _update_secret(self, force_refetch: bool = False) -> bool:
135145
try:
136146
self._secret = self._fetch_latest_credentials()
137147
if self._secret:
138-
AwsSecretsManagerPlugin._secrets_cache[self._secret_key] = self._secret
148+
AwsSecretsManagerPlugin._secrets_cache.put(self._secret_key, self._secret, token_expiration_ns)
139149
fetched = True
140150
except (ClientError, AttributeError) as e:
141151
logger.debug("AwsSecretsManagerPlugin.FailedToFetchDbCredentials", e)

aws_advanced_python_wrapper/utils/cache_map.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,12 @@ def _cleanup(self):
8888

8989

9090
class CacheItem(Generic[V]):
91-
def __init__(self, item: V, expiration_time: int):
91+
def __init__(self, item: V, expiration_time_ns: int):
9292
self.item = item
93-
self._expiration_time = expiration_time
93+
self._expiration_time_ns = expiration_time_ns
9494

9595
def __str__(self):
96-
return f"CacheItem [item={str(self.item)}, expiration_time={self._expiration_time}]"
96+
return f"CacheItem [item={str(self.item)}, expiration_time={self._expiration_time_ns}]"
9797

9898
def is_expired(self) -> bool:
99-
return time.perf_counter_ns() > self._expiration_time
99+
return time.perf_counter_ns() > self._expiration_time_ns

aws_advanced_python_wrapper/utils/iam_utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from aws_advanced_python_wrapper.hostinfo import HostInfo
3232
from aws_advanced_python_wrapper.plugin_service import PluginService
3333
from boto3 import Session
34+
from types import SimpleNamespace
3435

3536
from aws_advanced_python_wrapper.utils.properties import (Properties,
3637
WrapperProperties)
@@ -132,3 +133,22 @@ def __init__(self, token: str, expiration: datetime):
132133

133134
def is_expired(self) -> bool:
134135
return datetime.now() > self._expiration
136+
137+
138+
class SecretInfo:
139+
@property
140+
def secret(self):
141+
return self._secret
142+
143+
@property
144+
def expiration(self):
145+
return self._expiration
146+
147+
def __init__(self, secret: SimpleNamespace, expiration: Optional[datetime] = None):
148+
self._secret = secret
149+
self._expiration = expiration
150+
151+
def is_expired(self) -> bool:
152+
if self._expiration is None:
153+
return False
154+
return datetime.now() > self._expiration

aws_advanced_python_wrapper/utils/properties.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,10 @@ class WrapperProperties:
137137
SECRETS_MANAGER_ENDPOINT = WrapperProperty(
138138
"secrets_manager_endpoint",
139139
"The endpoint of the secret to retrieve.")
140+
SECRETS_MANAGER_EXPIRATION = WrapperProperty(
141+
"secrets_manager_expiration",
142+
"Secret cache expiration in seconds",
143+
-1)
140144

141145
DIALECT = WrapperProperty("wrapper_dialect", "A unique identifier for the supported database dialect.")
142146
AUXILIARY_QUERY_TIMEOUT_SEC = WrapperProperty(
@@ -255,7 +259,8 @@ class WrapperProperties:
255259
True)
256260

257261
# Host Selector
258-
ROUND_ROBIN_DEFAULT_WEIGHT = WrapperProperty("round_robin_default_weight", "The default weight for any hosts that have not been " +
262+
ROUND_ROBIN_DEFAULT_WEIGHT = WrapperProperty("round_robin_default_weight",
263+
"The default weight for any hosts that have not been " +
259264
"configured with the `round_robin_host_weight_pairs` parameter.",
260265
1)
261266

docs/examples/MySQLSecretsManager.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,24 @@
2121
if __name__ == "__main__":
2222
with AwsWrapperConnection.connect(
2323
mysql.connector.Connect,
24-
host="database.cluster-xyz.us-east-1.rds.amazonaws.com",
24+
host="atlas-mysql.cluster-cx422ywmsto6.us-east-2.rds.amazonaws.com",
2525
database="mysql",
26-
secrets_manager_secret_id="arn:aws:secretsmanager:<Region>:<AccountId>:secret:Secre78tName-6RandomCharacters",
26+
secrets_manager_secret_id="arn:aws:secretsmanager:us-east-2:851725167871:secret:mysql_test_1-ZjindE",
2727
secrets_manager_region="us-east-2",
2828
plugins="aws_secrets_manager"
2929
) as awsconn, awsconn.cursor() as cursor:
3030
cursor.execute("SELECT @@aurora_server_id")
3131
for record in cursor.fetchone():
3232
print(record)
33+
with AwsWrapperConnection.connect(
34+
mysql.connector.Connect,
35+
host="atlas-mysql.cluster-cx422ywmsto6.us-east-2.rds.amazonaws.com",
36+
database="mysql",
37+
secrets_manager_secret_id="arn:aws:secretsmanager:us-east-2:851725167871:secret:mysql_test_1-ZjindE",
38+
secrets_manager_region="us-east-2",
39+
plugins="aws_secrets_manager",
40+
secrets_manager_expiration=50
41+
) as awsconn, awsconn.cursor() as cursor:
42+
cursor.execute("SELECT @@aurora_server_id")
43+
for record in cursor.fetchone():
44+
print(record)

docs/using-the-python-driver/using-plugins/UsingTheAwsSecretsManagerPlugin.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ The following properties are required for the AWS Secrets Manager Connection Plu
2222
| `secrets_manager_secret_id` | String | Yes | Set this value to be the secret name or the secret ARN. | `secret_id` | `None` |
2323
| `secrets_manager_region` | String | Yes unless the `secrets_manager_secret_id` is a Secret ARN. | Set this value to be the region your secret is in. | `us-east-2` | `us-east-1` |
2424
| `secrets_manager_endpoint` | String | No | Set this value to be the endpoint override to retrieve your secret from. This parameter value should be in the form of a URL, with a valid protocol (ex. `http://`) and domain (ex. `localhost`). A port number is not required. | `http://localhost:1234` | `None` |
25+
| `secrets_manager_expiration`| int | No | Set this value to be the expiration time the secret is stored in the cache. If the value is -1, sets the expiration time to one year. | 500 | `-1` |
2526

2627
*NOTE* A Secret ARN has the following format: `arn:aws:secretsmanager:<Region>:<AccountId>:secret:Secre78tName-6RandomCharacters`
2728

tests/unit/test_secrets_manager_plugin.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828

2929
from typing import TYPE_CHECKING
3030

31+
from aws_advanced_python_wrapper.utils.cache_map import CacheItem, CacheMap
32+
3133
from aws_advanced_python_wrapper.aws_secrets_manager_plugin import \
3234
AwsSecretsManagerPlugin
3335

@@ -38,7 +40,7 @@
3840
from aws_advanced_python_wrapper.plugin_service import PluginService
3941

4042
from types import SimpleNamespace
41-
from typing import Callable, Dict, Tuple
43+
from typing import Callable, Tuple
4244
from unittest import TestCase
4345
from unittest.mock import MagicMock, patch
4446

@@ -64,6 +66,7 @@ class TestAwsSecretsManagerPlugin(TestCase):
6466
_SECRET_CACHE_KEY = (_TEST_SECRET_ID, _TEST_REGION, _TEST_ENDPOINT)
6567
_TEST_HOST_INFO = HostInfo(_TEST_HOST, _TEST_PORT)
6668
_TEST_SECRET = SimpleNamespace(username="testUser", password="testPassword")
69+
_ONE_YEAR_IN_NANOSECONDS = 60 * 60 * 24 * 365 * 1000
6770

6871
_MYSQL_HOST_INFO = HostInfo("mysql.testdb.us-east-2.rds.amazonaws.com")
6972
_PG_HOST_INFO = HostInfo("pg.testdb.us-east-2.rds.amazonaws.com")
@@ -80,7 +83,7 @@ class TestAwsSecretsManagerPlugin(TestCase):
8083
}
8184
}, "some_operation")
8285

83-
_secrets_cache: Dict[Tuple, SimpleNamespace] = {}
86+
_secrets_cache: CacheMap[Tuple, SimpleNamespace] = CacheMap()
8487

8588
_mock_func: Callable
8689
_mock_plugin_service: PluginService
@@ -111,7 +114,7 @@ def setUp(self):
111114

112115
@patch("aws_advanced_python_wrapper.aws_secrets_manager_plugin.AwsSecretsManagerPlugin._secrets_cache", _secrets_cache)
113116
def test_connect_with_cached_secrets(self):
114-
self._secrets_cache[self._SECRET_CACHE_KEY] = self._TEST_SECRET
117+
self._secrets_cache.put(self._SECRET_CACHE_KEY, self._TEST_SECRET, self._ONE_YEAR_IN_NANOSECONDS)
115118
target_plugin: AwsSecretsManagerPlugin = AwsSecretsManagerPlugin(self._mock_plugin_service,
116119
self._properties,
117120
self._mock_session)

0 commit comments

Comments
 (0)