From cd64862dc2190772c23e740bfa8e612cfe479936 Mon Sep 17 00:00:00 2001 From: mse139 Date: Sun, 16 Mar 2025 17:56:51 -0400 Subject: [PATCH 1/5] Update ec2 opereator with AwsBaseOperator --- .../providers/amazon/aws/operators/ec2.py | 61 +++++++------------ 1 file changed, 22 insertions(+), 39 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/ec2.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/ec2.py index f3d0e9fc2af25..e6a7def635733 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/ec2.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/ec2.py @@ -21,7 +21,9 @@ from typing import TYPE_CHECKING from airflow.exceptions import AirflowException -from airflow.models import BaseOperator +from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator +from airflow.providers.amazon.aws.utils.mixins import aws_template_fields + from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook from airflow.providers.amazon.aws.links.ec2 import ( EC2InstanceDashboardLink, @@ -32,7 +34,7 @@ from airflow.utils.context import Context -class EC2StartInstanceOperator(BaseOperator): +class EC2StartInstanceOperator(AwsBaseOperator[EC2Hook]): """ Start AWS EC2 instance using boto3. @@ -51,8 +53,9 @@ class EC2StartInstanceOperator(BaseOperator): between each instance state checks until operation is completed """ + aws_hook_class = EC2Hook operator_extra_links = (EC2InstanceLink(),) - template_fields: Sequence[str] = ("instance_id", "region_name") + template_fields: Sequence[str] = aws_template_fields("instance_id", "region_name") ui_color = "#eeaa11" ui_fgcolor = "#ffffff" @@ -60,15 +63,11 @@ def __init__( self, *, instance_id: str, - aws_conn_id: str | None = "aws_default", - region_name: str | None = None, check_interval: float = 15, **kwargs, ): super().__init__(**kwargs) self.instance_id = instance_id - self.aws_conn_id = aws_conn_id - self.region_name = region_name self.check_interval = check_interval def execute(self, context: Context): @@ -90,7 +89,7 @@ def execute(self, context: Context): ) -class EC2StopInstanceOperator(BaseOperator): +class EC2StopInstanceOperator(AwsBaseOperator[EC2Hook]): """ Stop AWS EC2 instance using boto3. @@ -108,9 +107,9 @@ class EC2StopInstanceOperator(BaseOperator): :param check_interval: time in seconds that the job should wait in between each instance state checks until operation is completed """ - + aws_hook_class = EC2Hook operator_extra_links = (EC2InstanceLink(),) - template_fields: Sequence[str] = ("instance_id", "region_name") + template_fields: Sequence[str] = aws_template_fields("instance_id", "region_name") ui_color = "#eeaa11" ui_fgcolor = "#ffffff" @@ -118,15 +117,11 @@ def __init__( self, *, instance_id: str, - aws_conn_id: str | None = "aws_default", - region_name: str | None = None, check_interval: float = 15, **kwargs, ): super().__init__(**kwargs) self.instance_id = instance_id - self.aws_conn_id = aws_conn_id - self.region_name = region_name self.check_interval = check_interval def execute(self, context: Context): @@ -149,7 +144,7 @@ def execute(self, context: Context): ) -class EC2CreateInstanceOperator(BaseOperator): +class EC2CreateInstanceOperator(AwsBaseOperator[EC2Hook]): """ Create and start a specified number of EC2 Instances using boto3. @@ -175,8 +170,10 @@ class EC2CreateInstanceOperator(BaseOperator): in the `running` state before returning. """ + aws_hook_class = EC2Hook + operator_extra_links = (EC2InstanceDashboardLink(),) - template_fields: Sequence[str] = ( + template_fields: Sequence[str] = aws_template_fields( "image_id", "max_count", "min_count", @@ -191,8 +188,6 @@ def __init__( image_id: str, max_count: int = 1, min_count: int = 1, - aws_conn_id: str | None = "aws_default", - region_name: str | None = None, poll_interval: int = 20, max_attempts: int = 20, config: dict | None = None, @@ -203,8 +198,6 @@ def __init__( self.image_id = image_id self.max_count = max_count self.min_count = min_count - self.aws_conn_id = aws_conn_id - self.region_name = region_name self.poll_interval = poll_interval self.max_attempts = max_attempts self.config = config or {} @@ -258,7 +251,7 @@ def on_kill(self) -> None: super().on_kill() -class EC2TerminateInstanceOperator(BaseOperator): +class EC2TerminateInstanceOperator(AwsBaseOperator[EC2Hook]): """ Terminate EC2 Instances using boto3. @@ -281,13 +274,12 @@ class EC2TerminateInstanceOperator(BaseOperator): in the `terminated` state before returning. """ - template_fields: Sequence[str] = ("instance_ids", "region_name", "aws_conn_id", "wait_for_completion") + aws_hook_class = EC2Hook + template_fields: Sequence[str] = aws_template_fields("instance_ids", "region_name", "aws_conn_id", "wait_for_completion") def __init__( self, instance_ids: str | list[str], - aws_conn_id: str | None = "aws_default", - region_name: str | None = None, poll_interval: int = 20, max_attempts: int = 20, wait_for_completion: bool = False, @@ -295,8 +287,6 @@ def __init__( ): super().__init__(**kwargs) self.instance_ids = instance_ids - self.aws_conn_id = aws_conn_id - self.region_name = region_name self.poll_interval = poll_interval self.max_attempts = max_attempts self.wait_for_completion = wait_for_completion @@ -319,7 +309,7 @@ def execute(self, context: Context): ) -class EC2RebootInstanceOperator(BaseOperator): +class EC2RebootInstanceOperator(AwsBaseOperator[EC2Hook]): """ Reboot Amazon EC2 instances. @@ -342,8 +332,9 @@ class EC2RebootInstanceOperator(BaseOperator): in the `running` state before returning. """ + aws_hook_class = EC2Hook operator_extra_links = (EC2InstanceDashboardLink(),) - template_fields: Sequence[str] = ("instance_ids", "region_name") + template_fields: Sequence[str] = aws_template_fields("instance_ids", "region_name") ui_color = "#eeaa11" ui_fgcolor = "#ffffff" @@ -351,8 +342,6 @@ def __init__( self, *, instance_ids: str | list[str], - aws_conn_id: str | None = "aws_default", - region_name: str | None = None, poll_interval: int = 20, max_attempts: int = 20, wait_for_completion: bool = False, @@ -360,8 +349,6 @@ def __init__( ): super().__init__(**kwargs) self.instance_ids = instance_ids - self.aws_conn_id = aws_conn_id - self.region_name = region_name self.poll_interval = poll_interval self.max_attempts = max_attempts self.wait_for_completion = wait_for_completion @@ -391,7 +378,7 @@ def execute(self, context: Context): ) -class EC2HibernateInstanceOperator(BaseOperator): +class EC2HibernateInstanceOperator(AwsBaseOperator[EC2Hook]): """ Hibernate Amazon EC2 instances. @@ -413,9 +400,9 @@ class EC2HibernateInstanceOperator(BaseOperator): :param wait_for_completion: If True, the operator will wait for the instance to be in the `stopped` state before returning. """ - + aws_hook_class = EC2Hook operator_extra_links = (EC2InstanceDashboardLink(),) - template_fields: Sequence[str] = ("instance_ids", "region_name") + template_fields: Sequence[str] = aws_template_fields("instance_ids", "region_name") ui_color = "#eeaa11" ui_fgcolor = "#ffffff" @@ -423,8 +410,6 @@ def __init__( self, *, instance_ids: str | list[str], - aws_conn_id: str | None = "aws_default", - region_name: str | None = None, poll_interval: int = 20, max_attempts: int = 20, wait_for_completion: bool = False, @@ -432,8 +417,6 @@ def __init__( ): super().__init__(**kwargs) self.instance_ids = instance_ids - self.aws_conn_id = aws_conn_id - self.region_name = region_name self.poll_interval = poll_interval self.max_attempts = max_attempts self.wait_for_completion = wait_for_completion From 337111974d1fded725fff59c3e9d1a3d15f92363 Mon Sep 17 00:00:00 2001 From: mse139 Date: Sun, 16 Mar 2025 18:20:39 -0400 Subject: [PATCH 2/5] Update ec2 sensor with AWS base class --- .../airflow/providers/amazon/aws/sensors/ec2.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/ec2.py b/providers/amazon/src/airflow/providers/amazon/aws/sensors/ec2.py index 910cf9f4a8fcd..f4bbe326db530 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/sensors/ec2.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/ec2.py @@ -23,16 +23,18 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook from airflow.providers.amazon.aws.triggers.ec2 import EC2StateSensorTrigger from airflow.providers.amazon.aws.utils import validate_execute_complete_event +from airflow.providers.amazon.aws.utils.mixins import aws_template_fields from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: from airflow.utils.context import Context -class EC2InstanceStateSensor(BaseSensorOperator): +class EC2InstanceStateSensor(AwsBaseSensor[EC2Hook]): """ Poll the state of the AWS EC2 instance until the instance reaches the target state. @@ -45,8 +47,8 @@ class EC2InstanceStateSensor(BaseSensorOperator): :param region_name: (optional) aws region name associated with the client :param deferrable: if True, the sensor will run in deferrable mode """ - - template_fields: Sequence[str] = ("target_state", "instance_id", "region_name") + aws_hook_class = EC2Hook + template_fields: Sequence[str] = aws_template_fields("target_state", "instance_id", "region_name") ui_color = "#cc8811" ui_fgcolor = "#ffffff" valid_states = ["running", "stopped", "terminated"] @@ -56,8 +58,6 @@ def __init__( *, target_state: str, instance_id: str, - aws_conn_id: str | None = "aws_default", - region_name: str | None = None, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ): @@ -66,8 +66,6 @@ def __init__( super().__init__(**kwargs) self.target_state = target_state self.instance_id = instance_id - self.aws_conn_id = aws_conn_id - self.region_name = region_name self.deferrable = deferrable def execute(self, context: Context) -> Any: @@ -85,9 +83,6 @@ def execute(self, context: Context) -> Any: else: super().execute(context=context) - @cached_property - def hook(self): - return EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) def poke(self, context: Context): instance_state = self.hook.get_instance_state(instance_id=self.instance_id) From 638c2a202eac480368aef03f0fa1ee324fc53f1e Mon Sep 17 00:00:00 2001 From: mike ellis Date: Tue, 18 Mar 2025 19:48:19 +0000 Subject: [PATCH 3/5] Use AWS Base classes in EC2 operators and sensors --- providers/amazon/docs/operators/ec2.rst | 5 ++ .../providers/amazon/aws/operators/ec2.py | 73 +++++++++++-------- 2 files changed, 46 insertions(+), 32 deletions(-) diff --git a/providers/amazon/docs/operators/ec2.rst b/providers/amazon/docs/operators/ec2.rst index 2d4e07b258b1b..ce104ac7e4f1c 100644 --- a/providers/amazon/docs/operators/ec2.rst +++ b/providers/amazon/docs/operators/ec2.rst @@ -27,6 +27,11 @@ Prerequisite Tasks .. include:: ../_partials/prerequisite_tasks.rst +Generic Parameters +------------------ + +.. include:: ../_partials/generic_parameters.rst + Operators --------- diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/ec2.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/ec2.py index e6a7def635733..37aa89fc3b7c9 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/ec2.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/ec2.py @@ -71,18 +71,17 @@ def __init__( self.check_interval = check_interval def execute(self, context: Context): - ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) self.log.info("Starting EC2 instance %s", self.instance_id) - instance = ec2_hook.get_instance(instance_id=self.instance_id) + instance = self.hook.get_instance(instance_id=self.instance_id) instance.start() EC2InstanceLink.persist( context=context, operator=self, - aws_partition=ec2_hook.conn_partition, + aws_partition=self.hook.conn_partition, instance_id=self.instance_id, - region_name=ec2_hook.conn_region_name, + region_name=self.hook.conn_region_name, ) - ec2_hook.wait_for_state( + self.hook.wait_for_state( instance_id=self.instance_id, target_state="running", check_interval=self.check_interval, @@ -125,19 +124,18 @@ def __init__( self.check_interval = check_interval def execute(self, context: Context): - ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) self.log.info("Stopping EC2 instance %s", self.instance_id) - instance = ec2_hook.get_instance(instance_id=self.instance_id) + instance = self.hook.get_instance(instance_id=self.instance_id) EC2InstanceLink.persist( context=context, operator=self, - aws_partition=ec2_hook.conn_partition, + aws_partition=self.hook.conn_partition, instance_id=self.instance_id, - region_name=ec2_hook.conn_region_name, + region_name=self.hook.conn_region_name, ) instance.stop() - ec2_hook.wait_for_state( + self.hook.wait_for_state( instance_id=self.instance_id, target_state="stopped", check_interval=self.check_interval, @@ -203,9 +201,11 @@ def __init__( self.config = config or {} self.wait_for_completion = wait_for_completion + @property + def _hook_parameters(self)->dict[str:any]: + return {**super()._hook_parameters, "api_type": "client_type"} def execute(self, context: Context): - ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name, api_type="client_type") - instances = ec2_hook.conn.run_instances( + instances = self.hook.conn.run_instances( ImageId=self.image_id, MinCount=self.min_count, MaxCount=self.max_count, @@ -218,15 +218,15 @@ def execute(self, context: Context): EC2InstanceDashboardLink.persist( context=context, operator=self, - region_name=ec2_hook.conn_region_name, - aws_partition=ec2_hook.conn_partition, + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, instance_ids=EC2InstanceDashboardLink.format_instance_id_filter(instance_ids), ) for instance_id in instance_ids: self.log.info("Created EC2 instance %s", instance_id) if self.wait_for_completion: - ec2_hook.get_waiter("instance_running").wait( + self.hook.get_waiter("instance_running").wait( InstanceIds=[instance_id], WaiterConfig={ "Delay": self.poll_interval, @@ -242,12 +242,12 @@ def on_kill(self) -> None: if instance_ids: self.log.info("on_kill: Terminating instance/s %s", ", ".join(instance_ids)) - ec2_hook = EC2Hook( + """ ec2_hook = EC2Hook( aws_conn_id=self.aws_conn_id, region_name=self.region_name, api_type="client_type", - ) - ec2_hook.conn.terminate_instances(InstanceIds=instance_ids) + ) """ + self.hook.terminate_instances(InstanceIds=instance_ids) super().on_kill() @@ -291,16 +291,19 @@ def __init__( self.max_attempts = max_attempts self.wait_for_completion = wait_for_completion + @property + def _hook_parameters(self)->dict[str:any]: + return {**super()._hook_parameters, "api_type": "client_type"} + def execute(self, context: Context): if isinstance(self.instance_ids, str): self.instance_ids = [self.instance_ids] - ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name, api_type="client_type") - ec2_hook.conn.terminate_instances(InstanceIds=self.instance_ids) + self.hook.conn.terminate_instances(InstanceIds=self.instance_ids) for instance_id in self.instance_ids: self.log.info("Terminating EC2 instance %s", instance_id) if self.wait_for_completion: - ec2_hook.get_waiter("instance_terminated").wait( + self.hook.get_waiter("instance_terminated").wait( InstanceIds=[instance_id], WaiterConfig={ "Delay": self.poll_interval, @@ -353,23 +356,26 @@ def __init__( self.max_attempts = max_attempts self.wait_for_completion = wait_for_completion + @property + def _hook_parameters(self)->dict[str:any]: + return {**super()._hook_parameters, "api_type": "client_type"} + def execute(self, context: Context): if isinstance(self.instance_ids, str): self.instance_ids = [self.instance_ids] - ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name, api_type="client_type") self.log.info("Rebooting EC2 instances %s", ", ".join(self.instance_ids)) - ec2_hook.conn.reboot_instances(InstanceIds=self.instance_ids) + self.hook.conn.reboot_instances(InstanceIds=self.instance_ids) # Console link is for EC2 dashboard list, not individual instances EC2InstanceDashboardLink.persist( context=context, operator=self, - region_name=ec2_hook.conn_region_name, - aws_partition=ec2_hook.conn_partition, + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, instance_ids=EC2InstanceDashboardLink.format_instance_id_filter(self.instance_ids), ) if self.wait_for_completion: - ec2_hook.get_waiter("instance_running").wait( + self.hook.get_waiter("instance_running").wait( InstanceIds=self.instance_ids, WaiterConfig={ "Delay": self.poll_interval, @@ -421,19 +427,22 @@ def __init__( self.max_attempts = max_attempts self.wait_for_completion = wait_for_completion + @property + def _hook_parameters(self)->dict[str:any]: + return {**super()._hook_parameters, "api_type": "client_type"} + def execute(self, context: Context): if isinstance(self.instance_ids, str): self.instance_ids = [self.instance_ids] - ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name, api_type="client_type") self.log.info("Hibernating EC2 instances %s", ", ".join(self.instance_ids)) - instances = ec2_hook.get_instances(instance_ids=self.instance_ids) + instances = self.hook.get_instances(instance_ids=self.instance_ids) # Console link is for EC2 dashboard list, not individual instances EC2InstanceDashboardLink.persist( context=context, operator=self, - region_name=ec2_hook.conn_region_name, - aws_partition=ec2_hook.conn_partition, + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, instance_ids=EC2InstanceDashboardLink.format_instance_id_filter(self.instance_ids), ) @@ -442,10 +451,10 @@ def execute(self, context: Context): if not hibernation_options or not hibernation_options["Configured"]: raise AirflowException(f"Instance {instance['InstanceId']} is not configured for hibernation") - ec2_hook.conn.stop_instances(InstanceIds=self.instance_ids, Hibernate=True) + self.hook.conn.stop_instances(InstanceIds=self.instance_ids, Hibernate=True) if self.wait_for_completion: - ec2_hook.get_waiter("instance_stopped").wait( + self.hook.get_waiter("instance_stopped").wait( InstanceIds=self.instance_ids, WaiterConfig={ "Delay": self.poll_interval, From 93818836037e64596e6929bf57f08261b568e324 Mon Sep 17 00:00:00 2001 From: mike ellis Date: Wed, 19 Mar 2025 11:41:51 +0000 Subject: [PATCH 4/5] Fixing static check issues --- .../providers/amazon/aws/operators/ec2.py | 28 +++++++++++-------- .../providers/amazon/aws/sensors/ec2.py | 6 ++-- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/ec2.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/ec2.py index 37aa89fc3b7c9..0c30308a68cb6 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/ec2.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/ec2.py @@ -18,17 +18,16 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from airflow.exceptions import AirflowException -from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator -from airflow.providers.amazon.aws.utils.mixins import aws_template_fields - from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook from airflow.providers.amazon.aws.links.ec2 import ( EC2InstanceDashboardLink, EC2InstanceLink, ) +from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator +from airflow.providers.amazon.aws.utils.mixins import aws_template_fields if TYPE_CHECKING: from airflow.utils.context import Context @@ -106,6 +105,7 @@ class EC2StopInstanceOperator(AwsBaseOperator[EC2Hook]): :param check_interval: time in seconds that the job should wait in between each instance state checks until operation is completed """ + aws_hook_class = EC2Hook operator_extra_links = (EC2InstanceLink(),) template_fields: Sequence[str] = aws_template_fields("instance_id", "region_name") @@ -202,10 +202,11 @@ def __init__( self.wait_for_completion = wait_for_completion @property - def _hook_parameters(self)->dict[str:any]: + def _hook_parameters(self) -> dict[str, Any]: return {**super()._hook_parameters, "api_type": "client_type"} + def execute(self, context: Context): - instances = self.hook.conn.run_instances( + instances = self.hook.conn.run_instances( ImageId=self.image_id, MinCount=self.min_count, MaxCount=self.max_count, @@ -275,7 +276,9 @@ class EC2TerminateInstanceOperator(AwsBaseOperator[EC2Hook]): """ aws_hook_class = EC2Hook - template_fields: Sequence[str] = aws_template_fields("instance_ids", "region_name", "aws_conn_id", "wait_for_completion") + template_fields: Sequence[str] = aws_template_fields( + "instance_ids", "region_name", "aws_conn_id", "wait_for_completion" + ) def __init__( self, @@ -292,9 +295,9 @@ def __init__( self.wait_for_completion = wait_for_completion @property - def _hook_parameters(self)->dict[str:any]: + def _hook_parameters(self) -> dict[str, Any]: return {**super()._hook_parameters, "api_type": "client_type"} - + def execute(self, context: Context): if isinstance(self.instance_ids, str): self.instance_ids = [self.instance_ids] @@ -357,7 +360,7 @@ def __init__( self.wait_for_completion = wait_for_completion @property - def _hook_parameters(self)->dict[str:any]: + def _hook_parameters(self) -> dict[str, Any]: return {**super()._hook_parameters, "api_type": "client_type"} def execute(self, context: Context): @@ -406,6 +409,7 @@ class EC2HibernateInstanceOperator(AwsBaseOperator[EC2Hook]): :param wait_for_completion: If True, the operator will wait for the instance to be in the `stopped` state before returning. """ + aws_hook_class = EC2Hook operator_extra_links = (EC2InstanceDashboardLink(),) template_fields: Sequence[str] = aws_template_fields("instance_ids", "region_name") @@ -428,9 +432,9 @@ def __init__( self.wait_for_completion = wait_for_completion @property - def _hook_parameters(self)->dict[str:any]: + def _hook_parameters(self) -> dict[str, Any]: return {**super()._hook_parameters, "api_type": "client_type"} - + def execute(self, context: Context): if isinstance(self.instance_ids, str): self.instance_ids = [self.instance_ids] diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/ec2.py b/providers/amazon/src/airflow/providers/amazon/aws/sensors/ec2.py index f4bbe326db530..7b9b650784496 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/sensors/ec2.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/ec2.py @@ -18,17 +18,15 @@ from __future__ import annotations from collections.abc import Sequence -from functools import cached_property from typing import TYPE_CHECKING, Any from airflow.configuration import conf from airflow.exceptions import AirflowException -from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook +from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor from airflow.providers.amazon.aws.triggers.ec2 import EC2StateSensorTrigger from airflow.providers.amazon.aws.utils import validate_execute_complete_event from airflow.providers.amazon.aws.utils.mixins import aws_template_fields -from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: from airflow.utils.context import Context @@ -47,6 +45,7 @@ class EC2InstanceStateSensor(AwsBaseSensor[EC2Hook]): :param region_name: (optional) aws region name associated with the client :param deferrable: if True, the sensor will run in deferrable mode """ + aws_hook_class = EC2Hook template_fields: Sequence[str] = aws_template_fields("target_state", "instance_id", "region_name") ui_color = "#cc8811" @@ -83,7 +82,6 @@ def execute(self, context: Context) -> Any: else: super().execute(context=context) - def poke(self, context: Context): instance_state = self.hook.get_instance_state(instance_id=self.instance_id) self.log.info("instance state: %s", instance_state) From b0d5431eeea0be3cd4b94ddc3e6fead8986cdcaa Mon Sep 17 00:00:00 2001 From: mike ellis Date: Thu, 20 Mar 2025 12:34:38 +0000 Subject: [PATCH 5/5] PR feedback changes --- providers/amazon/docs/operators/ec2.rst | 2 +- .../providers/amazon/aws/operators/ec2.py | 38 ++++++++++++------- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/providers/amazon/docs/operators/ec2.rst b/providers/amazon/docs/operators/ec2.rst index ce104ac7e4f1c..79246e6f13f14 100644 --- a/providers/amazon/docs/operators/ec2.rst +++ b/providers/amazon/docs/operators/ec2.rst @@ -31,7 +31,7 @@ Generic Parameters ------------------ .. include:: ../_partials/generic_parameters.rst - + Operators --------- diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/ec2.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/ec2.py index 0c30308a68cb6..2ade558742b91 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/ec2.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/ec2.py @@ -42,12 +42,14 @@ class EC2StartInstanceOperator(AwsBaseOperator[EC2Hook]): :ref:`howto/operator:EC2StartInstanceOperator` :param instance_id: id of the AWS EC2 instance - :param aws_conn_id: The Airflow connection used for AWS credentials. - If this is None or empty then the default boto3 behaviour is used. If + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If running Airflow in a distributed manner and aws_conn_id is None or empty, then default boto3 configuration would be used (and must be maintained on each worker node). - :param region_name: (optional) aws region name associated with the client + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html :param check_interval: time in seconds that the job should wait in between each instance state checks until operation is completed """ @@ -97,11 +99,13 @@ class EC2StopInstanceOperator(AwsBaseOperator[EC2Hook]): :param instance_id: id of the AWS EC2 instance :param aws_conn_id: The Airflow connection used for AWS credentials. - If this is None or empty then the default boto3 behaviour is used. If + If this is ``None`` or empty then the default boto3 behaviour is used. If running Airflow in a distributed manner and aws_conn_id is None or empty, then default boto3 configuration would be used (and must be maintained on each worker node). - :param region_name: (optional) aws region name associated with the client + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html :param check_interval: time in seconds that the job should wait in between each instance state checks until operation is completed """ @@ -154,11 +158,13 @@ class EC2CreateInstanceOperator(AwsBaseOperator[EC2Hook]): :param max_count: Maximum number of instances to launch. Defaults to 1. :param min_count: Minimum number of instances to launch. Defaults to 1. :param aws_conn_id: The Airflow connection used for AWS credentials. - If this is None or empty then the default boto3 behaviour is used. If + If this is ``None`` or empty then the default boto3 behaviour is used. If running Airflow in a distributed manner and aws_conn_id is None or empty, then default boto3 configuration would be used (and must be maintained on each worker node). - :param region_name: AWS region name associated with the client. + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html :param poll_interval: Number of seconds to wait before attempting to check state of instance. Only used if wait_for_completion is True. Default is 20. :param max_attempts: Maximum number of attempts when checking state of instance. @@ -262,11 +268,13 @@ class EC2TerminateInstanceOperator(AwsBaseOperator[EC2Hook]): :param instance_id: ID of the instance to be terminated. :param aws_conn_id: The Airflow connection used for AWS credentials. - If this is None or empty then the default boto3 behaviour is used. If + If this is ``None`` or empty then the default boto3 behaviour is used. If running Airflow in a distributed manner and aws_conn_id is None or empty, then default boto3 configuration would be used (and must be maintained on each worker node). - :param region_name: AWS region name associated with the client. + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html :param poll_interval: Number of seconds to wait before attempting to check state of instance. Only used if wait_for_completion is True. Default is 20. :param max_attempts: Maximum number of attempts when checking state of instance. @@ -325,11 +333,13 @@ class EC2RebootInstanceOperator(AwsBaseOperator[EC2Hook]): :param instance_ids: ID of the instance(s) to be rebooted. :param aws_conn_id: The Airflow connection used for AWS credentials. - If this is None or empty then the default boto3 behaviour is used. If + If this is ``None`` or empty then the default boto3 behaviour is used. If running Airflow in a distributed manner and aws_conn_id is None or empty, then default boto3 configuration would be used (and must be maintained on each worker node). - :param region_name: AWS region name associated with the client. + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html :param poll_interval: Number of seconds to wait before attempting to check state of instance. Only used if wait_for_completion is True. Default is 20. :param max_attempts: Maximum number of attempts when checking state of instance. @@ -397,11 +407,13 @@ class EC2HibernateInstanceOperator(AwsBaseOperator[EC2Hook]): :param instance_ids: ID of the instance(s) to be hibernated. :param aws_conn_id: The Airflow connection used for AWS credentials. - If this is None or empty then the default boto3 behaviour is used. If + If this is ``None`` or empty then the default boto3 behaviour is used. If running Airflow in a distributed manner and aws_conn_id is None or empty, then default boto3 configuration would be used (and must be maintained on each worker node). - :param region_name: AWS region name associated with the client. + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html :param poll_interval: Number of seconds to wait before attempting to check state of instance. Only used if wait_for_completion is True. Default is 20. :param max_attempts: Maximum number of attempts when checking state of instance.