diff --git a/airflow/providers/google/cloud/triggers/dataproc.py b/airflow/providers/google/cloud/triggers/dataproc.py index 8708c9e24512fb..926b1f6d6ad852 100644 --- a/airflow/providers/google/cloud/triggers/dataproc.py +++ b/airflow/providers/google/cloud/triggers/dataproc.py @@ -151,7 +151,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: try: while True: cluster = await self.fetch_cluster_status() - if self.is_terminal_state(cluster.status.state): + if self.check_cluster_state(cluster.status.state): if cluster.status.state == ClusterStatus.State.ERROR: await self.gather_diagnostics_and_maybe_delete(cluster) else: @@ -174,9 +174,9 @@ async def fetch_cluster_status(self) -> Cluster: project_id=self.project_id, region=self.region, cluster_name=self.cluster_name ) - def is_terminal_state(self, state: ClusterStatus.State) -> bool: + def check_cluster_state(self, state: ClusterStatus.State) -> bool: """ - Check if the state is terminal. + Check if the state is error or running. :param state: The state of the cluster. """ diff --git a/tests/providers/google/cloud/triggers/test_dataproc.py b/tests/providers/google/cloud/triggers/test_dataproc.py index 404d05eda37154..0b19e6fe366e47 100644 --- a/tests/providers/google/cloud/triggers/test_dataproc.py +++ b/tests/providers/google/cloud/triggers/test_dataproc.py @@ -24,6 +24,7 @@ import pytest from google.cloud.dataproc_v1 import Batch, ClusterStatus from google.protobuf.any_pb2 import Any +from google.rpc.error_details_pb2 import ErrorInfo from google.rpc.status_pb2 import Status from airflow.providers.google.cloud.triggers.dataproc import ( @@ -70,6 +71,7 @@ def batch_trigger(): gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=None, polling_interval_seconds=TEST_POLL_INTERVAL, + delete_on_error=True, ) return trigger @@ -174,30 +176,6 @@ async def test_async_cluster_triggers_on_success_should_execute_successfully( ) assert expected_event == actual_event - @pytest.mark.asyncio - @mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster") - async def test_async_cluster_trigger_run_returns_error_event( - self, mock_hook, cluster_trigger, async_get_cluster - ): - mock_hook.return_value = async_get_cluster( - project_id=TEST_PROJECT_ID, - region=TEST_REGION, - cluster_name=TEST_CLUSTER_NAME, - status=ClusterStatus(state=ClusterStatus.State.ERROR), - ) - - actual_event = await cluster_trigger.run().asend(None) - await asyncio.sleep(0.5) - - expected_event = TriggerEvent( - { - "cluster_name": TEST_CLUSTER_NAME, - "cluster_state": ClusterStatus.State.ERROR, - "cluster": actual_event.payload["cluster"], - } - ) - assert expected_event == actual_event - @pytest.mark.asyncio @mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster") async def test_cluster_run_loop_is_still_running( @@ -216,9 +194,54 @@ async def test_cluster_run_loop_is_still_running( await asyncio.sleep(0.5) assert not task.done() - assert f"Current state is: {ClusterStatus.State.CREATING}" + assert f"Current state is: {ClusterStatus.State.CREATING}." assert f"Sleeping for {TEST_POLL_INTERVAL} seconds." + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster") + async def test_fetch_cluster_status(self, mock_get_cluster, cluster_trigger, async_get_cluster): + mock_get_cluster.return_value = async_get_cluster( + status=ClusterStatus(state=ClusterStatus.State.RUNNING) + ) + cluster = await cluster_trigger.fetch_cluster_status() + + assert cluster.status.state == ClusterStatus.State.RUNNING, "The cluster state should be RUNNING" + + def test_check_luster_state(self, cluster_trigger): + """Test if specific states are correctly identified.""" + assert cluster_trigger.check_cluster_state( + ClusterStatus.State.RUNNING + ), "RUNNING should be correct state" + assert cluster_trigger.check_cluster_state(ClusterStatus.State.ERROR), "ERROR should be correct state" + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.diagnose_cluster") + @mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.delete_cluster") + async def test_gather_diagnostics_and_maybe_delete( + self, mock_delete_cluster, mock_diagnose_cluster, cluster_trigger, async_get_cluster + ): + error_info = ErrorInfo(reason="DIAGNOSTICS") + any_message = Any() + any_message.Pack(error_info) + + diagnose_future = asyncio.Future() + status = Status() + status.details.add().CopyFrom(any_message) + diagnose_future.set_result(status) + mock_diagnose_cluster.return_value = diagnose_future + + delete_future = asyncio.Future() + delete_future.set_result(None) + mock_delete_cluster.return_value = delete_future + + cluster = await async_get_cluster(status=ClusterStatus(state=ClusterStatus.State.ERROR)) + event = await cluster_trigger.gather_diagnostics_and_maybe_delete(cluster) + + mock_delete_cluster.assert_called_once() + assert ( + "deleted" in event.payload["action"] + ), "The cluster should be deleted due to error state and delete_on_error=True" + @pytest.mark.db_test class TestDataprocBatchTrigger: