diff --git a/providers/amazon/docs/operators/ssm.rst b/providers/amazon/docs/operators/ssm.rst index 197bd34b4fafc..77e17dc32204e 100644 --- a/providers/amazon/docs/operators/ssm.rst +++ b/providers/amazon/docs/operators/ssm.rst @@ -58,6 +58,43 @@ or the :class:`~airflow.providers.amazon.aws.triggers.ssm.SsmRunCommandTrigger` :dedent: 4 :start-after: [START howto_operator_run_command] :end-before: [END howto_operator_run_command] +.. _howto/operator:SsmGetCommandInvocationOperator: + +Retrieve output from an SSM command invocation +============================================== + +To retrieve the output and execution details from an SSM command that has been executed, you can use +:class:`~airflow.providers.amazon.aws.operators.ssm.SsmGetCommandInvocationOperator`. + +This operator is useful for: + +* Retrieving output from commands executed by :class:`~airflow.providers.amazon.aws.operators.ssm.SsmRunCommandOperator` in previous tasks +* Getting output from SSM commands executed outside of Airflow +* Inspecting command results for debugging or data processing purposes + +To retrieve output from all instances that executed a command: + +.. code-block:: python + + get_all_output = SsmGetCommandInvocationOperator( + task_id="get_command_output", + command_id='{{ ti.xcom_pull(task_ids="run_command") }}', # From previous task + ) + +To retrieve output from a specific instance: + +.. exampleinclude:: /../../amazon/tests/system/amazon/aws/example_ssm.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_get_command_invocation] + :end-before: [END howto_operator_get_command_invocation] + +The operator returns structured data including: + +* Standard output and error content +* Execution status and response codes +* Execution start and end times +* Document name and comments Sensors ------- @@ -79,7 +116,7 @@ To wait on the state of an Amazon SSM run command job until it reaches a termina IAM Permissions --------------- -You need to ensure the following IAM permissions are granted to allow Airflow to run and monitor SSM Run Command executions: +You need to ensure the following IAM permissions are granted to allow Airflow to run, retrieve and monitor SSM Run Command executions: .. code-block:: diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/ssm.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/ssm.py index d40b88be6417b..7bc2c8214264d 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/ssm.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/ssm.py @@ -38,7 +38,8 @@ class SsmHook(AwsBaseHook): """ Interact with Amazon Systems Manager (SSM). - Provide thin wrapper around :external+boto3:py:class:`boto3.client("ssm") `. + Provide thin wrapper around + :external+boto3:py:class:`boto3.client("ssm") `. Additional arguments (such as ``aws_conn_id``) may be specified and are passed down to the underlying AwsBaseHook. @@ -53,7 +54,9 @@ def __init__(self, *args, **kwargs) -> None: def get_parameter_value(self, parameter: str, default: str | ArgNotSet = NOTSET) -> str: """ - Return the provided Parameter or an optional default; if it is encrypted, then decrypt and mask. + Return the provided Parameter or an optional default. + + If it is encrypted, then decrypt and mask. .. seealso:: - :external+boto3:py:meth:`SSM.Client.get_parameter` @@ -71,3 +74,28 @@ def get_parameter_value(self, parameter: str, default: str | ArgNotSet = NOTSET) if isinstance(default, ArgNotSet): raise return default + + def get_command_invocation(self, command_id: str, instance_id: str) -> dict: + """ + Get the output of a command invocation for a specific instance. + + .. seealso:: + - :external+boto3:py:meth:`SSM.Client.get_command_invocation` + + :param command_id: The ID of the command. + :param instance_id: The ID of the instance. + :return: The command invocation details including output. + """ + return self.conn.get_command_invocation(CommandId=command_id, InstanceId=instance_id) + + def list_command_invocations(self, command_id: str) -> dict: + """ + List all command invocations for a given command ID. + + .. seealso:: + - :external+boto3:py:meth:`SSM.Client.list_command_invocations` + + :param command_id: The ID of the command. + :return: Response from SSM list_command_invocations API. + """ + return self.conn.list_command_invocations(CommandId=command_id) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/ssm.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/ssm.py index 52480ed09f068..7c2af24919a70 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/ssm.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/ssm.py @@ -20,7 +20,6 @@ from typing import TYPE_CHECKING, Any from airflow.configuration import conf -from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.ssm import SsmHook from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator from airflow.providers.amazon.aws.triggers.ssm import SsmRunCommandTrigger @@ -36,27 +35,35 @@ class SsmRunCommandOperator(AwsBaseOperator[SsmHook]): Executes the SSM Run Command to perform actions on managed instances. .. seealso:: - For more information on how to use this operator, take a look at the guide: + For more information on how to use this operator, take a look at the + guide: :ref:`howto/operator:SsmRunCommandOperator` - :param document_name: The name of the Amazon Web Services Systems Manager document (SSM document) to run. - :param run_command_kwargs: Optional parameters to pass to the send_command API. - - :param wait_for_completion: Whether to wait for cluster to stop. (default: True) - :param waiter_delay: Time in seconds to wait between status checks. (default: 120) - :param waiter_max_attempts: Maximum number of attempts to check for job completion. (default: 75) - :param deferrable: If True, the operator will wait asynchronously for the cluster to stop. - This implies waiting for completion. This mode requires aiobotocore module to be installed. - (default: False) + :param document_name: The name of the Amazon Web Services Systems Manager + document (SSM document) to run. + :param run_command_kwargs: Optional parameters to pass to the send_command + API. + + :param wait_for_completion: Whether to wait for cluster to stop. + (default: True) + :param waiter_delay: Time in seconds to wait between status checks. + (default: 120) + :param waiter_max_attempts: Maximum number of attempts to check for job + completion. (default: 75) + :param deferrable: If True, the operator will wait asynchronously for the + cluster to stop. This implies waiting for completion. This mode + requires aiobotocore module to be installed. (default: False) :param aws_conn_id: The Airflow connection used for AWS credentials. - If this is ``None`` or empty then the default boto3 behaviour is used. If - running Airflow in a distributed manner and aws_conn_id is None or + If this is ``None`` or empty then the default boto3 behaviour is used. + If running Airflow in a distributed manner and aws_conn_id is None or empty, then default boto3 configuration would be used (and must be maintained on each worker node). - :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param region_name: AWS region_name. If not specified then the default + boto3 behaviour is used. :param verify: Whether or not to verify SSL certificates. See: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html - :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + :param botocore_config: Configuration dictionary (key-values) for botocore + client. See: https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ @@ -90,7 +97,7 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None event = validate_execute_complete_event(event) if event["status"] != "success": - raise AirflowException(f"Error while running run command: {event}") + raise RuntimeError(f"Error while running run command: {event}") self.log.info("SSM run command `%s` completed.", event["command_id"]) return event["command_id"] @@ -112,6 +119,9 @@ def execute(self, context: Context): waiter_delay=self.waiter_delay, waiter_max_attempts=self.waiter_max_attempts, aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + verify=self.verify, + botocore_config=self.botocore_config, ), method_name="execute_complete", ) @@ -125,7 +135,102 @@ def execute(self, context: Context): waiter.wait( CommandId=command_id, InstanceId=instance_id, - WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, + WaiterConfig={ + "Delay": self.waiter_delay, + "MaxAttempts": self.waiter_max_attempts, + }, ) return command_id + + +class SsmGetCommandInvocationOperator(AwsBaseOperator[SsmHook]): + """ + Retrieves the output and execution details of an SSM command invocation. + + This operator allows you to fetch the standard output, standard error, + execution status, and other details from SSM commands. It can be used to + retrieve output from commands executed by SsmRunCommandOperator in previous + tasks, or from commands executed outside of Airflow entirely. + + The operator returns structured data including stdout, stderr, execution + times, and status information for each instance that executed the command. + + .. seealso:: + For more information on how to use this operator, take a look at the + guide: + :ref:`howto/operator:SsmGetCommandInvocationOperator` + + :param command_id: The ID of the SSM command to retrieve output for. + :param instance_id: The ID of the specific instance to retrieve output + for. If not provided, retrieves output from all instances that + executed the command. + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. + If running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default + boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore + client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html + """ + + aws_hook_class = SsmHook + template_fields: Sequence[str] = aws_template_fields( + "command_id", + "instance_id", + ) + + def __init__( + self, + *, + command_id: str, + instance_id: str | None = None, + **kwargs, + ): + super().__init__(**kwargs) + self.command_id = command_id + self.instance_id = instance_id + + def execute(self, context: Context) -> dict[str, Any]: + """Execute the operator to retrieve command invocation output.""" + if self.instance_id: + self.log.info( + "Retrieving output for command %s on instance %s", + self.command_id, + self.instance_id, + ) + invocations = [{"InstanceId": self.instance_id}] + else: + self.log.info("Retrieving output for command %s from all instances", self.command_id) + response = self.hook.list_command_invocations(self.command_id) + invocations = response.get("CommandInvocations", []) + + output_data: dict[str, Any] = {"command_id": self.command_id, "invocations": []} + + for invocation in invocations: + instance_id = invocation["InstanceId"] + try: + invocation_details = self.hook.get_command_invocation(self.command_id, instance_id) + output_data["invocations"].append( + { + "instance_id": instance_id, + "status": invocation_details.get("Status", ""), + "response_code": invocation_details.get("ResponseCode", ""), + "standard_output": invocation_details.get("StandardOutputContent", ""), + "standard_error": invocation_details.get("StandardErrorContent", ""), + "execution_start_time": invocation_details.get("ExecutionStartDateTime", ""), + "execution_end_time": invocation_details.get("ExecutionEndDateTime", ""), + "document_name": invocation_details.get("DocumentName", ""), + "comment": invocation_details.get("Comment", ""), + } + ) + except Exception as e: + self.log.warning("Failed to get output for instance %s: %s", instance_id, e) + output_data["invocations"].append({"instance_id": instance_id, "error": str(e)}) + + return output_data diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/ssm.py b/providers/amazon/src/airflow/providers/amazon/aws/sensors/ssm.py index f2a6d4351b5a9..514c382067d53 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/sensors/ssm.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/ssm.py @@ -21,7 +21,6 @@ from typing import TYPE_CHECKING, Any from airflow.configuration import conf -from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.ssm import SsmHook from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor from airflow.providers.amazon.aws.triggers.ssm import SsmRunCommandTrigger @@ -34,32 +33,45 @@ class SsmRunCommandCompletedSensor(AwsBaseSensor[SsmHook]): """ - Poll the state of an AWS SSM Run Command until all instance jobs reach a terminal state. Fails if any instance job ends in a failed state. + Poll the state of an AWS SSM Run Command until completion. + + Waits until all instance jobs reach a terminal state. Fails if any + instance job ends in a failed state. .. seealso:: - For more information on how to use this sensor, take a look at the guide: + For more information on how to use this sensor, take a look at the + guide: :ref:`howto/sensor:SsmRunCommandCompletedSensor` :param command_id: The ID of the AWS SSM Run Command. - - :param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore - module to be installed. - (default: False, but can be overridden in config file by setting default_deferrable to True) - :param poke_interval: Polling period in seconds to check for the status of the job. (default: 120) - :param max_retries: Number of times before returning the current state. (default: 75) + :param deferrable: If True, the sensor will operate in deferrable mode. + This mode requires aiobotocore module to be installed. + (default: False, but can be overridden in config file by setting + default_deferrable to True) + :param poke_interval: Polling period in seconds to check for the status + of the job. (default: 120) + :param max_retries: Number of times before returning the current state. + (default: 75) :param aws_conn_id: The Airflow connection used for AWS credentials. - If this is ``None`` or empty then the default boto3 behaviour is used. If - running Airflow in a distributed manner and aws_conn_id is None or + If this is ``None`` or empty then the default boto3 behaviour is used. + If running Airflow in a distributed manner and aws_conn_id is None or empty, then default boto3 configuration would be used (and must be maintained on each worker node). - :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param region_name: AWS region_name. If not specified then the default + boto3 behaviour is used. :param verify: Whether or not to verify SSL certificates. See: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html - :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + :param botocore_config: Configuration dictionary (key-values) for botocore + client. See: https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - INTERMEDIATE_STATES: tuple[str, ...] = ("Pending", "Delayed", "InProgress", "Cancelling") + INTERMEDIATE_STATES: tuple[str, ...] = ( + "Pending", + "Delayed", + "InProgress", + "Cancelling", + ) FAILURE_STATES: tuple[str, ...] = ("Cancelled", "TimedOut", "Failed") SUCCESS_STATES: tuple[str, ...] = ("Success",) FAILURE_MESSAGE = "SSM run command sensor failed." @@ -89,14 +101,18 @@ def poke(self, context: Context): command_invocations = response.get("CommandInvocations", []) if not command_invocations: - self.log.info("No command invocations found for command_id=%s yet, waiting...", self.command_id) + self.log.info( + "No command invocations found", + "command_id=%s yet, waiting...", + self.command_id, + ) return False for invocation in command_invocations: state = invocation["Status"] if state in self.FAILURE_STATES: - raise AirflowException(self.FAILURE_MESSAGE) + raise RuntimeError(self.FAILURE_MESSAGE) if state in self.INTERMEDIATE_STATES: return False @@ -122,6 +138,6 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None event = validate_execute_complete_event(event) if event["status"] != "success": - raise AirflowException(f"Error while running run command: {event}") + raise RuntimeError(f"Error while running run command: {event}") self.log.info("SSM run command `%s` completed.", event["command_id"]) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/ssm.py b/providers/amazon/src/airflow/providers/amazon/aws/triggers/ssm.py index 94d0697584a7f..caa0bb81aacbd 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/ssm.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/ssm.py @@ -36,6 +36,11 @@ class SsmRunCommandTrigger(AwsBaseWaiterTrigger): :param waiter_delay: The amount of time in seconds to wait between attempts. (default: 120) :param waiter_max_attempts: The maximum number of attempts to be made. (default: 75) :param aws_conn_id: The Airflow connection used for AWS credentials. + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ def __init__( @@ -45,6 +50,9 @@ def __init__( waiter_delay: int = 120, waiter_max_attempts: int = 75, aws_conn_id: str | None = None, + region_name: str | None = None, + verify: bool | str | None = None, + botocore_config: dict | None = None, ) -> None: super().__init__( serialized_fields={"command_id": command_id}, @@ -58,11 +66,19 @@ def __init__( waiter_delay=waiter_delay, waiter_max_attempts=waiter_max_attempts, aws_conn_id=aws_conn_id, + region_name=region_name, + verify=verify, + botocore_config=botocore_config, ) self.command_id = command_id def hook(self) -> AwsGenericHook: - return SsmHook(aws_conn_id=self.aws_conn_id) + return SsmHook( + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + verify=self.verify, + config=self.botocore_config, + ) async def run(self) -> AsyncIterator[TriggerEvent]: hook = self.hook() diff --git a/providers/amazon/tests/system/amazon/aws/example_ssm.py b/providers/amazon/tests/system/amazon/aws/example_ssm.py index 3af62dcb63ea4..1ac80ac9f5be6 100644 --- a/providers/amazon/tests/system/amazon/aws/example_ssm.py +++ b/providers/amazon/tests/system/amazon/aws/example_ssm.py @@ -24,7 +24,7 @@ import boto3 from airflow.providers.amazon.aws.operators.ec2 import EC2CreateInstanceOperator, EC2TerminateInstanceOperator -from airflow.providers.amazon.aws.operators.ssm import SsmRunCommandOperator +from airflow.providers.amazon.aws.operators.ssm import SsmGetCommandInvocationOperator, SsmRunCommandOperator from airflow.providers.amazon.aws.sensors.ssm import SsmRunCommandCompletedSensor from airflow.sdk import DAG, chain, task @@ -207,6 +207,14 @@ def wait_until_ssm_ready(instance_id: str, max_attempts: int = 10, delay_seconds ) # [END howto_sensor_run_command] + # [START howto_operator_get_command_invocation] + get_command_output = SsmGetCommandInvocationOperator( + task_id="get_command_output", + command_id="{{ ti.xcom_pull(task_ids='run_command') }}", + instance_id=instance_id, + ) + # [END howto_operator_get_command_invocation] + delete_instance = EC2TerminateInstanceOperator( task_id="terminate_instance", trigger_rule=TriggerRule.ALL_DONE, @@ -227,6 +235,7 @@ def wait_until_ssm_ready(instance_id: str, max_attempts: int = 10, delay_seconds # TEST BODY run_command, await_run_command, + get_command_output, # TEST TEARDOWN delete_instance, delete_instance_profile(instance_profile_name, role_name), diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_ssm.py b/providers/amazon/tests/unit/amazon/aws/hooks/test_ssm.py index a083f36b9defa..1ce6a3ece7217 100644 --- a/providers/amazon/tests/unit/amazon/aws/hooks/test_ssm.py +++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_ssm.py @@ -87,3 +87,53 @@ def test_get_parameter_value_param_does_not_exist_no_default_provided(self) -> N error = raised_exception.value.response["Error"] assert error["Code"] == "ParameterNotFound" assert BAD_PARAM_NAME in error["Message"] + + @mock.patch("airflow.providers.amazon.aws.hooks.ssm.SsmHook.conn", new_callable=mock.PropertyMock) + def test_get_command_invocation(self, mock_conn): + command_id = "12345678-1234-1234-1234-123456789012" + instance_id = "i-1234567890abcdef0" + expected_response = { + "CommandId": command_id, + "InstanceId": instance_id, + "Status": "Success", + "ResponseCode": 0, + "StandardOutputContent": "Hello World", + "StandardErrorContent": "", + } + + mock_conn.return_value.get_command_invocation.return_value = expected_response + + result = self.hook.get_command_invocation(command_id, instance_id) + + mock_conn.return_value.get_command_invocation.assert_called_once_with( + CommandId=command_id, InstanceId=instance_id + ) + assert result == expected_response + + @mock.patch("airflow.providers.amazon.aws.hooks.ssm.SsmHook.conn", new_callable=mock.PropertyMock) + def test_list_command_invocations(self, mock_conn): + command_id = "12345678-1234-1234-1234-123456789012" + expected_invocations = [ + {"InstanceId": "i-111", "Status": "Success"}, + {"InstanceId": "i-222", "Status": "Failed"}, + ] + expected_response = {"CommandInvocations": expected_invocations} + + mock_conn.return_value.list_command_invocations.return_value = expected_response + + result = self.hook.list_command_invocations(command_id) + + mock_conn.return_value.list_command_invocations.assert_called_once_with(CommandId=command_id) + assert result == expected_response + + @mock.patch("airflow.providers.amazon.aws.hooks.ssm.SsmHook.conn", new_callable=mock.PropertyMock) + def test_list_command_invocations_empty_response(self, mock_conn): + command_id = "12345678-1234-1234-1234-123456789012" + expected_response = {} # No CommandInvocations key + + mock_conn.return_value.list_command_invocations.return_value = expected_response + + result = self.hook.list_command_invocations(command_id) + + mock_conn.return_value.list_command_invocations.assert_called_once_with(CommandId=command_id) + assert result == expected_response diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_ssm.py b/providers/amazon/tests/unit/amazon/aws/operators/test_ssm.py index 2714cb45313d6..f0027da110f77 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_ssm.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_ssm.py @@ -22,7 +22,7 @@ import pytest from airflow.providers.amazon.aws.hooks.ssm import SsmHook -from airflow.providers.amazon.aws.operators.ssm import SsmRunCommandOperator +from airflow.providers.amazon.aws.operators.ssm import SsmGetCommandInvocationOperator, SsmRunCommandOperator from unit.amazon.aws.utils.test_template_fields import validate_template_fields @@ -73,3 +73,186 @@ def test_run_command_wait_combinations(self, mock_get_waiter, wait_for_completio def test_template_fields(self): validate_template_fields(self.operator) + + def test_deferrable_with_region(self, mock_conn): + """Test that deferrable mode properly passes region and other AWS parameters to trigger.""" + self.operator.deferrable = True + self.operator.region_name = "us-west-2" + self.operator.verify = False + self.operator.botocore_config = {"retries": {"max_attempts": 5}} + + command_id = self.operator.execute({}) + + assert command_id == COMMAND_ID + mock_conn.send_command.assert_called_once_with(DocumentName=DOCUMENT_NAME, InstanceIds=INSTANCE_IDS) + + # Verify defer was called with correct trigger parameters + self.operator.defer.assert_called_once() + call_args = self.operator.defer.call_args + trigger = call_args[1]["trigger"] # Get the trigger from kwargs + + # Verify the trigger has the correct parameters + assert trigger.command_id == COMMAND_ID + assert trigger.region_name == "us-west-2" + assert trigger.verify is False + assert trigger.botocore_config == {"retries": {"max_attempts": 5}} + assert trigger.aws_conn_id == self.operator.aws_conn_id + + +class TestSsmGetCommandInvocationOperator: + @pytest.fixture + def mock_hook(self) -> Generator[mock.MagicMock, None, None]: + with mock.patch.object(SsmGetCommandInvocationOperator, "hook") as _hook: + yield _hook + + def setup_method(self): + self.command_id = "test-command-id-123" + self.instance_id = "i-1234567890abcdef0" + self.operator = SsmGetCommandInvocationOperator( + task_id="test_get_command_invocation", + command_id=self.command_id, + instance_id=self.instance_id, + ) + + def test_execute_with_specific_instance(self, mock_hook): + # Mock response for specific instance + mock_invocation_details = { + "Status": "Success", + "ResponseCode": 0, + "StandardOutputContent": "Hello World", + "StandardErrorContent": "", + "ExecutionStartDateTime": "2023-01-01T12:00:00Z", + "ExecutionEndDateTime": "2023-01-01T12:00:05Z", + "DocumentName": "AWS-RunShellScript", + "Comment": "Test command", + } + mock_hook.get_command_invocation.return_value = mock_invocation_details + + result = self.operator.execute({}) + + # Verify hook was called correctly + mock_hook.get_command_invocation.assert_called_once_with(self.command_id, self.instance_id) + + # Verify returned data structure - should use standardized format with invocations array + expected_result = { + "command_id": self.command_id, + "invocations": [ + { + "instance_id": self.instance_id, + "status": "Success", + "response_code": 0, + "standard_output": "Hello World", + "standard_error": "", + "execution_start_time": "2023-01-01T12:00:00Z", + "execution_end_time": "2023-01-01T12:00:05Z", + "document_name": "AWS-RunShellScript", + "comment": "Test command", + } + ], + } + assert result == expected_result + + def test_execute_all_instances(self, mock_hook): + # Setup operator without instance_id to get all instances + operator = SsmGetCommandInvocationOperator( + task_id="test_get_all_invocations", + command_id=self.command_id, + ) + + # Mock list_command_invocations response + mock_invocations = [ + {"InstanceId": "i-111"}, + {"InstanceId": "i-222"}, + ] + mock_hook.list_command_invocations.return_value = {"CommandInvocations": mock_invocations} + + # Mock get_command_invocation responses + mock_invocation_details_1 = { + "Status": "Success", + "ResponseCode": 0, + "StandardOutputContent": "Output 1", + "StandardErrorContent": "", + "ExecutionStartDateTime": "2023-01-01T12:00:00Z", + "ExecutionEndDateTime": "2023-01-01T12:00:05Z", + "DocumentName": "AWS-RunShellScript", + "Comment": "", + } + mock_invocation_details_2 = { + "Status": "Failed", + "ResponseCode": 1, + "StandardOutputContent": "", + "StandardErrorContent": "Error occurred", + "ExecutionStartDateTime": "2023-01-01T12:00:00Z", + "ExecutionEndDateTime": "2023-01-01T12:00:10Z", + "DocumentName": "AWS-RunShellScript", + "Comment": "", + } + + mock_hook.get_command_invocation.side_effect = [ + mock_invocation_details_1, + mock_invocation_details_2, + ] + + result = operator.execute({}) + + # Verify hook calls + mock_hook.list_command_invocations.assert_called_once_with(self.command_id) + assert mock_hook.get_command_invocation.call_count == 2 + mock_hook.get_command_invocation.assert_any_call(self.command_id, "i-111") + mock_hook.get_command_invocation.assert_any_call(self.command_id, "i-222") + + # Verify returned data structure + expected_result = { + "command_id": self.command_id, + "invocations": [ + { + "instance_id": "i-111", + "status": "Success", + "response_code": 0, + "standard_output": "Output 1", + "standard_error": "", + "execution_start_time": "2023-01-01T12:00:00Z", + "execution_end_time": "2023-01-01T12:00:05Z", + "document_name": "AWS-RunShellScript", + "comment": "", + }, + { + "instance_id": "i-222", + "status": "Failed", + "response_code": 1, + "standard_output": "", + "standard_error": "Error occurred", + "execution_start_time": "2023-01-01T12:00:00Z", + "execution_end_time": "2023-01-01T12:00:10Z", + "document_name": "AWS-RunShellScript", + "comment": "", + }, + ], + } + assert result == expected_result + + def test_execute_all_instances_with_error(self, mock_hook): + # Setup operator without instance_id + operator = SsmGetCommandInvocationOperator( + task_id="test_get_all_with_error", + command_id=self.command_id, + ) + + # Mock list_command_invocations response + mock_invocations = [{"InstanceId": "i-111"}] + mock_hook.list_command_invocations.return_value = {"CommandInvocations": mock_invocations} + + # Mock get_command_invocation to raise an exception + mock_hook.get_command_invocation.side_effect = Exception("API Error") + + result = operator.execute({}) + + # Verify error handling + expected_result = { + "command_id": self.command_id, + "invocations": [{"instance_id": "i-111", "error": "API Error"}], + } + assert result == expected_result + + def test_template_fields(self): + validate_template_fields(self.operator) diff --git a/providers/amazon/tests/unit/amazon/aws/sensors/test_ssm.py b/providers/amazon/tests/unit/amazon/aws/sensors/test_ssm.py index faf91f379b236..6679cb1dc5669 100644 --- a/providers/amazon/tests/unit/amazon/aws/sensors/test_ssm.py +++ b/providers/amazon/tests/unit/amazon/aws/sensors/test_ssm.py @@ -21,7 +21,6 @@ import pytest -from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.ssm import SsmHook from airflow.providers.amazon.aws.sensors.ssm import SsmRunCommandCompletedSensor @@ -91,5 +90,5 @@ def test_poke_intermediate_states(self, mock_conn, state, mock_ssm_list_invocati @mock.patch.object(SsmHook, "conn") def test_poke_failure_states(self, mock_conn, state, mock_ssm_list_invocations): mock_ssm_list_invocations(mock_conn, state) - with pytest.raises(AirflowException, match=self.SENSOR.FAILURE_MESSAGE): + with pytest.raises(RuntimeError, match=self.SENSOR.FAILURE_MESSAGE): self.sensor.poke({}) diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_ssm.py b/providers/amazon/tests/unit/amazon/aws/triggers/test_ssm.py index d19acc0de1b62..f5b75cf302e53 100644 --- a/providers/amazon/tests/unit/amazon/aws/triggers/test_ssm.py +++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_ssm.py @@ -61,6 +61,24 @@ def test_serialization(self): assert classpath == BASE_TRIGGER_CLASSPATH + "SsmRunCommandTrigger" assert kwargs.get("command_id") == COMMAND_ID + def test_serialization_with_region(self): + """Test that region_name and other AWS parameters are properly serialized.""" + trigger = SsmRunCommandTrigger( + command_id=COMMAND_ID, + region_name="us-east-1", + aws_conn_id="test_conn", + verify=True, + botocore_config={"retries": {"max_attempts": 3}}, + ) + classpath, kwargs = trigger.serialize() + + assert classpath == BASE_TRIGGER_CLASSPATH + "SsmRunCommandTrigger" + assert kwargs.get("command_id") == COMMAND_ID + assert kwargs.get("region_name") == "us-east-1" + assert kwargs.get("aws_conn_id") == "test_conn" + assert kwargs.get("verify") is True + assert kwargs.get("botocore_config") == {"retries": {"max_attempts": 3}} + @pytest.mark.asyncio @mock.patch.object(SsmHook, "get_async_conn") @mock.patch.object(SsmHook, "get_waiter")