Skip to content

Commit

Permalink
Support custom ids (#8216)
Browse files Browse the repository at this point in the history
  • Loading branch information
cloutierMat authored Oct 13, 2024
1 parent 81d6f95 commit 1ace1db
Show file tree
Hide file tree
Showing 9 changed files with 660 additions and 22 deletions.
71 changes: 58 additions & 13 deletions moto/apigateway/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,19 @@
ValidationException,
VpcLinkNotFound,
)
from .utils import create_id, to_path
from .utils import (
ApigwApiKeyIdentifier,
ApigwAuthorizerIdentifier,
ApigwDeploymentIdentifier,
ApigwModelIdentifier,
ApigwRequestValidatorIdentifier,
ApigwResourceIdentifier,
ApigwRestApiIdentifier,
ApigwUsagePlanIdentifier,
ApigwVpcLinkIdentifier,
create_id,
to_path,
)

STAGE_URL = "https://{api_id}.execute-api.{region_name}.amazonaws.com/{stage_name}"
PATCH_OPERATIONS = ["add", "remove", "replace", "move", "copy", "test"]
Expand Down Expand Up @@ -789,6 +801,7 @@ def _apply_operation_to_variables(self, op: Dict[str, Any]) -> None:
class ApiKey(BaseModel):
def __init__(
self,
api_key_id: str,
name: Optional[str] = None,
description: Optional[str] = None,
enabled: bool = False,
Expand All @@ -798,7 +811,7 @@ def __init__(
tags: Optional[List[Dict[str, str]]] = None,
customerId: Optional[str] = None,
):
self.id = create_id()
self.id = api_key_id
self.value = value or "".join(
random.sample(string.ascii_letters + string.digits, 40)
)
Expand Down Expand Up @@ -846,6 +859,7 @@ def _str2bool(self, v: str) -> bool:
class UsagePlan(BaseModel):
def __init__(
self,
usage_plan_id: str,
name: Optional[str] = None,
description: Optional[str] = None,
apiStages: Any = None,
Expand All @@ -854,7 +868,7 @@ def __init__(
productCode: Optional[str] = None,
tags: Optional[List[Dict[str, str]]] = None,
):
self.id = create_id()
self.id = usage_plan_id
self.name = name
self.description = description
self.api_stages = apiStages or []
Expand Down Expand Up @@ -985,12 +999,13 @@ def to_json(self) -> Dict[str, Any]:
class VpcLink(BaseModel):
def __init__(
self,
vpc_link_id: str,
name: str,
description: str,
target_arns: List[str],
tags: List[Dict[str, str]],
):
self.id = create_id()
self.id = vpc_link_id
self.name = name
self.description = description
self.target_arns = target_arns
Expand Down Expand Up @@ -1162,7 +1177,9 @@ def create_from_cloudformation_json( # type: ignore[misc]
)

def add_child(self, path: str, parent_id: Optional[str] = None) -> Resource:
child_id = create_id()
child_id = ApigwResourceIdentifier(
self.account_id, self.region_name, parent_id or "", path
).generate()
child = Resource(
resource_id=child_id,
account_id=self.account_id,
Expand All @@ -1181,7 +1198,9 @@ def add_model(
schema: str,
content_type: str,
) -> "Model":
model_id = create_id()
model_id = ApigwModelIdentifier(
self.account_id, self.region_name, name
).generate()
new_model = Model(
model_id=model_id,
name=name,
Expand Down Expand Up @@ -1293,7 +1312,11 @@ def create_deployment(
) -> Deployment:
if stage_variables is None:
stage_variables = {}
deployment_id = create_id()
# Since there are no unique values to a deployment, we will use the stage name for the deployment.
# We are also passing a list of deployment ids to generate to prevent overwriting deployments.
deployment_id = ApigwDeploymentIdentifier(
self.account_id, self.region_name, stage_name=name
).generate(list(self.deployments.keys()))
deployment = Deployment(deployment_id, name, description)
self.deployments[deployment_id] = deployment
if name:
Expand Down Expand Up @@ -1332,7 +1355,9 @@ def create_request_validator(
validateRequestBody: Optional[bool],
validateRequestParameters: Any,
) -> RequestValidator:
validator_id = create_id()
validator_id = ApigwRequestValidatorIdentifier(
self.account_id, self.region_name, name
).generate()
request_validator = RequestValidator(
_id=validator_id,
name=name,
Expand Down Expand Up @@ -1631,7 +1656,9 @@ def create_rest_api(
minimum_compression_size: Optional[int] = None,
disable_execute_api_endpoint: Optional[bool] = None,
) -> RestAPI:
api_id = create_id()
api_id = ApigwRestApiIdentifier(
self.account_id, self.region_name, name
).generate()
rest_api = RestAPI(
api_id,
self.account_id,
Expand Down Expand Up @@ -1882,7 +1909,9 @@ def create_authorizer(
self, restapi_id: str, name: str, authorizer_type: str, **kwargs: Any
) -> Authorizer:
api = self.get_rest_api(restapi_id)
authorizer_id = create_id()
authorizer_id = ApigwAuthorizerIdentifier(
self.account_id, self.region_name, name
).generate()
return api.create_authorizer(
authorizer_id,
name,
Expand Down Expand Up @@ -2146,7 +2175,13 @@ def create_api_key(self, payload: Dict[str, Any]) -> ApiKey:
for api_key in self.get_api_keys():
if api_key.value == payload["value"]:
raise ApiKeyAlreadyExists()
key = ApiKey(**payload)
api_key_id = ApigwApiKeyIdentifier(
self.account_id,
self.region_name,
# The value of an api key must be unique on aws
payload.get("value", ""),
).generate()
key = ApiKey(api_key_id=api_key_id, **payload)
self.keys[key.id] = key
return key

Expand All @@ -2170,7 +2205,10 @@ def delete_api_key(self, api_key_id: str) -> None:
self.keys.pop(api_key_id)

def create_usage_plan(self, payload: Any) -> UsagePlan:
plan = UsagePlan(**payload)
usage_plan_id = ApigwUsagePlanIdentifier(
self.account_id, self.region_name, payload["name"]
).generate()
plan = UsagePlan(usage_plan_id=usage_plan_id, **payload)
self.usage_plans[plan.id] = plan
return plan

Expand Down Expand Up @@ -2497,8 +2535,15 @@ def create_vpc_link(
target_arns: List[str],
tags: List[Dict[str, str]],
) -> VpcLink:
vpc_link_id = ApigwVpcLinkIdentifier(
self.account_id, self.region_name, name
).generate()
vpc_link = VpcLink(
name, description=description, target_arns=target_arns, tags=tags
vpc_link_id,
name,
description=description,
target_arns=target_arns,
tags=tags,
)
self.vpc_links[vpc_link.id] = vpc_link
return vpc_link
Expand Down
73 changes: 72 additions & 1 deletion moto/apigateway/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,81 @@
import json
import string
from typing import Any, Dict
from typing import Any, Dict, List, Union

import yaml

from moto.moto_api._internal import mock_random as random
from moto.utilities.id_generator import ResourceIdentifier, Tags, generate_str_id


class ApigwIdentifier(ResourceIdentifier):
service = "apigateway"

def __init__(self, account_id: str, region: str, name: str):
super().__init__(account_id, region, name)

def generate(
self, existing_ids: Union[List[str], None] = None, tags: Tags = None
) -> str:
return generate_str_id(
resource_identifier=self,
existing_ids=existing_ids,
tags=tags,
length=10,
include_digits=True,
lower_case=True,
)


class ApigwApiKeyIdentifier(ApigwIdentifier):
resource = "api_key"

def __init__(self, account_id: str, region: str, value: str):
super().__init__(account_id, region, value)


class ApigwAuthorizerIdentifier(ApigwIdentifier):
resource = "authorizer"


class ApigwDeploymentIdentifier(ApigwIdentifier):
resource = "deployment"

def __init__(self, account_id: str, region: str, stage_name: str):
super().__init__(account_id, region, stage_name)


class ApigwModelIdentifier(ApigwIdentifier):
resource = "model"


class ApigwRequestValidatorIdentifier(ApigwIdentifier):
resource = "request_validator"


class ApigwResourceIdentifier(ApigwIdentifier):
resource = "resource"

def __init__(
self, account_id: str, region: str, parent_id: str = "", path_name: str = "/"
):
super().__init__(
account_id,
region,
".".join((parent_id, path_name)),
)


class ApigwRestApiIdentifier(ApigwIdentifier):
resource = "rest_api"


class ApigwUsagePlanIdentifier(ApigwIdentifier):
resource = "usage_plan"


class ApigwVpcLinkIdentifier(ApigwIdentifier):
resource = "vpc_link"


def create_id() -> str:
Expand Down
14 changes: 11 additions & 3 deletions moto/secretsmanager/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@
tag_key,
tag_value,
)
from .utils import get_secret_name_from_partial_arn, random_password, secret_arn
from .utils import (
SecretsManagerSecretIdentifier,
get_secret_name_from_partial_arn,
random_password,
)

MAX_RESULTS_DEFAULT = 100

Expand Down Expand Up @@ -94,7 +98,9 @@ def __init__(
):
self.secret_id = secret_id
self.name = secret_id
self.arn = secret_arn(account_id, region_name, secret_id)
self.arn = SecretsManagerSecretIdentifier(
account_id, region_name, secret_id
).generate()
self.account_id = account_id
self.region = region_name
self.secret_string = secret_string
Expand Down Expand Up @@ -935,7 +941,9 @@ def delete_secret(
if not force_delete_without_recovery:
raise SecretNotFoundException()
else:
arn = secret_arn(self.account_id, self.region_name, secret_id=secret_id)
arn = SecretsManagerSecretIdentifier(
self.account_id, self.region_name, secret_id=secret_id
).generate()
name = secret_id
deletion_date = utcnow()
return arn, name, self._unix_time_secs(deletion_date)
Expand Down
32 changes: 27 additions & 5 deletions moto/secretsmanager/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
import string

from moto.moto_api._internal import mock_random as random
from moto.utilities.id_generator import (
ExistingIds,
ResourceIdentifier,
Tags,
generate_str_id,
)
from moto.utilities.utils import ARN_PARTITION_REGEX, get_partition


Expand Down Expand Up @@ -62,11 +68,6 @@ def random_password(
return password


def secret_arn(account_id: str, region: str, secret_id: str) -> str:
id_string = "".join(random.choice(string.ascii_letters) for _ in range(6))
return f"arn:{get_partition(region)}:secretsmanager:{region}:{account_id}:secret:{secret_id}-{id_string}"


def get_secret_name_from_partial_arn(partial_arn: str) -> str:
# We can retrieve a secret either using a full ARN, or using a partial ARN
# name: testsecret
Expand Down Expand Up @@ -99,3 +100,24 @@ def _add_password_require_each_included_type(
password_with_required_char += required_characters

return password_with_required_char


class SecretsManagerSecretIdentifier(ResourceIdentifier):
service = "secretsmanager"
resource = "secret"

def __init__(self, account_id: str, region: str, secret_id: str):
super().__init__(account_id, region, name=secret_id)

def generate(self, existing_ids: ExistingIds = None, tags: Tags = None) -> str:
id_string = generate_str_id(
resource_identifier=self,
existing_ids=existing_ids,
tags=tags,
length=6,
include_digits=False,
)
return (
f"arn:{get_partition(self.region)}:secretsmanager:{self.region}:"
f"{self.account_id}:secret:{self.name}-{id_string}"
)
Loading

0 comments on commit 1ace1db

Please sign in to comment.