Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions airflow-core/src/airflow/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,13 @@ def string_lower_type(val):
)

# test_dag
ARG_DAGFILE_PATH = Arg(
(
"-f",
"--dagfile-path",
),
help="Path to the dag file. Can be absolute or relative to current directory",
)
ARG_SHOW_DAGRUN = Arg(
("--show-dagrun",),
help=(
Expand Down Expand Up @@ -1124,6 +1131,16 @@ class GroupCommand(NamedTuple):
description=(
"Execute one single DagRun for a given DAG and logical date.\n"
"\n"
"You can test a DAG in three ways:\n"
"1. Using default bundle:\n"
" airflow dags test <DAG_ID>\n"
"\n"
"2. Using a specific bundle if multiple DAG bundles are configured:\n"
" airflow dags test <DAG_ID> --bundle-name <BUNDLE_NAME> (or -B <BUNDLE_NAME>)\n"
"\n"
"3. Using a specific DAG file:\n"
" airflow dags test <DAG_ID> --dagfile-path <PATH> (or -f <PATH>)\n"
"\n"
"The --imgcat-dagrun option only works in iTerm.\n"
"\n"
"For more information, see: https://www.iterm2.com/documentation-images.html\n"
Expand All @@ -1144,6 +1161,8 @@ class GroupCommand(NamedTuple):
args=(
ARG_DAG_ID,
ARG_LOGICAL_DATE_OPTIONAL,
ARG_BUNDLE_NAME,
ARG_DAGFILE_PATH,
ARG_CONF,
ARG_SHOW_DAGRUN,
ARG_IMGCAT_DAGRUN,
Expand Down
31 changes: 2 additions & 29 deletions airflow-core/src/airflow/cli/commands/dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,9 @@
import json
import logging
import operator
import os
import re
import subprocess
import sys
from pathlib import Path
from typing import TYPE_CHECKING

from sqlalchemy import func, select
Expand All @@ -41,10 +39,8 @@
from airflow.exceptions import AirflowException
from airflow.jobs.job import Job
from airflow.models import DagBag, DagModel, DagRun, TaskInstance
from airflow.models.dag import DAG
from airflow.models.errors import ParseImportError
from airflow.models.serialized_dag import SerializedDagModel
from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager
from airflow.utils import cli as cli_utils, timezone
from airflow.utils.cli import get_dag, suppress_logs_and_warning, validate_dag_bundle_arg
from airflow.utils.dot_renderer import render_dag, render_dag_dependencies
Expand All @@ -57,6 +53,7 @@
from graphviz.dot import Dot
from sqlalchemy.orm import Session

from airflow.models.dag import DAG
from airflow.timetables.base import DataInterval

DAG_DETAIL_FIELDS = {*DAGResponse.model_fields, *DAGResponse.model_computed_fields}
Expand Down Expand Up @@ -591,29 +588,6 @@ def _render_dagrun(dr: DagRun) -> dict[str, str]:
AirflowConsole().print_as(data=dag_runs, output=args.output, mapper=_render_dagrun)


def _parse_and_get_dag(dag_id: str) -> DAG | None:
"""Given a dag_id, determine the bundle and relative fileloc from the db, then parse and return the DAG."""
db_dag = get_dag(bundle_names=None, dag_id=dag_id, from_db=True)
bundle_name = db_dag.get_bundle_name()
if bundle_name is None:
raise AirflowException(
f"Bundle name for DAG {dag_id!r} is not found in the database. This should not happen."
)
if db_dag.relative_fileloc is None:
raise AirflowException(
f"Relative fileloc for DAG {dag_id!r} is not found in the database. This should not happen."
)
bundle = DagBundlesManager().get_bundle(bundle_name)
bundle.initialize()
dag_absolute_path = os.fspath(Path(bundle.path, db_dag.relative_fileloc))

with _airflow_parsing_context_manager(dag_id=dag_id):
bag = DagBag(
dag_folder=dag_absolute_path, include_examples=False, safe_mode=False, load_op_links=False
)
return bag.dags.get(dag_id)


@cli_utils.action_cli
@providers_configuration_loaded
@provide_session
Expand All @@ -632,13 +606,12 @@ def dag_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> No
re.compile(args.mark_success_pattern) if args.mark_success_pattern is not None else None
)

dag = dag or _parse_and_get_dag(args.dag_id)
dag = dag or get_dag(bundle_names=args.bundle_name, dag_id=args.dag_id, dagfile_path=args.dagfile_path)
if not dag:
raise AirflowException(
f"Dag {args.dag_id!r} could not be found; either it does not exist or it failed to parse."
)

dag = DAG.from_sdk_dag(dag)
dr: DagRun = dag.test(
logical_date=logical_date,
run_conf=run_conf,
Expand Down
15 changes: 12 additions & 3 deletions airflow-core/src/airflow/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from airflow import settings
from airflow.dag_processing.bundles.manager import DagBundlesManager
from airflow.exceptions import AirflowException
from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager
from airflow.sdk.execution_time.secrets_masker import should_hide_value_for_key
from airflow.utils import cli_action_loggers, timezone
from airflow.utils.log.non_caching_file_handler import NonCachingFileHandler
Expand Down Expand Up @@ -258,7 +259,9 @@ def _search_for_dag_file(val: str | None) -> str | None:
return None


def get_dag(bundle_names: list | None, dag_id: str, from_db: bool = False) -> DAG:
def get_dag(
bundle_names: list | None, dag_id: str, from_db: bool = False, dagfile_path: str | None = None
) -> DAG:
"""
Return DAG of a given dag_id.

Expand All @@ -277,7 +280,10 @@ def get_dag(bundle_names: list | None, dag_id: str, from_db: bool = False) -> DA
manager = DagBundlesManager()
for bundle_name in bundle_names:
bundle = manager.get_bundle(bundle_name)
dagbag = DagBag(dag_folder=bundle.path, bundle_path=bundle.path, include_examples=False)
with _airflow_parsing_context_manager(dag_id=dag_id):
dagbag = DagBag(
dag_folder=dagfile_path or bundle.path, bundle_path=bundle.path, include_examples=False
)
dag = dagbag.dags.get(dag_id)
if dag:
break
Expand All @@ -287,7 +293,10 @@ def get_dag(bundle_names: list | None, dag_id: str, from_db: bool = False) -> DA
manager = DagBundlesManager()
all_bundles = list(manager.get_all_dag_bundles())
for bundle in all_bundles:
dag_bag = DagBag(dag_folder=bundle.path, bundle_path=bundle.path)
with _airflow_parsing_context_manager(dag_id=dag_id):
dag_bag = DagBag(
dag_folder=dagfile_path or bundle.path, bundle_path=bundle.path, include_examples=False
)
dag = dag_bag.dags.get(dag_id)
if dag:
break
Expand Down
123 changes: 100 additions & 23 deletions airflow-core/tests/unit/cli/commands/test_dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,14 +617,14 @@ def test_dag_state(self):
is None
)

@mock.patch("airflow.cli.commands.dag_command._parse_and_get_dag")
def test_dag_test(self, mock_parse_and_get_dag):
@mock.patch("airflow.cli.commands.dag_command.get_dag")
def test_dag_test(self, mock_get_dag):
cli_args = self.parser.parse_args(["dags", "test", "example_bash_operator", DEFAULT_DATE.isoformat()])
dag_command.dag_test(cli_args)

mock_parse_and_get_dag.assert_has_calls(
mock_get_dag.assert_has_calls(
[
mock.call("example_bash_operator"),
mock.call(bundle_names=None, dag_id="example_bash_operator", dagfile_path=None),
mock.call().__bool__(),
mock.call().test(
logical_date=timezone.parse(DEFAULT_DATE.isoformat()),
Expand All @@ -635,19 +635,19 @@ def test_dag_test(self, mock_parse_and_get_dag):
]
)

@mock.patch("airflow.cli.commands.dag_command._parse_and_get_dag")
def test_dag_test_fail_raise_error(self, mock_parse_and_get_dag):
@mock.patch("airflow.cli.commands.dag_command.get_dag")
def test_dag_test_fail_raise_error(self, mock_get_dag):
logical_date_str = DEFAULT_DATE.isoformat()
mock_parse_and_get_dag.return_value.test.return_value = DagRun(
mock_get_dag.return_value.test.return_value = DagRun(
dag_id="example_bash_operator", logical_date=DEFAULT_DATE, state=DagRunState.FAILED
)
cli_args = self.parser.parse_args(["dags", "test", "example_bash_operator", logical_date_str])
with pytest.raises(SystemExit, match=r"DagRun failed"):
dag_command.dag_test(cli_args)

@mock.patch("airflow.cli.commands.dag_command._parse_and_get_dag")
@mock.patch("airflow.cli.commands.dag_command.get_dag")
@mock.patch("airflow.utils.timezone.utcnow")
def test_dag_test_no_logical_date(self, mock_utcnow, mock_parse_and_get_dag):
def test_dag_test_no_logical_date(self, mock_utcnow, mock_get_dag):
now = pendulum.now()
mock_utcnow.return_value = now
cli_args = self.parser.parse_args(["dags", "test", "example_bash_operator"])
Expand All @@ -656,9 +656,9 @@ def test_dag_test_no_logical_date(self, mock_utcnow, mock_parse_and_get_dag):

dag_command.dag_test(cli_args)

mock_parse_and_get_dag.assert_has_calls(
mock_get_dag.assert_has_calls(
[
mock.call("example_bash_operator"),
mock.call(bundle_names=None, dag_id="example_bash_operator", dagfile_path=None),
mock.call().__bool__(),
mock.call().test(
logical_date=mock.ANY,
Expand All @@ -669,8 +669,8 @@ def test_dag_test_no_logical_date(self, mock_utcnow, mock_parse_and_get_dag):
]
)

@mock.patch("airflow.cli.commands.dag_command._parse_and_get_dag")
def test_dag_test_conf(self, mock_parse_and_get_dag):
@mock.patch("airflow.cli.commands.dag_command.get_dag")
def test_dag_test_conf(self, mock_get_dag):
cli_args = self.parser.parse_args(
[
"dags",
Expand All @@ -683,9 +683,9 @@ def test_dag_test_conf(self, mock_parse_and_get_dag):
)
dag_command.dag_test(cli_args)

mock_parse_and_get_dag.assert_has_calls(
mock_get_dag.assert_has_calls(
[
mock.call("example_bash_operator"),
mock.call(bundle_names=None, dag_id="example_bash_operator", dagfile_path=None),
mock.call().__bool__(),
mock.call().test(
logical_date=timezone.parse(DEFAULT_DATE.isoformat()),
Expand All @@ -697,11 +697,9 @@ def test_dag_test_conf(self, mock_parse_and_get_dag):
)

@mock.patch("airflow.cli.commands.dag_command.render_dag", return_value=MagicMock(source="SOURCE"))
@mock.patch("airflow.cli.commands.dag_command._parse_and_get_dag")
def test_dag_test_show_dag(self, mock_parse_and_get_dag, mock_render_dag):
mock_parse_and_get_dag.return_value.test.return_value.run_id = (
"__test_dag_test_show_dag_fake_dag_run_run_id__"
)
@mock.patch("airflow.cli.commands.dag_command.get_dag")
def test_dag_test_show_dag(self, mock_get_dag, mock_render_dag):
mock_get_dag.return_value.test.return_value.run_id = "__test_dag_test_show_dag_fake_dag_run_run_id__"

cli_args = self.parser.parse_args(
["dags", "test", "example_bash_operator", DEFAULT_DATE.isoformat(), "--show-dagrun"]
Expand All @@ -711,9 +709,9 @@ def test_dag_test_show_dag(self, mock_parse_and_get_dag, mock_render_dag):

output = stdout.getvalue()

mock_parse_and_get_dag.assert_has_calls(
mock_get_dag.assert_has_calls(
[
mock.call("example_bash_operator"),
mock.call(bundle_names=None, dag_id="example_bash_operator", dagfile_path=None),
mock.call().__bool__(),
mock.call().test(
logical_date=timezone.parse(DEFAULT_DATE.isoformat()),
Expand All @@ -723,9 +721,88 @@ def test_dag_test_show_dag(self, mock_parse_and_get_dag, mock_render_dag):
),
]
)
mock_render_dag.assert_has_calls([mock.call(mock_parse_and_get_dag.return_value, tis=[])])
mock_render_dag.assert_has_calls([mock.call(mock_get_dag.return_value, tis=[])])
assert "SOURCE" in output

@mock.patch("airflow.models.dagbag.DagBag")
def test_dag_test_with_bundle_name(self, mock_dagbag, configure_dag_bundles):
"""Test that DAG can be tested using bundle name."""
mock_dagbag.return_value.get_dag.return_value.test.return_value = DagRun(
dag_id="test_example_bash_operator", logical_date=DEFAULT_DATE, state=DagRunState.SUCCESS
)

cli_args = self.parser.parse_args(
[
"dags",
"test",
"test_example_bash_operator",
DEFAULT_DATE.isoformat(),
"--bundle-name",
"testing",
]
)

with configure_dag_bundles({"testing": TEST_DAGS_FOLDER}):
dag_command.dag_test(cli_args)

mock_dagbag.assert_called_once_with(
bundle_path=TEST_DAGS_FOLDER,
dag_folder=TEST_DAGS_FOLDER,
include_examples=False,
)

@mock.patch("airflow.models.dagbag.DagBag")
def test_dag_test_with_dagfile_path(self, mock_dagbag, configure_dag_bundles):
"""Test that DAG can be tested using dagfile path."""
mock_dagbag.return_value.get_dag.return_value.test.return_value = DagRun(
dag_id="test_example_bash_operator", logical_date=DEFAULT_DATE, state=DagRunState.SUCCESS
)

dag_file = TEST_DAGS_FOLDER / "test_example_bash_operator.py"

cli_args = self.parser.parse_args(
["dags", "test", "test_example_bash_operator", "--dagfile-path", str(dag_file)]
)
with configure_dag_bundles({"testing": TEST_DAGS_FOLDER}):
dag_command.dag_test(cli_args)

mock_dagbag.assert_called_once_with(
bundle_path=TEST_DAGS_FOLDER,
dag_folder=str(dag_file),
include_examples=False,
)

@mock.patch("airflow.models.dagbag.DagBag")
def test_dag_test_with_both_bundle_and_dagfile_path(self, mock_dagbag, configure_dag_bundles):
"""Test that DAG can be tested using both bundle name and dagfile path."""
mock_dagbag.return_value.get_dag.return_value.test.return_value = DagRun(
dag_id="test_example_bash_operator", logical_date=DEFAULT_DATE, state=DagRunState.SUCCESS
)

dag_file = TEST_DAGS_FOLDER / "test_example_bash_operator.py"

cli_args = self.parser.parse_args(
[
"dags",
"test",
"test_example_bash_operator",
DEFAULT_DATE.isoformat(),
"--bundle-name",
"testing",
"--dagfile-path",
str(dag_file),
]
)

with configure_dag_bundles({"testing": TEST_DAGS_FOLDER}):
dag_command.dag_test(cli_args)

mock_dagbag.assert_called_once_with(
bundle_path=TEST_DAGS_FOLDER,
dag_folder=str(dag_file),
include_examples=False,
)

@mock.patch("airflow.models.dag._get_or_create_dagrun")
def test_dag_test_with_custom_timetable(self, mock__get_or_create_dagrun):
"""
Expand Down
1 change: 1 addition & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ Dagbag
dagbag
dagbags
DagCallbackRequest
dagfile
DagFileProcessorManager
dagfolder
dagmodel
Expand Down