Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from typing import TYPE_CHECKING, Any
from uuid import uuid4

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
from airflow.providers.amazon.aws.links.emr import (
EmrClusterLink,
Expand Down Expand Up @@ -657,7 +656,7 @@ class EmrCreateJobFlowOperator(AwsBaseOperator[EmrHook]):
:param wait_for_completion: Whether to finish task immediately after creation (False) or wait for jobflow
completion (True)
(default: None)
:param wait_policy: Deprecated. Use `wait_for_completion` instead. Whether to finish the task immediately after creation (None) or:
:param wait_policy: Whether to finish the task immediately after creation (None) or:
- wait for the jobflow completion (WaitPolicy.WAIT_FOR_COMPLETION)
- wait for the jobflow completion and cluster to terminate (WaitPolicy.WAIT_FOR_STEPS_COMPLETION)
(default: None)
Expand Down Expand Up @@ -697,29 +696,35 @@ def __init__(
super().__init__(**kwargs)
self.emr_conn_id = emr_conn_id
self.job_flow_overrides = job_flow_overrides or {}
self.wait_for_completion = wait_for_completion
self.waiter_max_attempts = waiter_max_attempts or 60
self.waiter_delay = waiter_delay or 60
self.deferrable = deferrable

if wait_policy is not None:
warnings.warn(
"`wait_policy` parameter is deprecated and will be removed in a future release; "
"please use `wait_for_completion` (bool) instead.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)

if wait_for_completion is not None:
raise ValueError(
"Cannot specify both `wait_for_completion` and deprecated `wait_policy`. "
"Please use `wait_for_completion` (bool)."
self.wait_policy = wait_policy

# Backwards-compatible default: if the user requested waiting for
# completion (wait_for_completion=True) but did not provide an
# explicit wait_policy, default the wait_policy to
# WaitPolicy.WAIT_FOR_COMPLETION
if self.wait_policy is None and wait_for_completion:
self.wait_policy = WaitPolicy.WAIT_FOR_COMPLETION

# Handle deprecated wait_for_completion parameter. If wait_policy is set,
# we always override wait_for_completion to True (since some form of waiting is
# requested). If wait_policy is not set, we use the value of wait_for_completion
# (defaulting to False if not provided).
if self.wait_policy is not None:
if wait_for_completion is False:
warnings.warn(
"Setting wait_policy while wait_for_completion is False is deprecated. "
"In future, you must set wait_for_completion=True to wait.",
UserWarning,
stacklevel=2,
)

self.wait_for_completion = wait_policy in (
WaitPolicy.WAIT_FOR_COMPLETION,
WaitPolicy.WAIT_FOR_STEPS_COMPLETION,
)
self.wait_for_completion = True
elif wait_for_completion is not None:
self.wait_for_completion = wait_for_completion
else:
self.wait_for_completion = False

@property
def _hook_parameters(self):
Expand Down Expand Up @@ -758,15 +763,24 @@ def execute(self, context: Context) -> str | None:
log_uri=get_log_uri(emr_client=self.hook.conn, job_flow_id=self._job_flow_id),
)
if self.wait_for_completion:
waiter_name = WAITER_POLICY_NAME_MAPPING[WaitPolicy.WAIT_FOR_COMPLETION]
# Determine which waiter to use. Prefer explicit wait_policy when provided,
# otherwise default to WAIT_FOR_COMPLETION.
wp = self.wait_policy
if wp is not None:
waiter_name = WAITER_POLICY_NAME_MAPPING[wp]
else:
waiter_name = WAITER_POLICY_NAME_MAPPING[WaitPolicy.WAIT_FOR_COMPLETION]

if self.deferrable:
# Pass the selected waiter_name to the trigger so deferrable mode waits
# according to the requested policy as well.
self.defer(
trigger=EmrCreateJobFlowTrigger(
job_flow_id=self._job_flow_id,
aws_conn_id=self.aws_conn_id,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
waiter_name=waiter_name,
),
method_name="execute_complete",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,11 @@ def __init__(
aws_conn_id: str | None = None,
waiter_delay: int = 30,
waiter_max_attempts: int = 60,
waiter_name: str = "job_flow_waiting",
):
super().__init__(
serialized_fields={"job_flow_id": job_flow_id},
waiter_name="job_flow_waiting",
waiter_name=waiter_name,
waiter_args={"ClusterId": job_flow_id},
failure_message="JobFlow creation failed",
status_message="JobFlow creation in progress",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from botocore.waiter import Waiter
from jinja2 import StrictUndefined

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.models import DAG, DagRun, TaskInstance
from airflow.providers.amazon.aws.operators.emr import EmrCreateJobFlowOperator
from airflow.providers.amazon.aws.triggers.emr import EmrCreateJobFlowTrigger
Expand Down Expand Up @@ -254,10 +253,31 @@ def test_create_job_flow_deferrable_no_wait(self, mocked_hook_client):
def test_template_fields(self):
validate_template_fields(self.operator)

def test_wait_policy_deprecation_warning(self):
"""Test that using wait_policy raises a deprecation warning."""
with pytest.warns(AirflowProviderDeprecationWarning, match="`wait_policy` parameter is deprecated"):
EmrCreateJobFlowOperator(
task_id=TASK_ID,
wait_policy=WaitPolicy.WAIT_FOR_COMPLETION,
)
def test_wait_policy_behavior(self):
"""Test that using wait_for_completion but not pass wait_policy."""
op = EmrCreateJobFlowOperator(
task_id=TASK_ID,
wait_for_completion=True,
)
# wait_policy should be the default WAIT_FOR_COMPLETION
assert getattr(op, "wait_policy") == WaitPolicy.WAIT_FOR_COMPLETION
assert op.wait_for_completion is True

def test_specify_both_wait_for_completion_and_wait_policy(self):
"""Passing both wait_for_completion and wait_policy."""
op = EmrCreateJobFlowOperator(
task_id=TASK_ID,
wait_for_completion=True,
wait_policy=WaitPolicy.WAIT_FOR_STEPS_COMPLETION,
)
assert getattr(op, "wait_policy") == WaitPolicy.WAIT_FOR_STEPS_COMPLETION
assert op.wait_for_completion is True

def test_specify_only_wait_policy(self):
"""Passing only wait_policy."""
op = EmrCreateJobFlowOperator(
task_id=TASK_ID,
wait_policy=WaitPolicy.WAIT_FOR_STEPS_COMPLETION,
)
assert getattr(op, "wait_policy") == WaitPolicy.WAIT_FOR_STEPS_COMPLETION
assert op.wait_for_completion is True