diff --git a/providers/src/airflow/providers/apache/beam/operators/beam.py b/providers/src/airflow/providers/apache/beam/operators/beam.py index 41c55ede2a5bc..65f23336589d2 100644 --- a/providers/src/airflow/providers/apache/beam/operators/beam.py +++ b/providers/src/airflow/providers/apache/beam/operators/beam.py @@ -187,7 +187,20 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.beam_hook: BeamHook self.dataflow_hook: DataflowHook | None = None - self.dataflow_job_id: str | None = None + self._dataflow_job_id: str | None = None + self._execute_context: Context | None = None + + @property + def dataflow_job_id(self): + return self._dataflow_job_id + + @dataflow_job_id.setter + def dataflow_job_id(self, new_value): + if all([new_value, not self._dataflow_job_id, self._execute_context]): + # push job_id as soon as it's ready, to let Sensors work before the job finished + # and job_id pushed as returned value item. + self.xcom_push(context=self._execute_context, key="dataflow_job_id", value=new_value) + self._dataflow_job_id = new_value def _cast_dataflow_config(self): if isinstance(self.dataflow_config, dict): @@ -346,6 +359,7 @@ def __init__( def execute(self, context: Context): """Execute the Apache Beam Python Pipeline.""" + self._execute_context = context self._cast_dataflow_config() self.pipeline_options.setdefault("labels", {}).update( {"airflow-version": "v" + version.replace(".", "-").replace("+", "-")} @@ -540,6 +554,7 @@ def __init__( def execute(self, context: Context): """Execute the Apache Beam Python Pipeline.""" + self._execute_context = context self._cast_dataflow_config() ( self.is_dataflow, @@ -738,7 +753,7 @@ def execute(self, context: Context): """Execute the Apache Beam Pipeline.""" if not exactly_one(self.go_file, self.launcher_binary): raise ValueError("Exactly one of `go_file` and `launcher_binary` must be set") - + self._execute_context = context self._cast_dataflow_config() if self.dataflow_config.impersonation_chain: self.log.warning( diff --git a/providers/tests/apache/beam/operators/test_beam.py b/providers/tests/apache/beam/operators/test_beam.py index 6d1b4b5d1b958..fd2e706c29414 100644 --- a/providers/tests/apache/beam/operators/test_beam.py +++ b/providers/tests/apache/beam/operators/test_beam.py @@ -110,6 +110,27 @@ def test_async_execute_logging_should_execute_successfully(self, caplog): ) assert f"{TASK_ID} completed with response Pipeline has finished SUCCESSFULLY" in caplog.text + def test_early_dataflow_id_xcom_push(self, default_options, pipeline_options): + with mock.patch.object(BeamBasePipelineOperator, "xcom_push") as mock_xcom_push: + op = BeamBasePipelineOperator( + **self.default_op_kwargs, + default_pipeline_options=copy.deepcopy(default_options), + pipeline_options=copy.deepcopy(pipeline_options), + dataflow_config={}, + ) + sample_df_job_id = "sample_df_job_id_value" + op._execute_context = MagicMock() + + assert op.dataflow_job_id is None + + op.dataflow_job_id = sample_df_job_id + mock_xcom_push.assert_called_once_with( + context=op._execute_context, key="dataflow_job_id", value=sample_df_job_id + ) + mock_xcom_push.reset_mock() + op.dataflow_job_id = "sample_df_job_same_value_id" + mock_xcom_push.assert_not_called() + class TestBeamRunPythonPipelineOperator: @pytest.fixture(autouse=True)