diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py b/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py index 13d9f55bacb45..2cac99543a637 100644 --- a/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py +++ b/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py @@ -22,10 +22,12 @@ from aiohttp import ClientSession from aiohttp.client_exceptions import ClientResponseError +from asgiref.sync import sync_to_async from airflow.exceptions import AirflowException from airflow.models.taskinstance import TaskInstance from airflow.providers.google.cloud.hooks.bigquery import BigQueryAsyncHook, BigQueryTableAsyncHook +from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS from airflow.triggers.base import BaseTrigger, TriggerEvent from airflow.utils.session import provide_session from airflow.utils.state import TaskInstanceState @@ -116,16 +118,41 @@ def get_task_instance(self, session: Session) -> TaskInstance: ) return task_instance - def safe_to_cancel(self) -> bool: + async def get_task_state(self): + from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance + + task_states_response = await sync_to_async(RuntimeTaskInstance.get_task_states)( + dag_id=self.task_instance.dag_id, + task_ids=[self.task_instance.task_id], + run_ids=[self.task_instance.run_id], + map_index=self.task_instance.map_index, + ) + try: + task_state = task_states_response[self.task_instance.run_id][self.task_instance.task_id] + except Exception: + raise AirflowException( + "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found", + self.task_instance.dag_id, + self.task_instance.task_id, + self.task_instance.run_id, + self.task_instance.map_index, + ) + return task_state + + async def safe_to_cancel(self) -> bool: """ Whether it is safe to cancel the external job which is being executed by this trigger. This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped. Because in those cases, we should NOT cancel the external job. """ - # Database query is needed to get the latest state of the task instance. - task_instance = self.get_task_instance() # type: ignore[call-arg] - return task_instance.state != TaskInstanceState.DEFERRED + if AIRFLOW_V_3_0_PLUS: + task_state = await self.get_task_state() + else: + # Database query is needed to get the latest state of the task instance. + task_instance = self.get_task_instance() # type: ignore[call-arg] + task_state = task_instance.state + return task_state != TaskInstanceState.DEFERRED async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] """Get current job execution status and yields a TriggerEvent.""" @@ -155,7 +182,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] ) await asyncio.sleep(self.poll_interval) except asyncio.CancelledError: - if self.job_id and self.cancel_on_kill and self.safe_to_cancel(): + if self.job_id and self.cancel_on_kill and await self.safe_to_cancel(): self.log.info( "The job is safe to cancel the as airflow TaskInstance is not in deferred state." ) diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py b/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py index 80afd381cc389..2f44d37043271 100644 --- a/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py +++ b/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py @@ -25,6 +25,7 @@ from collections.abc import AsyncIterator, Sequence from typing import TYPE_CHECKING, Any +from asgiref.sync import sync_to_async from google.api_core.exceptions import NotFound from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus @@ -33,6 +34,7 @@ from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook, DataprocHook from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID +from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS from airflow.triggers.base import BaseTrigger, TriggerEvent from airflow.utils.session import provide_session from airflow.utils.state import TaskInstanceState @@ -141,16 +143,41 @@ def get_task_instance(self, session: Session) -> TaskInstance: ) return task_instance - def safe_to_cancel(self) -> bool: + async def get_task_state(self): + from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance + + task_states_response = await sync_to_async(RuntimeTaskInstance.get_task_states)( + dag_id=self.task_instance.dag_id, + task_ids=[self.task_instance.task_id], + run_ids=[self.task_instance.run_id], + map_index=self.task_instance.map_index, + ) + try: + task_state = task_states_response[self.task_instance.run_id][self.task_instance.task_id] + except Exception: + raise AirflowException( + "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found", + self.task_instance.dag_id, + self.task_instance.task_id, + self.task_instance.run_id, + self.task_instance.map_index, + ) + return task_state + + async def safe_to_cancel(self) -> bool: """ Whether it is safe to cancel the external job which is being executed by this trigger. This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped. Because in those cases, we should NOT cancel the external job. """ - # Database query is needed to get the latest state of the task instance. - task_instance = self.get_task_instance() # type: ignore[call-arg] - return task_instance.state != TaskInstanceState.DEFERRED + if AIRFLOW_V_3_0_PLUS: + task_state = await self.get_task_state() + else: + # Database query is needed to get the latest state of the task instance. + task_instance = self.get_task_instance() # type: ignore[call-arg] + task_state = task_instance.state + return task_state != TaskInstanceState.DEFERRED async def run(self): try: @@ -167,7 +194,7 @@ async def run(self): except asyncio.CancelledError: self.log.info("Task got cancelled.") try: - if self.job_id and self.cancel_on_kill and self.safe_to_cancel(): + if self.job_id and self.cancel_on_kill and await self.safe_to_cancel(): self.log.info( "Cancelling the job as it is safe to do so. Note that the airflow TaskInstance is not" " in deferred state." @@ -243,16 +270,41 @@ def get_task_instance(self, session: Session) -> TaskInstance: ) return task_instance - def safe_to_cancel(self) -> bool: + async def get_task_state(self): + from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance + + task_states_response = await sync_to_async(RuntimeTaskInstance.get_task_states)( + dag_id=self.task_instance.dag_id, + task_ids=[self.task_instance.task_id], + run_ids=[self.task_instance.run_id], + map_index=self.task_instance.map_index, + ) + try: + task_state = task_states_response[self.task_instance.run_id][self.task_instance.task_id] + except Exception: + raise AirflowException( + "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found", + self.task_instance.dag_id, + self.task_instance.task_id, + self.task_instance.run_id, + self.task_instance.map_index, + ) + return task_state + + async def safe_to_cancel(self) -> bool: """ Whether it is safe to cancel the external job which is being executed by this trigger. This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped. Because in those cases, we should NOT cancel the external job. """ - # Database query is needed to get the latest state of the task instance. - task_instance = self.get_task_instance() # type: ignore[call-arg] - return task_instance.state != TaskInstanceState.DEFERRED + if AIRFLOW_V_3_0_PLUS: + task_state = await self.get_task_state() + else: + # Database query is needed to get the latest state of the task instance. + task_instance = self.get_task_instance() # type: ignore[call-arg] + task_state = task_instance.state + return task_state != TaskInstanceState.DEFERRED async def run(self) -> AsyncIterator[TriggerEvent]: try: @@ -283,7 +335,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: await asyncio.sleep(self.polling_interval_seconds) except asyncio.CancelledError: try: - if self.delete_on_error and self.safe_to_cancel(): + if self.delete_on_error and await self.safe_to_cancel(): self.log.info( "Deleting the cluster as it is safe to delete as the airflow TaskInstance is not in " "deferred state."