diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/datafusion.py b/providers/google/src/airflow/providers/google/cloud/hooks/datafusion.py index c27f63eb8399a..7355cca284238 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/datafusion.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/datafusion.py @@ -469,31 +469,33 @@ def start_pipeline( is always default. If your pipeline belongs to an Enterprise edition instance, you can create a namespace. """ - # TODO: This API endpoint starts multiple pipelines. There will eventually be a fix - # return the run Id as part of the API request to run a single pipeline. - # https://github.com/apache/airflow/pull/8954#discussion_r438223116 + # Use the single-program start endpoint for better error handling + # https://cdap.atlassian.net/wiki/spaces/DOCS/pages/477560983/Lifecycle+Microservices#Start-a-Program + program_type = self.cdap_program_type(pipeline_type=pipeline_type) + program_id = self.cdap_program_id(pipeline_type=pipeline_type) url = os.path.join( - instance_url, - "v3", - "namespaces", - quote(namespace), + self._base_url(instance_url, namespace), + quote(pipeline_name), + f"{program_type}s", + program_id, "start", ) runtime_args = runtime_args or {} - body = [ - { - "appId": pipeline_name, - "runtimeargs": runtime_args, - "programType": self.cdap_program_type(pipeline_type=pipeline_type), - "programId": self.cdap_program_id(pipeline_type=pipeline_type), - } - ] - response = self._cdap_request(url=url, method="POST", body=body) + response = self._cdap_request(url=url, method="POST", body=runtime_args) self._check_response_status_and_data( response, f"Starting a pipeline failed with code {response.status}" ) response_json = json.loads(response.data) - return response_json[0]["runId"] + + # Extract and validate runId from response + if "runId" not in response_json: + error_message = response_json.get("error", "Unknown error") + raise AirflowException( + f"Failed to start pipeline '{pipeline_name}'. " + f"The response does not contain a runId. Error: {error_message}" + ) + + return str(response_json["runId"]) def stop_pipeline(self, pipeline_name: str, instance_url: str, namespace: str = "default") -> None: """ diff --git a/providers/google/tests/unit/google/cloud/hooks/test_datafusion.py b/providers/google/tests/unit/google/cloud/hooks/test_datafusion.py index 61854b4dfc793..6ffdc9c827f49 100644 --- a/providers/google/tests/unit/google/cloud/hooks/test_datafusion.py +++ b/providers/google/tests/unit/google/cloud/hooks/test_datafusion.py @@ -21,6 +21,7 @@ from unittest import mock import aiohttp +import google.auth.transport import pytest from aiohttp.helpers import TimerNoop from yarl import URL @@ -340,42 +341,38 @@ def test_list_pipelines_should_fail_if_status_not_200(self, mock_request, hook): @mock.patch(HOOK_STR.format("DataFusionHook._cdap_request")) def test_start_pipeline(self, mock_request, hook): run_id = 1234 - mock_request.return_value = mock.MagicMock(status=200, data=f'[{{"runId":{run_id}}}]') + mock_request.return_value = mock.MagicMock( + spec=google.auth.transport.Response, status=200, data=f'{{"runId":{run_id}}}' + ) - hook.start_pipeline(pipeline_name=PIPELINE_NAME, instance_url=INSTANCE_URL, runtime_args=RUNTIME_ARGS) - body = [ - { - "appId": PIPELINE_NAME, - "programType": "workflow", - "programId": "DataPipelineWorkflow", - "runtimeargs": RUNTIME_ARGS, - } - ] + result = hook.start_pipeline( + pipeline_name=PIPELINE_NAME, instance_url=INSTANCE_URL, runtime_args=RUNTIME_ARGS + ) + assert result == str(run_id) mock_request.assert_called_once_with( - url=f"{INSTANCE_URL}/v3/namespaces/default/start", method="POST", body=body + url=f"{INSTANCE_URL}/v3/namespaces/default/apps/{PIPELINE_NAME}/workflows/DataPipelineWorkflow/start", + method="POST", + body=RUNTIME_ARGS, ) @mock.patch(HOOK_STR.format("DataFusionHook._cdap_request")) def test_start_pipeline_stream(self, mock_request, hook): - run_id = 1234 - mock_request.return_value = mock.MagicMock(status=200, data=f'[{{"runId":{run_id}}}]') + run_id = "test-run-123" + mock_request.return_value = mock.MagicMock( + spec=google.auth.transport.Response, status=200, data=f'{{"runId":"{run_id}"}}' + ) - hook.start_pipeline( + result = hook.start_pipeline( pipeline_name=PIPELINE_NAME, instance_url=INSTANCE_URL, runtime_args=RUNTIME_ARGS, pipeline_type=DataFusionPipelineType.STREAM, ) - body = [ - { - "appId": PIPELINE_NAME, - "programType": "spark", - "programId": "DataStreamsSparkStreaming", - "runtimeargs": RUNTIME_ARGS, - } - ] + assert result == run_id mock_request.assert_called_once_with( - url=f"{INSTANCE_URL}/v3/namespaces/default/start", method="POST", body=body + url=f"{INSTANCE_URL}/v3/namespaces/default/apps/{PIPELINE_NAME}/sparks/DataStreamsSparkStreaming/start", + method="POST", + body=RUNTIME_ARGS, ) @mock.patch(HOOK_STR.format("DataFusionHook._cdap_request")) @@ -390,16 +387,10 @@ def test_start_pipeline_should_fail_if_empty_data_response(self, mock_request, h hook.start_pipeline( pipeline_name=PIPELINE_NAME, instance_url=INSTANCE_URL, runtime_args=RUNTIME_ARGS ) - body = [ - { - "appId": PIPELINE_NAME, - "programType": "workflow", - "programId": "DataPipelineWorkflow", - "runtimeargs": RUNTIME_ARGS, - } - ] mock_request.assert_called_once_with( - url=f"{INSTANCE_URL}/v3/namespaces/default/start", method="POST", body=body + url=f"{INSTANCE_URL}/v3/namespaces/default/apps/{PIPELINE_NAME}/workflows/DataPipelineWorkflow/start", + method="POST", + body=RUNTIME_ARGS, ) @mock.patch(HOOK_STR.format("DataFusionHook._cdap_request")) @@ -409,16 +400,31 @@ def test_start_pipeline_should_fail_if_status_not_200(self, mock_request, hook): hook.start_pipeline( pipeline_name=PIPELINE_NAME, instance_url=INSTANCE_URL, runtime_args=RUNTIME_ARGS ) - body = [ - { - "appId": PIPELINE_NAME, - "programType": "workflow", - "programId": "DataPipelineWorkflow", - "runtimeargs": RUNTIME_ARGS, - } - ] mock_request.assert_called_once_with( - url=f"{INSTANCE_URL}/v3/namespaces/default/start", method="POST", body=body + url=f"{INSTANCE_URL}/v3/namespaces/default/apps/{PIPELINE_NAME}/workflows/DataPipelineWorkflow/start", + method="POST", + body=RUNTIME_ARGS, + ) + + @mock.patch(HOOK_STR.format("DataFusionHook._cdap_request")) + def test_start_pipeline_should_fail_if_no_run_id(self, mock_request, hook): + """Test that start_pipeline fails gracefully when response doesn't contain runId.""" + error_response = '{"error": "Invalid runtime arguments"}' + mock_request.return_value = mock.MagicMock( + spec=google.auth.transport.Response, status=200, data=error_response + ) + with pytest.raises( + AirflowException, + match=r"Failed to start pipeline 'shrubberyPipeline'. " + r"The response does not contain a runId. Error: Invalid runtime arguments", + ): + hook.start_pipeline( + pipeline_name=PIPELINE_NAME, instance_url=INSTANCE_URL, runtime_args=RUNTIME_ARGS + ) + mock_request.assert_called_once_with( + url=f"{INSTANCE_URL}/v3/namespaces/default/apps/{PIPELINE_NAME}/workflows/DataPipelineWorkflow/start", + method="POST", + body=RUNTIME_ARGS, ) @mock.patch(HOOK_STR.format("DataFusionHook._cdap_request"))