Skip to content

Commit

Permalink
Add wait_policy option to EmrCreateJobFlowOperator.
Browse files Browse the repository at this point in the history
Possible values:

- None: No wait (default)
- WaitPolicy.WAIT_FOR_COMPLETION: Previous behaviour when wait_for_completion was True
- WaitPolicy.WAIT_FOR_STEPS_COMPLETION: New behaviour - wait for the cluster to terminate.
  • Loading branch information
adrian-adikteev committed Nov 20, 2024
1 parent 8440016 commit f447927
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 11 deletions.
9 changes: 9 additions & 0 deletions docs/apache-airflow-providers-amazon/operators/emr/emr.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ Create an EMR job flow

You can use :class:`~airflow.providers.amazon.aws.operators.emr.EmrCreateJobFlowOperator` to
create a new EMR job flow. The cluster will be terminated automatically after finishing the steps.

The default behaviour is to mark the DAG Task node as success as soon as the cluster is launched
(``wait_policy=None``).
It is possible to modify this behaviour by using a different ``wait_policy``. Available options are:

- ``WaitPolicy.WAIT_FOR_COMPLETION`` - DAG Task node waits for the cluster to be Running
- ``WaitPolicy.WAIT_FOR_STEPS_COMPLETION`` - DAG Task node waits for the cluster to Terminate


This operator can be run in deferrable mode by passing ``deferrable=True`` as a parameter.
Using ``deferrable`` mode will release worker slots and leads to efficient utilization of
resources within Airflow cluster.However this mode will need the Airflow triggerer to be
Expand Down
1 change: 1 addition & 0 deletions newsfragments/44055.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
New argument ``wait_policy`` to control waiting behaviour when using ``EmrCreateJobFlowOperator``.
36 changes: 29 additions & 7 deletions providers/src/airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
from __future__ import annotations

import ast
import warnings
from datetime import timedelta
from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence
from uuid import uuid4

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
from airflow.providers.amazon.aws.links.emr import (
Expand All @@ -49,7 +50,11 @@
EmrTerminateJobFlowTrigger,
)
from airflow.providers.amazon.aws.utils import validate_execute_complete_event
from airflow.providers.amazon.aws.utils.waiter import waiter
from airflow.providers.amazon.aws.utils.waiter import (
WAITER_POLICY_NAME_MAPPING,
WaitPolicy,
waiter,
)
from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
from airflow.utils.helpers import exactly_one, prune_dict
from airflow.utils.types import NOTSET, ArgNotSet
Expand Down Expand Up @@ -636,8 +641,14 @@ class EmrCreateJobFlowOperator(BaseOperator):
:param job_flow_overrides: boto3 style arguments or reference to an arguments file
(must be '.json') to override specific ``emr_conn_id`` extra parameters. (templated)
:param region_name: Region named passed to EmrHook
:param wait_for_completion: Whether to finish task immediately after creation (False) or wait for jobflow
:param wait_for_completion: Deprecated - use `wait_policy` instead.
Whether to finish task immediately after creation (False) or wait for jobflow
completion (True)
(default: None)
:param wait_policy: Whether to finish the task immediately after creation (None) or:
- wait for the jobflow completion (WaitPolicy.WAIT_FOR_COMPLETION)
- wait for the jobflow completion and cluster to terminate (WaitPolicy.WAIT_FOR_STEPS_COMPLETION)
(default: None)
:param waiter_max_attempts: Maximum number of tries before failing.
:param waiter_delay: Number of seconds between polling the state of the notebook.
:param deferrable: If True, the operator will wait asynchronously for the crawl to complete.
Expand Down Expand Up @@ -665,7 +676,8 @@ def __init__(
emr_conn_id: str | None = "emr_default",
job_flow_overrides: str | dict[str, Any] | None = None,
region_name: str | None = None,
wait_for_completion: bool = False,
wait_for_completion: bool | None = None,
wait_policy: WaitPolicy | None = None,
waiter_max_attempts: int | None = None,
waiter_delay: int | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
Expand All @@ -676,11 +688,20 @@ def __init__(
self.emr_conn_id = emr_conn_id
self.job_flow_overrides = job_flow_overrides or {}
self.region_name = region_name
self.wait_for_completion = wait_for_completion
self.wait_policy = wait_policy
self.waiter_max_attempts = waiter_max_attempts or 60
self.waiter_delay = waiter_delay or 60
self.deferrable = deferrable

if wait_for_completion is not None:
warnings.warn(
"`wait_for_completion` parameter is deprecated, please use `wait_policy` instead.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
# preserve previous behaviour
self.wait_policy = WaitPolicy.WAIT_FOR_COMPLETION if wait_for_completion else None

@cached_property
def _emr_hook(self) -> EmrHook:
"""Create and return an EmrHook."""
Expand Down Expand Up @@ -733,8 +754,9 @@ def execute(self, context: Context) -> str | None:
# 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent)
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60),
)
if self.wait_for_completion:
self._emr_hook.get_waiter("job_flow_waiting").wait(
if self.wait_policy:
waiter_name = WAITER_POLICY_NAME_MAPPING[self.wait_policy]
self._emr_hook.get_waiter(waiter_name).wait(
ClusterId=self._job_flow_id,
WaiterConfig=prune_dict(
{
Expand Down
22 changes: 21 additions & 1 deletion providers/src/airflow/providers/amazon/aws/utils/waiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

import logging
import time
from typing import Callable
from enum import Enum
from typing import Callable, Mapping

from airflow.exceptions import AirflowException

Expand Down Expand Up @@ -83,3 +84,22 @@ def get_state(response, keys) -> str:
if value is not None:
value = value.get(key, None)
return value


class WaitPolicy(str, Enum):
"""
Used to control the waiting behaviour within EMRClusterJobFlowOperator.
Choices:
- WAIT_FOR_COMPLETION - Will wait for the cluster to report "Running" state
- WAIT_FOR_STEPS_COMPLETION - Will wait for the cluster to report "Terminated" state
"""

WAIT_FOR_COMPLETION = "wait_for_completion"
WAIT_FOR_STEPS_COMPLETION = "wait_for_steps_completion"


WAITER_POLICY_NAME_MAPPING: Mapping[WaitPolicy, str] = {
WaitPolicy.WAIT_FOR_COMPLETION: "job_flow_waiting",
WaitPolicy.WAIT_FOR_STEPS_COMPLETION: "job_flow_terminated",
}
14 changes: 11 additions & 3 deletions providers/tests/amazon/aws/operators/test_emr_create_job_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from airflow.models import DAG, DagRun, TaskInstance
from airflow.providers.amazon.aws.operators.emr import EmrCreateJobFlowOperator
from airflow.providers.amazon.aws.triggers.emr import EmrCreateJobFlowTrigger
from airflow.providers.amazon.aws.utils.waiter import WAITER_POLICY_NAME_MAPPING, WaitPolicy
from airflow.utils import timezone
from airflow.utils.types import DagRunType

Expand Down Expand Up @@ -193,17 +194,24 @@ def test_execute_returns_job_id(self, mocked_hook_client):
mocked_hook_client.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN
assert self.operator.execute(self.mock_context) == JOB_FLOW_ID

@pytest.mark.parametrize(
"wait_policy",
[
pytest.param(WaitPolicy.WAIT_FOR_COMPLETION, id="with wait for completion"),
pytest.param(WaitPolicy.WAIT_FOR_STEPS_COMPLETION, id="with wait for steps completion policy"),
],
)
@mock.patch("botocore.waiter.get_service_module_name", return_value="emr")
@mock.patch.object(Waiter, "wait")
def test_execute_with_wait(self, mock_waiter, _, mocked_hook_client):
def test_execute_with_wait_policy(self, mock_waiter, _, mocked_hook_client, wait_policy: WaitPolicy):
mocked_hook_client.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN

# Mock out the emr_client creator
self.operator.wait_for_completion = True
self.operator.wait_policy = wait_policy

assert self.operator.execute(self.mock_context) == JOB_FLOW_ID
mock_waiter.assert_called_once_with(mock.ANY, ClusterId=JOB_FLOW_ID, WaiterConfig=mock.ANY)
assert_expected_waiter_type(mock_waiter, "job_flow_waiting")
assert_expected_waiter_type(mock_waiter, WAITER_POLICY_NAME_MAPPING[wait_policy])

def test_create_job_flow_deferrable(self, mocked_hook_client):
"""
Expand Down

0 comments on commit f447927

Please sign in to comment.