diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py index 48e8773c16a09..e4cd06c005fa3 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py @@ -1150,6 +1150,9 @@ class EmrServerlessStartJobOperator(AwsBaseOperator[EmrServerlessHook]): :param enable_application_ui_links: If True, the operator will generate one-time links to EMR Serverless application UIs. The generated links will allow any user with access to the DAG to see the Spark or Tez UI or Spark stdout logs. Defaults to False. + :param cancel_on_kill: If True, the EMR Serverless job will be cancelled when the task is killed + while in deferrable mode. This ensures that orphan jobs are not left running in EMR Serverless + when an Airflow task is cancelled. Defaults to True. """ aws_hook_class = EmrServerlessHook @@ -1188,6 +1191,7 @@ def __init__( waiter_delay: int | ArgNotSet = NOTSET, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), enable_application_ui_links: bool = False, + cancel_on_kill: bool = True, **kwargs, ): waiter_delay = 60 if waiter_delay is NOTSET else waiter_delay @@ -1205,6 +1209,7 @@ def __init__( self.job_id: str | None = None self.deferrable = deferrable self.enable_application_ui_links = enable_application_ui_links + self.cancel_on_kill = cancel_on_kill super().__init__(**kwargs) self.client_request_token = client_request_token or str(uuid4()) @@ -1269,6 +1274,7 @@ def execute(self, context: Context, event: dict[str, Any] | None = None) -> str waiter_delay=self.waiter_delay, waiter_max_attempts=self.waiter_max_attempts, aws_conn_id=self.aws_conn_id, + cancel_on_kill=self.cancel_on_kill, ), method_name="execute_complete", timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay), @@ -1320,7 +1326,8 @@ def on_kill(self) -> None: """ Cancel the submitted job run. - Note: this method will not run in deferrable mode. + Note: In deferrable mode, this method will not run. Instead, job cancellation + is handled by the trigger's cancel_on_kill parameter when the task is killed. """ if self.job_id: self.log.info("Stopping job run with jobId - %s", self.job_id) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py b/providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py index 3e24fc2b6d160..d8f7b2000238b 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py @@ -16,15 +16,29 @@ # under the License. from __future__ import annotations +import asyncio import sys +from collections.abc import AsyncIterator from typing import TYPE_CHECKING +from asgiref.sync import sync_to_async + from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger +from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait +from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS +from airflow.triggers.base import TriggerEvent +from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: + from sqlalchemy.orm.session import Session + from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook +if not AIRFLOW_V_3_0_PLUS: + from airflow.models.taskinstance import TaskInstance + from airflow.utils.session import provide_session + class EmrAddStepsTrigger(AwsBaseWaiterTrigger): """ @@ -331,9 +345,10 @@ class EmrServerlessStartJobTrigger(AwsBaseWaiterTrigger): :param application_id: The ID of the application the job in being run on. :param job_id: The ID of the job run. - :waiter_delay: polling period in seconds to check for the status + :param waiter_delay: polling period in seconds to check for the status :param waiter_max_attempts: The maximum number of attempts to be made :param aws_conn_id: Reference to AWS connection id + :param cancel_on_kill: Flag to indicate whether to cancel the job when the task is killed. """ def __init__( @@ -343,9 +358,14 @@ def __init__( waiter_delay: int = 30, waiter_max_attempts: int = 60, aws_conn_id: str | None = "aws_default", + cancel_on_kill: bool = True, ) -> None: super().__init__( - serialized_fields={"application_id": application_id, "job_id": job_id}, + serialized_fields={ + "application_id": application_id, + "job_id": job_id, + "cancel_on_kill": cancel_on_kill, + }, waiter_name="serverless_job_completed", waiter_args={"applicationId": application_id, "jobRunId": job_id}, failure_message="Serverless Job failed", @@ -357,10 +377,117 @@ def __init__( waiter_max_attempts=waiter_max_attempts, aws_conn_id=aws_conn_id, ) + self.application_id = application_id + self.job_id = job_id + self.cancel_on_kill = cancel_on_kill def hook(self) -> AwsGenericHook: return EmrServerlessHook(self.aws_conn_id) + if not AIRFLOW_V_3_0_PLUS: + + @provide_session + def get_task_instance(self, session: Session) -> TaskInstance: + """Get the task instance for the current trigger (Airflow 2.x compatibility).""" + from sqlalchemy import select + + query = select(TaskInstance).where( + TaskInstance.dag_id == self.task_instance.dag_id, + TaskInstance.task_id == self.task_instance.task_id, + TaskInstance.run_id == self.task_instance.run_id, + TaskInstance.map_index == self.task_instance.map_index, + ) + task_instance = session.scalars(query).one_or_none() + if task_instance is None: + raise ValueError( + f"TaskInstance with dag_id: {self.task_instance.dag_id}, " + f"task_id: {self.task_instance.task_id}, " + f"run_id: {self.task_instance.run_id} and " + f"map_index: {self.task_instance.map_index} is not found" + ) + return task_instance + + async def get_task_state(self): + """Get the current state of the task instance (Airflow 3.x).""" + from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance + + task_states_response = await sync_to_async(RuntimeTaskInstance.get_task_states)( + dag_id=self.task_instance.dag_id, + task_ids=[self.task_instance.task_id], + run_ids=[self.task_instance.run_id], + map_index=self.task_instance.map_index, + ) + try: + task_state = task_states_response[self.task_instance.run_id][self.task_instance.task_id] + except Exception: + raise ValueError( + f"TaskInstance with dag_id: {self.task_instance.dag_id}, " + f"task_id: {self.task_instance.task_id}, " + f"run_id: {self.task_instance.run_id} and " + f"map_index: {self.task_instance.map_index} is not found" + ) + return task_state + + async def safe_to_cancel(self) -> bool: + """ + Whether it is safe to cancel the EMR Serverless job. + + Returns True if task is NOT DEFERRED (user-initiated cancellation). + Returns False if task is DEFERRED (triggerer restart - don't cancel job). + """ + if AIRFLOW_V_3_0_PLUS: + task_state = await self.get_task_state() + else: + task_instance = self.get_task_instance() # type: ignore[call-arg] + task_state = task_instance.state + return task_state != TaskInstanceState.DEFERRED + + async def run(self) -> AsyncIterator[TriggerEvent]: + """ + Run the trigger and wait for the job to complete. + + If the task is cancelled while waiting, attempt to cancel the EMR Serverless job + if cancel_on_kill is enabled and it's safe to do so. + """ + hook = self.hook() + try: + async with await hook.get_async_conn() as client: + waiter = hook.get_waiter( + self.waiter_name, + deferrable=True, + client=client, + config_overrides=self.waiter_config_overrides, + ) + await async_wait( + waiter, + self.waiter_delay, + self.attempts, + self.waiter_args, + self.failure_message, + self.status_message, + self.status_queries, + ) + yield TriggerEvent({"status": "success", self.return_key: self.return_value}) + except asyncio.CancelledError: + if self.job_id and self.cancel_on_kill and await self.safe_to_cancel(): + self.log.info( + "Task was cancelled. Cancelling EMR Serverless job. Application ID: %s, Job ID: %s", + self.application_id, + self.job_id, + ) + hook.conn.cancel_job_run(applicationId=self.application_id, jobRunId=self.job_id) + self.log.info("EMR Serverless job %s cancelled successfully.", self.job_id) + else: + self.log.info( + "Trigger may have shutdown or cancel_on_kill is disabled. " + "Skipping job cancellation. Application ID: %s, Job ID: %s", + self.application_id, + self.job_id, + ) + raise + except Exception as e: + yield TriggerEvent({"status": "failure", "message": str(e)}) + class EmrServerlessDeleteApplicationTrigger(AwsBaseWaiterTrigger): """ diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_emr.py b/providers/amazon/tests/unit/amazon/aws/triggers/test_emr.py index eb7f1851155ac..fef643dc92776 100644 --- a/providers/amazon/tests/unit/amazon/aws/triggers/test_emr.py +++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_emr.py @@ -16,7 +16,11 @@ # under the License. from __future__ import annotations +import asyncio import sys +from unittest import mock + +import pytest from airflow.providers.amazon.aws.triggers.emr import ( EmrAddStepsTrigger, @@ -269,8 +273,136 @@ def test_serialization(self): "waiter_max_attempts": 60, "job_id": "job_id", "aws_conn_id": "aws_default", + "cancel_on_kill": True, } + def test_serialization_cancel_on_kill_false(self): + """Test that cancel_on_kill=False is correctly serialized.""" + trigger = EmrServerlessStartJobTrigger( + application_id="test_app", + job_id="test_job", + waiter_delay=30, + waiter_max_attempts=60, + aws_conn_id="aws_default", + cancel_on_kill=False, + ) + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.amazon.aws.triggers.emr.EmrServerlessStartJobTrigger" + assert kwargs["cancel_on_kill"] is False + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.triggers.emr.async_wait") + @mock.patch("airflow.providers.amazon.aws.triggers.emr.EmrServerlessStartJobTrigger.safe_to_cancel") + async def test_emr_serverless_trigger_cancellation(self, mock_safe_to_cancel, mock_async_wait): + """ + Test that EmrServerlessStartJobTrigger cancels the job when task is killed + and safe_to_cancel returns True. + """ + mock_safe_to_cancel.return_value = True + mock_async_wait.side_effect = asyncio.CancelledError() + + trigger = EmrServerlessStartJobTrigger( + application_id="test_app", + job_id="test_job", + waiter_delay=30, + waiter_max_attempts=60, + aws_conn_id="aws_default", + cancel_on_kill=True, + ) + + mock_hook = mock.MagicMock() + mock_hook.get_waiter.return_value = mock.MagicMock() + mock_hook.conn.cancel_job_run.return_value = {"ResponseMetadata": {"HTTPStatusCode": 200}} + + mock_client = mock.MagicMock() + mock_async_cm = mock.MagicMock() + mock_async_cm.__aenter__ = mock.AsyncMock(return_value=mock_client) + mock_async_cm.__aexit__ = mock.AsyncMock(return_value=None) + mock_hook.get_async_conn = mock.AsyncMock(return_value=mock_async_cm) + + with mock.patch.object(trigger, "hook", return_value=mock_hook): + with pytest.raises(asyncio.CancelledError): + async for _ in trigger.run(): + pass + + mock_hook.conn.cancel_job_run.assert_called_once_with(applicationId="test_app", jobRunId="test_job") + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.triggers.emr.async_wait") + @mock.patch("airflow.providers.amazon.aws.triggers.emr.EmrServerlessStartJobTrigger.safe_to_cancel") + async def test_emr_serverless_trigger_no_cancellation_when_unsafe( + self, mock_safe_to_cancel, mock_async_wait + ): + """ + Test that EmrServerlessStartJobTrigger does NOT cancel the job when + safe_to_cancel returns False (e.g., triggerer shutdown). + """ + mock_safe_to_cancel.return_value = False + mock_async_wait.side_effect = asyncio.CancelledError() + + trigger = EmrServerlessStartJobTrigger( + application_id="test_app", + job_id="test_job", + waiter_delay=30, + waiter_max_attempts=60, + aws_conn_id="aws_default", + cancel_on_kill=True, + ) + + mock_hook = mock.MagicMock() + mock_hook.get_waiter.return_value = mock.MagicMock() + + mock_client = mock.MagicMock() + mock_async_cm = mock.MagicMock() + mock_async_cm.__aenter__ = mock.AsyncMock(return_value=mock_client) + mock_async_cm.__aexit__ = mock.AsyncMock(return_value=None) + mock_hook.get_async_conn = mock.AsyncMock(return_value=mock_async_cm) + + with mock.patch.object(trigger, "hook", return_value=mock_hook): + with pytest.raises(asyncio.CancelledError): + async for _ in trigger.run(): + pass + + mock_hook.conn.cancel_job_run.assert_not_called() + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.triggers.emr.async_wait") + @mock.patch("airflow.providers.amazon.aws.triggers.emr.EmrServerlessStartJobTrigger.safe_to_cancel") + async def test_emr_serverless_trigger_no_cancellation_when_disabled( + self, mock_safe_to_cancel, mock_async_wait + ): + """ + Test that EmrServerlessStartJobTrigger does NOT cancel the job when + cancel_on_kill=False. + """ + mock_safe_to_cancel.return_value = True + mock_async_wait.side_effect = asyncio.CancelledError() + + trigger = EmrServerlessStartJobTrigger( + application_id="test_app", + job_id="test_job", + waiter_delay=30, + waiter_max_attempts=60, + aws_conn_id="aws_default", + cancel_on_kill=False, # Disabled + ) + + mock_hook = mock.MagicMock() + mock_hook.get_waiter.return_value = mock.MagicMock() + + mock_client = mock.MagicMock() + mock_async_cm = mock.MagicMock() + mock_async_cm.__aenter__ = mock.AsyncMock(return_value=mock_client) + mock_async_cm.__aexit__ = mock.AsyncMock(return_value=None) + mock_hook.get_async_conn = mock.AsyncMock(return_value=mock_async_cm) + + with mock.patch.object(trigger, "hook", return_value=mock_hook): + with pytest.raises(asyncio.CancelledError): + async for _ in trigger.run(): + pass + + mock_hook.conn.cancel_job_run.assert_not_called() + class TestEmrServerlessDeleteApplicationTrigger: def test_serialization(self):