diff --git a/airflow-core/src/airflow/cli/cli_config.py b/airflow-core/src/airflow/cli/cli_config.py index 61622dbeefc4c..dd1fab0456e17 100644 --- a/airflow-core/src/airflow/cli/cli_config.py +++ b/airflow-core/src/airflow/cli/cli_config.py @@ -355,6 +355,13 @@ def string_lower_type(val): ) # test_dag +ARG_DAGFILE_PATH = Arg( + ( + "-f", + "--dagfile-path", + ), + help="Path to the dag file. Can be absolute or relative to current directory", +) ARG_SHOW_DAGRUN = Arg( ("--show-dagrun",), help=( @@ -1124,6 +1131,16 @@ class GroupCommand(NamedTuple): description=( "Execute one single DagRun for a given DAG and logical date.\n" "\n" + "You can test a DAG in three ways:\n" + "1. Using default bundle:\n" + " airflow dags test \n" + "\n" + "2. Using a specific bundle if multiple DAG bundles are configured:\n" + " airflow dags test --bundle-name (or -B )\n" + "\n" + "3. Using a specific DAG file:\n" + " airflow dags test --dagfile-path (or -f )\n" + "\n" "The --imgcat-dagrun option only works in iTerm.\n" "\n" "For more information, see: https://www.iterm2.com/documentation-images.html\n" @@ -1144,6 +1161,8 @@ class GroupCommand(NamedTuple): args=( ARG_DAG_ID, ARG_LOGICAL_DATE_OPTIONAL, + ARG_BUNDLE_NAME, + ARG_DAGFILE_PATH, ARG_CONF, ARG_SHOW_DAGRUN, ARG_IMGCAT_DAGRUN, diff --git a/airflow-core/src/airflow/cli/commands/dag_command.py b/airflow-core/src/airflow/cli/commands/dag_command.py index 1b4b017e6c793..3fd3f29f8f429 100644 --- a/airflow-core/src/airflow/cli/commands/dag_command.py +++ b/airflow-core/src/airflow/cli/commands/dag_command.py @@ -24,11 +24,9 @@ import json import logging import operator -import os import re import subprocess import sys -from pathlib import Path from typing import TYPE_CHECKING from sqlalchemy import func, select @@ -41,10 +39,8 @@ from airflow.exceptions import AirflowException from airflow.jobs.job import Job from airflow.models import DagBag, DagModel, DagRun, TaskInstance -from airflow.models.dag import DAG from airflow.models.errors import ParseImportError from airflow.models.serialized_dag import SerializedDagModel -from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager from airflow.utils import cli as cli_utils, timezone from airflow.utils.cli import get_dag, suppress_logs_and_warning, validate_dag_bundle_arg from airflow.utils.dot_renderer import render_dag, render_dag_dependencies @@ -57,6 +53,7 @@ from graphviz.dot import Dot from sqlalchemy.orm import Session + from airflow.models.dag import DAG from airflow.timetables.base import DataInterval DAG_DETAIL_FIELDS = {*DAGResponse.model_fields, *DAGResponse.model_computed_fields} @@ -591,29 +588,6 @@ def _render_dagrun(dr: DagRun) -> dict[str, str]: AirflowConsole().print_as(data=dag_runs, output=args.output, mapper=_render_dagrun) -def _parse_and_get_dag(dag_id: str) -> DAG | None: - """Given a dag_id, determine the bundle and relative fileloc from the db, then parse and return the DAG.""" - db_dag = get_dag(bundle_names=None, dag_id=dag_id, from_db=True) - bundle_name = db_dag.get_bundle_name() - if bundle_name is None: - raise AirflowException( - f"Bundle name for DAG {dag_id!r} is not found in the database. This should not happen." - ) - if db_dag.relative_fileloc is None: - raise AirflowException( - f"Relative fileloc for DAG {dag_id!r} is not found in the database. This should not happen." - ) - bundle = DagBundlesManager().get_bundle(bundle_name) - bundle.initialize() - dag_absolute_path = os.fspath(Path(bundle.path, db_dag.relative_fileloc)) - - with _airflow_parsing_context_manager(dag_id=dag_id): - bag = DagBag( - dag_folder=dag_absolute_path, include_examples=False, safe_mode=False, load_op_links=False - ) - return bag.dags.get(dag_id) - - @cli_utils.action_cli @providers_configuration_loaded @provide_session @@ -632,13 +606,12 @@ def dag_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> No re.compile(args.mark_success_pattern) if args.mark_success_pattern is not None else None ) - dag = dag or _parse_and_get_dag(args.dag_id) + dag = dag or get_dag(bundle_names=args.bundle_name, dag_id=args.dag_id, dagfile_path=args.dagfile_path) if not dag: raise AirflowException( f"Dag {args.dag_id!r} could not be found; either it does not exist or it failed to parse." ) - dag = DAG.from_sdk_dag(dag) dr: DagRun = dag.test( logical_date=logical_date, run_conf=run_conf, diff --git a/airflow-core/src/airflow/utils/cli.py b/airflow-core/src/airflow/utils/cli.py index 0a7980ec44cbe..8bde3f9abc8c0 100644 --- a/airflow-core/src/airflow/utils/cli.py +++ b/airflow-core/src/airflow/utils/cli.py @@ -35,6 +35,7 @@ from airflow import settings from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.exceptions import AirflowException +from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager from airflow.sdk.execution_time.secrets_masker import should_hide_value_for_key from airflow.utils import cli_action_loggers, timezone from airflow.utils.log.non_caching_file_handler import NonCachingFileHandler @@ -258,7 +259,9 @@ def _search_for_dag_file(val: str | None) -> str | None: return None -def get_dag(bundle_names: list | None, dag_id: str, from_db: bool = False) -> DAG: +def get_dag( + bundle_names: list | None, dag_id: str, from_db: bool = False, dagfile_path: str | None = None +) -> DAG: """ Return DAG of a given dag_id. @@ -277,7 +280,10 @@ def get_dag(bundle_names: list | None, dag_id: str, from_db: bool = False) -> DA manager = DagBundlesManager() for bundle_name in bundle_names: bundle = manager.get_bundle(bundle_name) - dagbag = DagBag(dag_folder=bundle.path, bundle_path=bundle.path, include_examples=False) + with _airflow_parsing_context_manager(dag_id=dag_id): + dagbag = DagBag( + dag_folder=dagfile_path or bundle.path, bundle_path=bundle.path, include_examples=False + ) dag = dagbag.dags.get(dag_id) if dag: break @@ -287,7 +293,10 @@ def get_dag(bundle_names: list | None, dag_id: str, from_db: bool = False) -> DA manager = DagBundlesManager() all_bundles = list(manager.get_all_dag_bundles()) for bundle in all_bundles: - dag_bag = DagBag(dag_folder=bundle.path, bundle_path=bundle.path) + with _airflow_parsing_context_manager(dag_id=dag_id): + dag_bag = DagBag( + dag_folder=dagfile_path or bundle.path, bundle_path=bundle.path, include_examples=False + ) dag = dag_bag.dags.get(dag_id) if dag: break diff --git a/airflow-core/tests/unit/cli/commands/test_dag_command.py b/airflow-core/tests/unit/cli/commands/test_dag_command.py index 052138d3693b7..fb2daebf418f2 100644 --- a/airflow-core/tests/unit/cli/commands/test_dag_command.py +++ b/airflow-core/tests/unit/cli/commands/test_dag_command.py @@ -617,14 +617,14 @@ def test_dag_state(self): is None ) - @mock.patch("airflow.cli.commands.dag_command._parse_and_get_dag") - def test_dag_test(self, mock_parse_and_get_dag): + @mock.patch("airflow.cli.commands.dag_command.get_dag") + def test_dag_test(self, mock_get_dag): cli_args = self.parser.parse_args(["dags", "test", "example_bash_operator", DEFAULT_DATE.isoformat()]) dag_command.dag_test(cli_args) - mock_parse_and_get_dag.assert_has_calls( + mock_get_dag.assert_has_calls( [ - mock.call("example_bash_operator"), + mock.call(bundle_names=None, dag_id="example_bash_operator", dagfile_path=None), mock.call().__bool__(), mock.call().test( logical_date=timezone.parse(DEFAULT_DATE.isoformat()), @@ -635,19 +635,19 @@ def test_dag_test(self, mock_parse_and_get_dag): ] ) - @mock.patch("airflow.cli.commands.dag_command._parse_and_get_dag") - def test_dag_test_fail_raise_error(self, mock_parse_and_get_dag): + @mock.patch("airflow.cli.commands.dag_command.get_dag") + def test_dag_test_fail_raise_error(self, mock_get_dag): logical_date_str = DEFAULT_DATE.isoformat() - mock_parse_and_get_dag.return_value.test.return_value = DagRun( + mock_get_dag.return_value.test.return_value = DagRun( dag_id="example_bash_operator", logical_date=DEFAULT_DATE, state=DagRunState.FAILED ) cli_args = self.parser.parse_args(["dags", "test", "example_bash_operator", logical_date_str]) with pytest.raises(SystemExit, match=r"DagRun failed"): dag_command.dag_test(cli_args) - @mock.patch("airflow.cli.commands.dag_command._parse_and_get_dag") + @mock.patch("airflow.cli.commands.dag_command.get_dag") @mock.patch("airflow.utils.timezone.utcnow") - def test_dag_test_no_logical_date(self, mock_utcnow, mock_parse_and_get_dag): + def test_dag_test_no_logical_date(self, mock_utcnow, mock_get_dag): now = pendulum.now() mock_utcnow.return_value = now cli_args = self.parser.parse_args(["dags", "test", "example_bash_operator"]) @@ -656,9 +656,9 @@ def test_dag_test_no_logical_date(self, mock_utcnow, mock_parse_and_get_dag): dag_command.dag_test(cli_args) - mock_parse_and_get_dag.assert_has_calls( + mock_get_dag.assert_has_calls( [ - mock.call("example_bash_operator"), + mock.call(bundle_names=None, dag_id="example_bash_operator", dagfile_path=None), mock.call().__bool__(), mock.call().test( logical_date=mock.ANY, @@ -669,8 +669,8 @@ def test_dag_test_no_logical_date(self, mock_utcnow, mock_parse_and_get_dag): ] ) - @mock.patch("airflow.cli.commands.dag_command._parse_and_get_dag") - def test_dag_test_conf(self, mock_parse_and_get_dag): + @mock.patch("airflow.cli.commands.dag_command.get_dag") + def test_dag_test_conf(self, mock_get_dag): cli_args = self.parser.parse_args( [ "dags", @@ -683,9 +683,9 @@ def test_dag_test_conf(self, mock_parse_and_get_dag): ) dag_command.dag_test(cli_args) - mock_parse_and_get_dag.assert_has_calls( + mock_get_dag.assert_has_calls( [ - mock.call("example_bash_operator"), + mock.call(bundle_names=None, dag_id="example_bash_operator", dagfile_path=None), mock.call().__bool__(), mock.call().test( logical_date=timezone.parse(DEFAULT_DATE.isoformat()), @@ -697,11 +697,9 @@ def test_dag_test_conf(self, mock_parse_and_get_dag): ) @mock.patch("airflow.cli.commands.dag_command.render_dag", return_value=MagicMock(source="SOURCE")) - @mock.patch("airflow.cli.commands.dag_command._parse_and_get_dag") - def test_dag_test_show_dag(self, mock_parse_and_get_dag, mock_render_dag): - mock_parse_and_get_dag.return_value.test.return_value.run_id = ( - "__test_dag_test_show_dag_fake_dag_run_run_id__" - ) + @mock.patch("airflow.cli.commands.dag_command.get_dag") + def test_dag_test_show_dag(self, mock_get_dag, mock_render_dag): + mock_get_dag.return_value.test.return_value.run_id = "__test_dag_test_show_dag_fake_dag_run_run_id__" cli_args = self.parser.parse_args( ["dags", "test", "example_bash_operator", DEFAULT_DATE.isoformat(), "--show-dagrun"] @@ -711,9 +709,9 @@ def test_dag_test_show_dag(self, mock_parse_and_get_dag, mock_render_dag): output = stdout.getvalue() - mock_parse_and_get_dag.assert_has_calls( + mock_get_dag.assert_has_calls( [ - mock.call("example_bash_operator"), + mock.call(bundle_names=None, dag_id="example_bash_operator", dagfile_path=None), mock.call().__bool__(), mock.call().test( logical_date=timezone.parse(DEFAULT_DATE.isoformat()), @@ -723,9 +721,88 @@ def test_dag_test_show_dag(self, mock_parse_and_get_dag, mock_render_dag): ), ] ) - mock_render_dag.assert_has_calls([mock.call(mock_parse_and_get_dag.return_value, tis=[])]) + mock_render_dag.assert_has_calls([mock.call(mock_get_dag.return_value, tis=[])]) assert "SOURCE" in output + @mock.patch("airflow.models.dagbag.DagBag") + def test_dag_test_with_bundle_name(self, mock_dagbag, configure_dag_bundles): + """Test that DAG can be tested using bundle name.""" + mock_dagbag.return_value.get_dag.return_value.test.return_value = DagRun( + dag_id="test_example_bash_operator", logical_date=DEFAULT_DATE, state=DagRunState.SUCCESS + ) + + cli_args = self.parser.parse_args( + [ + "dags", + "test", + "test_example_bash_operator", + DEFAULT_DATE.isoformat(), + "--bundle-name", + "testing", + ] + ) + + with configure_dag_bundles({"testing": TEST_DAGS_FOLDER}): + dag_command.dag_test(cli_args) + + mock_dagbag.assert_called_once_with( + bundle_path=TEST_DAGS_FOLDER, + dag_folder=TEST_DAGS_FOLDER, + include_examples=False, + ) + + @mock.patch("airflow.models.dagbag.DagBag") + def test_dag_test_with_dagfile_path(self, mock_dagbag, configure_dag_bundles): + """Test that DAG can be tested using dagfile path.""" + mock_dagbag.return_value.get_dag.return_value.test.return_value = DagRun( + dag_id="test_example_bash_operator", logical_date=DEFAULT_DATE, state=DagRunState.SUCCESS + ) + + dag_file = TEST_DAGS_FOLDER / "test_example_bash_operator.py" + + cli_args = self.parser.parse_args( + ["dags", "test", "test_example_bash_operator", "--dagfile-path", str(dag_file)] + ) + with configure_dag_bundles({"testing": TEST_DAGS_FOLDER}): + dag_command.dag_test(cli_args) + + mock_dagbag.assert_called_once_with( + bundle_path=TEST_DAGS_FOLDER, + dag_folder=str(dag_file), + include_examples=False, + ) + + @mock.patch("airflow.models.dagbag.DagBag") + def test_dag_test_with_both_bundle_and_dagfile_path(self, mock_dagbag, configure_dag_bundles): + """Test that DAG can be tested using both bundle name and dagfile path.""" + mock_dagbag.return_value.get_dag.return_value.test.return_value = DagRun( + dag_id="test_example_bash_operator", logical_date=DEFAULT_DATE, state=DagRunState.SUCCESS + ) + + dag_file = TEST_DAGS_FOLDER / "test_example_bash_operator.py" + + cli_args = self.parser.parse_args( + [ + "dags", + "test", + "test_example_bash_operator", + DEFAULT_DATE.isoformat(), + "--bundle-name", + "testing", + "--dagfile-path", + str(dag_file), + ] + ) + + with configure_dag_bundles({"testing": TEST_DAGS_FOLDER}): + dag_command.dag_test(cli_args) + + mock_dagbag.assert_called_once_with( + bundle_path=TEST_DAGS_FOLDER, + dag_folder=str(dag_file), + include_examples=False, + ) + @mock.patch("airflow.models.dag._get_or_create_dagrun") def test_dag_test_with_custom_timetable(self, mock__get_or_create_dagrun): """ diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 3c9251a2b55ed..7babea50a69ef 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -374,6 +374,7 @@ Dagbag dagbag dagbags DagCallbackRequest +dagfile DagFileProcessorManager dagfolder dagmodel