Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@

from aiohttp import ClientSession
from aiohttp.client_exceptions import ClientResponseError
from asgiref.sync import sync_to_async

from airflow.exceptions import AirflowException
from airflow.models.taskinstance import TaskInstance
from airflow.providers.google.cloud.hooks.bigquery import BigQueryAsyncHook, BigQueryTableAsyncHook
from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils.session import provide_session
from airflow.utils.state import TaskInstanceState
Expand Down Expand Up @@ -116,16 +118,41 @@ def get_task_instance(self, session: Session) -> TaskInstance:
)
return task_instance

def safe_to_cancel(self) -> bool:
async def get_task_state(self):
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance

task_states_response = await sync_to_async(RuntimeTaskInstance.get_task_states)(
dag_id=self.task_instance.dag_id,
task_ids=[self.task_instance.task_id],
run_ids=[self.task_instance.run_id],
map_index=self.task_instance.map_index,
)
try:
task_state = task_states_response[self.task_instance.run_id][self.task_instance.task_id]
except Exception:
raise AirflowException(
"TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found",
self.task_instance.dag_id,
self.task_instance.task_id,
self.task_instance.run_id,
self.task_instance.map_index,
)
return task_state

async def safe_to_cancel(self) -> bool:
"""
Whether it is safe to cancel the external job which is being executed by this trigger.

This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped.
Because in those cases, we should NOT cancel the external job.
"""
# Database query is needed to get the latest state of the task instance.
task_instance = self.get_task_instance() # type: ignore[call-arg]
return task_instance.state != TaskInstanceState.DEFERRED
if AIRFLOW_V_3_0_PLUS:
task_state = await self.get_task_state()
else:
# Database query is needed to get the latest state of the task instance.
task_instance = self.get_task_instance() # type: ignore[call-arg]
task_state = task_instance.state
return task_state != TaskInstanceState.DEFERRED

async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
"""Get current job execution status and yields a TriggerEvent."""
Expand Down Expand Up @@ -155,7 +182,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
)
await asyncio.sleep(self.poll_interval)
except asyncio.CancelledError:
if self.job_id and self.cancel_on_kill and self.safe_to_cancel():
if self.job_id and self.cancel_on_kill and await self.safe_to_cancel():
self.log.info(
"The job is safe to cancel the as airflow TaskInstance is not in deferred state."
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from collections.abc import AsyncIterator, Sequence
from typing import TYPE_CHECKING, Any

from asgiref.sync import sync_to_async
from google.api_core.exceptions import NotFound
from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus

Expand All @@ -33,6 +34,7 @@
from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook, DataprocHook
from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils.session import provide_session
from airflow.utils.state import TaskInstanceState
Expand Down Expand Up @@ -141,16 +143,41 @@ def get_task_instance(self, session: Session) -> TaskInstance:
)
return task_instance

def safe_to_cancel(self) -> bool:
async def get_task_state(self):
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance

task_states_response = await sync_to_async(RuntimeTaskInstance.get_task_states)(
dag_id=self.task_instance.dag_id,
task_ids=[self.task_instance.task_id],
run_ids=[self.task_instance.run_id],
map_index=self.task_instance.map_index,
)
try:
task_state = task_states_response[self.task_instance.run_id][self.task_instance.task_id]
except Exception:
raise AirflowException(
"TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found",
self.task_instance.dag_id,
self.task_instance.task_id,
self.task_instance.run_id,
self.task_instance.map_index,
)
return task_state

async def safe_to_cancel(self) -> bool:
"""
Whether it is safe to cancel the external job which is being executed by this trigger.

This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped.
Because in those cases, we should NOT cancel the external job.
"""
# Database query is needed to get the latest state of the task instance.
task_instance = self.get_task_instance() # type: ignore[call-arg]
return task_instance.state != TaskInstanceState.DEFERRED
if AIRFLOW_V_3_0_PLUS:
task_state = await self.get_task_state()
else:
# Database query is needed to get the latest state of the task instance.
task_instance = self.get_task_instance() # type: ignore[call-arg]
task_state = task_instance.state
return task_state != TaskInstanceState.DEFERRED

async def run(self):
try:
Expand All @@ -167,7 +194,7 @@ async def run(self):
except asyncio.CancelledError:
self.log.info("Task got cancelled.")
try:
if self.job_id and self.cancel_on_kill and self.safe_to_cancel():
if self.job_id and self.cancel_on_kill and await self.safe_to_cancel():
self.log.info(
"Cancelling the job as it is safe to do so. Note that the airflow TaskInstance is not"
" in deferred state."
Expand Down Expand Up @@ -243,16 +270,41 @@ def get_task_instance(self, session: Session) -> TaskInstance:
)
return task_instance

def safe_to_cancel(self) -> bool:
async def get_task_state(self):
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance

task_states_response = await sync_to_async(RuntimeTaskInstance.get_task_states)(
dag_id=self.task_instance.dag_id,
task_ids=[self.task_instance.task_id],
run_ids=[self.task_instance.run_id],
map_index=self.task_instance.map_index,
)
try:
task_state = task_states_response[self.task_instance.run_id][self.task_instance.task_id]
except Exception:
raise AirflowException(
"TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found",
self.task_instance.dag_id,
self.task_instance.task_id,
self.task_instance.run_id,
self.task_instance.map_index,
)
return task_state

async def safe_to_cancel(self) -> bool:
"""
Whether it is safe to cancel the external job which is being executed by this trigger.

This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped.
Because in those cases, we should NOT cancel the external job.
"""
# Database query is needed to get the latest state of the task instance.
task_instance = self.get_task_instance() # type: ignore[call-arg]
return task_instance.state != TaskInstanceState.DEFERRED
if AIRFLOW_V_3_0_PLUS:
task_state = await self.get_task_state()
else:
# Database query is needed to get the latest state of the task instance.
task_instance = self.get_task_instance() # type: ignore[call-arg]
task_state = task_instance.state
return task_state != TaskInstanceState.DEFERRED

async def run(self) -> AsyncIterator[TriggerEvent]:
try:
Expand Down Expand Up @@ -283,7 +335,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
await asyncio.sleep(self.polling_interval_seconds)
except asyncio.CancelledError:
try:
if self.delete_on_error and self.safe_to_cancel():
if self.delete_on_error and await self.safe_to_cancel():
self.log.info(
"Deleting the cluster as it is safe to delete as the airflow TaskInstance is not in "
"deferred state."
Expand Down