diff --git a/moto/apigateway/models.py b/moto/apigateway/models.py index a8414594702f..fad7ca7e1485 100644 --- a/moto/apigateway/models.py +++ b/moto/apigateway/models.py @@ -68,7 +68,19 @@ ValidationException, VpcLinkNotFound, ) -from .utils import create_id, to_path +from .utils import ( + ApigwApiKeyIdentifier, + ApigwAuthorizerIdentifier, + ApigwDeploymentIdentifier, + ApigwModelIdentifier, + ApigwRequestValidatorIdentifier, + ApigwResourceIdentifier, + ApigwRestApiIdentifier, + ApigwUsagePlanIdentifier, + ApigwVpcLinkIdentifier, + create_id, + to_path, +) STAGE_URL = "https://{api_id}.execute-api.{region_name}.amazonaws.com/{stage_name}" PATCH_OPERATIONS = ["add", "remove", "replace", "move", "copy", "test"] @@ -789,6 +801,7 @@ def _apply_operation_to_variables(self, op: Dict[str, Any]) -> None: class ApiKey(BaseModel): def __init__( self, + api_key_id: str, name: Optional[str] = None, description: Optional[str] = None, enabled: bool = False, @@ -798,7 +811,7 @@ def __init__( tags: Optional[List[Dict[str, str]]] = None, customerId: Optional[str] = None, ): - self.id = create_id() + self.id = api_key_id self.value = value or "".join( random.sample(string.ascii_letters + string.digits, 40) ) @@ -846,6 +859,7 @@ def _str2bool(self, v: str) -> bool: class UsagePlan(BaseModel): def __init__( self, + usage_plan_id: str, name: Optional[str] = None, description: Optional[str] = None, apiStages: Any = None, @@ -854,7 +868,7 @@ def __init__( productCode: Optional[str] = None, tags: Optional[List[Dict[str, str]]] = None, ): - self.id = create_id() + self.id = usage_plan_id self.name = name self.description = description self.api_stages = apiStages or [] @@ -985,12 +999,13 @@ def to_json(self) -> Dict[str, Any]: class VpcLink(BaseModel): def __init__( self, + vpc_link_id: str, name: str, description: str, target_arns: List[str], tags: List[Dict[str, str]], ): - self.id = create_id() + self.id = vpc_link_id self.name = name self.description = description self.target_arns = target_arns @@ -1162,7 +1177,9 @@ def create_from_cloudformation_json( # type: ignore[misc] ) def add_child(self, path: str, parent_id: Optional[str] = None) -> Resource: - child_id = create_id() + child_id = ApigwResourceIdentifier( + self.account_id, self.region_name, parent_id or "", path + ).generate() child = Resource( resource_id=child_id, account_id=self.account_id, @@ -1181,7 +1198,9 @@ def add_model( schema: str, content_type: str, ) -> "Model": - model_id = create_id() + model_id = ApigwModelIdentifier( + self.account_id, self.region_name, name + ).generate() new_model = Model( model_id=model_id, name=name, @@ -1293,7 +1312,11 @@ def create_deployment( ) -> Deployment: if stage_variables is None: stage_variables = {} - deployment_id = create_id() + # Since there are no unique values to a deployment, we will use the stage name for the deployment. + # We are also passing a list of deployment ids to generate to prevent overwriting deployments. + deployment_id = ApigwDeploymentIdentifier( + self.account_id, self.region_name, stage_name=name + ).generate(list(self.deployments.keys())) deployment = Deployment(deployment_id, name, description) self.deployments[deployment_id] = deployment if name: @@ -1332,7 +1355,9 @@ def create_request_validator( validateRequestBody: Optional[bool], validateRequestParameters: Any, ) -> RequestValidator: - validator_id = create_id() + validator_id = ApigwRequestValidatorIdentifier( + self.account_id, self.region_name, name + ).generate() request_validator = RequestValidator( _id=validator_id, name=name, @@ -1631,7 +1656,9 @@ def create_rest_api( minimum_compression_size: Optional[int] = None, disable_execute_api_endpoint: Optional[bool] = None, ) -> RestAPI: - api_id = create_id() + api_id = ApigwRestApiIdentifier( + self.account_id, self.region_name, name + ).generate() rest_api = RestAPI( api_id, self.account_id, @@ -1882,7 +1909,9 @@ def create_authorizer( self, restapi_id: str, name: str, authorizer_type: str, **kwargs: Any ) -> Authorizer: api = self.get_rest_api(restapi_id) - authorizer_id = create_id() + authorizer_id = ApigwAuthorizerIdentifier( + self.account_id, self.region_name, name + ).generate() return api.create_authorizer( authorizer_id, name, @@ -2146,7 +2175,13 @@ def create_api_key(self, payload: Dict[str, Any]) -> ApiKey: for api_key in self.get_api_keys(): if api_key.value == payload["value"]: raise ApiKeyAlreadyExists() - key = ApiKey(**payload) + api_key_id = ApigwApiKeyIdentifier( + self.account_id, + self.region_name, + # The value of an api key must be unique on aws + payload.get("value", ""), + ).generate() + key = ApiKey(api_key_id=api_key_id, **payload) self.keys[key.id] = key return key @@ -2170,7 +2205,10 @@ def delete_api_key(self, api_key_id: str) -> None: self.keys.pop(api_key_id) def create_usage_plan(self, payload: Any) -> UsagePlan: - plan = UsagePlan(**payload) + usage_plan_id = ApigwUsagePlanIdentifier( + self.account_id, self.region_name, payload["name"] + ).generate() + plan = UsagePlan(usage_plan_id=usage_plan_id, **payload) self.usage_plans[plan.id] = plan return plan @@ -2497,8 +2535,15 @@ def create_vpc_link( target_arns: List[str], tags: List[Dict[str, str]], ) -> VpcLink: + vpc_link_id = ApigwVpcLinkIdentifier( + self.account_id, self.region_name, name + ).generate() vpc_link = VpcLink( - name, description=description, target_arns=target_arns, tags=tags + vpc_link_id, + name, + description=description, + target_arns=target_arns, + tags=tags, ) self.vpc_links[vpc_link.id] = vpc_link return vpc_link diff --git a/moto/apigateway/utils.py b/moto/apigateway/utils.py index 971fd218c496..3fcc25f2828e 100644 --- a/moto/apigateway/utils.py +++ b/moto/apigateway/utils.py @@ -1,10 +1,81 @@ import json import string -from typing import Any, Dict +from typing import Any, Dict, List, Union import yaml from moto.moto_api._internal import mock_random as random +from moto.utilities.id_generator import ResourceIdentifier, Tags, generate_str_id + + +class ApigwIdentifier(ResourceIdentifier): + service = "apigateway" + + def __init__(self, account_id: str, region: str, name: str): + super().__init__(account_id, region, name) + + def generate( + self, existing_ids: Union[List[str], None] = None, tags: Tags = None + ) -> str: + return generate_str_id( + resource_identifier=self, + existing_ids=existing_ids, + tags=tags, + length=10, + include_digits=True, + lower_case=True, + ) + + +class ApigwApiKeyIdentifier(ApigwIdentifier): + resource = "api_key" + + def __init__(self, account_id: str, region: str, value: str): + super().__init__(account_id, region, value) + + +class ApigwAuthorizerIdentifier(ApigwIdentifier): + resource = "authorizer" + + +class ApigwDeploymentIdentifier(ApigwIdentifier): + resource = "deployment" + + def __init__(self, account_id: str, region: str, stage_name: str): + super().__init__(account_id, region, stage_name) + + +class ApigwModelIdentifier(ApigwIdentifier): + resource = "model" + + +class ApigwRequestValidatorIdentifier(ApigwIdentifier): + resource = "request_validator" + + +class ApigwResourceIdentifier(ApigwIdentifier): + resource = "resource" + + def __init__( + self, account_id: str, region: str, parent_id: str = "", path_name: str = "/" + ): + super().__init__( + account_id, + region, + ".".join((parent_id, path_name)), + ) + + +class ApigwRestApiIdentifier(ApigwIdentifier): + resource = "rest_api" + + +class ApigwUsagePlanIdentifier(ApigwIdentifier): + resource = "usage_plan" + + +class ApigwVpcLinkIdentifier(ApigwIdentifier): + resource = "vpc_link" def create_id() -> str: diff --git a/moto/secretsmanager/models.py b/moto/secretsmanager/models.py index bbebebb58099..8d3d1d5a2e81 100644 --- a/moto/secretsmanager/models.py +++ b/moto/secretsmanager/models.py @@ -26,7 +26,11 @@ tag_key, tag_value, ) -from .utils import get_secret_name_from_partial_arn, random_password, secret_arn +from .utils import ( + SecretsManagerSecretIdentifier, + get_secret_name_from_partial_arn, + random_password, +) MAX_RESULTS_DEFAULT = 100 @@ -94,7 +98,9 @@ def __init__( ): self.secret_id = secret_id self.name = secret_id - self.arn = secret_arn(account_id, region_name, secret_id) + self.arn = SecretsManagerSecretIdentifier( + account_id, region_name, secret_id + ).generate() self.account_id = account_id self.region = region_name self.secret_string = secret_string @@ -935,7 +941,9 @@ def delete_secret( if not force_delete_without_recovery: raise SecretNotFoundException() else: - arn = secret_arn(self.account_id, self.region_name, secret_id=secret_id) + arn = SecretsManagerSecretIdentifier( + self.account_id, self.region_name, secret_id=secret_id + ).generate() name = secret_id deletion_date = utcnow() return arn, name, self._unix_time_secs(deletion_date) diff --git a/moto/secretsmanager/utils.py b/moto/secretsmanager/utils.py index e5435a6c9dd8..e4b55622b176 100644 --- a/moto/secretsmanager/utils.py +++ b/moto/secretsmanager/utils.py @@ -2,6 +2,12 @@ import string from moto.moto_api._internal import mock_random as random +from moto.utilities.id_generator import ( + ExistingIds, + ResourceIdentifier, + Tags, + generate_str_id, +) from moto.utilities.utils import ARN_PARTITION_REGEX, get_partition @@ -62,11 +68,6 @@ def random_password( return password -def secret_arn(account_id: str, region: str, secret_id: str) -> str: - id_string = "".join(random.choice(string.ascii_letters) for _ in range(6)) - return f"arn:{get_partition(region)}:secretsmanager:{region}:{account_id}:secret:{secret_id}-{id_string}" - - def get_secret_name_from_partial_arn(partial_arn: str) -> str: # We can retrieve a secret either using a full ARN, or using a partial ARN # name: testsecret @@ -99,3 +100,24 @@ def _add_password_require_each_included_type( password_with_required_char += required_characters return password_with_required_char + + +class SecretsManagerSecretIdentifier(ResourceIdentifier): + service = "secretsmanager" + resource = "secret" + + def __init__(self, account_id: str, region: str, secret_id: str): + super().__init__(account_id, region, name=secret_id) + + def generate(self, existing_ids: ExistingIds = None, tags: Tags = None) -> str: + id_string = generate_str_id( + resource_identifier=self, + existing_ids=existing_ids, + tags=tags, + length=6, + include_digits=False, + ) + return ( + f"arn:{get_partition(self.region)}:secretsmanager:{self.region}:" + f"{self.account_id}:secret:{self.name}-{id_string}" + ) diff --git a/moto/utilities/id_generator.py b/moto/utilities/id_generator.py new file mode 100644 index 000000000000..c5f0a2c938f5 --- /dev/null +++ b/moto/utilities/id_generator.py @@ -0,0 +1,180 @@ +import abc +import logging +import threading +from typing import Any, Callable, Dict, List, TypedDict, Union + +from moto.moto_api._internal import mock_random + +log = logging.getLogger(__name__) + +ExistingIds = Union[List[str], None] +Tags = Union[Dict[str, str], None] + +# Custom resource tag to override the generated resource ID. +TAG_KEY_CUSTOM_ID = "_custom_id_" + + +class IdSourceContext(TypedDict, total=False): + resource_identifier: "ResourceIdentifier" + tags: Tags + existing_ids: ExistingIds + + +class ResourceIdentifier(abc.ABC): + """ + Base class for resource identifiers. When implementing a new resource, it is important to set + the service and resource as they will be used to create the unique identifier for that resource. + + It is recommended to implement the `generate` method using functions decorated with `@moto_id`. + This will ensure that your resource can be assigned a custom id. + """ + + service: str + resource: str + + def __init__(self, account_id: str, region: str, name: str): + self.account_id = account_id + self.region = region + self.name = name or "" + + @abc.abstractmethod + def generate(self, existing_ids: ExistingIds = None, tags: Tags = None) -> str: + """Method to generate a resource id""" + + @property + def unique_identifier(self) -> str: + return ".".join( + [self.account_id, self.region, self.service, self.resource, self.name] + ) + + def __str__(self) -> str: + return self.unique_identifier + + +class MotoIdManager: + """class to manage custom ids. Do not create instance and instead + use the `id_manager` instance created below.""" + + _custom_ids: Dict[str, str] + _id_sources: List[Callable[[IdSourceContext], Union[str, None]]] + + _lock: threading.RLock + + def __init__(self) -> None: + self._custom_ids = {} + self._lock = threading.RLock() + self._id_sources = [] + + self.add_id_source(self.get_id_from_tags) + self.add_id_source(self.get_custom_id_from_context) + + def get_custom_id( + self, resource_identifier: ResourceIdentifier + ) -> Union[str, None]: + # retrieves a custom_id for a resource. Returns None + return self._custom_ids.get(resource_identifier.unique_identifier) + + def set_custom_id( + self, resource_identifier: ResourceIdentifier, custom_id: str + ) -> None: + # Do not set a custom_id for a resource no value was found for the name + if not resource_identifier.name: + return + with self._lock: + self._custom_ids[resource_identifier.unique_identifier] = custom_id + + def unset_custom_id(self, resource_identifier: ResourceIdentifier) -> None: + # removes a set custom_id for a resource + with self._lock: + self._custom_ids.pop(resource_identifier.unique_identifier, None) + + def add_id_source( + self, id_source: Callable[[IdSourceContext], Union[str, None]] + ) -> None: + self._id_sources.append(id_source) + + @staticmethod + def get_id_from_tags(id_source_context: IdSourceContext) -> Union[str, None]: + if tags := id_source_context.get("tags"): + return tags.get(TAG_KEY_CUSTOM_ID) + + return None + + def get_custom_id_from_context( + self, id_source_context: IdSourceContext + ) -> Union[str, None]: + # retrieves a custom_id for a resource. Returns None + if resource_identifier := id_source_context.get("resource_identifier"): + return self.get_custom_id(resource_identifier) + return None + + def find_id_from_sources( + self, id_source_context: IdSourceContext + ) -> Union[str, None]: + existing_ids = id_source_context.get("existing_ids") or [] + for id_source in self._id_sources: + if found_id := id_source(id_source_context): + if found_id in existing_ids: + log.debug( + f"Found id {found_id} for resource {id_source_context.get('resource_identifier')}, " + "but a resource already exists with this id." + ) + else: + return found_id + + return None + + +moto_id_manager = MotoIdManager() + + +def moto_id(fn: Callable[..., str]) -> Callable[..., str]: + """ + Decorator for helping in creation of static ids. + + The decorated function should accept the following parameters + + :param resource_identifier + :param existing_ids + If provided, we will omit returning a custom id if it is already on the list + :param tags + If provided will look for a tag named `_custom_id_`. This will take precedence over registered custom ids + """ + + def _wrapper( + resource_identifier: ResourceIdentifier, + existing_ids: ExistingIds = None, + tags: Tags = None, + **kwargs: Dict[str, Any], + ) -> str: + if resource_identifier and ( + found_id := moto_id_manager.find_id_from_sources( + IdSourceContext( + resource_identifier=resource_identifier, + existing_ids=existing_ids, + tags=tags, + ) + ) + ): + return found_id + + return fn( + resource_identifier=resource_identifier, + existing_ids=existing_ids, + tags=tags, + **kwargs, + ) + + return _wrapper + + +@moto_id +def generate_str_id( # type: ignore + resource_identifier: ResourceIdentifier, + existing_ids: ExistingIds = None, + tags: Tags = None, + length: int = 20, + include_digits: bool = True, + lower_case: bool = False, +) -> str: + return mock_random.get_random_string(length, include_digits, lower_case) diff --git a/tests/conftest.py b/tests/conftest.py index f723cf5f3d16..3a1f20cb0fb8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ import pytest from moto import mock_aws +from moto.utilities.id_generator import ResourceIdentifier, moto_id_manager @pytest.fixture(scope="function") @@ -16,3 +17,19 @@ def account_id(): with mock_aws(): identity = boto3.client("sts", "us-east-1").get_caller_identity() yield identity["Account"] + + +@pytest.fixture +def set_custom_id(): + set_ids = [] + + def _set_custom_id(resource_identifier: ResourceIdentifier, custom_id): + moto_id_manager.set_custom_id( + resource_identifier=resource_identifier, custom_id=custom_id + ) + set_ids.append(resource_identifier) + + yield _set_custom_id + + for resource_identifier in set_ids: + moto_id_manager.unset_custom_id(resource_identifier) diff --git a/tests/test_apigateway/test_apigateway_custom_ids.py b/tests/test_apigateway/test_apigateway_custom_ids.py new file mode 100644 index 000000000000..1d2debbbafff --- /dev/null +++ b/tests/test_apigateway/test_apigateway_custom_ids.py @@ -0,0 +1,140 @@ +import boto3 +import pytest + +from moto import mock_aws, settings +from moto.apigateway.utils import ( + ApigwApiKeyIdentifier, + ApigwDeploymentIdentifier, + ApigwModelIdentifier, + ApigwRequestValidatorIdentifier, + ApigwResourceIdentifier, + ApigwRestApiIdentifier, + ApigwUsagePlanIdentifier, +) + +API_ID = "ApiId" +API_KEY_ID = "ApiKeyId" +DEPLOYMENT_ID = "DeployId" +MODEL_ID = "ModelId" +PET_1_RESOURCE_ID = "Pet1Id" +PET_2_RESOURCE_ID = "Pet2Id" +REQUEST_VALIDATOR_ID = "ReqValId" +ROOT_RESOURCE_ID = "RootId" +USAGE_PLAN_ID = "UPlanId" + + +@mock_aws +@pytest.mark.skipif( + not settings.TEST_DECORATOR_MODE, reason="Can't access the id manager in proxy mode" +) +def test_custom_id_rest_api(set_custom_id, account_id): + region_name = "us-west-2" + rest_api_name = "my-api" + model_name = "modelName" + request_validator_name = "request-validator-name" + stage_name = "stage-name" + + client = boto3.client("apigateway", region_name=region_name) + + set_custom_id( + ApigwRestApiIdentifier(account_id, region_name, rest_api_name), API_ID + ) + set_custom_id( + ApigwResourceIdentifier(account_id, region_name, path_name="/"), + ROOT_RESOURCE_ID, + ) + set_custom_id( + ApigwResourceIdentifier( + account_id, region_name, parent_id=ROOT_RESOURCE_ID, path_name="pet" + ), + PET_1_RESOURCE_ID, + ) + set_custom_id( + ApigwResourceIdentifier( + account_id, region_name, parent_id=PET_1_RESOURCE_ID, path_name="pet" + ), + PET_2_RESOURCE_ID, + ) + set_custom_id(ApigwModelIdentifier(account_id, region_name, model_name), MODEL_ID) + set_custom_id( + ApigwRequestValidatorIdentifier( + account_id, region_name, request_validator_name + ), + REQUEST_VALIDATOR_ID, + ) + set_custom_id( + ApigwDeploymentIdentifier(account_id, region_name, stage_name=stage_name), + DEPLOYMENT_ID, + ) + + rest_api = client.create_rest_api(name=rest_api_name) + assert rest_api["id"] == API_ID + assert rest_api["rootResourceId"] == ROOT_RESOURCE_ID + + pet_resource_1 = client.create_resource( + restApiId=API_ID, parentId=ROOT_RESOURCE_ID, pathPart="pet" + ) + assert pet_resource_1["id"] == PET_1_RESOURCE_ID + + # we create a second resource with the same path part to ensure we can pass different ids + pet_resource_2 = client.create_resource( + restApiId=API_ID, parentId=PET_1_RESOURCE_ID, pathPart="pet" + ) + assert pet_resource_2["id"] == PET_2_RESOURCE_ID + + model = client.create_model( + restApiId=API_ID, + name=model_name, + schema="EMPTY", + contentType="application/json", + ) + assert model["id"] == MODEL_ID + + request_validator = client.create_request_validator( + restApiId=API_ID, name=request_validator_name + ) + assert request_validator["id"] == REQUEST_VALIDATOR_ID + + # Creating the resource to make a deployment + client.put_method( + restApiId=API_ID, + httpMethod="ANY", + resourceId=PET_2_RESOURCE_ID, + authorizationType="NONE", + ) + client.put_integration( + restApiId=API_ID, resourceId=PET_2_RESOURCE_ID, httpMethod="ANY", type="MOCK" + ) + deployment = client.create_deployment(restApiId=API_ID, stageName=stage_name) + assert deployment["id"] == DEPLOYMENT_ID + + +@mock_aws +@pytest.mark.skipif( + not settings.TEST_DECORATOR_MODE, reason="Can't access the id manager in proxy mode" +) +def test_custom_id_api_key(account_id, set_custom_id): + region_name = "us-west-2" + api_key_value = "01234567890123456789" + usage_plan_name = "usage-plan" + + client = boto3.client("apigateway", region_name=region_name) + + set_custom_id( + ApigwApiKeyIdentifier(account_id, region_name, value=api_key_value), API_KEY_ID + ) + set_custom_id( + ApigwUsagePlanIdentifier(account_id, region_name, usage_plan_name), + USAGE_PLAN_ID, + ) + + api_key = client.create_api_key(name="api-key", value=api_key_value) + usage_plan = client.create_usage_plan(name=usage_plan_name) + + # verify that we can create a usage plan key using the custom ids + client.create_usage_plan_key( + usagePlanId=USAGE_PLAN_ID, keyId=API_KEY_ID, keyType="API_KEY" + ) + + assert api_key["id"] == API_KEY_ID + assert usage_plan["id"] == USAGE_PLAN_ID diff --git a/tests/test_secretsmanager/test_secretsmanager.py b/tests/test_secretsmanager/test_secretsmanager.py index 8879b64f00b3..734fcf38420c 100644 --- a/tests/test_secretsmanager/test_secretsmanager.py +++ b/tests/test_secretsmanager/test_secretsmanager.py @@ -13,6 +13,7 @@ from moto import mock_aws, settings from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID +from moto.secretsmanager.utils import SecretsManagerSecretIdentifier from . import secretsmanager_aws_verified @@ -1961,3 +1962,23 @@ def test_update_secret_version_stage_dont_specify_current_stage(secret_arn=None) err["Message"] == f"The parameter RemoveFromVersionId can't be empty. Staging label AWSCURRENT is currently attached to version {current_version}, so you must explicitly reference that version in RemoveFromVersionId." ) + + +@mock_aws +@pytest.mark.skipif( + not settings.TEST_DECORATOR_MODE, reason="Can't access the id manager in proxy mode" +) +def test_create_secret_custom_id(account_id, set_custom_id): + secret_suffix = "randomSuffix" + secret_name = "secret-name" + region_name = "us-east-1" + + client = boto3.client("secretsmanager", region_name=region_name) + + set_custom_id( + SecretsManagerSecretIdentifier(account_id, region_name, secret_name), + secret_suffix, + ) + secret = client.create_secret(Name=secret_name, SecretString="my secret") + + assert secret["ARN"].split(":")[-1] == f"{secret_name}-{secret_suffix}" diff --git a/tests/test_utilities/test_id_generator.py b/tests/test_utilities/test_id_generator.py new file mode 100644 index 000000000000..7ee85fd32f3f --- /dev/null +++ b/tests/test_utilities/test_id_generator.py @@ -0,0 +1,134 @@ +from moto.utilities.id_generator import ( + TAG_KEY_CUSTOM_ID, + ExistingIds, + ResourceIdentifier, + Tags, + moto_id, + moto_id_manager, +) + +ACCOUNT = "account" +REGION = "us-east-1" +RESOURCE_NAME = "my-resource" + +CUSTOM_ID = "custom" +GENERATED_ID = "generated" +TAG_ID = "fromTag" +SERVICE = "test-service" +RESOURCE = "test-resource" + + +@moto_id +def generate_test_id( + resource_identifier: ResourceIdentifier, + existing_ids: ExistingIds = None, + tags: Tags = None, +): + return GENERATED_ID + + +class TestResourceIdentifier(ResourceIdentifier): + service = SERVICE + resource = RESOURCE + + def generate(self, existing_ids: ExistingIds = None, tags: Tags = None) -> str: + return generate_test_id( + resource_identifier=self, existing_ids=existing_ids, tags=tags + ) + + +def test_generate_with_no_resource_identifier(): + generated_id = generate_test_id(None) + assert generated_id == GENERATED_ID + + +def test_generate_with_matching_resource_identifier(set_custom_id): + resource_identifier = TestResourceIdentifier(ACCOUNT, REGION, RESOURCE_NAME) + + set_custom_id(resource_identifier, CUSTOM_ID) + + generated_id = generate_test_id(resource_identifier=resource_identifier) + assert generated_id == CUSTOM_ID + + +def test_generate_with_non_matching_resource_identifier(set_custom_id): + resource_identifier = TestResourceIdentifier(ACCOUNT, REGION, RESOURCE_NAME) + resource_identifier_2 = TestResourceIdentifier(ACCOUNT, REGION, "non-matching") + + set_custom_id(resource_identifier, CUSTOM_ID) + + generated_id = generate_test_id(resource_identifier=resource_identifier_2) + assert generated_id == GENERATED_ID + + +def test_generate_with_custom_id_tag(): + resource_identifier = TestResourceIdentifier(ACCOUNT, REGION, RESOURCE_NAME) + + generated_id = generate_test_id( + resource_identifier=resource_identifier, tags={TAG_KEY_CUSTOM_ID: TAG_ID} + ) + assert generated_id == TAG_ID + + +def test_generate_with_custom_id_tag_has_priority(set_custom_id): + resource_identifier = TestResourceIdentifier(ACCOUNT, REGION, RESOURCE_NAME) + + set_custom_id(resource_identifier, CUSTOM_ID) + generated_id = generate_test_id( + resource_identifier=resource_identifier, tags={TAG_KEY_CUSTOM_ID: TAG_ID} + ) + assert generated_id == TAG_ID + + +def test_generate_with_existing_id(set_custom_id): + resource_identifier = TestResourceIdentifier(ACCOUNT, REGION, RESOURCE_NAME) + + set_custom_id(resource_identifier, CUSTOM_ID) + generated_id = generate_test_id( + resource_identifier=resource_identifier, existing_ids=[CUSTOM_ID] + ) + assert generated_id == GENERATED_ID + + +def test_generate_with_tags_and_existing_id(set_custom_id): + resource_identifier = TestResourceIdentifier(ACCOUNT, REGION, RESOURCE_NAME) + + generated_id = generate_test_id( + resource_identifier=resource_identifier, + existing_ids=[TAG_ID], + tags={TAG_KEY_CUSTOM_ID: TAG_ID}, + ) + assert generated_id == GENERATED_ID + + +def test_generate_with_tags_fallback(set_custom_id): + resource_identifier = TestResourceIdentifier(ACCOUNT, REGION, RESOURCE_NAME) + + set_custom_id(resource_identifier, CUSTOM_ID) + generated_id = generate_test_id( + resource_identifier=resource_identifier, + existing_ids=[TAG_ID], + tags={TAG_KEY_CUSTOM_ID: TAG_ID}, + ) + assert generated_id == CUSTOM_ID + + +def test_set_custom_id_lifecycle(): + resource_identifier = TestResourceIdentifier(ACCOUNT, REGION, RESOURCE_NAME) + + moto_id_manager.set_custom_id(resource_identifier, CUSTOM_ID) + + found_id = moto_id_manager.get_custom_id(resource_identifier) + assert found_id == CUSTOM_ID + + moto_id_manager.unset_custom_id(resource_identifier) + + found_id = moto_id_manager.get_custom_id(resource_identifier) + assert found_id is None + + +def test_set_custom_id_name_is_not_set(): + resource_identifier = TestResourceIdentifier(ACCOUNT, REGION, None) + moto_id_manager.set_custom_id(resource_identifier, CUSTOM_ID) + + assert moto_id_manager._custom_ids == {}