Skip to content

Commit

Permalink
Fix automatic termination issue in EmrOperator by ensuring `waiter_…
Browse files Browse the repository at this point in the history
…max_attempts` is set for deferrable triggers (apache#38658)
  • Loading branch information
beobest2 authored and RNHTTR committed Jun 1, 2024
1 parent ad8dff6 commit 424eabd
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 0 deletions.
8 changes: 8 additions & 0 deletions airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,14 @@ def execute(self, context: Context) -> str | None:
job_id=self.job_id,
aws_conn_id=self.aws_conn_id,
waiter_delay=self.poll_interval,
waiter_max_attempts=self.max_polling_attempts,
)
if self.max_polling_attempts
else EmrContainerTrigger(
virtual_cluster_id=self.virtual_cluster_id,
job_id=self.job_id,
aws_conn_id=self.aws_conn_id,
waiter_delay=self.poll_interval,
),
method_name="execute_complete",
)
Expand Down
8 changes: 8 additions & 0 deletions airflow/providers/amazon/aws/sensors/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,14 @@ def execute(self, context: Context):
job_id=self.job_id,
aws_conn_id=self.aws_conn_id,
waiter_delay=self.poll_interval,
waiter_max_attempts=self.max_retries,
)
if self.max_retries
else EmrContainerTrigger(
virtual_cluster_id=self.virtual_cluster_id,
job_id=self.job_id,
aws_conn_id=self.aws_conn_id,
waiter_delay=self.poll_interval,
),
method_name="execute_complete",
)
Expand Down
16 changes: 16 additions & 0 deletions tests/providers/amazon/aws/operators/test_emr_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,22 @@ def test_operator_defer(self, mock_submit_job, mock_check_query_status):
exc.value.trigger, EmrContainerTrigger
), f"{exc.value.trigger} is not a EmrContainerTrigger"

@mock.patch.object(EmrContainerHook, "submit_job")
@mock.patch.object(
EmrContainerHook, "check_query_status", return_value=EmrContainerHook.INTERMEDIATE_STATES[0]
)
def test_operator_defer_with_timeout(self, mock_submit_job, mock_check_query_status):
self.emr_container.deferrable = True
self.emr_container.max_polling_attempts = 1000

with pytest.raises(TaskDeferred) as e:
self.emr_container.execute(context=None)

trigger = e.value.trigger
assert isinstance(trigger, EmrContainerTrigger), f"{trigger} is not a EmrContainerTrigger"
assert trigger.waiter_delay == self.emr_container.poll_interval
assert trigger.attempts == self.emr_container.max_polling_attempts


class TestEmrEksCreateClusterOperator:
def setup_method(self):
Expand Down
14 changes: 14 additions & 0 deletions tests/providers/amazon/aws/sensors/test_emr_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,17 @@ def test_sensor_defer(self, mock_poke):
assert isinstance(
e.value.trigger, EmrContainerTrigger
), f"{e.value.trigger} is not a EmrContainerTrigger"

@mock.patch("airflow.providers.amazon.aws.sensors.emr.EmrContainerSensor.poke")
def test_sensor_defer_with_timeout(self, mock_poke):
self.sensor.deferrable = True
mock_poke.return_value = False
self.sensor.max_retries = 1000

with pytest.raises(TaskDeferred) as e:
self.sensor.execute(context=None)

trigger = e.value.trigger
assert isinstance(trigger, EmrContainerTrigger), f"{trigger} is not a EmrContainerTrigger"
assert trigger.waiter_delay == self.sensor.poll_interval
assert trigger.attempts == self.sensor.max_retries

0 comments on commit 424eabd

Please sign in to comment.