diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/datafusion.py b/providers/google/src/airflow/providers/google/cloud/hooks/datafusion.py index 063f7bbc2e3a1..73183ca2236be 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/datafusion.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/datafusion.py @@ -469,33 +469,31 @@ def start_pipeline( is always default. If your pipeline belongs to an Enterprise edition instance, you can create a namespace. """ - # Use the single-program start endpoint for better error handling - # https://cdap.atlassian.net/wiki/spaces/DOCS/pages/477560983/Lifecycle+Microservices#Start-a-Program - program_type = self.cdap_program_type(pipeline_type=pipeline_type) - program_id = self.cdap_program_id(pipeline_type=pipeline_type) + # TODO: This API endpoint starts multiple pipelines. There will eventually be a fix + # return the run Id as part of the API request to run a single pipeline. + # https://github.com/apache/airflow/pull/8954#discussion_r438223116 url = os.path.join( - self._base_url(instance_url, namespace), - quote(pipeline_name), - f"{program_type}s", - program_id, + instance_url, + "v3", + "namespaces", + quote(namespace), "start", ) runtime_args = runtime_args or {} - response = self._cdap_request(url=url, method="POST", body=runtime_args) + body = [ + { + "appId": pipeline_name, + "runtimeargs": runtime_args, + "programType": self.cdap_program_type(pipeline_type=pipeline_type), + "programId": self.cdap_program_id(pipeline_type=pipeline_type), + } + ] + response = self._cdap_request(url=url, method="POST", body=body) self._check_response_status_and_data( response, f"Starting a pipeline failed with code {response.status}" ) response_json = json.loads(response.data) - - # Extract and validate runId from response - if "runId" not in response_json: - error_message = response_json.get("error", "Unknown error") - raise AirflowException( - f"Failed to start pipeline '{pipeline_name}'. " - f"The response does not contain a runId. Error: {error_message}" - ) - - return str(response_json["runId"]) + return response_json[0]["runId"] def stop_pipeline(self, pipeline_name: str, instance_url: str, namespace: str = "default") -> None: """ diff --git a/providers/google/tests/unit/google/cloud/hooks/test_datafusion.py b/providers/google/tests/unit/google/cloud/hooks/test_datafusion.py index 49479446cdeed..d98a64fc17d45 100644 --- a/providers/google/tests/unit/google/cloud/hooks/test_datafusion.py +++ b/providers/google/tests/unit/google/cloud/hooks/test_datafusion.py @@ -21,7 +21,6 @@ from unittest import mock import aiohttp -import google.auth.transport import pytest from aiohttp.helpers import TimerNoop from yarl import URL @@ -341,38 +340,42 @@ def test_list_pipelines_should_fail_if_status_not_200(self, mock_request, hook): @mock.patch(HOOK_STR.format("DataFusionHook._cdap_request")) def test_start_pipeline(self, mock_request, hook): run_id = 1234 - mock_request.return_value = mock.MagicMock( - spec=google.auth.transport.Response, status=200, data=f'{{"runId":{run_id}}}' - ) + mock_request.return_value = mock.MagicMock(status=200, data=f'[{{"runId":{run_id}}}]') - result = hook.start_pipeline( - pipeline_name=PIPELINE_NAME, instance_url=INSTANCE_URL, runtime_args=RUNTIME_ARGS - ) - assert result == str(run_id) + hook.start_pipeline(pipeline_name=PIPELINE_NAME, instance_url=INSTANCE_URL, runtime_args=RUNTIME_ARGS) + body = [ + { + "appId": PIPELINE_NAME, + "programType": "workflow", + "programId": "DataPipelineWorkflow", + "runtimeargs": RUNTIME_ARGS, + } + ] mock_request.assert_called_once_with( - url=f"{INSTANCE_URL}/v3/namespaces/default/apps/{PIPELINE_NAME}/workflows/DataPipelineWorkflow/start", - method="POST", - body=RUNTIME_ARGS, + url=f"{INSTANCE_URL}/v3/namespaces/default/start", method="POST", body=body ) @mock.patch(HOOK_STR.format("DataFusionHook._cdap_request")) def test_start_pipeline_stream(self, mock_request, hook): - run_id = "test-run-123" - mock_request.return_value = mock.MagicMock( - spec=google.auth.transport.Response, status=200, data=f'{{"runId":"{run_id}"}}' - ) + run_id = 1234 + mock_request.return_value = mock.MagicMock(status=200, data=f'[{{"runId":{run_id}}}]') - result = hook.start_pipeline( + hook.start_pipeline( pipeline_name=PIPELINE_NAME, instance_url=INSTANCE_URL, runtime_args=RUNTIME_ARGS, pipeline_type=DataFusionPipelineType.STREAM, ) - assert result == run_id + body = [ + { + "appId": PIPELINE_NAME, + "programType": "spark", + "programId": "DataStreamsSparkStreaming", + "runtimeargs": RUNTIME_ARGS, + } + ] mock_request.assert_called_once_with( - url=f"{INSTANCE_URL}/v3/namespaces/default/apps/{PIPELINE_NAME}/sparks/DataStreamsSparkStreaming/start", - method="POST", - body=RUNTIME_ARGS, + url=f"{INSTANCE_URL}/v3/namespaces/default/start", method="POST", body=body ) @mock.patch(HOOK_STR.format("DataFusionHook._cdap_request")) @@ -387,10 +390,16 @@ def test_start_pipeline_should_fail_if_empty_data_response(self, mock_request, h hook.start_pipeline( pipeline_name=PIPELINE_NAME, instance_url=INSTANCE_URL, runtime_args=RUNTIME_ARGS ) + body = [ + { + "appId": PIPELINE_NAME, + "programType": "workflow", + "programId": "DataPipelineWorkflow", + "runtimeargs": RUNTIME_ARGS, + } + ] mock_request.assert_called_once_with( - url=f"{INSTANCE_URL}/v3/namespaces/default/apps/{PIPELINE_NAME}/workflows/DataPipelineWorkflow/start", - method="POST", - body=RUNTIME_ARGS, + url=f"{INSTANCE_URL}/v3/namespaces/default/start", method="POST", body=body ) @mock.patch(HOOK_STR.format("DataFusionHook._cdap_request")) @@ -400,31 +409,16 @@ def test_start_pipeline_should_fail_if_status_not_200(self, mock_request, hook): hook.start_pipeline( pipeline_name=PIPELINE_NAME, instance_url=INSTANCE_URL, runtime_args=RUNTIME_ARGS ) + body = [ + { + "appId": PIPELINE_NAME, + "programType": "workflow", + "programId": "DataPipelineWorkflow", + "runtimeargs": RUNTIME_ARGS, + } + ] mock_request.assert_called_once_with( - url=f"{INSTANCE_URL}/v3/namespaces/default/apps/{PIPELINE_NAME}/workflows/DataPipelineWorkflow/start", - method="POST", - body=RUNTIME_ARGS, - ) - - @mock.patch(HOOK_STR.format("DataFusionHook._cdap_request")) - def test_start_pipeline_should_fail_if_no_run_id(self, mock_request, hook): - """Test that start_pipeline fails gracefully when response doesn't contain runId.""" - error_response = '{"error": "Invalid runtime arguments"}' - mock_request.return_value = mock.MagicMock( - spec=google.auth.transport.Response, status=200, data=error_response - ) - with pytest.raises( - AirflowException, - match=r"Failed to start pipeline 'shrubberyPipeline'. " - r"The response does not contain a runId. Error: Invalid runtime arguments", - ): - hook.start_pipeline( - pipeline_name=PIPELINE_NAME, instance_url=INSTANCE_URL, runtime_args=RUNTIME_ARGS - ) - mock_request.assert_called_once_with( - url=f"{INSTANCE_URL}/v3/namespaces/default/apps/{PIPELINE_NAME}/workflows/DataPipelineWorkflow/start", - method="POST", - body=RUNTIME_ARGS, + url=f"{INSTANCE_URL}/v3/namespaces/default/start", method="POST", body=body ) @mock.patch(HOOK_STR.format("DataFusionHook._cdap_request"))