diff --git a/providers/ssh/src/airflow/providers/ssh/operators/ssh.py b/providers/ssh/src/airflow/providers/ssh/operators/ssh.py index e6e90f50d765c..8f22a86407137 100644 --- a/providers/ssh/src/airflow/providers/ssh/operators/ssh.py +++ b/providers/ssh/src/airflow/providers/ssh/operators/ssh.py @@ -24,8 +24,8 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowSkipException -from airflow.models import BaseOperator from airflow.providers.ssh.hooks.ssh import SSHHook +from airflow.providers.ssh.version_compat import BaseOperator from airflow.utils.types import NOTSET, ArgNotSet if TYPE_CHECKING: diff --git a/providers/ssh/src/airflow/providers/ssh/version_compat.py b/providers/ssh/src/airflow/providers/ssh/version_compat.py new file mode 100644 index 0000000000000..4f8d5e32bca4a --- /dev/null +++ b/providers/ssh/src/airflow/providers/ssh/version_compat.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# NOTE! THIS FILE IS COPIED MANUALLY IN OTHER PROVIDERS DELIBERATELY TO AVOID ADDING UNNECESSARY +# DEPENDENCIES BETWEEN PROVIDERS. IF YOU WANT TO ADD CONDITIONAL CODE IN YOUR PROVIDER THAT DEPENDS +# ON AIRFLOW VERSION, PLEASE COPY THIS FILE TO THE ROOT PACKAGE OF YOUR PROVIDER AND IMPORT +# THOSE CONSTANTS FROM IT RATHER THAN IMPORTING THEM FROM ANOTHER PROVIDER OR TEST CODE +# +from __future__ import annotations + + +def get_base_airflow_version_tuple() -> tuple[int, int, int]: + from packaging.version import Version + + from airflow import __version__ + + airflow_version = Version(__version__) + return airflow_version.major, airflow_version.minor, airflow_version.micro + + +AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0) + +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk import BaseOperator +else: + from airflow.models import BaseOperator # type: ignore[no-redef] + +__all__ = ["AIRFLOW_V_3_0_PLUS", "BaseOperator"] diff --git a/providers/ssh/tests/unit/ssh/operators/test_ssh.py b/providers/ssh/tests/unit/ssh/operators/test_ssh.py index 06ea865fc4481..b100fc13d31a4 100644 --- a/providers/ssh/tests/unit/ssh/operators/test_ssh.py +++ b/providers/ssh/tests/unit/ssh/operators/test_ssh.py @@ -268,7 +268,7 @@ def test_push_ssh_exit_to_xcom(self, request, dag_maker): dr = dag_maker.create_dagrun(run_id="push_xcom") ti = TaskInstance(task=task, run_id=dr.run_id) with pytest.raises(AirflowException, match=f"SSH operator error: exit status = {ssh_exit_code}"): - ti.run() + dag_maker.run_ti("push_xcom", dr) assert ti.xcom_pull(task_ids=task.task_id, key="ssh_exit") == ssh_exit_code def test_timeout_triggers_on_kill(self, request, dag_maker): @@ -278,18 +278,17 @@ def command_sleep_forever(*args, **kwargs): self.exec_ssh_client_command.side_effect = command_sleep_forever with dag_maker(dag_id=f"dag_{request.node.name}"): - task = SSHOperator( + _ = SSHOperator( task_id="test_timeout", ssh_hook=self.hook, command="sleep 100", execution_timeout=timedelta(seconds=1), ) dr = dag_maker.create_dagrun(run_id="test_timeout") - ti = TaskInstance(task=task, run_id=dr.run_id) with mock.patch.object(SSHOperator, "on_kill") as mock_on_kill: with pytest.raises(AirflowTaskTimeout): - ti.run() + dag_maker.run_ti("test_timeout", dr) # Wait a bit to ensure on_kill has time to be called time.sleep(1)