diff --git a/airflow-core/docs/public-airflow-interface.rst b/airflow-core/docs/public-airflow-interface.rst index 88eed605eefa4..0d74912cc46db 100644 --- a/airflow-core/docs/public-airflow-interface.rst +++ b/airflow-core/docs/public-airflow-interface.rst @@ -138,7 +138,7 @@ Hooks Hooks are interfaces to external platforms and databases, implementing a common interface when possible and acting as building blocks for operators. All hooks -are derived from :class:`~airflow.hooks.base.BaseHook`. +are derived from :class:`~airflow.sdk.bases.hook.BaseHook`. Airflow has a set of Hooks that are considered public. You are free to extend their functionality by extending them: diff --git a/airflow-core/src/airflow/cli/commands/connection_command.py b/airflow-core/src/airflow/cli/commands/connection_command.py index 9d2358b029b3d..79bd77e53bf50 100644 --- a/airflow-core/src/airflow/cli/commands/connection_command.py +++ b/airflow-core/src/airflow/cli/commands/connection_command.py @@ -33,7 +33,6 @@ from airflow.cli.utils import is_stdout, print_export_output from airflow.configuration import conf from airflow.exceptions import AirflowNotFoundException -from airflow.hooks.base import BaseHook from airflow.models import Connection from airflow.providers_manager import ProvidersManager from airflow.secrets.local_filesystem import load_connections_dict @@ -67,6 +66,8 @@ def _connection_mapper(conn: Connection) -> dict[str, Any]: def connections_get(args): """Get a connection.""" try: + from airflow.sdk import BaseHook + conn = BaseHook.get_connection(args.conn_id) except AirflowNotFoundException: raise SystemExit("Connection not found.") diff --git a/airflow-core/src/airflow/hooks/__init__.py b/airflow-core/src/airflow/hooks/__init__.py index 9b9ef41aa89b9..6bd03fa282b09 100644 --- a/airflow-core/src/airflow/hooks/__init__.py +++ b/airflow-core/src/airflow/hooks/__init__.py @@ -31,5 +31,8 @@ "subprocess": { "SubprocessHook": "airflow.providers.standard.hooks.subprocess.SubprocessHook", }, + "base": { + "BaseHook": "airflow.sdk.bases.hook.BaseHook", + }, } add_deprecated_classes(__deprecated_classes, __name__) diff --git a/airflow-core/src/airflow/hooks/base.py b/airflow-core/src/airflow/hooks/base.py index 88a72ddf6e2bf..41a9b26ac9f8e 100644 --- a/airflow-core/src/airflow/hooks/base.py +++ b/airflow-core/src/airflow/hooks/base.py @@ -19,75 +19,7 @@ from __future__ import annotations -import logging -from typing import TYPE_CHECKING, Any, Protocol - -from airflow.utils.log.logging_mixin import LoggingMixin - -if TYPE_CHECKING: - from airflow.models.connection import Connection # Avoid circular imports. - -log = logging.getLogger(__name__) - - -class BaseHook(LoggingMixin): - """ - Abstract base class for hooks. - - Hooks are meant as an interface to - interact with external systems. MySqlHook, HiveHook, PigHook return - object that can handle the connection and interaction to specific - instances of these systems, and expose consistent methods to interact - with them. - - :param logger_name: Name of the logger used by the Hook to emit logs. - If set to `None` (default), the logger name will fall back to - `airflow.task.hooks.{class.__module__}.{class.__name__}` (e.g. DbApiHook will have - *airflow.task.hooks.airflow.providers.common.sql.hooks.sql.DbApiHook* as logger). - """ - - def __init__(self, logger_name: str | None = None): - super().__init__() - self._log_config_logger_name = "airflow.task.hooks" - self._logger_name = logger_name - - @classmethod - def get_connection(cls, conn_id: str) -> Connection: - """ - Get connection, given connection id. - - :param conn_id: connection id - :return: connection - """ - from airflow.models.connection import Connection - - conn = Connection.get_connection_from_secrets(conn_id) - log.info("Connection Retrieved '%s'", conn.conn_id) - return conn - - @classmethod - def get_hook(cls, conn_id: str, hook_params: dict | None = None) -> BaseHook: - """ - Return default hook for this connection id. - - :param conn_id: connection id - :param hook_params: hook parameters - :return: default hook for this connection - """ - connection = cls.get_connection(conn_id) - return connection.get_hook(hook_params=hook_params) - - def get_conn(self) -> Any: - """Return connection for the hook.""" - raise NotImplementedError() - - @classmethod - def get_connection_form_widgets(cls) -> dict[str, Any]: - return {} - - @classmethod - def get_ui_field_behaviour(cls) -> dict[str, Any]: - return {} +from typing import Any, Protocol class DiscoverableHook(Protocol): diff --git a/airflow-core/src/airflow/lineage/hook.py b/airflow-core/src/airflow/lineage/hook.py index 712b778aef6aa..b69f12484dc1b 100644 --- a/airflow-core/src/airflow/lineage/hook.py +++ b/airflow-core/src/airflow/lineage/hook.py @@ -29,8 +29,7 @@ from airflow.utils.log.logging_mixin import LoggingMixin if TYPE_CHECKING: - from airflow.hooks.base import BaseHook - from airflow.sdk import ObjectStoragePath + from airflow.sdk import BaseHook, ObjectStoragePath # Store context what sent lineage. LineageContext: TypeAlias = BaseHook | ObjectStoragePath diff --git a/airflow-core/src/airflow/providers_manager.py b/airflow-core/src/airflow/providers_manager.py index 757957ad5d039..eba323f869d53 100644 --- a/airflow-core/src/airflow/providers_manager.py +++ b/airflow-core/src/airflow/providers_manager.py @@ -83,7 +83,7 @@ def ensure_prefix(field): if TYPE_CHECKING: from urllib.parse import SplitResult - from airflow.hooks.base import BaseHook + from airflow.sdk import BaseHook from airflow.sdk.bases.decorator import TaskDecorator from airflow.sdk.definitions.asset import Asset diff --git a/airflow-core/src/airflow/utils/email.py b/airflow-core/src/airflow/utils/email.py index b4c60350f5c1b..e47ea49dce80e 100644 --- a/airflow-core/src/airflow/utils/email.py +++ b/airflow-core/src/airflow/utils/email.py @@ -246,7 +246,7 @@ def send_mime_email( if conn_id is not None: try: - from airflow.hooks.base import BaseHook + from airflow.sdk import BaseHook airflow_conn = BaseHook.get_connection(conn_id) smtp_user = airflow_conn.login diff --git a/airflow-core/tests/unit/always/test_connection.py b/airflow-core/tests/unit/always/test_connection.py index 2df03cd94dc2b..ef6fa2213c270 100644 --- a/airflow-core/tests/unit/always/test_connection.py +++ b/airflow-core/tests/unit/always/test_connection.py @@ -29,8 +29,8 @@ from cryptography.fernet import Fernet from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook from airflow.models import Connection, crypto +from airflow.sdk import BaseHook from tests_common.test_utils.version_compat import SQLALCHEMY_V_1_4 @@ -640,8 +640,18 @@ def test_param_setup(self): assert conn.port is None @pytest.mark.db_test - def test_env_var_priority(self): + def test_env_var_priority(self, mock_supervisor_comms): from airflow.providers.sqlite.hooks.sqlite import SqliteHook + from airflow.sdk.execution_time.comms import ConnectionResult + + conn = ConnectionResult( + conn_id="airflow_db", + conn_type="mysql", + host="mysql", + login="root", + ) + + mock_supervisor_comms.send.return_value = conn conn = SqliteHook.get_connection(conn_id="airflow_db") assert conn.host != "ec2.compute.com" diff --git a/airflow-core/tests/unit/always/test_example_dags.py b/airflow-core/tests/unit/always/test_example_dags.py index 937e703294afd..442a33db42cdf 100644 --- a/airflow-core/tests/unit/always/test_example_dags.py +++ b/airflow-core/tests/unit/always/test_example_dags.py @@ -27,8 +27,8 @@ from packaging.specifiers import SpecifierSet from packaging.version import Version -from airflow.hooks.base import BaseHook from airflow.models import Connection, DagBag +from airflow.sdk import BaseHook from airflow.utils import yaml from tests_common.test_utils.asserts import assert_queries_count diff --git a/airflow-core/tests/unit/cli/commands/test_connection_command.py b/airflow-core/tests/unit/cli/commands/test_connection_command.py index 1ebb415aa0f82..e8b6f5d98e31d 100644 --- a/airflow-core/tests/unit/cli/commands/test_connection_command.py +++ b/airflow-core/tests/unit/cli/commands/test_connection_command.py @@ -30,7 +30,7 @@ from airflow.cli import cli_config, cli_parser from airflow.cli.commands import connection_command -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowNotFoundException from airflow.models import Connection from airflow.utils.db import merge_conn from airflow.utils.session import create_session @@ -61,7 +61,8 @@ def test_cli_connection_get(self): stdout = stdout.getvalue() assert "google-cloud-platform:///default" in stdout - def test_cli_connection_get_invalid(self): + def test_cli_connection_get_invalid(self, mock_supervisor_comms): + mock_supervisor_comms.send.side_effect = AirflowNotFoundException with pytest.raises(SystemExit, match=re.escape("Connection not found.")): connection_command.connections_get(self.parser.parse_args(["connections", "get", "INVALID"])) diff --git a/airflow-core/tests/unit/cli/commands/test_rotate_fernet_key_command.py b/airflow-core/tests/unit/cli/commands/test_rotate_fernet_key_command.py index 5ec4502962ea4..60f399e29a65c 100644 --- a/airflow-core/tests/unit/cli/commands/test_rotate_fernet_key_command.py +++ b/airflow-core/tests/unit/cli/commands/test_rotate_fernet_key_command.py @@ -23,8 +23,8 @@ from airflow.cli import cli_parser from airflow.cli.commands import rotate_fernet_key_command -from airflow.hooks.base import BaseHook from airflow.models import Connection, Variable +from airflow.sdk import BaseHook from airflow.utils.session import provide_session from tests_common.test_utils.config import conf_vars @@ -84,7 +84,7 @@ def test_should_rotate_variable(self, session): assert Variable.get(key=var2_key) == "value" @provide_session - def test_should_rotate_connection(self, session): + def test_should_rotate_connection(self, session, mock_supervisor_comms): fernet_key1 = Fernet.generate_key() fernet_key2 = Fernet.generate_key() var1_key = f"{__file__}_var1" @@ -111,6 +111,26 @@ def test_should_rotate_connection(self, session): args = self.parser.parse_args(["rotate-fernet-key"]) rotate_fernet_key_command.rotate_fernet_key(args) + def mock_get_connection(conn_id): + conn = session.query(Connection).filter(Connection.conn_id == conn_id).first() + if conn: + from airflow.sdk.execution_time.comms import ConnectionResult + + return ConnectionResult( + conn_id=conn.conn_id, + conn_type=conn.conn_type or "mysql", # Provide a default conn_type + host=conn.host, + login=conn.login, + password=conn.password, + schema_=conn.schema, + port=conn.port, + extra=conn.extra, + ) + raise Exception(f"Connection {conn_id} not found") + + # Mock the send method to return our connection data + mock_supervisor_comms.send.return_value = mock_get_connection(var1_key) + # Assert correctness using a new fernet key with ( conf_vars({("core", "fernet_key"): fernet_key2.decode()}), @@ -119,5 +139,7 @@ def test_should_rotate_connection(self, session): # Unencrypted variable should be unchanged conn1: Connection = BaseHook.get_connection(var1_key) assert conn1.password == "pass" - assert conn1._password == "pass" + + # Mock for the second connection + mock_supervisor_comms.send.return_value = mock_get_connection(var2_key) assert BaseHook.get_connection(var2_key).password == "pass" diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index d4351039db186..1125d003c31d6 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -282,8 +282,7 @@ def test_top_level_connection_access( logger_filehandle = MagicMock() def dag_in_a_fn(): - from airflow.hooks.base import BaseHook - from airflow.sdk import DAG + from airflow.sdk import DAG, BaseHook with DAG(f"test_{BaseHook.get_connection(conn_id='my_conn').conn_id}"): ... @@ -312,8 +311,7 @@ def test_top_level_connection_access_not_found(self, tmp_path: pathlib.Path, inp logger_filehandle = MagicMock() def dag_in_a_fn(): - from airflow.hooks.base import BaseHook - from airflow.sdk import DAG + from airflow.sdk import DAG, BaseHook with DAG(f"test_{BaseHook.get_connection(conn_id='my_conn').conn_id}"): ... diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index d3cecb5fc94e6..c9f9161861f88 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -31,7 +31,6 @@ from asgiref.sync import sync_to_async from airflow.executors import workloads -from airflow.hooks.base import BaseHook from airflow.jobs.job import Job from airflow.jobs.triggerer_job_runner import ( TriggerCommsDecoder, @@ -49,6 +48,7 @@ from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.operators.python import PythonOperator from airflow.providers.standard.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger +from airflow.sdk import BaseHook from airflow.triggers.base import BaseTrigger, TriggerEvent from airflow.triggers.testing import FailureTrigger, SuccessTrigger from airflow.utils import timezone diff --git a/airflow-core/tests/unit/lineage/test_hook.py b/airflow-core/tests/unit/lineage/test_hook.py index 4586e59ea76c1..1c2b2809a3c87 100644 --- a/airflow-core/tests/unit/lineage/test_hook.py +++ b/airflow-core/tests/unit/lineage/test_hook.py @@ -22,7 +22,6 @@ import pytest from airflow import plugins_manager -from airflow.hooks.base import BaseHook from airflow.lineage import hook from airflow.lineage.hook import ( AssetLineageInfo, @@ -32,6 +31,7 @@ NoOpCollector, get_hook_lineage_collector, ) +from airflow.sdk import BaseHook from airflow.sdk.definitions.asset import Asset from tests_common.test_utils.mock_plugins import mock_plugin_manager diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py b/airflow-core/tests/unit/serialization/test_dag_serialization.py index 1a3b55f0a892e..143f2c8acdd51 100644 --- a/airflow-core/tests/unit/serialization/test_dag_serialization.py +++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py @@ -53,7 +53,6 @@ ParamValidationError, SerializationError, ) -from airflow.hooks.base import BaseHook from airflow.models.asset import AssetModel from airflow.models.baseoperator import BaseOperator from airflow.models.connection import Connection @@ -64,7 +63,7 @@ from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator from airflow.providers.standard.operators.bash import BashOperator from airflow.providers.standard.sensors.bash import BashSensor -from airflow.sdk import AssetAlias, teardown +from airflow.sdk import AssetAlias, BaseHook, teardown from airflow.sdk.bases.decorator import DecoratedOperator from airflow.sdk.definitions._internal.expandinput import EXPAND_INPUT_EMPTY from airflow.sdk.definitions.asset import Asset, AssetUniqueKey diff --git a/devel-common/src/tests_common/test_utils/common_sql.py b/devel-common/src/tests_common/test_utils/common_sql.py index b191d524ee488..cb38fd47212a3 100644 --- a/devel-common/src/tests_common/test_utils/common_sql.py +++ b/devel-common/src/tests_common/test_utils/common_sql.py @@ -23,7 +23,7 @@ from airflow.models import Connection if TYPE_CHECKING: - from airflow.hooks.base import BaseHook + from airflow.sdk import BaseHook def mock_db_hook(hook_class: type[BaseHook], hook_params=None, conn_params=None): diff --git a/providers/airbyte/src/airflow/providers/airbyte/hooks/airbyte.py b/providers/airbyte/src/airflow/providers/airbyte/hooks/airbyte.py index 8a9c14f9ba8b2..dc7c39683176d 100644 --- a/providers/airbyte/src/airflow/providers/airbyte/hooks/airbyte.py +++ b/providers/airbyte/src/airflow/providers/airbyte/hooks/airbyte.py @@ -26,7 +26,11 @@ from requests import Session from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] T = TypeVar("T", bound=Any) diff --git a/providers/alibaba/src/airflow/providers/alibaba/cloud/hooks/analyticdb_spark.py b/providers/alibaba/src/airflow/providers/alibaba/cloud/hooks/analyticdb_spark.py index 73b6b33e66992..5b8d7f99ec62b 100644 --- a/providers/alibaba/src/airflow/providers/alibaba/cloud/hooks/analyticdb_spark.py +++ b/providers/alibaba/src/airflow/providers/alibaba/cloud/hooks/analyticdb_spark.py @@ -34,7 +34,11 @@ from alibabacloud_tea_openapi.models import Config from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] from airflow.utils.log.logging_mixin import LoggingMixin diff --git a/providers/alibaba/src/airflow/providers/alibaba/cloud/hooks/base_alibaba.py b/providers/alibaba/src/airflow/providers/alibaba/cloud/hooks/base_alibaba.py index d583c7701176d..1910f77b12bbf 100644 --- a/providers/alibaba/src/airflow/providers/alibaba/cloud/hooks/base_alibaba.py +++ b/providers/alibaba/src/airflow/providers/alibaba/cloud/hooks/base_alibaba.py @@ -18,7 +18,10 @@ from typing import Any, NamedTuple -from airflow.hooks.base import BaseHook +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] class AccessKeyCredentials(NamedTuple): diff --git a/providers/alibaba/src/airflow/providers/alibaba/cloud/hooks/oss.py b/providers/alibaba/src/airflow/providers/alibaba/cloud/hooks/oss.py index c70cad5d9a27d..fb9834c8802f8 100644 --- a/providers/alibaba/src/airflow/providers/alibaba/cloud/hooks/oss.py +++ b/providers/alibaba/src/airflow/providers/alibaba/cloud/hooks/oss.py @@ -27,10 +27,17 @@ from oss2.exceptions import ClientError from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] if TYPE_CHECKING: - from airflow.models.connection import Connection + try: + from airflow.sdk import Connection + except ImportError: + from airflow.models.connection import Connection # type: ignore[assignment] T = TypeVar("T", bound=Callable) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/athena_sql.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/athena_sql.py index d66f788226aa8..f0491153fbd90 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/athena_sql.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/athena_sql.py @@ -127,7 +127,10 @@ def conn_config(self) -> AwsConnectionWrapper: ) return AwsConnectionWrapper( - conn=connection, region_name=self._region_name, botocore_config=self._config, verify=self._verify + conn=connection, + region_name=self._region_name, + botocore_config=self._config, + verify=self._verify, ) @property diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/base_aws.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/base_aws.py index 14e20f0775a1d..18e9d396230a9 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/base_aws.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/base_aws.py @@ -57,12 +57,16 @@ AirflowNotFoundException, AirflowProviderDeprecationWarning, ) -from airflow.hooks.base import BaseHook from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper from airflow.providers.amazon.aws.utils.identifiers import generate_uuid from airflow.providers.amazon.aws.utils.suppress import return_on_error from airflow.providers.common.compat.version_compat import AIRFLOW_V_3_0_PLUS from airflow.providers_manager import ProvidersManager + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] from airflow.utils.helpers import exactly_one from airflow.utils.log.logging_mixin import LoggingMixin @@ -639,7 +643,10 @@ def conn_config(self) -> AwsConnectionWrapper: raise return AwsConnectionWrapper( - conn=connection, region_name=self._region_name, botocore_config=self._config, verify=self._verify + conn=connection, # type: ignore[arg-type] + region_name=self._region_name, + botocore_config=self._config, + verify=self._verify, ) def _resolve_service_name(self, is_resource_type: bool = False) -> str: diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/chime.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/chime.py index f004b1ac8656b..8201f32e76f01 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/chime.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/chime.py @@ -69,6 +69,10 @@ def _get_webhook_endpoint(self, conn_id: str) -> str: token = conn.password if token is None: raise AirflowException("Webhook token field is missing and is required.") + if not conn.schema: + raise AirflowException("Webook schema field is missing and is required") + if not conn.host: + raise AirflowException("Webhook host field is missing and is required.") url = conn.schema + "://" + conn.host endpoint = url + token # Check to make sure the endpoint matches what Chime expects diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py index d7ee54543628c..96a9b6edad8a2 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py @@ -25,9 +25,13 @@ from sagemaker_studio.sagemaker_studio_api import SageMakerStudioAPI from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook from airflow.providers.amazon.aws.utils.sagemaker_unified_studio import is_local_runner +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] + class SageMakerNotebookHook(BaseHook): """ diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_sql.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_sql.py index 7d5a69aed431a..d70a7b105887f 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_sql.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_sql.py @@ -22,10 +22,14 @@ from typing import TYPE_CHECKING from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.providers.amazon.version_compat import BaseOperator +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] + if TYPE_CHECKING: from airflow.utils.context import Context diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/sql_to_s3.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/sql_to_s3.py index fb92a4b4569b1..9546db8fc990c 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/sql_to_s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/sql_to_s3.py @@ -25,10 +25,14 @@ from typing import TYPE_CHECKING, Any, Literal, cast from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.providers.amazon.version_compat import BaseOperator +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] + if TYPE_CHECKING: import pandas as pd diff --git a/providers/amazon/src/airflow/providers/amazon/aws/utils/connection_wrapper.py b/providers/amazon/src/airflow/providers/amazon/aws/utils/connection_wrapper.py index ef2ca7b4cb343..4d10856490ff6 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/utils/connection_wrapper.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/utils/connection_wrapper.py @@ -32,7 +32,10 @@ from airflow.utils.types import NOTSET, ArgNotSet if TYPE_CHECKING: - from airflow.models.connection import Connection # Avoid circular imports. + try: + from airflow.sdk import Connection + except ImportError: + from airflow.models.connection import Connection # type: ignore[assignment] @dataclass diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_s3.py b/providers/amazon/tests/unit/amazon/aws/hooks/test_s3.py index 31b3a625a3971..1664f2f7e8a83 100644 --- a/providers/amazon/tests/unit/amazon/aws/hooks/test_s3.py +++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_s3.py @@ -43,6 +43,16 @@ ) from airflow.utils.timezone import datetime +try: + import importlib.util + + if not importlib.util.find_spec("airflow.sdk.bases.hook"): + raise ImportError + + BASEHOOK_PATCH_PATH = "airflow.sdk.bases.hook.BaseHook" +except ImportError: + BASEHOOK_PATCH_PATH = "airflow.hooks.base.BaseHook" + @pytest.fixture def mocked_s3_res(): @@ -1748,7 +1758,7 @@ def test_delete_bucket_tagging_with_no_tags(self): ("rel_key", "with_conn", "with_bucket", "provide", ["kwargs_bucket", "key.txt"]), ], ) -@patch("airflow.hooks.base.BaseHook.get_connection") +@patch(f"{BASEHOOK_PATCH_PATH}.get_connection") def test_unify_and_provide_bucket_name_combination( mock_base, key_kind, has_conn, has_bucket, precedence, expected, caplog ): @@ -1811,7 +1821,7 @@ def do_something(self, bucket_name=None, key=None): ("rel_key", "with_conn", "with_bucket", ["kwargs_bucket", "key.txt"]), ], ) -@patch("airflow.hooks.base.BaseHook.get_connection") +@patch(f"{BASEHOOK_PATCH_PATH}.get_connection") def test_s3_head_object_decorated_behavior(mock_conn, has_conn, has_bucket, key_kind, expected): if has_conn == "with_conn": c = Connection(extra={"service_config": {"s3": {"bucket_name": "conn_bucket"}}}) diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_base_aws.py b/providers/amazon/tests/unit/amazon/aws/operators/test_base_aws.py index e95748f0989c2..0b561ce791356 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_base_aws.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_base_aws.py @@ -20,9 +20,13 @@ import pytest -from airflow.hooks.base import BaseHook from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] from airflow.utils import timezone TEST_CONN = "aws_test_conn" diff --git a/providers/amazon/tests/unit/amazon/aws/sensors/test_base_aws.py b/providers/amazon/tests/unit/amazon/aws/sensors/test_base_aws.py index 435f80c2aed8c..b3257f8b6783d 100644 --- a/providers/amazon/tests/unit/amazon/aws/sensors/test_base_aws.py +++ b/providers/amazon/tests/unit/amazon/aws/sensors/test_base_aws.py @@ -20,9 +20,13 @@ import pytest -from airflow.hooks.base import BaseHook from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] from airflow.utils import timezone TEST_CONN = "aws_test_conn" diff --git a/providers/apache/beam/src/airflow/providers/apache/beam/hooks/beam.py b/providers/apache/beam/src/airflow/providers/apache/beam/hooks/beam.py index 476a5ffc0555a..635e9b84d82d7 100644 --- a/providers/apache/beam/src/airflow/providers/apache/beam/hooks/beam.py +++ b/providers/apache/beam/src/airflow/providers/apache/beam/hooks/beam.py @@ -37,9 +37,13 @@ from packaging.version import Version from airflow.exceptions import AirflowConfigException, AirflowException -from airflow.hooks.base import BaseHook from airflow.providers.common.compat.standard.utils import prepare_virtualenv +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] + if TYPE_CHECKING: import logging diff --git a/providers/apache/cassandra/src/airflow/providers/apache/cassandra/hooks/cassandra.py b/providers/apache/cassandra/src/airflow/providers/apache/cassandra/hooks/cassandra.py index 9ba9b9915ed4d..3b5af5fa7046a 100644 --- a/providers/apache/cassandra/src/airflow/providers/apache/cassandra/hooks/cassandra.py +++ b/providers/apache/cassandra/src/airflow/providers/apache/cassandra/hooks/cassandra.py @@ -31,7 +31,10 @@ WhiteListRoundRobinPolicy, ) -from airflow.hooks.base import BaseHook +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] from airflow.utils.log.logging_mixin import LoggingMixin Policy: TypeAlias = DCAwareRoundRobinPolicy | RoundRobinPolicy | TokenAwarePolicy | WhiteListRoundRobinPolicy @@ -90,7 +93,7 @@ def __init__(self, cassandra_conn_id: str = default_conn_name): super().__init__() conn = self.get_connection(cassandra_conn_id) - conn_config = {} + conn_config: dict[str, Any] = {} if conn.host: conn_config["contact_points"] = conn.host.split(",") diff --git a/providers/apache/drill/src/airflow/providers/apache/drill/hooks/drill.py b/providers/apache/drill/src/airflow/providers/apache/drill/hooks/drill.py index 86bacc5836d42..74d07ce29e0a7 100644 --- a/providers/apache/drill/src/airflow/providers/apache/drill/hooks/drill.py +++ b/providers/apache/drill/src/airflow/providers/apache/drill/hooks/drill.py @@ -73,7 +73,7 @@ def get_uri(self) -> str: e.g: ``drill://localhost:8047/dfs`` """ conn_md = self.get_connection(self.get_conn_id()) - host = conn_md.host + host = conn_md.host or "" if conn_md.port is not None: host += f":{conn_md.port}" conn_type = conn_md.conn_type or "drill" diff --git a/providers/apache/druid/src/airflow/providers/apache/druid/hooks/druid.py b/providers/apache/druid/src/airflow/providers/apache/druid/hooks/druid.py index 6c7b0a9cb90b4..ccd5228feffb8 100644 --- a/providers/apache/druid/src/airflow/providers/apache/druid/hooks/druid.py +++ b/providers/apache/druid/src/airflow/providers/apache/druid/hooks/druid.py @@ -27,9 +27,13 @@ from pydruid.db import connect from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook from airflow.providers.common.sql.hooks.sql import DbApiHook +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] + if TYPE_CHECKING: from airflow.models import Connection @@ -83,7 +87,7 @@ def __init__( @cached_property def conn(self) -> Connection: - return self.get_connection(self.druid_ingest_conn_id) + return self.get_connection(self.druid_ingest_conn_id) # type: ignore[return-value] @property def get_connection_type(self) -> str: @@ -238,8 +242,8 @@ def get_uri(self) -> str: e.g: druid://localhost:8082/druid/v2/sql/ """ conn = self.get_connection(self.get_conn_id()) - host = conn.host - if conn.port is not None: + host = conn.host or "" + if conn.port: host += f":{conn.port}" conn_type = conn.conn_type or "druid" endpoint = conn.extra_dejson.get("endpoint", "druid/v2/sql") diff --git a/providers/apache/hdfs/src/airflow/providers/apache/hdfs/hooks/webhdfs.py b/providers/apache/hdfs/src/airflow/providers/apache/hdfs/hooks/webhdfs.py index e3b735228fd2f..6edca5ad34167 100644 --- a/providers/apache/hdfs/src/airflow/providers/apache/hdfs/hooks/webhdfs.py +++ b/providers/apache/hdfs/src/airflow/providers/apache/hdfs/hooks/webhdfs.py @@ -19,14 +19,18 @@ import logging import socket -from typing import Any +from typing import Any, cast import requests from hdfs import HdfsError, InsecureClient from airflow.configuration import conf from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] log = logging.getLogger(__name__) @@ -74,7 +78,7 @@ def get_conn(self) -> Any: def _find_valid_server(self) -> Any: connection = self.get_connection(self.webhdfs_conn_id) - namenodes = connection.host.split(",") + namenodes = cast("str", connection.host).split(",") for namenode in namenodes: host_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.log.info("Trying to connect to %s:%s", namenode, connection.port) @@ -84,10 +88,10 @@ def _find_valid_server(self) -> Any: self.log.info("Trying namenode %s", namenode) client = self._get_client( namenode, - connection.port, - connection.login, + cast("int", connection.port), + cast("str", connection.login), connection.password, - connection.schema, + cast("str", connection.schema), connection.extra_dejson, ) client.status("/") diff --git a/providers/apache/hive/src/airflow/providers/apache/hive/hooks/hive.py b/providers/apache/hive/src/airflow/providers/apache/hive/hooks/hive.py index a7d0f94aff151..6c3f55c015636 100644 --- a/providers/apache/hive/src/airflow/providers/apache/hive/hooks/hive.py +++ b/providers/apache/hive/src/airflow/providers/apache/hive/hooks/hive.py @@ -38,9 +38,13 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning -from airflow.hooks.base import BaseHook from airflow.providers.apache.hive.version_compat import AIRFLOW_VAR_NAME_FORMAT_MAPPING from airflow.providers.common.sql.hooks.sql import DbApiHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] from airflow.security import utils from airflow.utils.helpers import as_flattened_list @@ -270,7 +274,7 @@ def run_cli( True """ conn = self.conn - schema = schema or conn.schema + schema = schema or conn.schema or "" invalid_chars_list = re.findall(r"[^a-z0-9_]", schema) if invalid_chars_list: @@ -598,7 +602,9 @@ def sasl_factory() -> sasl.Client: def _find_valid_host(self) -> Any: conn = self.conn - hosts = conn.host.split(",") + hosts = [] + if conn.host: + hosts = conn.host.split(",") for host in hosts: host_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.log.info("Trying to connect to %s:%s", host, conn.port) diff --git a/providers/apache/iceberg/src/airflow/providers/apache/iceberg/hooks/iceberg.py b/providers/apache/iceberg/src/airflow/providers/apache/iceberg/hooks/iceberg.py index 90f9c75078739..a08a29e2be37d 100644 --- a/providers/apache/iceberg/src/airflow/providers/apache/iceberg/hooks/iceberg.py +++ b/providers/apache/iceberg/src/airflow/providers/apache/iceberg/hooks/iceberg.py @@ -16,12 +16,15 @@ # under the License. from __future__ import annotations -from typing import Any +from typing import Any, cast import requests from requests import HTTPError -from airflow.hooks.base import BaseHook +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] TOKENS_ENDPOINT = "oauth/tokens" @@ -75,7 +78,7 @@ def test_connection(self) -> tuple[bool, str]: def get_conn(self) -> str: """Obtain a short-lived access token via a client_id and client_secret.""" conn = self.get_connection(self.conn_id) - base_url = conn.host + base_url = cast("str", conn.host) base_url = base_url.rstrip("/") client_id = conn.login client_secret = conn.password diff --git a/providers/apache/kafka/src/airflow/providers/apache/kafka/hooks/base.py b/providers/apache/kafka/src/airflow/providers/apache/kafka/hooks/base.py index 63b492ce911af..a9a753fb0b1f6 100644 --- a/providers/apache/kafka/src/airflow/providers/apache/kafka/hooks/base.py +++ b/providers/apache/kafka/src/airflow/providers/apache/kafka/hooks/base.py @@ -21,7 +21,10 @@ from confluent_kafka.admin import AdminClient -from airflow.hooks.base import BaseHook +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] class KafkaBaseHook(BaseHook): diff --git a/providers/apache/kafka/tests/unit/apache/kafka/hooks/test_base.py b/providers/apache/kafka/tests/unit/apache/kafka/hooks/test_base.py index 5701a789b70e6..be790fbab5061 100644 --- a/providers/apache/kafka/tests/unit/apache/kafka/hooks/test_base.py +++ b/providers/apache/kafka/tests/unit/apache/kafka/hooks/test_base.py @@ -23,6 +23,16 @@ from airflow.providers.apache.kafka.hooks.base import KafkaBaseHook +try: + import importlib.util + + if not importlib.util.find_spec("airflow.sdk.bases.hook"): + raise ImportError + + BASEHOOK_PATCH_PATH = "airflow.sdk.bases.hook.BaseHook" +except ImportError: + BASEHOOK_PATCH_PATH = "airflow.hooks.base.BaseHook" + class SomeKafkaHook(KafkaBaseHook): def _get_client(self, config): @@ -38,20 +48,20 @@ def hook(): class TestKafkaBaseHook: - @mock.patch("airflow.hooks.base.BaseHook.get_connection") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") def test_get_conn(self, mock_get_connection, hook): config = {"bootstrap.servers": MagicMock()} mock_get_connection.return_value.extra_dejson = config assert hook.get_conn == config - @mock.patch("airflow.hooks.base.BaseHook.get_connection") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") def test_get_conn_value_error(self, mock_get_connection, hook): mock_get_connection.return_value.extra_dejson = {} with pytest.raises(ValueError, match="must be provided"): hook.get_conn() @mock.patch("airflow.providers.apache.kafka.hooks.base.AdminClient") - @mock.patch("airflow.hooks.base.BaseHook.get_connection") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") def test_test_connection(self, mock_get_connection, admin_client, hook): config = {"bootstrap.servers": MagicMock()} mock_get_connection.return_value.extra_dejson = config @@ -65,7 +75,7 @@ def test_test_connection(self, mock_get_connection, admin_client, hook): "airflow.providers.apache.kafka.hooks.base.AdminClient", return_value=MagicMock(list_topics=MagicMock(return_value=[])), ) - @mock.patch("airflow.hooks.base.BaseHook.get_connection") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") def test_test_connection_no_topics(self, mock_get_connection, admin_client, hook): config = {"bootstrap.servers": MagicMock()} mock_get_connection.return_value.extra_dejson = config @@ -76,7 +86,7 @@ def test_test_connection_no_topics(self, mock_get_connection, admin_client, hook assert connection == (False, "Failed to establish connection.") @mock.patch("airflow.providers.apache.kafka.hooks.base.AdminClient") - @mock.patch("airflow.hooks.base.BaseHook.get_connection") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") def test_test_connection_exception(self, mock_get_connection, admin_client, hook): config = {"bootstrap.servers": MagicMock()} mock_get_connection.return_value.extra_dejson = config diff --git a/providers/apache/kylin/src/airflow/providers/apache/kylin/hooks/kylin.py b/providers/apache/kylin/src/airflow/providers/apache/kylin/hooks/kylin.py index cfc2021f66069..31a76ac1b1f1b 100644 --- a/providers/apache/kylin/src/airflow/providers/apache/kylin/hooks/kylin.py +++ b/providers/apache/kylin/src/airflow/providers/apache/kylin/hooks/kylin.py @@ -20,7 +20,11 @@ from kylinpy import exceptions, kylinpy from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] class KylinHook(BaseHook): diff --git a/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py b/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py index 2c9f4bb7de609..378d3a0ab7b12 100644 --- a/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py +++ b/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py @@ -528,7 +528,7 @@ async def _do_api_call_async( if self.http_conn_id: conn = await sync_to_async(self.get_connection)(self.http_conn_id) - self.base_url = self._generate_base_url(conn) + self.base_url = self._generate_base_url(conn) # type: ignore[arg-type] if conn.login: auth = self.auth_type(conn.login, conn.password) if conn.extra: diff --git a/providers/apache/pig/src/airflow/providers/apache/pig/hooks/pig.py b/providers/apache/pig/src/airflow/providers/apache/pig/hooks/pig.py index 610ce28046704..b067d5ab5316d 100644 --- a/providers/apache/pig/src/airflow/providers/apache/pig/hooks/pig.py +++ b/providers/apache/pig/src/airflow/providers/apache/pig/hooks/pig.py @@ -22,7 +22,11 @@ from typing import Any from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] class PigCliHook(BaseHook): diff --git a/providers/apache/pinot/src/airflow/providers/apache/pinot/hooks/pinot.py b/providers/apache/pinot/src/airflow/providers/apache/pinot/hooks/pinot.py index e14136362b6ae..c27612b6d3356 100644 --- a/providers/apache/pinot/src/airflow/providers/apache/pinot/hooks/pinot.py +++ b/providers/apache/pinot/src/airflow/providers/apache/pinot/hooks/pinot.py @@ -26,9 +26,13 @@ from pinotdb import connect from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook from airflow.providers.common.sql.hooks.sql import DbApiHook +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] + if TYPE_CHECKING: from airflow.models import Connection @@ -106,7 +110,8 @@ def add_schema(self, schema_file: str, with_exec: bool = True) -> Any: cmd += ["-user", self.username] if self.password: cmd += ["-password", self.password] - cmd += ["-controllerHost", self.host] + if self.host is not None: + cmd += ["-controllerHost", self.host] cmd += ["-controllerPort", self.port] cmd += ["-schemaFile", schema_file] if with_exec: @@ -125,7 +130,8 @@ def add_table(self, file_path: str, with_exec: bool = True) -> Any: cmd += ["-user", self.username] if self.password: cmd += ["-password", self.password] - cmd += ["-controllerHost", self.host] + if self.host is not None: + cmd += ["-controllerHost", self.host] cmd += ["-controllerPort", self.port] cmd += ["-filePath", file_path] if with_exec: @@ -230,7 +236,8 @@ def upload_segment(self, segment_dir: str, table_name: str | None = None) -> Any cmd += ["-user", self.username] if self.password: cmd += ["-password", self.password] - cmd += ["-controllerHost", self.host] + if self.host is not None: + cmd += ["-controllerHost", self.host] cmd += ["-controllerPort", self.port] cmd += ["-segmentDir", segment_dir] if table_name: @@ -312,7 +319,7 @@ def get_uri(self) -> str: e.g: http://localhost:9000/query/sql """ conn = self.get_connection(self.get_conn_id()) - host = conn.host + host = conn.host or "" if conn.login and conn.password: host = f"{quote_plus(conn.login)}:{quote_plus(conn.password)}@{host}" if conn.port: diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/decorators/pyspark.py b/providers/apache/spark/src/airflow/providers/apache/spark/decorators/pyspark.py index 8e93185321f1b..d2e2dd8a403e9 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/decorators/pyspark.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/decorators/pyspark.py @@ -33,10 +33,14 @@ ) -from airflow.hooks.base import BaseHook from airflow.providers.apache.spark.hooks.spark_connect import SparkConnectHook from airflow.providers.common.compat.standard.operators import PythonOperator +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] + if TYPE_CHECKING: from airflow.utils.context import Context diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_connect.py b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_connect.py index acdeaae1f4fd4..b04c87c927afb 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_connect.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_connect.py @@ -17,10 +17,13 @@ # under the License. from __future__ import annotations -from typing import Any +from typing import Any, cast from urllib.parse import quote, urlparse, urlunparse -from airflow.hooks.base import BaseHook +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] from airflow.utils.log.logging_mixin import LoggingMixin @@ -64,11 +67,11 @@ def __init__(self, conn_id: str = default_conn_name) -> None: def get_connection_url(self) -> str: conn = self.get_connection(self._conn_id) - host = conn.host - if conn.host.find("://") == -1: - host = f"sc://{conn.host}" + host = cast("str", conn.host) + if host.find("://") == -1: + host = f"sc://{host}" if conn.port: - host = f"{conn.host}:{conn.port}" + host = f"{host}:{conn.port}" url = urlparse(host) diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_jdbc.py b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_jdbc.py index b904ca4260e61..3d8f2a646e8e2 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_jdbc.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_jdbc.py @@ -18,7 +18,7 @@ from __future__ import annotations import os -from typing import Any +from typing import Any, cast from airflow.exceptions import AirflowException from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook @@ -162,17 +162,17 @@ def _resolve_jdbc_connection(self) -> dict[str, Any]: conn_data = {"url": "", "schema": "", "conn_prefix": "", "user": "", "password": ""} try: conn = self.get_connection(self._jdbc_conn_id) - if "/" in conn.host: + if conn.host is not None and "/" in conn.host: raise ValueError("The jdbc host should not contain a '/'") - if "?" in conn.schema: + if conn.schema is not None and "?" in conn.schema: raise ValueError("The jdbc schema should not contain a '?'") - if conn.port: - conn_data["url"] = f"{conn.host}:{conn.port}" + if conn.port is not None: + conn_data["url"] = f"{cast('str', conn.host)}:{conn.port}" else: - conn_data["url"] = conn.host - conn_data["schema"] = conn.schema - conn_data["user"] = conn.login - conn_data["password"] = conn.password + conn_data["url"] = cast("str", conn.host) + conn_data["schema"] = cast("str", conn.schema) + conn_data["user"] = cast("str", conn.login) + conn_data["password"] = cast("str", conn.password) extra = conn.extra_dejson conn_data["conn_prefix"] = extra.get("conn_prefix", "") except AirflowException: diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_sql.py b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_sql.py index f31dcbaaa27cd..6ebb046e1d3b2 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_sql.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_sql.py @@ -21,10 +21,17 @@ from typing import TYPE_CHECKING, Any from airflow.exceptions import AirflowException, AirflowNotFoundException -from airflow.hooks.base import BaseHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] if TYPE_CHECKING: - from airflow.models.connection import Connection + try: + from airflow.sdk import Connection + except ImportError: + from airflow.models.connection import Connection # type: ignore[assignment] class SparkSqlHook(BaseHook): diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py index 7a0c0bbd2038e..c846016b56912 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py @@ -32,7 +32,11 @@ from airflow.configuration import conf as airflow_conf from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] from airflow.security.kerberos import renew_from_kt from airflow.utils.log.logging_mixin import LoggingMixin diff --git a/providers/apache/tinkerpop/src/airflow/providers/apache/tinkerpop/hooks/gremlin.py b/providers/apache/tinkerpop/src/airflow/providers/apache/tinkerpop/hooks/gremlin.py index 3c547a02c338c..bc5d96e6225d8 100644 --- a/providers/apache/tinkerpop/src/airflow/providers/apache/tinkerpop/hooks/gremlin.py +++ b/providers/apache/tinkerpop/src/airflow/providers/apache/tinkerpop/hooks/gremlin.py @@ -23,7 +23,10 @@ from gremlin_python.driver.client import Client -from airflow.hooks.base import BaseHook +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] if TYPE_CHECKING: from airflow.models import Connection @@ -67,11 +70,14 @@ def get_conn(self, serializer=None) -> Client: self.connection = self.get_connection(self.gremlin_conn_id) - uri = self.get_uri(self.connection) + uri = self.get_uri(self.connection) # type: ignore[arg-type] self.log.info("Connecting to URI: %s", uri) self.client = self.get_client( - self.connection, self.traversal_source, uri, message_serializer=serializer + self.connection, # type: ignore[arg-type] + self.traversal_source, + uri, + message_serializer=serializer, ) return self.client diff --git a/providers/apprise/src/airflow/providers/apprise/hooks/apprise.py b/providers/apprise/src/airflow/providers/apprise/hooks/apprise.py index 4fe373281f6c7..6998d9dcce5ba 100644 --- a/providers/apprise/src/airflow/providers/apprise/hooks/apprise.py +++ b/providers/apprise/src/airflow/providers/apprise/hooks/apprise.py @@ -24,7 +24,10 @@ import apprise from apprise import AppriseConfig, NotifyFormat, NotifyType -from airflow.hooks.base import BaseHook +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] if TYPE_CHECKING: from apprise import AppriseAttachment diff --git a/providers/arangodb/src/airflow/providers/arangodb/hooks/arangodb.py b/providers/arangodb/src/airflow/providers/arangodb/hooks/arangodb.py index edd5ad2d5b487..c5ef6a8b9fc38 100644 --- a/providers/arangodb/src/airflow/providers/arangodb/hooks/arangodb.py +++ b/providers/arangodb/src/airflow/providers/arangodb/hooks/arangodb.py @@ -32,7 +32,11 @@ ) from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] if TYPE_CHECKING: from arango.database import StandardDatabase @@ -70,7 +74,7 @@ def db_conn(self) -> StandardDatabase: @cached_property def _conn(self) -> Connection: - return self.get_connection(self.arangodb_conn_id) + return self.get_connection(self.arangodb_conn_id) # type: ignore[return-value] @property def hosts(self) -> list[str]: diff --git a/providers/asana/src/airflow/providers/asana/hooks/asana.py b/providers/asana/src/airflow/providers/asana/hooks/asana.py index dc78ab0ee4e20..4d56f2417f515 100644 --- a/providers/asana/src/airflow/providers/asana/hooks/asana.py +++ b/providers/asana/src/airflow/providers/asana/hooks/asana.py @@ -28,7 +28,10 @@ from asana.configuration import Configuration from asana.rest import ApiException -from airflow.hooks.base import BaseHook +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] class AsanaHook(BaseHook): diff --git a/providers/atlassian/jira/src/airflow/providers/atlassian/jira/hooks/jira.py b/providers/atlassian/jira/src/airflow/providers/atlassian/jira/hooks/jira.py index 7ebe60fb82844..f761dff5fea54 100644 --- a/providers/atlassian/jira/src/airflow/providers/atlassian/jira/hooks/jira.py +++ b/providers/atlassian/jira/src/airflow/providers/atlassian/jira/hooks/jira.py @@ -19,12 +19,16 @@ from __future__ import annotations -from typing import Any +from typing import Any, cast from atlassian import Jira from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] class JiraHook(BaseHook): @@ -62,7 +66,7 @@ def get_conn(self) -> Jira: # more can be added ex: timeout, cloud, session self.client = Jira( - url=conn.host, + url=cast("str", conn.host), username=conn.login, password=conn.password, verify_ssl=verify, diff --git a/providers/cloudant/src/airflow/providers/cloudant/hooks/cloudant.py b/providers/cloudant/src/airflow/providers/cloudant/hooks/cloudant.py index 35cb6064b9ef9..212931bee4db7 100644 --- a/providers/cloudant/src/airflow/providers/cloudant/hooks/cloudant.py +++ b/providers/cloudant/src/airflow/providers/cloudant/hooks/cloudant.py @@ -24,7 +24,11 @@ from ibmcloudant import CloudantV1, CouchDbSessionAuthenticator from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] if TYPE_CHECKING: from airflow.models import Connection @@ -70,13 +74,14 @@ def get_conn(self) -> CloudantV1: """ conn = self.get_connection(self.cloudant_conn_id) - self._validate_connection(conn) - - authenticator = CouchDbSessionAuthenticator(username=conn.login, password=conn.password) - service = CloudantV1(authenticator=authenticator) - service.set_service_url(f"https://{conn.host}.cloudant.com") + self._validate_connection(conn) # type: ignore[arg-type] + if conn.login and conn.password: + authenticator = CouchDbSessionAuthenticator(username=conn.login, password=conn.password) + service = CloudantV1(authenticator=authenticator) + service.set_service_url(f"https://{conn.host}.cloudant.com") - return service + return service + raise AirflowException("Missing login or password in Cloudant connection.") @staticmethod def _validate_connection(conn: Connection) -> None: diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py index f90de71855f9d..a1e4436a0b71d 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py @@ -36,7 +36,6 @@ from urllib3.exceptions import HTTPError from airflow.exceptions import AirflowException, AirflowNotFoundException -from airflow.hooks.base import BaseHook from airflow.models import Connection from airflow.providers.cncf.kubernetes.kube_client import _disable_verify_ssl, _enable_tcp_keepalive from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import should_retry_creation @@ -45,6 +44,11 @@ container_is_completed, container_is_running, ) + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] from airflow.utils import yaml if TYPE_CHECKING: @@ -174,7 +178,7 @@ def get_connection(cls, conn_id: str) -> Connection: default to cluster-derived credentials. """ try: - return super().get_connection(conn_id) + return super().get_connection(conn_id) # type: ignore[return-value] except AirflowNotFoundException: if conn_id == cls.default_conn_name: return Connection(conn_id=cls.default_conn_name) diff --git a/providers/cohere/src/airflow/providers/cohere/hooks/cohere.py b/providers/cohere/src/airflow/providers/cohere/hooks/cohere.py index ccfdd4b178e18..a0022cfa5e38d 100644 --- a/providers/cohere/src/airflow/providers/cohere/hooks/cohere.py +++ b/providers/cohere/src/airflow/providers/cohere/hooks/cohere.py @@ -25,7 +25,11 @@ from cohere.types import UserChatMessageV2 from airflow.exceptions import AirflowProviderDeprecationWarning -from airflow.hooks.base import BaseHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] if TYPE_CHECKING: from cohere.core.request_options import RequestOptions diff --git a/providers/common/sql/src/airflow/providers/common/sql/hooks/sql.py b/providers/common/sql/src/airflow/providers/common/sql/hooks/sql.py index d0cf632cf1b17..28cefecf8021f 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/hooks/sql.py +++ b/providers/common/sql/src/airflow/providers/common/sql/hooks/sql.py @@ -39,9 +39,13 @@ AirflowOptionalProviderFeatureException, AirflowProviderDeprecationWarning, ) -from airflow.hooks.base import BaseHook from airflow.providers.common.sql.dialects.dialect import Dialect from airflow.providers.common.sql.hooks import handlers + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] from airflow.utils.module_loading import import_string if TYPE_CHECKING: @@ -49,10 +53,14 @@ from polars import DataFrame as PolarsDataFrame from sqlalchemy.engine import URL, Engine, Inspector - from airflow.models import Connection from airflow.providers.openlineage.extractors import OperatorLineage from airflow.providers.openlineage.sqlparser import DatabaseInfo + try: + from airflow.sdk import Connection + except ImportError: + from airflow.models.connection import Connection # type: ignore[assignment] + T = TypeVar("T") SQL_PLACEHOLDERS = frozenset({"%s", "?"}) @@ -269,7 +277,10 @@ def get_conn(self) -> Any: db = self.connection if self.connector is None: raise RuntimeError(f"{type(self).__name__} didn't have `self.connector` set!") - return self.connector.connect(host=db.host, port=db.port, username=db.login, schema=db.schema) + host = db.host or "" + login = db.login or "" + schema = db.schema or "" + return self.connector.connect(host=host, port=cast("int", db.port), username=login, schema=schema) def get_uri(self) -> str: """ diff --git a/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py b/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py index 4e3b0a87ce027..748a146e0d060 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py +++ b/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py @@ -22,11 +22,15 @@ from typing import TYPE_CHECKING, Any from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.common.sql.triggers.sql import SQLExecuteQueryTrigger from airflow.providers.common.sql.version_compat import BaseOperator +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] + if TYPE_CHECKING: import jinja2 diff --git a/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py b/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py index 250a249d5af19..ce0cbe5214bbb 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py +++ b/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py @@ -24,13 +24,17 @@ from typing import TYPE_CHECKING, Any, ClassVar, NoReturn, SupportsAbs from airflow.exceptions import AirflowException, AirflowFailException -from airflow.hooks.base import BaseHook from airflow.models import SkipMixin from airflow.providers.common.sql.hooks.handlers import fetch_all_handler, return_single_query_results from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.common.sql.version_compat import BaseOperator from airflow.utils.helpers import merge_dicts +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] + if TYPE_CHECKING: from airflow.providers.openlineage.extractors import OperatorLineage from airflow.utils.context import Context diff --git a/providers/common/sql/src/airflow/providers/common/sql/sensors/sql.py b/providers/common/sql/src/airflow/providers/common/sql/sensors/sql.py index 0aa83fef2e5b8..687b39f63c0d1 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/sensors/sql.py +++ b/providers/common/sql/src/airflow/providers/common/sql/sensors/sql.py @@ -21,8 +21,13 @@ from typing import TYPE_CHECKING, Any from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook from airflow.providers.common.sql.hooks.sql import DbApiHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] + from airflow.providers.common.sql.version_compat import AIRFLOW_V_3_0_PLUS if AIRFLOW_V_3_0_PLUS: diff --git a/providers/common/sql/src/airflow/providers/common/sql/triggers/sql.py b/providers/common/sql/src/airflow/providers/common/sql/triggers/sql.py index f4a41192acda3..64618fc8f5f38 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/triggers/sql.py +++ b/providers/common/sql/src/airflow/providers/common/sql/triggers/sql.py @@ -20,8 +20,12 @@ from typing import TYPE_CHECKING from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook from airflow.providers.common.sql.hooks.sql import DbApiHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] from airflow.triggers.base import BaseTrigger, TriggerEvent if TYPE_CHECKING: diff --git a/providers/common/sql/tests/unit/common/sql/hooks/test_dbapi.py b/providers/common/sql/tests/unit/common/sql/hooks/test_dbapi.py index 6cccd66b35d15..29a9f8350e8fe 100644 --- a/providers/common/sql/tests/unit/common/sql/hooks/test_dbapi.py +++ b/providers/common/sql/tests/unit/common/sql/hooks/test_dbapi.py @@ -26,12 +26,16 @@ from pyodbc import Cursor from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG -from airflow.hooks.base import BaseHook from airflow.models import Connection from airflow.providers.common.sql.dialects.dialect import Dialect from airflow.providers.common.sql.hooks.handlers import fetch_all_handler, fetch_one_handler from airflow.providers.common.sql.hooks.sql import DbApiHook +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] + class DbApiHookInProvider(DbApiHook): conn_name_attr = "test_conn_id" diff --git a/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py b/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py index fe01d68d2f1b4..07a250cc40d4c 100644 --- a/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py +++ b/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py @@ -38,6 +38,15 @@ from tests_common.test_utils.operators.run_deferrable import execute_operator, mock_context from tests_common.test_utils.providers import get_provider_min_airflow_version +try: + import importlib.util + + if not importlib.util.find_spec("airflow.sdk.bases.hook"): + raise ImportError + + BASEHOOK_PATCH_PATH = "airflow.sdk.bases.hook.BaseHook" +except ImportError: + BASEHOOK_PATCH_PATH = "airflow.hooks.base.BaseHook" pytestmark = pytest.mark.db_test DEFAULT_DATE = timezone.datetime(2015, 1, 1) @@ -252,8 +261,8 @@ def test_templated_fields(self): assert operator.insert_args == {"commit_every": 5000, "executemany": True, "replace": True} def test_non_paginated_read(self): - with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=self.get_connection): - with mock.patch("airflow.hooks.base.BaseHook.get_hook", side_effect=self.get_hook): + with mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=self.get_connection): + with mock.patch(f"{BASEHOOK_PATCH_PATH}.get_hook", side_effect=self.get_hook): operator = GenericTransfer( task_id="transfer_table", source_conn_id="my_source_conn_id", @@ -280,8 +289,8 @@ def test_paginated_read(self): https://medium.com/apache-airflow/transfering-data-from-sap-hana-to-mssql-using-the-airflow-generictransfer-d29f147a9f1f """ - with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=self.get_connection): - with mock.patch("airflow.hooks.base.BaseHook.get_hook", side_effect=self.get_hook): + with mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=self.get_connection): + with mock.patch(f"{BASEHOOK_PATCH_PATH}.get_hook", side_effect=self.get_hook): operator = GenericTransfer( task_id="transfer_table", source_conn_id="my_source_conn_id", diff --git a/providers/common/sql/tests/unit/common/sql/operators/test_sql.py b/providers/common/sql/tests/unit/common/sql/operators/test_sql.py index b3e02f8d7f8ce..058ccf2b55b25 100644 --- a/providers/common/sql/tests/unit/common/sql/operators/test_sql.py +++ b/providers/common/sql/tests/unit/common/sql/operators/test_sql.py @@ -24,6 +24,15 @@ import pytest +try: + import importlib.util + + if not importlib.util.find_spec("airflow.sdk.bases.hook"): + raise ImportError + + BASEHOOK_PATCH_PATH = "airflow.sdk.bases.hook.BaseHook" +except ImportError: + BASEHOOK_PATCH_PATH = "airflow.hooks.base.BaseHook" from airflow import DAG from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.models import Connection, DagRun, TaskInstance as TI @@ -1174,7 +1183,6 @@ def test_invalid_follow_task_true(self): follow_task_ids_if_false=["branch_2"], dag=self.dag, ) - with pytest.raises(AirflowException): op.execute({}) @@ -1188,7 +1196,6 @@ def test_invalid_follow_task_false(self): follow_task_ids_if_false=[], dag=self.dag, ) - with pytest.raises(AirflowException): op.execute({}) @@ -1554,7 +1561,7 @@ def __init__(self, custom_conn_id_field="test_conn", **kwargs): @pytest.mark.parametrize( "operator_class", [NewStyleBaseSQLOperatorSubClass, OldStyleBaseSQLOperatorSubClass] ) - @mock.patch("airflow.hooks.base.BaseHook.get_connection") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") def test_new_style_subclass(self, mock_get_connection, operator_class): from airflow.providers.common.sql.hooks.sql import DbApiHook diff --git a/providers/common/sql/tests/unit/common/sql/triggers/test_sql.py b/providers/common/sql/tests/unit/common/sql/triggers/test_sql.py index fbcb5da6fa6fd..e09fba590be73 100644 --- a/providers/common/sql/tests/unit/common/sql/triggers/test_sql.py +++ b/providers/common/sql/tests/unit/common/sql/triggers/test_sql.py @@ -23,11 +23,20 @@ from airflow.providers.common.sql.triggers.sql import SQLExecuteQueryTrigger from airflow.triggers.base import TriggerEvent +try: + import importlib.util + + if not importlib.util.find_spec("airflow.sdk.bases.hook"): + raise ImportError + + BASEHOOK_PATCH_PATH = "airflow.sdk.bases.hook.BaseHook" +except ImportError: + BASEHOOK_PATCH_PATH = "airflow.hooks.base.BaseHook" from tests_common.test_utils.operators.run_deferrable import run_trigger class TestSQLExecuteQueryTrigger: - @mock.patch("airflow.hooks.base.BaseHook.get_connection") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") def test_run(self, mock_get_connection): data = [(1, "Alice"), (2, "Bob")] mock_connection = mock.MagicMock(spec=Connection) diff --git a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py index 2829c7012d32b..328b19d2a477a 100644 --- a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py +++ b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py @@ -50,9 +50,13 @@ from airflow import __version__ from airflow.exceptions import AirflowException, AirflowOptionalProviderFeatureException -from airflow.hooks.base import BaseHook from airflow.providers_manager import ProvidersManager +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook as BaseHook # type: ignore + if TYPE_CHECKING: from airflow.models import Connection @@ -135,7 +139,7 @@ def my_after_func(retry_state): @cached_property def databricks_conn(self) -> Connection: - return self.get_connection(self.databricks_conn_id) + return self.get_connection(self.databricks_conn_id) # type: ignore[return-value] def get_conn(self) -> Connection: return self.databricks_conn diff --git a/providers/datadog/src/airflow/providers/datadog/hooks/datadog.py b/providers/datadog/src/airflow/providers/datadog/hooks/datadog.py index 3530aa7677ec4..1901471860402 100644 --- a/providers/datadog/src/airflow/providers/datadog/hooks/datadog.py +++ b/providers/datadog/src/airflow/providers/datadog/hooks/datadog.py @@ -23,7 +23,11 @@ from datadog import api, initialize # type: ignore[attr-defined] from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] from airflow.utils.log.logging_mixin import LoggingMixin diff --git a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py index 4abd415480bcc..c7023f1b9230a 100644 --- a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py +++ b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py @@ -283,7 +283,7 @@ def connection(self) -> Connection: if not _connection.password: raise AirflowException("An API token is required to connect to dbt Cloud.") - return _connection + return _connection # type: ignore[return-value] def get_conn(self, *args, **kwargs) -> Session: tenant = self._get_tenant_domain(self.connection) diff --git a/providers/docker/src/airflow/providers/docker/hooks/docker.py b/providers/docker/src/airflow/providers/docker/hooks/docker.py index fa1377bedd021..c07ba908392e6 100644 --- a/providers/docker/src/airflow/providers/docker/hooks/docker.py +++ b/providers/docker/src/airflow/providers/docker/hooks/docker.py @@ -27,7 +27,11 @@ from docker.errors import APIError, DockerException from airflow.exceptions import AirflowException, AirflowNotFoundException -from airflow.hooks.base import BaseHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] if TYPE_CHECKING: from airflow.models import Connection @@ -150,7 +154,7 @@ def api_client(self) -> APIClient: raise AirflowException(msg) if self.docker_conn_id: # Obtain connection and try to login to Container Registry only if ``docker_conn_id`` set. - self.__login(client, self.get_connection(self.docker_conn_id)) + self.__login(client, self.get_connection(self.docker_conn_id)) # type: ignore[arg-type] except APIError: raise except DockerException as e: diff --git a/providers/docker/tests/conftest.py b/providers/docker/tests/conftest.py index 89feb74cf6913..88e0bc5a753e4 100644 --- a/providers/docker/tests/conftest.py +++ b/providers/docker/tests/conftest.py @@ -40,7 +40,17 @@ def hook_conn(request): except AttributeError: conn = None - with mock.patch("airflow.hooks.base.BaseHook.get_connection") as m: + try: + import importlib.util + + if not importlib.util.find_spec("airflow.sdk.bases.hook"): + raise ImportError + + basehook_pp = "airflow.sdk.bases.hook.BaseHook" + except ImportError: + basehook_pp = "airflow.hooks.base.BaseHook" + + with mock.patch(f"{basehook_pp}.get_connection") as m: if not conn: pass # Don't do anything if param not specified or empty elif isinstance(conn, dict): diff --git a/providers/edge3/src/airflow/providers/edge3/example_dags/integration_test.py b/providers/edge3/src/airflow/providers/edge3/example_dags/integration_test.py index a1fec39e71cbf..b1bcea3027028 100644 --- a/providers/edge3/src/airflow/providers/edge3/example_dags/integration_test.py +++ b/providers/edge3/src/airflow/providers/edge3/example_dags/integration_test.py @@ -27,7 +27,11 @@ from time import sleep from airflow.exceptions import AirflowNotFoundException -from airflow.hooks.base import BaseHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] from airflow.utils.trigger_rule import TriggerRule try: diff --git a/providers/edge3/src/airflow/providers/edge3/example_dags/win_test.py b/providers/edge3/src/airflow/providers/edge3/example_dags/win_test.py index 8de7b50201d2c..1209c2499acd2 100644 --- a/providers/edge3/src/airflow/providers/edge3/example_dags/win_test.py +++ b/providers/edge3/src/airflow/providers/edge3/example_dags/win_test.py @@ -34,11 +34,15 @@ from airflow.decorators import task, task_group from airflow.exceptions import AirflowException, AirflowNotFoundException, AirflowSkipException -from airflow.hooks.base import BaseHook from airflow.models import BaseOperator from airflow.models.dag import DAG from airflow.models.variable import Variable from airflow.providers.standard.operators.empty import EmptyOperator + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] from airflow.sdk import Param from airflow.sdk.execution_time.context import context_to_airflow_vars from airflow.utils.trigger_rule import TriggerRule diff --git a/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py b/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py index dba66b44462cb..02eed4036ac2d 100644 --- a/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py +++ b/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py @@ -20,14 +20,18 @@ from collections.abc import Iterable, Mapping from copy import deepcopy from functools import cached_property -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from urllib import parse from elasticsearch import Elasticsearch -from airflow.hooks.base import BaseHook from airflow.providers.common.sql.hooks.sql import DbApiHook +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] + if TYPE_CHECKING: from elastic_transport import ObjectApiResponse @@ -179,8 +183,8 @@ def get_conn(self) -> ESConnection: conn = self.connection conn_args = { - "host": conn.host, - "port": conn.port, + "host": cast("str", conn.host), + "port": cast("int", conn.port), "user": conn.login or None, "password": conn.password or None, "scheme": conn.schema or "http", @@ -191,7 +195,7 @@ def get_conn(self) -> ESConnection: if conn_args.get("http_compress", False): conn_args["http_compress"] = bool(conn_args["http_compress"]) - return connect(**conn_args) + return connect(**conn_args) # type: ignore[arg-type] def get_uri(self) -> str: conn = self.connection @@ -199,7 +203,7 @@ def get_uri(self) -> str: login = "" if conn.login: login = f"{conn.login}:{conn.password}@" - host = conn.host + host = conn.host or "" if conn.port is not None: host += f":{conn.port}" uri = f"{conn.conn_type}+{conn.schema}://{login}{host}/" diff --git a/providers/facebook/src/airflow/providers/facebook/ads/hooks/ads.py b/providers/facebook/src/airflow/providers/facebook/ads/hooks/ads.py index 0cff27bbd68d1..6c5ee194f2bb5 100644 --- a/providers/facebook/src/airflow/providers/facebook/ads/hooks/ads.py +++ b/providers/facebook/src/airflow/providers/facebook/ads/hooks/ads.py @@ -29,7 +29,11 @@ from facebook_business.api import FacebookAdsApi from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] if TYPE_CHECKING: from facebook_business.adobjects.adsinsights import AdsInsights diff --git a/providers/facebook/tests/unit/facebook/ads/hooks/test_ads.py b/providers/facebook/tests/unit/facebook/ads/hooks/test_ads.py index 2656a935731c7..1f633b4bb9cbc 100644 --- a/providers/facebook/tests/unit/facebook/ads/hooks/test_ads.py +++ b/providers/facebook/tests/unit/facebook/ads/hooks/test_ads.py @@ -20,6 +20,16 @@ import pytest +try: + import importlib.util + + if not importlib.util.find_spec("airflow.sdk.bases.hook"): + raise ImportError + + BASEHOOK_PATCH_PATH = "airflow.sdk.bases.hook.BaseHook" +except ImportError: + BASEHOOK_PATCH_PATH = "airflow.hooks.base.BaseHook" + from airflow.providers.facebook.ads.hooks.ads import FacebookAdsReportingHook API_VERSION = "api_version" @@ -44,7 +54,7 @@ @pytest.fixture def mock_hook(): - with mock.patch("airflow.hooks.base.BaseHook.get_connection") as conn: + with mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") as conn: hook = FacebookAdsReportingHook(api_version=API_VERSION) conn.return_value.extra_dejson = EXTRAS yield hook @@ -52,7 +62,7 @@ def mock_hook(): @pytest.fixture def mock_hook_multiple(): - with mock.patch("airflow.hooks.base.BaseHook.get_connection") as conn: + with mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") as conn: hook = FacebookAdsReportingHook(api_version=API_VERSION) conn.return_value.extra_dejson = EXTRAS_MULTIPLE yield hook diff --git a/providers/ftp/src/airflow/providers/ftp/hooks/ftp.py b/providers/ftp/src/airflow/providers/ftp/hooks/ftp.py index 3a2593f7d24a4..3fd389c57e6ba 100644 --- a/providers/ftp/src/airflow/providers/ftp/hooks/ftp.py +++ b/providers/ftp/src/airflow/providers/ftp/hooks/ftp.py @@ -21,9 +21,12 @@ import ftplib # nosec: B402 import logging from collections.abc import Callable -from typing import Any +from typing import Any, cast -from airflow.hooks.base import BaseHook +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] logger = logging.getLogger(__name__) @@ -64,12 +67,13 @@ def get_conn(self) -> ftplib.FTP: pasv = params.extra_dejson.get("passive", True) self.conn = ftplib.FTP() # nosec: B321 if params.host: - port = ftplib.FTP_PORT + port: int = int(ftplib.FTP_PORT) if params.port is not None: port = params.port logger.info("Connecting via FTP to %s:%d", params.host, port) self.conn.connect(params.host, port) if params.login: + params.password = cast("str", params.password) self.conn.login(params.login, params.password) self.conn.set_pasv(pasv) @@ -293,6 +297,9 @@ def get_conn(self) -> ftplib.FTP: # Construct FTP_TLS instance with SSL context to allow certificates to be validated by default context = ssl.create_default_context() + params.host = cast("str", params.host) + params.password = cast("str", params.password) + params.login = cast("str", params.login) self.conn = ftplib.FTP_TLS(params.host, params.login, params.password, context=context) # nosec: B321 self.conn.set_pasv(pasv) diff --git a/providers/git/src/airflow/providers/git/hooks/git.py b/providers/git/src/airflow/providers/git/hooks/git.py index 86ee583dec13d..a88908e116a9c 100644 --- a/providers/git/src/airflow/providers/git/hooks/git.py +++ b/providers/git/src/airflow/providers/git/hooks/git.py @@ -25,7 +25,11 @@ from typing import Any from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] log = logging.getLogger(__name__) diff --git a/providers/github/src/airflow/providers/github/hooks/github.py b/providers/github/src/airflow/providers/github/hooks/github.py index 6be50fd31fec4..d833be9d4b38b 100644 --- a/providers/github/src/airflow/providers/github/hooks/github.py +++ b/providers/github/src/airflow/providers/github/hooks/github.py @@ -24,7 +24,11 @@ from github import Github as GithubClient from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] class GithubHook(BaseHook): diff --git a/providers/google/src/airflow/providers/google/ads/hooks/ads.py b/providers/google/src/airflow/providers/google/ads/hooks/ads.py index 2f734aa8a9c1d..a2a33321fb4e4 100644 --- a/providers/google/src/airflow/providers/google/ads/hooks/ads.py +++ b/providers/google/src/airflow/providers/google/ads/hooks/ads.py @@ -28,9 +28,13 @@ from google.auth.exceptions import GoogleAuthError from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook from airflow.providers.google.common.hooks.base_google import get_field +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] + if TYPE_CHECKING: from google.ads.googleads.v19.services.services.customer_service import CustomerServiceClient from google.ads.googleads.v19.services.services.google_ads_service import GoogleAdsServiceClient diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_sql.py b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_sql.py index 8ffdd7c1a484d..ae6c788d47081 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_sql.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_sql.py @@ -38,7 +38,7 @@ from pathlib import Path from subprocess import PIPE, Popen from tempfile import NamedTemporaryFile, _TemporaryFileWrapper, gettempdir -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from urllib.parse import quote_plus import httpx @@ -50,7 +50,6 @@ # Number of retries - used by googleapiclient method calls to perform retries # For requests that are "retriable" from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook from airflow.models import Connection from airflow.providers.google.cloud.hooks.secret_manager import ( GoogleCloudSecretManagerHook, @@ -61,6 +60,11 @@ GoogleBaseHook, get_field, ) + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] from airflow.utils.log.logging_mixin import LoggingMixin if TYPE_CHECKING: @@ -847,8 +851,8 @@ def __init__( self.user = self._get_iam_db_login() self.password = self._generate_login_token(service_account=self.cloudsql_connection.login) else: - self.user = self.cloudsql_connection.login - self.password = self.cloudsql_connection.password + self.user = cast("str", self.cloudsql_connection.login) + self.password = cast("str", self.cloudsql_connection.password) self.public_ip = self.cloudsql_connection.host self.public_port = self.cloudsql_connection.port self.ssl_cert = ssl_cert diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/dataprep.py b/providers/google/src/airflow/providers/google/cloud/hooks/dataprep.py index 39d57b004ddca..d1e68024d783d 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/dataprep.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/dataprep.py @@ -28,7 +28,10 @@ from requests import HTTPError from tenacity import retry, stop_after_attempt, wait_exponential -from airflow.hooks.base import BaseHook +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] def _get_field(extras: dict, field_name: str) -> str | None: diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/looker.py b/providers/google/src/airflow/providers/google/cloud/hooks/looker.py index 34fdaf6f17956..79c2b7e70bcec 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/looker.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/looker.py @@ -29,7 +29,11 @@ from packaging.version import parse as parse_version from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] from airflow.version import version if TYPE_CHECKING: diff --git a/providers/google/src/airflow/providers/google/cloud/operators/cloud_sql.py b/providers/google/src/airflow/providers/google/cloud/operators/cloud_sql.py index 142649603c5f2..0d670dddeae4f 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/cloud_sql.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/cloud_sql.py @@ -28,7 +28,6 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook from airflow.providers.google.cloud.hooks.cloud_sql import CloudSQLDatabaseHook, CloudSQLHook from airflow.providers.google.cloud.links.cloud_sql import CloudSQLInstanceDatabaseLink, CloudSQLInstanceLink from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator @@ -37,6 +36,11 @@ from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, get_field from airflow.providers.google.common.links.storage import FileDetailsLink +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] + if TYPE_CHECKING: from airflow.models import Connection from airflow.providers.openlineage.extractors import OperatorLineage diff --git a/providers/google/src/airflow/providers/google/common/hooks/base_google.py b/providers/google/src/airflow/providers/google/common/hooks/base_google.py index 75ca159906855..76d2248bf1c79 100644 --- a/providers/google/src/airflow/providers/google/common/hooks/base_google.py +++ b/providers/google/src/airflow/providers/google/common/hooks/base_google.py @@ -50,12 +50,16 @@ from airflow import version from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook from airflow.providers.google.cloud.utils.credentials_provider import ( _get_scopes, _get_target_principal_and_delegates, get_credentials_and_project_id, ) + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] from airflow.utils.process_utils import patch_environ if TYPE_CHECKING: diff --git a/providers/google/src/airflow/providers/google/leveldb/hooks/leveldb.py b/providers/google/src/airflow/providers/google/leveldb/hooks/leveldb.py index 9423a604d6ebd..b3212632ec742 100644 --- a/providers/google/src/airflow/providers/google/leveldb/hooks/leveldb.py +++ b/providers/google/src/airflow/providers/google/leveldb/hooks/leveldb.py @@ -21,7 +21,11 @@ from typing import Any from airflow.exceptions import AirflowException, AirflowOptionalProviderFeatureException -from airflow.hooks.base import BaseHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] try: import plyvel diff --git a/providers/google/tests/unit/google/ads/hooks/test_ads.py b/providers/google/tests/unit/google/ads/hooks/test_ads.py index 5e32a1440a9a2..e077754ee3872 100644 --- a/providers/google/tests/unit/google/ads/hooks/test_ads.py +++ b/providers/google/tests/unit/google/ads/hooks/test_ads.py @@ -21,6 +21,16 @@ import pytest +try: + import importlib.util + + if not importlib.util.find_spec("airflow.sdk.bases.hook"): + raise ImportError + + BASEHOOK_PATCH_PATH = "airflow.sdk.bases.hook.BaseHook" +except ImportError: + BASEHOOK_PATCH_PATH = "airflow.hooks.base.BaseHook" + from airflow.exceptions import AirflowException from airflow.providers.google.ads.hooks.ads import GoogleAdsHook @@ -46,7 +56,7 @@ params=[EXTRAS_DEVELOPER_TOKEN, EXTRAS_SERVICE_ACCOUNT], ids=["developer_token", "service_account"] ) def mock_hook(request): - with mock.patch("airflow.hooks.base.BaseHook.get_connection") as conn: + with mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") as conn: hook = GoogleAdsHook(api_version=API_VERSION) conn.return_value.extra_dejson = request.param yield hook @@ -61,7 +71,7 @@ def mock_hook(request): ids=["developer_token", "service_account", "empty"], ) def mock_hook_for_authentication_method(request): - with mock.patch("airflow.hooks.base.BaseHook.get_connection") as conn: + with mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") as conn: hook = GoogleAdsHook(api_version=API_VERSION) conn.return_value.extra_dejson = request.param["input"] yield hook, request.param["expected_result"] diff --git a/providers/google/tests/unit/google/cloud/hooks/test_alloy_db.py b/providers/google/tests/unit/google/cloud/hooks/test_alloy_db.py index 09891b0cf8083..40e0f1efaa2e5 100644 --- a/providers/google/tests/unit/google/cloud/hooks/test_alloy_db.py +++ b/providers/google/tests/unit/google/cloud/hooks/test_alloy_db.py @@ -29,6 +29,15 @@ from airflow.providers.google.cloud.hooks.alloy_db import AlloyDbHook from airflow.providers.google.common.consts import CLIENT_INFO +try: + import importlib.util + + if not importlib.util.find_spec("airflow.sdk.bases.hook"): + raise ImportError + + BASEHOOK_PATCH_PATH = "airflow.sdk.bases.hook.BaseHook" +except ImportError: + BASEHOOK_PATCH_PATH = "airflow.hooks.base.BaseHook" TEST_GCP_PROJECT = "test-project" TEST_GCP_REGION = "global" TEST_GCP_CONN_ID = "test_conn_id" @@ -58,7 +67,7 @@ class TestAlloyDbHook: def setup_method(self): - with mock.patch("airflow.hooks.base.BaseHook.get_connection"): + with mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection"): self.hook = AlloyDbHook( gcp_conn_id=TEST_GCP_CONN_ID, ) diff --git a/providers/google/tests/unit/google/cloud/hooks/test_dataprep.py b/providers/google/tests/unit/google/cloud/hooks/test_dataprep.py index a9056cdd41954..2a260b937b6fa 100644 --- a/providers/google/tests/unit/google/cloud/hooks/test_dataprep.py +++ b/providers/google/tests/unit/google/cloud/hooks/test_dataprep.py @@ -28,6 +28,15 @@ from airflow.providers.google.cloud.hooks.dataprep import GoogleDataprepHook +try: + import importlib.util + + if not importlib.util.find_spec("airflow.sdk.bases.hook"): + raise ImportError + + BASEHOOK_PATCH_PATH = "airflow.sdk.bases.hook.BaseHook" +except ImportError: + BASEHOOK_PATCH_PATH = "airflow.hooks.base.BaseHook" JOB_ID = 1234567 RECIPE_ID = 1234567 TOKEN = "1111" @@ -49,7 +58,7 @@ class TestGoogleDataprepHook: def setup_method(self): - with mock.patch("airflow.hooks.base.BaseHook.get_connection") as conn: + with mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") as conn: conn.return_value.extra_dejson = EXTRA self.hook = GoogleDataprepHook(dataprep_conn_id="dataprep_default") self._imported_dataset_id = 12345 @@ -602,7 +611,7 @@ def setup_method(self): "description": "Test description", } ) - with mock.patch("airflow.hooks.base.BaseHook.get_connection") as conn: + with mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") as conn: conn.return_value.extra_dejson = EXTRA self.hook = GoogleDataprepHook(dataprep_conn_id="dataprep_default") diff --git a/providers/google/tests/unit/google/cloud/hooks/test_looker.py b/providers/google/tests/unit/google/cloud/hooks/test_looker.py index e06c7a31e0f4f..6de0ed62695d8 100644 --- a/providers/google/tests/unit/google/cloud/hooks/test_looker.py +++ b/providers/google/tests/unit/google/cloud/hooks/test_looker.py @@ -26,6 +26,15 @@ from airflow.providers.google.cloud.hooks.looker import JobStatus, LookerHook from airflow.version import version +try: + import importlib.util + + if not importlib.util.find_spec("airflow.sdk.bases.hook"): + raise ImportError + + BASEHOOK_PATCH_PATH = "airflow.sdk.bases.hook.BaseHook" +except ImportError: + BASEHOOK_PATCH_PATH = "airflow.hooks.base.BaseHook" HOOK_PATH = "airflow.providers.google.cloud.hooks.looker.LookerHook.{}" JOB_ID = "test-id" @@ -39,7 +48,7 @@ class TestLookerHook: def setup_method(self): - with mock.patch("airflow.hooks.base.BaseHook.get_connection") as conn: + with mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") as conn: conn.return_value.extra_dejson = CONN_EXTRA self.hook = LookerHook(looker_conn_id="test") diff --git a/providers/google/tests/unit/google/cloud/operators/test_cloud_sql.py b/providers/google/tests/unit/google/cloud/operators/test_cloud_sql.py index fdfcd4ee804e1..e3a27f05d2046 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_cloud_sql.py +++ b/providers/google/tests/unit/google/cloud/operators/test_cloud_sql.py @@ -23,6 +23,15 @@ import pytest +try: + import importlib.util + + if not importlib.util.find_spec("airflow.sdk.bases.hook"): + raise ImportError + + BASEHOOK_PATCH_PATH = "airflow.sdk.bases.hook.BaseHook" +except ImportError: + BASEHOOK_PATCH_PATH = "airflow.hooks.base.BaseHook" from airflow.exceptions import AirflowException, TaskDeferred from airflow.models import Connection from airflow.providers.common.compat.openlineage.facet import ( @@ -791,7 +800,7 @@ def _setup_connections(get_connection, uri): ), ], ) - @mock.patch("airflow.hooks.base.BaseHook.get_connection") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") def test_create_operator_with_wrong_parameters( self, get_connection, @@ -816,7 +825,7 @@ def test_create_operator_with_wrong_parameters( err = ctx.value assert message in str(err) - @mock.patch("airflow.hooks.base.BaseHook.get_connection") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") def test_create_operator_with_too_long_unix_socket_path(self, get_connection): uri = ( "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=postgres&" diff --git a/providers/google/tests/unit/google/suite/hooks/test_drive.py b/providers/google/tests/unit/google/suite/hooks/test_drive.py index 3bef29d3079af..dd75c505858ab 100644 --- a/providers/google/tests/unit/google/suite/hooks/test_drive.py +++ b/providers/google/tests/unit/google/suite/hooks/test_drive.py @@ -25,12 +25,22 @@ from unit.google.cloud.utils.base_gcp_mock import GCP_CONNECTION_WITH_PROJECT_ID +try: + import importlib.util + + if not importlib.util.find_spec("airflow.sdk.bases.hook"): + raise ImportError + + BASEHOOK_PATCH_PATH = "airflow.sdk.bases.hook.BaseHook" +except ImportError: + BASEHOOK_PATCH_PATH = "airflow.hooks.base.BaseHook" + @pytest.mark.db_test class TestGoogleDriveHook: def setup_method(self): self.patcher_get_connection = mock.patch( - "airflow.hooks.base.BaseHook.get_connection", return_value=GCP_CONNECTION_WITH_PROJECT_ID + f"{BASEHOOK_PATCH_PATH}.get_connection", return_value=GCP_CONNECTION_WITH_PROJECT_ID ) self.patcher_get_connection.start() self.gdrive_hook = GoogleDriveHook(gcp_conn_id="test") diff --git a/providers/grpc/src/airflow/providers/grpc/hooks/grpc.py b/providers/grpc/src/airflow/providers/grpc/hooks/grpc.py index 8171e49d36f4a..7395fcd2a42d2 100644 --- a/providers/grpc/src/airflow/providers/grpc/hooks/grpc.py +++ b/providers/grpc/src/airflow/providers/grpc/hooks/grpc.py @@ -30,7 +30,11 @@ ) from airflow.exceptions import AirflowConfigException -from airflow.hooks.base import BaseHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] class GrpcHook(BaseHook): @@ -82,7 +86,7 @@ def __init__( self.custom_connection_func = custom_connection_func def get_conn(self) -> grpc.Channel: - base_url = self.conn.host + base_url = self.conn.host or "" if self.conn.port: base_url += f":{self.conn.port}" diff --git a/providers/grpc/tests/unit/grpc/hooks/test_grpc.py b/providers/grpc/tests/unit/grpc/hooks/test_grpc.py index ed185b4bf2823..42f7ec38eafb2 100644 --- a/providers/grpc/tests/unit/grpc/hooks/test_grpc.py +++ b/providers/grpc/tests/unit/grpc/hooks/test_grpc.py @@ -27,6 +27,16 @@ from airflow.models import Connection from airflow.providers.grpc.hooks.grpc import GrpcHook +try: + import importlib.util + + if not importlib.util.find_spec("airflow.sdk.bases.hook"): + raise ImportError + + BASEHOOK_PATCH_PATH = "airflow.sdk.bases.hook.BaseHook" +except ImportError: + BASEHOOK_PATCH_PATH = "airflow.hooks.base.BaseHook" + def get_airflow_connection(auth_type="NO_AUTH", credential_pem_file=None, scopes=None): extra = { @@ -68,7 +78,7 @@ def channel_mock(): class TestGrpcHook: @mock.patch("grpc.insecure_channel") - @mock.patch("airflow.hooks.base.BaseHook.get_connection") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") def test_no_auth_connection(self, mock_get_connection, mock_insecure_channel, channel_mock): conn = get_airflow_connection() mock_get_connection.return_value = conn @@ -83,7 +93,7 @@ def test_no_auth_connection(self, mock_get_connection, mock_insecure_channel, ch assert channel == mocked_channel @mock.patch("grpc.insecure_channel") - @mock.patch("airflow.hooks.base.BaseHook.get_connection") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") def test_connection_with_port(self, mock_get_connection, mock_insecure_channel, channel_mock): conn = get_airflow_connection_with_port() mock_get_connection.return_value = conn @@ -98,7 +108,7 @@ def test_connection_with_port(self, mock_get_connection, mock_insecure_channel, assert channel == mocked_channel @mock.patch("airflow.providers.grpc.hooks.grpc.open") - @mock.patch("airflow.hooks.base.BaseHook.get_connection") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") @mock.patch("grpc.ssl_channel_credentials") @mock.patch("grpc.secure_channel") def test_connection_with_ssl( @@ -122,7 +132,7 @@ def test_connection_with_ssl( assert channel == mocked_channel @mock.patch("airflow.providers.grpc.hooks.grpc.open") - @mock.patch("airflow.hooks.base.BaseHook.get_connection") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") @mock.patch("grpc.ssl_channel_credentials") @mock.patch("grpc.secure_channel") def test_connection_with_tls( @@ -145,7 +155,7 @@ def test_connection_with_tls( mock_secure_channel.assert_called_once_with(expected_url, mock_credential_object) assert channel == mocked_channel - @mock.patch("airflow.hooks.base.BaseHook.get_connection") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") @mock.patch("google.auth.jwt.OnDemandCredentials.from_signing_credentials") @mock.patch("google.auth.default") @mock.patch("google.auth.transport.grpc.secure_authorized_channel") @@ -173,7 +183,7 @@ def test_connection_with_jwt( mock_secure_channel.assert_called_once_with(mock_credential_object, None, expected_url) assert channel == mocked_channel - @mock.patch("airflow.hooks.base.BaseHook.get_connection") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") @mock.patch("google.auth.transport.requests.Request") @mock.patch("google.auth.default") @mock.patch("google.auth.transport.grpc.secure_authorized_channel") @@ -201,7 +211,7 @@ def test_connection_with_google_oauth( mock_secure_channel.assert_called_once_with(mock_credential_object, "request", expected_url) assert channel == mocked_channel - @mock.patch("airflow.hooks.base.BaseHook.get_connection") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") def test_custom_connection(self, mock_get_connection, channel_mock): def custom_conn_func(_): mocked_channel = channel_mock.return_value @@ -216,7 +226,7 @@ def custom_conn_func(_): assert channel == mocked_channel - @mock.patch("airflow.hooks.base.BaseHook.get_connection") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") def test_custom_connection_with_no_connection_func(self, mock_get_connection, channel_mock): conn = get_airflow_connection("CUSTOM") mock_get_connection.return_value = conn @@ -225,7 +235,7 @@ def test_custom_connection_with_no_connection_func(self, mock_get_connection, ch with pytest.raises(AirflowConfigException): hook.get_conn() - @mock.patch("airflow.hooks.base.BaseHook.get_connection") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") def test_connection_type_not_supported(self, mock_get_connection, channel_mock): conn = get_airflow_connection("NOT_SUPPORT") mock_get_connection.return_value = conn @@ -235,7 +245,7 @@ def test_connection_type_not_supported(self, mock_get_connection, channel_mock): hook.get_conn() @mock.patch("grpc.intercept_channel") - @mock.patch("airflow.hooks.base.BaseHook.get_connection") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") @mock.patch("grpc.insecure_channel") def test_connection_with_interceptors( self, mock_insecure_channel, mock_get_connection, mock_intercept_channel, channel_mock @@ -252,7 +262,7 @@ def test_connection_with_interceptors( assert channel == mocked_channel mock_intercept_channel.assert_called_once_with(mocked_channel, "test1") - @mock.patch("airflow.hooks.base.BaseHook.get_connection") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") @mock.patch("airflow.providers.grpc.hooks.grpc.GrpcHook.get_conn") def test_simple_run(self, mock_get_conn, mock_get_connection, channel_mock): conn = get_airflow_connection() @@ -267,7 +277,7 @@ def test_simple_run(self, mock_get_conn, mock_get_connection, channel_mock): assert next(response) == "hello" - @mock.patch("airflow.hooks.base.BaseHook.get_connection") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") @mock.patch("airflow.providers.grpc.hooks.grpc.GrpcHook.get_conn") def test_stream_run(self, mock_get_conn, mock_get_connection, channel_mock): conn = get_airflow_connection() diff --git a/providers/hashicorp/src/airflow/providers/hashicorp/hooks/vault.py b/providers/hashicorp/src/airflow/providers/hashicorp/hooks/vault.py index 5029f420bedfb..957d5b7911a70 100644 --- a/providers/hashicorp/src/airflow/providers/hashicorp/hooks/vault.py +++ b/providers/hashicorp/src/airflow/providers/hashicorp/hooks/vault.py @@ -23,12 +23,16 @@ from hvac.exceptions import VaultError -from airflow.hooks.base import BaseHook from airflow.providers.hashicorp._internal_client.vault_client import ( DEFAULT_KUBERNETES_JWT_PATH, DEFAULT_KV_ENGINE_VERSION, _VaultClient, ) + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] from airflow.utils.helpers import merge_dicts if TYPE_CHECKING: diff --git a/providers/http/src/airflow/providers/http/hooks/http.py b/providers/http/src/airflow/providers/http/hooks/http.py index 7635ab67e29b6..8144d4a8e330f 100644 --- a/providers/http/src/airflow/providers/http/hooks/http.py +++ b/providers/http/src/airflow/providers/http/hooks/http.py @@ -33,9 +33,13 @@ from requests_toolbelt.adapters.socket_options import TCPKeepAliveAdapter from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook from airflow.providers.http.exceptions import HttpErrorException, HttpMethodException +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook as BaseHook # type: ignore + if TYPE_CHECKING: from aiohttp.client_reqrep import ClientResponse from requests.adapters import HTTPAdapter @@ -51,7 +55,7 @@ def _url_from_endpoint(base_url: str | None, endpoint: str | None) -> str: def _process_extra_options_from_connection( - conn: Connection, extra_options: dict[str, Any] + conn, extra_options: dict[str, Any] ) -> tuple[dict[str, Any], dict[str, Any]]: """ Return the updated extra options from the connection, as well as those passed. @@ -174,7 +178,7 @@ def get_conn( session = Session() connection = self.get_connection(self.http_conn_id) self._set_base_url(connection) - session = self._configure_session_from_auth(session, connection) + session = self._configure_session_from_auth(session, connection) # type: ignore[arg-type] # Since get_conn can be called outside of run, we'll check this again extra_options = extra_options or {} @@ -190,7 +194,7 @@ def get_conn( session.headers.update(headers) return session - def _set_base_url(self, connection: Connection) -> None: + def _set_base_url(self, connection) -> None: host = connection.host or self.default_host schema = connection.schema or "http" # RFC 3986 (https://www.rfc-editor.org/rfc/rfc3986.html#page-16) @@ -216,7 +220,7 @@ def _extract_auth(self, connection: Connection) -> Any | None: return None def _configure_session_from_extra( - self, session: Session, connection: Connection, extra_options: dict[str, Any] + self, session: Session, connection, extra_options: dict[str, Any] ) -> Session: """ Configure the session using both the extra field from the Connection and passed in extra_options. diff --git a/providers/http/src/airflow/providers/http/operators/http.py b/providers/http/src/airflow/providers/http/operators/http.py index 6e6c48ea94a90..2d3a901a300e7 100644 --- a/providers/http/src/airflow/providers/http/operators/http.py +++ b/providers/http/src/airflow/providers/http/operators/http.py @@ -27,7 +27,11 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] from airflow.providers.http.triggers.http import HttpTrigger, serialize_auth_type from airflow.providers.http.version_compat import BaseOperator from airflow.utils.helpers import merge_dicts diff --git a/providers/imap/src/airflow/providers/imap/hooks/imap.py b/providers/imap/src/airflow/providers/imap/hooks/imap.py index a5ab9ad673cf1..f14b720137f78 100644 --- a/providers/imap/src/airflow/providers/imap/hooks/imap.py +++ b/providers/imap/src/airflow/providers/imap/hooks/imap.py @@ -32,7 +32,11 @@ from typing import TYPE_CHECKING, Any from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] from airflow.utils.log.logging_mixin import LoggingMixin if TYPE_CHECKING: @@ -77,8 +81,9 @@ def get_conn(self) -> ImapHook: """ if not self.mail_client: conn = self.get_connection(self.imap_conn_id) - self.mail_client = self._build_client(conn) - self.mail_client.login(conn.login, conn.password) + self.mail_client = self._build_client(conn) # type: ignore[arg-type] + if conn.login and conn.password: + self.mail_client.login(conn.login, conn.password) return self diff --git a/providers/influxdb/src/airflow/providers/influxdb/hooks/influxdb.py b/providers/influxdb/src/airflow/providers/influxdb/hooks/influxdb.py index 902b968a05790..7f039eefdde8d 100644 --- a/providers/influxdb/src/airflow/providers/influxdb/hooks/influxdb.py +++ b/providers/influxdb/src/airflow/providers/influxdb/hooks/influxdb.py @@ -31,7 +31,10 @@ from influxdb_client.client.write.point import Point from influxdb_client.client.write_api import SYNCHRONOUS -from airflow.hooks.base import BaseHook +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] if TYPE_CHECKING: import pandas as pd @@ -92,7 +95,7 @@ def get_conn(self) -> InfluxDBClient: self.connection = self.get_connection(self.influxdb_conn_id) self.extras = self.connection.extra_dejson.copy() - self.uri = self.get_uri(self.connection) + self.uri = self.get_uri(self.connection) # type: ignore[arg-type] self.log.info("URI: %s", self.uri) if self.client is not None: diff --git a/providers/jdbc/src/airflow/providers/jdbc/hooks/jdbc.py b/providers/jdbc/src/airflow/providers/jdbc/hooks/jdbc.py index 705ed847e02df..7622825109acd 100644 --- a/providers/jdbc/src/airflow/providers/jdbc/hooks/jdbc.py +++ b/providers/jdbc/src/airflow/providers/jdbc/hooks/jdbc.py @@ -21,7 +21,7 @@ import warnings from contextlib import contextmanager from threading import RLock -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from urllib.parse import quote_plus, urlencode import jaydebeapi @@ -32,7 +32,11 @@ from airflow.providers.common.sql.hooks.sql import DbApiHook if TYPE_CHECKING: - from airflow.models.connection import Connection + if TYPE_CHECKING: + try: + from airflow.sdk import Connection + except ImportError: + from airflow.models.connection import Connection # type: ignore[assignment] @contextmanager @@ -186,9 +190,9 @@ def get_sqlalchemy_engine(self, engine_kwargs=None): def get_conn(self) -> jaydebeapi.Connection: conn: Connection = self.connection - host: str = conn.host - login: str = conn.login - psw: str = conn.password + host: str = cast("str", conn.host) + login: str = cast("str", conn.login) + psw: str = cast("str", conn.password) with self.lock: conn = jaydebeapi.connect( @@ -229,7 +233,7 @@ def get_uri(self) -> str: scheme = extra.get("sqlalchemy_scheme") if not scheme: - return conn.host + return cast("str", conn.host) driver = extra.get("sqlalchemy_driver") uri_prefix = f"{scheme}+{driver}" if driver else scheme diff --git a/providers/jenkins/src/airflow/providers/jenkins/hooks/jenkins.py b/providers/jenkins/src/airflow/providers/jenkins/hooks/jenkins.py index e303b6d0ee072..3d8958a43c19b 100644 --- a/providers/jenkins/src/airflow/providers/jenkins/hooks/jenkins.py +++ b/providers/jenkins/src/airflow/providers/jenkins/hooks/jenkins.py @@ -21,7 +21,10 @@ import jenkins -from airflow.hooks.base import BaseHook +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] class JenkinsHook(BaseHook): diff --git a/providers/jenkins/tests/unit/jenkins/hooks/test_jenkins.py b/providers/jenkins/tests/unit/jenkins/hooks/test_jenkins.py index a9f7969986455..c7316e64007e4 100644 --- a/providers/jenkins/tests/unit/jenkins/hooks/test_jenkins.py +++ b/providers/jenkins/tests/unit/jenkins/hooks/test_jenkins.py @@ -23,9 +23,19 @@ from airflow.providers.jenkins.hooks.jenkins import JenkinsHook +try: + import importlib.util + + if not importlib.util.find_spec("airflow.sdk.bases.hook"): + raise ImportError + + BASEHOOK_PATCH_PATH = "airflow.sdk.bases.hook.BaseHook" +except ImportError: + BASEHOOK_PATCH_PATH = "airflow.hooks.base.BaseHook" + class TestJenkinsHook: - @mock.patch("airflow.hooks.base.BaseHook.get_connection") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") def test_client_created_default_http(self, get_connection_mock): """tests `init` method to validate http client creation when all parameters are passed""" default_connection_id = "jenkins_default" @@ -47,7 +57,7 @@ def test_client_created_default_http(self, get_connection_mock): assert hook.jenkins_server is not None assert hook.jenkins_server.server == complete_url - @mock.patch("airflow.hooks.base.BaseHook.get_connection") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") def test_client_created_default_https(self, get_connection_mock): """tests `init` method to validate https client creation when all parameters are passed""" @@ -71,7 +81,7 @@ def test_client_created_default_https(self, get_connection_mock): assert hook.jenkins_server.server == complete_url @pytest.mark.parametrize("param_building", [True, False]) - @mock.patch("airflow.hooks.base.BaseHook.get_connection") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") @mock.patch("jenkins.Jenkins.get_job_info") @mock.patch("jenkins.Jenkins.get_build_info") def test_get_build_building_state( diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/fs/adls.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/fs/adls.py index 84a242015f5b0..ef9fe01d4872c 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/fs/adls.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/fs/adls.py @@ -20,9 +20,13 @@ from azure.identity import ClientSecretCredential -from airflow.hooks.base import BaseHook from airflow.providers.microsoft.azure.utils import get_field, parse_blob_account_url +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] + if TYPE_CHECKING: from fsspec import AbstractFileSystem diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/adx.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/adx.py index e5ccc2342b626..819725ca2c394 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/adx.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/adx.py @@ -28,18 +28,22 @@ import warnings from functools import cached_property -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from azure.kusto.data import ClientRequestProperties, KustoClient, KustoConnectionStringBuilder from azure.kusto.data.exceptions import KustoServiceError from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook from airflow.providers.microsoft.azure.utils import ( add_managed_identity_connection_widgets, get_sync_default_azure_credential, ) +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] + if TYPE_CHECKING: from azure.kusto.data.response import KustoResponseDataSet @@ -170,7 +174,7 @@ def get_required_param(name: str) -> str: if auth_method == "AAD_APP": tenant = get_required_param("tenant") kcsb = KustoConnectionStringBuilder.with_aad_application_key_authentication( - cluster, conn.login, conn.password, tenant + cluster, cast("str", conn.login), cast("str", conn.password), tenant ) elif auth_method == "AAD_APP_CERT": certificate = get_required_param("certificate") @@ -178,7 +182,7 @@ def get_required_param(name: str) -> str: tenant = get_required_param("tenant") kcsb = KustoConnectionStringBuilder.with_aad_application_certificate_authentication( cluster, - conn.login, + cast("str", conn.login), certificate, thumbprint, tenant, @@ -186,7 +190,7 @@ def get_required_param(name: str) -> str: elif auth_method == "AAD_CREDS": tenant = get_required_param("tenant") kcsb = KustoConnectionStringBuilder.with_aad_user_password_authentication( - cluster, conn.login, conn.password, tenant + cluster, cast("str", conn.login), cast("str", conn.password), tenant ) elif auth_method == "AAD_DEVICE": kcsb = KustoConnectionStringBuilder.with_aad_device_authentication(cluster) diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/asb.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/asb.py index ae51108eb43de..a49c9598d4936 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/asb.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/asb.py @@ -37,13 +37,17 @@ SubscriptionProperties, ) -from airflow.hooks.base import BaseHook from airflow.providers.microsoft.azure.utils import ( add_managed_identity_connection_widgets, get_field, get_sync_default_azure_credential, ) +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] + if TYPE_CHECKING: import datetime diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/base_azure.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/base_azure.py index fcc01248b755a..2a625becba763 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/base_azure.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/base_azure.py @@ -22,12 +22,16 @@ from azure.common.credentials import ServicePrincipalCredentials from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook from airflow.providers.microsoft.azure.utils import ( AzureIdentityCredentialAdapter, add_managed_identity_connection_widgets, ) +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] + class AzureBaseHook(BaseHook): """ diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/batch.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/batch.py index 71c5ea1ca9978..e60ff63597865 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/batch.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/batch.py @@ -25,12 +25,16 @@ from azure.batch import BatchServiceClient, batch_auth, models as batch_models from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook from airflow.providers.microsoft.azure.utils import ( AzureIdentityCredentialAdapter, add_managed_identity_connection_widgets, get_field, ) + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] from airflow.utils import timezone if TYPE_CHECKING: diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/container_instance.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/container_instance.py index 914bfccab5911..bcc44569bfeb5 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/container_instance.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/container_instance.py @@ -85,7 +85,9 @@ def get_conn(self) -> Any: if all([conn.login, conn.password, tenant]): self.log.info("Getting connection using specific credentials and subscription_id.") credential = ClientSecretCredential( - client_id=conn.login, client_secret=conn.password, tenant_id=cast("str", tenant) + client_id=cast("str", conn.login), + client_secret=cast("str", conn.password), + tenant_id=cast("str", tenant), ) else: self.log.info("Using DefaultAzureCredential as credential") diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/container_registry.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/container_registry.py index d860a5a567e95..d53f1dd31766b 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/container_registry.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/container_registry.py @@ -20,18 +20,22 @@ from __future__ import annotations from functools import cached_property -from typing import Any +from typing import Any, cast from azure.mgmt.containerinstance.models import ImageRegistryCredential from azure.mgmt.containerregistry import ContainerRegistryManagementClient -from airflow.hooks.base import BaseHook from airflow.providers.microsoft.azure.utils import ( add_managed_identity_connection_widgets, get_field, get_sync_default_azure_credential, ) +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] + class AzureContainerRegistryHook(BaseHook): """ @@ -121,4 +125,6 @@ def get_conn(self) -> ImageRegistryCredential: credentials = client.registries.list_credentials(resource_group, conn.login).as_dict() password = credentials["passwords"][0]["value"] - return ImageRegistryCredential(server=conn.host, username=conn.login, password=password) + return ImageRegistryCredential( + server=cast("str", conn.host), username=cast("str", conn.login), password=password + ) diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/container_volume.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/container_volume.py index 6e2700bf1e964..c652940cb5200 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/container_volume.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/container_volume.py @@ -16,18 +16,22 @@ # under the License. from __future__ import annotations -from typing import Any +from typing import Any, cast from azure.mgmt.containerinstance.models import AzureFileVolume, Volume from azure.mgmt.storage import StorageManagementClient -from airflow.hooks.base import BaseHook from airflow.providers.microsoft.azure.utils import ( add_managed_identity_connection_widgets, get_field, get_sync_default_azure_credential, ) +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] + class AzureContainerVolumeHook(BaseHook): """ @@ -121,7 +125,7 @@ def get_storagekey(self, *, storage_account_name: str | None = None) -> str: ) return storage_account_list_keys_result.as_dict()["keys"][0]["value"] - return conn.password + return cast("str", conn.password) def get_file_volume( self, mount_name: str, share_name: str, storage_account_name: str, read_only: bool = False diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/cosmos.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/cosmos.py index f654c0b386190..296c49297ad35 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/cosmos.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/cosmos.py @@ -27,7 +27,7 @@ from __future__ import annotations import uuid -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from urllib.parse import urlparse from azure.cosmos import PartitionKey @@ -36,13 +36,17 @@ from azure.mgmt.cosmosdb import CosmosDBManagementClient from airflow.exceptions import AirflowBadRequest, AirflowException -from airflow.hooks.base import BaseHook from airflow.providers.microsoft.azure.utils import ( add_managed_identity_connection_widgets, get_field, get_sync_default_azure_credential, ) +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] + if TYPE_CHECKING: PartitionKeyType = str | list[str] @@ -131,6 +135,7 @@ def get_conn(self) -> CosmosClient: conn = self.get_connection(self.conn_id) extras = conn.extra_dejson endpoint_uri = conn.login + endpoint_uri = cast("str", endpoint_uri) resource_group_name = self._get_field(extras, "resource_group_name") if conn.password: @@ -147,12 +152,12 @@ def get_conn(self) -> CosmosClient: credential=credential, subscription_id=subscritption_id, ) - + conn.login = cast("str", conn.login) database_account = urlparse(conn.login).netloc.split(".")[0] database_account_keys = management_client.database_accounts.list_keys( resource_group_name, database_account ) - master_key = database_account_keys.primary_master_key + master_key = cast("str", database_account_keys.primary_master_key) else: raise AirflowException("Either password or resource_group_name is required") diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/data_factory.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/data_factory.py index 8f90fc3769d52..fc33473f5f957 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/data_factory.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/data_factory.py @@ -49,13 +49,17 @@ from azure.mgmt.datafactory.aio import DataFactoryManagementClient as AsyncDataFactoryManagementClient from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook from airflow.providers.microsoft.azure.utils import ( add_managed_identity_connection_widgets, get_async_default_azure_credential, get_sync_default_azure_credential, ) +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] + if TYPE_CHECKING: from azure.core.polling import LROPoller from azure.mgmt.datafactory.models import ( diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/data_lake.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/data_lake.py index c01dadc03b3cd..c132a68f532d6 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/data_lake.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/data_lake.py @@ -18,7 +18,7 @@ from __future__ import annotations from functools import cached_property -from typing import Any +from typing import Any, cast from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError from azure.datalake.store import core, lib, multithread @@ -33,7 +33,6 @@ ) from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook from airflow.providers.microsoft.azure.utils import ( AzureIdentityCredentialAdapter, add_managed_identity_connection_widgets, @@ -41,6 +40,11 @@ get_sync_default_azure_credential, ) +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] + Credentials = ClientSecretCredential | AzureIdentityCredentialAdapter | DefaultAzureCredential @@ -355,12 +359,13 @@ def get_conn(self) -> DataLakeServiceClient: # type: ignore[override] app_id = conn.login app_secret = conn.password proxies = extra.get("proxies", {}) - + app_id = cast("str", app_id) + app_secret = cast("str", app_secret) credential = ClientSecretCredential( tenant_id=tenant, client_id=app_id, client_secret=app_secret, proxies=proxies ) elif conn.password: - credential = conn.password + credential = conn.password # type: ignore[assignment] else: managed_identity_client_id = self._get_field(extra, "managed_identity_client_id") workload_identity_tenant_id = self._get_field(extra, "workload_identity_tenant_id") diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/fileshare.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/fileshare.py index 24eb45be5d68f..f45c9766e7d72 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/fileshare.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/fileshare.py @@ -21,12 +21,16 @@ from azure.storage.fileshare import FileProperties, ShareDirectoryClient, ShareFileClient, ShareServiceClient -from airflow.hooks.base import BaseHook from airflow.providers.microsoft.azure.utils import ( add_managed_identity_connection_widgets, get_sync_default_azure_credential, ) +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] + class AzureFileShareHook(BaseHook): """ diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py index be1725836a3c5..71f286fb641ab 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py @@ -50,7 +50,11 @@ AirflowException, AirflowNotFoundException, ) -from airflow.hooks.base import BaseHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] if TYPE_CHECKING: from azure.identity._internal.client_credential_base import ClientCredentialBase @@ -254,7 +258,7 @@ def get_conn(self) -> RequestAdapter: client_secret = connection.password config = connection.extra_dejson if connection.extra else {} api_version = self.get_api_version(config) - host = self.get_host(connection) + host = self.get_host(connection) # type: ignore[arg-type] base_url = config.get("base_url", urljoin(host, api_version)) authority = config.get("authority") proxies = self.get_proxies(config) diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/synapse.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/synapse.py index 9c184b378a558..f470e4e375e94 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/synapse.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/synapse.py @@ -25,13 +25,17 @@ from azure.synapse.spark import SparkClient from airflow.exceptions import AirflowException, AirflowTaskTimeout -from airflow.hooks.base import BaseHook from airflow.providers.microsoft.azure.utils import ( add_managed_identity_connection_widgets, get_field, get_sync_default_azure_credential, ) +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] + if TYPE_CHECKING: from azure.synapse.artifacts.models import CreateRunResponse, PipelineRun from azure.synapse.spark.models import SparkBatchJobOptions diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/wasb.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/wasb.py index 05cdc7ef283ce..b981e1bd049e9 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/wasb.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/wasb.py @@ -29,7 +29,7 @@ import logging import os from functools import cached_property -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from asgiref.sync import sync_to_async from azure.core.exceptions import HttpResponseError, ResourceExistsError, ResourceNotFoundError @@ -46,7 +46,6 @@ ) from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook from airflow.providers.microsoft.azure.utils import ( add_managed_identity_connection_widgets, get_async_default_azure_credential, @@ -54,7 +53,13 @@ parse_blob_account_url, ) +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] + if TYPE_CHECKING: + from azure.core.credentials import TokenCredential from azure.storage.blob._models import BlobProperties AsyncCredentials = AsyncClientSecretCredential | AsyncDefaultAzureCredential @@ -172,8 +177,8 @@ def get_conn(self) -> BlobServiceClient: tenant = self._get_field(extra, "tenant_id") if tenant: # use Active Directory auth - app_id = conn.login - app_secret = conn.password + app_id = cast("str", conn.login) + app_secret = cast("str", conn.password) token_credential = ClientSecretCredential( tenant_id=tenant, client_id=app_id, client_secret=app_secret, **client_secret_auth_config ) @@ -197,7 +202,7 @@ def get_conn(self) -> BlobServiceClient: return BlobServiceClient(account_url=f"{account_url.rstrip('/')}/{sas_token}", **extra) # Fall back to old auth (password) or use managed identity if not provided. - credential = conn.password + credential: str | TokenCredential | None = conn.password if not credential: # Check for account_key in extra fields before falling back to DefaultAzureCredential account_key = self._get_field(extra, "account_key") diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/data_factory.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/data_factory.py index 20c5006d08004..b733e74f39104 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/data_factory.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/data_factory.py @@ -24,7 +24,6 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook from airflow.providers.microsoft.azure.hooks.data_factory import ( AzureDataFactoryHook, AzureDataFactoryPipelineRunException, @@ -32,6 +31,11 @@ get_field, ) from airflow.providers.microsoft.azure.triggers.data_factory import AzureDataFactoryTrigger + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] from airflow.providers.microsoft.azure.version_compat import BaseOperator from airflow.utils.log.logging_mixin import LoggingMixin diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/synapse.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/synapse.py index f83bc00054c3f..85f40863526b2 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/synapse.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/synapse.py @@ -22,7 +22,6 @@ from urllib.parse import urlencode from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook from airflow.providers.microsoft.azure.hooks.synapse import ( AzureSynapseHook, AzureSynapsePipelineHook, @@ -32,6 +31,11 @@ ) from airflow.providers.microsoft.azure.version_compat import BaseOperator +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] + if TYPE_CHECKING: from azure.synapse.spark.models import SparkBatchJobOptions diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/base.py b/providers/microsoft/azure/tests/unit/microsoft/azure/base.py index 34550afa014d7..09cd1d29e73cf 100644 --- a/providers/microsoft/azure/tests/unit/microsoft/azure/base.py +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/base.py @@ -25,6 +25,16 @@ from unit.microsoft.azure.test_utils import get_airflow_connection +try: + import importlib.util + + if not importlib.util.find_spec("airflow.sdk.bases.hook"): + raise ImportError + + BASEHOOK_PATCH_PATH = "airflow.sdk.bases.hook.BaseHook" +except ImportError: + BASEHOOK_PATCH_PATH = "airflow.hooks.base.BaseHook" + class Base: def teardown_method(self, method): @@ -33,7 +43,7 @@ def teardown_method(self, method): @contextmanager def patch_hook_and_request_adapter(self, response): with ( - patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection), + patch(f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=get_airflow_connection), patch.object(HttpxRequestAdapter, "get_http_response_message") as mock_get_http_response, ): if isinstance(response, Exception): diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py index e01ab8de6a20d..6e06b61f931e5 100644 --- a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py @@ -62,6 +62,16 @@ AzureIdentityAccessTokenProvider, ) +try: + import importlib.util + + if not importlib.util.find_spec("airflow.sdk.bases.hook"): + raise ImportError + + BASEHOOK_PATCH_PATH = "airflow.sdk.bases.hook.BaseHook" +except ImportError: + BASEHOOK_PATCH_PATH = "airflow.hooks.base.BaseHook" + class TestKiotaRequestAdapterHook: @staticmethod @@ -81,7 +91,7 @@ def assert_tenant_id(request_adapter: RequestAdapter, expected_tenant_id: str): def test_get_conn(self): with patch( - "airflow.hooks.base.BaseHook.get_connection", + f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=get_airflow_connection, ): hook = KiotaRequestAdapterHook(conn_id="msgraph_api") @@ -98,7 +108,7 @@ def test_get_conn_with_custom_base_url(self): ) with patch( - "airflow.hooks.base.BaseHook.get_connection", + f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=connection, ): hook = KiotaRequestAdapterHook(conn_id="msgraph_api") @@ -116,7 +126,7 @@ def test_get_conn_with_proxies_as_string(self): ) with patch( - "airflow.hooks.base.BaseHook.get_connection", + f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=connection, ): hook = KiotaRequestAdapterHook(conn_id="msgraph_api") @@ -135,7 +145,7 @@ def test_get_conn_with_proxies_as_invalid_string(self): ) with patch( - "airflow.hooks.base.BaseHook.get_connection", + f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=connection, ): hook = KiotaRequestAdapterHook(conn_id="msgraph_api") @@ -152,7 +162,7 @@ def test_get_conn_with_proxies_as_json(self): ) with patch( - "airflow.hooks.base.BaseHook.get_connection", + f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=connection, ): hook = KiotaRequestAdapterHook(conn_id="msgraph_api") @@ -164,7 +174,7 @@ def test_get_conn_with_proxies_as_json(self): def test_scopes_when_default(self): with patch( - "airflow.hooks.base.BaseHook.get_connection", + f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=get_airflow_connection, ): hook = KiotaRequestAdapterHook(conn_id="msgraph_api") @@ -173,7 +183,7 @@ def test_scopes_when_default(self): def test_scopes_when_passed_as_string(self): with patch( - "airflow.hooks.base.BaseHook.get_connection", + f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=get_airflow_connection, ): hook = KiotaRequestAdapterHook( @@ -184,7 +194,7 @@ def test_scopes_when_passed_as_string(self): def test_scopes_when_passed_as_list(self): with patch( - "airflow.hooks.base.BaseHook.get_connection", + f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=get_airflow_connection, ): hook = KiotaRequestAdapterHook( @@ -195,7 +205,7 @@ def test_scopes_when_passed_as_list(self): def test_api_version(self): with patch( - "airflow.hooks.base.BaseHook.get_connection", + f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=get_airflow_connection, ): hook = KiotaRequestAdapterHook(conn_id="msgraph_api") @@ -204,7 +214,7 @@ def test_api_version(self): def test_get_api_version_when_empty_config_dict(self): with patch( - "airflow.hooks.base.BaseHook.get_connection", + f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=get_airflow_connection, ): hook = KiotaRequestAdapterHook(conn_id="msgraph_api") @@ -214,7 +224,7 @@ def test_get_api_version_when_empty_config_dict(self): def test_get_api_version_when_api_version_in_config_dict(self): with patch( - "airflow.hooks.base.BaseHook.get_connection", + f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=get_airflow_connection, ): hook = KiotaRequestAdapterHook(conn_id="msgraph_api") @@ -224,7 +234,7 @@ def test_get_api_version_when_api_version_in_config_dict(self): def test_get_api_version_when_custom_api_version_in_config_dict(self): with patch( - "airflow.hooks.base.BaseHook.get_connection", + f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=get_airflow_connection, ): hook = KiotaRequestAdapterHook(conn_id="msgraph_api", api_version="v1") @@ -234,7 +244,7 @@ def test_get_api_version_when_custom_api_version_in_config_dict(self): def test_get_host_when_connection_has_scheme_and_host(self): with patch( - "airflow.hooks.base.BaseHook.get_connection", + f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=get_airflow_connection, ): hook = KiotaRequestAdapterHook(conn_id="msgraph_api") @@ -245,7 +255,7 @@ def test_get_host_when_connection_has_scheme_and_host(self): def test_get_host_when_connection_has_no_scheme_or_host(self): with patch( - "airflow.hooks.base.BaseHook.get_connection", + f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=get_airflow_connection, ): hook = KiotaRequestAdapterHook(conn_id="msgraph_api") @@ -256,7 +266,7 @@ def test_get_host_when_connection_has_no_scheme_or_host(self): def test_tenant_id(self): with patch( - "airflow.hooks.base.BaseHook.get_connection", + f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=get_airflow_connection, ): hook = KiotaRequestAdapterHook(conn_id="msgraph_api") @@ -271,7 +281,7 @@ def test_azure_tenant_id(self): ) with patch( - "airflow.hooks.base.BaseHook.get_connection", + f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=airflow_connection, ): hook = KiotaRequestAdapterHook(conn_id="msgraph_api") @@ -294,7 +304,7 @@ def test_request_information_with_custom_host(self): ) with patch( - "airflow.hooks.base.BaseHook.get_connection", + f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=connection, ): hook = KiotaRequestAdapterHook(conn_id="msgraph_api") @@ -309,7 +319,7 @@ def test_request_information_with_custom_host(self): @pytest.mark.asyncio async def test_throw_failed_responses_with_text_plain_content_type(self): with patch( - "airflow.hooks.base.BaseHook.get_connection", + f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=get_airflow_connection, ): hook = KiotaRequestAdapterHook(conn_id="msgraph_api") @@ -328,7 +338,7 @@ async def test_throw_failed_responses_with_text_plain_content_type(self): @pytest.mark.asyncio async def test_throw_failed_responses_with_application_json_content_type(self): with patch( - "airflow.hooks.base.BaseHook.get_connection", + f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=get_airflow_connection, ): hook = KiotaRequestAdapterHook(conn_id="msgraph_api") diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_powerbi.py b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_powerbi.py index 6d1bdddeef90c..3ba4d3cda2975 100644 --- a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_powerbi.py +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_powerbi.py @@ -35,6 +35,15 @@ from unit.microsoft.azure.base import Base from unit.microsoft.azure.test_utils import get_airflow_connection +try: + import importlib.util + + if not importlib.util.find_spec("airflow.sdk.bases.hook"): + raise ImportError + + BASEHOOK_PATCH_PATH = "airflow.sdk.bases.hook.BaseHook" +except ImportError: + BASEHOOK_PATCH_PATH = "airflow.hooks.base.BaseHook" DEFAULT_CONNECTION_CLIENT_SECRET = "powerbi_conn_id" TASK_ID = "run_powerbi_operator" GROUP_ID = "group_id" @@ -93,7 +102,7 @@ class TestPowerBIDatasetRefreshOperator(Base): - @mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection) + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=get_airflow_connection) def test_execute_wait_for_termination_with_deferrable(self, connection): operator = PowerBIDatasetRefreshOperator( **CONFIG, @@ -106,7 +115,7 @@ def test_execute_wait_for_termination_with_deferrable(self, connection): assert isinstance(exc.value.trigger, PowerBITrigger) assert exc.value.trigger.dataset_refresh_id is None - @mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection) + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=get_airflow_connection) def test_powerbi_operator_async_get_refresh_status_success(self, connection): """Assert that get_refresh_status log success message""" operator = PowerBIDatasetRefreshOperator( diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_powerbi_list.py b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_powerbi_list.py index 3082940533b81..94c440443d2d4 100644 --- a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_powerbi_list.py +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_powerbi_list.py @@ -36,6 +36,15 @@ from unit.microsoft.azure.base import Base from unit.microsoft.azure.test_utils import get_airflow_connection +try: + import importlib.util + + if not importlib.util.find_spec("airflow.sdk.bases.hook"): + raise ImportError + + BASEHOOK_PATCH_PATH = "airflow.sdk.bases.hook.BaseHook" +except ImportError: + BASEHOOK_PATCH_PATH = "airflow.hooks.base.BaseHook" DEFAULT_CONNECTION_CLIENT_SECRET = "powerbi_conn_id" TASK_ID = "run_powerbi_operators" GROUP_ID = "group_id" @@ -69,7 +78,7 @@ class TestPowerBIDatasetListOperator(Base): - @mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection) + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=get_airflow_connection) def test_powerbi_operator_async_get_dataset_list_success(self, connection): """Assert that get_dataset_list log success message""" operator = PowerBIDatasetListOperator( @@ -149,7 +158,7 @@ def test_execute_complete_no_event(self): class TestPowerBIWorkspaceListOperator(Base): - @mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection) + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=get_airflow_connection) def test_powerbi_operator_async_get_workspace_list_success(self, connection): """Assert that get_workspace_list log success message""" operator = PowerBIWorkspaceListOperator( diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_msgraph.py b/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_msgraph.py index cde310a7d19ac..978ad15892e2c 100644 --- a/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_msgraph.py +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_msgraph.py @@ -44,6 +44,16 @@ mock_response, ) +try: + import importlib.util + + if not importlib.util.find_spec("airflow.sdk.bases.hook"): + raise ImportError + + BASEHOOK_PATCH_PATH = "airflow.sdk.bases.hook.BaseHook" +except ImportError: + BASEHOOK_PATCH_PATH = "airflow.hooks.base.BaseHook" + class TestMSGraphTrigger(Base): def test_run_when_valid_response(self): @@ -104,7 +114,7 @@ def test_run_when_response_is_bytes(self): def test_serialize(self): with patch( - "airflow.hooks.base.BaseHook.get_connection", + f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=get_airflow_connection, ): url = "https://graph.microsoft.com/v1.0/me/drive/items" diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_powerbi.py b/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_powerbi.py index 658b6e3046f3b..737fd96f3fb66 100644 --- a/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_powerbi.py +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_powerbi.py @@ -35,6 +35,15 @@ from unit.microsoft.azure.test_utils import get_airflow_connection +try: + import importlib.util + + if not importlib.util.find_spec("airflow.sdk.bases.hook"): + raise ImportError + + BASEHOOK_PATCH_PATH = "airflow.sdk.bases.hook.BaseHook" +except ImportError: + BASEHOOK_PATCH_PATH = "airflow.hooks.base.BaseHook" POWERBI_CONN_ID = "powerbi_default" DATASET_ID = "dataset_id" GROUP_ID = "group_id" @@ -96,7 +105,7 @@ def powerbi_workspace_list_trigger(timeout=TIMEOUT) -> PowerBIWorkspaceListTrigg class TestPowerBITrigger: - @mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection) + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=get_airflow_connection) def test_powerbi_trigger_serialization(self, connection): """Asserts that the PowerBI Trigger correctly serializes its arguments and classpath.""" powerbi_trigger = PowerBITrigger( diff --git a/providers/microsoft/mssql/src/airflow/providers/microsoft/mssql/hooks/mssql.py b/providers/microsoft/mssql/src/airflow/providers/microsoft/mssql/hooks/mssql.py index 73a646a7f139f..f05e26ec027ca 100644 --- a/providers/microsoft/mssql/src/airflow/providers/microsoft/mssql/hooks/mssql.py +++ b/providers/microsoft/mssql/src/airflow/providers/microsoft/mssql/hooks/mssql.py @@ -101,10 +101,10 @@ def get_conn(self) -> PymssqlConnection: conn = self.connection extra_conn_args = {key: val for key, val in conn.extra_dejson.items() if key != "sqlalchemy_scheme"} return pymssql.connect( - server=conn.host, + server=conn.host or "", user=conn.login, password=conn.password, - database=self.schema or conn.schema, + database=self.schema or conn.schema or "", port=str(conn.port), **extra_conn_args, ) diff --git a/providers/microsoft/psrp/src/airflow/providers/microsoft/psrp/hooks/psrp.py b/providers/microsoft/psrp/src/airflow/providers/microsoft/psrp/hooks/psrp.py index 877c3d237f547..8aed7065a8b8a 100644 --- a/providers/microsoft/psrp/src/airflow/providers/microsoft/psrp/hooks/psrp.py +++ b/providers/microsoft/psrp/src/airflow/providers/microsoft/psrp/hooks/psrp.py @@ -21,7 +21,7 @@ from contextlib import contextmanager from copy import copy from logging import DEBUG, ERROR, INFO, WARNING -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from weakref import WeakKeyDictionary from pypsrp.host import PSHost @@ -30,7 +30,11 @@ from pypsrp.wsman import WSMan from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] INFORMATIONAL_RECORD_LEVEL_MAP = { MessageType.DEBUG_RECORD: DEBUG, @@ -151,6 +155,7 @@ def apply_extra(d, keys): "ssl", ), ) + conn.host = cast("str", conn.host) wsman = WSMan(conn.host, username=conn.login, password=conn.password, **wsman_options) runspace_options = apply_extra(self._runspace_options, ("configuration_name",)) diff --git a/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/hooks/winrm.py b/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/hooks/winrm.py index b35ee7dfa1999..75efdc1588c36 100644 --- a/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/hooks/winrm.py +++ b/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/hooks/winrm.py @@ -26,7 +26,11 @@ from winrm.protocol import Protocol from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] from airflow.utils.platform import getuser # TODO: FIXME please - I have too complex implementation diff --git a/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py b/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py index c38c704a72975..f9b3dade967c7 100644 --- a/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py +++ b/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py @@ -28,7 +28,11 @@ from pymongo.errors import CollectionInvalid from airflow.exceptions import AirflowConfigException -from airflow.hooks.base import BaseHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] if TYPE_CHECKING: from types import TracebackType @@ -121,7 +125,7 @@ def __init__(self, mongo_conn_id: str = default_conn_name, *args, **kwargs) -> N self.mongo_conn_id = mongo_conn_id conn = self.get_connection(self.mongo_conn_id) - self._validate_connection(conn) + self._validate_connection(conn) # type: ignore[arg-type] self.connection = conn self.extras = self.connection.extra_dejson.copy() diff --git a/providers/mysql/src/airflow/providers/mysql/hooks/mysql.py b/providers/mysql/src/airflow/providers/mysql/hooks/mysql.py index a9edb03b453dc..ffee8ad187018 100644 --- a/providers/mysql/src/airflow/providers/mysql/hooks/mysql.py +++ b/providers/mysql/src/airflow/providers/mysql/hooks/mysql.py @@ -220,7 +220,7 @@ def get_conn(self) -> MySQLConnectionTypes: "installed in case you see compilation error during installation." ) - conn_config = self._get_conn_config_mysql_client(conn) + conn_config = self._get_conn_config_mysql_client(conn) # type: ignore[arg-type] return MySQLdb.connect(**conn_config) if client_name == "mysql-connector-python": @@ -233,7 +233,7 @@ def get_conn(self) -> MySQLConnectionTypes: "'mysql-connector-python'. Warning! It might cause dependency conflicts." ) - conn_config = self._get_conn_config_mysql_connector_python(conn) + conn_config = self._get_conn_config_mysql_connector_python(conn) # type: ignore[arg-type] return mysql.connector.connect(**conn_config) raise ValueError("Unknown MySQL client name provided!") diff --git a/providers/neo4j/src/airflow/providers/neo4j/hooks/neo4j.py b/providers/neo4j/src/airflow/providers/neo4j/hooks/neo4j.py index 5f58f802086f9..64ded3f9735c5 100644 --- a/providers/neo4j/src/airflow/providers/neo4j/hooks/neo4j.py +++ b/providers/neo4j/src/airflow/providers/neo4j/hooks/neo4j.py @@ -24,7 +24,10 @@ from neo4j import Driver, GraphDatabase -from airflow.hooks.base import BaseHook +try: + from airflow.sdk.bases.hook import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] if TYPE_CHECKING: from airflow.models import Connection @@ -57,12 +60,12 @@ def get_conn(self) -> Driver: self.connection = self.get_connection(self.neo4j_conn_id) - uri = self.get_uri(self.connection) + uri = self.get_uri(self.connection) # type: ignore[arg-type] self.log.info("URI: %s", uri) is_encrypted = self.connection.extra_dejson.get("encrypted", False) - self.client = self.get_client(self.connection, is_encrypted, uri) + self.client = self.get_client(self.connection, is_encrypted, uri) # type: ignore[arg-type] return self.client diff --git a/providers/openai/src/airflow/providers/openai/hooks/openai.py b/providers/openai/src/airflow/providers/openai/hooks/openai.py index b2fd19b96478d..18545aca14248 100644 --- a/providers/openai/src/airflow/providers/openai/hooks/openai.py +++ b/providers/openai/src/airflow/providers/openai/hooks/openai.py @@ -43,9 +43,13 @@ ChatCompletionUserMessageParam, ) from openai.types.vector_stores import VectorStoreFile, VectorStoreFileBatch, VectorStoreFileDeleted -from airflow.hooks.base import BaseHook from airflow.providers.openai.exceptions import OpenAIBatchJobException, OpenAIBatchTimeout +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] + class BatchStatus(str, Enum): """Enum for the status of a batch.""" diff --git a/providers/openfaas/src/airflow/providers/openfaas/hooks/openfaas.py b/providers/openfaas/src/airflow/providers/openfaas/hooks/openfaas.py index 9bda14775c3c5..51c3fa529ee76 100644 --- a/providers/openfaas/src/airflow/providers/openfaas/hooks/openfaas.py +++ b/providers/openfaas/src/airflow/providers/openfaas/hooks/openfaas.py @@ -22,7 +22,11 @@ import requests from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] OK_STATUS_CODE = 202 diff --git a/providers/openfaas/tests/unit/openfaas/hooks/test_openfaas.py b/providers/openfaas/tests/unit/openfaas/hooks/test_openfaas.py index 17172cc383425..7a008373545a6 100644 --- a/providers/openfaas/tests/unit/openfaas/hooks/test_openfaas.py +++ b/providers/openfaas/tests/unit/openfaas/hooks/test_openfaas.py @@ -22,10 +22,14 @@ import pytest from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook from airflow.models import Connection from airflow.providers.openfaas.hooks.openfaas import OpenFaasHook +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] + FUNCTION_NAME = "function_name" diff --git a/providers/openlineage/src/airflow/providers/openlineage/sqlparser.py b/providers/openlineage/src/airflow/providers/openlineage/sqlparser.py index 9cb4ef6f319b8..f461513463cc1 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/sqlparser.py +++ b/providers/openlineage/src/airflow/providers/openlineage/sqlparser.py @@ -39,8 +39,8 @@ from openlineage.client.facet_v2 import JobFacet, RunFacet from sqlalchemy.engine import Engine - from airflow.hooks.base import BaseHook from airflow.providers.common.sql.hooks.sql import DbApiHook + from airflow.sdk import BaseHook log = logging.getLogger(__name__) diff --git a/providers/openlineage/src/airflow/providers/openlineage/utils/sql.py b/providers/openlineage/src/airflow/providers/openlineage/utils/sql.py index 2143dd2f19fd7..904206170e961 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/utils/sql.py +++ b/providers/openlineage/src/airflow/providers/openlineage/utils/sql.py @@ -31,7 +31,7 @@ from sqlalchemy.engine import Engine from sqlalchemy.sql import ClauseElement - from airflow.hooks.base import BaseHook + from airflow.sdk import BaseHook log = logging.getLogger(__name__) diff --git a/providers/opensearch/src/airflow/providers/opensearch/hooks/opensearch.py b/providers/opensearch/src/airflow/providers/opensearch/hooks/opensearch.py index 55184d1b9448f..6a6e69ea4201f 100644 --- a/providers/opensearch/src/airflow/providers/opensearch/hooks/opensearch.py +++ b/providers/opensearch/src/airflow/providers/opensearch/hooks/opensearch.py @@ -27,7 +27,11 @@ from opensearchpy import Connection as OpenSearchConnectionClass from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] from airflow.utils.strings import to_boolean diff --git a/providers/opensearch/tests/unit/opensearch/conftest.py b/providers/opensearch/tests/unit/opensearch/conftest.py index 8882b3e7777b9..93bcf37e00ff6 100644 --- a/providers/opensearch/tests/unit/opensearch/conftest.py +++ b/providers/opensearch/tests/unit/opensearch/conftest.py @@ -20,9 +20,13 @@ import pytest -from airflow.hooks.base import BaseHook from airflow.models import Connection +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] + try: from opensearchpy import OpenSearch diff --git a/providers/opensearch/tests/unit/opensearch/hooks/test_opensearch.py b/providers/opensearch/tests/unit/opensearch/hooks/test_opensearch.py index 7788d899e791b..49eb8eb8fad7a 100644 --- a/providers/opensearch/tests/unit/opensearch/hooks/test_opensearch.py +++ b/providers/opensearch/tests/unit/opensearch/hooks/test_opensearch.py @@ -26,6 +26,15 @@ from airflow.models import Connection from airflow.providers.opensearch.hooks.opensearch import OpenSearchHook +try: + import importlib.util + + if not importlib.util.find_spec("airflow.sdk.bases.hook"): + raise ImportError + + BASEHOOK_PATCH_PATH = "airflow.sdk.bases.hook.BaseHook" +except ImportError: + BASEHOOK_PATCH_PATH = "airflow.hooks.base.BaseHook" opensearchpy = pytest.importorskip("opensearchpy") MOCK_SEARCH_RETURN = {"status": "test"} @@ -52,7 +61,7 @@ def test_delete_check_parameters(self): with pytest.raises(AirflowException, match="must include one of either a query or a document id"): hook.delete(index_name="test_index") - @mock.patch("airflow.hooks.base.BaseHook.get_connection") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") def test_hook_param_bool(self, mock_get_connection): mock_conn = Connection( conn_id="opensearch_default", extra={"use_ssl": "True", "verify_certs": "True"} diff --git a/providers/opsgenie/src/airflow/providers/opsgenie/hooks/opsgenie.py b/providers/opsgenie/src/airflow/providers/opsgenie/hooks/opsgenie.py index fdbbd83f07fb6..23387d0110767 100644 --- a/providers/opsgenie/src/airflow/providers/opsgenie/hooks/opsgenie.py +++ b/providers/opsgenie/src/airflow/providers/opsgenie/hooks/opsgenie.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from typing import Any +from typing import Any, cast from opsgenie_sdk import ( AlertApi, @@ -29,7 +29,10 @@ SuccessResponse, ) -from airflow.hooks.base import BaseHook +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] class OpsgenieAlertHook(BaseHook): @@ -68,7 +71,7 @@ def _get_api_key(self) -> str: :return: API key """ conn = self.get_connection(self.conn_id) - return conn.password + return cast("str", conn.password) def get_conn(self) -> AlertApi: """ diff --git a/providers/oracle/src/airflow/providers/oracle/hooks/oracle.py b/providers/oracle/src/airflow/providers/oracle/hooks/oracle.py index 39b32fd119494..e28fed8a58018 100644 --- a/providers/oracle/src/airflow/providers/oracle/hooks/oracle.py +++ b/providers/oracle/src/airflow/providers/oracle/hooks/oracle.py @@ -20,6 +20,7 @@ import math import warnings from datetime import datetime +from typing import Any import oracledb @@ -144,7 +145,7 @@ def get_conn(self) -> oracledb.Connection: """ conn = self.get_connection(self.oracle_conn_id) # type: ignore[attr-defined] - conn_config = {"user": conn.login, "password": conn.password} + conn_config: dict[str, Any] = {"user": conn.login, "password": conn.password} sid = conn.extra_dejson.get("sid") mod = conn.extra_dejson.get("module") schema = conn.schema @@ -192,7 +193,7 @@ def get_conn(self) -> oracledb.Connection: else: dsn = conn.extra_dejson.get("dsn") if dsn is None: - dsn = conn.host + dsn = conn.host or "" if conn.port is not None: dsn += f":{conn.port}" if service_name: @@ -232,16 +233,16 @@ def get_conn(self) -> oracledb.Connection: conn = oracledb.connect(**conn_config) # type: ignore[assignment] if mod is not None: - conn.module = mod + conn.module = mod # type: ignore[attr-defined] # if Connection.schema is defined, set schema after connecting successfully # cannot be part of conn_config # https://python-oracledb.readthedocs.io/en/latest/api_manual/connection.html?highlight=schema#Connection.current_schema # Only set schema when not using conn.schema as Service Name if schema and service_name: - conn.current_schema = schema + conn.current_schema = schema # type: ignore[attr-defined] - return conn + return conn # type: ignore[return-value] def insert_rows( self, diff --git a/providers/pagerduty/src/airflow/providers/pagerduty/hooks/pagerduty.py b/providers/pagerduty/src/airflow/providers/pagerduty/hooks/pagerduty.py index 2d7bcf9c61d1b..ed6a8e4f0bb7f 100644 --- a/providers/pagerduty/src/airflow/providers/pagerduty/hooks/pagerduty.py +++ b/providers/pagerduty/src/airflow/providers/pagerduty/hooks/pagerduty.py @@ -24,7 +24,11 @@ import pagerduty from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] class PagerdutyHook(BaseHook): diff --git a/providers/pagerduty/src/airflow/providers/pagerduty/hooks/pagerduty_events.py b/providers/pagerduty/src/airflow/providers/pagerduty/hooks/pagerduty_events.py index 2388cf4f834f5..f6338a9b3f87a 100644 --- a/providers/pagerduty/src/airflow/providers/pagerduty/hooks/pagerduty_events.py +++ b/providers/pagerduty/src/airflow/providers/pagerduty/hooks/pagerduty_events.py @@ -24,7 +24,11 @@ import pagerduty from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] if TYPE_CHECKING: from datetime import datetime diff --git a/providers/papermill/src/airflow/providers/papermill/hooks/kernel.py b/providers/papermill/src/airflow/providers/papermill/hooks/kernel.py index 60c4edc6fd7f6..6cc730a02195a 100644 --- a/providers/papermill/src/airflow/providers/papermill/hooks/kernel.py +++ b/providers/papermill/src/airflow/providers/papermill/hooks/kernel.py @@ -17,6 +17,7 @@ from __future__ import annotations import typing +from typing import cast from jupyter_client import AsyncKernelManager from papermill.clientwrap import PapermillNotebookClient @@ -24,7 +25,10 @@ from papermill.utils import merge_kwargs, remove_args from traitlets import Unicode -from airflow.hooks.base import BaseHook +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] JUPYTER_KERNEL_SHELL_PORT = 60316 JUPYTER_KERNEL_IOPUB_PORT = 60317 @@ -68,7 +72,7 @@ def __init__(self, kernel_conn_id: str = default_conn_name, *args, **kwargs) -> def get_conn(self) -> KernelConnection: kernel_connection = KernelConnection() - kernel_connection.ip = self.kernel_conn.host + kernel_connection.ip = cast("str", self.kernel_conn.host) kernel_connection.shell_port = self.kernel_conn.extra_dejson.get( "shell_port", JUPYTER_KERNEL_SHELL_PORT ) diff --git a/providers/pinecone/src/airflow/providers/pinecone/hooks/pinecone.py b/providers/pinecone/src/airflow/providers/pinecone/hooks/pinecone.py index cfb354974b559..7577a21522a53 100644 --- a/providers/pinecone/src/airflow/providers/pinecone/hooks/pinecone.py +++ b/providers/pinecone/src/airflow/providers/pinecone/hooks/pinecone.py @@ -26,7 +26,10 @@ from pinecone import Pinecone, PodSpec, PodType, ServerlessSpec -from airflow.hooks.base import BaseHook +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] if TYPE_CHECKING: from pinecone import Vector @@ -127,7 +130,7 @@ def pinecone_client(self) -> Pinecone: @cached_property def conn(self) -> Connection: - return self.get_connection(self.conn_id) + return self.get_connection(self.conn_id) # type: ignore[return-value] def test_connection(self) -> tuple[bool, str]: try: diff --git a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py index e46839f598179..c8c87734d4454 100644 --- a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py +++ b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py @@ -20,7 +20,7 @@ import os from contextlib import closing from copy import deepcopy -from typing import TYPE_CHECKING, Any, TypeAlias +from typing import TYPE_CHECKING, Any, TypeAlias, cast import psycopg2 import psycopg2.extensions @@ -35,10 +35,14 @@ if TYPE_CHECKING: from psycopg2.extensions import connection - from airflow.models.connection import Connection from airflow.providers.common.sql.dialects.dialect import Dialect from airflow.providers.openlineage.sqlparser import DatabaseInfo + try: + from airflow.sdk import Connection + except ImportError: + from airflow.models.connection import Connection # type: ignore[assignment] + CursorType: TypeAlias = DictCursor | RealDictCursor | NamedTupleCursor @@ -256,7 +260,9 @@ def get_iam_token(self, conn: Connection) -> tuple[str, str, int]: port = conn.port or 5439 # Pull the custer-identifier from the beginning of the Redshift URL # ex. my-cluster.ccdre4hpd39h.us-east-1.redshift.amazonaws.com returns my-cluster - cluster_identifier = conn.extra_dejson.get("cluster-identifier", conn.host.split(".")[0]) + cluster_identifier = conn.extra_dejson.get( + "cluster-identifier", cast("str", conn.host).split(".")[0] + ) redshift_client = AwsBaseHook(aws_conn_id=aws_conn_id, client_type="redshift").conn # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift/client/get_cluster_credentials.html#Redshift.Client.get_cluster_credentials cluster_creds = redshift_client.get_cluster_credentials( @@ -272,7 +278,7 @@ def get_iam_token(self, conn: Connection) -> tuple[str, str, int]: # Pull the workgroup-name from the query params/extras, if not there then pull it from the # beginning of the Redshift URL # ex. workgroup-name.ccdre4hpd39h.us-east-1.redshift.amazonaws.com returns workgroup-name - workgroup_name = conn.extra_dejson.get("workgroup-name", conn.host.split(".")[0]) + workgroup_name = conn.extra_dejson.get("workgroup-name", cast("str", conn.host).split(".")[0]) redshift_serverless_client = AwsBaseHook( aws_conn_id=aws_conn_id, client_type="redshift-serverless" ).conn @@ -288,7 +294,7 @@ def get_iam_token(self, conn: Connection) -> tuple[str, str, int]: rds_client = AwsBaseHook(aws_conn_id=aws_conn_id, client_type="rds").conn # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds/client/generate_db_auth_token.html#RDS.Client.generate_db_auth_token token = rds_client.generate_db_auth_token(conn.host, port, conn.login) - return login, token, port + return cast("str", login), cast("str", token), port def get_table_primary_key(self, table: str, schema: str | None = "public") -> list[str] | None: """ diff --git a/providers/qdrant/src/airflow/providers/qdrant/hooks/qdrant.py b/providers/qdrant/src/airflow/providers/qdrant/hooks/qdrant.py index 0e501b8ac6922..66e4dfb5372c4 100644 --- a/providers/qdrant/src/airflow/providers/qdrant/hooks/qdrant.py +++ b/providers/qdrant/src/airflow/providers/qdrant/hooks/qdrant.py @@ -24,7 +24,10 @@ from qdrant_client import QdrantClient from qdrant_client.http.exceptions import UnexpectedResponse -from airflow.hooks.base import BaseHook +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] class QdrantHook(BaseHook): diff --git a/providers/redis/src/airflow/providers/redis/hooks/redis.py b/providers/redis/src/airflow/providers/redis/hooks/redis.py index 97ca40ea9bb5c..b66e36fc51784 100644 --- a/providers/redis/src/airflow/providers/redis/hooks/redis.py +++ b/providers/redis/src/airflow/providers/redis/hooks/redis.py @@ -23,7 +23,10 @@ from redis import Redis -from airflow.hooks.base import BaseHook +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] DEFAULT_SSL_CERT_REQS = "required" ALLOWED_SSL_CERT_REQS = [DEFAULT_SSL_CERT_REQS, "optional", "none"] diff --git a/providers/salesforce/src/airflow/providers/salesforce/hooks/salesforce.py b/providers/salesforce/src/airflow/providers/salesforce/hooks/salesforce.py index 4eb2822945bc5..8a950c9e2b6ca 100644 --- a/providers/salesforce/src/airflow/providers/salesforce/hooks/salesforce.py +++ b/providers/salesforce/src/airflow/providers/salesforce/hooks/salesforce.py @@ -32,7 +32,10 @@ from simple_salesforce import Salesforce, api -from airflow.hooks.base import BaseHook +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] if TYPE_CHECKING: import pandas as pd diff --git a/providers/samba/src/airflow/providers/samba/hooks/samba.py b/providers/samba/src/airflow/providers/samba/hooks/samba.py index 819cccaa1e19a..3f44ccf060ad8 100644 --- a/providers/samba/src/airflow/providers/samba/hooks/samba.py +++ b/providers/samba/src/airflow/providers/samba/hooks/samba.py @@ -24,7 +24,10 @@ import smbclient -from airflow.hooks.base import BaseHook +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] if TYPE_CHECKING: import smbprotocol.connection diff --git a/providers/samba/tests/unit/samba/hooks/test_samba.py b/providers/samba/tests/unit/samba/hooks/test_samba.py index 285eadb1ee906..b843bd6f92cba 100644 --- a/providers/samba/tests/unit/samba/hooks/test_samba.py +++ b/providers/samba/tests/unit/samba/hooks/test_samba.py @@ -26,6 +26,15 @@ from airflow.models import Connection from airflow.providers.samba.hooks.samba import SambaHook +try: + import importlib.util + + if not importlib.util.find_spec("airflow.sdk.bases.hook"): + raise ImportError + + BASEHOOK_PATCH_PATH = "airflow.sdk.bases.hook.BaseHook" +except ImportError: + BASEHOOK_PATCH_PATH = "airflow.hooks.base.BaseHook" PATH_PARAMETER_NAMES = {"path", "src", "dst"} @@ -36,7 +45,7 @@ def test_get_conn_should_fail_if_conn_id_does_not_exist(self): SambaHook("non-existed-connection-id") @mock.patch("smbclient.register_session") - @mock.patch("airflow.hooks.base.BaseHook.get_connection") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") def test_context_manager(self, get_conn_mock, register_session): CONNECTION = Connection( host="ip", @@ -92,7 +101,7 @@ def test_context_manager(self, get_conn_mock, register_session): "walk", ], ) - @mock.patch("airflow.hooks.base.BaseHook.get_connection") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") def test_method(self, get_conn_mock, name): CONNECTION = Connection( host="ip", @@ -145,7 +154,7 @@ def test_method(self, get_conn_mock, name): ("start/path/without/slash", "//ip/share/start/path/without/slash"), ], ) - @mock.patch("airflow.hooks.base.BaseHook.get_connection") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") def test__join_path(self, get_conn_mock, path, full_path): CONNECTION = Connection( host="ip", diff --git a/providers/segment/src/airflow/providers/segment/hooks/segment.py b/providers/segment/src/airflow/providers/segment/hooks/segment.py index 2c95c03a86e4e..742722f97e0de 100644 --- a/providers/segment/src/airflow/providers/segment/hooks/segment.py +++ b/providers/segment/src/airflow/providers/segment/hooks/segment.py @@ -27,7 +27,11 @@ import segment.analytics as analytics from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] class SegmentHook(BaseHook): diff --git a/providers/sendgrid/src/airflow/providers/sendgrid/utils/emailer.py b/providers/sendgrid/src/airflow/providers/sendgrid/utils/emailer.py index 6ab97e577aede..59b4db1fab9ec 100644 --- a/providers/sendgrid/src/airflow/providers/sendgrid/utils/emailer.py +++ b/providers/sendgrid/src/airflow/providers/sendgrid/utils/emailer.py @@ -38,7 +38,10 @@ SandBoxMode, ) -from airflow.hooks.base import BaseHook +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] from airflow.utils.email import get_email_address_list log = logging.getLogger(__name__) diff --git a/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py b/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py index f5e3c798b1b01..23c18cc2cf47d 100644 --- a/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py +++ b/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py @@ -35,9 +35,13 @@ from asgiref.sync import sync_to_async from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning -from airflow.hooks.base import BaseHook from airflow.providers.ssh.hooks.ssh import SSHHook +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] + if TYPE_CHECKING: from paramiko import SSHClient from paramiko.sftp_attr import SFTPAttributes @@ -723,9 +727,9 @@ async def _get_conn(self) -> asyncssh.SSHClientConnection: """ conn = await sync_to_async(self.get_connection)(self.sftp_conn_id) if conn.extra is not None: - self._parse_extras(conn) + self._parse_extras(conn) # type: ignore[arg-type] - conn_config = { + conn_config: dict[str, Any] = { "host": conn.host, "port": conn.port, "username": conn.login, @@ -737,10 +741,10 @@ async def _get_conn(self) -> asyncssh.SSHClientConnection: if self.known_hosts.lower() == "none": conn_config.update(known_hosts=None) else: - conn_config.update(known_hosts=self.known_hosts) + conn_config.update(known_hosts=self.known_hosts) # type: ignore if self.private_key: _private_key = asyncssh.import_private_key(self.private_key, self.passphrase) - conn_config.update(client_keys=[_private_key]) + conn_config["client_keys"] = [_private_key] if self.passphrase: conn_config.update(passphrase=self.passphrase) ssh_client_conn = await asyncssh.connect(**conn_config) diff --git a/providers/slack/src/airflow/providers/slack/hooks/slack.py b/providers/slack/src/airflow/providers/slack/hooks/slack.py index f2101d4320e40..11ddb264332fb 100644 --- a/providers/slack/src/airflow/providers/slack/hooks/slack.py +++ b/providers/slack/src/airflow/providers/slack/hooks/slack.py @@ -30,8 +30,12 @@ from typing_extensions import NotRequired from airflow.exceptions import AirflowException, AirflowNotFoundException -from airflow.hooks.base import BaseHook from airflow.providers.slack.utils import ConnectionExtraConfig + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] from airflow.utils.helpers import exactly_one if TYPE_CHECKING: diff --git a/providers/slack/src/airflow/providers/slack/hooks/slack_webhook.py b/providers/slack/src/airflow/providers/slack/hooks/slack_webhook.py index f9a66c43d6020..f9d06888dd36d 100644 --- a/providers/slack/src/airflow/providers/slack/hooks/slack_webhook.py +++ b/providers/slack/src/airflow/providers/slack/hooks/slack_webhook.py @@ -26,9 +26,13 @@ from slack_sdk import WebhookClient from airflow.exceptions import AirflowException, AirflowNotFoundException -from airflow.hooks.base import BaseHook from airflow.providers.slack.utils import ConnectionExtraConfig +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] + if TYPE_CHECKING: from slack_sdk.http_retry import RetryHandler diff --git a/providers/slack/src/airflow/providers/slack/transfers/base_sql_to_slack.py b/providers/slack/src/airflow/providers/slack/transfers/base_sql_to_slack.py index ba544965ea255..334d206bc0287 100644 --- a/providers/slack/src/airflow/providers/slack/transfers/base_sql_to_slack.py +++ b/providers/slack/src/airflow/providers/slack/transfers/base_sql_to_slack.py @@ -20,9 +20,13 @@ from typing import TYPE_CHECKING, Any from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook from airflow.providers.slack.version_compat import BaseOperator +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] + if TYPE_CHECKING: import pandas as pd from slack_sdk.http_retry import RetryHandler diff --git a/providers/smtp/src/airflow/providers/smtp/hooks/smtp.py b/providers/smtp/src/airflow/providers/smtp/hooks/smtp.py index 3eb3b6ea6c372..ac06effffa4ec 100644 --- a/providers/smtp/src/airflow/providers/smtp/hooks/smtp.py +++ b/providers/smtp/src/airflow/providers/smtp/hooks/smtp.py @@ -34,13 +34,20 @@ from email.mime.text import MIMEText from email.utils import formatdate from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from airflow.exceptions import AirflowException, AirflowNotFoundException -from airflow.hooks.base import BaseHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] if TYPE_CHECKING: - from airflow.models.connection import Connection + try: + from airflow.sdk import Connection + except ImportError: + from airflow.models.connection import Connection # type: ignore[assignment] class SmtpHook(BaseHook): @@ -366,11 +373,11 @@ def from_email(self) -> str | None: @property def smtp_user(self) -> str: - return self.conn.login + return self.conn.login if self.conn.login else "" @property def smtp_password(self) -> str: - return self.conn.password + return self.conn.password if self.conn.password else "" @property def smtp_starttls(self) -> bool: @@ -378,11 +385,11 @@ def smtp_starttls(self) -> bool: @property def host(self) -> str: - return self.conn.host + return self.conn.host if self.conn.host else "" @property def port(self) -> int: - return self.conn.port + return cast("int", self.conn.port) @property def timeout(self) -> int: diff --git a/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py b/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py index c88f2a1de8ed4..5a5d1ab0fb933 100644 --- a/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py +++ b/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py @@ -33,7 +33,11 @@ from tenacity import Retrying, stop_after_attempt, wait_fixed, wait_random from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] from airflow.utils.platform import getuser from airflow.utils.types import NOTSET, ArgNotSet @@ -155,7 +159,8 @@ def __init__( if self.password is None: self.password = conn.password if not self.remote_host: - self.remote_host = conn.host + if conn.host: + self.remote_host = conn.host if self.port is None: self.port = conn.port diff --git a/providers/standard/src/airflow/providers/standard/hooks/filesystem.py b/providers/standard/src/airflow/providers/standard/hooks/filesystem.py index 0877b05ecc631..c3e890d5cf517 100644 --- a/providers/standard/src/airflow/providers/standard/hooks/filesystem.py +++ b/providers/standard/src/airflow/providers/standard/hooks/filesystem.py @@ -20,7 +20,10 @@ from pathlib import Path from typing import Any -from airflow.hooks.base import BaseHook +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] class FSHook(BaseHook): diff --git a/providers/standard/src/airflow/providers/standard/hooks/package_index.py b/providers/standard/src/airflow/providers/standard/hooks/package_index.py index 8475a09369a43..044a224250aa3 100644 --- a/providers/standard/src/airflow/providers/standard/hooks/package_index.py +++ b/providers/standard/src/airflow/providers/standard/hooks/package_index.py @@ -23,7 +23,10 @@ from typing import Any from urllib.parse import quote, urlparse -from airflow.hooks.base import BaseHook +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] class PackageIndexHook(BaseHook): diff --git a/providers/standard/src/airflow/providers/standard/hooks/subprocess.py b/providers/standard/src/airflow/providers/standard/hooks/subprocess.py index cb7338be804f9..5dba4867f45ae 100644 --- a/providers/standard/src/airflow/providers/standard/hooks/subprocess.py +++ b/providers/standard/src/airflow/providers/standard/hooks/subprocess.py @@ -24,7 +24,10 @@ from subprocess import PIPE, STDOUT, Popen from tempfile import TemporaryDirectory, gettempdir -from airflow.hooks.base import BaseHook +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] SubprocessResult = namedtuple("SubprocessResult", ["exit_code", "output"]) diff --git a/providers/tableau/src/airflow/providers/tableau/hooks/tableau.py b/providers/tableau/src/airflow/providers/tableau/hooks/tableau.py index ce0ddb2b65e31..0840fa5504bd1 100644 --- a/providers/tableau/src/airflow/providers/tableau/hooks/tableau.py +++ b/providers/tableau/src/airflow/providers/tableau/hooks/tableau.py @@ -18,12 +18,16 @@ import time from enum import Enum -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from tableauserverclient import Pager, Server, TableauAuth from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] if TYPE_CHECKING: from tableauserverclient.server import Auth @@ -115,7 +119,9 @@ def get_conn(self) -> Auth.contextmgr: def _auth_via_password(self) -> Auth.contextmgr: tableau_auth = TableauAuth( - username=self.conn.login, password=self.conn.password, site_id=self.site_id + username=cast("str", self.conn.login), + password=cast("str", self.conn.password), + site_id=self.site_id, ) return self.server.auth.sign_in(tableau_auth) diff --git a/providers/telegram/src/airflow/providers/telegram/hooks/telegram.py b/providers/telegram/src/airflow/providers/telegram/hooks/telegram.py index fd830c2250b69..8aa809eb7f4c0 100644 --- a/providers/telegram/src/airflow/providers/telegram/hooks/telegram.py +++ b/providers/telegram/src/airflow/providers/telegram/hooks/telegram.py @@ -26,7 +26,11 @@ import tenacity from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] class TelegramHook(BaseHook): diff --git a/providers/teradata/src/airflow/providers/teradata/hooks/teradata.py b/providers/teradata/src/airflow/providers/teradata/hooks/teradata.py index 0742bc374369d..4f0d90e4a298e 100644 --- a/providers/teradata/src/airflow/providers/teradata/hooks/teradata.py +++ b/providers/teradata/src/airflow/providers/teradata/hooks/teradata.py @@ -29,7 +29,10 @@ from airflow.providers.common.sql.hooks.sql import DbApiHook if TYPE_CHECKING: - from airflow.models.connection import Connection + try: + from airflow.sdk import Connection + except ImportError: + from airflow.models.connection import Connection # type: ignore[assignment] PARAM_TYPES = {bool, float, int, str} @@ -176,7 +179,7 @@ def _get_conn_config_teradatasql(self) -> dict[str, Any]: if conn.extra_dejson.get("sslmode", False): conn_config["sslmode"] = conn.extra_dejson["sslmode"] - if "verify" in conn_config["sslmode"]: + if "verify" in str(conn_config["sslmode"]): if conn.extra_dejson.get("sslca", False): conn_config["sslca"] = conn.extra_dejson["sslca"] if conn.extra_dejson.get("sslcapath", False): diff --git a/providers/teradata/src/airflow/providers/teradata/hooks/ttu.py b/providers/teradata/src/airflow/providers/teradata/hooks/ttu.py index 958bd76796e73..9128b6a88ef1f 100644 --- a/providers/teradata/src/airflow/providers/teradata/hooks/ttu.py +++ b/providers/teradata/src/airflow/providers/teradata/hooks/ttu.py @@ -22,7 +22,11 @@ from typing import Any from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook + +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] class TtuHook(BaseHook, ABC): diff --git a/providers/vertica/src/airflow/providers/vertica/hooks/vertica.py b/providers/vertica/src/airflow/providers/vertica/hooks/vertica.py index b19a8ee2a0900..ad09fbc130d4e 100644 --- a/providers/vertica/src/airflow/providers/vertica/hooks/vertica.py +++ b/providers/vertica/src/airflow/providers/vertica/hooks/vertica.py @@ -68,7 +68,7 @@ class VerticaHook(DbApiHook): def get_conn(self) -> connect: """Return vertica connection object.""" conn = self.get_connection(self.vertica_conn_id) # type: ignore - conn_config = { + conn_config: dict[str, Any] = { "user": conn.login, "password": conn.password or "", "database": conn.schema, diff --git a/providers/weaviate/src/airflow/providers/weaviate/hooks/weaviate.py b/providers/weaviate/src/airflow/providers/weaviate/hooks/weaviate.py index d9e451d2aa7d0..5ac87cbc19068 100644 --- a/providers/weaviate/src/airflow/providers/weaviate/hooks/weaviate.py +++ b/providers/weaviate/src/airflow/providers/weaviate/hooks/weaviate.py @@ -33,7 +33,10 @@ from weaviate.exceptions import ObjectAlreadyExistsException from weaviate.util import generate_uuid5 -from airflow.hooks.base import BaseHook +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] if TYPE_CHECKING: from collections.abc import Callable @@ -131,14 +134,14 @@ def get_conn(self) -> WeaviateClient: http_secure = extras.pop("http_secure", False) grpc_secure = extras.pop("grpc_secure", False) return weaviate.connect_to_custom( - http_host=conn.host, + http_host=conn.host, # type: ignore[arg-type] http_port=conn.port or 443 if http_secure else 80, http_secure=http_secure, grpc_host=extras.pop("grpc_host", conn.host), grpc_port=extras.pop("grpc_port", 443 if grpc_secure else 80), grpc_secure=grpc_secure, headers=extras.pop("additional_headers", {}), - auth_credentials=self._extract_auth_credentials(conn), + auth_credentials=self._extract_auth_credentials(conn), # type: ignore[arg-type] ) def _extract_auth_credentials(self, conn: Connection) -> AuthCredentials: diff --git a/providers/yandex/src/airflow/providers/yandex/hooks/yandex.py b/providers/yandex/src/airflow/providers/yandex/hooks/yandex.py index a8e796c08c143..6d4edd8c4c4b1 100644 --- a/providers/yandex/src/airflow/providers/yandex/hooks/yandex.py +++ b/providers/yandex/src/airflow/providers/yandex/hooks/yandex.py @@ -20,7 +20,6 @@ import yandexcloud -from airflow.hooks.base import BaseHook from airflow.providers.yandex.utils.credentials import ( CredentialsType, get_credentials, @@ -30,6 +29,11 @@ from airflow.providers.yandex.utils.fields import get_field_from_extras from airflow.providers.yandex.utils.user_agent import provider_user_agent +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] + class YandexCloudBaseHook(BaseHook): """ diff --git a/providers/yandex/tests/unit/yandex/hooks/test_dataproc.py b/providers/yandex/tests/unit/yandex/hooks/test_dataproc.py index 212d1e3a9c05a..255fee7125f46 100644 --- a/providers/yandex/tests/unit/yandex/hooks/test_dataproc.py +++ b/providers/yandex/tests/unit/yandex/hooks/test_dataproc.py @@ -21,6 +21,15 @@ import pytest +try: + import importlib.util + + if not importlib.util.find_spec("airflow.sdk.bases.hook"): + raise ImportError + + BASEHOOK_PATCH_PATH = "airflow.sdk.bases.hook.BaseHook" +except ImportError: + BASEHOOK_PATCH_PATH = "airflow.hooks.base.BaseHook" yandexlcloud = pytest.importorskip("yandexcloud") from airflow.models import Connection # noqa: E402 @@ -62,7 +71,7 @@ class TestYandexCloudDataprocHook: def _init_hook(self): - with mock.patch("airflow.hooks.base.BaseHook.get_connection") as mock_get_connection: + with mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") as mock_get_connection: mock_get_connection.return_value = self.connection self.hook = DataprocHook() diff --git a/providers/yandex/tests/unit/yandex/hooks/test_yandex.py b/providers/yandex/tests/unit/yandex/hooks/test_yandex.py index 7907c6a434001..5a655fae61ac4 100644 --- a/providers/yandex/tests/unit/yandex/hooks/test_yandex.py +++ b/providers/yandex/tests/unit/yandex/hooks/test_yandex.py @@ -25,11 +25,20 @@ from tests_common.test_utils.config import conf_vars +try: + import importlib.util + + if not importlib.util.find_spec("airflow.sdk.bases.hook"): + raise ImportError + + BASEHOOK_PATCH_PATH = "airflow.sdk.bases.hook.BaseHook" +except ImportError: + BASEHOOK_PATCH_PATH = "airflow.hooks.base.BaseHook" yandexcloud = pytest.importorskip("yandexcloud") class TestYandexHook: - @mock.patch("airflow.hooks.base.BaseHook.get_connection") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") @mock.patch("airflow.providers.yandex.utils.credentials.get_credentials") def test_client_created_without_exceptions(self, mock_get_credentials, mock_get_connection): """tests `init` method to validate client creation when all parameters are passed""" @@ -52,7 +61,7 @@ def test_client_created_without_exceptions(self, mock_get_credentials, mock_get_ ) assert hook.client is not None - @mock.patch("airflow.hooks.base.BaseHook.get_connection") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") @mock.patch("airflow.providers.yandex.utils.credentials.get_credentials") def test_sdk_user_agent(self, mock_get_credentials, mock_get_connection): mock_get_connection.return_value = mock.Mock(yandex_conn_id="yandexcloud_default", extra_dejson="{}") @@ -63,7 +72,7 @@ def test_sdk_user_agent(self, mock_get_credentials, mock_get_connection): hook = YandexCloudBaseHook() assert hook.sdk._channels._client_user_agent.startswith(sdk_prefix) - @mock.patch("airflow.hooks.base.BaseHook.get_connection") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") @mock.patch("airflow.providers.yandex.utils.credentials.get_credentials") def test_get_endpoint_specified(self, mock_get_credentials, mock_get_connection): default_folder_id = "test_id" @@ -83,7 +92,7 @@ def test_get_endpoint_specified(self, mock_get_credentials, mock_get_connection) assert hook._get_endpoint() == {"endpoint": "my_endpoint"} - @mock.patch("airflow.hooks.base.BaseHook.get_connection") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") @mock.patch("airflow.providers.yandex.utils.credentials.get_credentials") def test_get_endpoint_unspecified(self, mock_get_credentials, mock_get_connection): default_folder_id = "test_id" @@ -103,7 +112,7 @@ def test_get_endpoint_unspecified(self, mock_get_credentials, mock_get_connectio assert hook._get_endpoint() == {} - @mock.patch("airflow.hooks.base.BaseHook.get_connection") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") def test__get_field(self, mock_get_connection): field_name = "one" field_value = "value_one" @@ -128,7 +137,7 @@ def test__get_field(self, mock_get_connection): assert res == field_value - @mock.patch("airflow.hooks.base.BaseHook.get_connection") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") def test__get_field_extras_not_found(self, get_connection_mock): field_name = "some_field" default = "some_default" diff --git a/providers/yandex/tests/unit/yandex/hooks/test_yq.py b/providers/yandex/tests/unit/yandex/hooks/test_yq.py index c93a5f53895f6..8472db0989387 100644 --- a/providers/yandex/tests/unit/yandex/hooks/test_yq.py +++ b/providers/yandex/tests/unit/yandex/hooks/test_yq.py @@ -27,6 +27,15 @@ from airflow.models import Connection from airflow.providers.yandex.hooks.yq import YQHook +try: + import importlib.util + + if not importlib.util.find_spec("airflow.sdk.bases.hook"): + raise ImportError + + BASEHOOK_PATCH_PATH = "airflow.sdk.bases.hook.BaseHook" +except ImportError: + BASEHOOK_PATCH_PATH = "airflow.hooks.base.BaseHook" yandexcloud = pytest.importorskip("yandexcloud") OAUTH_TOKEN = "my_oauth_token" @@ -36,7 +45,7 @@ class TestYandexCloudYqHook: def _init_hook(self): - with mock.patch("airflow.hooks.base.BaseHook.get_connection") as mock_get_connection: + with mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") as mock_get_connection: mock_get_connection.return_value = self.connection self.hook = YQHook(default_folder_id="my_folder_id") diff --git a/providers/yandex/tests/unit/yandex/operators/test_dataproc.py b/providers/yandex/tests/unit/yandex/operators/test_dataproc.py index cde731d2d407b..ef430b76c0c9f 100644 --- a/providers/yandex/tests/unit/yandex/operators/test_dataproc.py +++ b/providers/yandex/tests/unit/yandex/operators/test_dataproc.py @@ -66,6 +66,16 @@ # https://cloud.yandex.com/docs/logging/concepts/log-group LOG_GROUP_ID = "my_log_group_id" +try: + import importlib.util + + if not importlib.util.find_spec("airflow.sdk.bases.hook"): + raise ImportError + + BASEHOOK_PATCH_PATH = "airflow.sdk.bases.hook.BaseHook" +except ImportError: + BASEHOOK_PATCH_PATH = "airflow.hooks.base.BaseHook" + class TestDataprocClusterCreateOperator: def setup_method(self): @@ -81,7 +91,7 @@ def setup_method(self): ) @patch("airflow.providers.yandex.utils.credentials.get_credentials") - @patch("airflow.hooks.base.BaseHook.get_connection") + @patch(f"{BASEHOOK_PATCH_PATH}.get_connection") @patch("yandexcloud._wrappers.dataproc.Dataproc.create_cluster") def test_create_cluster(self, mock_create_cluster, *_): operator = DataprocCreateClusterOperator( @@ -145,7 +155,7 @@ def test_create_cluster(self, mock_create_cluster, *_): ) @patch("airflow.providers.yandex.utils.credentials.get_credentials") - @patch("airflow.hooks.base.BaseHook.get_connection") + @patch(f"{BASEHOOK_PATCH_PATH}.get_connection") @patch("yandexcloud._wrappers.dataproc.Dataproc.delete_cluster") def test_delete_cluster_operator(self, mock_delete_cluster, *_): operator = DataprocDeleteClusterOperator( @@ -159,7 +169,7 @@ def test_delete_cluster_operator(self, mock_delete_cluster, *_): mock_delete_cluster.assert_called_once_with("my_cluster_id") @patch("airflow.providers.yandex.utils.credentials.get_credentials") - @patch("airflow.hooks.base.BaseHook.get_connection") + @patch(f"{BASEHOOK_PATCH_PATH}.get_connection") @patch("yandexcloud._wrappers.dataproc.Dataproc.create_hive_job") def test_create_hive_job_operator(self, mock_create_hive_job, *_): operator = DataprocCreateHiveJobOperator( @@ -188,7 +198,7 @@ def test_create_hive_job_operator(self, mock_create_hive_job, *_): ) @patch("airflow.providers.yandex.utils.credentials.get_credentials") - @patch("airflow.hooks.base.BaseHook.get_connection") + @patch(f"{BASEHOOK_PATCH_PATH}.get_connection") @patch("yandexcloud._wrappers.dataproc.Dataproc.create_mapreduce_job") def test_create_mapreduce_job_operator(self, mock_create_mapreduce_job, *_): operator = DataprocCreateMapReduceJobOperator( @@ -258,7 +268,7 @@ def test_create_mapreduce_job_operator(self, mock_create_mapreduce_job, *_): ) @patch("airflow.providers.yandex.utils.credentials.get_credentials") - @patch("airflow.hooks.base.BaseHook.get_connection") + @patch(f"{BASEHOOK_PATCH_PATH}.get_connection") @patch("yandexcloud._wrappers.dataproc.Dataproc.create_spark_job") def test_create_spark_job_operator(self, mock_create_spark_job, *_): operator = DataprocCreateSparkJobOperator( @@ -320,7 +330,7 @@ def test_create_spark_job_operator(self, mock_create_spark_job, *_): ) @patch("airflow.providers.yandex.utils.credentials.get_credentials") - @patch("airflow.hooks.base.BaseHook.get_connection") + @patch(f"{BASEHOOK_PATCH_PATH}.get_connection") @patch("yandexcloud._wrappers.dataproc.Dataproc.create_pyspark_job") def test_create_pyspark_job_operator(self, mock_create_pyspark_job, *_): operator = DataprocCreatePysparkJobOperator( diff --git a/providers/yandex/tests/unit/yandex/operators/test_yq.py b/providers/yandex/tests/unit/yandex/operators/test_yq.py index be09ca243d3a8..4cd778736f2c8 100644 --- a/providers/yandex/tests/unit/yandex/operators/test_yq.py +++ b/providers/yandex/tests/unit/yandex/operators/test_yq.py @@ -32,6 +32,16 @@ yandexcloud = pytest.importorskip("yandexcloud") +try: + import importlib.util + + if not importlib.util.find_spec("airflow.sdk.bases.hook"): + raise ImportError + + BASEHOOK_PATCH_PATH = "airflow.sdk.bases.hook.BaseHook" +except ImportError: + BASEHOOK_PATCH_PATH = "airflow.hooks.base.BaseHook" + OAUTH_TOKEN = "my_oauth_token" FOLDER_ID = "my_folder_id" @@ -50,7 +60,7 @@ def setup_method(self): ) @responses.activate() - @patch("airflow.hooks.base.BaseHook.get_connection") + @patch(f"{BASEHOOK_PATCH_PATH}.get_connection") def test_execute_query(self, mock_get_connection): mock_get_connection.return_value = Connection(extra={"oauth": OAUTH_TOKEN}) operator = YQExecuteQueryOperator(task_id="simple_sql", sql="select 987", folder_id="my_folder_id") diff --git a/providers/ydb/src/airflow/providers/ydb/hooks/ydb.py b/providers/ydb/src/airflow/providers/ydb/hooks/ydb.py index 2b1a319d3333f..5d28b49677cba 100644 --- a/providers/ydb/src/airflow/providers/ydb/hooks/ydb.py +++ b/providers/ydb/src/airflow/providers/ydb/hooks/ydb.py @@ -33,7 +33,10 @@ if TYPE_CHECKING: from ydb_dbapi import Cursor as DbApiCursor - from airflow.models.connection import Connection + try: + from airflow.sdk import Connection + except ImportError: + from airflow.models.connection import Connection # type: ignore[assignment] class YDBCursor: @@ -228,7 +231,8 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]: @property def sqlalchemy_url(self) -> URL: - conn: Connection = self.get_connection(self.get_conn_id()) + # TODO: @amoghrajesh: Handle type better + conn: Connection = self.get_connection(self.get_conn_id()) # type: ignore[assignment] return URL.create( drivername="ydb", username=conn.login, diff --git a/providers/ydb/src/airflow/providers/ydb/utils/credentials.py b/providers/ydb/src/airflow/providers/ydb/utils/credentials.py index db468accf51f9..5ba0da0d69434 100644 --- a/providers/ydb/src/airflow/providers/ydb/utils/credentials.py +++ b/providers/ydb/src/airflow/providers/ydb/utils/credentials.py @@ -23,7 +23,10 @@ import ydb.iam.auth as auth if TYPE_CHECKING: - from airflow.models.connection import Connection + try: + from airflow.sdk import Connection + except ImportError: + from airflow.models.connection import Connection # type: ignore[assignment] log = logging.getLogger(__name__) diff --git a/providers/ydb/tests/unit/ydb/hooks/test_ydb.py b/providers/ydb/tests/unit/ydb/hooks/test_ydb.py index 8ef5f9e998c32..002ca264402fa 100644 --- a/providers/ydb/tests/unit/ydb/hooks/test_ydb.py +++ b/providers/ydb/tests/unit/ydb/hooks/test_ydb.py @@ -21,6 +21,16 @@ from airflow.models import Connection from airflow.providers.ydb.hooks.ydb import YDBHook +try: + import importlib.util + + if not importlib.util.find_spec("airflow.sdk.bases.hook"): + raise ImportError + + BASEHOOK_PATCH_PATH = "airflow.sdk.bases.hook.BaseHook" +except ImportError: + BASEHOOK_PATCH_PATH = "airflow.hooks.base.BaseHook" + class FakeDriver: def wait(*args, **kwargs): @@ -56,7 +66,7 @@ def rowcount(self): return 1 -@patch("airflow.hooks.base.BaseHook.get_connection") +@patch(f"{BASEHOOK_PATCH_PATH}.get_connection") @patch("ydb.Driver") @patch("ydb.QuerySessionPool") @patch("ydb_dbapi.Connection._cursor_cls", new_callable=PropertyMock) diff --git a/providers/ydb/tests/unit/ydb/operators/test_ydb.py b/providers/ydb/tests/unit/ydb/operators/test_ydb.py index d66861317c7c5..7f828370b7806 100644 --- a/providers/ydb/tests/unit/ydb/operators/test_ydb.py +++ b/providers/ydb/tests/unit/ydb/operators/test_ydb.py @@ -27,6 +27,16 @@ from airflow.providers.common.sql.hooks.handlers import fetch_all_handler, fetch_one_handler from airflow.providers.ydb.operators.ydb import YDBExecuteQueryOperator +try: + import importlib.util + + if not importlib.util.find_spec("airflow.sdk.bases.hook"): + raise ImportError + + BASEHOOK_PATCH_PATH = "airflow.sdk.bases.hook.BaseHook" +except ImportError: + BASEHOOK_PATCH_PATH = "airflow.hooks.base.BaseHook" + @pytest.mark.db_test def test_sql_templating(create_task_instance_of_operator): @@ -98,7 +108,7 @@ def setup_method(self): schedule="@once", ) - @patch("airflow.hooks.base.BaseHook.get_connection") + @patch(f"{BASEHOOK_PATCH_PATH}.get_connection") @patch("ydb.Driver") @patch("ydb.QuerySessionPool") @patch("ydb_dbapi.Connection._cursor_cls", new_callable=PropertyMock) diff --git a/providers/zendesk/src/airflow/providers/zendesk/hooks/zendesk.py b/providers/zendesk/src/airflow/providers/zendesk/hooks/zendesk.py index 00ff3b94f0d4e..b493b50a1db4a 100644 --- a/providers/zendesk/src/airflow/providers/zendesk/hooks/zendesk.py +++ b/providers/zendesk/src/airflow/providers/zendesk/hooks/zendesk.py @@ -21,7 +21,10 @@ from zenpy import Zenpy -from airflow.hooks.base import BaseHook +try: + from airflow.sdk import BaseHook +except ImportError: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] if TYPE_CHECKING: from zenpy.lib.api import BaseApi @@ -64,13 +67,16 @@ def _init_conn(self) -> tuple[Zenpy, str]: :return: zenpy.Zenpy client and the url for the API. """ conn = self.get_connection(self.zendesk_conn_id) - url = "https://" + conn.host - domain = conn.host + domain = "" + url = "" subdomain: str | None = None - if conn.host.count(".") >= 2: - dot_splitted_string = conn.host.rsplit(".", 2) - subdomain = dot_splitted_string[0] - domain = ".".join(dot_splitted_string[1:]) + if conn.host: + url = "https://" + conn.host + domain = conn.host + if conn.host.count(".") >= 2: + dot_splitted_string = conn.host.rsplit(".", 2) + subdomain = dot_splitted_string[0] + domain = ".".join(dot_splitted_string[1:]) return Zenpy(domain=domain, subdomain=subdomain, email=conn.login, password=conn.password), url def get_conn(self) -> Zenpy: diff --git a/scripts/in_container/verify_providers.py b/scripts/in_container/verify_providers.py index bfcff18f0f39d..66568c8f4ebd1 100755 --- a/scripts/in_container/verify_providers.py +++ b/scripts/in_container/verify_providers.py @@ -462,8 +462,8 @@ def get_package_class_summary( :return: dictionary of objects usable as context for JINJA2 templates, or None if there are some errors """ - from airflow.hooks.base import BaseHook from airflow.models.baseoperator import BaseOperator + from airflow.sdk import BaseHook from airflow.secrets import BaseSecretsBackend from airflow.sensors.base import BaseSensorOperator from airflow.triggers.base import BaseTrigger diff --git a/scripts/tools/list-integrations.py b/scripts/tools/list-integrations.py index 8054d591597bc..9a1dddb9a5e22 100755 --- a/scripts/tools/list-integrations.py +++ b/scripts/tools/list-integrations.py @@ -26,8 +26,8 @@ from pathlib import Path import airflow -from airflow.hooks.base import BaseHook from airflow.models.baseoperator import BaseOperator +from airflow.sdk import BaseHook from airflow.secrets import BaseSecretsBackend from airflow.sensors.base import BaseSensorOperator diff --git a/task-sdk/docs/api.rst b/task-sdk/docs/api.rst index 8c6fc526b5f5f..42fcc76dca5a3 100644 --- a/task-sdk/docs/api.rst +++ b/task-sdk/docs/api.rst @@ -70,6 +70,8 @@ Bases .. autoapiclass:: airflow.sdk.PokeReturnValue +.. autoapiclass:: airflow.sdk.BaseHook + Connections & Variables ----------------------- .. autoapiclass:: airflow.sdk.Connection diff --git a/task-sdk/src/airflow/sdk/__init__.py b/task-sdk/src/airflow/sdk/__init__.py index d415333babb10..b58cfcbe5fc6a 100644 --- a/task-sdk/src/airflow/sdk/__init__.py +++ b/task-sdk/src/airflow/sdk/__init__.py @@ -25,6 +25,7 @@ "AssetAll", "AssetAny", "AssetWatcher", + "BaseHook", "BaseNotifier", "BaseOperator", "BaseOperatorLink", @@ -58,6 +59,7 @@ __version__ = "1.1.0" if TYPE_CHECKING: + from airflow.sdk.bases.hook import BaseHook from airflow.sdk.bases.notifier import BaseNotifier from airflow.sdk.bases.operator import BaseOperator, chain, chain_linear, cross_downstream from airflow.sdk.bases.operatorlink import BaseOperatorLink @@ -84,6 +86,7 @@ "AssetAll": ".definitions.asset", "AssetAny": ".definitions.asset", "AssetWatcher": ".definitions.asset", + "BaseHook": ".bases.hook", "BaseNotifier": ".bases.notifier", "BaseOperator": ".bases.operator", "BaseOperatorLink": ".bases.operatorlink", diff --git a/task-sdk/src/airflow/sdk/__init__.pyi b/task-sdk/src/airflow/sdk/__init__.pyi index 985e616af1c8b..bfc7aed518a4d 100644 --- a/task-sdk/src/airflow/sdk/__init__.pyi +++ b/task-sdk/src/airflow/sdk/__init__.pyi @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +from airflow.sdk.bases.hook import BaseHook as BaseHook from airflow.sdk.bases.notifier import BaseNotifier as BaseNotifier from airflow.sdk.bases.operator import ( BaseOperator as BaseOperator, @@ -61,6 +62,7 @@ __all__ = [ "AssetAll", "AssetAny", "AssetWatcher", + "BaseHook", "BaseNotifier", "BaseOperator", "BaseOperatorLink", diff --git a/task-sdk/src/airflow/sdk/bases/hook.py b/task-sdk/src/airflow/sdk/bases/hook.py new file mode 100644 index 0000000000000..349fb150c6915 --- /dev/null +++ b/task-sdk/src/airflow/sdk/bases/hook.py @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +from airflow.utils.log.logging_mixin import LoggingMixin + +if TYPE_CHECKING: + from airflow.sdk.definitions.connection import Connection + +log = logging.getLogger(__name__) + + +class BaseHook(LoggingMixin): + """ + Abstract base class for hooks. + + Hooks are meant as an interface to + interact with external systems. MySqlHook, HiveHook, PigHook return + object that can handle the connection and interaction to specific + instances of these systems, and expose consistent methods to interact + with them. + + :param logger_name: Name of the logger used by the Hook to emit logs. + If set to `None` (default), the logger name will fall back to + `airflow.task.hooks.{class.__module__}.{class.__name__}` (e.g. DbApiHook will have + *airflow.task.hooks.airflow.providers.common.sql.hooks.sql.DbApiHook* as logger). + """ + + def __init__(self, logger_name: str | None = None): + super().__init__() + self._log_config_logger_name = "airflow.task.hooks" + self._logger_name = logger_name + + @classmethod + def get_connection(cls, conn_id: str) -> Connection: + """ + Get connection, given connection id. + + :param conn_id: connection id + :return: connection + """ + import sys + + # if SUPERVISOR_COMMS is set, we're in task sdk context + if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): + from airflow.sdk.definitions.connection import Connection + + conn = Connection.get(conn_id) + log.info("Connection Retrieved '%s' (via task-sdk)", conn.conn_id) + return conn + from airflow.models.connection import Connection as ConnectionModel + + conn = ConnectionModel.get_connection_from_secrets(conn_id) + log.info("Connection Retrieved '%s' (via core Airflow)", conn.conn_id) + return conn + + @classmethod + def get_hook(cls, conn_id: str, hook_params: dict | None = None): + """ + Return default hook for this connection id. + + :param conn_id: connection id + :param hook_params: hook parameters + :return: default hook for this connection + """ + connection = cls.get_connection(conn_id) + return connection.get_hook(hook_params=hook_params) + + def get_conn(self) -> Any: + """Return connection for the hook.""" + raise NotImplementedError() + + @classmethod + def get_connection_form_widgets(cls) -> dict[str, Any]: + return {} + + @classmethod + def get_ui_field_behaviour(cls) -> dict[str, Any]: + return {} diff --git a/airflow-core/tests/unit/hooks/test_base.py b/task-sdk/tests/task_sdk/bases/test_hook.py similarity index 90% rename from airflow-core/tests/unit/hooks/test_base.py rename to task-sdk/tests/task_sdk/bases/test_hook.py index 6d54827e67ced..f17ce8a12fcbb 100644 --- a/airflow-core/tests/unit/hooks/test_base.py +++ b/task-sdk/tests/task_sdk/bases/test_hook.py @@ -19,9 +19,8 @@ import pytest -from airflow.exceptions import AirflowNotFoundException -from airflow.hooks.base import BaseHook -from airflow.sdk.exceptions import ErrorType +from airflow.sdk import BaseHook +from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType from airflow.sdk.execution_time.comms import ConnectionResult, ErrorResponse, GetConnection from tests_common.test_utils.config import conf_vars @@ -30,7 +29,7 @@ class TestBaseHook: def test_hook_has_default_logger_name(self): hook = BaseHook() - assert hook.log.name == "airflow.task.hooks.airflow.hooks.base.BaseHook" + assert hook.log.name == "airflow.task.hooks.airflow.sdk.bases.hook.BaseHook" def test_custom_logger_name_is_correctly_set(self): hook = BaseHook(logger_name="airflow.custom.logger") @@ -65,7 +64,7 @@ def test_get_connection_not_found(self, mock_supervisor_comms): hook = BaseHook() mock_supervisor_comms.send.return_value = ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND) - with pytest.raises(AirflowNotFoundException, match=rf".*{conn_id}.*"): + with pytest.raises(AirflowRuntimeError, match="CONNECTION_NOT_FOUND"): hook.get_connection(conn_id=conn_id) def test_get_connection_secrets_backend_configured(self, mock_supervisor_comms, tmp_path):