From 8362c58e6087a9e02e64a4dbf5a70b48d3fe1cd1 Mon Sep 17 00:00:00 2001 From: Syed Hussain Date: Thu, 27 Apr 2023 12:59:44 -0700 Subject: [PATCH 1/9] Add Deferrable mode to Emr Add Steps operator --- airflow/providers/amazon/aws/operators/emr.py | 28 +++++- airflow/providers/amazon/aws/triggers/emr.py | 86 +++++++++++++++++++ .../aws/operators/test_emr_add_steps.py | 34 +++++++- .../amazon/aws/triggers/test_emr_trigger.py | 76 ++++++++++++++++ 4 files changed, 221 insertions(+), 3 deletions(-) create mode 100644 airflow/providers/amazon/aws/triggers/emr.py create mode 100644 tests/providers/amazon/aws/triggers/test_emr_trigger.py diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index 1490775fe61a7..50fd9f4d770dc 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -27,6 +27,7 @@ from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook from airflow.providers.amazon.aws.links.emr import EmrClusterLink, EmrLogsLink, get_log_uri +from airflow.providers.amazon.aws.triggers.emr import EmrAddStepsTrigger from airflow.providers.amazon.aws.utils.waiter import waiter from airflow.utils.helpers import exactly_one, prune_dict from airflow.utils.types import NOTSET, ArgNotSet @@ -55,6 +56,7 @@ class EmrAddStepsOperator(BaseOperator): :param wait_for_completion: If True, the operator will wait for all the steps to be completed. :param execution_role_arn: The ARN of the runtime role for a step on the cluster. :param do_xcom_push: if True, job_flow_id is pushed to XCom with key job_flow_id. + :param deferrable: if True, the operator will run in deferrable mode. """ template_fields: Sequence[str] = ( @@ -84,6 +86,7 @@ def __init__( waiter_delay: int | None = None, waiter_max_attempts: int | None = None, execution_role_arn: str | None = None, + deferrable: bool = False, **kwargs, ): if not exactly_one(job_flow_id is None, job_flow_name is None): @@ -96,10 +99,11 @@ def __init__( self.job_flow_name = job_flow_name self.cluster_states = cluster_states self.steps = steps - self.wait_for_completion = wait_for_completion + self.wait_for_completion = False if deferrable else wait_for_completion self.waiter_delay = waiter_delay self.waiter_max_attempts = waiter_max_attempts self.execution_role_arn = execution_role_arn + self.deferrable = deferrable def execute(self, context: Context) -> list[str]: emr_hook = EmrHook(aws_conn_id=self.aws_conn_id) @@ -137,7 +141,7 @@ def execute(self, context: Context) -> list[str]: steps = self.steps if isinstance(steps, str): steps = ast.literal_eval(steps) - return emr_hook.add_job_flow_steps( + step_ids = emr_hook.add_job_flow_steps( job_flow_id=job_flow_id, steps=steps, wait_for_completion=self.wait_for_completion, @@ -145,6 +149,26 @@ def execute(self, context: Context) -> list[str]: waiter_max_attempts=self.waiter_max_attempts, execution_role_arn=self.execution_role_arn, ) + if self.deferrable: + self.defer( + trigger=EmrAddStepsTrigger( + job_flow_id=job_flow_id, + step_ids=step_ids, + aws_conn_id=self.aws_conn_id, + max_attempts=self.waiter_max_attempts, + poll_interval=self.waiter_delay, + ), + method_name="execute_complete", + ) + + return step_ids + + def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error resuming cluster: {event}") + else: + self.log.info("Steps completed successfully") + return event["step_ids"] class EmrStartNotebookExecutionOperator(BaseOperator): diff --git a/airflow/providers/amazon/aws/triggers/emr.py b/airflow/providers/amazon/aws/triggers/emr.py new file mode 100644 index 0000000000000..e5d756e5cedae --- /dev/null +++ b/airflow/providers/amazon/aws/triggers/emr.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +from typing import Any + +from botocore.exceptions import WaiterError + +from airflow.compat.functools import cached_property +from airflow.providers.amazon.aws.hooks.emr import EmrHook +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class EmrAddStepsTrigger(BaseTrigger): + """AWS Emr Add Steps Trigger""" + + def __init__( + self, + job_flow_id: str, + step_ids: list[str], + aws_conn_id: str, + max_attempts: int | None, + poll_interval: int | None, + ): + self.job_flow_id = job_flow_id + self.step_ids = step_ids + self.aws_conn_id = aws_conn_id + self.max_attempts = max_attempts + self.poll_interval = poll_interval + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + "airflow.providers.amazon.aws.triggers.emr.EmrAddStepsTrigger", + { + "job_flow_id": str(self.job_flow_id), + "step_ids": self.step_ids, + "poll_interval": str(self.poll_interval), + "max_attempts": str(self.max_attempts), + "aws_conn_id": str(self.aws_conn_id), + }, + ) + + @cached_property + def hook(self) -> EmrHook: + return EmrHook(aws_conn_id=self.aws_conn_id) + + async def run(self): + async with self.hook.async_conn as client: + for step_id in self.step_ids: + waiter = client.get_waiter("step_complete") + try: + await waiter.wait( + ClusterId=self.job_flow_id, + StepId=step_id, + WaiterConfig={ + "Delay": int(self.poll_interval), + "MaxAttempts": 1, + }, + ) + break + except WaiterError as error: + if "terminal failure" in str(error): + yield TriggerEvent( + {"status": "failure", "message": f"Steps failed: {error}"} + ) + break + self.log.info( + "Status of step is %s", error.last_response["Step"]["Status"] + ) + await asyncio.sleep(int(self.poll_interval)) + yield TriggerEvent({"status": "success", "message": "Steps completed", "step_ids": self.step_ids}) diff --git a/tests/providers/amazon/aws/operators/test_emr_add_steps.py b/tests/providers/amazon/aws/operators/test_emr_add_steps.py index 5171815eb37b0..1be70d3f095cd 100644 --- a/tests/providers/amazon/aws/operators/test_emr_add_steps.py +++ b/tests/providers/amazon/aws/operators/test_emr_add_steps.py @@ -25,11 +25,12 @@ import pytest from jinja2 import StrictUndefined -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, TaskDeferred from airflow.models import DAG, DagRun, TaskInstance from airflow.providers.amazon.aws.hooks.emr import EmrHook from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.providers.amazon.aws.operators.emr import EmrAddStepsOperator +from airflow.providers.amazon.aws.triggers.emr import EmrAddStepsTrigger from airflow.utils import timezone from tests.test_utils import AIRFLOW_MAIN_FOLDER @@ -244,3 +245,34 @@ def test_wait_for_completion(self, mock_add_job_flow_steps, *_): waiter_max_attempts=None, execution_role_arn=None, ) + + def test_wait_for_completion_false_with_deferrable(self): + job_flow_id = "j-8989898989" + operator = EmrAddStepsOperator( + task_id="test_task", + job_flow_id=job_flow_id, + aws_conn_id="aws_default", + dag=DAG("test_dag_id", default_args=self.args), + wait_for_completion=True, + deferrable=True, + ) + + assert operator.wait_for_completion is False + + @patch("airflow.providers.amazon.aws.hooks.emr.EmrHook.add_job_flow_steps") + def test_emr_add_steps_deferrable(self, mock_add_job_flow_steps): + mock_add_job_flow_steps.return_value = "test_step_id" + job_flow_id = "j-8989898989" + operator = EmrAddStepsOperator( + task_id="test_task", + job_flow_id=job_flow_id, + aws_conn_id="aws_default", + dag=DAG("test_dag_id", default_args=self.args), + wait_for_completion=True, + deferrable=True, + ) + + with pytest.raises(TaskDeferred) as exc: + operator.execute(self.mock_context) + + assert isinstance(exc.value.trigger, EmrAddStepsTrigger), "Trigger is not a EmrAddStepsTrigger" diff --git a/tests/providers/amazon/aws/triggers/test_emr_trigger.py b/tests/providers/amazon/aws/triggers/test_emr_trigger.py new file mode 100644 index 0000000000000..e1145a77b5694 --- /dev/null +++ b/tests/providers/amazon/aws/triggers/test_emr_trigger.py @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import sys + +import pytest + +from airflow.providers.amazon.aws.triggers.emr import EmrAddStepsTrigger +from airflow.triggers.base import TriggerEvent + +if sys.version_info < (3, 8): + from asynctest import CoroutineMock as AsyncMock, mock as async_mock +else: + from unittest import mock as async_mock + from unittest.mock import AsyncMock + +TEST_JOB_FLOW_ID = "test_job_flow_id" +TEST_STEP_IDS = ["step1", "step2"] +TEST_AWS_CONN_ID = "test-aws-id" +TEST_MAX_ATTEMPT = 10 +TEST_POLL_INTERVAL = 10 + + +class TestEmrAddStepsTrigger: + def test_emr_add_steps_trigger_serialize(self): + emr_add_steps_trigger = EmrAddStepsTrigger( + job_flow_id=TEST_JOB_FLOW_ID, + step_ids=TEST_STEP_IDS, + aws_conn_id=TEST_AWS_CONN_ID, + max_attempts=TEST_MAX_ATTEMPT, + poll_interval=TEST_POLL_INTERVAL, + ) + class_path, args = emr_add_steps_trigger.serialize() + assert class_path == "airflow.providers.amazon.aws.triggers.emr.EmrAddStepsTrigger" + assert args["job_flow_id"] == TEST_JOB_FLOW_ID + assert args["step_ids"] == TEST_STEP_IDS + assert args["poll_interval"] == str(TEST_POLL_INTERVAL) + assert args["max_attempts"] == str(TEST_MAX_ATTEMPT) + assert args["aws_conn_id"] == TEST_AWS_CONN_ID + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrHook.async_conn") + async def test_redshift_create_cluster_trigger_run(self, mock_async_conn): + mock = async_mock.MagicMock() + mock_async_conn.__aenter__.return_value = mock + mock.get_waiter().wait = AsyncMock() + + emr_add_steps_trigger = EmrAddStepsTrigger( + job_flow_id=TEST_JOB_FLOW_ID, + step_ids=TEST_STEP_IDS, + aws_conn_id=TEST_AWS_CONN_ID, + max_attempts=TEST_MAX_ATTEMPT, + poll_interval=TEST_POLL_INTERVAL, + ) + + generator = emr_add_steps_trigger.run() + response = await generator.asend(None) + + assert response == TriggerEvent( + {"status": "success", "message": "Steps completed", "step_ids": TEST_STEP_IDS} + ) From 73c406fd716ee98d2f41b24386285d55efb886eb Mon Sep 17 00:00:00 2001 From: Syed Hussain Date: Mon, 29 May 2023 10:49:47 -0700 Subject: [PATCH 2/9] Add logging to waiters Add more tests for Trigger --- airflow/providers/amazon/aws/operators/emr.py | 1 + airflow/providers/amazon/aws/triggers/emr.py | 48 ++++--- .../amazon/aws/triggers/test_emr_trigger.py | 126 +++++++++++++++--- .../providers/amazon/aws/example_emr.py | 8 +- 4 files changed, 145 insertions(+), 38 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index 50fd9f4d770dc..8f66171b2ae16 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -111,6 +111,7 @@ def execute(self, context: Context) -> list[str]: job_flow_id = self.job_flow_id or emr_hook.get_cluster_id_by_name( str(self.job_flow_name), self.cluster_states ) + self.log.info(f"Deferrable is {self.deferrable} and wait_for_completion is {self.wait_for_completion}.") if not job_flow_id: raise AirflowException(f"No cluster found for name: {self.job_flow_name}") diff --git a/airflow/providers/amazon/aws/triggers/emr.py b/airflow/providers/amazon/aws/triggers/emr.py index e5d756e5cedae..7844022e08fcd 100644 --- a/airflow/providers/amazon/aws/triggers/emr.py +++ b/airflow/providers/amazon/aws/triggers/emr.py @@ -17,11 +17,11 @@ from __future__ import annotations import asyncio +from functools import cached_property from typing import Any from botocore.exceptions import WaiterError -from airflow.compat.functools import cached_property from airflow.providers.amazon.aws.hooks.emr import EmrHook from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -62,25 +62,33 @@ def hook(self) -> EmrHook: async def run(self): async with self.hook.async_conn as client: for step_id in self.step_ids: + attempt = 0 waiter = client.get_waiter("step_complete") - try: - await waiter.wait( - ClusterId=self.job_flow_id, - StepId=step_id, - WaiterConfig={ - "Delay": int(self.poll_interval), - "MaxAttempts": 1, - }, - ) - break - except WaiterError as error: - if "terminal failure" in str(error): - yield TriggerEvent( - {"status": "failure", "message": f"Steps failed: {error}"} + while attempt < int(self.max_attempts): + attempt += 1 + try: + await waiter.wait( + ClusterId=self.job_flow_id, + StepId=step_id, + WaiterConfig={ + "Delay": int(self.poll_interval), + "MaxAttempts": 1, + }, ) break - self.log.info( - "Status of step is %s", error.last_response["Step"]["Status"] - ) - await asyncio.sleep(int(self.poll_interval)) - yield TriggerEvent({"status": "success", "message": "Steps completed", "step_ids": self.step_ids}) + except WaiterError as error: + if "terminal failure" in str(error): + yield TriggerEvent( + {"status": "failure", "message": f"Step {step_id} failed: {error}"} + ) + break + self.log.info( + "Status of step is %s - %s", + error.last_response["Step"]["Status"]["State"], + error.last_response["Step"]["Status"]["StateChangeReason"], + ) + await asyncio.sleep(int(self.poll_interval)) + if attempt >= int(self.max_attempts): + yield TriggerEvent({"status": "failure", "message": "Steps failed: max attempts reached"}) + else: + yield TriggerEvent({"status": "success", "message": "Steps completed", "step_ids": self.step_ids}) diff --git a/tests/providers/amazon/aws/triggers/test_emr_trigger.py b/tests/providers/amazon/aws/triggers/test_emr_trigger.py index e1145a77b5694..0ec3b5af6eb8c 100644 --- a/tests/providers/amazon/aws/triggers/test_emr_trigger.py +++ b/tests/providers/amazon/aws/triggers/test_emr_trigger.py @@ -16,23 +16,20 @@ # under the License. from __future__ import annotations -import sys +from unittest import mock +from unittest.mock import AsyncMock import pytest +from botocore.exceptions import WaiterError +from airflow.providers.amazon.aws.hooks.emr import EmrHook from airflow.providers.amazon.aws.triggers.emr import EmrAddStepsTrigger from airflow.triggers.base import TriggerEvent -if sys.version_info < (3, 8): - from asynctest import CoroutineMock as AsyncMock, mock as async_mock -else: - from unittest import mock as async_mock - from unittest.mock import AsyncMock - TEST_JOB_FLOW_ID = "test_job_flow_id" TEST_STEP_IDS = ["step1", "step2"] TEST_AWS_CONN_ID = "test-aws-id" -TEST_MAX_ATTEMPT = 10 +TEST_MAX_ATTEMPTS = 10 TEST_POLL_INTERVAL = 10 @@ -42,7 +39,7 @@ def test_emr_add_steps_trigger_serialize(self): job_flow_id=TEST_JOB_FLOW_ID, step_ids=TEST_STEP_IDS, aws_conn_id=TEST_AWS_CONN_ID, - max_attempts=TEST_MAX_ATTEMPT, + max_attempts=TEST_MAX_ATTEMPTS, poll_interval=TEST_POLL_INTERVAL, ) class_path, args = emr_add_steps_trigger.serialize() @@ -50,21 +47,21 @@ def test_emr_add_steps_trigger_serialize(self): assert args["job_flow_id"] == TEST_JOB_FLOW_ID assert args["step_ids"] == TEST_STEP_IDS assert args["poll_interval"] == str(TEST_POLL_INTERVAL) - assert args["max_attempts"] == str(TEST_MAX_ATTEMPT) + assert args["max_attempts"] == str(TEST_MAX_ATTEMPTS) assert args["aws_conn_id"] == TEST_AWS_CONN_ID @pytest.mark.asyncio - @async_mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrHook.async_conn") - async def test_redshift_create_cluster_trigger_run(self, mock_async_conn): - mock = async_mock.MagicMock() - mock_async_conn.__aenter__.return_value = mock - mock.get_waiter().wait = AsyncMock() + @mock.patch.object(EmrHook, "async_conn") + async def test_emr_add_steps_trigger_run(self, mock_async_conn): + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + a_mock.get_waiter().wait = AsyncMock() emr_add_steps_trigger = EmrAddStepsTrigger( job_flow_id=TEST_JOB_FLOW_ID, step_ids=TEST_STEP_IDS, aws_conn_id=TEST_AWS_CONN_ID, - max_attempts=TEST_MAX_ATTEMPT, + max_attempts=TEST_MAX_ATTEMPTS, poll_interval=TEST_POLL_INTERVAL, ) @@ -74,3 +71,100 @@ async def test_redshift_create_cluster_trigger_run(self, mock_async_conn): assert response == TriggerEvent( {"status": "success", "message": "Steps completed", "step_ids": TEST_STEP_IDS} ) + + @pytest.mark.asyncio + @mock.patch("asyncio.sleep") + @mock.patch.object(EmrHook, "async_conn") + async def test_emr_add_steps_trigger_run_multiple_attempts(self, mock_async_conn, mock_sleep): + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + error = WaiterError( + name="test_name", + reason="test_reason", + last_response={"Step": {"Status": {"State": "Running", "StateChangeReason": "test_reason"}}}, + ) + a_mock.get_waiter().wait.side_effect = AsyncMock(side_effect=[error, error, True, error, error, True]) + mock_sleep.return_value = True + + emr_add_steps_trigger = EmrAddStepsTrigger( + job_flow_id=TEST_JOB_FLOW_ID, + step_ids=TEST_STEP_IDS, + aws_conn_id=TEST_AWS_CONN_ID, + max_attempts=TEST_MAX_ATTEMPTS, + poll_interval=TEST_POLL_INTERVAL, + ) + + generator = emr_add_steps_trigger.run() + response = await generator.asend(None) + + assert a_mock.get_waiter().wait.call_count == 6 + assert response == TriggerEvent( + {"status": "success", "message": "Steps completed", "step_ids": TEST_STEP_IDS} + ) + + @pytest.mark.asyncio + @mock.patch("asyncio.sleep") + @mock.patch.object(EmrHook, "async_conn") + async def test_emr_add_steps_trigger_run_attempts_exceeded(self, mock_async_conn, mock_sleep): + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + error = WaiterError( + name="test_name", + reason="test_reason", + last_response={"Step": {"Status": {"State": "Running", "StateChangeReason": "test_reason"}}}, + ) + a_mock.get_waiter().wait.side_effect = AsyncMock(side_effect=[error, error, True]) + mock_sleep.return_value = True + + emr_add_steps_trigger = EmrAddStepsTrigger( + job_flow_id=TEST_JOB_FLOW_ID, + step_ids=[TEST_STEP_IDS[0]], + aws_conn_id=TEST_AWS_CONN_ID, + max_attempts=2, + poll_interval=TEST_POLL_INTERVAL, + ) + + generator = emr_add_steps_trigger.run() + response = await generator.asend(None) + + assert a_mock.get_waiter().wait.call_count == 2 + assert response == TriggerEvent( + {"status": "failure", "message": "Steps failed: max attempts reached"} + ) + + @pytest.mark.asyncio + @mock.patch("asyncio.sleep") + @mock.patch.object(EmrHook, "async_conn") + async def test_emr_add_steps_trigger_run_attempts_failed(self, mock_async_conn, mock_sleep): + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + error_running = WaiterError( + name="test_name", + reason="test_reason", + last_response={"Step": {"Status": {"State": "Running", "StateChangeReason": "test_reason"}}}, + ) + error_failed = WaiterError( + name="test_name", + reason="Waiter encountered a terminal failure state:", + last_response={"Step": {"Status": {"State": "FAILED", "StateChangeReason": "test_reason"}}}, + ) + a_mock.get_waiter().wait.side_effect = AsyncMock( + side_effect=[error_running, error_running, error_failed] + ) + mock_sleep.return_value = True + + emr_add_steps_trigger = EmrAddStepsTrigger( + job_flow_id=TEST_JOB_FLOW_ID, + step_ids=[TEST_STEP_IDS[0]], + aws_conn_id=TEST_AWS_CONN_ID, + max_attempts=TEST_MAX_ATTEMPTS, + poll_interval=TEST_POLL_INTERVAL, + ) + + generator = emr_add_steps_trigger.run() + response = await generator.asend(None) + + assert a_mock.get_waiter().wait.call_count == 3 + assert response == TriggerEvent( + {"status": "failure", "message": f"Step {TEST_STEP_IDS[0]} failed: {error_failed}"} + ) diff --git a/tests/system/providers/amazon/aws/example_emr.py b/tests/system/providers/amazon/aws/example_emr.py index 5974afb7e70c7..1d8d4aaa04232 100644 --- a/tests/system/providers/amazon/aws/example_emr.py +++ b/tests/system/providers/amazon/aws/example_emr.py @@ -165,12 +165,16 @@ def get_step_id(step_ids: list): job_flow_id=create_job_flow.output, steps=SPARK_STEPS, execution_role_arn=execution_role_arn, + waiter_delay=30, + deferrable=True, + waiter_max_attempts=200, + ) # [END howto_operator_emr_add_steps] - add_steps.wait_for_completion = True + # add_steps.wait_for_completion = True # On rare occasion (1 in 50ish?) this system test times out. Extending the # max_attempts from the default 60 to attempt to mitigate the flaky test. - add_steps.waiter_max_attempts = 90 + # add_steps.waiter_max_attempts = 90 # [START howto_sensor_emr_step] wait_for_step = EmrStepSensor( From 56c96aa6c43cbd2d0afa34c692ed9657f2312de4 Mon Sep 17 00:00:00 2001 From: Syed Hussain Date: Wed, 31 May 2023 01:48:42 -0700 Subject: [PATCH 3/9] Remove changes to emr system test --- tests/system/providers/amazon/aws/example_emr.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/system/providers/amazon/aws/example_emr.py b/tests/system/providers/amazon/aws/example_emr.py index 1d8d4aaa04232..5974afb7e70c7 100644 --- a/tests/system/providers/amazon/aws/example_emr.py +++ b/tests/system/providers/amazon/aws/example_emr.py @@ -165,16 +165,12 @@ def get_step_id(step_ids: list): job_flow_id=create_job_flow.output, steps=SPARK_STEPS, execution_role_arn=execution_role_arn, - waiter_delay=30, - deferrable=True, - waiter_max_attempts=200, - ) # [END howto_operator_emr_add_steps] - # add_steps.wait_for_completion = True + add_steps.wait_for_completion = True # On rare occasion (1 in 50ish?) this system test times out. Extending the # max_attempts from the default 60 to attempt to mitigate the flaky test. - # add_steps.waiter_max_attempts = 90 + add_steps.waiter_max_attempts = 90 # [START howto_sensor_emr_step] wait_for_step = EmrStepSensor( From f03982d16191dd73086b4737410dc2fdde54f959 Mon Sep 17 00:00:00 2001 From: Syed Hussain Date: Wed, 31 May 2023 02:42:24 -0700 Subject: [PATCH 4/9] mock call to get_log_uri in unit test --- tests/providers/amazon/aws/operators/test_emr_add_steps.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/providers/amazon/aws/operators/test_emr_add_steps.py b/tests/providers/amazon/aws/operators/test_emr_add_steps.py index 1be70d3f095cd..0b279c051f9c6 100644 --- a/tests/providers/amazon/aws/operators/test_emr_add_steps.py +++ b/tests/providers/amazon/aws/operators/test_emr_add_steps.py @@ -259,9 +259,11 @@ def test_wait_for_completion_false_with_deferrable(self): assert operator.wait_for_completion is False + @patch("airflow.providers.amazon.aws.operators.emr.get_log_uri") @patch("airflow.providers.amazon.aws.hooks.emr.EmrHook.add_job_flow_steps") - def test_emr_add_steps_deferrable(self, mock_add_job_flow_steps): + def test_emr_add_steps_deferrable(self, mock_add_job_flow_steps, mock_get_log_uri): mock_add_job_flow_steps.return_value = "test_step_id" + mock_get_log_uri.return_value = "test/log/uri" job_flow_id = "j-8989898989" operator = EmrAddStepsOperator( task_id="test_task", From a34e44763fb5c206e99a74ceba402d750fe45320 Mon Sep 17 00:00:00 2001 From: Syed Hussain Date: Wed, 31 May 2023 07:12:57 -0700 Subject: [PATCH 5/9] Remove debugging message --- airflow/providers/amazon/aws/operators/emr.py | 1 - 1 file changed, 1 deletion(-) diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index 8f66171b2ae16..50fd9f4d770dc 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -111,7 +111,6 @@ def execute(self, context: Context) -> list[str]: job_flow_id = self.job_flow_id or emr_hook.get_cluster_id_by_name( str(self.job_flow_name), self.cluster_states ) - self.log.info(f"Deferrable is {self.deferrable} and wait_for_completion is {self.wait_for_completion}.") if not job_flow_id: raise AirflowException(f"No cluster found for name: {self.job_flow_name}") From 2d804fb7609cf65ea35a7ca715c1a6d679bfa892 Mon Sep 17 00:00:00 2001 From: Syed Hussain Date: Wed, 31 May 2023 07:14:41 -0700 Subject: [PATCH 6/9] Update doc string for deferrable parameter to be more informative --- airflow/providers/amazon/aws/operators/emr.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index 50fd9f4d770dc..b355de19843ad 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -56,7 +56,10 @@ class EmrAddStepsOperator(BaseOperator): :param wait_for_completion: If True, the operator will wait for all the steps to be completed. :param execution_role_arn: The ARN of the runtime role for a step on the cluster. :param do_xcom_push: if True, job_flow_id is pushed to XCom with key job_flow_id. - :param deferrable: if True, the operator will run in deferrable mode. + :param wait_for_completion: Whether to wait for job run completion. (default: True) + :param deferrable: If True, the operator will wait asynchronously for the job to complete. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) """ template_fields: Sequence[str] = ( From 42364ced11829cdb58b176bd10c6ab8376ffa617 Mon Sep 17 00:00:00 2001 From: Syed Hussain Date: Tue, 6 Jun 2023 20:32:43 -0700 Subject: [PATCH 7/9] Add doc string for Trigger --- airflow/providers/amazon/aws/triggers/emr.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/airflow/providers/amazon/aws/triggers/emr.py b/airflow/providers/amazon/aws/triggers/emr.py index 7844022e08fcd..764a9235abe38 100644 --- a/airflow/providers/amazon/aws/triggers/emr.py +++ b/airflow/providers/amazon/aws/triggers/emr.py @@ -27,7 +27,16 @@ class EmrAddStepsTrigger(BaseTrigger): - """AWS Emr Add Steps Trigger""" + """ + AWS Emr Add Steps Trigger + The trigger will asynchronously poll the boto3 API and wait for the + steps to finish executing. + :param job_flow_id: The id of the job flow. + :param step_ids: The id of the steps being waited upon. + :param poll_interval: The amount of time in seconds to wait between attempts. + :param max_attempts: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. + """ def __init__( self, From 758ee1265cfc646437fdcd623a186e0edf2676b4 Mon Sep 17 00:00:00 2001 From: Syed Hussain Date: Wed, 7 Jun 2023 13:54:40 -0700 Subject: [PATCH 8/9] Fix static checks Remove hook as a cached property in Trigger. --- airflow/providers/amazon/aws/triggers/emr.py | 6 +----- airflow/providers/amazon/provider.yaml | 3 +++ 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/airflow/providers/amazon/aws/triggers/emr.py b/airflow/providers/amazon/aws/triggers/emr.py index 764a9235abe38..2afc1c45afbdf 100644 --- a/airflow/providers/amazon/aws/triggers/emr.py +++ b/airflow/providers/amazon/aws/triggers/emr.py @@ -17,7 +17,6 @@ from __future__ import annotations import asyncio -from functools import cached_property from typing import Any from botocore.exceptions import WaiterError @@ -64,11 +63,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]: }, ) - @cached_property - def hook(self) -> EmrHook: - return EmrHook(aws_conn_id=self.aws_conn_id) - async def run(self): + self.hook = EmrHook(aws_conn_id=self.aws_conn_id) async with self.hook.async_conn as client: for step_id in self.step_ids: attempt = 0 diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index 716e3d8432cfb..9725a55bbaa15 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -522,6 +522,9 @@ triggers: python-modules: - airflow.providers.amazon.aws.triggers.glue - airflow.providers.amazon.aws.triggers.glue_crawler + - integration-name: Amazon EMR + python-modules: + - airflow.providers.amazon.aws.triggers.emr transfers: - source-integration-name: Amazon DynamoDB From bd78a9316a36b591913c0a67d8478780e9018c11 Mon Sep 17 00:00:00 2001 From: Syed Hussain Date: Wed, 7 Jun 2023 14:49:27 -0700 Subject: [PATCH 9/9] Update docs to mention availability of deferrable mode for EmrAddStepsOperator --- docs/apache-airflow-providers-amazon/operators/emr/emr.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/apache-airflow-providers-amazon/operators/emr/emr.rst b/docs/apache-airflow-providers-amazon/operators/emr/emr.rst index bd4e3e78c3a18..6cd628eb239cd 100644 --- a/docs/apache-airflow-providers-amazon/operators/emr/emr.rst +++ b/docs/apache-airflow-providers-amazon/operators/emr/emr.rst @@ -89,6 +89,10 @@ Add Steps to an EMR job flow To add steps to an existing EMR Job flow you can use :class:`~airflow.providers.amazon.aws.operators.emr.EmrAddStepsOperator`. +This operator can be run in deferrable mode by passing ``deferrable=True`` as a parameter. +Using ``deferrable`` mode will release worker slots and leads to efficient utilization of +resources within Airflow cluster.However this mode will need the Airflow triggerer to be +available in your deployment. .. exampleinclude:: /../../tests/system/providers/amazon/aws/example_emr.py :language: python