Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions aws_advanced_python_wrapper/aws_secrets_manager_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
from json import JSONDecodeError, loads
from re import search
from types import SimpleNamespace
from typing import TYPE_CHECKING, Callable, Dict, Optional, Set, Tuple
from typing import TYPE_CHECKING, Callable, Optional, Set, Tuple

import boto3
from botocore.exceptions import ClientError, EndpointConnectionError

from aws_advanced_python_wrapper.utils.cache_map import CacheMap

if TYPE_CHECKING:
from boto3 import Session
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
Expand All @@ -46,8 +48,10 @@ class AwsSecretsManagerPlugin(Plugin):
_SUBSCRIBED_METHODS: Set[str] = {"connect", "force_connect"}

_SECRETS_ARN_PATTERN = r"^arn:aws:secretsmanager:(?P<region>[^:\n]*):[^:\n]*:([^:/\n]*[:/])?(.*)$"
_ONE_YEAR_IN_SECONDS = 60 * 60 * 24 * 365

_secrets_cache: Dict[Tuple, SimpleNamespace] = {}
_secret: Optional[SimpleNamespace] = None
_secrets_cache: CacheMap[Tuple, SimpleNamespace] = CacheMap()
_secret_key: Tuple = ()

@property
Expand Down Expand Up @@ -94,7 +98,13 @@ def force_connect(
return self._connect(props, force_connect_func)

def _connect(self, props: Properties, connect_func: Callable) -> Connection:
secret_fetched: bool = self._update_secret()
token_expiration_sec: int = WrapperProperties.SECRETS_MANAGER_EXPIRATION.get_int(props)
# if value is less than 0, default to one year
if token_expiration_sec < 0:
token_expiration_sec = AwsSecretsManagerPlugin._ONE_YEAR_IN_SECONDS
token_expiration_ns = token_expiration_sec * 1_000_000_000

secret_fetched: bool = self._update_secret(token_expiration_ns=token_expiration_ns)

try:
self._apply_secret_to_properties(props)
Expand All @@ -105,7 +115,7 @@ def _connect(self, props: Properties, connect_func: Callable) -> Connection:
raise AwsWrapperError(
Messages.get_formatted("AwsSecretsManagerPlugin.ConnectException", e)) from e

secret_fetched = self._update_secret(True)
secret_fetched = self._update_secret(token_expiration_ns=token_expiration_ns, force_refetch=True)

if secret_fetched:
try:
Expand All @@ -117,9 +127,10 @@ def _connect(self, props: Properties, connect_func: Callable) -> Connection:
unhandled_error)) from unhandled_error
raise AwsWrapperError(Messages.get_formatted("AwsSecretsManagerPlugin.FailedLogin", e)) from e

def _update_secret(self, force_refetch: bool = False) -> bool:
def _update_secret(self, token_expiration_ns: int, force_refetch: bool = False) -> bool:
"""
Called to update credentials from the cache, or from the AWS Secrets Manager service.
:param token_expiration_ns: Expiration time in nanoseconds for secret stored in cache.
:param force_refetch: Allows ignoring cached credentials and force fetches the latest credentials from the service.
:return: `True`, if credentials were fetched from the service.
"""
Expand All @@ -135,7 +146,7 @@ def _update_secret(self, force_refetch: bool = False) -> bool:
try:
self._secret = self._fetch_latest_credentials()
if self._secret:
AwsSecretsManagerPlugin._secrets_cache[self._secret_key] = self._secret
AwsSecretsManagerPlugin._secrets_cache.put(self._secret_key, self._secret, token_expiration_ns)
Copy link
Contributor

@karenc-bq karenc-bq Jun 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the way CacheMap checks for expiration:

    def is_expired(self) -> bool:
        return time.perf_counter_ns() > self._expiration_time_ns

This should be

Suggested change
AwsSecretsManagerPlugin._secrets_cache.put(self._secret_key, self._secret, token_expiration_ns)
AwsSecretsManagerPlugin._secrets_cache.put(self._secret_key, self._secret, time.perf_counter_ns() + token_expiration_ns)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put already adds the time.perf_counter_ns()

def put(self, key: K, item: V, item_expiration_ns: int): self._cache[key] = CacheItem(item, time.perf_counter_ns() + item_expiration_ns) self._cleanup()

fetched = True
except (ClientError, AttributeError) as e:
logger.debug("AwsSecretsManagerPlugin.FailedToFetchDbCredentials", e)
Expand Down
8 changes: 4 additions & 4 deletions aws_advanced_python_wrapper/utils/cache_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,12 @@ def _cleanup(self):


class CacheItem(Generic[V]):
def __init__(self, item: V, expiration_time: int):
def __init__(self, item: V, expiration_time_ns: int):
self.item = item
self._expiration_time = expiration_time
self._expiration_time_ns = expiration_time_ns

def __str__(self):
return f"CacheItem [item={str(self.item)}, expiration_time={self._expiration_time}]"
return f"CacheItem [item={str(self.item)}, expiration_time={self._expiration_time_ns}]"

def is_expired(self) -> bool:
return time.perf_counter_ns() > self._expiration_time
return time.perf_counter_ns() > self._expiration_time_ns
7 changes: 6 additions & 1 deletion aws_advanced_python_wrapper/utils/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ class WrapperProperties:
SECRETS_MANAGER_ENDPOINT = WrapperProperty(
"secrets_manager_endpoint",
"The endpoint of the secret to retrieve.")
SECRETS_MANAGER_EXPIRATION = WrapperProperty(
"secrets_manager_expiration",
"Secret cache expiration in seconds",
60 * 60 * 24 * 365)

DIALECT = WrapperProperty("wrapper_dialect", "A unique identifier for the supported database dialect.")
AUXILIARY_QUERY_TIMEOUT_SEC = WrapperProperty(
Expand Down Expand Up @@ -264,7 +268,8 @@ class WrapperProperties:
True)

# Host Selector
ROUND_ROBIN_DEFAULT_WEIGHT = WrapperProperty("round_robin_default_weight", "The default weight for any hosts that have not been " +
ROUND_ROBIN_DEFAULT_WEIGHT = WrapperProperty("round_robin_default_weight",
"The default weight for any hosts that have not been " +
"configured with the `round_robin_host_weight_pairs` parameter.",
1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ The following properties are required for the AWS Secrets Manager Connection Plu
| `secrets_manager_endpoint` | String | No | Set this value to be the endpoint override to retrieve your secret from. This parameter value should be in the form of a URL, with a valid protocol (ex. `http://`) and domain (ex. `localhost`). A port number is not required. | `http://localhost:1234` | `None` |
| `secrets_manager_secret_username` | String | No | Set this value to be the key in the JSON secret that contains the username for database connection. | `username_key` | `username` |
| `secrets_manager_secret_password` | String | No | SSet this value to be the key in the JSON secret that contains the password for database connection. | `password_key` | `password` |
| `secrets_manager_expiration` | int | No | Set this value to be the expiration time in seconds the secret is stored in the cache. If the value is below 0, sets the expiration time to one year in seconds. | 500 | 31536000 |

*NOTE* A Secret ARN has the following format: `arn:aws:secretsmanager:<Region>:<AccountId>:secret:Secre78tName-6RandomCharacters`

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ The host response time is measured at an interval set by `response_measurement_i

## Using the Fastest Response Strategy Plugin

The plugin can be loaded by adding the plugin code `fastest_response_strategy` to the [`plugins`](../UsingThePythonDriver.md#aws-advanced-python-driver-parameters) parameter. The Fastest Response Strategy Plugin is not loaded by default, and must be loaded along with the [`read_write_splitting`](https://github.com/awslabs/aws-advanced-python-wrapper/blob/main/docs/using-the-python-driver/using-plugins/UsingTheReadWriteSplittingPlugin.md) plugin.
The plugin can be loaded by adding the plugin code `fastest_response_strategy` to the [`plugins`](../UsingThePythonDriver.md#aws-advanced-python-driver-parameters) parameter. The Fastest Response Strategy Plugin is not loaded by default, and must be loaded along with the [`read_write_splitting`](./UsingTheReadWriteSplittingPlugin.md) plugin.

> [!IMPORTANT]\
> **`reader_response_strategy` must be set to `fastest_reponse` when using this plugin. Otherwise an error will be thrown:**
Expand Down
8 changes: 5 additions & 3 deletions tests/unit/test_secrets_manager_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from aws_advanced_python_wrapper.aws_secrets_manager_plugin import \
AwsSecretsManagerPlugin
from aws_advanced_python_wrapper.utils.cache_map import CacheMap

if TYPE_CHECKING:
from boto3 import Session, client
Expand All @@ -38,7 +39,7 @@
from aws_advanced_python_wrapper.plugin_service import PluginService

from types import SimpleNamespace
from typing import Callable, Dict, Tuple
from typing import Callable, Tuple
from unittest import TestCase
from unittest.mock import MagicMock, patch

Expand Down Expand Up @@ -66,6 +67,7 @@ class TestAwsSecretsManagerPlugin(TestCase):
_SECRET_CACHE_KEY = (_TEST_SECRET_ID, _TEST_REGION, _TEST_ENDPOINT)
_TEST_HOST_INFO = HostInfo(_TEST_HOST, _TEST_PORT)
_TEST_SECRET = SimpleNamespace(username="testUser", password="testPassword")
_ONE_YEAR_IN_NANOSECONDS = 60 * 60 * 24 * 365 * 1000

_MYSQL_HOST_INFO = HostInfo("mysql.testdb.us-east-2.rds.amazonaws.com")
_PG_HOST_INFO = HostInfo("pg.testdb.us-east-2.rds.amazonaws.com")
Expand All @@ -82,7 +84,7 @@ class TestAwsSecretsManagerPlugin(TestCase):
}
}, "some_operation")

_secrets_cache: Dict[Tuple, SimpleNamespace] = {}
_secrets_cache: CacheMap[Tuple, SimpleNamespace] = CacheMap()

_mock_func: Callable
_mock_plugin_service: PluginService
Expand Down Expand Up @@ -113,7 +115,7 @@ def setUp(self):

@patch("aws_advanced_python_wrapper.aws_secrets_manager_plugin.AwsSecretsManagerPlugin._secrets_cache", _secrets_cache)
def test_connect_with_cached_secrets(self):
self._secrets_cache[self._SECRET_CACHE_KEY] = self._TEST_SECRET
self._secrets_cache.put(self._SECRET_CACHE_KEY, self._TEST_SECRET, self._ONE_YEAR_IN_NANOSECONDS)
target_plugin: AwsSecretsManagerPlugin = AwsSecretsManagerPlugin(self._mock_plugin_service,
self._properties,
self._mock_session)
Expand Down