Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -654,11 +654,10 @@ class EmrCreateJobFlowOperator(AwsBaseOperator[EmrHook]):
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param wait_for_completion: Deprecated - use `wait_policy` instead.
Whether to finish task immediately after creation (False) or wait for jobflow
:param wait_for_completion: Whether to finish task immediately after creation (False) or wait for jobflow
completion (True)
(default: None)
:param wait_policy: Whether to finish the task immediately after creation (None) or:
:param wait_policy: Deprecated. Use `wait_for_completion` instead. Whether to finish the task immediately after creation (None) or:
- wait for the jobflow completion (WaitPolicy.WAIT_FOR_COMPLETION)
- wait for the jobflow completion and cluster to terminate (WaitPolicy.WAIT_FOR_STEPS_COMPLETION)
(default: None)
Expand Down Expand Up @@ -698,19 +697,29 @@ def __init__(
super().__init__(**kwargs)
self.emr_conn_id = emr_conn_id
self.job_flow_overrides = job_flow_overrides or {}
self.wait_policy = wait_policy
self.wait_for_completion = wait_for_completion
self.waiter_max_attempts = waiter_max_attempts or 60
self.waiter_delay = waiter_delay or 60
self.deferrable = deferrable

if wait_for_completion is not None:
if wait_policy is not None:
warnings.warn(
"`wait_for_completion` parameter is deprecated, please use `wait_policy` instead.",
"`wait_policy` parameter is deprecated and will be removed in a future release; "
"please use `wait_for_completion` (bool) instead.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
# preserve previous behaviour
self.wait_policy = WaitPolicy.WAIT_FOR_COMPLETION if wait_for_completion else None

if wait_for_completion is not None:
raise ValueError(
"Cannot specify both `wait_for_completion` and deprecated `wait_policy`. "
"Please use `wait_for_completion` (bool)."
)

self.wait_for_completion = wait_policy in (
WaitPolicy.WAIT_FOR_COMPLETION,
WaitPolicy.WAIT_FOR_STEPS_COMPLETION,
)

@property
def _hook_parameters(self):
Expand Down Expand Up @@ -748,8 +757,8 @@ def execute(self, context: Context) -> str | None:
job_flow_id=self._job_flow_id,
log_uri=get_log_uri(emr_client=self.hook.conn, job_flow_id=self._job_flow_id),
)
if self.wait_policy:
waiter_name = WAITER_POLICY_NAME_MAPPING[self.wait_policy]
if self.wait_for_completion:
waiter_name = WAITER_POLICY_NAME_MAPPING[WaitPolicy.WAIT_FOR_COMPLETION]

if self.deferrable:
self.defer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from botocore.waiter import Waiter
from jinja2 import StrictUndefined

from airflow.exceptions import TaskDeferred
from airflow.exceptions import AirflowProviderDeprecationWarning, TaskDeferred
from airflow.models import DAG, DagRun, TaskInstance
from airflow.providers.amazon.aws.operators.emr import EmrCreateJobFlowOperator
from airflow.providers.amazon.aws.triggers.emr import EmrCreateJobFlowTrigger
Expand Down Expand Up @@ -216,34 +216,26 @@ def test_execute_returns_job_id(self, mocked_hook_client):
mocked_hook_client.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN
assert self.operator.execute(self.mock_context) == JOB_FLOW_ID

@pytest.mark.parametrize(
"wait_policy",
[
pytest.param(WaitPolicy.WAIT_FOR_COMPLETION, id="with wait for completion"),
pytest.param(WaitPolicy.WAIT_FOR_STEPS_COMPLETION, id="with wait for steps completion policy"),
],
)
@mock.patch("botocore.waiter.get_service_module_name", return_value="emr")
@mock.patch.object(Waiter, "wait")
def test_execute_with_wait_policy(self, mock_waiter, _, mocked_hook_client, wait_policy: WaitPolicy):
def test_execute_with_wait_for_completion(self, mock_waiter, _, mocked_hook_client):
mocked_hook_client.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN

# Mock out the emr_client creator
self.operator.wait_policy = wait_policy
self.operator.wait_for_completion = True

assert self.operator.execute(self.mock_context) == JOB_FLOW_ID
mock_waiter.assert_called_once_with(mock.ANY, ClusterId=JOB_FLOW_ID, WaiterConfig=mock.ANY)
assert_expected_waiter_type(mock_waiter, WAITER_POLICY_NAME_MAPPING[wait_policy])
assert_expected_waiter_type(mock_waiter, WAITER_POLICY_NAME_MAPPING[WaitPolicy.WAIT_FOR_COMPLETION])

def test_create_job_flow_deferrable(self, mocked_hook_client):
"""
Test to make sure that the operator raises a TaskDeferred exception
if run in deferrable mode and wait_policy is set.
if run in deferrable mode and wait_for_completion is set.
"""
mocked_hook_client.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN

self.operator.deferrable = True
self.operator.wait_policy = WaitPolicy.WAIT_FOR_COMPLETION
self.operator.wait_for_completion = True
with pytest.raises(TaskDeferred) as exc:
self.operator.execute(self.mock_context)

Expand All @@ -254,14 +246,22 @@ def test_create_job_flow_deferrable(self, mocked_hook_client):
def test_create_job_flow_deferrable_no_wait(self, mocked_hook_client):
"""
Test to make sure that the operator does NOT raise a TaskDeferred exception
if run in deferrable mode but wait_policy is not set.
if run in deferrable mode but wait_for_completion is not set.
"""
mocked_hook_client.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN

self.operator.deferrable = True
# wait_policy is None by default
# wait_for_completion is None by default
result = self.operator.execute(self.mock_context)
assert result == JOB_FLOW_ID

def test_template_fields(self):
validate_template_fields(self.operator)

def test_wait_policy_deprecation_warning(self):
"""Test that using wait_policy raises a deprecation warning."""
with pytest.warns(AirflowProviderDeprecationWarning, match="`wait_policy` parameter is deprecated"):
EmrCreateJobFlowOperator(
task_id=TASK_ID,
wait_policy=WaitPolicy.WAIT_FOR_COMPLETION,
)
Loading