diff --git a/providers/apache/beam/src/airflow/providers/apache/beam/operators/beam.py b/providers/apache/beam/src/airflow/providers/apache/beam/operators/beam.py index 3ac04cd4c57aa..bba1af0d7e518 100644 --- a/providers/apache/beam/src/airflow/providers/apache/beam/operators/beam.py +++ b/providers/apache/beam/src/airflow/providers/apache/beam/operators/beam.py @@ -441,17 +441,17 @@ def execute_on_dataflow(self, context: Context): """Execute the Apache Beam Pipeline on Dataflow runner.""" if not self.dataflow_hook: self.dataflow_hook = self.__set_dataflow_hook() - with self.dataflow_hook.provide_authorized_gcloud(): - self.beam_hook.start_python_pipeline( - variables=self.snake_case_pipeline_options, - py_file=self.py_file, - py_options=self.py_options, - py_interpreter=self.py_interpreter, - py_requirements=self.py_requirements, - py_system_site_packages=self.py_system_site_packages, - process_line_callback=self.process_line_callback, - is_dataflow_job_id_exist_callback=self.is_dataflow_job_id_exist_callback, - ) + + self.beam_hook.start_python_pipeline( + variables=self.snake_case_pipeline_options, + py_file=self.py_file, + py_options=self.py_options, + py_interpreter=self.py_interpreter, + py_requirements=self.py_requirements, + py_system_site_packages=self.py_system_site_packages, + process_line_callback=self.process_line_callback, + is_dataflow_job_id_exist_callback=self.is_dataflow_job_id_exist_callback, + ) location = self.dataflow_config.location or DEFAULT_DATAFLOW_LOCATION DataflowJobLink.persist(context=context, region=location) @@ -623,14 +623,13 @@ def execute_on_dataflow(self, context: Context): if not is_running: self.pipeline_options["jobName"] = self.dataflow_job_name - with self.dataflow_hook.provide_authorized_gcloud(): - self.beam_hook.start_java_pipeline( - variables=self.pipeline_options, - jar=self.jar, - job_class=self.job_class, - process_line_callback=self.process_line_callback, - is_dataflow_job_id_exist_callback=self.is_dataflow_job_id_exist_callback, - ) + self.beam_hook.start_java_pipeline( + variables=self.pipeline_options, + jar=self.jar, + job_class=self.job_class, + process_line_callback=self.process_line_callback, + is_dataflow_job_id_exist_callback=self.is_dataflow_job_id_exist_callback, + ) if self.dataflow_job_name and self.dataflow_config.location: DataflowJobLink.persist(context=context) if self.deferrable: @@ -790,12 +789,11 @@ def execute(self, context: Context): go_artifact.download_from_gcs(gcs_hook=gcs_hook, tmp_dir=tmp_dir) if is_dataflow and self.dataflow_hook: - with self.dataflow_hook.provide_authorized_gcloud(): - go_artifact.start_pipeline( - beam_hook=self.beam_hook, - variables=snake_case_pipeline_options, - process_line_callback=process_line_callback, - ) + go_artifact.start_pipeline( + beam_hook=self.beam_hook, + variables=snake_case_pipeline_options, + process_line_callback=process_line_callback, + ) DataflowJobLink.persist(context=context) if dataflow_job_name and self.dataflow_config.location: self.dataflow_hook.wait_for_done( diff --git a/providers/apache/beam/tests/unit/apache/beam/operators/test_beam.py b/providers/apache/beam/tests/unit/apache/beam/operators/test_beam.py index 1616f11803202..c517517c29963 100644 --- a/providers/apache/beam/tests/unit/apache/beam/operators/test_beam.py +++ b/providers/apache/beam/tests/unit/apache/beam/operators/test_beam.py @@ -245,7 +245,6 @@ def test_exec_dataflow_runner( process_line_callback=mock.ANY, is_dataflow_job_id_exist_callback=mock.ANY, ) - dataflow_hook_mock.return_value.provide_authorized_gcloud.assert_called_once_with() @mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist")) @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) @@ -760,7 +759,6 @@ def test_exec_dataflow_runner_with_go_file( multiple_jobs=False, project_id=dataflow_config.project_id, ) - dataflow_hook_mock.return_value.provide_authorized_gcloud.assert_called_once_with() @mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist")) @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook")) @@ -789,8 +787,6 @@ def gcs_download_side_effect(bucket_name: str, object_name: str, filename: str) gcs_download_method.side_effect = gcs_download_side_effect mock_dataflow_hook.build_dataflow_job_name.return_value = "test-job" - - provide_authorized_gcloud_method = mock_dataflow_hook.return_value.provide_authorized_gcloud start_go_pipeline_method = mock_beam_hook.return_value.start_go_pipeline_with_binary wait_for_done_method = mock_dataflow_hook.return_value.wait_for_done @@ -835,7 +831,6 @@ def gcs_download_side_effect(bucket_name: str, object_name: str, filename: str) cancel_timeout=dataflow_config.cancel_timeout, wait_until_finished=dataflow_config.wait_until_finished, ) - provide_authorized_gcloud_method.assert_called_once_with() start_go_pipeline_method.assert_called_once_with( variables=expected_options, launcher_binary=expected_launcher_binary, @@ -971,7 +966,6 @@ def test_exec_dataflow_runner(self, gcs_hook_mock, dataflow_hook_mock, beam_hook wait_until_finished=dataflow_config.wait_until_finished, ) beam_hook_mock.return_value.start_python_pipeline.assert_called_once() - dataflow_hook_mock.return_value.provide_authorized_gcloud.assert_called_once_with() @mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist")) @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) @@ -1100,7 +1094,6 @@ def test_exec_dataflow_runner(self, gcs_hook_mock, dataflow_hook_mock, beam_hook wait_until_finished=dataflow_config.wait_until_finished, ) beam_hook_mock.return_value.start_python_pipeline.assert_not_called() - dataflow_hook_mock.return_value.provide_authorized_gcloud.assert_called_once_with() @mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist")) @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))