Skip to content
23 changes: 22 additions & 1 deletion airflow-core/src/airflow/jobs/triggerer_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@
GetConnection,
GetDagRunState,
GetDRCount,
GetTICount,
GetVariable,
GetXCom,
TICount,
VariableResult,
XComResult,
)
Expand Down Expand Up @@ -222,6 +224,7 @@ class TriggerStateSync(BaseModel):
XComResult,
DagRunStateResult,
DRCount,
TICount,
ErrorResponse,
],
Field(discriminator="type"),
Expand All @@ -233,7 +236,15 @@ class TriggerStateSync(BaseModel):


ToTriggerSupervisor = Annotated[
Union[messages.TriggerStateChanges, GetConnection, GetVariable, GetXCom, GetDagRunState, GetDRCount],
Union[
messages.TriggerStateChanges,
GetConnection,
GetVariable,
GetXCom,
GetTICount,
GetDagRunState,
GetDRCount,
],
Field(discriminator="type"),
]
"""
Expand Down Expand Up @@ -411,6 +422,16 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger) -
dr_resp = self.client.dag_runs.get_state(msg.dag_id, msg.run_id)
resp = DagRunStateResult.from_api_response(dr_resp).model_dump_json().encode()

elif isinstance(msg, GetTICount):
ti_count = self.client.task_instances.get_count(
dag_id=msg.dag_id,
task_ids=msg.task_ids,
task_group_id=msg.task_group_id,
logical_dates=msg.logical_dates,
run_ids=msg.run_ids,
states=msg.states,
)
resp = ti_count.model_dump_json().encode()
else:
raise ValueError(f"Unknown message type {type(msg)}")

Expand Down
100 changes: 100 additions & 0 deletions airflow-core/tests/unit/jobs/test_triggerer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,3 +767,103 @@ async def test_trigger_can_fetch_trigger_dag_run_count_in_deferrable(session, da
assert task_instance.state == TaskInstanceState.SCHEDULED
assert task_instance.next_method != "__fail__"
assert task_instance.next_kwargs == {"event": {"count": 1}}


class CustomTriggerWorkflowStateTrigger(BaseTrigger):
"""Custom Trigger to check the triggerer can access the get_ti_count and get_dr_count."""

def __init__(self, external_dag_id, execution_dates, external_task_ids, allowed_states, run_ids):
self.external_dag_id = external_dag_id
self.execution_dates = execution_dates
self.external_task_ids = external_task_ids
self.allowed_states = allowed_states
self.run_ids = run_ids

def serialize(self) -> tuple[str, dict[str, Any]]:
return (
f"{type(self).__module__}.{type(self).__qualname__}",
{
"external_dag_id": self.external_dag_id,
"execution_dates": self.execution_dates,
"external_task_ids": self.external_task_ids,
"allowed_states": self.allowed_states,
"run_ids": self.run_ids,
},
)

async def run(self, **args) -> AsyncIterator[TriggerEvent]:
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance

ti_count = await sync_to_async(RuntimeTaskInstance.get_ti_count)(
dag_id=self.external_dag_id,
task_ids=self.external_task_ids,
task_group_id=None,
run_ids=self.run_ids,
logical_dates=self.execution_dates,
states=self.allowed_states,
)
dr_count = await sync_to_async(RuntimeTaskInstance.get_dr_count)(
dag_id=self.external_dag_id,
run_ids=self.run_ids,
logical_dates=self.execution_dates,
states=["running"],
)
yield TriggerEvent({"ti_count": ti_count, "dr_count": dr_count})


@pytest.mark.xfail(
reason="We know that test is flaky and have no time to fix it before 3.0. "
"We should fix it later. TODO: AIP-72"
)
@pytest.mark.asyncio
@pytest.mark.flaky(reruns=2, reruns_delay=10)
@pytest.mark.execution_timeout(30)
async def test_trigger_can_fetch_dag_run_count_ti_count_in_deferrable(session, dag_maker):
"""Checks that the trigger will successfully fetch the count of trigger DAG runs."""
# Create the test DAG and task
with dag_maker(dag_id="parent_dag", session=session):
EmptyOperator(task_id="parent_task")
parent_dag_run = dag_maker.create_dagrun()
parent_task = parent_dag_run.task_instances[0]
parent_task.state = TaskInstanceState.SUCCESS

with dag_maker(dag_id="trigger_can_fetch_dag_run_count_ti_count_in_deferrable", session=session):
EmptyOperator(task_id="dummy1")
dr = dag_maker.create_dagrun()
task_instance = dr.task_instances[0]
task_instance.state = TaskInstanceState.DEFERRED

# Use the same dag run with states deferred to fetch the count
trigger = CustomTriggerWorkflowStateTrigger(
external_dag_id=parent_task.dag_id,
execution_dates=[parent_task.logical_date],
external_task_ids=[parent_task.task_id],
allowed_states=[State.SUCCESS],
run_ids=[parent_task.run_id],
)
trigger_orm = Trigger(
classpath=trigger.serialize()[0],
kwargs={
"external_dag_id": parent_dag_run.dag_id,
"execution_dates": [parent_dag_run.logical_date],
"external_task_ids": [parent_task.task_id],
"allowed_states": [State.SUCCESS],
"run_ids": [parent_dag_run.run_id],
},
)
session.add(trigger_orm)
session.commit()
task_instance.trigger_id = trigger_orm.id

job = Job()
session.add(job)
session.commit()

supervisor = DummyTriggerRunnerSupervisor.start(job=job, capacity=1, logger=None)
supervisor.run()

parent_task.refresh_from_db()
task_instance.refresh_from_db()
assert task_instance.state == TaskInstanceState.SCHEDULED
assert task_instance.next_method != "__fail__"
assert task_instance.next_kwargs == {"event": {"ti_count": 1, "dr_count": 1}}
Original file line number Diff line number Diff line change
Expand Up @@ -346,8 +346,6 @@ def _calculate_count(self, count: int, dttm_filter: list[datetime.datetime]) ->
"""Calculate the normalized count based on the type of check."""
if self.external_task_ids:
return count / len(self.external_task_ids)
elif self.external_task_group_id:
return count / len(dttm_filter)
else:
return count

Expand Down Expand Up @@ -421,16 +419,22 @@ def execute(self, context: Context) -> None:
if not self.deferrable:
super().execute(context)
else:
dttm_filter = self._get_dttm_filter(context)
logical_or_execution_dates = (
{"logical_dates": dttm_filter} if AIRFLOW_V_3_0_PLUS else {"execution_date": dttm_filter}
)
self.defer(
timeout=self.execution_timeout,
trigger=WorkflowTrigger(
external_dag_id=self.external_dag_id,
external_task_group_id=self.external_task_group_id,
external_task_ids=self.external_task_ids,
logical_dates=self._get_dttm_filter(context),
allowed_states=self.allowed_states,
failed_states=self.failed_states,
skipped_states=self.skipped_states,
poke_interval=self.poll_interval,
soft_fail=self.soft_fail,
**logical_or_execution_dates,
),
method_name="execute_complete",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,15 @@ class WorkflowTrigger(BaseTrigger):
:param allowed_states: States considered as successful for external tasks.
:param poke_interval: The interval (in seconds) for poking the external tasks.
:param soft_fail: If True, the trigger will not fail the entire dag on external task failure.
:param logical_dates: A list of logical dates for the external dag.
"""

def __init__(
self,
external_dag_id: str,
run_ids: list[str] | None = None,
execution_dates: list[datetime] | None = None,
logical_dates: list[datetime] | None = None,
external_task_ids: typing.Collection[str] | None = None,
external_task_group_id: str | None = None,
failed_states: typing.Iterable[str] | None = None,
Expand All @@ -76,6 +78,7 @@ def __init__(
self.poke_interval = poke_interval
self.soft_fail = soft_fail
self.execution_dates = execution_dates
self.logical_dates = logical_dates
super().__init__(**kwargs)

def serialize(self) -> tuple[str, dict[str, Any]]:
Expand All @@ -92,35 +95,68 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
}
if AIRFLOW_V_3_0_PLUS:
data["run_ids"] = self.run_ids
data["logical_dates"] = self.logical_dates
else:
data["execution_dates"] = self.execution_dates

return "airflow.providers.standard.triggers.external_task.WorkflowTrigger", data

async def run(self) -> typing.AsyncIterator[TriggerEvent]:
"""Check periodically tasks, task group or dag status."""
if AIRFLOW_V_3_0_PLUS:
get_count_func = self._get_count_af_3
run_id_or_dates = (self.run_ids or self.logical_dates) or []
else:
get_count_func = self._get_count
run_id_or_dates = self.execution_dates or []

while True:
if self.failed_states:
failed_count = await self._get_count(self.failed_states)
failed_count = await get_count_func(self.failed_states)
if failed_count > 0:
yield TriggerEvent({"status": "failed"})
return
else:
yield TriggerEvent({"status": "success"})
return
if self.skipped_states:
skipped_count = await self._get_count(self.skipped_states)
skipped_count = await get_count_func(self.skipped_states)
if skipped_count > 0:
yield TriggerEvent({"status": "skipped"})
return
allowed_count = await self._get_count(self.allowed_states)
_dates = self.run_ids if AIRFLOW_V_3_0_PLUS else self.execution_dates
if allowed_count == len(_dates): # type: ignore[arg-type]
allowed_count = await get_count_func(self.allowed_states)

if allowed_count == len(run_id_or_dates): # type: ignore[arg-type]
yield TriggerEvent({"status": "success"})
return
self.log.info("Sleeping for %s seconds", self.poke_interval)
await asyncio.sleep(self.poke_interval)

async def _get_count_af_3(self, states):
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance

if self.external_task_ids or self.external_task_group_id:
count = await sync_to_async(RuntimeTaskInstance.get_ti_count)(
dag_id=self.external_dag_id,
task_ids=self.external_task_ids,
task_group_id=self.external_task_group_id,
logical_dates=self.logical_dates,
run_ids=self.run_ids,
states=states,
)
else:
count = await sync_to_async(RuntimeTaskInstance.get_dr_count)(
dag_id=self.external_dag_id,
logical_dates=self.logical_dates,
run_ids=self.run_ids,
states=states,
)

if self.external_task_ids:
return count / len(self.external_task_ids)
else:
return count

@sync_to_async
def _get_count(self, states: typing.Iterable[str] | None) -> int:
"""
Expand Down
Loading