diff --git a/providers/yandex/pyproject.toml b/providers/yandex/pyproject.toml index 1eb44bd7219b9..876f8add3151c 100644 --- a/providers/yandex/pyproject.toml +++ b/providers/yandex/pyproject.toml @@ -73,7 +73,6 @@ dev = [ "apache-airflow", "apache-airflow-task-sdk", "apache-airflow-devel-common", - "apache-airflow-providers-common-compat", # Additional devel dependencies (do not remove this line and add extra development dependencies) "responses>=0.25.0", ] diff --git a/providers/yandex/src/airflow/providers/yandex/links/yq.py b/providers/yandex/src/airflow/providers/yandex/links/yq.py index 49a42473ca446..7bd416e889831 100644 --- a/providers/yandex/src/airflow/providers/yandex/links/yq.py +++ b/providers/yandex/src/airflow/providers/yandex/links/yq.py @@ -18,24 +18,11 @@ from typing import TYPE_CHECKING +from airflow.providers.yandex.version_compat import BaseOperatorLink, XCom + if TYPE_CHECKING: - from airflow.models import BaseOperator from airflow.models.taskinstancekey import TaskInstanceKey - - try: - from airflow.sdk.definitions.context import Context - except ImportError: - # TODO: Remove once provider drops support for Airflow 2 - from airflow.utils.context import Context - -from airflow.providers.common.compat.version_compat import AIRFLOW_V_3_0_PLUS - -if AIRFLOW_V_3_0_PLUS: - from airflow.sdk import BaseOperatorLink - from airflow.sdk.execution_time.xcom import XCom -else: - from airflow.models import XCom # type: ignore[no-redef] - from airflow.models.baseoperatorlink import BaseOperatorLink # type: ignore[no-redef] + from airflow.providers.yandex.version_compat import BaseOperator, Context XCOM_WEBLINK_KEY = "web_link" @@ -49,5 +36,5 @@ def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey): return XCom.get_value(key=XCOM_WEBLINK_KEY, ti_key=ti_key) or "https://yq.cloud.yandex.ru" @staticmethod - def persist(context: Context, task_instance: BaseOperator, web_link: str) -> None: - task_instance.xcom_push(context, key=XCOM_WEBLINK_KEY, value=web_link) + def persist(context: Context, web_link: str) -> None: + context["ti"].xcom_push(key=XCOM_WEBLINK_KEY, value=web_link) diff --git a/providers/yandex/src/airflow/providers/yandex/operators/dataproc.py b/providers/yandex/src/airflow/providers/yandex/operators/dataproc.py index e389f85abe583..a5418024541e8 100644 --- a/providers/yandex/src/airflow/providers/yandex/operators/dataproc.py +++ b/providers/yandex/src/airflow/providers/yandex/operators/dataproc.py @@ -20,15 +20,11 @@ from dataclasses import dataclass from typing import TYPE_CHECKING -from airflow.models import BaseOperator from airflow.providers.yandex.hooks.dataproc import DataprocHook +from airflow.providers.yandex.version_compat import BaseOperator if TYPE_CHECKING: - try: - from airflow.sdk.definitions.context import Context - except ImportError: - # TODO: Remove once provider drops support for Airflow 2 - from airflow.utils.context import Context + from airflow.providers.yandex.version_compat import Context @dataclass diff --git a/providers/yandex/src/airflow/providers/yandex/operators/yq.py b/providers/yandex/src/airflow/providers/yandex/operators/yq.py index da3890f4adfec..1c450174e1681 100644 --- a/providers/yandex/src/airflow/providers/yandex/operators/yq.py +++ b/providers/yandex/src/airflow/providers/yandex/operators/yq.py @@ -20,16 +20,12 @@ from functools import cached_property from typing import TYPE_CHECKING, Any -from airflow.models import BaseOperator from airflow.providers.yandex.hooks.yq import YQHook from airflow.providers.yandex.links.yq import YQLink +from airflow.providers.yandex.version_compat import BaseOperator if TYPE_CHECKING: - try: - from airflow.sdk.definitions.context import Context - except ImportError: - # TODO: Remove once provider drops support for Airflow 2 - from airflow.utils.context import Context + from airflow.providers.yandex.version_compat import Context class YQExecuteQueryOperator(BaseOperator): @@ -84,7 +80,7 @@ def execute(self, context: Context) -> Any: # pass to YQLink web_link = self.hook.compose_query_web_link(self.query_id) - YQLink.persist(context, self, web_link) + YQLink.persist(context, web_link) results = self.hook.wait_results(self.query_id) # forget query to avoid 'stop_query' in on_kill diff --git a/providers/yandex/src/airflow/providers/yandex/version_compat.py b/providers/yandex/src/airflow/providers/yandex/version_compat.py new file mode 100644 index 0000000000000..a57abc71e5c35 --- /dev/null +++ b/providers/yandex/src/airflow/providers/yandex/version_compat.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + + +def get_base_airflow_version_tuple() -> tuple[int, int, int]: + from packaging.version import Version + + from airflow import __version__ + + airflow_version = Version(__version__) + return airflow_version.major, airflow_version.minor, airflow_version.micro + + +AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0) + +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk import BaseOperator, BaseOperatorLink + from airflow.sdk.definitions.context import Context + from airflow.sdk.execution_time.xcom import XCom +else: + from airflow.models import BaseOperator, XCom + from airflow.models.baseoperatorlink import BaseOperatorLink # type: ignore[no-redef] + from airflow.utils.context import Context + + +__all__ = [ + "AIRFLOW_V_3_0_PLUS", + "BaseOperator", + "BaseOperatorLink", + "Context", + "XCom", +] diff --git a/providers/yandex/tests/unit/yandex/links/test_yq.py b/providers/yandex/tests/unit/yandex/links/test_yq.py index 9bfa0fbede6ff..bb6ef4bf3608d 100644 --- a/providers/yandex/tests/unit/yandex/links/test_yq.py +++ b/providers/yandex/tests/unit/yandex/links/test_yq.py @@ -35,18 +35,15 @@ def test_persist(): - mock_context = mock.MagicMock() + mock_ti = mock.MagicMock() + mock_context = {"ti": mock_ti} + if not AIRFLOW_V_3_0_PLUS: + mock_context["task_instance"] = mock_ti - YQLink.persist(context=mock_context, task_instance=MockOperator(task_id="test_task_id"), web_link="g.com") + YQLink.persist(context=mock_context, web_link="g.com") ti = mock_context["ti"] - if AIRFLOW_V_3_0_PLUS: - ti.xcom_push.assert_called_once_with( - key="web_link", - value="g.com", - ) - else: - ti.xcom_push.assert_called_once_with(key="web_link", value="g.com", execution_date=None) + ti.xcom_push.assert_called_once_with(key="web_link", value="g.com") def test_default_link(): diff --git a/providers/yandex/tests/unit/yandex/operators/test_yq.py b/providers/yandex/tests/unit/yandex/operators/test_yq.py index 127e4eb837972..be09ca243d3a8 100644 --- a/providers/yandex/tests/unit/yandex/operators/test_yq.py +++ b/providers/yandex/tests/unit/yandex/operators/test_yq.py @@ -54,7 +54,10 @@ def setup_method(self): def test_execute_query(self, mock_get_connection): mock_get_connection.return_value = Connection(extra={"oauth": OAUTH_TOKEN}) operator = YQExecuteQueryOperator(task_id="simple_sql", sql="select 987", folder_id="my_folder_id") - context = {"ti": MagicMock()} + mock_ti = MagicMock() + context = {"ti": mock_ti} + if not AIRFLOW_V_3_0_PLUS: + context["task_instance"] = operator responses.post( "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries", @@ -90,25 +93,14 @@ def test_execute_query(self, mock_get_connection): results = operator.execute(context) assert results == {"rows": [[777]], "columns": [{"name": "column0", "type": "Int32"}]} - if AIRFLOW_V_3_0_PLUS: - context["ti"].xcom_push.assert_has_calls( - [ - call( - key="web_link", - value=f"https://yq.cloud.yandex.ru/folders/{FOLDER_ID}/ide/queries/query1", - ), - ] - ) - else: - context["ti"].xcom_push.assert_has_calls( - [ - call( - key="web_link", - value=f"https://yq.cloud.yandex.ru/folders/{FOLDER_ID}/ide/queries/query1", - execution_date=None, - ), - ] - ) + context["ti"].xcom_push.assert_has_calls( + [ + call( + key="web_link", + value=f"https://yq.cloud.yandex.ru/folders/{FOLDER_ID}/ide/queries/query1", + ), + ] + ) responses.get( "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1/status",