diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py b/airflow/providers/amazon/aws/operators/sagemaker.py index 0d4ba9fcabb34..8467459188508 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker.py +++ b/airflow/providers/amazon/aws/operators/sagemaker.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import datetime import json import time import warnings @@ -375,6 +376,7 @@ class SageMakerEndpointOperator(SageMakerBaseOperator): finish within max_ingestion_time seconds. If you set this parameter to None it never times out. :param operation: Whether to create an endpoint or update an endpoint. Must be either 'create or 'update'. :param aws_conn_id: The AWS connection ID to use. + :param deferrable: Will wait asynchronously for completion. :return Dict: Returns The ARN of the endpoint created in Amazon SageMaker. """ @@ -387,15 +389,17 @@ def __init__( check_interval: int = CHECK_INTERVAL_SECOND, max_ingestion_time: int | None = None, operation: str = "create", + deferrable: bool = False, **kwargs, ): super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs) self.wait_for_completion = wait_for_completion self.check_interval = check_interval - self.max_ingestion_time = max_ingestion_time + self.max_ingestion_time = max_ingestion_time or 3600 * 10 self.operation = operation.lower() if self.operation not in ["create", "update"]: raise ValueError('Invalid value! Argument operation has to be one of "create" and "update"') + self.deferrable = deferrable def _create_integer_fields(self) -> None: """Set fields which should be cast to integers.""" @@ -436,29 +440,54 @@ def execute(self, context: Context) -> dict: try: response = sagemaker_operation( endpoint_info, - wait_for_completion=self.wait_for_completion, - check_interval=self.check_interval, - max_ingestion_time=self.max_ingestion_time, + wait_for_completion=False, ) + # waiting for completion is handled here in the operator except ClientError: self.operation = "update" sagemaker_operation = self.hook.update_endpoint - log_str = "Updating" response = sagemaker_operation( endpoint_info, - wait_for_completion=self.wait_for_completion, - check_interval=self.check_interval, - max_ingestion_time=self.max_ingestion_time, + wait_for_completion=False, ) + if response["ResponseMetadata"]["HTTPStatusCode"] != 200: raise AirflowException(f"Sagemaker endpoint creation failed: {response}") - else: - return { - "EndpointConfig": serialize( - self.hook.describe_endpoint_config(endpoint_info["EndpointConfigName"]) + + if self.deferrable: + self.defer( + trigger=SageMakerTrigger( + job_name=endpoint_info["EndpointName"], + job_type="endpoint", + poke_interval=self.check_interval, + aws_conn_id=self.aws_conn_id, ), - "Endpoint": serialize(self.hook.describe_endpoint(endpoint_info["EndpointName"])), - } + method_name="execute_complete", + timeout=datetime.timedelta(seconds=self.max_ingestion_time), + ) + elif self.wait_for_completion: + self.hook.get_waiter("endpoint_in_service").wait( + EndpointName=endpoint_info["EndpointName"], + WaiterConfig={"Delay": self.check_interval, "MaxAttempts": self.max_ingestion_time}, + ) + + return { + "EndpointConfig": serialize( + self.hook.describe_endpoint_config(endpoint_info["EndpointConfigName"]) + ), + "Endpoint": serialize(self.hook.describe_endpoint(endpoint_info["EndpointName"])), + } + + def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error while running job: {event}") + endpoint_info = self.config.get("Endpoint", self.config) + return { + "EndpointConfig": serialize( + self.hook.describe_endpoint_config(endpoint_info["EndpointConfigName"]) + ), + "Endpoint": serialize(self.hook.describe_endpoint(endpoint_info["EndpointName"])), + } class SageMakerTransformOperator(SageMakerBaseOperator): @@ -652,6 +681,7 @@ class SageMakerTuningOperator(SageMakerBaseOperator): :param max_ingestion_time: If wait is set to True, the operation fails if the tuning job doesn't finish within max_ingestion_time seconds. If you set this parameter to None, the operation does not timeout. + :param deferrable: Will wait asynchronously for completion. :return Dict: Returns The ARN of the tuning job created in Amazon SageMaker. """ @@ -663,12 +693,14 @@ def __init__( wait_for_completion: bool = True, check_interval: int = CHECK_INTERVAL_SECOND, max_ingestion_time: int | None = None, + deferrable: bool = False, **kwargs, ): super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs) self.wait_for_completion = wait_for_completion self.check_interval = check_interval self.max_ingestion_time = max_ingestion_time + self.deferrable = deferrable def expand_role(self) -> None: """Expands an IAM role name into an ARN.""" @@ -695,16 +727,46 @@ def execute(self, context: Context) -> dict: ) response = self.hook.create_tuning_job( self.config, - wait_for_completion=self.wait_for_completion, + wait_for_completion=False, # we handle this here check_interval=self.check_interval, max_ingestion_time=self.max_ingestion_time, ) if response["ResponseMetadata"]["HTTPStatusCode"] != 200: raise AirflowException(f"Sagemaker Tuning Job creation failed: {response}") + + if self.deferrable: + self.defer( + trigger=SageMakerTrigger( + job_name=self.config["HyperParameterTuningJobName"], + job_type="tuning", + poke_interval=self.check_interval, + aws_conn_id=self.aws_conn_id, + ), + method_name="execute_complete", + timeout=datetime.timedelta(seconds=self.max_ingestion_time) + if self.max_ingestion_time is not None + else None, + ) + description = {} # never executed but makes static checkers happy + elif self.wait_for_completion: + description = self.hook.check_status( + self.config["HyperParameterTuningJobName"], + "HyperParameterTuningJobStatus", + self.hook.describe_tuning_job, + self.check_interval, + self.max_ingestion_time, + ) else: - return { - "Tuning": serialize(self.hook.describe_tuning_job(self.config["HyperParameterTuningJobName"])) - } + description = self.hook.describe_tuning_job(self.config["HyperParameterTuningJobName"]) + + return {"Tuning": serialize(description)} + + def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error while running job: {event}") + return { + "Tuning": serialize(self.hook.describe_tuning_job(self.config["HyperParameterTuningJobName"])) + } class SageMakerModelOperator(SageMakerBaseOperator): diff --git a/airflow/providers/amazon/aws/triggers/sagemaker.py b/airflow/providers/amazon/aws/triggers/sagemaker.py index 92266cad5ffdb..ca511a4a46d03 100644 --- a/airflow/providers/amazon/aws/triggers/sagemaker.py +++ b/airflow/providers/amazon/aws/triggers/sagemaker.py @@ -21,6 +21,7 @@ from typing import Any from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook +from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -41,7 +42,7 @@ def __init__( job_name: str, job_type: str, poke_interval: int = 30, - max_attempts: int | None = None, + max_attempts: int = 480, aws_conn_id: str = "aws_default", ): super().__init__() @@ -74,14 +75,28 @@ def _get_job_type_waiter(job_type: str) -> str: "training": "TrainingJobComplete", "transform": "TransformJobComplete", "processing": "ProcessingJobComplete", + "tuning": "TuningJobComplete", + "endpoint": "endpoint_in_service", # this one is provided by boto }[job_type.lower()] @staticmethod - def _get_job_type_waiter_job_name_arg(job_type: str) -> str: + def _get_waiter_arg_name(job_type: str) -> str: return { "training": "TrainingJobName", "transform": "TransformJobName", "processing": "ProcessingJobName", + "tuning": "HyperParameterTuningJobName", + "endpoint": "EndpointName", + }[job_type.lower()] + + @staticmethod + def _get_response_status_key(job_type: str) -> str: + return { + "training": "TrainingJobStatus", + "transform": "TransformJobStatus", + "processing": "ProcessingJobStatus", + "tuning": "HyperParameterTuningJobStatus", + "endpoint": "EndpointStatus", }[job_type.lower()] async def run(self): @@ -90,12 +105,13 @@ async def run(self): waiter = self.hook.get_waiter( self._get_job_type_waiter(self.job_type), deferrable=True, client=client ) - waiter_args = { - self._get_job_type_waiter_job_name_arg(self.job_type): self.job_name, - "WaiterConfig": { - "Delay": self.poke_interval, - "MaxAttempts": self.max_attempts, - }, - } - await waiter.wait(**waiter_args) - yield TriggerEvent({"status": "success", "message": "Job completed."}) + await async_wait( + waiter=waiter, + waiter_delay=self.poke_interval, + waiter_max_attempts=self.max_attempts, + args={self._get_waiter_arg_name(self.job_type): self.job_name}, + failure_message=f"Error while waiting for {self.job_type} job", + status_message=f"{self.job_type} job not done yet", + status_args=[self._get_response_status_key(self.job_type)], + ) + yield TriggerEvent({"status": "success", "message": "Job completed."}) diff --git a/airflow/providers/amazon/aws/waiters/sagemaker.json b/airflow/providers/amazon/aws/waiters/sagemaker.json index 73e3f09925c89..2c2760982c5f4 100644 --- a/airflow/providers/amazon/aws/waiters/sagemaker.json +++ b/airflow/providers/amazon/aws/waiters/sagemaker.json @@ -78,6 +78,32 @@ "state": "failure" } ] + }, + "TuningJobComplete": { + "delay": 30, + "operation": "DescribeHyperParameterTuningJob", + "maxAttempts": 60, + "description": "Wait until job is COMPLETED", + "acceptors": [ + { + "matcher": "path", + "argument": "HyperParameterTuningJobStatus", + "expected": "Completed", + "state": "success" + }, + { + "matcher": "path", + "argument": "HyperParameterTuningJobStatus", + "expected": "Failed", + "state": "failure" + }, + { + "matcher": "path", + "argument": "HyperParameterTuningJobStatus", + "expected": "Stopped", + "state": "failure" + } + ] } } } diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py index 498a38b816076..8a566535b98e3 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py @@ -22,10 +22,11 @@ import pytest from botocore.exceptions import ClientError -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook from airflow.providers.amazon.aws.operators import sagemaker from airflow.providers.amazon.aws.operators.sagemaker import SageMakerEndpointOperator +from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger CREATE_MODEL_PARAMS: dict = { "ModelName": "model_name", @@ -83,12 +84,12 @@ def test_integer_fields(self, serialize, mock_endpoint, mock_endpoint_config, mo @mock.patch.object(sagemaker, "serialize", return_value="") def test_execute(self, serialize, mock_endpoint, mock_endpoint_config, mock_model, mock_client): mock_endpoint.return_value = {"EndpointArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}} + self.sagemaker.execute(None) + mock_model.assert_called_once_with(CREATE_MODEL_PARAMS) mock_endpoint_config.assert_called_once_with(CREATE_ENDPOINT_CONFIG_PARAMS) - mock_endpoint.assert_called_once_with( - CREATE_ENDPOINT_PARAMS, wait_for_completion=False, check_interval=5, max_ingestion_time=None - ) + mock_endpoint.assert_called_once_with(CREATE_ENDPOINT_PARAMS, wait_for_completion=False) assert self.sagemaker.integer_fields == EXPECTED_INTEGER_FIELDS for variant in self.sagemaker.config["EndpointConfig"]["ProductionVariants"]: assert variant["InitialInstanceCount"] == int(variant["InitialInstanceCount"]) @@ -120,3 +121,18 @@ def test_execute_with_duplicate_endpoint_creation( "ResponseMetadata": {"HTTPStatusCode": 200}, } self.sagemaker.execute(None) + + @mock.patch.object(SageMakerHook, "create_model") + @mock.patch.object(SageMakerHook, "create_endpoint_config") + @mock.patch.object(SageMakerHook, "create_endpoint") + def test_deferred(self, mock_create_endpoint, _, __): + self.sagemaker.deferrable = True + + mock_create_endpoint.return_value = {"ResponseMetadata": {"HTTPStatusCode": 200}} + + with pytest.raises(TaskDeferred) as defer: + self.sagemaker.execute(None) + + assert isinstance(defer.value.trigger, SageMakerTrigger) + assert defer.value.trigger.job_name == "endpoint_name" + assert defer.value.trigger.job_type == "endpoint" diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py b/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py index 4862b930f1d61..4d6805ec74dbe 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py @@ -21,10 +21,11 @@ import pytest -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook from airflow.providers.amazon.aws.operators import sagemaker from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTuningOperator +from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger EXPECTED_INTEGER_FIELDS: list[list[str]] = [ ["HyperParameterTuningJobConfig", "ResourceLimits", "MaxNumberOfTrainingJobs"], @@ -107,3 +108,15 @@ def test_execute_with_failure(self, mock_tuning, mock_client): mock_tuning.return_value = {"TrainingJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 404}} with pytest.raises(AirflowException): self.sagemaker.execute(None) + + @mock.patch.object(SageMakerHook, "create_tuning_job") + def test_defers(self, create_mock): + create_mock.return_value = {"ResponseMetadata": {"HTTPStatusCode": 200}} + self.sagemaker.deferrable = True + + with pytest.raises(TaskDeferred) as defer: + self.sagemaker.execute(None) + + assert isinstance(defer.value.trigger, SageMakerTrigger) + assert defer.value.trigger.job_name == "job_name" + assert defer.value.trigger.job_type == "tuning" diff --git a/tests/providers/amazon/aws/triggers/test_sagemaker.py b/tests/providers/amazon/aws/triggers/test_sagemaker.py index 5a7f8e3c8e30b..f2d05f85a68dc 100644 --- a/tests/providers/amazon/aws/triggers/test_sagemaker.py +++ b/tests/providers/amazon/aws/triggers/test_sagemaker.py @@ -49,28 +49,26 @@ def test_sagemaker_trigger_serialize(self): assert args["aws_conn_id"] == AWS_CONN_ID @pytest.mark.asyncio + @pytest.mark.parametrize( + "job_type", + [ + "training", + "transform", + "processing", + "tuning", + "endpoint", + ], + ) @mock.patch("airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.get_waiter") @mock.patch("airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.async_conn") - @mock.patch("airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrigger._get_job_type_waiter") - @mock.patch( - "airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrigger._get_job_type_waiter_job_name_arg" - ) - async def test_sagemaker_trigger_run( - self, - mock_get_job_type_waiter_job_name_arg, - mock_get_job_type_waiter, - mock_async_conn, - mock_get_waiter, - ): - mock_get_job_type_waiter_job_name_arg.return_value = "job_name" - mock_get_job_type_waiter.return_value = "waiter" + async def test_sagemaker_trigger_run_all_job_types(self, mock_async_conn, mock_get_waiter, job_type): mock_async_conn.__aenter__.return_value = mock.MagicMock() mock_get_waiter().wait = AsyncMock() sagemaker_trigger = SageMakerTrigger( job_name=JOB_NAME, - job_type=JOB_TYPE, + job_type=job_type, poke_interval=POKE_INTERVAL, max_attempts=MAX_ATTEMPTS, aws_conn_id=AWS_CONN_ID, diff --git a/tests/system/providers/amazon/aws/example_sagemaker.py b/tests/system/providers/amazon/aws/example_sagemaker.py index 9506970446320..2b0f3fc6ef0a5 100644 --- a/tests/system/providers/amazon/aws/example_sagemaker.py +++ b/tests/system/providers/amazon/aws/example_sagemaker.py @@ -159,12 +159,11 @@ def _build_and_upload_docker_image(preprocess_script, repository_uri): docker_build_and_push_commands = f""" cp /root/.aws/credentials /tmp/credentials && # login to public ecr repo containing amazonlinux image - docker login --username {creds.username} --password {creds.password} public.ecr.aws + docker login --username {creds.username} --password {creds.password} public.ecr.aws && docker build --platform=linux/amd64 -f {dockerfile.name} -t {repository_uri} /tmp && rm /tmp/credentials && # login again, this time to the private repo we created to hold that specific image - aws ecr get-login-password --region {ecr_region} | docker login --username {creds.username} --password {creds.password} {repository_uri} && docker push {repository_uri} """ @@ -178,7 +177,8 @@ def _build_and_upload_docker_image(preprocess_script, repository_uri): if docker_build.returncode != 0: raise RuntimeError( "Failed to prepare docker image for the preprocessing job.\n" - f"The following error happened while executing the sequence of bash commands:\n{stderr}" + "The following error happened while executing the sequence of bash commands:\n" + f"{stderr.decode()}" )