diff --git a/providers/apache/beam/src/airflow/providers/apache/beam/operators/beam.py b/providers/apache/beam/src/airflow/providers/apache/beam/operators/beam.py index d1e3bcd411064..85d858ff45cdf 100644 --- a/providers/apache/beam/src/airflow/providers/apache/beam/operators/beam.py +++ b/providers/apache/beam/src/airflow/providers/apache/beam/operators/beam.py @@ -34,9 +34,9 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowOptionalProviderFeatureException -from airflow.models import BaseOperator from airflow.providers.apache.beam.hooks.beam import BeamHook, BeamRunnerType from airflow.providers.apache.beam.triggers.beam import BeamJavaPipelineTrigger, BeamPythonPipelineTrigger +from airflow.providers.apache.beam.version_compat import BaseOperator from airflow.providers_manager import ProvidersManager from airflow.utils.helpers import convert_camel_to_snake, exactly_one from airflow.version import version @@ -214,7 +214,8 @@ def dataflow_job_id(self, new_value): if all([new_value, not self._dataflow_job_id, self._execute_context]): # push job_id as soon as it's ready, to let Sensors work before the job finished # and job_id pushed as returned value item. - self.xcom_push(context=self._execute_context, key="dataflow_job_id", value=new_value) + # Use task instance to push XCom (works for both Airflow 2.x and 3.x) + self._execute_context["ti"].xcom_push(key="dataflow_job_id", value=new_value) self._dataflow_job_id = new_value def _cast_dataflow_config(self): diff --git a/providers/apache/beam/src/airflow/providers/apache/beam/version_compat.py b/providers/apache/beam/src/airflow/providers/apache/beam/version_compat.py new file mode 100644 index 0000000000000..6756c5b297327 --- /dev/null +++ b/providers/apache/beam/src/airflow/providers/apache/beam/version_compat.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Version compatibility for Apache Beam provider.""" + +from __future__ import annotations + + +def get_base_airflow_version_tuple() -> tuple[int, int, int]: + from packaging.version import Version + + from airflow import __version__ + + airflow_version = Version(__version__) + return airflow_version.major, airflow_version.minor, airflow_version.micro + + +AIRFLOW_V_3_1_PLUS = get_base_airflow_version_tuple() >= (3, 1, 0) + +if AIRFLOW_V_3_1_PLUS: + from airflow.sdk import BaseOperator +else: + from airflow.models import BaseOperator + +__all__ = [ + "AIRFLOW_V_3_1_PLUS", + "BaseOperator", +] diff --git a/providers/apache/beam/tests/unit/apache/beam/operators/test_beam.py b/providers/apache/beam/tests/unit/apache/beam/operators/test_beam.py index d03cde557767f..1616f11803202 100644 --- a/providers/apache/beam/tests/unit/apache/beam/operators/test_beam.py +++ b/providers/apache/beam/tests/unit/apache/beam/operators/test_beam.py @@ -116,25 +116,24 @@ def test_async_execute_logging_should_execute_successfully(self, caplog): assert f"{TASK_ID} completed with response Pipeline has finished SUCCESSFULLY" in caplog.text def test_early_dataflow_id_xcom_push(self, default_options, pipeline_options): - with mock.patch.object(BeamBasePipelineOperator, "xcom_push") as mock_xcom_push: - op = BeamBasePipelineOperator( - **self.default_op_kwargs, - default_pipeline_options=copy.deepcopy(default_options), - pipeline_options=copy.deepcopy(pipeline_options), - dataflow_config={}, - ) - sample_df_job_id = "sample_df_job_id_value" - op._execute_context = MagicMock() - - assert op.dataflow_job_id is None - - op.dataflow_job_id = sample_df_job_id - mock_xcom_push.assert_called_once_with( - context=op._execute_context, key="dataflow_job_id", value=sample_df_job_id - ) - mock_xcom_push.reset_mock() - op.dataflow_job_id = "sample_df_job_same_value_id" - mock_xcom_push.assert_not_called() + op = BeamBasePipelineOperator( + **self.default_op_kwargs, + default_pipeline_options=copy.deepcopy(default_options), + pipeline_options=copy.deepcopy(pipeline_options), + dataflow_config={}, + ) + sample_df_job_id = "sample_df_job_id_value" + # Mock the task instance with xcom_push method + mock_ti = MagicMock() + op._execute_context = {"ti": mock_ti} + + assert op.dataflow_job_id is None + + op.dataflow_job_id = sample_df_job_id + mock_ti.xcom_push.assert_called_once_with(key="dataflow_job_id", value=sample_df_job_id) + mock_ti.xcom_push.reset_mock() + op.dataflow_job_id = "sample_df_job_same_value_id" + mock_ti.xcom_push.assert_not_called() class TestBeamRunPythonPipelineOperator: