diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index a0fabb324d0ba..15ef379aad7d3 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -19,6 +19,7 @@ import ast import warnings +from datetime import timedelta from functools import cached_property from typing import TYPE_CHECKING, Any, Sequence from uuid import uuid4 @@ -27,7 +28,7 @@ from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook from airflow.providers.amazon.aws.links.emr import EmrClusterLink, EmrLogsLink, get_log_uri -from airflow.providers.amazon.aws.triggers.emr import EmrAddStepsTrigger +from airflow.providers.amazon.aws.triggers.emr import EmrAddStepsTrigger, EmrCreateJobFlowTrigger from airflow.providers.amazon.aws.utils.waiter import waiter from airflow.utils.helpers import exactly_one, prune_dict from airflow.utils.types import NOTSET, ArgNotSet @@ -624,6 +625,9 @@ class EmrCreateJobFlowOperator(BaseOperator): wait_for_completion=True, None = no limit) (Deprecated. Please use waiter_max_attempts.) :param waiter_check_interval_seconds: Number of seconds between polling the jobflow state. Defaults to 60 seconds. (Deprecated. Please use waiter_delay.) + :param deferrable: If True, the operator will wait asynchronously for the crawl to complete. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) """ template_fields: Sequence[str] = ( @@ -652,6 +656,7 @@ def __init__( waiter_delay: int | None | ArgNotSet = NOTSET, waiter_countdown: int | None = None, waiter_check_interval_seconds: int = 60, + deferrable: bool = False, **kwargs: Any, ): if waiter_max_attempts is NOTSET: @@ -676,10 +681,9 @@ def __init__( self.job_flow_overrides = job_flow_overrides or {} self.region_name = region_name self.wait_for_completion = wait_for_completion - self.waiter_max_attempts = waiter_max_attempts - self.waiter_delay = waiter_delay - - self._job_flow_id: str | None = None + self.waiter_max_attempts = int(waiter_max_attempts) # type: ignore[arg-type] + self.waiter_delay = int(waiter_delay) # type: ignore[arg-type] + self.deferrable = deferrable @cached_property def _emr_hook(self) -> EmrHook: @@ -720,7 +724,19 @@ def execute(self, context: Context) -> str | None: job_flow_id=self._job_flow_id, log_uri=get_log_uri(emr_client=self._emr_hook.conn, job_flow_id=self._job_flow_id), ) - + if self.deferrable: + self.defer( + trigger=EmrCreateJobFlowTrigger( + job_flow_id=self._job_flow_id, + aws_conn_id=self.aws_conn_id, + poll_interval=self.waiter_delay, + max_attempts=self.waiter_max_attempts, + ), + method_name="execute_complete", + # timeout is set to ensure that if a trigger dies, the timeout does not restart + # 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent) + timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60), + ) if self.wait_for_completion: self._emr_hook.get_waiter("job_flow_waiting").wait( ClusterId=self._job_flow_id, @@ -734,6 +750,13 @@ def execute(self, context: Context) -> str | None: return self._job_flow_id + def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error creating jobFlow: {event}") + else: + self.log.info("JobFlow created successfully") + return event["job_flow_id"] + def on_kill(self) -> None: """ Terminate the EMR cluster (job flow). If TerminationProtected=True on the cluster, diff --git a/airflow/providers/amazon/aws/triggers/emr.py b/airflow/providers/amazon/aws/triggers/emr.py index 2afc1c45afbdf..76ee47bc8baa6 100644 --- a/airflow/providers/amazon/aws/triggers/emr.py +++ b/airflow/providers/amazon/aws/triggers/emr.py @@ -21,8 +21,10 @@ from botocore.exceptions import WaiterError +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.emr import EmrHook from airflow.triggers.base import BaseTrigger, TriggerEvent +from airflow.utils.helpers import prune_dict class EmrAddStepsTrigger(BaseTrigger): @@ -97,3 +99,77 @@ async def run(self): yield TriggerEvent({"status": "failure", "message": "Steps failed: max attempts reached"}) else: yield TriggerEvent({"status": "success", "message": "Steps completed", "step_ids": self.step_ids}) + + +class EmrCreateJobFlowTrigger(BaseTrigger): + """ + Trigger for EmrCreateJobFlowOperator. + The trigger will asynchronously poll the boto3 API and wait for the + JobFlow to finish executing. + + :param job_flow_id: The id of the job flow to wait for. + :param poll_interval: The amount of time in seconds to wait between attempts. + :param max_attempts: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. + """ + + def __init__( + self, + job_flow_id: str, + poll_interval: int, + max_attempts: int, + aws_conn_id: str, + ): + self.job_flow_id = job_flow_id + self.poll_interval = poll_interval + self.max_attempts = max_attempts + self.aws_conn_id = aws_conn_id + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + self.__class__.__module__ + "." + self.__class__.__qualname__, + { + "job_flow_id": self.job_flow_id, + "poll_interval": str(self.poll_interval), + "max_attempts": str(self.max_attempts), + "aws_conn_id": self.aws_conn_id, + }, + ) + + async def run(self): + self.hook = EmrHook(aws_conn_id=self.aws_conn_id) + async with self.hook.async_conn as client: + attempt = 0 + waiter = self.hook.get_waiter("job_flow_waiting", deferrable=True, client=client) + while attempt < int(self.max_attempts): + attempt = attempt + 1 + try: + await waiter.wait( + ClusterId=self.job_flow_id, + WaiterConfig=prune_dict( + { + "Delay": self.poll_interval, + "MaxAttempts": 1, + } + ), + ) + break + except WaiterError as error: + if "terminal failure" in str(error): + raise AirflowException(f"JobFlow creation failed: {error}") + self.log.info( + "Status of jobflow is %s - %s", + error.last_response["Cluster"]["Status"]["State"], + error.last_response["Cluster"]["Status"]["StateChangeReason"], + ) + await asyncio.sleep(int(self.poll_interval)) + if attempt >= int(self.max_attempts): + raise AirflowException(f"JobFlow creation failed - max attempts reached: {self.max_attempts}") + else: + yield TriggerEvent( + { + "status": "success", + "message": "JobFlow completed successfully", + "job_flow_id": self.job_flow_id, + } + ) diff --git a/airflow/providers/amazon/aws/waiters/emr.json b/airflow/providers/amazon/aws/waiters/emr.json index 78afee6b544d1..13bc5857e30cc 100644 --- a/airflow/providers/amazon/aws/waiters/emr.json +++ b/airflow/providers/amazon/aws/waiters/emr.json @@ -58,19 +58,19 @@ "acceptors": [ { "matcher": "path", - "argument": "cluster.status", + "argument": "Cluster.Status.State", "expected": "WAITING", "state": "success" }, { "matcher": "path", - "argument": "cluster.status", + "argument": "Cluster.Status.State", "expected": "TERMINATED", "state": "success" }, { "matcher": "path", - "argument": "cluster.status", + "argument": "Cluster.Status.State", "expected": "TERMINATED_WITH_ERRORS", "state": "failure" } diff --git a/docs/apache-airflow-providers-amazon/operators/emr/emr.rst b/docs/apache-airflow-providers-amazon/operators/emr/emr.rst index 6cd628eb239cd..d26a427a64e19 100644 --- a/docs/apache-airflow-providers-amazon/operators/emr/emr.rst +++ b/docs/apache-airflow-providers-amazon/operators/emr/emr.rst @@ -47,6 +47,10 @@ Create an EMR job flow You can use :class:`~airflow.providers.amazon.aws.operators.emr.EmrCreateJobFlowOperator` to create a new EMR job flow. The cluster will be terminated automatically after finishing the steps. +This operator can be run in deferrable mode by passing ``deferrable=True`` as a parameter. +Using ``deferrable`` mode will release worker slots and leads to efficient utilization of +resources within Airflow cluster.However this mode will need the Airflow triggerer to be +available in your deployment. JobFlow configuration """"""""""""""""""""" diff --git a/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py b/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py index 6be459008e0bb..7e531f4a0f183 100644 --- a/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py +++ b/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py @@ -22,12 +22,15 @@ from unittest import mock from unittest.mock import MagicMock, patch +import pytest from botocore.waiter import Waiter from jinja2 import StrictUndefined +from airflow.exceptions import TaskDeferred from airflow.models import DAG, DagRun, TaskInstance from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.providers.amazon.aws.operators.emr import EmrCreateJobFlowOperator +from airflow.providers.amazon.aws.triggers.emr import EmrCreateJobFlowTrigger from airflow.utils import timezone from tests.providers.amazon.aws.utils.test_waiter import assert_expected_waiter_type from tests.test_utils import AIRFLOW_MAIN_FOLDER @@ -192,3 +195,28 @@ def test_execute_with_wait(self, mock_waiter, *_): assert self.operator.execute(self.mock_context) == JOB_FLOW_ID mock_waiter.assert_called_once_with(mock.ANY, ClusterId=JOB_FLOW_ID, WaiterConfig=mock.ANY) assert_expected_waiter_type(mock_waiter, "job_flow_waiting") + + @patch.object(S3Hook, "parse_s3_url", return_value="valid_uri") + def test_create_job_flow_deferrable(self, _): + """ + Test to make sure that the operator raises a TaskDeferred exception + if run in deferrable mode. + """ + self.emr_client_mock.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN + + # Mock out the emr_client creator + emr_session_mock = MagicMock() + emr_session_mock.client.return_value = self.emr_client_mock + boto3_session_mock = MagicMock(return_value=emr_session_mock) + + self.operator.deferrable = True + with patch("boto3.session.Session", boto3_session_mock), patch( + "airflow.providers.amazon.aws.hooks.base_aws.isinstance" + ) as mock_isinstance: + mock_isinstance.return_value = True + with pytest.raises(TaskDeferred) as exc: + self.operator.execute(self.mock_context) + + assert isinstance( + exc.value.trigger, EmrCreateJobFlowTrigger + ), "Trigger is not a EmrCreateJobFlowTrigger" diff --git a/tests/providers/amazon/aws/triggers/test_emr.py b/tests/providers/amazon/aws/triggers/test_emr.py new file mode 100644 index 0000000000000..c749c4ee9abcf --- /dev/null +++ b/tests/providers/amazon/aws/triggers/test_emr.py @@ -0,0 +1,200 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock +from unittest.mock import AsyncMock + +import pytest +from botocore.exceptions import WaiterError + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.emr import EmrHook +from airflow.providers.amazon.aws.triggers.emr import EmrCreateJobFlowTrigger +from airflow.triggers.base import TriggerEvent + +TEST_JOB_FLOW_ID = "test-job-flow-id" +TEST_POLL_INTERVAL = 10 +TEST_MAX_ATTEMPTS = 10 +TEST_AWS_CONN_ID = "test-aws-id" + + +class TestEmrCreateJobFlowTrigger: + def test_emr_create_job_flow_trigger_serialize(self): + """Test serialize method to make sure all parameters are being serialized correctly.""" + emr_create_job_flow_trigger = EmrCreateJobFlowTrigger( + job_flow_id=TEST_JOB_FLOW_ID, + aws_conn_id=TEST_AWS_CONN_ID, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=TEST_MAX_ATTEMPTS, + ) + class_path, args = emr_create_job_flow_trigger.serialize() + assert class_path == "airflow.providers.amazon.aws.triggers.emr.EmrCreateJobFlowTrigger" + assert args["job_flow_id"] == TEST_JOB_FLOW_ID + assert args["aws_conn_id"] == TEST_AWS_CONN_ID + assert args["poll_interval"] == str(TEST_POLL_INTERVAL) + assert args["max_attempts"] == str(TEST_MAX_ATTEMPTS) + + @pytest.mark.asyncio + @mock.patch.object(EmrHook, "get_waiter") + @mock.patch.object(EmrHook, "async_conn") + async def test_emr_create_job_flow_trigger_run(self, mock_async_conn, mock_get_waiter): + """ + Test run method, with basic success case to assert TriggerEvent contains the + correct payload. + """ + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + + mock_get_waiter().wait = AsyncMock() + + emr_create_job_flow_trigger = EmrCreateJobFlowTrigger( + job_flow_id=TEST_JOB_FLOW_ID, + aws_conn_id=TEST_AWS_CONN_ID, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=TEST_MAX_ATTEMPTS, + ) + + generator = emr_create_job_flow_trigger.run() + response = await generator.asend(None) + + assert response == TriggerEvent( + { + "status": "success", + "message": "JobFlow completed successfully", + "job_flow_id": TEST_JOB_FLOW_ID, + } + ) + + @pytest.mark.asyncio + @mock.patch("asyncio.sleep") + @mock.patch.object(EmrHook, "get_waiter") + @mock.patch.object(EmrHook, "async_conn") + async def test_emr_create_job_flow_trigger_run_multiple_attempts( + self, mock_async_conn, mock_get_waiter, mock_sleep + ): + """ + Test run method with multiple attempts to make sure the waiter retries + are working as expected. + """ + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + error = WaiterError( + name="test_name", + reason="test_reason", + last_response={"Cluster": {"Status": {"State": "STARTING", "StateChangeReason": "test-reason"}}}, + ) + mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error, error, True]) + mock_sleep.return_value = True + + emr_create_job_flow_trigger = EmrCreateJobFlowTrigger( + job_flow_id=TEST_JOB_FLOW_ID, + aws_conn_id=TEST_AWS_CONN_ID, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=TEST_MAX_ATTEMPTS, + ) + + generator = emr_create_job_flow_trigger.run() + response = await generator.asend(None) + + assert mock_get_waiter().wait.call_count == 3 + assert response == TriggerEvent( + { + "status": "success", + "message": "JobFlow completed successfully", + "job_flow_id": TEST_JOB_FLOW_ID, + } + ) + + @pytest.mark.asyncio + @mock.patch("asyncio.sleep") + @mock.patch.object(EmrHook, "get_waiter") + @mock.patch.object(EmrHook, "async_conn") + async def test_emr_create_job_flow_trigger_run_attempts_exceeded( + self, mock_async_conn, mock_get_waiter, mock_sleep + ): + """ + Test run method with max_attempts set to 2 to test the Trigger yields + the correct TriggerEvent in the case of max_attempts being exceeded. + """ + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + error = WaiterError( + name="test_name", + reason="test_reason", + last_response={"Cluster": {"Status": {"State": "STARTING", "StateChangeReason": "test-reason"}}}, + ) + mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error, error, True]) + mock_sleep.return_value = True + + emr_create_job_flow_trigger = EmrCreateJobFlowTrigger( + job_flow_id=TEST_JOB_FLOW_ID, + aws_conn_id=TEST_AWS_CONN_ID, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=2, + ) + + with pytest.raises(AirflowException) as exc: + generator = emr_create_job_flow_trigger.run() + await generator.asend(None) + + assert str(exc.value) == "JobFlow creation failed - max attempts reached: 2" + assert mock_get_waiter().wait.call_count == 2 + + @pytest.mark.asyncio + @mock.patch("asyncio.sleep") + @mock.patch.object(EmrHook, "get_waiter") + @mock.patch.object(EmrHook, "async_conn") + async def test_emr_create_job_flow_trigger_run_attempts_failed( + self, mock_async_conn, mock_get_waiter, mock_sleep + ): + """ + Test run method with a failure case to test Trigger yields the correct + failure TriggerEvent. + """ + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + error_starting = WaiterError( + name="test_name", + reason="test_reason", + last_response={"Cluster": {"Status": {"State": "STARTING", "StateChangeReason": "test-reason"}}}, + ) + error_failed = WaiterError( + name="test_name", + reason="Waiter encountered a terminal failure state:", + last_response={ + "Cluster": {"Status": {"State": "TERMINATED_WITH_ERRORS", "StateChangeReason": "test-reason"}} + }, + ) + mock_get_waiter().wait.side_effect = AsyncMock( + side_effect=[error_starting, error_starting, error_failed] + ) + mock_sleep.return_value = True + + emr_create_job_flow_trigger = EmrCreateJobFlowTrigger( + job_flow_id=TEST_JOB_FLOW_ID, + aws_conn_id=TEST_AWS_CONN_ID, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=TEST_MAX_ATTEMPTS, + ) + + with pytest.raises(AirflowException) as exc: + generator = emr_create_job_flow_trigger.run() + await generator.asend(None) + + assert str(exc.value) == f"JobFlow creation failed: {error_failed}" + assert mock_get_waiter().wait.call_count == 3