diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/agent.py index 2fe072fc87..93ce7638d0 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/agent.py @@ -1,6 +1,4 @@ -import json from dataclasses import dataclass -from datetime import datetime from typing import Any, Dict, Optional import cloudpickle @@ -15,7 +13,7 @@ from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate -from .boto3_mixin import Boto3AgentMixin +from .boto3_mixin import Boto3AgentMixin, CustomException @dataclass @@ -39,14 +37,6 @@ def decode(cls, data: bytes) -> "SageMakerEndpointMetadata": } -class DateTimeEncoder(json.JSONEncoder): - def default(self, o): - if isinstance(o, datetime): - return o.isoformat() - - return json.JSONEncoder.default(self, o) - - class SageMakerEndpointAgent(Boto3AgentMixin, AsyncAgentBase): """This agent creates an endpoint.""" @@ -66,22 +56,49 @@ async def create( config = custom.get("config") region = custom.get("region") - await self._call( - method="create_endpoint", - config=config, - inputs=inputs, - region=region, - ) + try: + await self._call( + method="create_endpoint", + config=config, + inputs=inputs, + region=region, + ) + except CustomException as e: + original_exception = e.original_exception + error_code = original_exception.response["Error"]["Code"] + error_message = original_exception.response["Error"]["Message"] + + if error_code == "ValidationException" and "Cannot create already existing" in error_message: + return SageMakerEndpointMetadata(config=config, region=region, inputs=inputs) + elif ( + error_code == "ResourceLimitExceeded" + and "Please use AWS Service Quotas to request an increase for this quota." in error_message + ): + return SageMakerEndpointMetadata(config=config, region=region, inputs=inputs) + raise e + except Exception as e: + raise e return SageMakerEndpointMetadata(config=config, region=region, inputs=inputs) async def get(self, resource_meta: SageMakerEndpointMetadata, **kwargs) -> Resource: - endpoint_status = await self._call( - method="describe_endpoint", - config={"EndpointName": resource_meta.config.get("EndpointName")}, - inputs=resource_meta.inputs, - region=resource_meta.region, - ) + try: + endpoint_status, idempotence_token = await self._call( + method="describe_endpoint", + config={"EndpointName": resource_meta.config.get("EndpointName")}, + inputs=resource_meta.inputs, + region=resource_meta.region, + ) + except CustomException as e: + original_exception = e.original_exception + error_code = original_exception.response["Error"]["Code"] + error_message = original_exception.response["Error"]["Message"] + + if error_code == "ValidationException" and "Could not find endpoint" in error_message: + raise Exception( + "This might be due to resource limits being exceeded, preventing the creation of a new endpoint. Please check your resource usage and limits." + ) from e + raise e current_state = endpoint_status.get("EndpointStatus") flyte_phase = convert_to_flyte_phase(states[current_state]) @@ -92,7 +109,10 @@ async def get(self, resource_meta: SageMakerEndpointMetadata, **kwargs) -> Resou res = None if current_state == "InService": - res = {"result": json.dumps(endpoint_status, cls=DateTimeEncoder)} + res = { + "result": {"EndpointArn": endpoint_status.get("EndpointArn")}, + "idempotence_token": idempotence_token, + } return Resource(phase=flyte_phase, outputs=res, message=message) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_agent.py index f5624127fb..7b935e9101 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_agent.py @@ -1,3 +1,4 @@ +import re from typing import Optional from flyteidl.core.execution_pb2 import TaskExecution @@ -15,7 +16,7 @@ from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate -from .boto3_mixin import Boto3AgentMixin +from .boto3_mixin import Boto3AgentMixin, CustomException # https://github.com/flyteorg/flyte/issues/4505 @@ -58,12 +59,44 @@ async def do( boto3_object = Boto3AgentMixin(service=service, region=region) - result = await boto3_object._call( - method=method, - config=config, - images=images, - inputs=inputs, - ) + try: + result, idempotence_token = await boto3_object._call( + method=method, + config=config, + images=images, + inputs=inputs, + ) + except CustomException as e: + original_exception = e.original_exception + error_code = original_exception.response["Error"]["Code"] + error_message = original_exception.response["Error"]["Message"] + + if error_code == "ValidationException" and "Cannot create already existing" in error_message: + arn = re.search( + r"arn:aws:[a-zA-Z0-9\-]+:[a-zA-Z0-9\-]+:\d+:[a-zA-Z0-9\-\/]+", + error_message, + ).group(0) + if arn: + return Resource( + phase=TaskExecution.SUCCEEDED, + outputs={ + "result": {"result": f"Entity already exists: {arn}"}, + "idempotence_token": e.idempotence_token, + }, + ) + else: + return Resource( + phase=TaskExecution.SUCCEEDED, + outputs={ + "result": {"result": "Entity already exists."}, + "idempotence_token": e.idempotence_token, + }, + ) + else: + # Re-raise the exception if it's not the specific error we're handling + raise e + except Exception as e: + raise e outputs = {"result": {"result": None}} if result: @@ -83,7 +116,13 @@ async def do( result, Annotated[dict, kwtypes(allow_pickle=True)], TypeEngine.to_literal_type(dict), - ) + ), + "idempotence_token": TypeEngine.to_literal( + new_ctx, + idempotence_token, + str, + TypeEngine.to_literal_type(str), + ), } ) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py index c2596750fc..cf3cc0c14b 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py @@ -1,10 +1,21 @@ +import re from typing import Any, Dict, Optional import aioboto3 +import xxhash +from botocore.exceptions import ClientError from flytekit.interaction.string_literals import literal_map_string_repr from flytekit.models.literals import LiteralMap + +class CustomException(Exception): + def __init__(self, message, idempotence_token, original_exception): + super().__init__(message) + self.idempotence_token = idempotence_token + self.original_exception = original_exception + + account_id_map = { "us-east-1": "785573368785", "us-east-2": "007439368137", @@ -31,7 +42,11 @@ } -def update_dict_fn(original_dict: Any, update_dict: Dict[str, Any]) -> Any: +def update_dict_fn( + original_dict: Any, + update_dict: Dict[str, Any], + idempotence_token: Optional[str] = None, +) -> Any: """ Recursively update a dictionary with values from another dictionary. For example, if original_dict is {"EndpointConfigName": "{endpoint_config_name}"}, @@ -40,6 +55,7 @@ def update_dict_fn(original_dict: Any, update_dict: Dict[str, Any]) -> Any: :param original_dict: The dictionary to update (in place) :param update_dict: The dictionary to use for updating + :param idempotence_token: Hash of config -- this is to ensure the execution ID is deterministic :return: The updated dictionary """ if original_dict is None: @@ -48,44 +64,50 @@ def update_dict_fn(original_dict: Any, update_dict: Dict[str, Any]) -> Any: # If the original value is a string and contains placeholder curly braces if isinstance(original_dict, str): if "{" in original_dict and "}" in original_dict: - # Check if there are nested keys - if "." in original_dict: - # Create a copy of update_dict - update_dict_copy = update_dict.copy() - - # Fetch keys from the original_dict - keys = original_dict.strip("{}").split(".") - - # Get value from the nested dictionary - for key in keys: - try: - update_dict_copy = update_dict_copy[key] - except Exception: - raise ValueError(f"Could not find the key {key} in {update_dict_copy}.") - - return update_dict_copy - - # Retrieve the original value using the key without curly braces - original_value = update_dict.get(original_dict.strip("{}")) - - # Check if original_value exists; if so, return it, - # otherwise, raise a ValueError indicating that the value for the key original_dict could not be found. - if original_value: - return original_value - else: - raise ValueError(f"Could not find value for {original_dict}.") - - # If the string does not contain placeholders, return it as is + matches = re.findall(r"\{([^}]+)\}", original_dict) + for match in matches: + # Check if there are nested keys + if "." in match: + # Create a copy of update_dict + update_dict_copy = update_dict.copy() + + # Fetch keys from the original_dict + keys = match.split(".") + + # Get value from the nested dictionary + for key in keys: + try: + update_dict_copy = update_dict_copy[key] + except Exception: + raise ValueError(f"Could not find the key {key} in {update_dict_copy}.") + + if len(matches) > 1: + # Replace the placeholder in the original_dict + original_dict = original_dict.replace(f"{{{match}}}", update_dict_copy) + else: + # If there's only one match, it needn't always be a string, so not replacing the original dict. + return update_dict_copy + elif match == "idempotence_token" and idempotence_token: + temp_dict = original_dict.replace(f"{{{match}}}", idempotence_token) + if len(temp_dict) > 63: + truncated_idempotence_token = idempotence_token[ + : (63 - len(original_dict.replace("{idempotence_token}", ""))) + ] + original_dict = original_dict.replace(f"{{{match}}}", truncated_idempotence_token) + else: + original_dict = temp_dict + + # If the string does not contain placeholders or if there are multiple placeholders, return the original dict. return original_dict # If the original value is a list, recursively update each element in the list if isinstance(original_dict, list): - return [update_dict_fn(item, update_dict) for item in original_dict] + return [update_dict_fn(item, update_dict, idempotence_token) for item in original_dict] # If the original value is a dictionary, recursively update each key-value pair if isinstance(original_dict, dict): for key, value in original_dict.items(): - original_dict[key] = update_dict_fn(value, update_dict) + original_dict[key] = update_dict_fn(value, update_dict, idempotence_token) # Return the updated original dict return original_dict @@ -116,7 +138,7 @@ async def _call( images: Optional[Dict[str, str]] = None, inputs: Optional[LiteralMap] = None, region: Optional[str] = None, - ) -> Any: + ) -> tuple[Any, str]: """ Utilize this method to invoke any boto3 method (AWS service method). @@ -162,6 +184,12 @@ async def _call( updated_config = update_dict_fn(config, args) + hash = "" + if "idempotence_token" in str(updated_config): + # compute hash of the config + hash = xxhash.xxh64(str(updated_config)).hexdigest() + updated_config = update_dict_fn(updated_config, args, idempotence_token=hash) + # Asynchronous Boto3 session session = aioboto3.Session() async with session.client( @@ -170,7 +198,7 @@ async def _call( ) as client: try: result = await getattr(client, method)(**updated_config) - except Exception as e: - raise e + except ClientError as e: + raise CustomException(f"An error occurred: {e}", hash, e) from e - return result + return result, hash diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_task.py index 1cb59eab08..332523cc8c 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_task.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_task.py @@ -34,7 +34,7 @@ def __init__( task_type=self._TASK_TYPE, interface=Interface( inputs=inputs, - outputs=kwtypes(result=dict), + outputs=kwtypes(result=dict, idempotence_token=str), ), **kwargs, ) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/task.py index a381547bf5..8714915776 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/task.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/task.py @@ -95,7 +95,7 @@ def __init__( super().__init__( name=name, task_type=self._TASK_TYPE, - interface=Interface(inputs=inputs, outputs=kwtypes(result=str)), + interface=Interface(inputs=inputs, outputs=kwtypes(result=dict, idempotence_token=str)), **kwargs, ) self._config = config diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py index 87a27c7497..13b89e8ec4 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py @@ -27,7 +27,13 @@ def create_deployment_task( else: inputs = kwtypes(region=str) return ( - task_type(name=name, config=config, region=region, inputs=inputs, images=images), + task_type( + name=name, + config=config, + region=region, + inputs=inputs, + images=images, + ), inputs, ) @@ -89,6 +95,11 @@ def create_sagemaker_deployment( nodes = [] for key, value in inputs.items(): input_types = value["input_types"] + if len(nodes) > 0: + if not input_types: + input_types = {} + input_types["idempotence_token"] = str + obj, new_input_types = create_deployment_task( name=f"{value['name']}-{name}", task_type=key, @@ -101,16 +112,29 @@ def create_sagemaker_deployment( input_dict = {} if isinstance(new_input_types, dict): for param, t in new_input_types.items(): - # Handles the scenario when the same input is present during different API calls. - if param not in wf.inputs.keys(): - wf.add_workflow_input(param, t) - input_dict[param] = wf.inputs[param] + if param != "idempotence_token": + # Handles the scenario when the same input is present during different API calls. + if param not in wf.inputs.keys(): + wf.add_workflow_input(param, t) + input_dict[param] = wf.inputs[param] + else: + input_dict["idempotence_token"] = nodes[-1].outputs["idempotence_token"] + node = wf.add_entity(obj, **input_dict) + if len(nodes) > 0: nodes[-1] >> node nodes.append(node) - wf.add_workflow_output("wf_output", nodes[2].outputs["result"], str) + wf.add_workflow_output( + "wf_output", + [ + nodes[0].outputs["result"], + nodes[1].outputs["result"], + nodes[2].outputs["result"], + ], + list[dict], + ) return wf diff --git a/plugins/flytekit-aws-sagemaker/setup.py b/plugins/flytekit-aws-sagemaker/setup.py index cdc4b816b6..c4bfe27026 100644 --- a/plugins/flytekit-aws-sagemaker/setup.py +++ b/plugins/flytekit-aws-sagemaker/setup.py @@ -5,7 +5,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.11.0", "aioboto3>=12.3.0"] +plugin_requires = ["flytekit>=1.11.0", "aioboto3>=12.3.0", "xxhash"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py b/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py index f17e50ea6f..fcdffe83fa 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py @@ -12,6 +12,11 @@ from flytekit.models.core.identifier import ResourceType from flytekit.models.task import RuntimeMetadata, TaskMetadata, TaskTemplate +from flytekitplugins.awssagemaker_inference.boto3_mixin import CustomException +from botocore.exceptions import ClientError + +idempotence_token = "74443947857331f7" + @pytest.mark.asyncio @pytest.mark.parametrize( @@ -31,7 +36,8 @@ }, }, "EndpointConfigArn": "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config", - } + }, + idempotence_token, ), ( { @@ -48,9 +54,25 @@ }, "pickle_check": datetime(2024, 5, 5), "EndpointConfigArn": "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config", - } + }, + idempotence_token, + ), + (None, idempotence_token), + ( + CustomException( + message="An error occurred", + idempotence_token=idempotence_token, + original_exception=ClientError( + error_response={ + "Error": { + "Code": "ValidationException", + "Message": "Cannot create already existing endpoint 'arn:aws:sagemaker:us-east-2:123456789:endpoint/stable-diffusion-endpoint-non-finetuned-06716dbe4b2c68e7'", + } + }, + operation_name="DescribeEndpoint", + ), + ) ), - (None), ], ) @mock.patch( @@ -79,7 +101,9 @@ async def test_agent(mock_boto_call, mock_return_value): "InstanceType": "ml.m4.xlarge", }, ], - "AsyncInferenceConfig": {"OutputConfig": {"S3OutputPath": "{inputs.s3_output_path}"}}, + "AsyncInferenceConfig": { + "OutputConfig": {"S3OutputPath": "{inputs.s3_output_path}"} + }, }, "region": "us-east-2", "method": "create_endpoint_config", @@ -87,7 +111,9 @@ async def test_agent(mock_boto_call, mock_return_value): } task_metadata = TaskMetadata( discoverable=True, - runtime=RuntimeMetadata(RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + runtime=RuntimeMetadata( + RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python" + ), timeout=timedelta(days=1), retries=literals.RetryStrategy(3), interruptible=True, @@ -108,28 +134,50 @@ async def test_agent(mock_boto_call, mock_return_value): task_inputs = literals.LiteralMap( { "model_name": literals.Literal( - scalar=literals.Scalar(primitive=literals.Primitive(string_value="sagemaker-model")) + scalar=literals.Scalar( + primitive=literals.Primitive(string_value="sagemaker-model") + ) ), "s3_output_path": literals.Literal( - scalar=literals.Scalar(primitive=literals.Primitive(string_value="s3-output-path")) + scalar=literals.Scalar( + primitive=literals.Primitive(string_value="s3-output-path") + ) ), }, ) ctx = FlyteContext.current_context() output_prefix = ctx.file_access.get_random_remote_directory() - resource = await agent.do(task_template=task_template, inputs=task_inputs, output_prefix=output_prefix) + + if isinstance(mock_return_value, Exception): + mock_boto_call.side_effect = mock_return_value + + resource = await agent.do( + task_template=task_template, + inputs=task_inputs, + output_prefix=output_prefix, + ) + assert resource.outputs["result"] == { + "result": f"Entity already exists: arn:aws:sagemaker:us-east-2:123456789:endpoint/stable-diffusion-endpoint-non-finetuned-06716dbe4b2c68e7" + } + assert resource.outputs["idempotence_token"] == idempotence_token + return + + resource = await agent.do( + task_template=task_template, inputs=task_inputs, output_prefix=output_prefix + ) assert resource.phase == TaskExecution.SUCCEEDED - if mock_return_value: + if mock_return_value[0]: outputs = literal_map_string_repr(resource.outputs) - if "pickle_check" in mock_return_value: + if "pickle_check" in mock_return_value[0]: assert "pickle_file" in outputs["result"] else: assert ( outputs["result"]["EndpointConfigArn"] == "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config" ) - elif mock_return_value is None: + assert outputs["idempotence_token"] == "74443947857331f7" + elif mock_return_value[0] is None: assert resource.outputs["result"] == {"result": None} diff --git a/plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py b/plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py index c53088cf38..60d0dd45af 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py @@ -101,7 +101,7 @@ async def test_call(mock_session): {"model_name": str, "region": str}, ) - result = await mixin._call( + result, idempotence_token = await mixin._call( method="create_model", config=config, inputs=inputs, @@ -117,3 +117,4 @@ async def test_call(mock_session): ) assert result == mock_method.return_value + assert idempotence_token == "" diff --git a/plugins/flytekit-aws-sagemaker/tests/test_boto3_task.py b/plugins/flytekit-aws-sagemaker/tests/test_boto3_task.py index 78dce7eae3..893634536e 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_boto3_task.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_boto3_task.py @@ -13,20 +13,21 @@ def test_boto_task_and_config(): config={ "ModelName": "{inputs.model_name}", "PrimaryContainer": { - "Image": "{container.image}", + "Image": "{images.deployment_image}", "ModelDataUrl": "{inputs.model_data_url}", }, "ExecutionRoleArn": "{inputs.execution_role_arn}", }, region="us-east-2", + images={ + "deployment_image": "1234567890.dkr.ecr.us-east-2.amazonaws.com/sagemaker-xgboost" + }, ), inputs=kwtypes(model_name=str, model_data_url=str, execution_role_arn=str), - outputs=kwtypes(result=dict), - container_image="1234567890.dkr.ecr.us-east-2.amazonaws.com/sagemaker-xgboost", ) assert len(boto_task.interface.inputs) == 3 - assert len(boto_task.interface.outputs) == 1 + assert len(boto_task.interface.outputs) == 2 default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( @@ -43,10 +44,14 @@ def test_boto_task_and_config(): assert retrieved_setttings["config"] == { "ModelName": "{inputs.model_name}", "PrimaryContainer": { - "Image": "{container.image}", + "Image": "{images.deployment_image}", "ModelDataUrl": "{inputs.model_data_url}", }, "ExecutionRoleArn": "{inputs.execution_role_arn}", } assert retrieved_setttings["region"] == "us-east-2" assert retrieved_setttings["method"] == "create_model" + assert ( + retrieved_setttings["images"]["deployment_image"] + == "1234567890.dkr.ecr.us-east-2.amazonaws.com/sagemaker-xgboost" + ) diff --git a/plugins/flytekit-aws-sagemaker/tests/test_inference_agent.py b/plugins/flytekit-aws-sagemaker/tests/test_inference_agent.py index 5ee8d11f01..b3c8cba2e6 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_inference_agent.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_inference_agent.py @@ -12,50 +12,82 @@ from flytekit.models.core.identifier import ResourceType from flytekit.models.task import RuntimeMetadata, TaskMetadata, TaskTemplate +from flytekitplugins.awssagemaker_inference.boto3_mixin import CustomException +from botocore.exceptions import ClientError + +idempotence_token = "74443947857331f7" + @pytest.mark.asyncio -@mock.patch( - "flytekitplugins.awssagemaker_inference.agent.Boto3AgentMixin._call", - return_value={ - "EndpointName": "sagemaker-xgboost-endpoint", - "EndpointArn": "arn:aws:sagemaker:us-east-2:1234567890:endpoint/sagemaker-xgboost-endpoint", - "EndpointConfigName": "sagemaker-xgboost-endpoint-config", - "ProductionVariants": [ +@pytest.mark.parametrize( + "mock_return_value", + [ + ( { - "VariantName": "variant-name-1", - "DeployedImages": [ + "EndpointName": "sagemaker-xgboost-endpoint", + "EndpointArn": "arn:aws:sagemaker:us-east-2:1234567890:endpoint/sagemaker-xgboost-endpoint", + "EndpointConfigName": "sagemaker-xgboost-endpoint-config", + "ProductionVariants": [ { - "SpecifiedImage": "1234567890.dkr.ecr.us-east-2.amazonaws.com/sagemaker-xgboost:iL3_jIEY3lQPB4wnlS7HKA..", - "ResolvedImage": "1234567890.dkr.ecr.us-east-2.amazonaws.com/sagemaker-xgboost@sha256:0725042bf15f384c46e93bbf7b22c0502859981fc8830fd3aea4127469e8cf1e", - "ResolutionTime": "2024-01-31T22:14:07.193000+05:30", + "VariantName": "variant-name-1", + "DeployedImages": [ + { + "SpecifiedImage": "1234567890.dkr.ecr.us-east-2.amazonaws.com/sagemaker-xgboost:iL3_jIEY3lQPB4wnlS7HKA..", + "ResolvedImage": "1234567890.dkr.ecr.us-east-2.amazonaws.com/sagemaker-xgboost@sha256:0725042bf15f384c46e93bbf7b22c0502859981fc8830fd3aea4127469e8cf1e", + "ResolutionTime": "2024-01-31T22:14:07.193000+05:30", + } + ], + "CurrentWeight": 1.0, + "DesiredWeight": 1.0, + "CurrentInstanceCount": 1, + "DesiredInstanceCount": 1, } ], - "CurrentWeight": 1.0, - "DesiredWeight": 1.0, - "CurrentInstanceCount": 1, - "DesiredInstanceCount": 1, - } - ], - "EndpointStatus": "InService", - "CreationTime": "2024-01-31T22:14:06.553000+05:30", - "LastModifiedTime": "2024-01-31T22:16:37.001000+05:30", - "AsyncInferenceConfig": { - "OutputConfig": {"S3OutputPath": "s3://sagemaker-agent-xgboost/inference-output/output"} - }, - "ResponseMetadata": { - "RequestId": "50d8bfa7-d84-4bd9-bf11-992832f42793", - "HTTPStatusCode": 200, - "HTTPHeaders": { - "x-amzn-requestid": "50d8bfa7-d840-4bd9-bf11-992832f42793", - "content-type": "application/x-amz-json-1.1", - "content-length": "865", - "date": "Wed, 31 Jan 2024 16:46:38 GMT", + "EndpointStatus": "InService", + "CreationTime": "2024-01-31T22:14:06.553000+05:30", + "LastModifiedTime": "2024-01-31T22:16:37.001000+05:30", + "AsyncInferenceConfig": { + "OutputConfig": { + "S3OutputPath": "s3://sagemaker-agent-xgboost/inference-output/output" + } + }, + "ResponseMetadata": { + "RequestId": "50d8bfa7-d84-4bd9-bf11-992832f42793", + "HTTPStatusCode": 200, + "HTTPHeaders": { + "x-amzn-requestid": "50d8bfa7-d840-4bd9-bf11-992832f42793", + "content-type": "application/x-amz-json-1.1", + "content-length": "865", + "date": "Wed, 31 Jan 2024 16:46:38 GMT", + }, + "RetryAttempts": 0, + }, }, - "RetryAttempts": 0, - }, - }, + idempotence_token, + ), + ( + CustomException( + message="An error occurred", + idempotence_token=idempotence_token, + original_exception=ClientError( + error_response={ + "Error": { + "Code": "ValidationException", + "Message": "Cannot create already existing endpoint 'arn:aws:sagemaker:us-east-2:123456789:endpoint/stable-diffusion-endpoint-non-finetuned-06716dbe4b2c68e7'", + } + }, + operation_name="CreateEndpoint", + ), + ) + ), + ], +) +@mock.patch( + "flytekitplugins.awssagemaker_inference.agent.Boto3AgentMixin._call", ) -async def test_agent(mock_boto_call): +async def test_agent(mock_boto_call, mock_return_value): + mock_boto_call.return_value = mock_return_value + agent = AgentRegistry.get_agent("sagemaker-endpoint") task_id = Identifier( resource_type=ResourceType.TASK, @@ -67,7 +99,7 @@ async def test_agent(mock_boto_call): task_config = { "service": "sagemaker", "config": { - "EndpointName": "sagemaker-endpoint", + "EndpointName": "sagemaker-endpoint-{idempotence_token}", "EndpointConfigName": "sagemaker-endpoint-config", }, "region": "us-east-2", @@ -75,7 +107,9 @@ async def test_agent(mock_boto_call): } task_metadata = TaskMetadata( discoverable=True, - runtime=RuntimeMetadata(RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + runtime=RuntimeMetadata( + RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python" + ), timeout=timedelta(days=1), retries=literals.RetryStrategy(3), interruptible=True, @@ -94,14 +128,38 @@ async def test_agent(mock_boto_call): type="sagemaker-endpoint", ) - # CREATE metadata = SageMakerEndpointMetadata( config={ - "EndpointName": "sagemaker-endpoint", + "EndpointName": "sagemaker-endpoint-{idempotence_token}", "EndpointConfigName": "sagemaker-endpoint-config", }, region="us-east-2", ) + + # Exception check + if isinstance(mock_return_value, Exception): + response = await agent.create(task_template) + assert response == metadata + + mock_boto_call.side_effect = CustomException( + message="An error occurred", + idempotence_token=idempotence_token, + original_exception=ClientError( + error_response={ + "Error": { + "Code": "ValidationException", + "Message": "Could not find endpoint 'arn:aws:sagemaker:us-east-2:123456789:endpoint/stable-diffusion-endpoint-non-finetuned-06716dbe4b2c68e7'", + } + }, + operation_name="DescribeEndpoint", + ), + ) + + with pytest.raises(Exception, match="resource limits being exceeded"): + resource = await agent.get(metadata) + return + + # CREATE response = await agent.create(task_template) assert response == metadata @@ -109,9 +167,11 @@ async def test_agent(mock_boto_call): resource = await agent.get(metadata) assert resource.phase == TaskExecution.SUCCEEDED - from_json = json.loads(resource.outputs["result"]) - assert from_json["EndpointName"] == "sagemaker-xgboost-endpoint" - assert from_json["EndpointArn"] == "arn:aws:sagemaker:us-east-2:1234567890:endpoint/sagemaker-xgboost-endpoint" + assert ( + resource.outputs["result"]["EndpointArn"] + == "arn:aws:sagemaker:us-east-2:1234567890:endpoint/sagemaker-xgboost-endpoint" + ) + assert resource.outputs["idempotence_token"] == idempotence_token # DELETE delete_response = await agent.delete(metadata) diff --git a/plugins/flytekit-aws-sagemaker/tests/test_inference_task.py b/plugins/flytekit-aws-sagemaker/tests/test_inference_task.py index 93e61d909d..f74e0cc4b6 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_inference_task.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_inference_task.py @@ -29,9 +29,11 @@ "sagemaker", "create_model", kwtypes(model_name=str, model_data_url=str, execution_role_arn=str), - {"primary_container_image": "1234567890.dkr.ecr.us-east-2.amazonaws.com/sagemaker-xgboost"}, + { + "primary_container_image": "1234567890.dkr.ecr.us-east-2.amazonaws.com/sagemaker-xgboost" + }, 3, - 1, + 2, "us-east-2", SageMakerModelTask, ), @@ -47,14 +49,16 @@ "InstanceType": "ml.m4.xlarge", }, ], - "AsyncInferenceConfig": {"OutputConfig": {"S3OutputPath": "{inputs.s3_output_path}"}}, + "AsyncInferenceConfig": { + "OutputConfig": {"S3OutputPath": "{inputs.s3_output_path}"} + }, }, "sagemaker", "create_endpoint_config", kwtypes(endpoint_config_name=str, model_name=str, s3_output_path=str), None, 3, - 1, + 2, "us-east-2", SageMakerEndpointConfigTask, ), @@ -69,7 +73,7 @@ kwtypes(endpoint_name=str, endpoint_config_name=str), None, 2, - 1, + 2, "us-east-2", SageMakerEndpointTask, ), @@ -81,7 +85,7 @@ kwtypes(endpoint_name=str), None, 1, - 1, + 2, "us-east-2", SageMakerDeleteEndpointTask, ), @@ -93,7 +97,7 @@ kwtypes(endpoint_config_name=str), None, 1, - 1, + 2, "us-east-2", SageMakerDeleteEndpointConfigTask, ), @@ -105,7 +109,7 @@ kwtypes(model_name=str), None, 1, - 1, + 2, "us-east-2", SageMakerDeleteModelTask, ), @@ -120,7 +124,7 @@ kwtypes(endpoint_name=str), None, 1, - 1, + 2, "us-east-2", SageMakerInvokeEndpointTask, ), @@ -135,7 +139,7 @@ kwtypes(endpoint_name=str, region=str), None, 2, - 1, + 2, None, SageMakerInvokeEndpointTask, ), diff --git a/plugins/flytekit-aws-sagemaker/tests/test_inference_workflow.py b/plugins/flytekit-aws-sagemaker/tests/test_inference_workflow.py index f98bb557fa..3546ec43a0 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_inference_workflow.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_inference_workflow.py @@ -1,4 +1,7 @@ -from flytekitplugins.awssagemaker_inference import create_sagemaker_deployment, delete_sagemaker_deployment +from flytekitplugins.awssagemaker_inference import ( + create_sagemaker_deployment, + delete_sagemaker_deployment, +) from flytekit import kwtypes @@ -17,7 +20,7 @@ def test_sagemaker_deployment_workflow(): }, endpoint_config_input_types=kwtypes(instance_type=str), endpoint_config_config={ - "EndpointConfigName": "sagemaker-xgboost-endpoint-config", + "EndpointConfigName": "sagemaker-xgboost", "ProductionVariants": [ { "VariantName": "variant-name-1", @@ -27,14 +30,18 @@ def test_sagemaker_deployment_workflow(): }, ], "AsyncInferenceConfig": { - "OutputConfig": {"S3OutputPath": "s3://sagemaker-agent-xgboost/inference-output/output"} + "OutputConfig": { + "S3OutputPath": "s3://sagemaker-agent-xgboost/inference-output/output" + } }, }, endpoint_config={ - "EndpointName": "sagemaker-xgboost-endpoint", - "EndpointConfigName": "sagemaker-xgboost-endpoint-config", + "EndpointName": "sagemaker-xgboost", + "EndpointConfigName": "sagemaker-xgboost", + }, + images={ + "primary_container_image": "1234567890.dkr.ecr.us-east-2.amazonaws.com/sagemaker-xgboost" }, - images={"primary_container_image": "1234567890.dkr.ecr.us-east-2.amazonaws.com/sagemaker-xgboost"}, region="us-east-2", ) @@ -57,7 +64,7 @@ def test_sagemaker_deployment_workflow_with_region_at_runtime(): }, endpoint_config_input_types=kwtypes(instance_type=str), endpoint_config_config={ - "EndpointConfigName": "sagemaker-xgboost-endpoint-config", + "EndpointConfigName": "sagemaker-xgboost", "ProductionVariants": [ { "VariantName": "variant-name-1", @@ -67,14 +74,18 @@ def test_sagemaker_deployment_workflow_with_region_at_runtime(): }, ], "AsyncInferenceConfig": { - "OutputConfig": {"S3OutputPath": "s3://sagemaker-agent-xgboost/inference-output/output"} + "OutputConfig": { + "S3OutputPath": "s3://sagemaker-agent-xgboost/inference-output/output" + } }, }, endpoint_config={ - "EndpointName": "sagemaker-xgboost-endpoint", - "EndpointConfigName": "sagemaker-xgboost-endpoint-config", + "EndpointName": "sagemaker-xgboost", + "EndpointConfigName": "sagemaker-xgboost", + }, + images={ + "primary_container_image": "1234567890.dkr.ecr.us-east-2.amazonaws.com/sagemaker-xgboost" }, - images={"primary_container_image": "1234567890.dkr.ecr.us-east-2.amazonaws.com/sagemaker-xgboost"}, region_at_runtime=True, )