Skip to content

Commit

Permalink
Add wait_policy option to EmrCreateJobFlowOperator.
Browse files Browse the repository at this point in the history
  • Loading branch information
adrian-adikteev committed Nov 18, 2024
1 parent 123dadd commit 30c0240
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 5 deletions.
15 changes: 13 additions & 2 deletions providers/src/airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@
EmrTerminateJobFlowTrigger,
)
from airflow.providers.amazon.aws.utils import validate_execute_complete_event
from airflow.providers.amazon.aws.utils.waiter import waiter
from airflow.providers.amazon.aws.utils.waiter import (
waiter,
WaitPolicy,
WAITER_POLICY_NAME_MAPPING,
)
from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
from airflow.utils.helpers import exactly_one, prune_dict
from airflow.utils.types import NOTSET, ArgNotSet
Expand Down Expand Up @@ -638,6 +642,10 @@ class EmrCreateJobFlowOperator(BaseOperator):
:param region_name: Region named passed to EmrHook
:param wait_for_completion: Whether to finish task immediately after creation (False) or wait for jobflow
completion (True)
:param wait_policy: Depends on wait_for_completion to be True.
Whether to finish the task after the jobflow completion (WaitPolicy.DEFAULT) or wait for the
cluster to terminate (WaitPolicy.WAIT_FOR_STEPS_COMPLETION).
(default: WaitPolicy.DEFAULT)
:param waiter_max_attempts: Maximum number of tries before failing.
:param waiter_delay: Number of seconds between polling the state of the notebook.
:param deferrable: If True, the operator will wait asynchronously for the crawl to complete.
Expand Down Expand Up @@ -666,6 +674,7 @@ def __init__(
job_flow_overrides: str | dict[str, Any] | None = None,
region_name: str | None = None,
wait_for_completion: bool = False,
wait_policy: WaitPolicy = WaitPolicy.DEFAULT,
waiter_max_attempts: int | None = None,
waiter_delay: int | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
Expand All @@ -677,6 +686,7 @@ def __init__(
self.job_flow_overrides = job_flow_overrides or {}
self.region_name = region_name
self.wait_for_completion = wait_for_completion
self.wait_policy = wait_policy
self.waiter_max_attempts = waiter_max_attempts or 60
self.waiter_delay = waiter_delay or 60
self.deferrable = deferrable
Expand Down Expand Up @@ -734,7 +744,8 @@ def execute(self, context: Context) -> str | None:
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60),
)
if self.wait_for_completion:
self._emr_hook.get_waiter("job_flow_waiting").wait(
waiter_name = WAITER_POLICY_NAME_MAPPING[self.wait_policy]
self._emr_hook.get_waiter(waiter_name).wait(
ClusterId=self._job_flow_id,
WaiterConfig=prune_dict(
{
Expand Down
17 changes: 16 additions & 1 deletion providers/src/airflow/providers/amazon/aws/utils/waiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

from __future__ import annotations

from enum import Enum
import logging
import time
from typing import Callable
from typing import Callable, Mapping

from airflow.exceptions import AirflowException

Expand Down Expand Up @@ -83,3 +84,17 @@ def get_state(response, keys) -> str:
if value is not None:
value = value.get(key, None)
return value


class WaitPolicy(str, Enum):
# Wait for the cluster to be up.
DEFAULT = "default"
# Wait for the cluster to be up and wait for the steps to complete.
WAIT_FOR_STEPS_COMPLETION = "wait_for_steps_completion"



WAITER_POLICY_NAME_MAPPING: Mapping[WaitPolicy, str] = {
WaitPolicy.DEFAULT: "job_flow_completed",
WaitPolicy.WAIT_FOR_STEPS_COMPLETION: "job_flow_terminated",
}
13 changes: 11 additions & 2 deletions providers/tests/amazon/aws/operators/test_emr_create_job_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from airflow.models import DAG, DagRun, TaskInstance
from airflow.providers.amazon.aws.operators.emr import EmrCreateJobFlowOperator
from airflow.providers.amazon.aws.triggers.emr import EmrCreateJobFlowTrigger
from airflow.providers.amazon.aws.utils.waiter import WaitPolicy, WAITER_POLICY_NAME_MAPPING
from airflow.utils import timezone
from airflow.utils.types import DagRunType

Expand Down Expand Up @@ -193,17 +194,25 @@ def test_execute_returns_job_id(self, mocked_hook_client):
mocked_hook_client.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN
assert self.operator.execute(self.mock_context) == JOB_FLOW_ID

@pytest.mark.parametrize(
"wait_policy",
[
pytest.param(WaitPolicy.DEFAULT, id="with default policy"),
pytest.param(WaitPolicy.WAIT_FOR_STEPS_COMPLETION, id="with wait for steps completion policy"),
],
)
@mock.patch("botocore.waiter.get_service_module_name", return_value="emr")
@mock.patch.object(Waiter, "wait")
def test_execute_with_wait(self, mock_waiter, _, mocked_hook_client):
def test_execute_with_wait(self, mock_waiter, _, mocked_hook_client, wait_policy: WaitPolicy):
mocked_hook_client.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN

# Mock out the emr_client creator
self.operator.wait_for_completion = True
self.operator.wait_policy = wait_policy

assert self.operator.execute(self.mock_context) == JOB_FLOW_ID
mock_waiter.assert_called_once_with(mock.ANY, ClusterId=JOB_FLOW_ID, WaiterConfig=mock.ANY)
assert_expected_waiter_type(mock_waiter, "job_flow_waiting")
assert_expected_waiter_type(mock_waiter, WAITER_POLICY_NAME_MAPPING[wait_policy])

def test_create_job_flow_deferrable(self, mocked_hook_client):
"""
Expand Down

0 comments on commit 30c0240

Please sign in to comment.