Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable "airflow tasks test" to run deferrable operator #37542

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions airflow/cli/commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Task sub-commands."""
from __future__ import annotations

import functools
import importlib
import json
import logging
Expand All @@ -34,13 +35,13 @@
from airflow import settings
from airflow.cli.simple_table import AirflowConsole
from airflow.configuration import conf
from airflow.exceptions import AirflowException, DagRunNotFound, TaskInstanceNotFound
from airflow.exceptions import AirflowException, DagRunNotFound, TaskDeferred, TaskInstanceNotFound
from airflow.executors.executor_loader import ExecutorLoader
from airflow.jobs.job import Job, run_job
from airflow.jobs.local_task_job_runner import LocalTaskJobRunner
from airflow.listeners.listener import get_listener_manager
from airflow.models import DagPickle, TaskInstance
from airflow.models.dag import DAG
from airflow.models.dag import DAG, _run_inline_trigger
from airflow.models.dagrun import DagRun
from airflow.models.operator import needs_expansion
from airflow.models.param import ParamsDict
Expand Down Expand Up @@ -588,7 +589,8 @@ def format_task_instance(ti: TaskInstance) -> dict[str, str]:


@cli_utils.action_cli(check_db=False)
def task_test(args, dag: DAG | None = None) -> None:
@provide_session
def task_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> None:
"""Test task for a given dag_id."""
# We want to log output from operators etc to show up here. Normally
# airflow.task would redirect to a file, but here we want it to propagate
Expand Down Expand Up @@ -632,7 +634,22 @@ def task_test(args, dag: DAG | None = None) -> None:
if args.dry_run:
ti.dry_run()
else:
ti.run(ignore_task_deps=True, ignore_ti_state=True, test_mode=True)
ti.run(ignore_task_deps=True, ignore_ti_state=True, test_mode=True, raise_on_defer=True)
except TaskDeferred as defer:
ti.defer_task(defer=defer, session=session)
log.info("[TASK TEST] running trigger in line")

event = _run_inline_trigger(defer.trigger)
ti.next_method = defer.method_name
ti.next_kwargs = {"event": event.payload} if event else defer.kwargs

execute_callable = getattr(task, ti.next_method)
if ti.next_kwargs:
execute_callable = functools.partial(execute_callable, **ti.next_kwargs)
context = ti.get_template_context(ignore_param_exceptions=False)
execute_callable(context)

log.info("[TASK TEST] Trigger completed")
except Exception:
if args.post_mortem:
debugger = _guess_debugger()
Expand Down
8 changes: 4 additions & 4 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -4057,12 +4057,12 @@ def get_current_dag(cls) -> DAG | None:
return None


def _run_trigger(trigger):
async def _run_trigger_main():
def _run_inline_trigger(trigger):
async def _run_inline_trigger_main():
async for event in trigger.run():
return event

return asyncio.run(_run_trigger_main())
return asyncio.run(_run_inline_trigger_main())


def _run_task(*, ti: TaskInstance, inline_trigger: bool = False, session: Session):
Expand All @@ -4083,7 +4083,7 @@ def _run_task(*, ti: TaskInstance, inline_trigger: bool = False, session: Sessio
break
except TaskDeferred as e:
log.info("[DAG TEST] running trigger in line")
event = _run_trigger(e.trigger)
event = _run_inline_trigger(e.trigger)
ti.next_method = e.method_name
ti.next_kwargs = {"event": event.payload} if event else e.kwargs
log.info("[DAG TEST] Trigger completed")
Expand Down
17 changes: 13 additions & 4 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2378,7 +2378,7 @@ def _run_raw_task(
# a trigger.
if raise_on_defer:
raise
self._defer_task(defer=defer, session=session)
self.defer_task(defer=defer, session=session)
self.log.info(
"Pausing task as DEFERRED. dag_id=%s, task_id=%s, execution_date=%s, start_date=%s",
self.dag_id,
Expand Down Expand Up @@ -2565,8 +2565,11 @@ def _execute_task(self, context, task_orig):
return _execute_task(self, context, task_orig)

@provide_session
def _defer_task(self, session: Session, defer: TaskDeferred) -> None:
"""Mark the task as deferred and sets up the trigger that is needed to resume it."""
def defer_task(self, session: Session, defer: TaskDeferred) -> None:
"""Mark the task as deferred and sets up the trigger that is needed to resume it.

:meta: private
"""
from airflow.models.trigger import Trigger

# First, make the trigger entry
Expand Down Expand Up @@ -2625,6 +2628,7 @@ def run(
job_id: str | None = None,
pool: str | None = None,
session: Session = NEW_SESSION,
raise_on_defer: bool = False,
) -> None:
"""Run TaskInstance."""
res = self.check_and_change_state_before_execution(
Expand All @@ -2644,7 +2648,12 @@ def run(
return

self._run_raw_task(
mark_success=mark_success, test_mode=test_mode, job_id=job_id, pool=pool, session=session
mark_success=mark_success,
test_mode=test_mode,
job_id=job_id,
pool=pool,
session=session,
raise_on_defer=raise_on_defer,
)

def dry_run(self) -> None:
Expand Down
8 changes: 4 additions & 4 deletions tests/cli/commands/test_dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from airflow.exceptions import AirflowException
from airflow.models import DagBag, DagModel, DagRun
from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import _run_trigger
from airflow.models.dag import _run_inline_trigger
from airflow.models.serialized_dag import SerializedDagModel
from airflow.triggers.base import TriggerEvent
from airflow.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger
Expand Down Expand Up @@ -878,15 +878,15 @@ def test_dag_test_with_custom_timetable(self, mock__get_or_create_dagrun, _):
dag_command.dag_test(cli_args)
assert "data_interval" in mock__get_or_create_dagrun.call_args.kwargs

def test_dag_test_run_trigger(self, dag_maker):
def test_dag_test_run_inline_trigger(self, dag_maker):
now = timezone.utcnow()
trigger = DateTimeTrigger(moment=now)
e = _run_trigger(trigger)
e = _run_inline_trigger(trigger)
assert isinstance(e, TriggerEvent)
assert e.payload == now

def test_dag_test_no_triggerer_running(self, dag_maker):
with mock.patch("airflow.models.dag._run_trigger", wraps=_run_trigger) as mock_run:
with mock.patch("airflow.models.dag._run_inline_trigger", wraps=_run_inline_trigger) as mock_run:
with dag_maker() as dag:

@task
Expand Down
21 changes: 21 additions & 0 deletions tests/cli/commands/test_task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,27 @@ def test_cli_test_with_env_vars(self):
assert "foo=bar" in output
assert "AIRFLOW_TEST_MODE=True" in output

@pytest.mark.asyncio
@mock.patch("airflow.triggers.file.os.path.getmtime", return_value=0)
@mock.patch("airflow.triggers.file.glob", return_value=["/tmp/test"])
@mock.patch("airflow.triggers.file.os.path.isfile", return_value=True)
@mock.patch("airflow.sensors.filesystem.FileSensor.poke", return_value=False)
def test_cli_test_with_deferrable_operator(self, mock_pock, mock_is_file, mock_glob, mock_getmtime):
with redirect_stdout(StringIO()) as stdout:
task_command.task_test(
self.parser.parse_args(
[
"tasks",
"test",
"example_sensors",
"wait_for_file_async",
DEFAULT_DATE.isoformat(),
]
)
)
output = stdout.getvalue()
assert "wait_for_file_async completed successfully as /tmp/temporary_file_for_testing found" in output

@pytest.mark.parametrize(
"option",
[
Expand Down
4 changes: 2 additions & 2 deletions tests/jobs/test_triggerer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def test_trigger_lifecycle(session):
class TestTriggerRunner:
@pytest.mark.asyncio
@patch("airflow.jobs.triggerer_job_runner.TriggerRunner.set_individual_trigger_logging")
async def test_run_trigger_canceled(self, session) -> None:
async def test_run_inline_trigger_canceled(self, session) -> None:
trigger_runner = TriggerRunner()
trigger_runner.triggers = {1: {"task": MagicMock(), "name": "mock_name", "events": 0}}
mock_trigger = MagicMock()
Expand All @@ -278,7 +278,7 @@ async def test_run_trigger_canceled(self, session) -> None:

@pytest.mark.asyncio
@patch("airflow.jobs.triggerer_job_runner.TriggerRunner.set_individual_trigger_logging")
async def test_run_trigger_timeout(self, session, caplog) -> None:
async def test_run_inline_trigger_timeout(self, session, caplog) -> None:
trigger_runner = TriggerRunner()
trigger_runner.triggers = {1: {"task": MagicMock(), "name": "mock_name", "events": 0}}
mock_trigger = MagicMock()
Expand Down