diff --git a/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py b/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py index 02eed4036ac2d..e0bcc07683042 100644 --- a/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py +++ b/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py @@ -26,11 +26,7 @@ from elasticsearch import Elasticsearch from airflow.providers.common.sql.hooks.sql import DbApiHook - -try: - from airflow.sdk import BaseHook -except ImportError: - from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] +from airflow.providers.elasticsearch.version_compat import BaseHook if TYPE_CHECKING: from elastic_transport import ObjectApiResponse diff --git a/providers/elasticsearch/src/airflow/providers/elasticsearch/version_compat.py b/providers/elasticsearch/src/airflow/providers/elasticsearch/version_compat.py index 7d2af595fdfde..c9a2c48636350 100644 --- a/providers/elasticsearch/src/airflow/providers/elasticsearch/version_compat.py +++ b/providers/elasticsearch/src/airflow/providers/elasticsearch/version_compat.py @@ -33,6 +33,12 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0) +AIRFLOW_V_3_1_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 0) + +if AIRFLOW_V_3_1_PLUS: + from airflow.sdk import BaseHook +else: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] if AIRFLOW_V_3_0_PLUS: from airflow.utils.log.file_task_handler import StructuredLogMessage @@ -40,3 +46,5 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: EsLogMsgType = list[StructuredLogMessage] | str else: EsLogMsgType = list[tuple[str, str]] # type: ignore[misc] + +__all__ = ["AIRFLOW_V_3_0_PLUS", "AIRFLOW_V_3_1_PLUS", "BaseHook", "EsLogMsgType"]