diff --git a/providers/amazon/src/airflow/providers/amazon/aws/links/base_aws.py b/providers/amazon/src/airflow/providers/amazon/aws/links/base_aws.py index c97497b7d8aad..083fffdcd6345 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/links/base_aws.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/links/base_aws.py @@ -20,20 +20,13 @@ from typing import TYPE_CHECKING, ClassVar from airflow.providers.amazon.aws.utils.suppress import return_on_error -from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS +from airflow.providers.amazon.version_compat import BaseOperatorLink, XCom if TYPE_CHECKING: from airflow.models import BaseOperator from airflow.models.taskinstancekey import TaskInstanceKey from airflow.utils.context import Context -if AIRFLOW_V_3_0_PLUS: - from airflow.sdk import BaseOperatorLink - from airflow.sdk.execution_time.xcom import XCom -else: - from airflow.models import XCom # type: ignore[no-redef] - from airflow.models.baseoperatorlink import BaseOperatorLink # type: ignore[no-redef] - BASE_AWS_CONSOLE_LINK = "https://console.{aws_domain}" @@ -94,8 +87,7 @@ def persist( if not operator.do_xcom_push: return - operator.xcom_push( - context, + context["ti"].xcom_push( key=cls.key, value={ "region_name": region_name, diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/base_aws.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/base_aws.py index ce9855393319f..c998ebc046630 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/base_aws.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/base_aws.py @@ -19,13 +19,13 @@ from collections.abc import Sequence -from airflow.models import BaseOperator from airflow.providers.amazon.aws.utils.mixins import ( AwsBaseHookMixin, AwsHookParams, AwsHookType, aws_template_fields, ) +from airflow.providers.amazon.version_compat import BaseOperator from airflow.utils.types import NOTSET, ArgNotSet diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/ecs.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/ecs.py index 7205c396129b7..f7d74eee0dd15 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/ecs.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/ecs.py @@ -526,7 +526,7 @@ def execute(self, context): self._start_task() if self.do_xcom_push: - self.xcom_push(context, key="ecs_task_arn", value=self.arn) + context["ti"].xcom_push(key="ecs_task_arn", value=self.arn) if self.deferrable: self.defer( diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/redshift_data.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/redshift_data.py index 6bc62b0cedb63..b7284acf1602a 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/redshift_data.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/redshift_data.py @@ -159,7 +159,7 @@ def execute(self, context: Context) -> list[GetStatementResultResponseTypeDef] | self.statement_id: str = query_execution_output.statement_id if query_execution_output.session_id: - self.xcom_push(context, key="session_id", value=query_execution_output.session_id) + context["ti"].xcom_push(key="session_id", value=query_execution_output.session_id) if self.deferrable and self.wait_for_completion: is_finished: bool = self.hook.check_query_is_finished(self.statement_id) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py index c872c56afa634..cf204b84c6e0d 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py @@ -24,7 +24,6 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException -from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio import ( SageMakerNotebookHook, ) @@ -34,6 +33,7 @@ from airflow.providers.amazon.aws.triggers.sagemaker_unified_studio import ( SageMakerNotebookJobTrigger, ) +from airflow.providers.amazon.version_compat import BaseOperator if TYPE_CHECKING: from airflow.utils.context import Context diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/base_aws.py b/providers/amazon/src/airflow/providers/amazon/aws/sensors/base_aws.py index 7c3fad685b1f9..58db967e292c5 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/sensors/base_aws.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/base_aws.py @@ -25,13 +25,7 @@ AwsHookType, aws_template_fields, ) -from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS - -if AIRFLOW_V_3_0_PLUS: - from airflow.sdk import BaseSensorOperator -else: - from airflow.sensors.base import BaseSensorOperator # type: ignore[no-redef] - +from airflow.providers.amazon.version_compat import BaseSensorOperator from airflow.utils.types import NOTSET, ArgNotSet diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py b/providers/amazon/src/airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py index 9cec1215b009a..072e7f1e48feb 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py @@ -25,12 +25,7 @@ from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio import ( SageMakerNotebookHook, ) -from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS - -if AIRFLOW_V_3_0_PLUS: - from airflow.sdk import BaseSensorOperator -else: - from airflow.sensors.base import BaseSensorOperator # type: ignore[no-redef] +from airflow.providers.amazon.version_compat import BaseSensorOperator if TYPE_CHECKING: from airflow.utils.context import Context diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/azure_blob_to_s3.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/azure_blob_to_s3.py index 60a09bdc33f4c..e58d43393c85a 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/azure_blob_to_s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/azure_blob_to_s3.py @@ -22,8 +22,8 @@ from collections.abc import Sequence from typing import TYPE_CHECKING -from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.amazon.version_compat import BaseOperator try: from airflow.providers.microsoft.azure.hooks.wasb import WasbHook diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/base.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/base.py index f05f3c5c68000..d678a43cdf846 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/base.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/base.py @@ -21,8 +21,8 @@ from collections.abc import Sequence -from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from airflow.providers.amazon.version_compat import BaseOperator from airflow.utils.types import NOTSET, ArgNotSet diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/exasol_to_s3.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/exasol_to_s3.py index a0f1816c77836..3b231b9621cfb 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/exasol_to_s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/exasol_to_s3.py @@ -23,8 +23,8 @@ from tempfile import NamedTemporaryFile from typing import TYPE_CHECKING -from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.amazon.version_compat import BaseOperator from airflow.providers.exasol.hooks.exasol import ExasolHook if TYPE_CHECKING: diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/ftp_to_s3.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/ftp_to_s3.py index 9b6f7ea4a1027..92cfcd870806c 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/ftp_to_s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/ftp_to_s3.py @@ -21,8 +21,8 @@ from tempfile import NamedTemporaryFile from typing import TYPE_CHECKING -from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.amazon.version_compat import BaseOperator from airflow.providers.ftp.hooks.ftp import FTPHook if TYPE_CHECKING: diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/gcs_to_s3.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/gcs_to_s3.py index 6899f7daffc9c..61e6386cc8f9e 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/gcs_to_s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/gcs_to_s3.py @@ -26,8 +26,8 @@ from packaging.version import Version from airflow.exceptions import AirflowException -from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.amazon.version_compat import BaseOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook if TYPE_CHECKING: diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/glacier_to_gcs.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/glacier_to_gcs.py index 966afec02a788..1e0e51841514b 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/glacier_to_gcs.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/glacier_to_gcs.py @@ -21,8 +21,8 @@ from collections.abc import Sequence from typing import TYPE_CHECKING -from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.glacier import GlacierHook +from airflow.providers.amazon.version_compat import BaseOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook if TYPE_CHECKING: diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/google_api_to_s3.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/google_api_to_s3.py index 157477341b44d..d6dddea7eaea3 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/google_api_to_s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/google_api_to_s3.py @@ -24,9 +24,9 @@ from collections.abc import Sequence from typing import TYPE_CHECKING -from airflow.models import BaseOperator from airflow.models.xcom import MAX_XCOM_SIZE, XCOM_RETURN_KEY from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.amazon.version_compat import BaseOperator from airflow.providers.google.common.hooks.discovery_api import GoogleDiscoveryApiHook if TYPE_CHECKING: diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py index c96ae1d96c155..1cd5e4ce22d15 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py @@ -23,8 +23,8 @@ from collections.abc import Callable, Sequence from typing import TYPE_CHECKING -from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.dynamodb import DynamoDBHook +from airflow.providers.amazon.version_compat import BaseOperator from airflow.providers.apache.hive.hooks.hive import HiveServer2Hook if TYPE_CHECKING: diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/http_to_s3.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/http_to_s3.py index 373a31c87e1c3..3a91312cc07b7 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/http_to_s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/http_to_s3.py @@ -22,8 +22,8 @@ from functools import cached_property from typing import TYPE_CHECKING, Any -from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.amazon.version_compat import BaseOperator from airflow.providers.http.hooks.http import HttpHook if TYPE_CHECKING: diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py index cf1b96bfa2e7d..5b9979b2ac6a2 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py @@ -22,8 +22,8 @@ from collections.abc import Sequence from typing import TYPE_CHECKING -from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.amazon.version_compat import BaseOperator from airflow.providers.imap.hooks.imap import ImapHook if TYPE_CHECKING: diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/local_to_s3.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/local_to_s3.py index aa6678af6b8f7..367aedc34add0 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/local_to_s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/local_to_s3.py @@ -20,8 +20,8 @@ from collections.abc import Sequence from typing import TYPE_CHECKING -from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.amazon.version_compat import BaseOperator if TYPE_CHECKING: from airflow.utils.context import Context diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/mongo_to_s3.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/mongo_to_s3.py index d729655ae8e52..1515b5a951bea 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/mongo_to_s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/mongo_to_s3.py @@ -23,8 +23,8 @@ from bson import json_util -from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.amazon.version_compat import BaseOperator from airflow.providers.mongo.hooks.mongo import MongoHook if TYPE_CHECKING: diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/redshift_to_s3.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/redshift_to_s3.py index 1d39ac2affd8d..4485eb44779cc 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/redshift_to_s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/redshift_to_s3.py @@ -24,11 +24,11 @@ from typing import TYPE_CHECKING from airflow.exceptions import AirflowException -from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.providers.amazon.aws.utils.redshift import build_credentials_block +from airflow.providers.amazon.version_compat import BaseOperator from airflow.utils.types import NOTSET, ArgNotSet if TYPE_CHECKING: diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_dynamodb.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_dynamodb.py index f3e7291e142eb..e82c3dc5209a1 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_dynamodb.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_dynamodb.py @@ -23,8 +23,8 @@ from botocore.exceptions import ClientError, WaiterError from airflow.exceptions import AirflowException -from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.dynamodb import DynamoDBHook +from airflow.providers.amazon.version_compat import BaseOperator if TYPE_CHECKING: from airflow.utils.context import Context diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_ftp.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_ftp.py index ee7dd0306b768..8ce41b7816d6c 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_ftp.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_ftp.py @@ -21,8 +21,8 @@ from tempfile import NamedTemporaryFile from typing import TYPE_CHECKING -from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.amazon.version_compat import BaseOperator from airflow.providers.ftp.hooks.ftp import FTPHook if TYPE_CHECKING: diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_redshift.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_redshift.py index 40609dc755d92..906710e622bca 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_redshift.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_redshift.py @@ -20,11 +20,11 @@ from typing import TYPE_CHECKING from airflow.exceptions import AirflowException -from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.providers.amazon.aws.utils.redshift import build_credentials_block +from airflow.providers.amazon.version_compat import BaseOperator from airflow.utils.types import NOTSET, ArgNotSet if TYPE_CHECKING: diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_sftp.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_sftp.py index 7abf300844bab..817e88b8a4850 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_sftp.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_sftp.py @@ -22,8 +22,8 @@ from typing import TYPE_CHECKING from urllib.parse import urlsplit -from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.amazon.version_compat import BaseOperator from airflow.providers.ssh.hooks.ssh import SSHHook if TYPE_CHECKING: diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_sql.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_sql.py index baeceb4387650..7d5a69aed431a 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_sql.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_sql.py @@ -23,8 +23,8 @@ from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook -from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.amazon.version_compat import BaseOperator if TYPE_CHECKING: from airflow.utils.context import Context diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/salesforce_to_s3.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/salesforce_to_s3.py index 79bdfb983f8a3..ac50323ccb7d5 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/salesforce_to_s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/salesforce_to_s3.py @@ -21,8 +21,8 @@ from collections.abc import Sequence from typing import TYPE_CHECKING -from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.amazon.version_compat import BaseOperator from airflow.providers.salesforce.hooks.salesforce import SalesforceHook if TYPE_CHECKING: diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/sftp_to_s3.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/sftp_to_s3.py index 141785fb88663..8ba11155f2f8c 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/sftp_to_s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/sftp_to_s3.py @@ -22,8 +22,8 @@ from typing import TYPE_CHECKING from urllib.parse import urlsplit -from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.amazon.version_compat import BaseOperator from airflow.providers.ssh.hooks.ssh import SSHHook if TYPE_CHECKING: diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/sql_to_s3.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/sql_to_s3.py index 29f58beaf4f58..fb92a4b4569b1 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/sql_to_s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/sql_to_s3.py @@ -26,8 +26,8 @@ from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook -from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.amazon.version_compat import BaseOperator if TYPE_CHECKING: import pandas as pd diff --git a/providers/amazon/src/airflow/providers/amazon/version_compat.py b/providers/amazon/src/airflow/providers/amazon/version_compat.py index 48d122b669696..badf23bd1ee77 100644 --- a/providers/amazon/src/airflow/providers/amazon/version_compat.py +++ b/providers/amazon/src/airflow/providers/amazon/version_compat.py @@ -33,3 +33,13 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0) + +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk import BaseOperator, BaseOperatorLink, BaseSensorOperator + from airflow.sdk.execution_time.xcom import XCom +else: + from airflow.models import BaseOperator, XCom # type: ignore[no-redef] + from airflow.models.baseoperatorlink import BaseOperatorLink # type: ignore[no-redef] + from airflow.sensors.base import BaseSensorOperator # type: ignore[no-redef] + +__all__ = ["AIRFLOW_V_3_0_PLUS", "BaseOperator", "BaseOperatorLink", "BaseSensorOperator", "XCom"] diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_base_aws.py b/providers/amazon/tests/unit/amazon/aws/links/test_base_aws.py index 6155f1747bff6..383f154f3db8a 100644 --- a/providers/amazon/tests/unit/amazon/aws/links/test_base_aws.py +++ b/providers/amazon/tests/unit/amazon/aws/links/test_base_aws.py @@ -23,18 +23,14 @@ import pytest from airflow.providers.amazon.aws.links.base_aws import BaseAwsLink +from airflow.providers.amazon.version_compat import XCom from airflow.serialization.serialized_objects import SerializedDAG from tests_common.test_utils.mock_operators import MockOperator -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS if TYPE_CHECKING: from airflow.models.taskinstance import TaskInstance -if AIRFLOW_V_3_0_PLUS: - from airflow.sdk.execution_time.xcom import XCom -else: - from airflow.models import XCom # type: ignore[no-redef] XCOM_KEY = "test_xcom_key" CUSTOM_KEYS = { @@ -81,13 +77,10 @@ def test_persist(self, region_name, aws_partition, keywords, expected_value): ) ti = mock_context["ti"] - if AIRFLOW_V_3_0_PLUS: - ti.xcom_push.assert_called_once_with( - key=XCOM_KEY, - value=expected_value, - ) - else: - ti.xcom_push.assert_called_once_with(key=XCOM_KEY, value=expected_value, execution_date=None) + ti.xcom_push.assert_called_once_with( + key=XCOM_KEY, + value=expected_value, + ) def test_disable_xcom_push(self): mock_context = mock.MagicMock() @@ -102,18 +95,19 @@ def test_disable_xcom_push(self): def test_suppress_error_on_xcom_push(self): mock_context = mock.MagicMock() - with mock.patch.object(MockOperator, "xcom_push", side_effect=PermissionError("FakeError")) as m: - SimpleBaseAwsLink.persist( - context=mock_context, - operator=MockOperator(task_id="test_task_id"), - region_name="eu-east-1", - aws_partition="aws", - ) - m.assert_called_once_with( - mock_context, - key="test_xcom_key", - value={"region_name": "eu-east-1", "aws_domain": "aws.amazon.com"}, - ) + mock_context["ti"].xcom_push.side_effect = PermissionError("FakeError") + + SimpleBaseAwsLink.persist( + context=mock_context, + operator=MockOperator(task_id="test_task_id"), + region_name="eu-east-1", + aws_partition="aws", + ) + + mock_context["ti"].xcom_push.assert_called_once_with( + key="test_xcom_key", + value={"region_name": "eu-east-1", "aws_domain": "aws.amazon.com"}, + ) def link_test_operator(*links): diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_ecs.py b/providers/amazon/tests/unit/amazon/aws/operators/test_ecs.py index b1987852b887e..f55d4b5bd74f5 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_ecs.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_ecs.py @@ -332,7 +332,6 @@ def test_template_fields_overrides(self): ], ], ) - @mock.patch.object(EcsRunTaskOperator, "xcom_push") @mock.patch.object(EcsRunTaskOperator, "_wait_for_task_ended") @mock.patch.object(EcsRunTaskOperator, "_check_success_task") @mock.patch.object(EcsBaseOperator, "client") @@ -341,7 +340,6 @@ def test_execute_without_failures( client_mock, check_mock, wait_mock, - xcom_mock, launch_type, capacity_provider_strategy, platform_version, @@ -358,7 +356,10 @@ def test_execute_without_failures( ) client_mock.run_task.return_value = RESPONSE_WITHOUT_FAILURES - self.ecs.execute(None) + mock_ti = mock.MagicMock() + mock_context = {"ti": mock_ti, "task_instance": mock_ti} + + self.ecs.execute(mock_context) # type: ignore[arg-type] client_mock.run_task.assert_called_once_with( cluster="c", @@ -389,8 +390,11 @@ def test_execute_with_failures(self, client_mock): resp_failures["failures"].append("dummy error") client_mock.run_task.return_value = resp_failures + mock_ti = mock.MagicMock() + mock_context = {"ti": mock_ti, "task_instance": mock_ti} + with pytest.raises(EcsOperatorError): - self.ecs.execute(None) + self.ecs.execute(mock_context) # type: ignore[arg-type] client_mock.run_task.assert_called_once_with( cluster="c", @@ -700,49 +704,62 @@ def test_reattach_save_task_arn_xcom( assert self.ecs.arn == f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}" assert "No active previously launched task found to reattach" in caplog.messages - @mock.patch.object(EcsRunTaskOperator, "xcom_push") @mock.patch.object(EcsBaseOperator, "client") @mock.patch("airflow.providers.amazon.aws.utils.task_log_fetcher.AwsTaskLogFetcher") - def test_execute_xcom_with_log(self, log_fetcher_mock, client_mock, xcom_mock): + def test_execute_xcom_with_log(self, log_fetcher_mock, client_mock): self.ecs.do_xcom_push = True self.ecs.task_log_fetcher = log_fetcher_mock log_fetcher_mock.get_last_log_message.return_value = "Log output" - assert self.ecs.execute(None) == "Log output" + mock_ti = mock.MagicMock() + mock_context = {"ti": mock_ti, "task_instance": mock_ti} + + assert self.ecs.execute(mock_context) == "Log output" # type: ignore[arg-type] - @mock.patch.object(EcsRunTaskOperator, "xcom_push") @mock.patch.object(EcsBaseOperator, "client") @mock.patch("airflow.providers.amazon.aws.utils.task_log_fetcher.AwsTaskLogFetcher") - def test_execute_xcom_with_no_log(self, log_fetcher_mock, client_mock, xcom_mock): + def test_execute_xcom_with_no_log(self, log_fetcher_mock, client_mock): self.ecs.do_xcom_push = True self.ecs.task_log_fetcher = log_fetcher_mock log_fetcher_mock.get_last_log_message.return_value = None - assert self.ecs.execute(None) is None + mock_ti = mock.MagicMock() + mock_context = {"ti": mock_ti, "task_instance": mock_ti} + + assert self.ecs.execute(mock_context) is None # type: ignore[arg-type] - @mock.patch.object(EcsRunTaskOperator, "xcom_push") @mock.patch.object(EcsBaseOperator, "client") - def test_execute_xcom_with_no_log_fetcher(self, client_mock, xcom_mock): + def test_execute_xcom_with_no_log_fetcher(self, client_mock): self.ecs.do_xcom_push = True - assert self.ecs.execute(None) is None + + mock_ti = mock.MagicMock() + mock_context = {"ti": mock_ti, "task_instance": mock_ti} + + assert self.ecs.execute(mock_context) is None # type: ignore[arg-type] @mock.patch.object(EcsBaseOperator, "client") @mock.patch.object(AwsTaskLogFetcher, "get_last_log_message", return_value="Log output") def test_execute_xcom_disabled(self, log_fetcher_mock, client_mock): self.ecs.do_xcom_push = False - assert self.ecs.execute(None) is None - @mock.patch.object(EcsRunTaskOperator, "xcom_push") + mock_ti = mock.MagicMock() + mock_context = {"ti": mock_ti, "task_instance": mock_ti} + + assert self.ecs.execute(mock_context) is None # type: ignore[arg-type] + @mock.patch.object(EcsRunTaskOperator, "client") - def test_with_defer(self, client_mock, xcom_mock): + def test_with_defer(self, client_mock): self.ecs.deferrable = True client_mock.run_task.return_value = RESPONSE_WITHOUT_FAILURES + mock_ti = mock.MagicMock() + mock_context = {"ti": mock_ti, "task_instance": mock_ti} + with pytest.raises(TaskDeferred) as deferred: - self.ecs.execute(None) + self.ecs.execute(mock_context) # type: ignore[arg-type] assert isinstance(deferred.value.trigger, TaskDoneTrigger) assert deferred.value.trigger.task_arn == f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}" @@ -752,7 +769,10 @@ def test_execute_complete(self, client_mock): event = {"status": "success", "task_arn": "my_arn", "cluster": "test_cluster"} self.ecs.reattach = True - self.ecs.execute_complete(None, event) + mock_ti = mock.MagicMock() + mock_context = {"ti": mock_ti, "task_instance": mock_ti} + + self.ecs.execute_complete(mock_context, event) # type: ignore[arg-type] # task gets described to assert its success client_mock().describe_tasks.assert_called_once_with(cluster="test_cluster", tasks=["my_arn"])