diff --git a/providers/papermill/src/airflow/providers/papermill/hooks/kernel.py b/providers/papermill/src/airflow/providers/papermill/hooks/kernel.py index 6cc730a02195a..6bc5331443987 100644 --- a/providers/papermill/src/airflow/providers/papermill/hooks/kernel.py +++ b/providers/papermill/src/airflow/providers/papermill/hooks/kernel.py @@ -25,10 +25,7 @@ from papermill.utils import merge_kwargs, remove_args from traitlets import Unicode -try: - from airflow.sdk import BaseHook -except ImportError: - from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] +from airflow.providers.papermill.version_compat import BaseHook JUPYTER_KERNEL_SHELL_PORT = 60316 JUPYTER_KERNEL_IOPUB_PORT = 60317 diff --git a/providers/papermill/src/airflow/providers/papermill/version_compat.py b/providers/papermill/src/airflow/providers/papermill/version_compat.py index 89ade07a76753..7c415980c9307 100644 --- a/providers/papermill/src/airflow/providers/papermill/version_compat.py +++ b/providers/papermill/src/airflow/providers/papermill/version_compat.py @@ -28,13 +28,21 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0) +AIRFLOW_V_3_1_PLUS = get_base_airflow_version_tuple() >= (3, 1, 0) if AIRFLOW_V_3_0_PLUS: from airflow.sdk import BaseOperator else: from airflow.models import BaseOperator # type: ignore[no-redef] +if AIRFLOW_V_3_1_PLUS: + from airflow.sdk import BaseHook +else: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] + __all__ = [ "AIRFLOW_V_3_0_PLUS", + "AIRFLOW_V_3_1_PLUS", + "BaseHook", "BaseOperator", ]