diff --git a/providers/atlassian/jira/src/airflow/providers/atlassian/jira/operators/jira.py b/providers/atlassian/jira/src/airflow/providers/atlassian/jira/operators/jira.py index d1a2fe6f06a55..a9c11be54165d 100644 --- a/providers/atlassian/jira/src/airflow/providers/atlassian/jira/operators/jira.py +++ b/providers/atlassian/jira/src/airflow/providers/atlassian/jira/operators/jira.py @@ -20,8 +20,8 @@ from collections.abc import Callable, Sequence from typing import TYPE_CHECKING, Any -from airflow.models import BaseOperator from airflow.providers.atlassian.jira.hooks.jira import JiraHook +from airflow.providers.atlassian.jira.version_compat import BaseOperator if TYPE_CHECKING: from airflow.utils.context import Context @@ -80,7 +80,7 @@ def execute(self, context: Context) -> Any: jira_result: Any = getattr(resource, self.method_name)(**self.jira_method_args) output = jira_result.get("id", None) if isinstance(jira_result, dict) else None - self.xcom_push(context, key="id", value=output) + context["task_instance"].xcom_push(key="id", value=output) if self.result_processor: return self.result_processor(context, jira_result) diff --git a/providers/atlassian/jira/src/airflow/providers/atlassian/jira/version_compat.py b/providers/atlassian/jira/src/airflow/providers/atlassian/jira/version_compat.py index 48d122b669696..b326387fea2a1 100644 --- a/providers/atlassian/jira/src/airflow/providers/atlassian/jira/version_compat.py +++ b/providers/atlassian/jira/src/airflow/providers/atlassian/jira/version_compat.py @@ -33,3 +33,15 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0) + +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk import BaseOperator, BaseSensorOperator +else: + from airflow.models import BaseOperator + from airflow.sensors.base import BaseSensorOperator # type: ignore[no-redef] + +__all__ = [ + "AIRFLOW_V_3_0_PLUS", + "BaseOperator", + "BaseSensorOperator", +] diff --git a/providers/atlassian/jira/tests/unit/atlassian/jira/operators/test_jira.py b/providers/atlassian/jira/tests/unit/atlassian/jira/operators/test_jira.py index 45f06fe46a3ae..473d32cf3f8a9 100644 --- a/providers/atlassian/jira/tests/unit/atlassian/jira/operators/test_jira.py +++ b/providers/atlassian/jira/tests/unit/atlassian/jira/operators/test_jira.py @@ -61,9 +61,11 @@ def setup_test_cases(self, monkeypatch): ) ), ) - with mock.patch("airflow.models.baseoperator.BaseOperator.xcom_push", return_value=None) as m: - self.mocked_xcom_push = m - yield + # Mock task instance for xcom_push + mock_ti = mock.Mock() + mock_ti.xcom_push = mock.Mock(return_value=None) + self.mock_ti = mock_ti + self.mock_context = {"task_instance": mock_ti} def test_operator_init_with_optional_args(self): jira_operator = JiraOperator(task_id="jira_list_issue_types", jira_method="issue_types") @@ -80,11 +82,11 @@ def test_project_issue_count(self, mocked_jira_client): jira_method_args={"project": "ABC"}, ) - op.execute({}) + op.execute(self.mock_context) # type: ignore[arg-type] assert mocked_jira_client.called assert mocked_jira_client.return_value.get_project_issues_count.called - self.mocked_xcom_push.assert_called_once_with(mock.ANY, key="id", value=None) + self.mock_ti.xcom_push.assert_called_once_with(key="id", value=None) def test_issue_search(self, mocked_jira_client): jql_str = "issuekey=TEST-1226" @@ -95,11 +97,11 @@ def test_issue_search(self, mocked_jira_client): jira_method_args={"jql": jql_str, "limit": "1"}, ) - op.execute({}) + op.execute(self.mock_context) # type: ignore[arg-type] assert mocked_jira_client.called assert mocked_jira_client.return_value.jql_get_list_of_tickets.called - self.mocked_xcom_push.assert_called_once_with(mock.ANY, key="id", value="911539") + self.mock_ti.xcom_push.assert_called_once_with(key="id", value="911539") def test_update_issue(self, mocked_jira_client): mocked_jira_client.return_value.issue_add_comment.return_value = MINIMAL_TEST_TICKET @@ -110,8 +112,8 @@ def test_update_issue(self, mocked_jira_client): jira_method_args={"issue_key": MINIMAL_TEST_TICKET.get("key"), "comment": "this is test comment"}, ) - op.execute({}) + op.execute(self.mock_context) # type: ignore[arg-type] assert mocked_jira_client.called assert mocked_jira_client.return_value.issue_add_comment.called - self.mocked_xcom_push.assert_called_once_with(mock.ANY, key="id", value="911539") + self.mock_ti.xcom_push.assert_called_once_with(key="id", value="911539")