diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/ec2.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/ec2.py index 78202e7bd8f73..acacaaf779f32 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/ec2.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/ec2.py @@ -218,31 +218,46 @@ def execute(self, context: Context): MaxCount=self.max_count, **self.config, )["Instances"] - - instance_ids = self._on_kill_instance_ids = [instance["InstanceId"] for instance in instances] - # Console link is for EC2 dashboard list, not individual instances when more than 1 instance - - EC2InstanceDashboardLink.persist( - context=context, - operator=self, - region_name=self.hook.conn_region_name, - aws_partition=self.hook.conn_partition, - instance_ids=EC2InstanceDashboardLink.format_instance_id_filter(instance_ids), - ) - for instance_id in instance_ids: - self.log.info("Created EC2 instance %s", instance_id) - - if self.wait_for_completion: - self.hook.get_waiter("instance_running").wait( - InstanceIds=[instance_id], - WaiterConfig={ - "Delay": self.poll_interval, - "MaxAttempts": self.max_attempts, - }, + try: + instance_ids = self._on_kill_instance_ids = [instance["InstanceId"] for instance in instances] + # Console link is for EC2 dashboard list, not individual instances when more than 1 instance + + EC2InstanceDashboardLink.persist( + context=context, + operator=self, + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, + instance_ids=EC2InstanceDashboardLink.format_instance_id_filter(instance_ids), + ) + for instance_id in instance_ids: + self.log.info("Created EC2 instance %s", instance_id) + + if self.wait_for_completion: + self.hook.get_waiter("instance_running").wait( + InstanceIds=[instance_id], + WaiterConfig={ + "Delay": self.poll_interval, + "MaxAttempts": self.max_attempts, + }, + ) + + # leave "_on_kill_instance_ids" in place for finishing post-processing + return instance_ids + + # Best-effort cleanup when post-creation steps fail (e.g. IAM/permission errors). + except Exception: + self.log.exception( + "Exception after EC2 instance creation; attempting cleanup for instances %s", + instance_ids, + ) + try: + self.hook.terminate_instances(instance_ids=instance_ids) + except Exception: + self.log.exception( + "Failed to cleanup EC2 instances %s after task failure", + instance_ids, ) - - # leave "_on_kill_instance_ids" in place for finishing post-processing - return instance_ids + raise def on_kill(self) -> None: instance_ids = getattr(self, "_on_kill_instance_ids", []) diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_ec2.py b/providers/amazon/tests/unit/amazon/aws/operators/test_ec2.py index 51d7963e22930..89eae7cdd0fa1 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_ec2.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_ec2.py @@ -17,7 +17,10 @@ # under the License. from __future__ import annotations +from unittest import mock + import pytest +from botocore.exceptions import ClientError, WaiterError from moto import mock_aws from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook @@ -96,6 +99,80 @@ def test_template_fields(self): ) validate_template_fields(ec2_operator) + @mock_aws + def test_cleanup_on_post_create_failure(self): + ec2_hook = EC2Hook() + + operator = EC2CreateInstanceOperator( + task_id="test_cleanup_on_error", + image_id=self._get_image_id(ec2_hook), + wait_for_completion=True, + ) + + waiter_error = WaiterError( + "InstanceRunning", + "You are not authorized to perform this operation", + {}, + ) + + # Force failure after instance creation (e.g. missing DescribeInstances permission). + with mock.patch.object(operator.hook, "get_waiter") as mock_get_waiter: + mock_get_waiter.return_value.wait.side_effect = waiter_error + with pytest.raises(WaiterError) as exc: + operator.execute(None) + + # Ensure the original waiter exception is propagated unchanged. + assert exc.value is waiter_error + + # Instance must have been terminated. + # We know exactly one instance was created. + instances = list(ec2_hook.conn.instances.all()) + assert len(instances) == 1 + + instance = instances[0] + assert instance.state["Name"] == "terminated" + + @mock_aws + def test_cleanup_failure_propagates_original_exception(self): + ec2_hook = EC2Hook() + + operator = EC2CreateInstanceOperator( + task_id="test_cleanup_failure_does_not_mask_error", + image_id=self._get_image_id(ec2_hook), + wait_for_completion=True, + ) + + waiter_error = WaiterError( + "InstanceRunning", + "You are not authorized to perform this operation", + {}, + ) + + cleanup_error = ClientError( + error_response={ + "Error": { + "Code": "UnauthorizedOperation", + "Message": "You are not authorized to perform this operation", + } + }, + operation_name="TerminateInstances", + ) + + with ( + mock.patch.object(operator.hook, "get_waiter") as mock_get_waiter, + mock.patch.object(operator.hook, "terminate_instances") as mock_terminate, + ): + mock_get_waiter.return_value.wait.side_effect = waiter_error + mock_terminate.side_effect = cleanup_error + + with pytest.raises(WaiterError) as exc: + operator.execute(None) + + # Ensure the original waiter exception is propagated unchanged. + assert exc.value is waiter_error + + # Cleanup is best-effort; failure to terminate must not override the original error. + class TestEC2TerminateInstanceOperator(BaseEc2TestClass): def test_init(self):