diff --git a/providers/qdrant/src/airflow/providers/qdrant/hooks/qdrant.py b/providers/qdrant/src/airflow/providers/qdrant/hooks/qdrant.py index 66e4dfb5372c4..e692491ae8792 100644 --- a/providers/qdrant/src/airflow/providers/qdrant/hooks/qdrant.py +++ b/providers/qdrant/src/airflow/providers/qdrant/hooks/qdrant.py @@ -24,10 +24,7 @@ from qdrant_client import QdrantClient from qdrant_client.http.exceptions import UnexpectedResponse -try: - from airflow.sdk import BaseHook -except ImportError: - from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] +from airflow.providers.qdrant.version_compat import BaseHook class QdrantHook(BaseHook): diff --git a/providers/qdrant/src/airflow/providers/qdrant/version_compat.py b/providers/qdrant/src/airflow/providers/qdrant/version_compat.py index 9e99cafacafa1..41196dbb387c8 100644 --- a/providers/qdrant/src/airflow/providers/qdrant/version_compat.py +++ b/providers/qdrant/src/airflow/providers/qdrant/version_compat.py @@ -28,10 +28,16 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0) +AIRFLOW_V_3_1_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 0) if AIRFLOW_V_3_0_PLUS: from airflow.sdk import BaseOperator else: from airflow.models import BaseOperator -__all__ = ["AIRFLOW_V_3_0_PLUS", "BaseOperator"] +if AIRFLOW_V_3_1_PLUS: + from airflow.sdk import BaseHook +else: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] + +__all__ = ["AIRFLOW_V_3_0_PLUS", "AIRFLOW_V_3_1_PLUS", "BaseHook", "BaseOperator"]