Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 75 additions & 45 deletions providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from typing import TYPE_CHECKING, Any
from uuid import uuid4

from botocore.exceptions import WaiterError

from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
from airflow.providers.amazon.aws.links.emr import (
EmrClusterLink,
Expand Down Expand Up @@ -665,6 +667,9 @@ class EmrCreateJobFlowOperator(AwsBaseOperator[EmrHook]):
:param deferrable: If True, the operator will wait asynchronously for the crawl to complete.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False)
:param terminate_job_flow_on_failure: If True, attempts best-effort termination of the EMR job flow
when a failure occurs after the job flow has been created. Cleanup failures do not mask the
original exception. (default: True)
"""

aws_hook_class = EmrHook
Expand All @@ -691,6 +696,7 @@ def __init__(
waiter_max_attempts: int | None = None,
waiter_delay: int | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
terminate_job_flow_on_failure: bool = True,
**kwargs: Any,
):
super().__init__(**kwargs)
Expand All @@ -699,6 +705,7 @@ def __init__(
self.waiter_max_attempts = waiter_max_attempts or 60
self.waiter_delay = waiter_delay or 60
self.deferrable = deferrable
self.terminate_job_flow_on_failure = terminate_job_flow_on_failure
self.wait_policy = wait_policy

# Backwards-compatible default: if the user requested waiting for
Expand Down Expand Up @@ -746,58 +753,81 @@ def execute(self, context: Context) -> str | None:

self._job_flow_id = response["JobFlowId"]
self.log.info("Job flow with id %s created", self._job_flow_id)
EmrClusterLink.persist(
context=context,
operator=self,
region_name=self.hook.conn_region_name,
aws_partition=self.hook.conn_partition,
job_flow_id=self._job_flow_id,
)
if self._job_flow_id:
EmrLogsLink.persist(
try:
EmrClusterLink.persist(
context=context,
operator=self,
region_name=self.hook.conn_region_name,
aws_partition=self.hook.conn_partition,
job_flow_id=self._job_flow_id,
log_uri=get_log_uri(emr_client=self.hook.conn, job_flow_id=self._job_flow_id),
)
if self.wait_for_completion:
# Determine which waiter to use. Prefer explicit wait_policy when provided,
# otherwise default to WAIT_FOR_COMPLETION.
wp = self.wait_policy
if wp is not None:
waiter_name = WAITER_POLICY_NAME_MAPPING[wp]
else:
waiter_name = WAITER_POLICY_NAME_MAPPING[WaitPolicy.WAIT_FOR_COMPLETION]

if self.deferrable:
# Pass the selected waiter_name to the trigger so deferrable mode waits
# according to the requested policy as well.
self.defer(
trigger=EmrCreateJobFlowTrigger(
job_flow_id=self._job_flow_id,
aws_conn_id=self.aws_conn_id,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
waiter_name=waiter_name,
),
method_name="execute_complete",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
# 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent)
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60),
)
else:
self.hook.get_waiter(waiter_name).wait(
ClusterId=self._job_flow_id,
WaiterConfig=prune_dict(
{
"Delay": self.waiter_delay,
"MaxAttempts": self.waiter_max_attempts,
}
),
if self._job_flow_id:
EmrLogsLink.persist(
context=context,
operator=self,
region_name=self.hook.conn_region_name,
aws_partition=self.hook.conn_partition,
job_flow_id=self._job_flow_id,
log_uri=get_log_uri(emr_client=self.hook.conn, job_flow_id=self._job_flow_id),
)
return self._job_flow_id
if self.wait_for_completion:
# Determine which waiter to use. Prefer explicit wait_policy when provided,
# otherwise default to WAIT_FOR_COMPLETION.
wp = self.wait_policy
if wp is not None:
waiter_name = WAITER_POLICY_NAME_MAPPING[wp]
else:
waiter_name = WAITER_POLICY_NAME_MAPPING[WaitPolicy.WAIT_FOR_COMPLETION]

if self.deferrable:
# Pass the selected waiter_name to the trigger so deferrable mode waits
# according to the requested policy as well.
self.defer(
trigger=EmrCreateJobFlowTrigger(
job_flow_id=self._job_flow_id,
aws_conn_id=self.aws_conn_id,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
waiter_name=waiter_name,
),
method_name="execute_complete",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
# 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent)
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60),
)
else:
self.hook.get_waiter(waiter_name).wait(
ClusterId=self._job_flow_id,
WaiterConfig=prune_dict(
{
"Delay": self.waiter_delay,
"MaxAttempts": self.waiter_max_attempts,
}
),
)
return self._job_flow_id

# Best-effort cleanup when post-creation steps fail (e.g. IAM/permission errors).
except WaiterError:
if self._job_flow_id:
if self.terminate_job_flow_on_failure:
self.log.warning(
"Task failed after creating EMR job flow %s.",
self._job_flow_id,
)
try:
self.log.info(
"Attempting termination of EMR job flow %s.",
self._job_flow_id,
)

self.hook.conn.terminate_job_flows(JobFlowIds=[self._job_flow_id])
except Exception:
self.log.exception(
"Failed to terminate EMR job flow %s after task failure.",
self._job_flow_id,
)
raise

def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
validated_event = validate_execute_complete_event(event)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from unittest.mock import MagicMock, patch

import pytest
from botocore.exceptions import ClientError, WaiterError
from botocore.waiter import Waiter
from jinja2 import StrictUndefined

Expand Down Expand Up @@ -231,6 +232,7 @@ def test_create_job_flow_deferrable(self, mocked_hook_client):

self.operator.deferrable = True
self.operator.wait_for_completion = True

with pytest.raises(TaskDeferred) as exc:
self.operator.execute(self.mock_context)

Expand Down Expand Up @@ -281,3 +283,77 @@ def test_specify_only_wait_policy(self):
)
assert getattr(op, "wait_policy") == WaitPolicy.WAIT_FOR_STEPS_COMPLETION
assert op.wait_for_completion is True

def test_cleanup_on_post_create_failure(self, mocked_hook_client):
"""
Ensure that if the job flow is created successfully but a subsequent
post-create step fails (e.g. waiter / DescribeCluster),
the operator attempts best-effort cleanup.
"""
mocked_hook_client.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN

self.operator.wait_for_completion = True
self.operator.terminate_job_flow_on_failure = True

waiter_error = WaiterError(
"ClusterRunning",
"You are not authorized to perform this operation",
{},
)

with (
patch.object(self.operator.hook, "get_waiter") as mock_get_waiter,
patch.object(self.operator.hook.conn, "terminate_job_flows") as mock_terminate,
):
mock_get_waiter.return_value.wait.side_effect = waiter_error

with pytest.raises(WaiterError) as exc:
self.operator.execute(self.mock_context)

# Original exception must be propagated unchanged
assert exc.value is waiter_error

# Cleanup must be attempted
mock_terminate.assert_called_once_with(JobFlowIds=[JOB_FLOW_ID])

def test_cleanup_failure_does_not_mask_original_exception(self, mocked_hook_client):
"""
Ensure that failure during cleanup does not override
the original post-create exception.
"""
mocked_hook_client.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN

self.operator.wait_for_completion = True
self.operator.terminate_job_flow_on_failure = True

waiter_error = WaiterError(
"ClusterRunning",
"You are not authorized to perform this operation",
{},
)

cleanup_error = ClientError(
error_response={
"Error": {
"Code": "UnauthorizedOperation",
"Message": "You are not authorized to perform this operation",
}
},
operation_name="TerminateJobFlows",
)

with (
patch.object(self.operator.hook, "get_waiter") as mock_get_waiter,
patch.object(self.operator.hook.conn, "terminate_job_flows") as mock_terminate,
):
mock_get_waiter.return_value.wait.side_effect = waiter_error
mock_terminate.side_effect = cleanup_error

with pytest.raises(WaiterError) as exc:
self.operator.execute(self.mock_context)

# Original exception must be preserved
assert exc.value is waiter_error

# Cleanup attempted despite failure
mock_terminate.assert_called_once_with(JobFlowIds=[JOB_FLOW_ID])