Skip to content

Commit

Permalink
handle existing sagemaker deployments gracefully (flyteorg#2400)
Browse files Browse the repository at this point in the history
* add support to override model deployment

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* fix task chaining

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* fix task chaining

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* nit

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* add is deployment exists function

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* add override

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* mutable fix

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* nit

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* update workflow

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* add idempotence token and deployment exists check

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* fix init

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* fix workflow output

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* fix wf output

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* idempotence token as task output

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* move try catch to mixin

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* fix wf code

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* move try catch to agents

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* fixed all bugs

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* fix tests

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

---------

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>
Signed-off-by: mao3267 <chenvincent610@gmail.com>
  • Loading branch information
samhita-alla authored and mao3267 committed Jul 29, 2024
1 parent 5d7b03c commit 23b1864
Show file tree
Hide file tree
Showing 13 changed files with 397 additions and 157 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import json
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, Optional

import cloudpickle
Expand All @@ -15,7 +13,7 @@
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate

from .boto3_mixin import Boto3AgentMixin
from .boto3_mixin import Boto3AgentMixin, CustomException


@dataclass
Expand All @@ -39,14 +37,6 @@ def decode(cls, data: bytes) -> "SageMakerEndpointMetadata":
}


class DateTimeEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, datetime):
return o.isoformat()

return json.JSONEncoder.default(self, o)


class SageMakerEndpointAgent(Boto3AgentMixin, AsyncAgentBase):
"""This agent creates an endpoint."""

Expand All @@ -66,22 +56,49 @@ async def create(
config = custom.get("config")
region = custom.get("region")

await self._call(
method="create_endpoint",
config=config,
inputs=inputs,
region=region,
)
try:
await self._call(
method="create_endpoint",
config=config,
inputs=inputs,
region=region,
)
except CustomException as e:
original_exception = e.original_exception
error_code = original_exception.response["Error"]["Code"]
error_message = original_exception.response["Error"]["Message"]

if error_code == "ValidationException" and "Cannot create already existing" in error_message:
return SageMakerEndpointMetadata(config=config, region=region, inputs=inputs)
elif (
error_code == "ResourceLimitExceeded"
and "Please use AWS Service Quotas to request an increase for this quota." in error_message
):
return SageMakerEndpointMetadata(config=config, region=region, inputs=inputs)
raise e
except Exception as e:
raise e

return SageMakerEndpointMetadata(config=config, region=region, inputs=inputs)

async def get(self, resource_meta: SageMakerEndpointMetadata, **kwargs) -> Resource:
endpoint_status = await self._call(
method="describe_endpoint",
config={"EndpointName": resource_meta.config.get("EndpointName")},
inputs=resource_meta.inputs,
region=resource_meta.region,
)
try:
endpoint_status, idempotence_token = await self._call(
method="describe_endpoint",
config={"EndpointName": resource_meta.config.get("EndpointName")},
inputs=resource_meta.inputs,
region=resource_meta.region,
)
except CustomException as e:
original_exception = e.original_exception
error_code = original_exception.response["Error"]["Code"]
error_message = original_exception.response["Error"]["Message"]

if error_code == "ValidationException" and "Could not find endpoint" in error_message:
raise Exception(
"This might be due to resource limits being exceeded, preventing the creation of a new endpoint. Please check your resource usage and limits."
) from e
raise e

current_state = endpoint_status.get("EndpointStatus")
flyte_phase = convert_to_flyte_phase(states[current_state])
Expand All @@ -92,7 +109,10 @@ async def get(self, resource_meta: SageMakerEndpointMetadata, **kwargs) -> Resou

res = None
if current_state == "InService":
res = {"result": json.dumps(endpoint_status, cls=DateTimeEncoder)}
res = {
"result": {"EndpointArn": endpoint_status.get("EndpointArn")},
"idempotence_token": idempotence_token,
}

return Resource(phase=flyte_phase, outputs=res, message=message)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from typing import Optional

from flyteidl.core.execution_pb2 import TaskExecution
Expand All @@ -15,7 +16,7 @@
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate

from .boto3_mixin import Boto3AgentMixin
from .boto3_mixin import Boto3AgentMixin, CustomException


# https://github.com/flyteorg/flyte/issues/4505
Expand Down Expand Up @@ -58,12 +59,44 @@ async def do(

boto3_object = Boto3AgentMixin(service=service, region=region)

result = await boto3_object._call(
method=method,
config=config,
images=images,
inputs=inputs,
)
try:
result, idempotence_token = await boto3_object._call(
method=method,
config=config,
images=images,
inputs=inputs,
)
except CustomException as e:
original_exception = e.original_exception
error_code = original_exception.response["Error"]["Code"]
error_message = original_exception.response["Error"]["Message"]

if error_code == "ValidationException" and "Cannot create already existing" in error_message:
arn = re.search(
r"arn:aws:[a-zA-Z0-9\-]+:[a-zA-Z0-9\-]+:\d+:[a-zA-Z0-9\-\/]+",
error_message,
).group(0)
if arn:
return Resource(
phase=TaskExecution.SUCCEEDED,
outputs={
"result": {"result": f"Entity already exists: {arn}"},
"idempotence_token": e.idempotence_token,
},
)
else:
return Resource(
phase=TaskExecution.SUCCEEDED,
outputs={
"result": {"result": "Entity already exists."},
"idempotence_token": e.idempotence_token,
},
)
else:
# Re-raise the exception if it's not the specific error we're handling
raise e
except Exception as e:
raise e

outputs = {"result": {"result": None}}
if result:
Expand All @@ -83,7 +116,13 @@ async def do(
result,
Annotated[dict, kwtypes(allow_pickle=True)],
TypeEngine.to_literal_type(dict),
)
),
"idempotence_token": TypeEngine.to_literal(
new_ctx,
idempotence_token,
str,
TypeEngine.to_literal_type(str),
),
}
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
import re
from typing import Any, Dict, Optional

import aioboto3
import xxhash
from botocore.exceptions import ClientError

from flytekit.interaction.string_literals import literal_map_string_repr
from flytekit.models.literals import LiteralMap


class CustomException(Exception):
def __init__(self, message, idempotence_token, original_exception):
super().__init__(message)
self.idempotence_token = idempotence_token
self.original_exception = original_exception


account_id_map = {
"us-east-1": "785573368785",
"us-east-2": "007439368137",
Expand All @@ -31,7 +42,11 @@
}


def update_dict_fn(original_dict: Any, update_dict: Dict[str, Any]) -> Any:
def update_dict_fn(
original_dict: Any,
update_dict: Dict[str, Any],
idempotence_token: Optional[str] = None,
) -> Any:
"""
Recursively update a dictionary with values from another dictionary.
For example, if original_dict is {"EndpointConfigName": "{endpoint_config_name}"},
Expand All @@ -40,6 +55,7 @@ def update_dict_fn(original_dict: Any, update_dict: Dict[str, Any]) -> Any:
:param original_dict: The dictionary to update (in place)
:param update_dict: The dictionary to use for updating
:param idempotence_token: Hash of config -- this is to ensure the execution ID is deterministic
:return: The updated dictionary
"""
if original_dict is None:
Expand All @@ -48,44 +64,50 @@ def update_dict_fn(original_dict: Any, update_dict: Dict[str, Any]) -> Any:
# If the original value is a string and contains placeholder curly braces
if isinstance(original_dict, str):
if "{" in original_dict and "}" in original_dict:
# Check if there are nested keys
if "." in original_dict:
# Create a copy of update_dict
update_dict_copy = update_dict.copy()

# Fetch keys from the original_dict
keys = original_dict.strip("{}").split(".")

# Get value from the nested dictionary
for key in keys:
try:
update_dict_copy = update_dict_copy[key]
except Exception:
raise ValueError(f"Could not find the key {key} in {update_dict_copy}.")

return update_dict_copy

# Retrieve the original value using the key without curly braces
original_value = update_dict.get(original_dict.strip("{}"))

# Check if original_value exists; if so, return it,
# otherwise, raise a ValueError indicating that the value for the key original_dict could not be found.
if original_value:
return original_value
else:
raise ValueError(f"Could not find value for {original_dict}.")

# If the string does not contain placeholders, return it as is
matches = re.findall(r"\{([^}]+)\}", original_dict)
for match in matches:
# Check if there are nested keys
if "." in match:
# Create a copy of update_dict
update_dict_copy = update_dict.copy()

# Fetch keys from the original_dict
keys = match.split(".")

# Get value from the nested dictionary
for key in keys:
try:
update_dict_copy = update_dict_copy[key]
except Exception:
raise ValueError(f"Could not find the key {key} in {update_dict_copy}.")

if len(matches) > 1:
# Replace the placeholder in the original_dict
original_dict = original_dict.replace(f"{{{match}}}", update_dict_copy)
else:
# If there's only one match, it needn't always be a string, so not replacing the original dict.
return update_dict_copy
elif match == "idempotence_token" and idempotence_token:
temp_dict = original_dict.replace(f"{{{match}}}", idempotence_token)
if len(temp_dict) > 63:
truncated_idempotence_token = idempotence_token[
: (63 - len(original_dict.replace("{idempotence_token}", "")))
]
original_dict = original_dict.replace(f"{{{match}}}", truncated_idempotence_token)
else:
original_dict = temp_dict

# If the string does not contain placeholders or if there are multiple placeholders, return the original dict.
return original_dict

# If the original value is a list, recursively update each element in the list
if isinstance(original_dict, list):
return [update_dict_fn(item, update_dict) for item in original_dict]
return [update_dict_fn(item, update_dict, idempotence_token) for item in original_dict]

# If the original value is a dictionary, recursively update each key-value pair
if isinstance(original_dict, dict):
for key, value in original_dict.items():
original_dict[key] = update_dict_fn(value, update_dict)
original_dict[key] = update_dict_fn(value, update_dict, idempotence_token)

# Return the updated original dict
return original_dict
Expand Down Expand Up @@ -116,7 +138,7 @@ async def _call(
images: Optional[Dict[str, str]] = None,
inputs: Optional[LiteralMap] = None,
region: Optional[str] = None,
) -> Any:
) -> tuple[Any, str]:
"""
Utilize this method to invoke any boto3 method (AWS service method).
Expand Down Expand Up @@ -162,6 +184,12 @@ async def _call(

updated_config = update_dict_fn(config, args)

hash = ""
if "idempotence_token" in str(updated_config):
# compute hash of the config
hash = xxhash.xxh64(str(updated_config)).hexdigest()
updated_config = update_dict_fn(updated_config, args, idempotence_token=hash)

# Asynchronous Boto3 session
session = aioboto3.Session()
async with session.client(
Expand All @@ -170,7 +198,7 @@ async def _call(
) as client:
try:
result = await getattr(client, method)(**updated_config)
except Exception as e:
raise e
except ClientError as e:
raise CustomException(f"An error occurred: {e}", hash, e) from e

return result
return result, hash
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
task_type=self._TASK_TYPE,
interface=Interface(
inputs=inputs,
outputs=kwtypes(result=dict),
outputs=kwtypes(result=dict, idempotence_token=str),
),
**kwargs,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(
super().__init__(
name=name,
task_type=self._TASK_TYPE,
interface=Interface(inputs=inputs, outputs=kwtypes(result=str)),
interface=Interface(inputs=inputs, outputs=kwtypes(result=dict, idempotence_token=str)),
**kwargs,
)
self._config = config
Expand Down
Loading

0 comments on commit 23b1864

Please sign in to comment.