diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index a3e10044b5788..2dd43960e8bcd 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -230,6 +230,7 @@ callsite camelCase Cancelled cancelled +Cancelling carbonite cas Cassanda diff --git a/providers/amazon/docs/operators/ssm.rst b/providers/amazon/docs/operators/ssm.rst index 77e17dc32204e..317f62648b25a 100644 --- a/providers/amazon/docs/operators/ssm.rst +++ b/providers/amazon/docs/operators/ssm.rst @@ -52,12 +52,42 @@ Waiter. Additionally, you can use the following components to track the status o :class:`~airflow.providers.amazon.aws.sensors.ssm.SsmRunCommandCompletedSensor` Sensor, or the :class:`~airflow.providers.amazon.aws.triggers.ssm.SsmRunCommandTrigger` Trigger. - .. exampleinclude:: /../../amazon/tests/system/amazon/aws/example_ssm.py :language: python :dedent: 4 :start-after: [START howto_operator_run_command] :end-before: [END howto_operator_run_command] + +Exit code handling +^^^^^^^^^^^^^^^^^^ + +By default, both :class:`~airflow.providers.amazon.aws.operators.ssm.SsmRunCommandOperator` and +:class:`~airflow.providers.amazon.aws.sensors.ssm.SsmRunCommandCompletedSensor` will fail the task +if the command returns a non-zero exit code. You can change this behavior using the ``fail_on_nonzero_exit`` +parameter: + +.. code-block:: python + + # Default behavior - task fails on non-zero exit codes + run_command = SsmRunCommandOperator( + task_id="run_command", + document_name="AWS-RunShellScript", + run_command_kwargs={...}, + ) + + # Allow non-zero exit codes - task succeeds regardless of exit code + run_command = SsmRunCommandOperator( + task_id="run_command", + document_name="AWS-RunShellScript", + run_command_kwargs={...}, + fail_on_nonzero_exit=False, + ) + +When ``fail_on_nonzero_exit=False``, you can retrieve the exit code using +:class:`~airflow.providers.amazon.aws.operators.ssm.SsmGetCommandInvocationOperator` and use it +for workflow routing decisions. Note that AWS-level failures (TimedOut, Cancelled) will still raise +exceptions regardless of this setting. + .. _howto/operator:SsmGetCommandInvocationOperator: Retrieve output from an SSM command invocation diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/ssm.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/ssm.py index ef8c30323a74c..13485c79a0b3d 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/ssm.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/ssm.py @@ -99,3 +99,25 @@ def list_command_invocations(self, command_id: str) -> dict: :return: Response from SSM list_command_invocations API. """ return self.conn.list_command_invocations(CommandId=command_id) + + @staticmethod + def is_aws_level_failure(status: str) -> bool: + """ + Check if a command status represents an AWS-level failure. + + AWS-level failures are service-level issues that should always raise exceptions, + as opposed to command-level failures (non-zero exit codes) which may be tolerated + depending on the fail_on_nonzero_exit parameter. + + According to AWS SSM documentation, the possible statuses are: + Pending, InProgress, Delayed, Success, Cancelled, TimedOut, Failed, Cancelling + + AWS-level failures are: + - Cancelled: Command was cancelled before completion + - TimedOut: Command exceeded the timeout period + - Cancelling: Command is in the process of being cancelled + + :param status: The command invocation status from SSM. + :return: True if the status represents an AWS-level failure, False otherwise. + """ + return status in ("Cancelled", "TimedOut", "Cancelling") diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/ssm.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/ssm.py index c68cae45577a7..2221e3ea1322b 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/ssm.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/ssm.py @@ -19,6 +19,8 @@ from collections.abc import Sequence from typing import TYPE_CHECKING, Any +from botocore.exceptions import WaiterError + from airflow.providers.amazon.aws.hooks.ssm import SsmHook from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator from airflow.providers.amazon.aws.triggers.ssm import SsmRunCommandTrigger @@ -50,6 +52,12 @@ class SsmRunCommandOperator(AwsBaseOperator[SsmHook]): (default: 120) :param waiter_max_attempts: Maximum number of attempts to check for job completion. (default: 75) + :param fail_on_nonzero_exit: If True (default), the operator will fail when + the command returns a non-zero exit code. If False, the operator will + complete successfully regardless of the command exit code, allowing + downstream tasks to handle exit codes for workflow routing. Note that + AWS-level failures (Cancelled, TimedOut) will still raise exceptions + even when this is False. (default: True) :param deferrable: If True, the operator will wait asynchronously for the cluster to stop. This implies waiting for completion. This mode requires aiobotocore module to be installed. (default: False) @@ -81,6 +89,7 @@ def __init__( wait_for_completion: bool = True, waiter_delay: int = 120, waiter_max_attempts: int = 75, + fail_on_nonzero_exit: bool = True, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ): @@ -88,6 +97,7 @@ def __init__( self.wait_for_completion = wait_for_completion self.waiter_delay = waiter_delay self.waiter_max_attempts = waiter_max_attempts + self.fail_on_nonzero_exit = fail_on_nonzero_exit self.deferrable = deferrable self.document_name = document_name @@ -118,6 +128,7 @@ def execute(self, context: Context): command_id=command_id, waiter_delay=self.waiter_delay, waiter_max_attempts=self.waiter_max_attempts, + fail_on_nonzero_exit=self.fail_on_nonzero_exit, aws_conn_id=self.aws_conn_id, region_name=self.region_name, verify=self.verify, @@ -132,14 +143,35 @@ def execute(self, context: Context): instance_ids = response["Command"]["InstanceIds"] for instance_id in instance_ids: - waiter.wait( - CommandId=command_id, - InstanceId=instance_id, - WaiterConfig={ - "Delay": self.waiter_delay, - "MaxAttempts": self.waiter_max_attempts, - }, - ) + try: + waiter.wait( + CommandId=command_id, + InstanceId=instance_id, + WaiterConfig={ + "Delay": self.waiter_delay, + "MaxAttempts": self.waiter_max_attempts, + }, + ) + except WaiterError: + if not self.fail_on_nonzero_exit: + # Enhanced mode: distinguish between AWS-level and command-level failures + invocation = self.hook.get_command_invocation(command_id, instance_id) + status = invocation.get("Status", "") + + # AWS-level failures should always raise + if SsmHook.is_aws_level_failure(status): + raise + + # Command-level failure - tolerate it in enhanced mode + self.log.info( + "Command completed with status %s (exit code: %s). " + "Continuing due to fail_on_nonzero_exit=False", + status, + invocation.get("ResponseCode", "unknown"), + ) + else: + # Traditional mode: all failures raise + raise return command_id @@ -148,14 +180,6 @@ class SsmGetCommandInvocationOperator(AwsBaseOperator[SsmHook]): """ Retrieves the output and execution details of an SSM command invocation. - This operator allows you to fetch the standard output, standard error, - execution status, and other details from SSM commands. It can be used to - retrieve output from commands executed by SsmRunCommandOperator in previous - tasks, or from commands executed outside of Airflow entirely. - - The operator returns structured data including stdout, stderr, execution - times, and status information for each instance that executed the command. - .. seealso:: For more information on how to use this operator, take a look at the guide: diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/ssm.py b/providers/amazon/src/airflow/providers/amazon/aws/sensors/ssm.py index 1ada2fd1f70af..2874a553d5abf 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/sensors/ssm.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/ssm.py @@ -44,6 +44,11 @@ class SsmRunCommandCompletedSensor(AwsBaseSensor[SsmHook]): :ref:`howto/sensor:SsmRunCommandCompletedSensor` :param command_id: The ID of the AWS SSM Run Command. + :param fail_on_nonzero_exit: If True (default), the sensor will fail when the command + returns a non-zero exit code. If False, the sensor will complete successfully + for both Success and Failed command statuses, allowing downstream tasks to handle + exit codes. AWS-level failures (Cancelled, TimedOut) will still raise exceptions. + (default: True) :param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore module to be installed. (default: False, but can be overridden in config file by setting @@ -85,6 +90,7 @@ def __init__( self, *, command_id, + fail_on_nonzero_exit: bool = True, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), poke_interval: int = 120, max_retries: int = 75, @@ -92,6 +98,7 @@ def __init__( ): super().__init__(**kwargs) self.command_id = command_id + self.fail_on_nonzero_exit = fail_on_nonzero_exit self.deferrable = deferrable self.poke_interval = poke_interval self.max_retries = max_retries @@ -112,7 +119,19 @@ def poke(self, context: Context): state = invocation["Status"] if state in self.FAILURE_STATES: - raise RuntimeError(self.FAILURE_MESSAGE) + # Check if we should tolerate this failure + if self.fail_on_nonzero_exit: + raise RuntimeError(self.FAILURE_MESSAGE) # Traditional behavior + + # Only fail on AWS-level issues, tolerate command failures + if SsmHook.is_aws_level_failure(state): + raise RuntimeError(f"SSM command {self.command_id} {state}") + + # Command failed but we're tolerating it + self.log.info( + "Command invocation has status %s. Continuing due to fail_on_nonzero_exit=False", + state, + ) if state in self.INTERMEDIATE_STATES: return False @@ -127,6 +146,7 @@ def execute(self, context: Context): waiter_delay=int(self.poke_interval), waiter_max_attempts=self.max_retries, aws_conn_id=self.aws_conn_id, + fail_on_nonzero_exit=self.fail_on_nonzero_exit, ), method_name="execute_complete", ) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/ssm.py b/providers/amazon/src/airflow/providers/amazon/aws/triggers/ssm.py index caa0bb81aacbd..2c66c21c12a18 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/ssm.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/ssm.py @@ -35,6 +35,9 @@ class SsmRunCommandTrigger(AwsBaseWaiterTrigger): :param command_id: The ID of the AWS SSM Run Command. :param waiter_delay: The amount of time in seconds to wait between attempts. (default: 120) :param waiter_max_attempts: The maximum number of attempts to be made. (default: 75) + :param fail_on_nonzero_exit: If True (default), the trigger will fail when the command returns + a non-zero exit code. If False, the trigger will complete successfully regardless of the + command exit code. (default: True) :param aws_conn_id: The Airflow connection used for AWS credentials. :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. :param verify: Whether or not to verify SSL certificates. See: @@ -49,13 +52,14 @@ def __init__( command_id: str, waiter_delay: int = 120, waiter_max_attempts: int = 75, + fail_on_nonzero_exit: bool = True, aws_conn_id: str | None = None, region_name: str | None = None, verify: bool | str | None = None, botocore_config: dict | None = None, ) -> None: super().__init__( - serialized_fields={"command_id": command_id}, + serialized_fields={"command_id": command_id, "fail_on_nonzero_exit": fail_on_nonzero_exit}, waiter_name="command_executed", waiter_args={"CommandId": command_id}, failure_message="SSM run command failed.", @@ -71,6 +75,7 @@ def __init__( botocore_config=botocore_config, ) self.command_id = command_id + self.fail_on_nonzero_exit = fail_on_nonzero_exit def hook(self) -> AwsGenericHook: return SsmHook( @@ -89,14 +94,41 @@ async def run(self) -> AsyncIterator[TriggerEvent]: for instance_id in instance_ids: self.waiter_args["InstanceId"] = instance_id - await async_wait( - waiter, - self.waiter_delay, - self.attempts, - self.waiter_args, - self.failure_message, - self.status_message, - self.status_queries, - ) + try: + await async_wait( + waiter, + self.waiter_delay, + self.attempts, + self.waiter_args, + self.failure_message, + self.status_message, + self.status_queries, + ) + except Exception: + if not self.fail_on_nonzero_exit: + # Enhanced mode: check if it's an AWS-level failure + invocation = await client.get_command_invocation( + CommandId=self.command_id, InstanceId=instance_id + ) + status = invocation.get("Status", "") + + # AWS-level failures should always raise + if SsmHook.is_aws_level_failure(status): + raise + + # Command-level failure - tolerate it in enhanced mode + response_code = invocation.get("ResponseCode", "unknown") + self.log.info( + "Command %s completed with status %s (exit code: %s) for instance %s. " + "Continuing due to fail_on_nonzero_exit=False", + self.command_id, + status, + response_code, + instance_id, + ) + continue + else: + # Traditional mode: all failures raise + raise yield TriggerEvent({"status": "success", self.return_key: self.return_value}) diff --git a/providers/amazon/tests/system/amazon/aws/example_ssm.py b/providers/amazon/tests/system/amazon/aws/example_ssm.py index 6357747ba8212..82304a4c0c8b2 100644 --- a/providers/amazon/tests/system/amazon/aws/example_ssm.py +++ b/providers/amazon/tests/system/amazon/aws/example_ssm.py @@ -223,6 +223,79 @@ def wait_until_ssm_ready(instance_id: str, max_attempts: int = 10, delay_seconds ) # [END howto_operator_get_command_invocation] + # [START howto_operator_ssm_enhanced_async] + run_command_async = SsmRunCommandOperator( + task_id="run_command_async", + document_name="AWS-RunShellScript", + run_command_kwargs={ + "InstanceIds": [instance_id], + "Parameters": {"commands": ["echo 'Testing async pattern'", "exit 1"]}, + }, + wait_for_completion=False, + fail_on_nonzero_exit=False, + ) + + wait_command_async = SsmRunCommandCompletedSensor( + task_id="wait_command_async", + command_id="{{ ti.xcom_pull(task_ids='run_command_async') }}", + fail_on_nonzero_exit=False, + ) + # [END howto_operator_ssm_enhanced_async] + + # [START howto_operator_ssm_enhanced_sync] + run_command_sync = SsmRunCommandOperator( + task_id="run_command_sync", + document_name="AWS-RunShellScript", + run_command_kwargs={ + "InstanceIds": [instance_id], + "Parameters": {"commands": ["echo 'Testing sync pattern'", "exit 2"]}, + }, + wait_for_completion=True, + fail_on_nonzero_exit=False, + ) + # [END howto_operator_ssm_enhanced_sync] + + # [START howto_operator_ssm_exit_code_routing] + get_exit_code_output = SsmGetCommandInvocationOperator( + task_id="get_exit_code_output", + command_id="{{ ti.xcom_pull(task_ids='run_command_async') }}", + instance_id=instance_id, + ) + + @task + def route_based_on_exit_code(**context): + output = context["ti"].xcom_pull(task_ids="get_exit_code_output") + exit_code = output.get("response_code") if output else None + log.info("Command exit code: %s", exit_code) + return "handle_exit_code" + + route_task = route_based_on_exit_code() + + @task(task_id="handle_exit_code") + def handle_exit_code(): + log.info("Handling exit code routing") + return "exit_code_handled" + + handle_task = handle_exit_code() + # [END howto_operator_ssm_exit_code_routing] + + # [START howto_operator_ssm_traditional] + run_command_traditional = SsmRunCommandOperator( + task_id="run_command_traditional", + document_name="AWS-RunShellScript", + run_command_kwargs={ + "InstanceIds": [instance_id], + "Parameters": {"commands": ["echo 'Testing traditional pattern'", "exit 0"]}, + }, + wait_for_completion=False, + ) + + wait_command_traditional = SsmRunCommandCompletedSensor( + task_id="wait_command_traditional", + command_id="{{ ti.xcom_pull(task_ids='run_command_traditional') }}", + ) + # [END howto_operator_ssm_traditional] + delete_instance = EC2TerminateInstanceOperator( task_id="terminate_instance", trigger_rule=TriggerRule.ALL_DONE, @@ -244,11 +317,18 @@ def wait_until_ssm_ready(instance_id: str, max_attempts: int = 10, delay_seconds run_command, await_run_command, get_command_output, - # TEST TEARDOWN - delete_instance, - delete_instance_profile(instance_profile_name, role_name), ) + # Exit code handling examples (run in parallel) + wait_until_ssm_ready(instance_id) >> run_command_async >> wait_command_async >> get_exit_code_output + get_exit_code_output >> route_task >> handle_task + wait_until_ssm_ready(instance_id) >> run_command_sync + wait_until_ssm_ready(instance_id) >> run_command_traditional >> wait_command_traditional + + # TEST TEARDOWN + [get_command_output, handle_task, run_command_sync, wait_command_traditional] >> delete_instance + delete_instance >> delete_instance_profile(instance_profile_name, role_name) + from tests_common.test_utils.watcher import watcher # This test needs watcher in order to properly mark success/failure diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_ssm.py b/providers/amazon/tests/unit/amazon/aws/hooks/test_ssm.py index 5e6320e695697..6c51b5c563dd1 100644 --- a/providers/amazon/tests/unit/amazon/aws/hooks/test_ssm.py +++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_ssm.py @@ -137,3 +137,26 @@ def test_list_command_invocations_empty_response(self, mock_conn): mock_conn.return_value.list_command_invocations.assert_called_once_with(CommandId=command_id) assert result == expected_response + + @pytest.mark.parametrize( + ("status", "expected_result"), + [ + pytest.param("Cancelled", True, id="cancelled_is_aws_level"), + pytest.param("TimedOut", True, id="timedout_is_aws_level"), + pytest.param("Cancelling", True, id="cancelling_is_aws_level"), + pytest.param("Failed", False, id="failed_is_command_level"), + pytest.param("Success", False, id="success_is_not_failure"), + pytest.param("Pending", False, id="pending_is_not_failure"), + pytest.param("InProgress", False, id="inprogress_is_not_failure"), + pytest.param("Delayed", False, id="delayed_is_not_failure"), + ], + ) + def test_is_aws_level_failure(self, status, expected_result): + """ + Test that is_aws_level_failure correctly identifies AWS-level failures. + + AWS-level failures (Cancelled, TimedOut, Cancelling) represent service-level issues + that should always raise exceptions, while command-level failures (Failed) and + other statuses should not be considered AWS-level failures. + """ + assert SsmHook.is_aws_level_failure(status) == expected_result diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_ssm.py b/providers/amazon/tests/unit/amazon/aws/operators/test_ssm.py index 10e91d9eea01b..0dacf4de95da8 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_ssm.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_ssm.py @@ -20,6 +20,7 @@ from unittest import mock import pytest +from botocore.exceptions import WaiterError from airflow.providers.amazon.aws.hooks.ssm import SsmHook from airflow.providers.amazon.aws.operators.ssm import SsmGetCommandInvocationOperator, SsmRunCommandOperator @@ -98,6 +99,144 @@ def test_deferrable_with_region(self, mock_conn): assert trigger.botocore_config == {"retries": {"max_attempts": 5}} assert trigger.aws_conn_id == self.operator.aws_conn_id + def test_operator_default_fails_on_nonzero_exit(self, mock_conn): + """ + Test traditional mode where fail_on_nonzero_exit=True (default). + + Verifies that when fail_on_nonzero_exit is True (the default), the operator + raises an exception when the waiter encounters a command failure. + """ + self.operator.wait_for_completion = True + + # Mock waiter to raise WaiterError (simulating command failure) + mock_waiter = mock.MagicMock() + mock_waiter.wait.side_effect = WaiterError( + name="command_executed", + reason="Waiter encountered a terminal failure state", + last_response={"Status": "Failed"}, + ) + + with mock.patch.object(SsmHook, "get_waiter", return_value=mock_waiter): + # Should raise WaiterError in traditional mode + with pytest.raises(WaiterError): + self.operator.execute({}) + + def test_operator_enhanced_mode_tolerates_failed_status(self, mock_conn): + """ + Test enhanced mode where fail_on_nonzero_exit=False tolerates Failed status. + + Verifies that when fail_on_nonzero_exit is False, the operator completes + successfully even when the command returns a Failed status with non-zero exit code. + """ + self.operator.wait_for_completion = True + self.operator.fail_on_nonzero_exit = False + + # Mock waiter to raise WaiterError + mock_waiter = mock.MagicMock() + mock_waiter.wait.side_effect = WaiterError( + name="command_executed", + reason="Waiter encountered a terminal failure state", + last_response={"Status": "Failed"}, + ) + + # Mock get_command_invocation to return Failed status with exit code + with ( + mock.patch.object(SsmHook, "get_waiter", return_value=mock_waiter), + mock.patch.object( + SsmHook, "get_command_invocation", return_value={"Status": "Failed", "ResponseCode": 1} + ), + ): + # Should NOT raise in enhanced mode for Failed status + command_id = self.operator.execute({}) + assert command_id == COMMAND_ID + + def test_operator_enhanced_mode_fails_on_timeout(self, mock_conn): + """ + Test enhanced mode still fails on TimedOut status. + + Verifies that even when fail_on_nonzero_exit is False, the operator + still raises an exception for AWS-level failures like TimedOut. + """ + self.operator.wait_for_completion = True + self.operator.fail_on_nonzero_exit = False + + # Mock waiter to raise WaiterError + mock_waiter = mock.MagicMock() + mock_waiter.wait.side_effect = WaiterError( + name="command_executed", + reason="Waiter encountered a terminal failure state", + last_response={"Status": "TimedOut"}, + ) + + # Mock get_command_invocation to return TimedOut status + with ( + mock.patch.object(SsmHook, "get_waiter", return_value=mock_waiter), + mock.patch.object( + SsmHook, "get_command_invocation", return_value={"Status": "TimedOut", "ResponseCode": -1} + ), + ): + # Should raise even in enhanced mode for TimedOut + with pytest.raises(WaiterError): + self.operator.execute({}) + + def test_operator_enhanced_mode_fails_on_cancelled(self, mock_conn): + """ + Test enhanced mode still fails on Cancelled status. + + Verifies that even when fail_on_nonzero_exit is False, the operator + still raises an exception for AWS-level failures like Cancelled. + """ + self.operator.wait_for_completion = True + self.operator.fail_on_nonzero_exit = False + + # Mock waiter to raise WaiterError + mock_waiter = mock.MagicMock() + mock_waiter.wait.side_effect = WaiterError( + name="command_executed", + reason="Waiter encountered a terminal failure state", + last_response={"Status": "Cancelled"}, + ) + + # Mock get_command_invocation to return Cancelled status + with ( + mock.patch.object(SsmHook, "get_waiter", return_value=mock_waiter), + mock.patch.object( + SsmHook, "get_command_invocation", return_value={"Status": "Cancelled", "ResponseCode": -1} + ), + ): + # Should raise even in enhanced mode for Cancelled + with pytest.raises(WaiterError): + self.operator.execute({}) + + @mock.patch("airflow.providers.amazon.aws.operators.ssm.SsmRunCommandTrigger") + def test_operator_passes_parameter_to_trigger(self, mock_trigger_class, mock_conn): + """ + Test that fail_on_nonzero_exit parameter is passed to trigger in deferrable mode. + + Verifies that when using deferrable mode, the fail_on_nonzero_exit parameter + is correctly passed to the SsmRunCommandTrigger. + """ + self.operator.deferrable = True + self.operator.fail_on_nonzero_exit = False + + with mock.patch.object(self.operator, "defer") as mock_defer: + command_id = self.operator.execute({}) + + assert command_id == COMMAND_ID + mock_conn.send_command.assert_called_once_with( + DocumentName=DOCUMENT_NAME, InstanceIds=INSTANCE_IDS + ) + + # Verify defer was called + mock_defer.assert_called_once() + + # Verify the trigger was instantiated with correct parameters + mock_trigger_class.assert_called_once() + call_kwargs = mock_trigger_class.call_args[1] + + assert call_kwargs["command_id"] == COMMAND_ID + assert call_kwargs["fail_on_nonzero_exit"] is False + class TestSsmGetCommandInvocationOperator: @pytest.fixture @@ -256,3 +395,104 @@ def test_execute_all_instances_with_error(self, mock_hook): def test_template_fields(self): validate_template_fields(self.operator) + + def test_exit_code_routing_use_case(self, mock_hook): + """ + Test that demonstrates the exit code routing use case. + + This test verifies that SsmGetCommandInvocationOperator correctly retrieves + exit codes and status information that can be used for workflow routing, + particularly when used with SsmRunCommandOperator in enhanced mode + (fail_on_nonzero_exit=False). + """ + # Mock response with various exit codes that might be used for routing + mock_invocation_details = { + "Status": "Failed", # Command failed but we want to route based on exit code + "ResponseCode": 42, # Custom exit code for specific routing logic + "StandardOutputContent": "Partial success - some items processed", + "StandardErrorContent": "Warning: 3 items skipped", + "ExecutionStartDateTime": "2023-01-01T12:00:00Z", + "ExecutionEndDateTime": "2023-01-01T12:00:05Z", + "DocumentName": "AWS-RunShellScript", + "Comment": "Data processing script", + } + mock_hook.get_command_invocation.return_value = mock_invocation_details + + result = self.operator.execute({}) + + # Verify that response_code is available for routing decisions + assert result["invocations"][0]["response_code"] == 42 + assert result["invocations"][0]["status"] == "Failed" + + # Verify that output is available for additional context + assert "Partial success" in result["invocations"][0]["standard_output"] + assert "Warning" in result["invocations"][0]["standard_error"] + + # This demonstrates that the operator provides all necessary information + # for downstream tasks to make routing decisions based on exit codes, + # which is the key use case for the enhanced mode feature. + + def test_multiple_exit_codes_for_routing(self, mock_hook): + """ + Test retrieving multiple instances with different exit codes for routing. + + This demonstrates a common pattern where a command runs on multiple instances + and downstream tasks need to route based on the exit codes from each instance. + """ + operator = SsmGetCommandInvocationOperator( + task_id="test_multi_instance_routing", + command_id=self.command_id, + ) + + # Mock list_command_invocations response + mock_invocations = [ + {"InstanceId": "i-success"}, + {"InstanceId": "i-partial"}, + {"InstanceId": "i-failed"}, + ] + mock_hook.list_command_invocations.return_value = {"CommandInvocations": mock_invocations} + + # Mock different exit codes for routing scenarios + mock_hook.get_command_invocation.side_effect = [ + { + "Status": "Success", + "ResponseCode": 0, # Complete success + "StandardOutputContent": "All items processed", + "StandardErrorContent": "", + "ExecutionStartDateTime": "2023-01-01T12:00:00Z", + "ExecutionEndDateTime": "2023-01-01T12:00:05Z", + "DocumentName": "AWS-RunShellScript", + "Comment": "", + }, + { + "Status": "Failed", + "ResponseCode": 2, # Partial success - custom exit code + "StandardOutputContent": "Some items processed", + "StandardErrorContent": "Warning: partial completion", + "ExecutionStartDateTime": "2023-01-01T12:00:00Z", + "ExecutionEndDateTime": "2023-01-01T12:00:10Z", + "DocumentName": "AWS-RunShellScript", + "Comment": "", + }, + { + "Status": "Failed", + "ResponseCode": 1, # Complete failure + "StandardOutputContent": "", + "StandardErrorContent": "Error: operation failed", + "ExecutionStartDateTime": "2023-01-01T12:00:00Z", + "ExecutionEndDateTime": "2023-01-01T12:00:08Z", + "DocumentName": "AWS-RunShellScript", + "Comment": "", + }, + ] + + result = operator.execute({}) + + # Verify all exit codes are captured for routing logic + assert len(result["invocations"]) == 3 + assert result["invocations"][0]["response_code"] == 0 + assert result["invocations"][1]["response_code"] == 2 + assert result["invocations"][2]["response_code"] == 1 + + # This demonstrates that the operator can retrieve exit codes from multiple + # instances, enabling complex routing logic based on the results from each instance. diff --git a/providers/amazon/tests/unit/amazon/aws/sensors/test_ssm.py b/providers/amazon/tests/unit/amazon/aws/sensors/test_ssm.py index 6679cb1dc5669..4b714d1e90dda 100644 --- a/providers/amazon/tests/unit/amazon/aws/sensors/test_ssm.py +++ b/providers/amazon/tests/unit/amazon/aws/sensors/test_ssm.py @@ -92,3 +92,55 @@ def test_poke_failure_states(self, mock_conn, state, mock_ssm_list_invocations): mock_ssm_list_invocations(mock_conn, state) with pytest.raises(RuntimeError, match=self.SENSOR.FAILURE_MESSAGE): self.sensor.poke({}) + + @mock.patch.object(SsmHook, "conn") + def test_sensor_default_fails_on_failed_status(self, mock_conn, mock_ssm_list_invocations): + """Test that sensor fails on Failed status in traditional mode (fail_on_nonzero_exit=True).""" + mock_ssm_list_invocations(mock_conn, "Failed") + self.sensor.hook.conn = mock_conn + with pytest.raises(RuntimeError, match=self.SENSOR.FAILURE_MESSAGE): + self.sensor.poke({}) + + @mock.patch.object(SsmHook, "conn") + def test_sensor_enhanced_mode_tolerates_failed_status(self, mock_conn, mock_ssm_list_invocations): + """Test that sensor tolerates Failed status in enhanced mode (fail_on_nonzero_exit=False).""" + sensor = self.SENSOR(**self.default_op_kwarg, fail_on_nonzero_exit=False) + mock_ssm_list_invocations(mock_conn, "Failed") + sensor.hook.conn = mock_conn + assert sensor.poke({}) is True + + @mock.patch.object(SsmHook, "conn") + def test_sensor_enhanced_mode_fails_on_timeout(self, mock_conn, mock_ssm_list_invocations): + """Test that sensor still fails on TimedOut status in enhanced mode.""" + sensor = self.SENSOR(**self.default_op_kwarg, fail_on_nonzero_exit=False) + mock_ssm_list_invocations(mock_conn, "TimedOut") + sensor.hook.conn = mock_conn + with pytest.raises(RuntimeError, match=f"SSM command {COMMAND_ID} TimedOut"): + sensor.poke({}) + + @mock.patch.object(SsmHook, "conn") + def test_sensor_enhanced_mode_fails_on_cancelled(self, mock_conn, mock_ssm_list_invocations): + """Test that sensor still fails on Cancelled status in enhanced mode.""" + sensor = self.SENSOR(**self.default_op_kwarg, fail_on_nonzero_exit=False) + mock_ssm_list_invocations(mock_conn, "Cancelled") + sensor.hook.conn = mock_conn + with pytest.raises(RuntimeError, match=f"SSM command {COMMAND_ID} Cancelled"): + sensor.poke({}) + + @mock.patch("airflow.providers.amazon.aws.sensors.ssm.SsmRunCommandTrigger") + def test_sensor_passes_parameter_to_trigger(self, mock_trigger_class): + """Test that fail_on_nonzero_exit parameter is passed correctly to trigger in deferrable mode.""" + sensor = self.SENSOR(**self.default_op_kwarg, fail_on_nonzero_exit=False, deferrable=True) + + with mock.patch.object(sensor, "defer") as mock_defer: + sensor.execute({}) + + # Verify defer was called + assert mock_defer.called + + # Verify the trigger was instantiated with correct parameters + mock_trigger_class.assert_called_once() + call_kwargs = mock_trigger_class.call_args[1] + + assert call_kwargs["command_id"] == COMMAND_ID + assert call_kwargs["fail_on_nonzero_exit"] is False diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_ssm.py b/providers/amazon/tests/unit/amazon/aws/triggers/test_ssm.py index 15b9ebb0d852b..ad328e41e374e 100644 --- a/providers/amazon/tests/unit/amazon/aws/triggers/test_ssm.py +++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_ssm.py @@ -115,3 +115,92 @@ async def test_run_fails(self, mock_get_waiter, mock_get_async_conn, mock_ssm_li with pytest.raises(AirflowException): await generator.asend(None) + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.triggers.ssm.async_wait") + @mock.patch.object(SsmHook, "get_async_conn") + @mock.patch.object(SsmHook, "get_waiter") + async def test_trigger_default_fails_on_waiter_error( + self, mock_get_waiter, mock_get_async_conn, mock_async_wait, mock_ssm_list_invocations + ): + """Test traditional mode (fail_on_nonzero_exit=True) raises exception on waiter error.""" + mock_ssm_list_invocations(mock_get_async_conn) + mock_async_wait.side_effect = AirflowException("SSM run command failed.") + + trigger = SsmRunCommandTrigger(command_id=COMMAND_ID, fail_on_nonzero_exit=True) + generator = trigger.run() + + with pytest.raises(AirflowException): + await generator.asend(None) + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.triggers.ssm.async_wait") + @mock.patch.object(SsmHook, "get_async_conn") + @mock.patch.object(SsmHook, "get_waiter") + async def test_trigger_enhanced_mode_tolerates_failed_status( + self, mock_get_waiter, mock_get_async_conn, mock_async_wait, mock_ssm_list_invocations + ): + """Test enhanced mode (fail_on_nonzero_exit=False) tolerates Failed status.""" + mock_client = mock_ssm_list_invocations(mock_get_async_conn) + # Mock async_wait to raise exception (simulating waiter failure) + mock_async_wait.side_effect = AirflowException("SSM run command failed.") + # Mock get_command_invocation to return Failed status + mock_client.get_command_invocation = mock.AsyncMock( + return_value={"Status": "Failed", "ResponseCode": 1} + ) + + trigger = SsmRunCommandTrigger(command_id=COMMAND_ID, fail_on_nonzero_exit=False) + generator = trigger.run() + response = await generator.asend(None) + + assert response == TriggerEvent({"status": "success", "command_id": COMMAND_ID}) + # Verify get_command_invocation was called for both instances + assert mock_client.get_command_invocation.call_count == 2 + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.triggers.ssm.async_wait") + @mock.patch.object(SsmHook, "get_async_conn") + @mock.patch.object(SsmHook, "get_waiter") + async def test_trigger_enhanced_mode_fails_on_aws_errors( + self, mock_get_waiter, mock_get_async_conn, mock_async_wait, mock_ssm_list_invocations + ): + """Test enhanced mode (fail_on_nonzero_exit=False) still fails on AWS-level errors.""" + mock_client = mock_ssm_list_invocations(mock_get_async_conn) + # Mock async_wait to raise exception (simulating waiter failure) + mock_async_wait.side_effect = AirflowException("SSM run command failed.") + # Mock get_command_invocation to return TimedOut status (AWS-level failure) + mock_client.get_command_invocation = mock.AsyncMock( + return_value={"Status": "TimedOut", "ResponseCode": -1} + ) + + trigger = SsmRunCommandTrigger(command_id=COMMAND_ID, fail_on_nonzero_exit=False) + generator = trigger.run() + + with pytest.raises(AirflowException): + await generator.asend(None) + + # Test with Cancelled status as well + mock_client.get_command_invocation = mock.AsyncMock( + return_value={"Status": "Cancelled", "ResponseCode": -1} + ) + + trigger = SsmRunCommandTrigger(command_id=COMMAND_ID, fail_on_nonzero_exit=False) + generator = trigger.run() + + with pytest.raises(AirflowException): + await generator.asend(None) + + def test_trigger_serialization_includes_parameter(self): + """Test that fail_on_nonzero_exit parameter is properly serialized.""" + trigger = SsmRunCommandTrigger(command_id=COMMAND_ID, fail_on_nonzero_exit=False) + classpath, kwargs = trigger.serialize() + + assert classpath == BASE_TRIGGER_CLASSPATH + "SsmRunCommandTrigger" + assert kwargs.get("command_id") == COMMAND_ID + assert kwargs.get("fail_on_nonzero_exit") is False + + # Test with default value (True) + trigger_default = SsmRunCommandTrigger(command_id=COMMAND_ID) + classpath, kwargs = trigger_default.serialize() + + assert kwargs.get("fail_on_nonzero_exit") is True