Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix automatic termination issue in EmrOperator by ensuring waiter_max_attempts is set for deferrable triggers #38658

Merged
merged 9 commits into from
May 21, 2024
8 changes: 8 additions & 0 deletions airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,14 @@ def execute(self, context: Context) -> str | None:
job_id=self.job_id,
aws_conn_id=self.aws_conn_id,
waiter_delay=self.poll_interval,
waiter_max_attempts=self.max_polling_attempts,
)
if self.max_polling_attempts
vincbeck marked this conversation as resolved.
Show resolved Hide resolved
else EmrContainerTrigger(
virtual_cluster_id=self.virtual_cluster_id,
job_id=self.job_id,
aws_conn_id=self.aws_conn_id,
waiter_delay=self.poll_interval,
),
method_name="execute_complete",
)
Expand Down
8 changes: 8 additions & 0 deletions airflow/providers/amazon/aws/sensors/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,14 @@ def execute(self, context: Context):
job_id=self.job_id,
aws_conn_id=self.aws_conn_id,
waiter_delay=self.poll_interval,
waiter_max_attempts=self.max_retries,
)
if self.max_retries
else EmrContainerTrigger(
virtual_cluster_id=self.virtual_cluster_id,
job_id=self.job_id,
aws_conn_id=self.aws_conn_id,
waiter_delay=self.poll_interval,
),
method_name="execute_complete",
)
Expand Down
16 changes: 16 additions & 0 deletions tests/providers/amazon/aws/operators/test_emr_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,22 @@ def test_operator_defer(self, mock_submit_job, mock_check_query_status):
exc.value.trigger, EmrContainerTrigger
), f"{exc.value.trigger} is not a EmrContainerTrigger"

@mock.patch.object(EmrContainerHook, "submit_job")
@mock.patch.object(
EmrContainerHook, "check_query_status", return_value=EmrContainerHook.INTERMEDIATE_STATES[0]
)
def test_operator_defer_with_timeout(self, mock_submit_job, mock_check_query_status):
self.emr_container.deferrable = True
self.emr_container.max_polling_attempts = 1000
o-nikolas marked this conversation as resolved.
Show resolved Hide resolved

with pytest.raises(TaskDeferred) as e:
self.emr_container.execute(context=None)

trigger = e.value.trigger
assert isinstance(trigger, EmrContainerTrigger), f"{trigger} is not a EmrContainerTrigger"
assert trigger.waiter_delay == self.emr_container.poll_interval
assert trigger.attempts == self.emr_container.max_polling_attempts


class TestEmrEksCreateClusterOperator:
def setup_method(self):
Expand Down
14 changes: 14 additions & 0 deletions tests/providers/amazon/aws/sensors/test_emr_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,17 @@ def test_sensor_defer(self, mock_poke):
assert isinstance(
e.value.trigger, EmrContainerTrigger
), f"{e.value.trigger} is not a EmrContainerTrigger"

@mock.patch("airflow.providers.amazon.aws.sensors.emr.EmrContainerSensor.poke")
def test_sensor_defer_with_timeout(self, mock_poke):
self.sensor.deferrable = True
mock_poke.return_value = False
self.sensor.max_retries = 1000

with pytest.raises(TaskDeferred) as e:
self.sensor.execute(context=None)

trigger = e.value.trigger
assert isinstance(trigger, EmrContainerTrigger), f"{trigger} is not a EmrContainerTrigger"
assert trigger.waiter_delay == self.sensor.poll_interval
vincbeck marked this conversation as resolved.
Show resolved Hide resolved
assert trigger.attempts == self.sensor.max_retries