diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py index 48e8773c16a09..1dc3d54f4d31f 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py @@ -24,7 +24,6 @@ from typing import TYPE_CHECKING, Any from uuid import uuid4 -from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook from airflow.providers.amazon.aws.links.emr import ( EmrClusterLink, @@ -657,7 +656,7 @@ class EmrCreateJobFlowOperator(AwsBaseOperator[EmrHook]): :param wait_for_completion: Whether to finish task immediately after creation (False) or wait for jobflow completion (True) (default: None) - :param wait_policy: Deprecated. Use `wait_for_completion` instead. Whether to finish the task immediately after creation (None) or: + :param wait_policy: Whether to finish the task immediately after creation (None) or: - wait for the jobflow completion (WaitPolicy.WAIT_FOR_COMPLETION) - wait for the jobflow completion and cluster to terminate (WaitPolicy.WAIT_FOR_STEPS_COMPLETION) (default: None) @@ -697,29 +696,35 @@ def __init__( super().__init__(**kwargs) self.emr_conn_id = emr_conn_id self.job_flow_overrides = job_flow_overrides or {} - self.wait_for_completion = wait_for_completion self.waiter_max_attempts = waiter_max_attempts or 60 self.waiter_delay = waiter_delay or 60 self.deferrable = deferrable - - if wait_policy is not None: - warnings.warn( - "`wait_policy` parameter is deprecated and will be removed in a future release; " - "please use `wait_for_completion` (bool) instead.", - AirflowProviderDeprecationWarning, - stacklevel=2, - ) - - if wait_for_completion is not None: - raise ValueError( - "Cannot specify both `wait_for_completion` and deprecated `wait_policy`. " - "Please use `wait_for_completion` (bool)." + self.wait_policy = wait_policy + + # Backwards-compatible default: if the user requested waiting for + # completion (wait_for_completion=True) but did not provide an + # explicit wait_policy, default the wait_policy to + # WaitPolicy.WAIT_FOR_COMPLETION + if self.wait_policy is None and wait_for_completion: + self.wait_policy = WaitPolicy.WAIT_FOR_COMPLETION + + # Handle deprecated wait_for_completion parameter. If wait_policy is set, + # we always override wait_for_completion to True (since some form of waiting is + # requested). If wait_policy is not set, we use the value of wait_for_completion + # (defaulting to False if not provided). + if self.wait_policy is not None: + if wait_for_completion is False: + warnings.warn( + "Setting wait_policy while wait_for_completion is False is deprecated. " + "In future, you must set wait_for_completion=True to wait.", + UserWarning, + stacklevel=2, ) - - self.wait_for_completion = wait_policy in ( - WaitPolicy.WAIT_FOR_COMPLETION, - WaitPolicy.WAIT_FOR_STEPS_COMPLETION, - ) + self.wait_for_completion = True + elif wait_for_completion is not None: + self.wait_for_completion = wait_for_completion + else: + self.wait_for_completion = False @property def _hook_parameters(self): @@ -758,15 +763,24 @@ def execute(self, context: Context) -> str | None: log_uri=get_log_uri(emr_client=self.hook.conn, job_flow_id=self._job_flow_id), ) if self.wait_for_completion: - waiter_name = WAITER_POLICY_NAME_MAPPING[WaitPolicy.WAIT_FOR_COMPLETION] + # Determine which waiter to use. Prefer explicit wait_policy when provided, + # otherwise default to WAIT_FOR_COMPLETION. + wp = self.wait_policy + if wp is not None: + waiter_name = WAITER_POLICY_NAME_MAPPING[wp] + else: + waiter_name = WAITER_POLICY_NAME_MAPPING[WaitPolicy.WAIT_FOR_COMPLETION] if self.deferrable: + # Pass the selected waiter_name to the trigger so deferrable mode waits + # according to the requested policy as well. self.defer( trigger=EmrCreateJobFlowTrigger( job_flow_id=self._job_flow_id, aws_conn_id=self.aws_conn_id, waiter_delay=self.waiter_delay, waiter_max_attempts=self.waiter_max_attempts, + waiter_name=waiter_name, ), method_name="execute_complete", # timeout is set to ensure that if a trigger dies, the timeout does not restart diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py b/providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py index 3e24fc2b6d160..356a7354d78d5 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py @@ -82,10 +82,11 @@ def __init__( aws_conn_id: str | None = None, waiter_delay: int = 30, waiter_max_attempts: int = 60, + waiter_name: str = "job_flow_waiting", ): super().__init__( serialized_fields={"job_flow_id": job_flow_id}, - waiter_name="job_flow_waiting", + waiter_name=waiter_name, waiter_args={"ClusterId": job_flow_id}, failure_message="JobFlow creation failed", status_message="JobFlow creation in progress", diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py index e49f7ba775806..8389979ec95f3 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py @@ -26,7 +26,6 @@ from botocore.waiter import Waiter from jinja2 import StrictUndefined -from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.models import DAG, DagRun, TaskInstance from airflow.providers.amazon.aws.operators.emr import EmrCreateJobFlowOperator from airflow.providers.amazon.aws.triggers.emr import EmrCreateJobFlowTrigger @@ -254,10 +253,31 @@ def test_create_job_flow_deferrable_no_wait(self, mocked_hook_client): def test_template_fields(self): validate_template_fields(self.operator) - def test_wait_policy_deprecation_warning(self): - """Test that using wait_policy raises a deprecation warning.""" - with pytest.warns(AirflowProviderDeprecationWarning, match="`wait_policy` parameter is deprecated"): - EmrCreateJobFlowOperator( - task_id=TASK_ID, - wait_policy=WaitPolicy.WAIT_FOR_COMPLETION, - ) + def test_wait_policy_behavior(self): + """Test that using wait_for_completion but not pass wait_policy.""" + op = EmrCreateJobFlowOperator( + task_id=TASK_ID, + wait_for_completion=True, + ) + # wait_policy should be the default WAIT_FOR_COMPLETION + assert getattr(op, "wait_policy") == WaitPolicy.WAIT_FOR_COMPLETION + assert op.wait_for_completion is True + + def test_specify_both_wait_for_completion_and_wait_policy(self): + """Passing both wait_for_completion and wait_policy.""" + op = EmrCreateJobFlowOperator( + task_id=TASK_ID, + wait_for_completion=True, + wait_policy=WaitPolicy.WAIT_FOR_STEPS_COMPLETION, + ) + assert getattr(op, "wait_policy") == WaitPolicy.WAIT_FOR_STEPS_COMPLETION + assert op.wait_for_completion is True + + def test_specify_only_wait_policy(self): + """Passing only wait_policy.""" + op = EmrCreateJobFlowOperator( + task_id=TASK_ID, + wait_policy=WaitPolicy.WAIT_FOR_STEPS_COMPLETION, + ) + assert getattr(op, "wait_policy") == WaitPolicy.WAIT_FOR_STEPS_COMPLETION + assert op.wait_for_completion is True