diff --git a/providers/yandex/src/airflow/providers/yandex/hooks/yandex.py b/providers/yandex/src/airflow/providers/yandex/hooks/yandex.py index 6d4edd8c4c4b1..78c89368adf56 100644 --- a/providers/yandex/src/airflow/providers/yandex/hooks/yandex.py +++ b/providers/yandex/src/airflow/providers/yandex/hooks/yandex.py @@ -28,11 +28,7 @@ from airflow.providers.yandex.utils.defaults import conn_name_attr, conn_type, default_conn_name, hook_name from airflow.providers.yandex.utils.fields import get_field_from_extras from airflow.providers.yandex.utils.user_agent import provider_user_agent - -try: - from airflow.sdk import BaseHook -except ImportError: - from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] +from airflow.providers.yandex.version_compat import BaseHook class YandexCloudBaseHook(BaseHook): diff --git a/providers/yandex/src/airflow/providers/yandex/version_compat.py b/providers/yandex/src/airflow/providers/yandex/version_compat.py index a57abc71e5c35..d315b0d25897b 100644 --- a/providers/yandex/src/airflow/providers/yandex/version_compat.py +++ b/providers/yandex/src/airflow/providers/yandex/version_compat.py @@ -28,6 +28,13 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0) +AIRFLOW_V_3_1_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 0) + + +if AIRFLOW_V_3_1_PLUS: + from airflow.sdk import BaseHook +else: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] if AIRFLOW_V_3_0_PLUS: from airflow.sdk import BaseOperator, BaseOperatorLink @@ -41,6 +48,8 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: __all__ = [ "AIRFLOW_V_3_0_PLUS", + "AIRFLOW_V_3_1_PLUS", + "BaseHook", "BaseOperator", "BaseOperatorLink", "Context",