Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sunank200 committed Apr 22, 2024
1 parent 468332a commit c133c83
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 16 deletions.
6 changes: 3 additions & 3 deletions airflow/providers/google/cloud/triggers/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
try:
while True:
cluster = await self.fetch_cluster_status()
if self.is_terminal_state(cluster.status.state):
if self.check_cluster_state(cluster.status.state):
if cluster.status.state == ClusterStatus.State.ERROR:
await self.gather_diagnostics_and_maybe_delete(cluster)
else:
Expand All @@ -174,9 +174,9 @@ async def fetch_cluster_status(self) -> Cluster:
project_id=self.project_id, region=self.region, cluster_name=self.cluster_name
)

def is_terminal_state(self, state: ClusterStatus.State) -> bool:
def check_cluster_state(self, state: ClusterStatus.State) -> bool:
"""
Check if the state is terminal.
Check if the state is error or running.
:param state: The state of the cluster.
"""
Expand Down
84 changes: 71 additions & 13 deletions tests/providers/google/cloud/triggers/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import pytest
from google.cloud.dataproc_v1 import Batch, ClusterStatus
from google.protobuf.any_pb2 import Any
from google.rpc.error_details_pb2 import ErrorInfo
from google.rpc.status_pb2 import Status

from airflow.providers.google.cloud.triggers.dataproc import (
Expand Down Expand Up @@ -70,6 +71,7 @@ def batch_trigger():
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=None,
polling_interval_seconds=TEST_POLL_INTERVAL,
delete_on_error=True,
)
return trigger

Expand All @@ -96,6 +98,7 @@ def diagnose_operation_trigger():
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=None,
polling_interval_seconds=TEST_POLL_INTERVAL,
delete_on_error=True,
)


Expand Down Expand Up @@ -176,27 +179,37 @@ async def test_async_cluster_triggers_on_success_should_execute_successfully(

@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster")
@mock.patch(
"airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.delete_cluster",
return_value=asyncio.Future(),
)
@mock.patch("google.auth.default")
async def test_async_cluster_trigger_run_returns_error_event(
self, mock_hook, cluster_trigger, async_get_cluster
self, mock_auth, mock_delete_cluster, mock_get_cluster, cluster_trigger, async_get_cluster, caplog
):
mock_hook.return_value = async_get_cluster(
mock_credentials = mock.MagicMock()
mock_credentials.universe_domain = "googleapis.com"

mock_auth.return_value = (mock_credentials, "project-id")

mock_delete_cluster.return_value = asyncio.Future()
mock_delete_cluster.return_value.set_result(None)

mock_get_cluster.return_value = async_get_cluster(
project_id=TEST_PROJECT_ID,
region=TEST_REGION,
cluster_name=TEST_CLUSTER_NAME,
status=ClusterStatus(state=ClusterStatus.State.ERROR),
)

actual_event = await cluster_trigger.run().asend(None)
await asyncio.sleep(0.5)
caplog.set_level(logging.INFO)

expected_event = TriggerEvent(
{
"cluster_name": TEST_CLUSTER_NAME,
"cluster_state": ClusterStatus.State.ERROR,
"cluster": actual_event.payload["cluster"],
}
)
assert expected_event == actual_event
trigger_event = None
async for event in cluster_trigger.run():
trigger_event = event

assert trigger_event is None, "Expected an event to be emitted"
assert "Cluster is in ERROR state. Gathering diagnostic information." in caplog.text

@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster")
Expand All @@ -216,9 +229,54 @@ async def test_cluster_run_loop_is_still_running(
await asyncio.sleep(0.5)

assert not task.done()
assert f"Current state is: {ClusterStatus.State.CREATING}"
assert f"Current state is: {ClusterStatus.State.CREATING}."
assert f"Sleeping for {TEST_POLL_INTERVAL} seconds."

@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster")
async def test_fetch_cluster_status(self, mock_get_cluster, cluster_trigger, async_get_cluster):
mock_get_cluster.return_value = async_get_cluster(
status=ClusterStatus(state=ClusterStatus.State.RUNNING)
)
cluster = await cluster_trigger.fetch_cluster_status()

assert cluster.status.state == ClusterStatus.State.RUNNING, "The cluster state should be RUNNING"

def test_check_luster_state(self, cluster_trigger):
"""Test if specific states are correctly identified."""
assert cluster_trigger.check_cluster_state(
ClusterStatus.State.RUNNING
), "RUNNING should be correct state"
assert cluster_trigger.check_cluster_state(ClusterStatus.State.ERROR), "ERROR should be correct state"

@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.diagnose_cluster")
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.delete_cluster")
async def test_gather_diagnostics_and_maybe_delete(
self, mock_delete_cluster, mock_diagnose_cluster, cluster_trigger, async_get_cluster
):
error_info = ErrorInfo(reason="DIAGNOSTICS")
any_message = Any()
any_message.Pack(error_info)

diagnose_future = asyncio.Future()
status = Status()
status.details.add().CopyFrom(any_message)
diagnose_future.set_result(status)
mock_diagnose_cluster.return_value = diagnose_future

delete_future = asyncio.Future()
delete_future.set_result(None)
mock_delete_cluster.return_value = delete_future

cluster = await async_get_cluster(status=ClusterStatus(state=ClusterStatus.State.ERROR))
event = await cluster_trigger.gather_diagnostics_and_maybe_delete(cluster)

mock_delete_cluster.assert_called_once()
assert (
"deleted" in event.payload["action"]
), "The cluster should be deleted due to error state and delete_on_error=True"


@pytest.mark.db_test
class TestDataprocBatchTrigger:
Expand Down

0 comments on commit c133c83

Please sign in to comment.