From c458827594b1d2b6834e638b02ca5a5d023cb7c6 Mon Sep 17 00:00:00 2001 From: beobest2 Date: Mon, 1 Apr 2024 15:57:19 -0400 Subject: [PATCH 1/8] Fix automatic termination issue in EmrOperator by ensuring is set for deferrable triggers --- airflow/providers/amazon/aws/operators/emr.py | 1 + 1 file changed, 1 insertion(+) diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index ec1be30f91346..3575ce81796ef 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -617,6 +617,7 @@ def execute(self, context: Context) -> str | None: job_id=self.job_id, aws_conn_id=self.aws_conn_id, waiter_delay=self.poll_interval, + waiter_max_attempts=self.max_polling_attempts, ), method_name="execute_complete", ) From 699d3f53be9cb5dd798d9285bce78b9af6f15594 Mon Sep 17 00:00:00 2001 From: beobest2 Date: Mon, 8 Apr 2024 16:13:50 -0700 Subject: [PATCH 2/8] Fix: mypy type check error --- airflow/providers/amazon/aws/operators/emr.py | 6 +++++- airflow/providers/amazon/aws/sensors/emr.py | 7 ++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index 3575ce81796ef..6f83a4fcb8575 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -618,7 +618,11 @@ def execute(self, context: Context) -> str | None: aws_conn_id=self.aws_conn_id, waiter_delay=self.poll_interval, waiter_max_attempts=self.max_polling_attempts, - ), + ) if self.max_polling_attempts else EmrContainerTrigger( + virtual_cluster_id=self.virtual_cluster_id, + job_id=self.job_id, + aws_conn_id=self.aws_conn_id, + waiter_delay=self.poll_interval), method_name="execute_complete", ) if self.wait_for_completion: diff --git a/airflow/providers/amazon/aws/sensors/emr.py b/airflow/providers/amazon/aws/sensors/emr.py index 30a4c943c69ba..33ee0a6bfbd97 100644 --- a/airflow/providers/amazon/aws/sensors/emr.py +++ b/airflow/providers/amazon/aws/sensors/emr.py @@ -354,7 +354,12 @@ def execute(self, context: Context): job_id=self.job_id, aws_conn_id=self.aws_conn_id, waiter_delay=self.poll_interval, - ), + waiter_max_attempts=self.max_retries, + ) if self.max_retries else EmrContainerTrigger( + virtual_cluster_id=self.virtual_cluster_id, + job_id=self.job_id, + aws_conn_id=self.aws_conn_id, + waiter_delay=self.poll_interval), method_name="execute_complete", ) From e6226176f246452a217c7d78c79fe96a26fce706 Mon Sep 17 00:00:00 2001 From: beobest2 Date: Wed, 10 Apr 2024 00:11:08 -0700 Subject: [PATCH 3/8] Fix: linting --- airflow/providers/amazon/aws/operators/emr.py | 7 +++++-- airflow/providers/amazon/aws/sensors/emr.py | 7 +++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index 6f83a4fcb8575..1a4518af706da 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -618,11 +618,14 @@ def execute(self, context: Context) -> str | None: aws_conn_id=self.aws_conn_id, waiter_delay=self.poll_interval, waiter_max_attempts=self.max_polling_attempts, - ) if self.max_polling_attempts else EmrContainerTrigger( + ) + if self.max_polling_attempts + else EmrContainerTrigger( virtual_cluster_id=self.virtual_cluster_id, job_id=self.job_id, aws_conn_id=self.aws_conn_id, - waiter_delay=self.poll_interval), + waiter_delay=self.poll_interval, + ), method_name="execute_complete", ) if self.wait_for_completion: diff --git a/airflow/providers/amazon/aws/sensors/emr.py b/airflow/providers/amazon/aws/sensors/emr.py index 33ee0a6bfbd97..19e026e7a6c4e 100644 --- a/airflow/providers/amazon/aws/sensors/emr.py +++ b/airflow/providers/amazon/aws/sensors/emr.py @@ -355,11 +355,14 @@ def execute(self, context: Context): aws_conn_id=self.aws_conn_id, waiter_delay=self.poll_interval, waiter_max_attempts=self.max_retries, - ) if self.max_retries else EmrContainerTrigger( + ) + if self.max_retries + else EmrContainerTrigger( virtual_cluster_id=self.virtual_cluster_id, job_id=self.job_id, aws_conn_id=self.aws_conn_id, - waiter_delay=self.poll_interval), + waiter_delay=self.poll_interval, + ), method_name="execute_complete", ) From 4155476a2c8ba3bcbde84d824d435affe891802b Mon Sep 17 00:00:00 2001 From: beobest2 Date: Wed, 10 Apr 2024 00:32:43 -0700 Subject: [PATCH 4/8] Add: unit test --- .../amazon/aws/operators/test_emr_containers.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/providers/amazon/aws/operators/test_emr_containers.py b/tests/providers/amazon/aws/operators/test_emr_containers.py index 8e94e744d943a..79ccbf99c99c3 100644 --- a/tests/providers/amazon/aws/operators/test_emr_containers.py +++ b/tests/providers/amazon/aws/operators/test_emr_containers.py @@ -144,6 +144,20 @@ def test_operator_defer(self, mock_submit_job, mock_check_query_status): exc.value.trigger, EmrContainerTrigger ), f"{exc.value.trigger} is not a EmrContainerTrigger" + @mock.patch.object(EmrContainerHook, "submit_job") + @mock.patch.object( + EmrContainerHook, "check_query_status", return_value=EmrContainerHook.INTERMEDIATE_STATES[0] + ) + def test_operator_defer_with_timeout(self, mock_submit_job, mock_check_query_status): + self.emr_container.deferrable = True + self.emr_container.max_polling_attempts = 1000 + + error_match = "Final state of EMR Containers job is SUBMITTED.*Max tries of poll status exceeded" + with pytest.raises(AirflowException, match=error_match): + self.emr_container.execute(context=None) + + assert mock_check_query_status.call_count == 1000 + class TestEmrEksCreateClusterOperator: def setup_method(self): From c708126a7b33fb49b6e63109983d1186d73b5dd1 Mon Sep 17 00:00:00 2001 From: beobest2 Date: Wed, 10 Apr 2024 22:44:59 -0700 Subject: [PATCH 5/8] Fix: error_match string --- tests/providers/amazon/aws/operators/test_emr_containers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/providers/amazon/aws/operators/test_emr_containers.py b/tests/providers/amazon/aws/operators/test_emr_containers.py index 79ccbf99c99c3..99c8b067bdee3 100644 --- a/tests/providers/amazon/aws/operators/test_emr_containers.py +++ b/tests/providers/amazon/aws/operators/test_emr_containers.py @@ -152,7 +152,7 @@ def test_operator_defer_with_timeout(self, mock_submit_job, mock_check_query_sta self.emr_container.deferrable = True self.emr_container.max_polling_attempts = 1000 - error_match = "Final state of EMR Containers job is SUBMITTED.*Max tries of poll status exceeded" + error_match = "Waiter error: max attempts reached" with pytest.raises(AirflowException, match=error_match): self.emr_container.execute(context=None) From 498189c464783093332be70afd55ad315993e94d Mon Sep 17 00:00:00 2001 From: beobest2 Date: Fri, 12 Apr 2024 00:15:01 -0700 Subject: [PATCH 6/8] Add: EMR sensor unit test --- .../amazon/aws/sensors/test_emr_containers.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/providers/amazon/aws/sensors/test_emr_containers.py b/tests/providers/amazon/aws/sensors/test_emr_containers.py index 606281e70a620..273b0c795115f 100644 --- a/tests/providers/amazon/aws/sensors/test_emr_containers.py +++ b/tests/providers/amazon/aws/sensors/test_emr_containers.py @@ -84,3 +84,16 @@ def test_sensor_defer(self, mock_poke): assert isinstance( e.value.trigger, EmrContainerTrigger ), f"{e.value.trigger} is not a EmrContainerTrigger" + + @mock.patch.object( + EmrContainerHook, "check_query_status", return_value=EmrContainerHook.INTERMEDIATE_STATES[0] + ) + def test_sensor_defer_with_timeout(self, mock_check_query_status): + self.sensor.deferrable = True + self.sensor.max_polling_attempts = 1000 + + error_match = "Waiter error: max attempts reached" + with pytest.raises(TaskDeferred, match=error_match): + self.sensor.execute(context=None) + + assert mock_check_query_status.call_count == 1000 From 993b583402e98d742144c2350641115b78504633 Mon Sep 17 00:00:00 2001 From: beobest2 Date: Wed, 1 May 2024 22:39:14 -0400 Subject: [PATCH 7/8] Fix: test for deferable operation --- tests/providers/amazon/aws/operators/test_emr_containers.py | 3 +-- tests/providers/amazon/aws/sensors/test_emr_containers.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/providers/amazon/aws/operators/test_emr_containers.py b/tests/providers/amazon/aws/operators/test_emr_containers.py index 99c8b067bdee3..fa5c4764b87ea 100644 --- a/tests/providers/amazon/aws/operators/test_emr_containers.py +++ b/tests/providers/amazon/aws/operators/test_emr_containers.py @@ -152,8 +152,7 @@ def test_operator_defer_with_timeout(self, mock_submit_job, mock_check_query_sta self.emr_container.deferrable = True self.emr_container.max_polling_attempts = 1000 - error_match = "Waiter error: max attempts reached" - with pytest.raises(AirflowException, match=error_match): + with pytest.raises(TaskDeferred): self.emr_container.execute(context=None) assert mock_check_query_status.call_count == 1000 diff --git a/tests/providers/amazon/aws/sensors/test_emr_containers.py b/tests/providers/amazon/aws/sensors/test_emr_containers.py index 273b0c795115f..048533cb60e6c 100644 --- a/tests/providers/amazon/aws/sensors/test_emr_containers.py +++ b/tests/providers/amazon/aws/sensors/test_emr_containers.py @@ -92,8 +92,7 @@ def test_sensor_defer_with_timeout(self, mock_check_query_status): self.sensor.deferrable = True self.sensor.max_polling_attempts = 1000 - error_match = "Waiter error: max attempts reached" - with pytest.raises(TaskDeferred, match=error_match): + with pytest.raises(TaskDeferred): self.sensor.execute(context=None) assert mock_check_query_status.call_count == 1000 From 23fad983bee20f900a252e81c1282252916a2050 Mon Sep 17 00:00:00 2001 From: beobest2 Date: Thu, 16 May 2024 01:06:16 -0400 Subject: [PATCH 8/8] fix test code --- .../amazon/aws/operators/test_emr_containers.py | 7 +++++-- .../amazon/aws/sensors/test_emr_containers.py | 16 +++++++++------- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/tests/providers/amazon/aws/operators/test_emr_containers.py b/tests/providers/amazon/aws/operators/test_emr_containers.py index fa5c4764b87ea..feeec1278e155 100644 --- a/tests/providers/amazon/aws/operators/test_emr_containers.py +++ b/tests/providers/amazon/aws/operators/test_emr_containers.py @@ -152,10 +152,13 @@ def test_operator_defer_with_timeout(self, mock_submit_job, mock_check_query_sta self.emr_container.deferrable = True self.emr_container.max_polling_attempts = 1000 - with pytest.raises(TaskDeferred): + with pytest.raises(TaskDeferred) as e: self.emr_container.execute(context=None) - assert mock_check_query_status.call_count == 1000 + trigger = e.value.trigger + assert isinstance(trigger, EmrContainerTrigger), f"{trigger} is not a EmrContainerTrigger" + assert trigger.waiter_delay == self.emr_container.poll_interval + assert trigger.attempts == self.emr_container.max_polling_attempts class TestEmrEksCreateClusterOperator: diff --git a/tests/providers/amazon/aws/sensors/test_emr_containers.py b/tests/providers/amazon/aws/sensors/test_emr_containers.py index 048533cb60e6c..65ae0729341ff 100644 --- a/tests/providers/amazon/aws/sensors/test_emr_containers.py +++ b/tests/providers/amazon/aws/sensors/test_emr_containers.py @@ -85,14 +85,16 @@ def test_sensor_defer(self, mock_poke): e.value.trigger, EmrContainerTrigger ), f"{e.value.trigger} is not a EmrContainerTrigger" - @mock.patch.object( - EmrContainerHook, "check_query_status", return_value=EmrContainerHook.INTERMEDIATE_STATES[0] - ) - def test_sensor_defer_with_timeout(self, mock_check_query_status): + @mock.patch("airflow.providers.amazon.aws.sensors.emr.EmrContainerSensor.poke") + def test_sensor_defer_with_timeout(self, mock_poke): self.sensor.deferrable = True - self.sensor.max_polling_attempts = 1000 + mock_poke.return_value = False + self.sensor.max_retries = 1000 - with pytest.raises(TaskDeferred): + with pytest.raises(TaskDeferred) as e: self.sensor.execute(context=None) - assert mock_check_query_status.call_count == 1000 + trigger = e.value.trigger + assert isinstance(trigger, EmrContainerTrigger), f"{trigger} is not a EmrContainerTrigger" + assert trigger.waiter_delay == self.sensor.poll_interval + assert trigger.attempts == self.sensor.max_retries