diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index a146fb417b2df..57fb667c69115 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1156,6 +1156,7 @@ mssql mTLS mtls muldelete +multi-cloud multimodal Multinamespace mutex diff --git a/providers/amazon/docs/operators/ssm.rst b/providers/amazon/docs/operators/ssm.rst new file mode 100644 index 0000000000000..197bd34b4fafc --- /dev/null +++ b/providers/amazon/docs/operators/ssm.rst @@ -0,0 +1,103 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +=================================== +Amazon Simple Systems Manager (SSM) +=================================== + +`Amazon Simple Systems Manager (Amazon SSM) `__ is a service +that helps centrally view, manage, and operate nodes at scale in AWS, on-premises, and multi-cloud +environments. Systems Manager consolidates various tools to help complete common node tasks across AWS +accounts and Regions. +To use Systems Manager, nodes must be managed, which means SSM Agent is installed on the machine and +the agent can communicate with the Systems Manager service. + +Prerequisite Tasks +------------------ + +.. include:: ../_partials/prerequisite_tasks.rst + +Generic Parameters +------------------ + +.. include:: ../_partials/generic_parameters.rst + +Operators +--------- + +.. _howto/operator:SsmRunCommandOperator: + +Runs commands on one or more managed nodes +========================================== + +To run SSM run command, you can use +:class:`~airflow.providers.amazon.aws.operators.ssm.SsmRunCommandOperator`. + +To monitor the state of the command for a specific instance, you can use the "command_executed" +Waiter. Additionally, you can use the following components to track the status of the command execution: +:class:`~airflow.providers.amazon.aws.sensors.ssm.SsmRunCommandCompletedSensor` Sensor, +or the :class:`~airflow.providers.amazon.aws.triggers.ssm.SsmRunCommandTrigger` Trigger. + + +.. exampleinclude:: /../../amazon/tests/system/amazon/aws/example_ssm.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_run_command] + :end-before: [END howto_operator_run_command] + +Sensors +------- + +.. _howto/sensor:SsmRunCommandCompletedSensor: + +Wait for an Amazon SSM run command +================================== + +To wait on the state of an Amazon SSM run command job until it reaches a terminal state you can use +:class:`~airflow.providers.amazon.aws.sensors.SSM.SsmRunCommandCompletedSensor` + +.. exampleinclude:: /../../amazon/tests/system/amazon/aws/example_ssm.py + :language: python + :dedent: 4 + :start-after: [START howto_sensor_run_command] + :end-before: [END howto_sensor_run_command] + +IAM Permissions +--------------- + +You need to ensure the following IAM permissions are granted to allow Airflow to run and monitor SSM Run Command executions: + +.. code-block:: + + { + "Effect": "Allow", + "Action": [ + "ssm:SendCommand", + "ssm:ListCommandInvocations", + "ssm:GetCommandInvocation" + ], + "Resource": "*" + } + +This policy allows access to all SSM documents and managed instances. For production environments, +it is recommended to restrict the ``Resource`` field to specific SSM document ARNs and, if applicable, +to the ARNs of intended target resources (such as EC2 instances), in accordance with the principle of least privilege. + +Reference +--------- + +* `AWS boto3 library documentation for Amazon SSM `__ diff --git a/providers/amazon/provider.yaml b/providers/amazon/provider.yaml index a184a4e996f38..757a7ab92a83a 100644 --- a/providers/amazon/provider.yaml +++ b/providers/amazon/provider.yaml @@ -281,6 +281,8 @@ integrations: - integration-name: Amazon Systems Manager (SSM) external-doc-url: https://aws.amazon.com/systems-manager/ logo: /docs/integration-logos/AWS-Systems-Manager_light-bg@4x.png + how-to-guide: + - /docs/apache-airflow-providers-amazon/operators/ssm.rst tags: [aws] - integration-name: Amazon Web Services external-doc-url: https://aws.amazon.com/ @@ -430,6 +432,9 @@ operators: - integration-name: AWS Step Functions python-modules: - airflow.providers.amazon.aws.operators.step_function + - integration-name: Amazon Systems Manager (SSM) + python-modules: + - airflow.providers.amazon.aws.operators.ssm - integration-name: Amazon RDS python-modules: - airflow.providers.amazon.aws.operators.rds @@ -531,6 +536,9 @@ sensors: - integration-name: AWS Step Functions python-modules: - airflow.providers.amazon.aws.sensors.step_function + - integration-name: Amazon Systems Manager (SSM) + python-modules: + - airflow.providers.amazon.aws.sensors.ssm - integration-name: Amazon QuickSight python-modules: - airflow.providers.amazon.aws.sensors.quicksight @@ -737,6 +745,9 @@ triggers: - integration-name: Amazon Simple Storage Service (S3) python-modules: - airflow.providers.amazon.aws.triggers.s3 + - integration-name: Amazon Systems Manager (SSM) + python-modules: + - airflow.providers.amazon.aws.triggers.ssm - integration-name: Amazon EMR python-modules: - airflow.providers.amazon.aws.triggers.emr diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/ssm.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/ssm.py new file mode 100644 index 0000000000000..52480ed09f068 --- /dev/null +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/ssm.py @@ -0,0 +1,131 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any + +from airflow.configuration import conf +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.ssm import SsmHook +from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator +from airflow.providers.amazon.aws.triggers.ssm import SsmRunCommandTrigger +from airflow.providers.amazon.aws.utils import validate_execute_complete_event +from airflow.providers.amazon.aws.utils.mixins import aws_template_fields + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class SsmRunCommandOperator(AwsBaseOperator[SsmHook]): + """ + Executes the SSM Run Command to perform actions on managed instances. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SsmRunCommandOperator` + + :param document_name: The name of the Amazon Web Services Systems Manager document (SSM document) to run. + :param run_command_kwargs: Optional parameters to pass to the send_command API. + + :param wait_for_completion: Whether to wait for cluster to stop. (default: True) + :param waiter_delay: Time in seconds to wait between status checks. (default: 120) + :param waiter_max_attempts: Maximum number of attempts to check for job completion. (default: 75) + :param deferrable: If True, the operator will wait asynchronously for the cluster to stop. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html + """ + + aws_hook_class = SsmHook + template_fields: Sequence[str] = aws_template_fields( + "document_name", + "run_command_kwargs", + ) + + def __init__( + self, + *, + document_name: str, + run_command_kwargs: dict[str, Any] | None = None, + wait_for_completion: bool = True, + waiter_delay: int = 120, + waiter_max_attempts: int = 75, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs, + ): + super().__init__(**kwargs) + self.wait_for_completion = wait_for_completion + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + self.deferrable = deferrable + + self.document_name = document_name + self.run_command_kwargs = run_command_kwargs or {} + + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str: + event = validate_execute_complete_event(event) + + if event["status"] != "success": + raise AirflowException(f"Error while running run command: {event}") + + self.log.info("SSM run command `%s` completed.", event["command_id"]) + return event["command_id"] + + def execute(self, context: Context): + response = self.hook.conn.send_command( + DocumentName=self.document_name, + **self.run_command_kwargs, + ) + + command_id = response["Command"]["CommandId"] + task_description = f"SSM run command {command_id} to complete." + + if self.deferrable: + self.log.info("Deferring for %s", task_description) + self.defer( + trigger=SsmRunCommandTrigger( + command_id=command_id, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + aws_conn_id=self.aws_conn_id, + ), + method_name="execute_complete", + ) + + elif self.wait_for_completion: + self.log.info("Waiting for %s", task_description) + waiter = self.hook.get_waiter("command_executed") + + instance_ids = response["Command"]["InstanceIds"] + for instance_id in instance_ids: + waiter.wait( + CommandId=command_id, + InstanceId=instance_id, + WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, + ) + + return command_id diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/ssm.py b/providers/amazon/src/airflow/providers/amazon/aws/sensors/ssm.py new file mode 100644 index 0000000000000..f2a6d4351b5a9 --- /dev/null +++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/ssm.py @@ -0,0 +1,127 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any + +from airflow.configuration import conf +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.ssm import SsmHook +from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor +from airflow.providers.amazon.aws.triggers.ssm import SsmRunCommandTrigger +from airflow.providers.amazon.aws.utils import validate_execute_complete_event +from airflow.providers.amazon.aws.utils.mixins import aws_template_fields + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class SsmRunCommandCompletedSensor(AwsBaseSensor[SsmHook]): + """ + Poll the state of an AWS SSM Run Command until all instance jobs reach a terminal state. Fails if any instance job ends in a failed state. + + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/sensor:SsmRunCommandCompletedSensor` + + :param command_id: The ID of the AWS SSM Run Command. + + :param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore + module to be installed. + (default: False, but can be overridden in config file by setting default_deferrable to True) + :param poke_interval: Polling period in seconds to check for the status of the job. (default: 120) + :param max_retries: Number of times before returning the current state. (default: 75) + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html + """ + + INTERMEDIATE_STATES: tuple[str, ...] = ("Pending", "Delayed", "InProgress", "Cancelling") + FAILURE_STATES: tuple[str, ...] = ("Cancelled", "TimedOut", "Failed") + SUCCESS_STATES: tuple[str, ...] = ("Success",) + FAILURE_MESSAGE = "SSM run command sensor failed." + + aws_hook_class = SsmHook + template_fields: Sequence[str] = aws_template_fields( + "command_id", + ) + + def __init__( + self, + *, + command_id, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + poke_interval: int = 120, + max_retries: int = 75, + **kwargs, + ): + super().__init__(**kwargs) + self.command_id = command_id + self.deferrable = deferrable + self.poke_interval = poke_interval + self.max_retries = max_retries + + def poke(self, context: Context): + response = self.hook.conn.list_command_invocations(CommandId=self.command_id) + command_invocations = response.get("CommandInvocations", []) + + if not command_invocations: + self.log.info("No command invocations found for command_id=%s yet, waiting...", self.command_id) + return False + + for invocation in command_invocations: + state = invocation["Status"] + + if state in self.FAILURE_STATES: + raise AirflowException(self.FAILURE_MESSAGE) + + if state in self.INTERMEDIATE_STATES: + return False + + return True + + def execute(self, context: Context): + if self.deferrable: + self.defer( + trigger=SsmRunCommandTrigger( + command_id=self.command_id, + waiter_delay=int(self.poke_interval), + waiter_max_attempts=self.max_retries, + aws_conn_id=self.aws_conn_id, + ), + method_name="execute_complete", + ) + + else: + super().execute(context=context) + + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: + event = validate_execute_complete_event(event) + + if event["status"] != "success": + raise AirflowException(f"Error while running run command: {event}") + + self.log.info("SSM run command `%s` completed.", event["command_id"]) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/ssm.py b/providers/amazon/src/airflow/providers/amazon/aws/triggers/ssm.py new file mode 100644 index 0000000000000..f7efc916c7289 --- /dev/null +++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/ssm.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from collections.abc import AsyncIterator +from typing import TYPE_CHECKING + +from airflow.providers.amazon.aws.hooks.ssm import SsmHook +from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger +from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait +from airflow.triggers.base import TriggerEvent + +if TYPE_CHECKING: + from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook + + +class SsmRunCommandTrigger(AwsBaseWaiterTrigger): + """ + Trigger when a SSM run command is complete. + + :param command_id: The ID of the AWS SSM Run Command. + :param waiter_delay: The amount of time in seconds to wait between attempts. (default: 120) + :param waiter_max_attempts: The maximum number of attempts to be made. (default: 75) + :param aws_conn_id: The Airflow connection used for AWS credentials. + """ + + def __init__( + self, + *, + command_id: str, + waiter_delay: int = 120, + waiter_max_attempts: int = 75, + aws_conn_id: str | None = None, + ) -> None: + super().__init__( + serialized_fields={"command_id": command_id}, + waiter_name="command_executed", + waiter_args={"CommandId": command_id}, + failure_message="SSM run command failed.", + status_message="Status of SSM run command is", + status_queries=["status"], + return_key="command_id", + return_value=command_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + ) + self.command_id = command_id + + def hook(self) -> AwsGenericHook: + return SsmHook(aws_conn_id=self.aws_conn_id) + + async def run(self) -> AsyncIterator[TriggerEvent]: + hook = self.hook() + async with hook.async_conn as client: + response = client.list_command_invocations(CommandId=self.command_id) + instance_ids = [invocation["InstanceId"] for invocation in response.get("CommandInvocations", [])] + waiter = hook.get_waiter(self.waiter_name, deferrable=True, client=client) + + for instance_id in instance_ids: + self.waiter_args["InstanceId"] = instance_id + await async_wait( + waiter, + self.waiter_delay, + self.attempts, + self.waiter_args, + self.failure_message, + self.status_message, + self.status_queries, + ) + + yield TriggerEvent({"status": "success", self.return_key: self.return_value}) diff --git a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py index 731e9bb377c3d..fc546e5eb6a24 100644 --- a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py +++ b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py @@ -250,6 +250,7 @@ def get_provider_info(): "integration-name": "Amazon Systems Manager (SSM)", "external-doc-url": "https://aws.amazon.com/systems-manager/", "logo": "/docs/integration-logos/AWS-Systems-Manager_light-bg@4x.png", + "how-to-guide": ["/docs/apache-airflow-providers-amazon/operators/ssm.rst"], "tags": ["aws"], }, { @@ -444,6 +445,10 @@ def get_provider_info(): "integration-name": "AWS Step Functions", "python-modules": ["airflow.providers.amazon.aws.operators.step_function"], }, + { + "integration-name": "Amazon Systems Manager (SSM)", + "python-modules": ["airflow.providers.amazon.aws.operators.ssm"], + }, { "integration-name": "Amazon RDS", "python-modules": ["airflow.providers.amazon.aws.operators.rds"], @@ -581,6 +586,10 @@ def get_provider_info(): "integration-name": "AWS Step Functions", "python-modules": ["airflow.providers.amazon.aws.sensors.step_function"], }, + { + "integration-name": "Amazon Systems Manager (SSM)", + "python-modules": ["airflow.providers.amazon.aws.sensors.ssm"], + }, { "integration-name": "Amazon QuickSight", "python-modules": ["airflow.providers.amazon.aws.sensors.quicksight"], @@ -846,6 +855,10 @@ def get_provider_info(): "integration-name": "Amazon Simple Storage Service (S3)", "python-modules": ["airflow.providers.amazon.aws.triggers.s3"], }, + { + "integration-name": "Amazon Systems Manager (SSM)", + "python-modules": ["airflow.providers.amazon.aws.triggers.ssm"], + }, { "integration-name": "Amazon EMR", "python-modules": ["airflow.providers.amazon.aws.triggers.emr"], diff --git a/providers/amazon/tests/system/amazon/aws/example_ec2.py b/providers/amazon/tests/system/amazon/aws/example_ec2.py index 3ae43b357ad77..1290efdce4e40 100644 --- a/providers/amazon/tests/system/amazon/aws/example_ec2.py +++ b/providers/amazon/tests/system/amazon/aws/example_ec2.py @@ -17,7 +17,6 @@ from __future__ import annotations from datetime import datetime -from operator import itemgetter import boto3 @@ -36,42 +35,13 @@ from airflow.utils.trigger_rule import TriggerRule from system.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder +from system.amazon.aws.utils.ec2 import get_latest_ami_id DAG_ID = "example_ec2" sys_test_context_task = SystemTestContextBuilder().build() -@task -def get_latest_ami_id(): - """Returns the AMI ID of the most recently-created Amazon Linux image""" - - # Amazon is retiring AL2 in 2023 and replacing it with Amazon Linux 2022. - # This image prefix should be futureproof, but may need adjusting depending - # on how they name the new images. This page should have AL2022 info when - # it comes available: https://aws.amazon.com/linux/amazon-linux-2022/faqs/ - image_prefix = "Amazon Linux*" - root_device_name = "/dev/xvda" - - images = boto3.client("ec2").describe_images( - Filters=[ - {"Name": "description", "Values": [image_prefix]}, - { - "Name": "architecture", - "Values": ["x86_64"], - }, # t3 instances are only compatible with x86 architecture - { - "Name": "root-device-type", - "Values": ["ebs"], - }, # instances which are capable of hibernation need to use an EBS-backed AMI - {"Name": "root-device-name", "Values": [root_device_name]}, - ], - Owners=["amazon"], - ) - # Sort on CreationDate - return max(images["Images"], key=itemgetter("CreationDate"))["ImageId"] - - @task def create_key_pair(key_name: str): client = boto3.client("ec2") diff --git a/providers/amazon/tests/system/amazon/aws/example_glue.py b/providers/amazon/tests/system/amazon/aws/example_glue.py index 8aaf5f2916a91..6d9b10623ee51 100644 --- a/providers/amazon/tests/system/amazon/aws/example_glue.py +++ b/providers/amazon/tests/system/amazon/aws/example_glue.py @@ -36,7 +36,7 @@ from airflow.providers.amazon.aws.sensors.glue_crawler import GlueCrawlerSensor from airflow.utils.trigger_rule import TriggerRule -from system.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder, prune_logs +from system.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder, get_role_name, prune_logs if TYPE_CHECKING: from botocore.client import BaseClient @@ -71,11 +71,6 @@ """ -@task -def get_role_name(arn: str) -> str: - return arn.split("/")[-1] - - @task(trigger_rule=TriggerRule.ALL_DONE) def glue_cleanup(crawler_name: str, job_name: str, db_name: str) -> None: client: BaseClient = boto3.client("glue") diff --git a/providers/amazon/tests/system/amazon/aws/example_ssm.py b/providers/amazon/tests/system/amazon/aws/example_ssm.py new file mode 100644 index 0000000000000..e6b05145b36f7 --- /dev/null +++ b/providers/amazon/tests/system/amazon/aws/example_ssm.py @@ -0,0 +1,242 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import datetime +import logging +import textwrap +import time + +import boto3 + +from airflow.decorators import task +from airflow.models.baseoperator import chain +from airflow.models.dag import DAG +from airflow.providers.amazon.aws.operators.ec2 import EC2CreateInstanceOperator, EC2TerminateInstanceOperator +from airflow.providers.amazon.aws.operators.ssm import SsmRunCommandOperator +from airflow.providers.amazon.aws.sensors.ssm import SsmRunCommandCompletedSensor +from airflow.utils.trigger_rule import TriggerRule + +from system.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder, get_role_name +from system.amazon.aws.utils.ec2 import get_latest_ami_id + +DAG_ID = "example_ssm" + +ROLE_ARN_KEY = "ROLE_ARN" +sys_test_context_task = SystemTestContextBuilder().add_variable(ROLE_ARN_KEY).build() + +USER_DATA = textwrap.dedent("""\ + #!/bin/bash + set -e + + # Update the system + if command -v yum &> /dev/null; then + PACKAGE_MANAGER="yum" + elif command -v dnf &> /dev/null; then + PACKAGE_MANAGER="dnf" + else + echo "No suitable package manager found" + exit 1 + fi + + # Install SSM agent if it's not installed + if ! command -v amazon-ssm-agent &> /dev/null; then + echo "Installing SSM agent..." + $PACKAGE_MANAGER install -y amazon-ssm-agent + else + echo "SSM agent already installed" + fi + + echo "Enabling and starting SSM agent..." + systemctl enable amazon-ssm-agent + systemctl start amazon-ssm-agent + + shutdown -h +15 + + echo "=== Finished user-data script ===" +""") + +log = logging.getLogger(__name__) + + +@task +def create_instance_profile(role_name: str, instance_profile_name: str): + client = boto3.client("iam") + client.create_instance_profile(InstanceProfileName=instance_profile_name) + client.add_role_to_instance_profile(InstanceProfileName=instance_profile_name, RoleName=role_name) + + +@task +def await_instance_profile_exists(instance_profile_name): + client = boto3.client("iam") + client.get_waiter("instance_profile_exists").wait(InstanceProfileName=instance_profile_name) + + +@task +def delete_instance_profile(instance_profile_name, role_name): + client = boto3.client("iam") + + try: + client.remove_role_from_instance_profile( + InstanceProfileName=instance_profile_name, RoleName=role_name + ) + except client.exceptions.NoSuchEntityException: + log.info("Role %s not attached to %s or already removed.", role_name, instance_profile_name) + + try: + client.delete_instance_profile(InstanceProfileName=instance_profile_name) + except client.exceptions.NoSuchEntityException: + log.info("Instance profile %s already deleted.", instance_profile_name) + + +@task +def extract_instance_id(instance_ids: list) -> str: + return instance_ids[0] + + +@task +def build_run_command_kwargs(instance_id: str): + return { + "InstanceIds": [instance_id], + "Parameters": {"commands": ["touch /tmp/ssm_test_passed"]}, + } + + +@task +def wait_until_ssm_ready(instance_id: str, max_attempts: int = 10, delay_seconds: int = 15): + """ + Waits for an EC2 instance to register with AWS Systems Manager (SSM). + + This may take over a minute even after the instance is running. + Raises an exception if the instance is not ready after max_attempts. + """ + ssm = boto3.client("ssm") + + for _ in range(max_attempts): + response = ssm.describe_instance_information( + Filters=[{"Key": "InstanceIds", "Values": [instance_id]}] + ) + + if ( + response.get("InstanceInformationList") + and response["InstanceInformationList"][0]["PingStatus"] == "Online" + ): + return + + time.sleep(delay_seconds) + + raise Exception(f"Instance {instance_id} not ready in SSM after {max_attempts} attempts.") + + +with DAG( + dag_id=DAG_ID, + schedule="@once", + start_date=datetime.datetime(2021, 1, 1), + tags=["example"], + catchup=False, +) as dag: + # Create EC2 instance with SSM agent + test_context = sys_test_context_task() + env_id = test_context[ENV_ID_KEY] + instance_name = f"{env_id}-instance" + image_id = get_latest_ami_id() + role_name = get_role_name(test_context[ROLE_ARN_KEY]) + instance_profile_name = f"{env_id}-ssm-instance-profile" + + config = { + "InstanceType": "t2.micro", + "IamInstanceProfile": {"Name": instance_profile_name}, + # Optional: Tags for identifying test resources in the AWS console + "TagSpecifications": [ + {"ResourceType": "instance", "Tags": [{"Key": "Name", "Value": instance_name}]} + ], + "UserData": USER_DATA, + # Use IMDSv2 for greater security, see the following doc for more details: + # https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-instance-metadata-service.html + "MetadataOptions": {"HttpEndpoint": "enabled", "HttpTokens": "required"}, + "BlockDeviceMappings": [ + {"DeviceName": "/dev/xvda", "Ebs": {"Encrypted": True, "DeleteOnTermination": True}} + ], + "InstanceInitiatedShutdownBehavior": "terminate", + } + + create_instance = EC2CreateInstanceOperator( + task_id="create_instance", + image_id=image_id, + max_count=1, + min_count=1, + config=config, + wait_for_completion=True, + retries=5, + retry_delay=datetime.timedelta(seconds=15), + ) + + instance_id = extract_instance_id(create_instance.output) + + run_command_kwargs = build_run_command_kwargs(instance_id) + + # [START howto_operator_run_command] + run_command = SsmRunCommandOperator( + task_id="run_command", + document_name="AWS-RunShellScript", + run_command_kwargs=run_command_kwargs, + wait_for_completion=False, + ) + # [END howto_operator_run_command] + + # [START howto_sensor_run_command] + await_run_command = SsmRunCommandCompletedSensor( + task_id="await_run_command", command_id=run_command.output + ) + # [END howto_sensor_run_command] + + delete_instance = EC2TerminateInstanceOperator( + task_id="terminate_instance", + trigger_rule=TriggerRule.ALL_DONE, + instance_ids=instance_id, + ) + + chain( + # TEST SETUP + test_context, + image_id, + role_name, + create_instance_profile(role_name, instance_profile_name), + await_instance_profile_exists(instance_profile_name), + create_instance, + instance_id, + run_command_kwargs, + wait_until_ssm_ready(instance_id), + # TEST BODY + run_command, + await_run_command, + # TEST TEARDOWN + delete_instance, + delete_instance_profile(instance_profile_name, role_name), + ) + + from tests_common.test_utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + + +from tests_common.test_utils.system_tests import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) diff --git a/providers/amazon/tests/system/amazon/aws/utils/__init__.py b/providers/amazon/tests/system/amazon/aws/utils/__init__.py index 5147d8309b7f7..a9a3f9d0ee1ba 100644 --- a/providers/amazon/tests/system/amazon/aws/utils/__init__.py +++ b/providers/amazon/tests/system/amazon/aws/utils/__init__.py @@ -382,3 +382,8 @@ def _purge_logs( @task def split_string(string): return string.split(",") + + +@task +def get_role_name(arn: str) -> str: + return arn.split("/")[-1] diff --git a/providers/amazon/tests/system/amazon/aws/utils/ec2.py b/providers/amazon/tests/system/amazon/aws/utils/ec2.py index c2d1411374733..9fc8354d9db52 100644 --- a/providers/amazon/tests/system/amazon/aws/utils/ec2.py +++ b/providers/amazon/tests/system/amazon/aws/utils/ec2.py @@ -17,6 +17,7 @@ from __future__ import annotations from ipaddress import IPv4Network +from operator import itemgetter import boto3 @@ -45,6 +46,36 @@ def _get_next_available_cidr(vpc_id: str) -> str: return f"{last_reserved_ip + 1}/{last_used_block.prefixlen}" +@task +def get_latest_ami_id(): + """Returns the AMI ID of the most recently-created Amazon Linux image""" + + # Amazon is retiring AL2 in 2023 and replacing it with Amazon Linux 2022. + # This image prefix should be futureproof, but may need adjusting depending + # on how they name the new images. This page should have AL2022 info when + # it comes available: https://aws.amazon.com/linux/amazon-linux-2022/faqs/ + image_prefix = "Amazon Linux*" + root_device_name = "/dev/xvda" + + images = boto3.client("ec2").describe_images( + Filters=[ + {"Name": "description", "Values": [image_prefix]}, + { + "Name": "architecture", + "Values": ["x86_64"], + }, # t3 instances are only compatible with x86 architecture + { + "Name": "root-device-type", + "Values": ["ebs"], + }, # instances which are capable of hibernation need to use an EBS-backed AMI + {"Name": "root-device-name", "Values": [root_device_name]}, + ], + Owners=["amazon"], + ) + # Sort on CreationDate + return max(images["Images"], key=itemgetter("CreationDate"))["ImageId"] + + @task def get_default_vpc_id() -> str: """Returns the VPC ID of the account's default VPC.""" diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_ssm.py b/providers/amazon/tests/unit/amazon/aws/operators/test_ssm.py new file mode 100644 index 0000000000000..2714cb45313d6 --- /dev/null +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_ssm.py @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from collections.abc import Generator +from unittest import mock + +import pytest + +from airflow.providers.amazon.aws.hooks.ssm import SsmHook +from airflow.providers.amazon.aws.operators.ssm import SsmRunCommandOperator + +from unit.amazon.aws.utils.test_template_fields import validate_template_fields + +COMMAND_ID = "test_command_id" +DOCUMENT_NAME = "test_ssm_custom_document" +INSTANCE_IDS = ["test_instance_id_1", "test_instance_id_2"] + + +class TestSsmRunCommandOperator: + @pytest.fixture + def mock_conn(self) -> Generator[SsmHook, None, None]: + with mock.patch.object(SsmHook, "conn") as _conn: + _conn.send_command.return_value = { + "Command": { + "CommandId": COMMAND_ID, + "InstanceIds": INSTANCE_IDS, + } + } + yield _conn + + def setup_method(self): + self.operator = SsmRunCommandOperator( + task_id="test_run_command_operator", + document_name=DOCUMENT_NAME, + run_command_kwargs={"InstanceIds": INSTANCE_IDS}, + ) + self.operator.defer = mock.MagicMock() + + @pytest.mark.parametrize( + "wait_for_completion, deferrable", + [ + pytest.param(False, False, id="no_wait"), + pytest.param(True, False, id="wait"), + pytest.param(False, True, id="defer"), + ], + ) + @mock.patch.object(SsmHook, "get_waiter") + def test_run_command_wait_combinations(self, mock_get_waiter, wait_for_completion, deferrable, mock_conn): + self.operator.wait_for_completion = wait_for_completion + self.operator.deferrable = deferrable + + command_id = self.operator.execute({}) + + assert command_id == COMMAND_ID + mock_conn.send_command.assert_called_once_with(DocumentName=DOCUMENT_NAME, InstanceIds=INSTANCE_IDS) + assert mock_get_waiter.call_count == wait_for_completion + assert self.operator.defer.call_count == deferrable + + def test_template_fields(self): + validate_template_fields(self.operator) diff --git a/providers/amazon/tests/unit/amazon/aws/sensors/test_ssm.py b/providers/amazon/tests/unit/amazon/aws/sensors/test_ssm.py new file mode 100644 index 0000000000000..faf91f379b236 --- /dev/null +++ b/providers/amazon/tests/unit/amazon/aws/sensors/test_ssm.py @@ -0,0 +1,95 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.ssm import SsmHook +from airflow.providers.amazon.aws.sensors.ssm import SsmRunCommandCompletedSensor + +COMMAND_ID = "123e4567-e89b-12d3-a456-426614174000" + + +@pytest.fixture +def mock_ssm_list_invocations(): + def _setup(mock_conn: mock.MagicMock, state: str): + mock_conn.list_command_invocations.return_value = { + "CommandInvocations": [ + {"CommandId": COMMAND_ID, "InstanceId": "i-1234567890abcdef0", "Status": state} + ] + } + + return _setup + + +class TestSsmRunCommandCompletedSensor: + SENSOR = SsmRunCommandCompletedSensor + + def setup_method(self): + self.default_op_kwarg = dict( + task_id="test_ssm_run_command_sensor", + command_id=COMMAND_ID, + poke_interval=5, + max_retries=1, + ) + + self.sensor = self.SENSOR(**self.default_op_kwarg) + + def test_base_aws_op_attributes(self): + op = self.SENSOR(**self.default_op_kwarg) + assert op.hook.aws_conn_id == "aws_default" + assert op.hook._region_name is None + assert op.hook._verify is None + assert op.hook._config is None + + op = self.SENSOR( + **self.default_op_kwarg, + aws_conn_id="aws-test-custom-conn", + region_name="eu-west-1", + verify=False, + botocore_config={"read_timeout": 42}, + ) + assert op.hook.aws_conn_id == "aws-test-custom-conn" + assert op.hook._region_name == "eu-west-1" + assert op.hook._verify is False + assert op.hook._config is not None + assert op.hook._config.read_timeout == 42 + + @pytest.mark.parametrize("state", SENSOR.SUCCESS_STATES) + @mock.patch.object(SsmHook, "conn") + def test_poke_success_states(self, mock_conn, state, mock_ssm_list_invocations): + mock_ssm_list_invocations(mock_conn, state) + self.sensor.hook.conn = mock_conn + assert self.sensor.poke({}) is True + + @pytest.mark.parametrize("state", SENSOR.INTERMEDIATE_STATES) + @mock.patch.object(SsmHook, "conn") + def test_poke_intermediate_states(self, mock_conn, state, mock_ssm_list_invocations): + mock_ssm_list_invocations(mock_conn, state) + self.sensor.hook.conn = mock_conn + assert self.sensor.poke({}) is False + + @pytest.mark.parametrize("state", SENSOR.FAILURE_STATES) + @mock.patch.object(SsmHook, "conn") + def test_poke_failure_states(self, mock_conn, state, mock_ssm_list_invocations): + mock_ssm_list_invocations(mock_conn, state) + with pytest.raises(AirflowException, match=self.SENSOR.FAILURE_MESSAGE): + self.sensor.poke({}) diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_ssm.py b/providers/amazon/tests/unit/amazon/aws/triggers/test_ssm.py new file mode 100644 index 0000000000000..994a989e8a8de --- /dev/null +++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_ssm.py @@ -0,0 +1,97 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest +from botocore.exceptions import WaiterError + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.ssm import SsmHook +from airflow.providers.amazon.aws.triggers.ssm import SsmRunCommandTrigger +from airflow.triggers.base import TriggerEvent + +from unit.amazon.aws.utils.test_waiter import assert_expected_waiter_type + +BASE_TRIGGER_CLASSPATH = "airflow.providers.amazon.aws.triggers.ssm." +EXPECTED_WAITER_NAME = "command_executed" +COMMAND_ID = "123e4567-e89b-12d3-a456-426614174000" +INSTANCE_ID_1 = "i-1234567890abcdef0" +INSTANCE_ID_2 = "i-1234567890abcdef1" + + +@pytest.fixture +def mock_ssm_list_invocations(): + def _setup(mock_async_conn): + mock_client = mock.MagicMock() + mock_async_conn.__aenter__.return_value = mock_client + mock_client.list_command_invocations.return_value = { + "CommandInvocations": [ + {"CommandId": COMMAND_ID, "InstanceId": INSTANCE_ID_1}, + {"CommandId": COMMAND_ID, "InstanceId": INSTANCE_ID_2}, + ] + } + return mock_client + + return _setup + + +class TestSsmRunCommandTrigger: + def test_serialization(self): + trigger = SsmRunCommandTrigger(command_id=COMMAND_ID) + classpath, kwargs = trigger.serialize() + + assert classpath == BASE_TRIGGER_CLASSPATH + "SsmRunCommandTrigger" + assert kwargs.get("command_id") == COMMAND_ID + + @pytest.mark.asyncio + @mock.patch.object(SsmHook, "async_conn") + @mock.patch.object(SsmHook, "get_waiter") + async def test_run_success(self, mock_get_waiter, mock_async_conn, mock_ssm_list_invocations): + mock_client = mock_ssm_list_invocations(mock_async_conn) + mock_get_waiter().wait = mock.AsyncMock(name="wait") + + trigger = SsmRunCommandTrigger(command_id=COMMAND_ID) + generator = trigger.run() + response = await generator.asend(None) + + assert response == TriggerEvent({"status": "success", "command_id": COMMAND_ID}) + assert_expected_waiter_type(mock_get_waiter, EXPECTED_WAITER_NAME) + assert mock_get_waiter().wait.call_count == 2 + mock_get_waiter().wait.assert_any_call( + CommandId=COMMAND_ID, InstanceId=INSTANCE_ID_1, WaiterConfig={"MaxAttempts": 1} + ) + mock_get_waiter().wait.assert_any_call( + CommandId=COMMAND_ID, InstanceId=INSTANCE_ID_2, WaiterConfig={"MaxAttempts": 1} + ) + mock_client.list_command_invocations.assert_called_once_with(CommandId=COMMAND_ID) + + @pytest.mark.asyncio + @mock.patch.object(SsmHook, "async_conn") + @mock.patch.object(SsmHook, "get_waiter") + async def test_run_fails(self, mock_get_waiter, mock_async_conn, mock_ssm_list_invocations): + mock_ssm_list_invocations(mock_async_conn) + mock_get_waiter().wait.side_effect = WaiterError( + "name", "terminal failure", {"CommandInvocations": [{"CommandId": COMMAND_ID}]} + ) + + trigger = SsmRunCommandTrigger(command_id=COMMAND_ID) + generator = trigger.run() + + with pytest.raises(AirflowException): + await generator.asend(None)