Skip to content

Commit

Permalink
Fix PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sunank200 committed Apr 23, 2024
1 parent 402c796 commit ee85f59
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 49 deletions.
105 changes: 65 additions & 40 deletions airflow/providers/google/cloud/triggers/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from google.api_core.exceptions import NotFound
from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus

from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook
from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook, DataprocHook
from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
from airflow.triggers.base import BaseTrigger, TriggerEvent
Expand Down Expand Up @@ -59,6 +59,12 @@ def get_async_hook(self):
impersonation_chain=self.impersonation_chain,
)

def get_sync_hook(self):
return DataprocHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)


class DataprocSubmitTrigger(DataprocBaseTrigger):
"""
Expand Down Expand Up @@ -150,39 +156,74 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
"""Run the trigger."""
try:
while True:
cluster = await self.fetch_cluster_status()
if self.check_cluster_state(cluster.status.state):
if cluster.status.state == ClusterStatus.State.ERROR:
await self.gather_diagnostics_and_maybe_delete(cluster)
else:
yield TriggerEvent(
{
"cluster_name": self.cluster_name,
"cluster_state": cluster.status.state,
"cluster": cluster,
}
)
cluster = await self.fetch_cluster()
state = cluster.status.state
if state == ClusterStatus.State.ERROR:
await self.gather_diagnostics_and_delete_on_error(cluster)
break
elif state == ClusterStatus.State.RUNNING:
yield TriggerEvent(
{
"cluster_name": self.cluster_name,
"cluster_state": state,
"cluster": cluster,
}
)
break

self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds)
await asyncio.sleep(self.polling_interval_seconds)
except asyncio.CancelledError:
await self.handle_cancellation()
try:
if self.delete_on_error:
self.log.info("Deleting cluster %s.", self.cluster_name)
self.get_sync_hook().delete_cluster(
region=self.region, cluster_name=self.cluster_name, project_id=self.project_id
)
self.log.info("Deleted cluster %s during cancellation.", self.cluster_name)
self.log.info("Cluster deletion initiated, awaiting completion...")
async for event in self.wait_until_cluster_deleted():
if event["status"] == "success":
self.log.info("Cluster deletion confirmed.")
elif event["status"] == "error":
self.log.error("Cluster deletion failed with message: %s", event["message"])
self.log.info("Finished handling cluster deletion.")
except Exception as e:
self.log.error("Error during cancellation handling: %s", e)

async def wait_until_cluster_deleted(self):
"""Wait until the cluster is confirmed as deleted."""
end_time = time.time() + self.polling_interval_seconds * 10 # Set end time for loop
try:
while time.time() < end_time:
try:
await self.get_async_hook().get_cluster(
region=self.region,
cluster_name=self.cluster_name,
project_id=self.project_id,
)
self.log.info(
"Cluster still exists. Sleeping for %s seconds.", self.polling_interval_seconds
)
await asyncio.sleep(self.polling_interval_seconds)
except NotFound:
self.log.info("Cluster successfully deleted.")
yield TriggerEvent({"status": "success", "message": "Cluster deleted successfully."})
return
except Exception as e:
self.log.error("Error while checking for cluster deletion: %s", e)
yield TriggerEvent({"status": "error", "message": str(e)})
yield TriggerEvent(
{"status": "error", "message": "Timeout - cluster deletion not confirmed within expected time."}
)

async def fetch_cluster_status(self) -> Cluster:
async def fetch_cluster(self) -> Cluster:
"""Fetch the cluster status."""
return await self.get_async_hook().get_cluster(
project_id=self.project_id, region=self.region, cluster_name=self.cluster_name
)

def check_cluster_state(self, state: ClusterStatus.State) -> bool:
"""
Check if the state is error or running.
:param state: The state of the cluster.
"""
return state in (ClusterStatus.State.ERROR, ClusterStatus.State.RUNNING)

async def gather_diagnostics_and_maybe_delete(self, cluster: Cluster):
async def gather_diagnostics_and_delete_on_error(self, cluster: Cluster):
"""
Gather diagnostics and maybe delete the cluster.
Expand Down Expand Up @@ -218,22 +259,6 @@ async def gather_diagnostics_and_maybe_delete(self, cluster: Cluster):
{"cluster_name": self.cluster_name, "cluster_state": cluster.status.state, "cluster": cluster}
)

async def handle_cancellation(self) -> None:
"""Handle the cancellation of the trigger, cleaning up resources if necessary."""
self.log.info("Cancellation requested. Deleting the cluster if created.")
try:
if self.delete_on_error:
cluster = await self.fetch_cluster_status()
if cluster.status.state == ClusterStatus.State.ERROR:
await self.get_async_hook().async_delete_cluster(
region=self.region, cluster_name=self.cluster_name, project_id=self.project_id
)
self.log.info("Deleted cluster due to ERROR state during cancellation.")
else:
self.log.info("Cancellation did not require cluster deletion.")
except Exception as e:
self.log.error("Error during cancellation handling: %s", e)


class DataprocBatchTrigger(DataprocBaseTrigger):
"""
Expand Down
11 changes: 2 additions & 9 deletions tests/providers/google/cloud/triggers/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,17 +238,10 @@ async def test_fetch_cluster_status(self, mock_get_cluster, cluster_trigger, asy
mock_get_cluster.return_value = async_get_cluster(
status=ClusterStatus(state=ClusterStatus.State.RUNNING)
)
cluster = await cluster_trigger.fetch_cluster_status()
cluster = await cluster_trigger.fetch_cluster()

assert cluster.status.state == ClusterStatus.State.RUNNING, "The cluster state should be RUNNING"

def test_check_luster_state(self, cluster_trigger):
"""Test if specific states are correctly identified."""
assert cluster_trigger.check_cluster_state(
ClusterStatus.State.RUNNING
), "RUNNING should be correct state"
assert cluster_trigger.check_cluster_state(ClusterStatus.State.ERROR), "ERROR should be correct state"

@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.diagnose_cluster")
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.delete_cluster")
Expand All @@ -270,7 +263,7 @@ async def test_gather_diagnostics_and_maybe_delete(
mock_delete_cluster.return_value = delete_future

cluster = await async_get_cluster(status=ClusterStatus(state=ClusterStatus.State.ERROR))
event = await cluster_trigger.gather_diagnostics_and_maybe_delete(cluster)
event = await cluster_trigger.gather_diagnostics_and_delete_on_error(cluster)

mock_delete_cluster.assert_called_once()
assert (
Expand Down

0 comments on commit ee85f59

Please sign in to comment.