diff --git a/airflow-core/src/airflow/cli/cli_config.py b/airflow-core/src/airflow/cli/cli_config.py index dd1fab0456e17..97c72fdbbdf74 100644 --- a/airflow-core/src/airflow/cli/cli_config.py +++ b/airflow-core/src/airflow/cli/cli_config.py @@ -177,11 +177,6 @@ def string_lower_type(val): type=str, default="[CWD]" if BUILD_DOCS else os.getcwd(), ) -ARG_DRY_RUN = Arg( - ("-n", "--dry-run"), - help="Perform a dry run for each task. Only renders Template Fields for each task, nothing else", - action="store_true", -) ARG_PID = Arg(("--pid",), help="PID file location", nargs="?") ARG_DAEMON = Arg( ("-D", "--daemon"), help="Daemonize instead of running in the foreground", action="store_true" @@ -1270,7 +1265,6 @@ class GroupCommand(NamedTuple): ARG_TASK_ID, ARG_LOGICAL_DATE_OR_RUN_ID_OPTIONAL, ARG_BUNDLE_NAME, - ARG_DRY_RUN, ARG_TASK_PARAMS, ARG_POST_MORTEM, ARG_ENV_VARS, diff --git a/airflow-core/src/airflow/cli/commands/task_command.py b/airflow-core/src/airflow/cli/commands/task_command.py index a3492a828c03a..cf074cd32c244 100644 --- a/airflow-core/src/airflow/cli/commands/task_command.py +++ b/airflow-core/src/airflow/cli/commands/task_command.py @@ -19,7 +19,6 @@ from __future__ import annotations -import functools import importlib import json import logging @@ -31,13 +30,14 @@ from airflow import settings from airflow.cli.simple_table import AirflowConsole from airflow.cli.utils import fetch_dag_run_from_run_id_or_logical_date_string -from airflow.exceptions import DagRunNotFound, TaskDeferred, TaskInstanceNotFound +from airflow.exceptions import DagRunNotFound, TaskInstanceNotFound from airflow.models import TaskInstance -from airflow.models.dag import DAG +from airflow.models.dag import DAG as SchedulerDAG, _get_or_create_dagrun from airflow.models.dagrun import DagRun -from airflow.sdk.definitions.dag import _run_inline_trigger +from airflow.sdk.definitions.dag import DAG, _run_task from airflow.sdk.definitions.param import ParamsDict from airflow.sdk.execution_time.secrets_masker import RedactedIO +from airflow.serialization.serialized_objects import SerializedDAG from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.dependencies_deps import SCHEDULER_QUEUED_DEPS from airflow.utils import cli as cli_utils, timezone @@ -49,7 +49,7 @@ ) from airflow.utils.providers_configuration_loader import providers_configuration_loaded from airflow.utils.session import NEW_SESSION, create_session, provide_session -from airflow.utils.state import DagRunState +from airflow.utils.state import DagRunState, State from airflow.utils.task_instance_session import set_current_task_instance_session from airflow.utils.types import DagRunTriggeredByType, DagRunType @@ -134,15 +134,17 @@ def _get_dag_run( ) return dag_run, True if create_if_necessary == "db": - dag_run = dag.create_dagrun( + scheduler_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag)) # type: ignore[arg-type] + dag_run = _get_or_create_dagrun( + dag=scheduler_dag, run_id=_generate_temporary_run_id(), logical_date=dag_run_logical_date, data_interval=data_interval, run_after=run_after, - run_type=DagRunType.MANUAL, triggered_by=DagRunTriggeredByType.CLI, - state=DagRunState.RUNNING, session=session, + start_date=logical_date or run_after, + conf=None, ) return dag_run, True raise ValueError(f"unknown create_if_necessary value: {create_if_necessary!r}") @@ -161,11 +163,6 @@ def _get_ti( dag = task.dag if dag is None: raise ValueError("Cannot get task instance for a task not assigned to a DAG") - if not isinstance(dag, DAG): - # TODO: Task-SDK: Shouldn't really happen, and this command will go away before 3.0 - raise ValueError( - f"We need a {DAG.__module__}.DAG, but we got {type(dag).__module__}.{type(dag).__name__}!" - ) # this check is imperfect because diff dags could have tasks with same name # but in a task, dag_id is a property that accesses its dag, and we don't @@ -274,11 +271,13 @@ def task_list(args, dag: DAG | None = None) -> None: class _SupportedDebugger(Protocol): def post_mortem(self) -> None: ... + def set_trace(self) -> None: ... SUPPORTED_DEBUGGER_MODULES = [ "pudb", "web_pdb", + "pdbr", "ipdb", "pdb", ] @@ -294,6 +293,7 @@ def _guess_debugger() -> _SupportedDebugger: * `pudb `__ * `web_pdb `__ + * `pdbr `__ * `ipdb `__ * `pdb `__ """ @@ -343,8 +343,7 @@ def format_task_instance(ti: TaskInstance) -> dict[str, str]: @cli_utils.action_cli(check_db=False) -@provide_session -def task_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> None: +def task_test(args, dag: DAG | None = None) -> None: """Test task for a given dag_id.""" # We want to log output from operators etc to show up here. Normally # airflow.task would redirect to a file, but here we want it to propagate @@ -368,8 +367,6 @@ def task_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> N dag = dag or get_dag(args.bundle_name, args.dag_id) - dag = DAG.from_sdk_dag(dag) - task = dag.get_task(task_id=args.task_id) # Add CLI provided task_params to task.params if args.task_params: @@ -384,31 +381,10 @@ def task_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> N ) try: with redirect_stdout(RedactedIO()): - if args.dry_run: - ti.dry_run() - else: - ti.run(ignore_task_deps=True, ignore_ti_state=True, test_mode=True, raise_on_defer=True) - except TaskDeferred as defer: - ti.defer_task(exception=defer, session=session) - log.info("[TASK TEST] running trigger in line") - - event = _run_inline_trigger(defer.trigger) - ti.next_method = defer.method_name - ti.next_kwargs = {"event": event.payload} if event else defer.kwargs - - execute_callable = getattr(task, ti.next_method) - if ti.next_kwargs: - execute_callable = functools.partial(execute_callable, **ti.next_kwargs) - context = ti.get_template_context(ignore_param_exceptions=False) - execute_callable(context) - - log.info("[TASK TEST] Trigger completed") - except Exception: - if args.post_mortem: + _run_task(ti=ti) + if ti.state == State.FAILED and args.post_mortem: debugger = _guess_debugger() - debugger.post_mortem() - else: - raise + debugger.set_trace() finally: if not already_has_stream_handler: # Make sure to reset back to normal. When run for CLI this doesn't @@ -426,7 +402,6 @@ def task_render(args, dag: DAG | None = None) -> None: """Render and displays templated fields for a given task.""" if not dag: dag = get_dag(args.bundle_name, args.dag_id) - dag = DAG.from_sdk_dag(dag) task = dag.get_task(task_id=args.task_id) ti, _ = _get_ti( task, args.map_index, logical_date_or_run_id=args.logical_date_or_run_id, create_if_necessary="memory" @@ -465,7 +440,7 @@ def task_clear(args) -> None: include_upstream=args.upstream, ) - DAG.clear_dags( + SchedulerDAG.clear_dags( dags, start_date=args.start_date, end_date=args.end_date, diff --git a/airflow-core/src/airflow/example_dags/example_passing_params_via_test_command.py b/airflow-core/src/airflow/example_dags/example_passing_params_via_test_command.py index 5114bea07132e..9e4d3bf477845 100644 --- a/airflow-core/src/airflow/example_dags/example_passing_params_via_test_command.py +++ b/airflow-core/src/airflow/example_dags/example_passing_params_via_test_command.py @@ -46,15 +46,14 @@ def my_py_command(params, test_mode=None, task=None): @task(task_id="env_var_test_task") -def print_env_vars(test_mode=None): +def print_env_vars(): """ Print out the "foo" param passed in via `airflow tasks test example_passing_params_via_test_command env_var_test_task --env-vars '{"foo":"bar"}'` """ - if test_mode: - print(f"foo={os.environ.get('foo')}") - print(f"AIRFLOW_TEST_MODE={os.environ.get('AIRFLOW_TEST_MODE')}") + print(f"foo={os.environ.get('foo')}") + print(f"AIRFLOW_TEST_MODE={os.environ.get('AIRFLOW_TEST_MODE')}") with DAG( diff --git a/airflow-core/tests/unit/cli/commands/test_task_command.py b/airflow-core/tests/unit/cli/commands/test_task_command.py index d864ffb808fee..a71cf02930117 100644 --- a/airflow-core/tests/unit/cli/commands/test_task_command.py +++ b/airflow-core/tests/unit/cli/commands/test_task_command.py @@ -159,10 +159,9 @@ def test_test_with_existing_dag_run(self, caplog): args = self.parser.parse_args(["tasks", "test", self.dag_id, task_id, DEFAULT_DATE.isoformat()]) with caplog.at_level("INFO", logger="airflow.task"): task_command.task_test(args) - assert ( - f"Marking task as SUCCESS. dag_id={self.dag_id}, task_id={task_id}, run_id={self.run_id}, " - in caplog.text - ) + ti = self.dag_run.get_task_instance(task_id=task_id) + assert ti is not None + assert ti.state == State.SUCCESS @pytest.mark.enable_redact def test_test_filters_secrets(self, capsys): @@ -177,12 +176,16 @@ def test_test_filters_secrets(self, capsys): ["tasks", "test", "example_python_operator", "print_the_context", "2018-01-01"], ) - with mock.patch("airflow.models.TaskInstance.run", side_effect=lambda *_, **__: print(password)): + with mock.patch( + "airflow.cli.commands.task_command._run_task", side_effect=lambda *_, **__: print(password) + ): task_command.task_test(args) assert capsys.readouterr().out.endswith("***\n") not_password = "!4321drowssapemos" - with mock.patch("airflow.models.TaskInstance.run", side_effect=lambda *_, **__: print(not_password)): + with mock.patch( + "airflow.cli.commands.task_command._run_task", side_effect=lambda *_, **__: print(not_password) + ): task_command.task_test(args) assert capsys.readouterr().out.endswith(f"{not_password}\n") @@ -234,7 +237,9 @@ def test_cli_test_with_env_vars(self): assert "AIRFLOW_TEST_MODE=True" in output @mock.patch("airflow.providers.standard.triggers.file.os.path.getmtime", return_value=0) - @mock.patch("airflow.providers.standard.triggers.file.glob", return_value=["/tmp/test"]) + @mock.patch( + "airflow.providers.standard.triggers.file.glob", return_value=["/tmp/temporary_file_for_testing"] + ) @mock.patch("airflow.providers.standard.triggers.file.os") @mock.patch("airflow.providers.standard.sensors.filesystem.FileSensor.poke", return_value=False) def test_cli_test_with_deferrable_operator(self, mock_pock, mock_os, mock_glob, mock_getmtime, caplog): @@ -252,7 +257,7 @@ def test_cli_test_with_deferrable_operator(self, mock_pock, mock_os, mock_glob, ) ) output = caplog.text - assert "wait_for_file_async completed successfully as /tmp/temporary_file_for_testing found" in output + assert "Found File /tmp/temporary_file_for_testing" in output def test_task_render(self): """ diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py b/task-sdk/src/airflow/sdk/definitions/dag.py index 9706994ed52cb..835e39dcdbad1 100644 --- a/task-sdk/src/airflow/sdk/definitions/dag.py +++ b/task-sdk/src/airflow/sdk/definitions/dag.py @@ -1244,8 +1244,9 @@ def _run_task(*, ti): ) msg = taskrun_result.msg + ti.set_state(taskrun_result.ti.state) - if taskrun_result.ti.state == State.DEFERRED and isinstance(msg, DeferTask): + if ti.state == State.DEFERRED and isinstance(msg, DeferTask): # API Server expects the task instance to be in QUEUED state before # resuming from deferral. ti.set_state(State.QUEUED)