diff --git a/airflow/providers/amazon/aws/sensors/batch.py b/airflow/providers/amazon/aws/sensors/batch.py index c93fc3d8b3140..475b0ecb71bb6 100644 --- a/airflow/providers/amazon/aws/sensors/batch.py +++ b/airflow/providers/amazon/aws/sensors/batch.py @@ -16,13 +16,15 @@ # under the License. from __future__ import annotations +from datetime import timedelta from functools import cached_property -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, Any, Sequence from deprecated import deprecated from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook +from airflow.providers.amazon.aws.triggers.batch import BatchSensorTrigger from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: @@ -41,6 +43,10 @@ class BatchSensor(BaseSensorOperator): :param job_id: Batch job_id to check the state for :param aws_conn_id: aws connection to use, defaults to 'aws_default' :param region_name: aws region name associated with the client + :param deferrable: Run sensor in the deferrable mode. + :param poke_interval: polling period in seconds to check for the status of the job. + :param max_retries: Number of times to poll for job state before + returning the current state. """ template_fields: Sequence[str] = ("job_id",) @@ -53,12 +59,18 @@ def __init__( job_id: str, aws_conn_id: str = "aws_default", region_name: str | None = None, + deferrable: bool = False, + poke_interval: float = 5, + max_retries: int = 5, **kwargs, ): super().__init__(**kwargs) self.job_id = job_id self.aws_conn_id = aws_conn_id self.region_name = region_name + self.deferrable = deferrable + self.poke_interval = poke_interval + self.max_retries = max_retries def poke(self, context: Context) -> bool: job_description = self.hook.get_job_description(self.job_id) @@ -75,6 +87,36 @@ def poke(self, context: Context) -> bool: raise AirflowException(f"Batch sensor failed. Unknown AWS Batch job status: {state}") + def execute(self, context: Context) -> None: + if not self.deferrable: + super().execute(context=context) + else: + timeout = ( + timedelta(seconds=self.max_retries * self.poke_interval + 60) + if self.max_retries + else self.execution_timeout + ) + self.defer( + timeout=timeout, + trigger=BatchSensorTrigger( + job_id=self.job_id, + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + poke_interval=self.poke_interval, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: Context, event: dict[str, Any]) -> None: + """ + Callback for when the trigger fires - returns immediately. + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + if "status" in event and event["status"] == "failure": + raise AirflowException(event["message"]) + self.log.info(event["message"]) + @deprecated(reason="use `hook` property instead.") def get_hook(self) -> BatchClientHook: """Create and return a BatchClientHook.""" diff --git a/airflow/providers/amazon/aws/triggers/batch.py b/airflow/providers/amazon/aws/triggers/batch.py index dc858a80fd710..f4a5de15254fa 100644 --- a/airflow/providers/amazon/aws/triggers/batch.py +++ b/airflow/providers/amazon/aws/triggers/batch.py @@ -105,3 +105,86 @@ async def run(self): yield TriggerEvent({"status": "failure", "message": "Job Failed - max attempts reached."}) else: yield TriggerEvent({"status": "success", "job_id": self.job_id}) + + +class BatchSensorTrigger(BaseTrigger): + """ + Checks for the status of a submitted job_id to AWS Batch until it reaches a failure or a success state. + BatchSensorTrigger is fired as deferred class with params to poll the job state in Triggerer. + + :param job_id: the job ID, to poll for job completion or not + :param region_name: AWS region name to use + Override the region_name in connection (if provided) + :param aws_conn_id: connection id of AWS credentials / region name. If None, + credential boto3 strategy will be used + :param poke_interval: polling period in seconds to check for the status of the job + """ + + def __init__( + self, + job_id: str, + region_name: str | None, + aws_conn_id: str | None = "aws_default", + poke_interval: float = 5, + ): + super().__init__() + self.job_id = job_id + self.aws_conn_id = aws_conn_id + self.region_name = region_name + self.poke_interval = poke_interval + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serializes BatchSensorTrigger arguments and classpath.""" + return ( + "airflow.providers.amazon.aws.triggers.batch.BatchSensorTrigger", + { + "job_id": self.job_id, + "aws_conn_id": self.aws_conn_id, + "region_name": self.region_name, + "poke_interval": self.poke_interval, + }, + ) + + @cached_property + def hook(self) -> BatchClientHook: + return BatchClientHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) + + async def run(self): + """ + Make async connection using aiobotocore library to AWS Batch, + periodically poll for the Batch job status. + + The status that indicates job completion are: 'SUCCEEDED'|'FAILED'. + """ + async with self.hook.async_conn as client: + waiter = self.hook.get_waiter("batch_job_complete", deferrable=True, client=client) + attempt = 0 + while True: + attempt = attempt + 1 + try: + await waiter.wait( + jobs=[self.job_id], + WaiterConfig={ + "Delay": int(self.poke_interval), + "MaxAttempts": 1, + }, + ) + break + except WaiterError as error: + if "error" in str(error): + yield TriggerEvent({"status": "failure", "message": f"Job Failed: {error}"}) + break + self.log.info( + "Job response is %s. Retrying attempt %s", + error.last_response["Error"]["Message"], + attempt, + ) + await asyncio.sleep(int(self.poke_interval)) + + yield TriggerEvent( + { + "status": "success", + "job_id": self.job_id, + "message": f"Job {self.job_id} Succeeded", + } + ) diff --git a/docs/apache-airflow-providers-amazon/operators/batch.rst b/docs/apache-airflow-providers-amazon/operators/batch.rst index bcfb86dbf73f6..4cc2a2b0cced4 100644 --- a/docs/apache-airflow-providers-amazon/operators/batch.rst +++ b/docs/apache-airflow-providers-amazon/operators/batch.rst @@ -77,6 +77,15 @@ use :class:`~airflow.providers.amazon.aws.sensors.batch.BatchSensor`. :start-after: [START howto_sensor_batch] :end-before: [END howto_sensor_batch] +In order to monitor the state of the AWS Batch Job asynchronously, use +:class:`~airflow.providers.amazon.aws.sensors.batch.BatchSensor` with the +parameter ``deferrable`` set to True. + +Since this will release the Airflow worker slot , it will lead to efficient +utilization of available resources on your Airflow deployment. +This will also need the triggerer component to be available in your +Airflow deployment. + .. _howto/sensor:BatchComputeEnvironmentSensor: Wait on an AWS Batch compute environment status diff --git a/tests/providers/amazon/aws/sensors/test_batch.py b/tests/providers/amazon/aws/sensors/test_batch.py index 835b99ad0a5c2..42e9bffb5b688 100644 --- a/tests/providers/amazon/aws/sensors/test_batch.py +++ b/tests/providers/amazon/aws/sensors/test_batch.py @@ -20,16 +20,18 @@ import pytest -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook from airflow.providers.amazon.aws.sensors.batch import ( BatchComputeEnvironmentSensor, BatchJobQueueSensor, BatchSensor, ) +from airflow.providers.amazon.aws.triggers.batch import BatchSensorTrigger TASK_ID = "batch_job_sensor" JOB_ID = "8222a1c2-b246-4e19-b1b8-0039bb4407c0" +AWS_REGION = "eu-west-1" class TestBatchSensor: @@ -195,3 +197,39 @@ def test_poke_invalid(self, mock_batch_client): jobQueues=[self.job_queue], ) assert "AWS Batch job queue failed" in str(ctx.value) + + +class TestBatchAsyncSensor: + TASK = BatchSensor(task_id="task", job_id=JOB_ID, region_name=AWS_REGION, deferrable=True) + + def test_batch_sensor_async(self): + """ + Asserts that a task is deferred and a BatchSensorTrigger will be fired + when the BatchSensorAsync is executed. + """ + + with pytest.raises(TaskDeferred) as exc: + self.TASK.execute({}) + assert isinstance(exc.value.trigger, BatchSensorTrigger), "Trigger is not a BatchSensorTrigger" + + def test_batch_sensor_async_execute_failure(self): + """Tests that an AirflowException is raised in case of error event""" + + with pytest.raises(AirflowException) as exc_info: + self.TASK.execute_complete( + context={}, event={"status": "failure", "message": "test failure message"} + ) + + assert str(exc_info.value) == "test failure message" + + @pytest.mark.parametrize( + "event", + [{"status": "success", "message": f"AWS Batch job ({JOB_ID}) succeeded"}], + ) + def test_batch_sensor_async_execute_complete(self, caplog, event): + """Tests that execute_complete method returns None and that it prints expected log""" + + with mock.patch.object(self.TASK.log, "info") as mock_log_info: + assert self.TASK.execute_complete(context={}, event=event) is None + + mock_log_info.assert_called_with(event["message"]) diff --git a/tests/providers/amazon/aws/triggers/test_batch.py b/tests/providers/amazon/aws/triggers/test_batch.py index 6f87d92a2da6a..5cf125f8280a5 100644 --- a/tests/providers/amazon/aws/triggers/test_batch.py +++ b/tests/providers/amazon/aws/triggers/test_batch.py @@ -20,8 +20,9 @@ from unittest.mock import AsyncMock import pytest +from botocore.exceptions import WaiterError -from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger +from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger, BatchSensorTrigger from airflow.triggers.base import TriggerEvent BATCH_JOB_ID = "job_id" @@ -29,6 +30,7 @@ MAX_ATTEMPT = 5 AWS_CONN_ID = "aws_batch_job_conn" AWS_REGION = "us-east-2" +pytest.importorskip("aiobotocore") class TestBatchOperatorTrigger: @@ -69,3 +71,113 @@ async def test_batch_job_trigger_run(self, mock_async_conn, mock_get_waiter): response = await generator.asend(None) assert response == TriggerEvent({"status": "success", "job_id": BATCH_JOB_ID}) + + +class TestBatchSensorTrigger: + TRIGGER = BatchSensorTrigger( + job_id=BATCH_JOB_ID, + region_name=AWS_REGION, + aws_conn_id=AWS_CONN_ID, + poke_interval=POLL_INTERVAL, + ) + + def test_batch_sensor_trigger_serialization(self): + """ + Asserts that the BatchSensorTrigger correctly serializes its arguments + and classpath. + """ + + classpath, kwargs = self.TRIGGER.serialize() + assert classpath == "airflow.providers.amazon.aws.triggers.batch.BatchSensorTrigger" + assert kwargs == { + "job_id": BATCH_JOB_ID, + "region_name": AWS_REGION, + "aws_conn_id": AWS_CONN_ID, + "poke_interval": POLL_INTERVAL, + } + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.get_waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.async_conn") + async def test_batch_job_trigger_run(self, mock_async_conn, mock_get_waiter): + the_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = the_mock + + mock_get_waiter().wait = AsyncMock() + + batch_trigger = BatchOperatorTrigger( + job_id=BATCH_JOB_ID, + poll_interval=POLL_INTERVAL, + max_retries=MAX_ATTEMPT, + aws_conn_id=AWS_CONN_ID, + region_name=AWS_REGION, + ) + + generator = batch_trigger.run() + response = await generator.asend(None) + + assert response == TriggerEvent({"status": "success", "job_id": BATCH_JOB_ID}) + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.get_waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.async_conn") + @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.get_job_description") + async def test_batch_sensor_trigger_completed(self, mock_response, mock_async_conn, mock_get_waiter): + """Test if the success event is returned from trigger.""" + mock_response.return_value = {"status": "SUCCEEDED"} + + the_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = the_mock + + mock_get_waiter().wait = AsyncMock() + + trigger = BatchSensorTrigger( + job_id=BATCH_JOB_ID, + region_name=AWS_REGION, + aws_conn_id=AWS_CONN_ID, + ) + generator = trigger.run() + actual_response = await generator.asend(None) + assert ( + TriggerEvent( + {"status": "success", "job_id": BATCH_JOB_ID, "message": f"Job {BATCH_JOB_ID} Succeeded"} + ) + == actual_response + ) + + @pytest.mark.asyncio + @mock.patch("asyncio.sleep") + @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.get_waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.get_job_description") + @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.async_conn") + async def test_batch_sensor_trigger_failure( + self, mock_async_conn, mock_response, mock_get_waiter, mock_sleep + ): + """Test if the failure event is returned from trigger.""" + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + + mock_response.return_value = {"status": "failed"} + + name = "batch_job_complete" + reason = ( + "An error occurred (UnrecognizedClientException): The security token included in the " + "request is invalid. " + ) + last_response = ({"Error": {"Message": "The security token included in the request is invalid."}},) + + error_failed = WaiterError( + name=name, + reason=reason, + last_response=last_response, + ) + + mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error_failed]) + mock_sleep.return_value = True + + trigger = BatchSensorTrigger(job_id=BATCH_JOB_ID, region_name=AWS_REGION, aws_conn_id=AWS_CONN_ID) + generator = trigger.run() + actual_response = await generator.asend(None) + assert actual_response == TriggerEvent( + {"status": "failure", "message": f"Job Failed: Waiter {name} failed: {reason}"} + )