diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index edbfbd3f39b45..e4fccfedd87b6 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -816,6 +816,7 @@ def execute(self, context: Context) -> dict: gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, polling_interval_seconds=self.polling_interval_seconds, + delete_on_error=self.delete_on_error, ), method_name="execute_complete", ) diff --git a/airflow/providers/google/cloud/triggers/dataproc.py b/airflow/providers/google/cloud/triggers/dataproc.py index f0aecddb4a8ed..32b536a2ecaa3 100644 --- a/airflow/providers/google/cloud/triggers/dataproc.py +++ b/airflow/providers/google/cloud/triggers/dataproc.py @@ -25,9 +25,10 @@ from typing import Any, AsyncIterator, Sequence from google.api_core.exceptions import NotFound -from google.cloud.dataproc_v1 import Batch, ClusterStatus, JobStatus +from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus -from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook +from airflow.exceptions import AirflowException +from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook, DataprocHook from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -43,6 +44,7 @@ def __init__( gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, polling_interval_seconds: int = 30, + delete_on_error: bool = True, ): super().__init__() self.region = region @@ -50,6 +52,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain self.polling_interval_seconds = polling_interval_seconds + self.delete_on_error = delete_on_error def get_async_hook(self): return DataprocAsyncHook( @@ -57,6 +60,16 @@ def get_async_hook(self): impersonation_chain=self.impersonation_chain, ) + def get_sync_hook(self): + # The synchronous hook is utilized to delete the cluster when a task is cancelled. + # This is because the asynchronous hook deletion is not awaited when the trigger task + # is cancelled. The call for deleting the cluster through the sync hook is not a blocking + # call, which means it does not wait until the cluster is deleted. + return DataprocHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + class DataprocSubmitTrigger(DataprocBaseTrigger): """ @@ -140,24 +153,73 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "gcp_conn_id": self.gcp_conn_id, "impersonation_chain": self.impersonation_chain, "polling_interval_seconds": self.polling_interval_seconds, + "delete_on_error": self.delete_on_error, }, ) async def run(self) -> AsyncIterator[TriggerEvent]: - while True: - cluster = await self.get_async_hook().get_cluster( - project_id=self.project_id, region=self.region, cluster_name=self.cluster_name + try: + while True: + cluster = await self.fetch_cluster() + state = cluster.status.state + if state == ClusterStatus.State.ERROR: + await self.delete_when_error_occurred(cluster) + yield TriggerEvent( + { + "cluster_name": self.cluster_name, + "cluster_state": ClusterStatus.State.DELETING, + "cluster": cluster, + } + ) + return + elif state == ClusterStatus.State.RUNNING: + yield TriggerEvent( + { + "cluster_name": self.cluster_name, + "cluster_state": state, + "cluster": cluster, + } + ) + return + self.log.info("Current state is %s", state) + self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds) + await asyncio.sleep(self.polling_interval_seconds) + except asyncio.CancelledError: + try: + if self.delete_on_error: + self.log.info("Deleting cluster %s.", self.cluster_name) + # The synchronous hook is utilized to delete the cluster when a task is cancelled. + # This is because the asynchronous hook deletion is not awaited when the trigger task + # is cancelled. The call for deleting the cluster through the sync hook is not a blocking + # call, which means it does not wait until the cluster is deleted. + self.get_sync_hook().delete_cluster( + region=self.region, cluster_name=self.cluster_name, project_id=self.project_id + ) + self.log.info("Deleted cluster %s during cancellation.", self.cluster_name) + except Exception as e: + self.log.error("Error during cancellation handling: %s", e) + raise AirflowException("Error during cancellation handling: %s", e) + + async def fetch_cluster(self) -> Cluster: + """Fetch the cluster status.""" + return await self.get_async_hook().get_cluster( + project_id=self.project_id, region=self.region, cluster_name=self.cluster_name + ) + + async def delete_when_error_occurred(self, cluster: Cluster) -> None: + """ + Delete the cluster on error. + + :param cluster: The cluster to delete. + """ + if self.delete_on_error: + self.log.info("Deleting cluster %s.", self.cluster_name) + await self.get_async_hook().delete_cluster( + region=self.region, cluster_name=self.cluster_name, project_id=self.project_id ) - state = cluster.status.state - self.log.info("Dataproc cluster: %s is in state: %s", self.cluster_name, state) - if state in ( - ClusterStatus.State.ERROR, - ClusterStatus.State.RUNNING, - ): - break - self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds) - await asyncio.sleep(self.polling_interval_seconds) - yield TriggerEvent({"cluster_name": self.cluster_name, "cluster_state": state, "cluster": cluster}) + self.log.info("Cluster %s has been deleted.", self.cluster_name) + else: + self.log.info("Cluster %s is not deleted as delete_on_error is set to False.", self.cluster_name) class DataprocBatchTrigger(DataprocBaseTrigger): diff --git a/tests/providers/google/cloud/triggers/test_dataproc.py b/tests/providers/google/cloud/triggers/test_dataproc.py index 45607d51b8a59..e310f2e0dfc9e 100644 --- a/tests/providers/google/cloud/triggers/test_dataproc.py +++ b/tests/providers/google/cloud/triggers/test_dataproc.py @@ -22,7 +22,7 @@ from unittest import mock import pytest -from google.cloud.dataproc_v1 import Batch, ClusterStatus +from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus from google.protobuf.any_pb2 import Any from google.rpc.status_pb2 import Status @@ -70,6 +70,7 @@ def batch_trigger(): gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=None, polling_interval_seconds=TEST_POLL_INTERVAL, + delete_on_error=True, ) return trigger @@ -96,6 +97,7 @@ def diagnose_operation_trigger(): gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=None, polling_interval_seconds=TEST_POLL_INTERVAL, + delete_on_error=True, ) @@ -147,6 +149,7 @@ def test_async_cluster_trigger_serialization_should_execute_successfully(self, c "gcp_conn_id": TEST_GCP_CONN_ID, "impersonation_chain": None, "polling_interval_seconds": TEST_POLL_INTERVAL, + "delete_on_error": True, } @pytest.mark.asyncio @@ -175,27 +178,37 @@ async def test_async_cluster_triggers_on_success_should_execute_successfully( @pytest.mark.asyncio @mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster") + @mock.patch( + "airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.delete_cluster", + return_value=asyncio.Future(), + ) + @mock.patch("google.auth.default") async def test_async_cluster_trigger_run_returns_error_event( - self, mock_hook, cluster_trigger, async_get_cluster + self, mock_auth, mock_delete_cluster, mock_get_cluster, cluster_trigger, async_get_cluster, caplog ): - mock_hook.return_value = async_get_cluster( + mock_credentials = mock.MagicMock() + mock_credentials.universe_domain = "googleapis.com" + + mock_auth.return_value = (mock_credentials, "project-id") + + mock_delete_cluster.return_value = asyncio.Future() + mock_delete_cluster.return_value.set_result(None) + + mock_get_cluster.return_value = async_get_cluster( project_id=TEST_PROJECT_ID, region=TEST_REGION, cluster_name=TEST_CLUSTER_NAME, status=ClusterStatus(state=ClusterStatus.State.ERROR), ) - actual_event = await cluster_trigger.run().asend(None) - await asyncio.sleep(0.5) + caplog.set_level(logging.INFO) - expected_event = TriggerEvent( - { - "cluster_name": TEST_CLUSTER_NAME, - "cluster_state": ClusterStatus.State.ERROR, - "cluster": actual_event.payload["cluster"], - } - ) - assert expected_event == actual_event + trigger_event = None + async for event in cluster_trigger.run(): + trigger_event = event + + assert trigger_event.payload["cluster_name"] == TEST_CLUSTER_NAME + assert trigger_event.payload["cluster_state"] == ClusterStatus.State.DELETING @pytest.mark.asyncio @mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster") @@ -215,9 +228,93 @@ async def test_cluster_run_loop_is_still_running( await asyncio.sleep(0.5) assert not task.done() - assert f"Current state is: {ClusterStatus.State.CREATING}" + assert f"Current state is: {ClusterStatus.State.CREATING}." assert f"Sleeping for {TEST_POLL_INTERVAL} seconds." + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_async_hook") + @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_sync_hook") + async def test_cluster_trigger_cancellation_handling( + self, mock_get_sync_hook, mock_get_async_hook, caplog + ): + cluster = Cluster(status=ClusterStatus(state=ClusterStatus.State.RUNNING)) + mock_get_async_hook.return_value.get_cluster.return_value = asyncio.Future() + mock_get_async_hook.return_value.get_cluster.return_value.set_result(cluster) + + mock_delete_cluster = mock.MagicMock() + mock_get_sync_hook.return_value.delete_cluster = mock_delete_cluster + + cluster_trigger = DataprocClusterTrigger( + cluster_name="cluster_name", + project_id="project-id", + region="region", + gcp_conn_id="google_cloud_default", + impersonation_chain=None, + polling_interval_seconds=5, + delete_on_error=True, + ) + + cluster_trigger_gen = cluster_trigger.run() + + try: + await cluster_trigger_gen.__anext__() + await cluster_trigger_gen.aclose() + + except asyncio.CancelledError: + # Verify that cancellation was handled as expected + if cluster_trigger.delete_on_error: + mock_get_sync_hook.assert_called_once() + mock_delete_cluster.assert_called_once_with( + region=cluster_trigger.region, + cluster_name=cluster_trigger.cluster_name, + project_id=cluster_trigger.project_id, + ) + assert "Deleting cluster" in caplog.text + assert "Deleted cluster" in caplog.text + else: + mock_delete_cluster.assert_not_called() + except Exception as e: + pytest.fail(f"Unexpected exception raised: {e}") + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster") + async def test_fetch_cluster_status(self, mock_get_cluster, cluster_trigger, async_get_cluster): + mock_get_cluster.return_value = async_get_cluster( + status=ClusterStatus(state=ClusterStatus.State.RUNNING) + ) + cluster = await cluster_trigger.fetch_cluster() + + assert cluster.status.state == ClusterStatus.State.RUNNING, "The cluster state should be RUNNING" + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.delete_cluster") + async def test_delete_when_error_occurred(self, mock_delete_cluster, cluster_trigger): + mock_cluster = mock.MagicMock(spec=Cluster) + type(mock_cluster).status = mock.PropertyMock( + return_value=mock.MagicMock(state=ClusterStatus.State.ERROR) + ) + + mock_delete_future = asyncio.Future() + mock_delete_future.set_result(None) + mock_delete_cluster.return_value = mock_delete_future + + cluster_trigger.delete_on_error = True + + await cluster_trigger.delete_when_error_occurred(mock_cluster) + + mock_delete_cluster.assert_called_once_with( + region=cluster_trigger.region, + cluster_name=cluster_trigger.cluster_name, + project_id=cluster_trigger.project_id, + ) + + mock_delete_cluster.reset_mock() + cluster_trigger.delete_on_error = False + + await cluster_trigger.delete_when_error_occurred(mock_cluster) + + mock_delete_cluster.assert_not_called() + @pytest.mark.db_test class TestDataprocBatchTrigger: