Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions airflow-core/src/airflow/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,6 @@ def string_lower_type(val):
type=str,
default="[CWD]" if BUILD_DOCS else os.getcwd(),
)
ARG_DRY_RUN = Arg(
("-n", "--dry-run"),
help="Perform a dry run for each task. Only renders Template Fields for each task, nothing else",
action="store_true",
)
ARG_PID = Arg(("--pid",), help="PID file location", nargs="?")
ARG_DAEMON = Arg(
("-D", "--daemon"), help="Daemonize instead of running in the foreground", action="store_true"
Expand Down Expand Up @@ -1270,7 +1265,6 @@ class GroupCommand(NamedTuple):
ARG_TASK_ID,
ARG_LOGICAL_DATE_OR_RUN_ID_OPTIONAL,
ARG_BUNDLE_NAME,
ARG_DRY_RUN,
ARG_TASK_PARAMS,
ARG_POST_MORTEM,
ARG_ENV_VARS,
Expand Down
61 changes: 18 additions & 43 deletions airflow-core/src/airflow/cli/commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from __future__ import annotations

import functools
import importlib
import json
import logging
Expand All @@ -31,13 +30,14 @@
from airflow import settings
from airflow.cli.simple_table import AirflowConsole
from airflow.cli.utils import fetch_dag_run_from_run_id_or_logical_date_string
from airflow.exceptions import DagRunNotFound, TaskDeferred, TaskInstanceNotFound
from airflow.exceptions import DagRunNotFound, TaskInstanceNotFound
from airflow.models import TaskInstance
from airflow.models.dag import DAG
from airflow.models.dag import DAG as SchedulerDAG, _get_or_create_dagrun
from airflow.models.dagrun import DagRun
from airflow.sdk.definitions.dag import _run_inline_trigger
from airflow.sdk.definitions.dag import DAG, _run_task
from airflow.sdk.definitions.param import ParamsDict
from airflow.sdk.execution_time.secrets_masker import RedactedIO
from airflow.serialization.serialized_objects import SerializedDAG
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_deps import SCHEDULER_QUEUED_DEPS
from airflow.utils import cli as cli_utils, timezone
Expand All @@ -49,7 +49,7 @@
)
from airflow.utils.providers_configuration_loader import providers_configuration_loaded
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.utils.state import DagRunState
from airflow.utils.state import DagRunState, State
from airflow.utils.task_instance_session import set_current_task_instance_session
from airflow.utils.types import DagRunTriggeredByType, DagRunType

Expand Down Expand Up @@ -134,15 +134,17 @@ def _get_dag_run(
)
return dag_run, True
if create_if_necessary == "db":
dag_run = dag.create_dagrun(
scheduler_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag)) # type: ignore[arg-type]
dag_run = _get_or_create_dagrun(
dag=scheduler_dag,
run_id=_generate_temporary_run_id(),
logical_date=dag_run_logical_date,
data_interval=data_interval,
run_after=run_after,
run_type=DagRunType.MANUAL,
triggered_by=DagRunTriggeredByType.CLI,
state=DagRunState.RUNNING,
session=session,
start_date=logical_date or run_after,
conf=None,
)
return dag_run, True
raise ValueError(f"unknown create_if_necessary value: {create_if_necessary!r}")
Expand All @@ -161,11 +163,6 @@ def _get_ti(
dag = task.dag
if dag is None:
raise ValueError("Cannot get task instance for a task not assigned to a DAG")
if not isinstance(dag, DAG):
# TODO: Task-SDK: Shouldn't really happen, and this command will go away before 3.0
raise ValueError(
f"We need a {DAG.__module__}.DAG, but we got {type(dag).__module__}.{type(dag).__name__}!"
)

# this check is imperfect because diff dags could have tasks with same name
# but in a task, dag_id is a property that accesses its dag, and we don't
Expand Down Expand Up @@ -274,11 +271,13 @@ def task_list(args, dag: DAG | None = None) -> None:

class _SupportedDebugger(Protocol):
def post_mortem(self) -> None: ...
def set_trace(self) -> None: ...


SUPPORTED_DEBUGGER_MODULES = [
"pudb",
"web_pdb",
"pdbr",
"ipdb",
"pdb",
]
Expand All @@ -294,6 +293,7 @@ def _guess_debugger() -> _SupportedDebugger:

* `pudb <https://github.com/inducer/pudb>`__
* `web_pdb <https://github.com/romanvm/python-web-pdb>`__
* `pdbr <https://github.com/cansarigol/pdbr>`__
* `ipdb <https://github.com/gotcha/ipdb>`__
* `pdb <https://docs.python.org/3/library/pdb.html>`__
"""
Expand Down Expand Up @@ -343,8 +343,7 @@ def format_task_instance(ti: TaskInstance) -> dict[str, str]:


@cli_utils.action_cli(check_db=False)
@provide_session
def task_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> None:
def task_test(args, dag: DAG | None = None) -> None:
"""Test task for a given dag_id."""
# We want to log output from operators etc to show up here. Normally
# airflow.task would redirect to a file, but here we want it to propagate
Expand All @@ -368,8 +367,6 @@ def task_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> N

dag = dag or get_dag(args.bundle_name, args.dag_id)

dag = DAG.from_sdk_dag(dag)

task = dag.get_task(task_id=args.task_id)
# Add CLI provided task_params to task.params
if args.task_params:
Expand All @@ -384,31 +381,10 @@ def task_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> N
)
try:
with redirect_stdout(RedactedIO()):
if args.dry_run:
ti.dry_run()
else:
ti.run(ignore_task_deps=True, ignore_ti_state=True, test_mode=True, raise_on_defer=True)
except TaskDeferred as defer:
ti.defer_task(exception=defer, session=session)
log.info("[TASK TEST] running trigger in line")

event = _run_inline_trigger(defer.trigger)
ti.next_method = defer.method_name
ti.next_kwargs = {"event": event.payload} if event else defer.kwargs

execute_callable = getattr(task, ti.next_method)
if ti.next_kwargs:
execute_callable = functools.partial(execute_callable, **ti.next_kwargs)
context = ti.get_template_context(ignore_param_exceptions=False)
execute_callable(context)

log.info("[TASK TEST] Trigger completed")
except Exception:
if args.post_mortem:
_run_task(ti=ti)
if ti.state == State.FAILED and args.post_mortem:
debugger = _guess_debugger()
debugger.post_mortem()
else:
raise
debugger.set_trace()
finally:
if not already_has_stream_handler:
# Make sure to reset back to normal. When run for CLI this doesn't
Expand All @@ -426,7 +402,6 @@ def task_render(args, dag: DAG | None = None) -> None:
"""Render and displays templated fields for a given task."""
if not dag:
dag = get_dag(args.bundle_name, args.dag_id)
dag = DAG.from_sdk_dag(dag)
task = dag.get_task(task_id=args.task_id)
ti, _ = _get_ti(
task, args.map_index, logical_date_or_run_id=args.logical_date_or_run_id, create_if_necessary="memory"
Expand Down Expand Up @@ -465,7 +440,7 @@ def task_clear(args) -> None:
include_upstream=args.upstream,
)

DAG.clear_dags(
SchedulerDAG.clear_dags(
dags,
start_date=args.start_date,
end_date=args.end_date,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,14 @@ def my_py_command(params, test_mode=None, task=None):


@task(task_id="env_var_test_task")
def print_env_vars(test_mode=None):
def print_env_vars():
"""
Print out the "foo" param passed in via
`airflow tasks test example_passing_params_via_test_command env_var_test_task <date>
--env-vars '{"foo":"bar"}'`
"""
if test_mode:
print(f"foo={os.environ.get('foo')}")
print(f"AIRFLOW_TEST_MODE={os.environ.get('AIRFLOW_TEST_MODE')}")
print(f"foo={os.environ.get('foo')}")
print(f"AIRFLOW_TEST_MODE={os.environ.get('AIRFLOW_TEST_MODE')}")


with DAG(
Expand Down
21 changes: 13 additions & 8 deletions airflow-core/tests/unit/cli/commands/test_task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,9 @@ def test_test_with_existing_dag_run(self, caplog):
args = self.parser.parse_args(["tasks", "test", self.dag_id, task_id, DEFAULT_DATE.isoformat()])
with caplog.at_level("INFO", logger="airflow.task"):
task_command.task_test(args)
assert (
f"Marking task as SUCCESS. dag_id={self.dag_id}, task_id={task_id}, run_id={self.run_id}, "
in caplog.text
)
ti = self.dag_run.get_task_instance(task_id=task_id)
assert ti is not None
assert ti.state == State.SUCCESS

@pytest.mark.enable_redact
def test_test_filters_secrets(self, capsys):
Expand All @@ -177,12 +176,16 @@ def test_test_filters_secrets(self, capsys):
["tasks", "test", "example_python_operator", "print_the_context", "2018-01-01"],
)

with mock.patch("airflow.models.TaskInstance.run", side_effect=lambda *_, **__: print(password)):
with mock.patch(
"airflow.cli.commands.task_command._run_task", side_effect=lambda *_, **__: print(password)
):
task_command.task_test(args)
assert capsys.readouterr().out.endswith("***\n")

not_password = "!4321drowssapemos"
with mock.patch("airflow.models.TaskInstance.run", side_effect=lambda *_, **__: print(not_password)):
with mock.patch(
"airflow.cli.commands.task_command._run_task", side_effect=lambda *_, **__: print(not_password)
):
task_command.task_test(args)
assert capsys.readouterr().out.endswith(f"{not_password}\n")

Expand Down Expand Up @@ -234,7 +237,9 @@ def test_cli_test_with_env_vars(self):
assert "AIRFLOW_TEST_MODE=True" in output

@mock.patch("airflow.providers.standard.triggers.file.os.path.getmtime", return_value=0)
@mock.patch("airflow.providers.standard.triggers.file.glob", return_value=["/tmp/test"])
@mock.patch(
"airflow.providers.standard.triggers.file.glob", return_value=["/tmp/temporary_file_for_testing"]
)
@mock.patch("airflow.providers.standard.triggers.file.os")
@mock.patch("airflow.providers.standard.sensors.filesystem.FileSensor.poke", return_value=False)
def test_cli_test_with_deferrable_operator(self, mock_pock, mock_os, mock_glob, mock_getmtime, caplog):
Expand All @@ -252,7 +257,7 @@ def test_cli_test_with_deferrable_operator(self, mock_pock, mock_os, mock_glob,
)
)
output = caplog.text
assert "wait_for_file_async completed successfully as /tmp/temporary_file_for_testing found" in output
assert "Found File /tmp/temporary_file_for_testing" in output

def test_task_render(self):
"""
Expand Down
3 changes: 2 additions & 1 deletion task-sdk/src/airflow/sdk/definitions/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1244,8 +1244,9 @@ def _run_task(*, ti):
)

msg = taskrun_result.msg
ti.set_state(taskrun_result.ti.state)

if taskrun_result.ti.state == State.DEFERRED and isinstance(msg, DeferTask):
if ti.state == State.DEFERRED and isinstance(msg, DeferTask):
# API Server expects the task instance to be in QUEUED state before
# resuming from deferral.
ti.set_state(State.QUEUED)
Expand Down
Loading