Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve DataprocCreateClusterOperator in Triggers for Enhanced Error Handling and Resource Cleanup #39130

Merged
merged 9 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions airflow/providers/google/cloud/operators/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,7 @@ def execute(self, context: Context) -> dict:
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
polling_interval_seconds=self.polling_interval_seconds,
delete_on_error=self.delete_on_error,
),
method_name="execute_complete",
)
Expand Down
92 changes: 77 additions & 15 deletions airflow/providers/google/cloud/triggers/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@
from typing import Any, AsyncIterator, Sequence

from google.api_core.exceptions import NotFound
from google.cloud.dataproc_v1 import Batch, ClusterStatus, JobStatus
from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus

from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook
from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook, DataprocHook
from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
from airflow.triggers.base import BaseTrigger, TriggerEvent
Expand All @@ -43,20 +44,32 @@ def __init__(
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
polling_interval_seconds: int = 30,
delete_on_error: bool = True,
):
super().__init__()
self.region = region
self.project_id = project_id
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.polling_interval_seconds = polling_interval_seconds
self.delete_on_error = delete_on_error

def get_async_hook(self):
return DataprocAsyncHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)

def get_sync_hook(self):
sunank200 marked this conversation as resolved.
Show resolved Hide resolved
# The synchronous hook is utilized to delete the cluster when a task is cancelled.
# This is because the asynchronous hook deletion is not awaited when the trigger task
# is cancelled. The call for deleting the cluster through the sync hook is not a blocking
# call, which means it does not wait until the cluster is deleted.
return DataprocHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)


class DataprocSubmitTrigger(DataprocBaseTrigger):
"""
Expand Down Expand Up @@ -140,24 +153,73 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"gcp_conn_id": self.gcp_conn_id,
"impersonation_chain": self.impersonation_chain,
"polling_interval_seconds": self.polling_interval_seconds,
"delete_on_error": self.delete_on_error,
},
)

async def run(self) -> AsyncIterator[TriggerEvent]:
while True:
cluster = await self.get_async_hook().get_cluster(
project_id=self.project_id, region=self.region, cluster_name=self.cluster_name
try:
while True:
cluster = await self.fetch_cluster()
state = cluster.status.state
if state == ClusterStatus.State.ERROR:
await self.delete_when_error_occurred(cluster)
yield TriggerEvent(
{
"cluster_name": self.cluster_name,
"cluster_state": ClusterStatus.State.DELETING,
"cluster": cluster,
}
)
return
elif state == ClusterStatus.State.RUNNING:
yield TriggerEvent(
{
"cluster_name": self.cluster_name,
"cluster_state": state,
"cluster": cluster,
}
)
return
self.log.info("Current state is %s", state)
self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds)
sunank200 marked this conversation as resolved.
Show resolved Hide resolved
await asyncio.sleep(self.polling_interval_seconds)
except asyncio.CancelledError:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would like to understand when exactly is the CancelledError raised.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is raised when the user marks the task as failed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like there are other cases too when CancelledError is raised e.g. when the trigger restarts as mentioned in #36090 (comment)

We might need some additional measures here.

try:
if self.delete_on_error:
sunank200 marked this conversation as resolved.
Show resolved Hide resolved
self.log.info("Deleting cluster %s.", self.cluster_name)
# The synchronous hook is utilized to delete the cluster when a task is cancelled.
# This is because the asynchronous hook deletion is not awaited when the trigger task
# is cancelled. The call for deleting the cluster through the sync hook is not a blocking
# call, which means it does not wait until the cluster is deleted.
self.get_sync_hook().delete_cluster(
sunank200 marked this conversation as resolved.
Show resolved Hide resolved
sunank200 marked this conversation as resolved.
Show resolved Hide resolved
region=self.region, cluster_name=self.cluster_name, project_id=self.project_id
)
self.log.info("Deleted cluster %s during cancellation.", self.cluster_name)
except Exception as e:
self.log.error("Error during cancellation handling: %s", e)
raise AirflowException("Error during cancellation handling: %s", e)

async def fetch_cluster(self) -> Cluster:
"""Fetch the cluster status."""
return await self.get_async_hook().get_cluster(
project_id=self.project_id, region=self.region, cluster_name=self.cluster_name
)
sunank200 marked this conversation as resolved.
Show resolved Hide resolved

async def delete_when_error_occurred(self, cluster: Cluster) -> None:
"""
Delete the cluster on error.

:param cluster: The cluster to delete.
"""
if self.delete_on_error:
self.log.info("Deleting cluster %s.", self.cluster_name)
await self.get_async_hook().delete_cluster(
region=self.region, cluster_name=self.cluster_name, project_id=self.project_id
)
state = cluster.status.state
self.log.info("Dataproc cluster: %s is in state: %s", self.cluster_name, state)
if state in (
ClusterStatus.State.ERROR,
ClusterStatus.State.RUNNING,
):
break
self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds)
await asyncio.sleep(self.polling_interval_seconds)
yield TriggerEvent({"cluster_name": self.cluster_name, "cluster_state": state, "cluster": cluster})
self.log.info("Cluster %s has been deleted.", self.cluster_name)
else:
self.log.info("Cluster %s is not deleted as delete_on_error is set to False.", self.cluster_name)


class DataprocBatchTrigger(DataprocBaseTrigger):
Expand Down
125 changes: 111 additions & 14 deletions tests/providers/google/cloud/triggers/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from unittest import mock

import pytest
from google.cloud.dataproc_v1 import Batch, ClusterStatus
from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus
from google.protobuf.any_pb2 import Any
from google.rpc.status_pb2 import Status

Expand Down Expand Up @@ -70,6 +70,7 @@ def batch_trigger():
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=None,
polling_interval_seconds=TEST_POLL_INTERVAL,
delete_on_error=True,
)
return trigger

Expand All @@ -96,6 +97,7 @@ def diagnose_operation_trigger():
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=None,
polling_interval_seconds=TEST_POLL_INTERVAL,
delete_on_error=True,
)


Expand Down Expand Up @@ -147,6 +149,7 @@ def test_async_cluster_trigger_serialization_should_execute_successfully(self, c
"gcp_conn_id": TEST_GCP_CONN_ID,
"impersonation_chain": None,
"polling_interval_seconds": TEST_POLL_INTERVAL,
"delete_on_error": True,
}

@pytest.mark.asyncio
Expand Down Expand Up @@ -175,27 +178,37 @@ async def test_async_cluster_triggers_on_success_should_execute_successfully(

@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster")
@mock.patch(
"airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.delete_cluster",
return_value=asyncio.Future(),
)
@mock.patch("google.auth.default")
async def test_async_cluster_trigger_run_returns_error_event(
self, mock_hook, cluster_trigger, async_get_cluster
self, mock_auth, mock_delete_cluster, mock_get_cluster, cluster_trigger, async_get_cluster, caplog
):
mock_hook.return_value = async_get_cluster(
mock_credentials = mock.MagicMock()
mock_credentials.universe_domain = "googleapis.com"

mock_auth.return_value = (mock_credentials, "project-id")

mock_delete_cluster.return_value = asyncio.Future()
mock_delete_cluster.return_value.set_result(None)

mock_get_cluster.return_value = async_get_cluster(
project_id=TEST_PROJECT_ID,
region=TEST_REGION,
cluster_name=TEST_CLUSTER_NAME,
status=ClusterStatus(state=ClusterStatus.State.ERROR),
)

actual_event = await cluster_trigger.run().asend(None)
await asyncio.sleep(0.5)
caplog.set_level(logging.INFO)

expected_event = TriggerEvent(
{
"cluster_name": TEST_CLUSTER_NAME,
"cluster_state": ClusterStatus.State.ERROR,
"cluster": actual_event.payload["cluster"],
}
)
assert expected_event == actual_event
trigger_event = None
async for event in cluster_trigger.run():
trigger_event = event

assert trigger_event.payload["cluster_name"] == TEST_CLUSTER_NAME
assert trigger_event.payload["cluster_state"] == ClusterStatus.State.DELETING

@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster")
Expand All @@ -215,9 +228,93 @@ async def test_cluster_run_loop_is_still_running(
await asyncio.sleep(0.5)

assert not task.done()
assert f"Current state is: {ClusterStatus.State.CREATING}"
assert f"Current state is: {ClusterStatus.State.CREATING}."
assert f"Sleeping for {TEST_POLL_INTERVAL} seconds."

@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_async_hook")
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_sync_hook")
async def test_cluster_trigger_cancellation_handling(
self, mock_get_sync_hook, mock_get_async_hook, caplog
):
cluster = Cluster(status=ClusterStatus(state=ClusterStatus.State.RUNNING))
mock_get_async_hook.return_value.get_cluster.return_value = asyncio.Future()
mock_get_async_hook.return_value.get_cluster.return_value.set_result(cluster)

mock_delete_cluster = mock.MagicMock()
mock_get_sync_hook.return_value.delete_cluster = mock_delete_cluster

cluster_trigger = DataprocClusterTrigger(
cluster_name="cluster_name",
project_id="project-id",
region="region",
gcp_conn_id="google_cloud_default",
impersonation_chain=None,
polling_interval_seconds=5,
delete_on_error=True,
)

cluster_trigger_gen = cluster_trigger.run()

try:
await cluster_trigger_gen.__anext__()
await cluster_trigger_gen.aclose()

except asyncio.CancelledError:
# Verify that cancellation was handled as expected
if cluster_trigger.delete_on_error:
mock_get_sync_hook.assert_called_once()
mock_delete_cluster.assert_called_once_with(
region=cluster_trigger.region,
cluster_name=cluster_trigger.cluster_name,
project_id=cluster_trigger.project_id,
)
assert "Deleting cluster" in caplog.text
assert "Deleted cluster" in caplog.text
else:
mock_delete_cluster.assert_not_called()
except Exception as e:
pytest.fail(f"Unexpected exception raised: {e}")

@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster")
async def test_fetch_cluster_status(self, mock_get_cluster, cluster_trigger, async_get_cluster):
mock_get_cluster.return_value = async_get_cluster(
status=ClusterStatus(state=ClusterStatus.State.RUNNING)
)
cluster = await cluster_trigger.fetch_cluster()

assert cluster.status.state == ClusterStatus.State.RUNNING, "The cluster state should be RUNNING"
pankajkoti marked this conversation as resolved.
Show resolved Hide resolved

@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.delete_cluster")
async def test_delete_when_error_occurred(self, mock_delete_cluster, cluster_trigger):
mock_cluster = mock.MagicMock(spec=Cluster)
type(mock_cluster).status = mock.PropertyMock(
return_value=mock.MagicMock(state=ClusterStatus.State.ERROR)
)

mock_delete_future = asyncio.Future()
mock_delete_future.set_result(None)
mock_delete_cluster.return_value = mock_delete_future

cluster_trigger.delete_on_error = True

await cluster_trigger.delete_when_error_occurred(mock_cluster)

mock_delete_cluster.assert_called_once_with(
region=cluster_trigger.region,
cluster_name=cluster_trigger.cluster_name,
project_id=cluster_trigger.project_id,
)

mock_delete_cluster.reset_mock()
cluster_trigger.delete_on_error = False

await cluster_trigger.delete_when_error_occurred(mock_cluster)

mock_delete_cluster.assert_not_called()

sunank200 marked this conversation as resolved.
Show resolved Hide resolved

@pytest.mark.db_test
class TestDataprocBatchTrigger:
Expand Down