diff --git a/airflow-core/src/airflow/__init__.py b/airflow-core/src/airflow/__init__.py index 229b068c76075..5c2d08b41c924 100644 --- a/airflow-core/src/airflow/__init__.py +++ b/airflow-core/src/airflow/__init__.py @@ -80,21 +80,20 @@ # Things to lazy import in form {local_name: ('target_module', 'target_name', 'deprecated')} __lazy_imports: dict[str, tuple[str, str, bool]] = { - "DAG": (".models.dag", "DAG", False), - "Asset": (".assets", "Asset", False), + "DAG": (".sdk", "DAG", False), + "Asset": (".sdk", "Asset", False), "XComArg": (".models.xcom_arg", "XComArg", False), "version": (".version", "", False), # Deprecated lazy imports "AirflowException": (".exceptions", "AirflowException", True), - "Dataset": (".sdk.definitions.asset", "Asset", True), + "Dataset": (".sdk", "Asset", True), } if TYPE_CHECKING: # These objects are imported by PEP-562, however, static analyzers and IDE's # have no idea about typing of these objects. # Add it under TYPE_CHECKING block should help with it. - from airflow.models.dag import DAG from airflow.models.xcom_arg import XComArg - from airflow.sdk.definitions.asset import Asset, Dataset + from airflow.sdk import DAG, Asset, Asset as Dataset def __getattr__(name: str): diff --git a/airflow-core/src/airflow/api/common/mark_tasks.py b/airflow-core/src/airflow/api/common/mark_tasks.py index fe02e3b462636..d424bab603a9b 100644 --- a/airflow-core/src/airflow/api/common/mark_tasks.py +++ b/airflow-core/src/airflow/api/common/mark_tasks.py @@ -33,9 +33,8 @@ if TYPE_CHECKING: from sqlalchemy.orm import Session as SASession - from airflow.models.dag import DAG from airflow.models.mappedoperator import MappedOperator - from airflow.serialization.serialized_objects import SerializedBaseOperator + from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG Operator: TypeAlias = MappedOperator | SerializedBaseOperator @@ -86,9 +85,6 @@ def set_state( if not run_id: raise ValueError("Received tasks with no run_id") - if TYPE_CHECKING: - assert isinstance(dag, DAG) - dag_run_ids = get_run_ids(dag, run_id, future, past, session=session) task_id_map_index_list = list(find_task_relatives(tasks, downstream, upstream)) # now look for the task instances that are affected @@ -106,7 +102,7 @@ def set_state( def get_all_dag_task_query( - dag: DAG, + dag: SerializedDAG, state: TaskInstanceState, task_ids: list[str | tuple[str, int]], run_ids: Iterable[str], @@ -142,30 +138,44 @@ def find_task_relatives(tasks, downstream, upstream): @provide_session -def get_run_ids(dag: DAG, run_id: str, future: bool, past: bool, session: SASession = NEW_SESSION): +def get_run_ids(dag: SerializedDAG, run_id: str, future: bool, past: bool, session: SASession = NEW_SESSION): """Return DAG executions' run_ids.""" - current_dagrun = dag.get_dagrun(run_id=run_id, session=session) - if current_dagrun.logical_date is None: + current_logical_date = session.scalar( + select(DagRun.logical_date).where(DagRun.dag_id == dag.dag_id, DagRun.run_id == run_id) + ) + if current_logical_date is None: return [run_id] - last_dagrun = dag.get_last_dagrun(include_manually_triggered=True, session=session) - first_dagrun = session.scalar( - select(DagRun) + last_logical_date = session.scalar( + select(DagRun.logical_date) .where(DagRun.dag_id == dag.dag_id, DagRun.logical_date.is_not(None)) - .order_by(DagRun.logical_date.asc()) + .order_by(DagRun.logical_date.desc()) .limit(1) ) - if last_dagrun is None: + if last_logical_date is None: raise ValueError(f"DagRun for {dag.dag_id} not found") + first_logical_date = session.scalar( + select(DagRun.logical_date) + .where(DagRun.dag_id == dag.dag_id, DagRun.logical_date.is_not(None)) + .order_by(DagRun.logical_date.asc()) + .limit(1) + ) + # determine run_id range of dag runs and tasks to consider - end_date = last_dagrun.logical_date if future else current_dagrun.logical_date - start_date = current_dagrun.logical_date if not past else first_dagrun.logical_date + end_date = last_logical_date if future else current_logical_date + start_date = current_logical_date if not past else first_logical_date if not dag.timetable.can_be_scheduled: - # If the DAG never schedules, need to look at existing DagRun if the user wants future or - # past runs. - dag_runs = dag.get_dagruns_between(start_date=start_date, end_date=end_date, session=session) - run_ids = sorted({d.run_id for d in dag_runs}) + # If the DAG never schedules, need to look at existing DagRun if the + # user wants future or past runs. + dag_runs = session.scalars( + select(DagRun).where( + DagRun.dag_id == dag.dag_id, + DagRun.logical_date >= start_date, + DagRun.logical_date <= end_date, + ) + ) + run_ids = sorted(d.run_id for d in dag_runs) elif not dag.timetable.periodic: run_ids = [run_id] else: @@ -195,7 +205,7 @@ def _set_dag_run_state(dag_id: str, run_id: str, state: DagRunState, session: SA @provide_session def set_dag_run_state_to_success( *, - dag: DAG, + dag: SerializedDAG, run_id: str | None = None, commit: bool = False, session: SASession = NEW_SESSION, @@ -256,7 +266,7 @@ def set_dag_run_state_to_success( @provide_session def set_dag_run_state_to_failed( *, - dag: DAG, + dag: SerializedDAG, run_id: str | None = None, commit: bool = False, session: SASession = NEW_SESSION, @@ -348,7 +358,7 @@ def _set_runing_task(task: Operator) -> Operator: def __set_dag_run_state_to_running_or_queued( *, new_state: DagRunState, - dag: DAG, + dag: SerializedDAG, run_id: str | None = None, commit: bool = False, session: SASession, @@ -379,7 +389,7 @@ def __set_dag_run_state_to_running_or_queued( @provide_session def set_dag_run_state_to_queued( *, - dag: DAG, + dag: SerializedDAG, run_id: str | None = None, commit: bool = False, session: SASession = NEW_SESSION, diff --git a/airflow-core/src/airflow/api_fastapi/common/dagbag.py b/airflow-core/src/airflow/api_fastapi/common/dagbag.py index f1c7271b020f6..491a7131acc07 100644 --- a/airflow-core/src/airflow/api_fastapi/common/dagbag.py +++ b/airflow-core/src/airflow/api_fastapi/common/dagbag.py @@ -24,8 +24,8 @@ from airflow.models.dagbag import DBDagBag if TYPE_CHECKING: - from airflow.models.dag import DAG from airflow.models.dagrun import DagRun + from airflow.serialization.serialized_objects import SerializedDAG def create_dag_bag() -> DBDagBag: @@ -45,7 +45,7 @@ def dag_bag_from_app(request: Request) -> DBDagBag: def get_latest_version_of_dag( dag_bag: DBDagBag, dag_id: str, session: Session, include_reason: bool = False -) -> DAG: +) -> SerializedDAG: dag = dag_bag.get_latest_version_of_dag(dag_id, session=session) if not dag: if include_reason: @@ -60,7 +60,7 @@ def get_latest_version_of_dag( return dag -def get_dag_for_run(dag_bag: DBDagBag, dag_run: DagRun, session: Session) -> DAG: +def get_dag_for_run(dag_bag: DBDagBag, dag_run: DagRun, session: Session) -> SerializedDAG: dag = dag_bag.get_dag_for_run(dag_run, session=session) if not dag: raise HTTPException(status.HTTP_404_NOT_FOUND, f"The Dag with ID: `{dag_run.dag_id}` was not found") @@ -69,8 +69,8 @@ def get_dag_for_run(dag_bag: DBDagBag, dag_run: DagRun, session: Session) -> DAG def get_dag_for_run_or_latest_version( dag_bag: DBDagBag, dag_run: DagRun | None, dag_id: str | None, session: Session -) -> DAG: - dag: DAG | None = None +) -> SerializedDAG: + dag: SerializedDAG | None = None if dag_run: dag = dag_bag.get_dag_for_run(dag_run, session=session) elif dag_id: diff --git a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/dag_run.py b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/dag_run.py index 9a0bb16b91884..ed7aac2ae000f 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/dag_run.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/dag_run.py @@ -32,7 +32,7 @@ from airflow.utils.types import DagRunTriggeredByType, DagRunType if TYPE_CHECKING: - from airflow.models import DAG + from airflow.serialization.serialized_objects import SerializedDAG class DAGRunPatchStates(str, Enum): @@ -113,7 +113,7 @@ def check_data_intervals(cls, values): ) return values - def validate_context(self, dag: DAG) -> dict: + def validate_context(self, dag: SerializedDAG) -> dict: coerced_logical_date = timezone.coerce_datetime(self.logical_date) run_after = self.run_after or timezone.utcnow() data_interval = None diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py index eabf620c41be0..437e2ef8d6600 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py @@ -744,7 +744,6 @@ def post_clear_task_instances( common_params = { "dry_run": True, "task_ids": task_ids, - "dag_bag": dag_bag, "session": session, "run_on_latest_version": body.run_on_latest_version, "only_failed": body.only_failed, diff --git a/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py b/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py index 4f68b808b7cf8..cbfe7bea3ef7b 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py @@ -39,8 +39,8 @@ from airflow.api_fastapi.core_api.security import GetUserDep from airflow.api_fastapi.core_api.services.public.common import BulkService from airflow.listeners.listener import get_listener_manager -from airflow.models.dag import DAG from airflow.models.taskinstance import TaskInstance as TI +from airflow.serialization.serialized_objects import SerializedDAG from airflow.utils.state import TaskInstanceState log = structlog.get_logger(__name__) @@ -55,7 +55,7 @@ def _patch_ti_validate_request( session: SessionDep, map_index: int | None = -1, update_mask: list[str] | None = Query(None), -) -> tuple[DAG, list[TI], dict]: +) -> tuple[SerializedDAG, list[TI], dict]: dag = get_latest_version_of_dag(dag_bag, dag_id, session) if not dag.has_task(task_id): raise HTTPException(status.HTTP_404_NOT_FOUND, f"Task '{task_id}' not found in DAG '{dag_id}'") @@ -94,7 +94,7 @@ def _patch_ti_validate_request( def _patch_task_instance_state( task_id: str, dag_run_id: str, - dag: DAG, + dag: SerializedDAG, task_instance_body: BulkTaskInstanceBody | PatchTaskInstanceBody, data: dict, session: Session, diff --git a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/calendar.py b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/calendar.py index 8562d305406ac..2e352d26e2478 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/calendar.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/calendar.py @@ -34,8 +34,8 @@ CalendarTimeRangeCollectionResponse, CalendarTimeRangeResponse, ) -from airflow.models.dag import DAG from airflow.models.dagrun import DagRun +from airflow.serialization.serialized_objects import SerializedDAG from airflow.timetables._cron import CronMixin from airflow.timetables.base import DataInterval, TimeRestriction from airflow.timetables.simple import ContinuousTimetable @@ -52,7 +52,7 @@ def get_calendar_data( self, dag_id: str, session: Session, - dag: DAG, + dag: SerializedDAG, logical_date: RangeFilter, granularity: Literal["hourly", "daily"] = "daily", ) -> CalendarTimeRangeCollectionResponse: @@ -126,7 +126,7 @@ def _get_historical_dag_runs( def _get_planned_dag_runs( self, - dag: DAG, + dag: SerializedDAG, raw_dag_states: list[Row], logical_date: RangeFilter, granularity: Literal["hourly", "daily"], @@ -152,7 +152,7 @@ def _get_planned_dag_runs( dag, last_data_interval, year, restriction, logical_date, granularity ) - def _should_calculate_planned_runs(self, dag: DAG, raw_dag_states: list[Row]) -> bool: + def _should_calculate_planned_runs(self, dag: SerializedDAG, raw_dag_states: list[Row]) -> bool: """Check if we should calculate planned runs.""" return ( bool(raw_dag_states) @@ -177,7 +177,7 @@ def _get_last_data_interval(self, raw_dag_states: list[Row]) -> DataInterval | N def _calculate_cron_planned_runs( self, - dag: DAG, + dag: SerializedDAG, last_data_interval: DataInterval, year: int, logical_date: RangeFilter, @@ -208,7 +208,7 @@ def _calculate_cron_planned_runs( def _calculate_timetable_planned_runs( self, - dag: DAG, + dag: SerializedDAG, last_data_interval: DataInterval, year: int, restriction: TimeRestriction, diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index 3892cf8d42967..ec59c5bd26d52 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -72,7 +72,8 @@ from sqlalchemy.sql.dml import Update from airflow.models.expandinput import SchedulerExpandInput - from airflow.sdk.types import Operator + from airflow.models.mappedoperator import MappedOperator + from airflow.serialization.serialized_objects import SerializedBaseOperator router = VersionedAPIRouter() @@ -254,7 +255,14 @@ def ti_run( if dag := dag_bag.get_dag_for_run(dag_run=dr, session=session): upstream_map_indexes = dict( - _get_upstream_map_indexes(dag.get_task(ti.task_id), ti.map_index, ti.run_id, session) + _get_upstream_map_indexes( + # TODO (GH-52141): This get_task should return scheduler + # types instead, but currently it inherits SDK's DAG. + cast("MappedOperator | SerializedBaseOperator", dag.get_task(ti.task_id)), + ti.map_index, + ti.run_id, + session=session, + ) ) else: upstream_map_indexes = None @@ -285,7 +293,7 @@ def ti_run( def _get_upstream_map_indexes( - task: Operator, ti_map_index: int, run_id: str, session: SessionDep + task: MappedOperator | SerializedBaseOperator, ti_map_index: int, run_id: str, session: SessionDep ) -> Iterator[tuple[str, int | list[int] | None]]: task_mapped_group = task.get_closest_mapped_task_group() for upstream_task in task.upstream_list: diff --git a/airflow-core/src/airflow/cli/commands/dag_command.py b/airflow-core/src/airflow/cli/commands/dag_command.py index e1cbb105dda52..0df550aa49360 100644 --- a/airflow-core/src/airflow/cli/commands/dag_command.py +++ b/airflow-core/src/airflow/cli/commands/dag_command.py @@ -40,11 +40,12 @@ from airflow.exceptions import AirflowConfigException, AirflowException from airflow.jobs.job import Job from airflow.models import DagBag, DagModel, DagRun, TaskInstance +from airflow.models.dag import get_next_data_interval from airflow.models.dagbag import sync_bag_to_db from airflow.models.errors import ParseImportError from airflow.models.serialized_dag import SerializedDagModel from airflow.utils import cli as cli_utils -from airflow.utils.cli import get_dag, suppress_logs_and_warning, validate_dag_bundle_arg +from airflow.utils.cli import get_bagged_dag, suppress_logs_and_warning, validate_dag_bundle_arg from airflow.utils.dot_renderer import render_dag, render_dag_dependencies from airflow.utils.helpers import ask_yesno from airflow.utils.platform import getuser @@ -53,10 +54,12 @@ from airflow.utils.state import DagRunState if TYPE_CHECKING: + from collections.abc import Iterable + from graphviz.dot import Dot from sqlalchemy.orm import Session - from airflow.models.dag import DAG + from airflow import DAG from airflow.timetables.base import DataInterval DAG_DETAIL_FIELDS = {*DAGResponse.model_fields, *DAGResponse.model_computed_fields} @@ -125,21 +128,18 @@ def dag_unpause(args) -> None: @providers_configuration_loaded -def set_is_paused(is_paused: bool, args) -> None: +@provide_session +def set_is_paused(is_paused: bool, args, *, session: Session = NEW_SESSION) -> None: """Set is_paused for DAG by a given dag_id.""" - should_apply = True - with create_session() as session: - query = select(DagModel) - - if args.treat_dag_id_as_regex: - query = query.where(DagModel.dag_id.regexp_match(args.dag_id)) - else: - query = query.where(DagModel.dag_id == args.dag_id) - - query = query.where(DagModel.is_paused != is_paused) + query = select(DagModel) + if args.treat_dag_id_as_regex: + query = query.where(DagModel.dag_id.regexp_match(args.dag_id)) + else: + query = query.where(DagModel.dag_id == args.dag_id) - matched_dags = session.scalars(query).all() + query = query.where(DagModel.is_paused != is_paused) + matched_dags: list[DagModel] = session.scalars(query).all() if not matched_dags: print(f"No {'un' if is_paused else ''}paused DAGs were found") return @@ -151,18 +151,20 @@ def set_is_paused(is_paused: bool, args) -> None: f"{','.join(dags_ids)}" f"\n\nAre you sure? [y/n]" ) - should_apply = ask_yesno(question) + if not ask_yesno(question): + print("Operation cancelled by user") + return - if should_apply: - for dag_model in matched_dags: - dag_model.set_is_paused(is_paused=is_paused) + def _update_is_paused(dag_model: DagModel) -> bool: + old_is_paused = dag_model.is_paused + dag_model.is_paused = is_paused + return old_is_paused - AirflowConsole().print_as( - data=[{"dag_id": dag.dag_id, "is_paused": not dag.get_is_paused()} for dag in matched_dags], - output=args.output, - ) - else: - print("Operation cancelled by user") + old_values = [ + {"dag_id": dag_model.dag_id, "is_paused": _update_is_paused(dag_model)} for dag_model in matched_dags + ] + session.commit() + AirflowConsole().print_as(data=old_values, output=args.output) @providers_configuration_loaded @@ -192,7 +194,11 @@ def dag_dependencies_show(args) -> None: @providers_configuration_loaded def dag_show(args) -> None: """Display DAG or saves its graphic representation to the file.""" - dag = get_dag(bundle_names=None, dag_id=args.dag_id, from_db=True) + from airflow.models.serialized_dag import SerializedDagModel + + if not (dag := SerializedDagModel.get_dag(dag_id=args.dag_id)): + raise SystemExit(f"Can not find dag {args.dag_id!r} in database") + dot = render_dag(dag) filename = args.save imgcat = args.imgcat @@ -231,15 +237,16 @@ def _save_dot_to_file(dot: Dot, filename: str) -> None: print(f"File {filename} saved") -def _get_dagbag_dag_details(dag: DAG) -> dict: +def _get_dagbag_dag_details(dag: DAG, session: Session) -> dict: """Return a dagbag dag details dict.""" + dag_model: DagModel | None = session.get(DagModel, dag.dag_id) return { "dag_id": dag.dag_id, "dag_display_name": dag.dag_display_name, - "bundle_name": dag.get_bundle_name() if hasattr(dag, "get_bundle_name") else None, - "bundle_version": dag.get_bundle_version() if hasattr(dag, "get_bundle_version") else None, - "is_paused": dag.get_is_paused() if hasattr(dag, "get_is_paused") else None, - "is_stale": dag.get_is_stale() if hasattr(dag, "get_is_stale") else None, + "bundle_name": dag_model.bundle_name if dag_model else None, + "bundle_version": dag_model.bundle_version if dag_model else None, + "is_paused": dag_model.is_paused if dag_model else None, + "is_stale": dag_model.is_stale if dag_model else None, "last_parsed_time": None, "last_expired": None, "relative_fileloc": dag.relative_fileloc, @@ -302,14 +309,18 @@ def dag_next_execution(args) -> None: >>> airflow dags next-execution tutorial 2018-08-31 10:38:00 """ - dag = get_dag(bundle_names=None, dag_id=args.dag_id, from_db=True) + from airflow.models.serialized_dag import SerializedDagModel with create_session() as session: - last_parsed_dag: DagModel = session.scalars( - select(DagModel).where(DagModel.dag_id == dag.dag_id) - ).one() + dag = SerializedDagModel.get_dag(args.dag_id, session=session) + last_parsed_dag: DagModel | None = session.scalars( + select(DagModel).where(DagModel.dag_id == args.dag_id) + ).one_or_none() - if last_parsed_dag.get_is_paused(): + if not dag or not last_parsed_dag: + raise SystemExit(f"DAG: {args.dag_id} does not exist in the database") + + if last_parsed_dag.is_paused: print("[INFO] Please be reminded this DAG is PAUSED now.", file=sys.stderr) def print_execution_interval(interval: DataInterval | None): @@ -323,7 +334,7 @@ def print_execution_interval(interval: DataInterval | None): return print(interval.start.isoformat()) - next_interval = dag.get_next_data_interval(last_parsed_dag) + next_interval = get_next_data_interval(dag.timetable, last_parsed_dag) print_execution_interval(next_interval) for _ in range(1, args.num_executions): @@ -391,20 +402,23 @@ def dag_list_dags(args, session: Session = NEW_SESSION) -> None: def get_dag_detail(dag: DAG) -> dict: dag_model = DagModel.get_dagmodel(dag.dag_id, session=session) if dag_model: - dag_detail = DAGResponse.from_orm(dag_model).model_dump() + dag_detail = DAGResponse.model_validate(dag_model, from_attributes=True).model_dump() else: - dag_detail = _get_dagbag_dag_details(dag) + dag_detail = _get_dagbag_dag_details(dag, session) if not cols: return dag_detail return {col: dag_detail[col] for col in cols if col in DAG_DETAIL_FIELDS} - def filter_dags_by_bundle(dags: list[DAG], bundle_names: list[str] | None) -> list[DAG]: + def filter_dags_by_bundle(dags: Iterable[DAG], bundle_names: list[str] | None) -> Iterable[DAG]: """Filter DAGs based on the specified bundle name, if provided.""" if not bundle_names: return dags validate_dag_bundle_arg(bundle_names) - return [dag for dag in dags if dag.get_bundle_name() in bundle_names] + selected_dag_ids = set( + session.scalars(select(DagModel.dag_id).where(DagModel.bundle_name.in_(bundle_names))) + ) + return (dag for dag in dags if dag.dag_id in selected_dag_ids) AirflowConsole().print_as( data=sorted( @@ -612,7 +626,11 @@ def dag_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> No re.compile(args.mark_success_pattern) if args.mark_success_pattern is not None else None ) - dag = dag or get_dag(bundle_names=args.bundle_name, dag_id=args.dag_id, dagfile_path=args.dagfile_path) + dag = dag or get_bagged_dag( + bundle_names=args.bundle_name, + dag_id=args.dag_id, + dagfile_path=args.dagfile_path, + ) if not dag: raise AirflowException( f"Dag {args.dag_id!r} could not be found; either it does not exist or it failed to parse." diff --git a/airflow-core/src/airflow/cli/commands/task_command.py b/airflow-core/src/airflow/cli/commands/task_command.py index 8962c6d640f1f..5b54d76100a50 100644 --- a/airflow-core/src/airflow/cli/commands/task_command.py +++ b/airflow-core/src/airflow/cli/commands/task_command.py @@ -33,9 +33,9 @@ from airflow.cli.utils import fetch_dag_run_from_run_id_or_logical_date_string from airflow.exceptions import AirflowConfigException, DagRunNotFound, TaskInstanceNotFound from airflow.models import TaskInstance -from airflow.models.dag import DAG as SchedulerDAG, _get_or_create_dagrun from airflow.models.dag_version import DagVersion -from airflow.models.dagrun import DagRun +from airflow.models.dagrun import DagRun, get_or_create_dagrun +from airflow.models.serialized_dag import SerializedDagModel from airflow.sdk.definitions.dag import DAG, _run_task from airflow.sdk.definitions.param import ParamsDict from airflow.serialization.serialized_objects import SerializedDAG @@ -43,11 +43,12 @@ from airflow.ti_deps.dependencies_deps import SCHEDULER_QUEUED_DEPS from airflow.utils import cli as cli_utils from airflow.utils.cli import ( - get_dag, + get_bagged_dag, get_dag_by_file_location, get_dags, suppress_logs_and_warning, ) +from airflow.utils.helpers import ask_yesno from airflow.utils.platform import getuser from airflow.utils.providers_configuration_loader import providers_configuration_loaded from airflow.utils.session import NEW_SESSION, create_session, provide_session @@ -144,7 +145,7 @@ def _get_dag_run( return dag_run, True if create_if_necessary == "db": scheduler_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag)) # type: ignore[arg-type] - dag_run = _get_or_create_dagrun( + dag_run = get_or_create_dagrun( dag=scheduler_dag, run_id=_generate_temporary_run_id(), logical_date=dag_run_logical_date, @@ -245,7 +246,7 @@ def task_failed_deps(args) -> None: Trigger Rule: Task's trigger rule 'all_success' requires all upstream tasks to have succeeded, but found 1 non-success(es). """ - dag = get_dag(args.bundle_name, args.dag_id) + dag = get_bagged_dag(args.bundle_name, args.dag_id) # TODO (GH-52141): get_task in scheduler needs to return scheduler types # instead, but currently it inherits SDK's DAG. task = cast("Operator", dag.get_task(task_id=args.task_id)) @@ -271,7 +272,8 @@ def task_state(args) -> None: >>> airflow tasks state tutorial sleep 2015-01-01 success """ - dag = get_dag(args.bundle_name, args.dag_id, from_db=True) + if not (dag := SerializedDagModel.get_dag(args.dag_id)): + raise SystemExit(f"Can not find dag {args.dag_id!r}") # TODO (GH-52141): get_task in scheduler needs to return scheduler types # instead, but currently it inherits SDK's DAG. task = cast("Operator", dag.get_task(task_id=args.task_id)) @@ -284,7 +286,7 @@ def task_state(args) -> None: @providers_configuration_loaded def task_list(args, dag: DAG | None = None) -> None: """List the tasks within a DAG at the command line.""" - dag = dag or get_dag(args.bundle_name, args.dag_id) + dag = dag or get_bagged_dag(args.bundle_name, args.dag_id) tasks = sorted(t.task_id for t in dag.tasks) print("\n".join(tasks)) @@ -387,7 +389,7 @@ def task_test(args, dag: DAG | None = None) -> None: env_vars.update(args.env_vars) os.environ.update(env_vars) - dag = dag or get_dag(args.bundle_name, args.dag_id) + dag = dag or get_bagged_dag(args.bundle_name, args.dag_id) # TODO (GH-52141): get_task in scheduler needs to return scheduler types # instead, but currently it inherits SDK's DAG. @@ -429,25 +431,38 @@ def task_test(args, dag: DAG | None = None) -> None: def task_render(args, dag: DAG | None = None) -> None: """Render and displays templated fields for a given task.""" if not dag: - dag = get_dag(args.bundle_name, args.dag_id) - # TODO (GH-52141): get_task in scheduler needs to return scheduler types - # instead, but currently it inherits SDK's DAG. - task = cast("Operator", dag.get_task(task_id=args.task_id)) + dag = get_bagged_dag(args.bundle_name, args.dag_id) + serialized_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag)) ti, _ = _get_ti( - task, args.map_index, logical_date_or_run_id=args.logical_date_or_run_id, create_if_necessary="memory" + # TODO (GH-52141): get_task in scheduler needs to return scheduler types + # instead, but currently it inherits SDK's DAG. + cast("Operator", serialized_dag.get_task(task_id=args.task_id)), + args.map_index, + logical_date_or_run_id=args.logical_date_or_run_id, + create_if_necessary="memory", ) + with create_session() as session, set_current_task_instance_session(session=session): - ti.render_templates() - for attr in task.template_fields: - print( - textwrap.dedent( - f""" # ---------------------------------------------------------- - # property: {attr} - # ---------------------------------------------------------- - """ + context = ti.get_template_context(session=session) + task = dag.get_task(args.task_id) + # TODO (GH-52141): After sdk separation, ti.get_template_context() would + # contain serialized operators, but we need the real operators for + # rendering. This does not make sense and eventually we should rewrite + # this entire function so "ti" is a RuntimeTaskInstance instead, but for + # now we'll just manually fix it to contain the right objects. + context["task"] = context["ti"].task = task + task.render_template_fields(context) + for attr in context["task"].template_fields: + print( + textwrap.dedent( + f"""\ + # ---------------------------------------------------------- + # property: {attr} + # ---------------------------------------------------------- + """ + ) + + str(getattr(context["task"], attr)) # This shouldn't be dedented. ) - + str(getattr(ti.task, attr)) - ) @cli_utils.action_cli(check_db=False) @@ -470,11 +485,25 @@ def task_clear(args) -> None: include_upstream=args.upstream, ) - SchedulerDAG.clear_dags( + if not args.yes: + tis = SerializedDAG.clear_dags( + dags, + start_date=args.start_date, + end_date=args.end_date, + only_failed=args.only_failed, + only_running=args.only_running, + dry_run=True, + ) + if not tis: + return + if not ask_yesno(f"You are about to delete these {len(tis)} tasks:\n{tis}\n\nAre you sure? [y/n]"): + print("Cancelled, nothing was cleared.") + return + + SerializedDAG.clear_dags( dags, start_date=args.start_date, end_date=args.end_date, only_failed=args.only_failed, only_running=args.only_running, - confirm_prompt=not args.yes, ) diff --git a/airflow-core/src/airflow/cli/utils.py b/airflow-core/src/airflow/cli/utils.py index 4c7e6409e53a8..b221521e01d7e 100644 --- a/airflow-core/src/airflow/cli/utils.py +++ b/airflow-core/src/airflow/cli/utils.py @@ -79,10 +79,9 @@ def fetch_dag_run_from_run_id_or_logical_date_string( from sqlalchemy import select from airflow._shared.timezones import timezone - from airflow.models.dag import DAG from airflow.models.dagrun import DagRun - if dag_run := DAG.fetch_dagrun(dag_id=dag_id, run_id=value, session=session): + if dag_run := session.scalar(select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == value)): return dag_run, dag_run.logical_date try: logical_date = timezone.parse(value) diff --git a/airflow-core/src/airflow/dag_processing/collection.py b/airflow-core/src/airflow/dag_processing/collection.py index e9264c38d322e..8857a7f30250b 100644 --- a/airflow-core/src/airflow/dag_processing/collection.py +++ b/airflow-core/src/airflow/dag_processing/collection.py @@ -29,7 +29,7 @@ import logging import traceback -from typing import TYPE_CHECKING, NamedTuple +from typing import TYPE_CHECKING, NamedTuple, TypeVar from sqlalchemy import delete, func, insert, select, tuple_, update from sqlalchemy.exc import OperationalError @@ -48,12 +48,15 @@ TaskInletAssetReference, TaskOutletAssetReference, ) -from airflow.models.dag import DAG, DagModel, DagOwnerAttributes, DagTag +from airflow.models.dag import DagModel, DagOwnerAttributes, DagTag, get_run_data_interval from airflow.models.dagrun import DagRun from airflow.models.dagwarning import DagWarningType from airflow.models.errors import ParseImportError from airflow.models.trigger import Trigger -from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUriRef +from airflow.sdk import Asset, AssetAlias +from airflow.sdk.definitions.asset import AssetNameRef, AssetUriRef, BaseAsset +from airflow.serialization.enums import Encoding +from airflow.serialization.serialized_objects import BaseSerialization, LazyDeserializedDAG, SerializedDAG from airflow.triggers.base import BaseEventTrigger from airflow.utils.retries import MAX_DB_RETRIES, run_with_db_retries from airflow.utils.sqlalchemy import with_row_locks @@ -66,14 +69,18 @@ from sqlalchemy.sql import Select from airflow.models.dagwarning import DagWarning - from airflow.serialization.serialized_objects import MaybeSerializedDAG from airflow.typing_compat import Self +AssetT = TypeVar("AssetT", bound=BaseAsset) + log = logging.getLogger(__name__) def _create_orm_dags( - bundle_name: str, dags: Iterable[MaybeSerializedDAG], *, session: Session + bundle_name: str, + dags: Iterable[LazyDeserializedDAG], + *, + session: Session, ) -> Iterator[DagModel]: for dag in dags: orm_dag = DagModel(dag_id=dag.dag_id, bundle_name=bundle_name) @@ -129,7 +136,7 @@ class _RunInfo(NamedTuple): num_active_runs: dict[str, int] @classmethod - def calculate(cls, dags: dict[str, MaybeSerializedDAG], *, session: Session) -> Self: + def calculate(cls, dags: dict[str, LazyDeserializedDAG], *, session: Session) -> Self: """ Query the the run counts from the db. @@ -175,7 +182,7 @@ def _update_dag_owner_links(dag_owner_links: dict[str, str], dm: DagModel, *, se def _serialize_dag_capturing_errors( - dag: MaybeSerializedDAG, bundle_name, session: Session, bundle_version: str | None + dag: LazyDeserializedDAG, bundle_name, session: Session, bundle_version: str | None ): """ Try to serialize the dag to the DB, but make a note of any errors. @@ -216,7 +223,7 @@ def _serialize_dag_capturing_errors( ] -def _sync_dag_perms(dag: MaybeSerializedDAG, session: Session): +def _sync_dag_perms(dag: LazyDeserializedDAG, session: Session): """Sync DAG specific permissions.""" dag_id = dag.dag_id @@ -334,7 +341,7 @@ def _update_import_errors( def update_dag_parsing_results_in_db( bundle_name: str, bundle_version: str | None, - dags: Collection[MaybeSerializedDAG], + dags: Collection[LazyDeserializedDAG], import_errors: dict[tuple[str, str], str], warnings: set[DagWarning], session: Session, @@ -371,12 +378,15 @@ def update_dag_parsing_results_in_db( ) log.debug("Calling the DAG.bulk_sync_to_db method") try: - DAG.bulk_write_to_db(bundle_name, bundle_version, dags, session=session) + SerializedDAG.bulk_write_to_db(bundle_name, bundle_version, dags, session=session) # Write Serialized DAGs to DB, capturing errors for dag in dags: serialize_errors.extend( _serialize_dag_capturing_errors( - dag=dag, bundle_name=bundle_name, bundle_version=bundle_version, session=session + dag=dag, + bundle_name=bundle_name, + bundle_version=bundle_version, + session=session, ) ) except OperationalError: @@ -384,7 +394,7 @@ def update_dag_parsing_results_in_db( raise # Only now we are "complete" do we update import_errors - don't want to record errors from # previous failed attempts - import_errors.update(dict(serialize_errors)) + import_errors.update(serialize_errors) # Record import errors into the ORM - we don't retry on this one as it's not as critical that it works try: # TODO: This won't clear errors for files that exist that no longer contain DAGs. Do we need to pass @@ -416,7 +426,7 @@ def update_dag_parsing_results_in_db( class DagModelOperation(NamedTuple): """Collect DAG objects and perform database operations for them.""" - dags: dict[str, MaybeSerializedDAG] + dags: dict[str, LazyDeserializedDAG] bundle_name: str bundle_version: str | None @@ -454,11 +464,7 @@ def update_dags( from airflow.configuration import conf # we exclude backfill from active run counts since their concurrency is separate - run_info = _RunInfo.calculate( - dags=self.dags, - session=session, - ) - + run_info = _RunInfo.calculate(dags=self.dags, session=session) for dag_id, dm in sorted(orm_dags.items()): dag = self.dags[dag_id] dm.fileloc = dag.fileloc @@ -516,7 +522,7 @@ def update_dags( if last_automated_run is None: last_automated_data_interval = None else: - last_automated_data_interval = dag.get_run_data_interval(last_automated_run) + last_automated_data_interval = get_run_data_interval(dag.timetable, last_automated_run) if run_info.num_active_runs.get(dag.dag_id, 0) >= dm.max_active_runs: dm.next_dagrun_create_after = None else: @@ -590,19 +596,41 @@ def expand_asset_expr(asset_expr: dict[str, list | dict]) -> dict[str, list | di dm.asset_expression = asset_expression -def _find_all_assets(dags: Iterable[MaybeSerializedDAG]) -> Iterator[Asset]: +def _get_task_ports(data: dict, inlets: bool, outlets: bool) -> Iterable[str]: + if inlets: + yield from data.get("inlets") or () + if outlets: + yield from data.get("outlets") or () + + +def _get_dag_assets( + dag: LazyDeserializedDAG, + of: type[AssetT], + *, + inlets: bool = True, + outlets: bool = True, +) -> Iterable[tuple[str, AssetT]]: + for task in dag.data["dag"]["tasks"]: + task = task[Encoding.VAR] + ports = _get_task_ports(task["partial_kwargs"] if task.get("_is_mapped") else task, inlets, outlets) + for port in ports: + if isinstance(obj := BaseSerialization.deserialize(port), of): + yield task["task_id"], obj + + +def _find_all_assets(dags: Iterable[LazyDeserializedDAG]) -> Iterator[Asset]: for dag in dags: for _, asset in dag.timetable.asset_condition.iter_assets(): yield asset - for _, asset in dag.get_task_assets(of_type=Asset): + for _, asset in _get_dag_assets(dag, of=Asset): yield asset -def _find_all_asset_aliases(dags: Iterable[MaybeSerializedDAG]) -> Iterator[AssetAlias]: +def _find_all_asset_aliases(dags: Iterable[LazyDeserializedDAG]) -> Iterator[AssetAlias]: for dag in dags: for _, alias in dag.timetable.asset_condition.iter_asset_aliases(): yield alias - for _, alias in dag.get_task_assets(of_type=AssetAlias): + for _, alias in _get_dag_assets(dag, of=AssetAlias): yield alias @@ -633,7 +661,7 @@ class AssetModelOperation(NamedTuple): asset_aliases: dict[str, AssetAlias] @classmethod - def collect(cls, dags: dict[str, MaybeSerializedDAG]) -> Self: + def collect(cls, dags: dict[str, LazyDeserializedDAG]) -> Self: coll = cls( schedule_asset_references={ dag_id: [asset for _, asset in dag.timetable.asset_condition.iter_assets()] @@ -656,10 +684,12 @@ def collect(cls, dags: dict[str, MaybeSerializedDAG]) -> Self: if isinstance(ref, AssetUriRef) }, inlet_references={ - dag_id: list(dag.get_task_assets(inlets=True, outlets=False)) for dag_id, dag in dags.items() + dag_id: list(_get_dag_assets(dag, Asset, inlets=True, outlets=False)) + for dag_id, dag in dags.items() }, outlet_references={ - dag_id: list(dag.get_task_assets(inlets=False, outlets=True)) for dag_id, dag in dags.items() + dag_id: list(_get_dag_assets(dag, Asset, inlets=False, outlets=True)) + for dag_id, dag in dags.items() }, assets={(asset.name, asset.uri): asset for asset in _find_all_assets(dags.values())}, asset_aliases={alias.name: alias for alias in _find_all_asset_aliases(dags.values())}, diff --git a/airflow-core/src/airflow/dag_processing/processor.py b/airflow-core/src/airflow/dag_processing/processor.py index 7e29799021612..e5fbd9b436f77 100644 --- a/airflow-core/src/airflow/dag_processing/processor.py +++ b/airflow-core/src/airflow/dag_processing/processor.py @@ -186,10 +186,9 @@ def _parse_file(msg: DagFileParseRequest, log: FilteringBoundLogger) -> DagFileP serialized_dags, serialization_import_errors = _serialize_dags(bag, log) bag.import_errors.update(serialization_import_errors) - dags = [LazyDeserializedDAG(data=serdag) for serdag in serialized_dags] result = DagFileParsingResult( fileloc=msg.file, - serialized_dags=dags, + serialized_dags=serialized_dags, import_errors=bag.import_errors, # TODO: Make `bag.dag_warnings` not return SQLA model objects warnings=[], @@ -197,13 +196,16 @@ def _parse_file(msg: DagFileParseRequest, log: FilteringBoundLogger) -> DagFileP return result -def _serialize_dags(bag: DagBag, log: FilteringBoundLogger) -> tuple[list[dict], dict[str, str]]: +def _serialize_dags( + bag: DagBag, + log: FilteringBoundLogger, +) -> tuple[list[LazyDeserializedDAG], dict[str, str]]: serialization_import_errors = {} serialized_dags = [] for dag in bag.dags.values(): try: - serialized_dag = SerializedDAG.to_dict(dag) - serialized_dags.append(serialized_dag) + data = SerializedDAG.to_dict(dag) + serialized_dags.append(LazyDeserializedDAG(data=data, last_loaded=dag.last_loaded)) except Exception: log.exception("Failed to serialize DAG: %s", dag.fileloc) dagbag_import_error_traceback_depth = conf.getint( diff --git a/airflow-core/src/airflow/datasets/__init__.py b/airflow-core/src/airflow/datasets/__init__.py index d0622b67b1f19..b5aded5fa588d 100644 --- a/airflow-core/src/airflow/datasets/__init__.py +++ b/airflow-core/src/airflow/datasets/__init__.py @@ -30,10 +30,10 @@ # TODO: Remove this module in Airflow 3.2 _names_moved = { - "DatasetAlias": ("airflow.sdk.definitions.asset", "AssetAlias"), - "DatasetAll": ("airflow.sdk.definitions.asset", "AssetAll"), - "DatasetAny": ("airflow.sdk.definitions.asset", "AssetAny"), - "Dataset": ("airflow.sdk.definitions.asset", "Asset"), + "DatasetAlias": ("airflow.sdk", "AssetAlias"), + "DatasetAll": ("airflow.sdk", "AssetAll"), + "DatasetAny": ("airflow.sdk", "AssetAny"), + "Dataset": ("airflow.sdk", "Asset"), "expand_alias_to_datasets": ("airflow.models.asset", "expand_alias_to_assets"), } @@ -45,8 +45,8 @@ def __getattr__(name: str): module_path, new_name = _names_moved[name] warnings.warn( - f"Import 'airflow.dataset.{name}' is deprecated and " - f"will be removed in the Airflow 3.2. Please import it from '{module_path}.{new_name}'.", + f"Import 'airflow.datasets.{name}' is deprecated and " + f"will be removed in Airflow 3.2. Please import it from '{module_path}.{new_name}'.", DeprecationWarning, stacklevel=2, ) diff --git a/airflow-core/src/airflow/datasets/metadata.py b/airflow-core/src/airflow/datasets/metadata.py index ef4e8037faa63..a5abf24650232 100644 --- a/airflow-core/src/airflow/datasets/metadata.py +++ b/airflow-core/src/airflow/datasets/metadata.py @@ -19,13 +19,13 @@ import warnings -from airflow.sdk.definitions.asset.metadata import Metadata +from airflow.sdk import Metadata # TODO: Remove this module in Airflow 3.2 warnings.warn( - "Import from the airflow.dataset module is deprecated and " - "will be removed in the Airflow 3.2. Please import it from 'airflow.sdk.definitions.asset.metadata'.", + "Import from the airflow.datasets.metadata module is deprecated and will " + "be removed in Airflow 3.2. Please import it from 'airflow.sdk'.", DeprecationWarning, stacklevel=2, ) diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index 89c9c2c7bfcf2..217d8d663c395 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -59,7 +59,7 @@ asset_trigger_association_table, ) from airflow.models.backfill import Backfill -from airflow.models.dag import DAG, DagModel +from airflow.models.dag import DagModel, get_next_data_interval, get_run_data_interval from airflow.models.dag_version import DagVersion from airflow.models.dagbag import DBDagBag from airflow.models.dagrun import DagRun @@ -94,7 +94,7 @@ from airflow.executors.executor_utils import ExecutorName from airflow.models.mappedoperator import MappedOperator from airflow.models.taskinstance import TaskInstanceKey - from airflow.serialization.serialized_objects import SerializedBaseOperator + from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG from airflow.utils.sqlalchemy import CommitProhibitorGuard TI = TaskInstance @@ -105,7 +105,7 @@ """:meta private:""" -def _get_current_dag(dag_id: str, session: Session) -> DAG | None: +def _get_current_dag(dag_id: str, session: Session) -> SerializedDAG | None: serdag = SerializedDagModel.get(dag_id=dag_id, session=session) # grabs the latest version if not serdag: return None @@ -1395,7 +1395,7 @@ def _do_scheduling(self, session: Session) -> int: # Send the callbacks after we commit to ensure the context is up to date when it gets run # cache saves time during scheduling of many dag_runs for same dag - cached_get_dag: Callable[[DagRun], DAG | None] = lru_cache()( + cached_get_dag: Callable[[DagRun], SerializedDAG | None] = lru_cache()( partial(self.scheduler_dag_bag.get_dag_for_run, session=session) ) for dag_run, callback_to_run in callback_tuples: @@ -1521,7 +1521,7 @@ def _create_dag_runs(self, dag_models: Collection[DagModel], session: Session) - self.log.error("DAG '%s' not found in serialized_dag table", dag_model.dag_id) continue - data_interval = dag.get_next_data_interval(dag_model) + data_interval = get_next_data_interval(dag.timetable, dag_model) # Explicitly check if the DagRun already exists. This is an edge case # where a Dag Run is created but `DagModel.next_dagrun` and `DagModel.next_dagrun_create_after` # are not updated. @@ -1642,7 +1642,7 @@ def _create_dag_runs_asset_triggered( def _should_update_dag_next_dagruns( self, - dag: DAG, + dag: SerializedDAG, dag_model: DagModel, *, last_dag_run: DagRun | None = None, @@ -1697,7 +1697,7 @@ def _start_queued_dagruns(self, session: Session) -> None: active_runs_of_dags = Counter({(dag_id, br_id): num for dag_id, br_id, num in session.execute(query)}) @add_debug_span - def _update_state(dag: DAG, dag_run: DagRun): + def _update_state(dag: SerializedDAG, dag_run: DagRun): span = Trace.get_current_span() span.set_attributes( { @@ -1723,7 +1723,7 @@ def _update_state(dag: DAG, dag_run: DagRun): # always happening immediately after the data interval. # We only publish these metrics for scheduled dag runs and only # when ``run_type`` is *MANUAL* and ``clear_number`` is 0. - expected_start_date = dag.get_run_data_interval(dag_run).end + expected_start_date = get_run_data_interval(dag.timetable, dag_run).end schedule_delay = dag_run.start_date - expected_start_date # Publish metrics twice with backward compatible name, and then with tags Stats.timing(f"dagrun.schedule_delay.{dag.dag_id}", schedule_delay) @@ -1739,7 +1739,7 @@ def _update_state(dag: DAG, dag_run: DagRun): ) # cache saves time during scheduling of many dag_runs for same dag - cached_get_dag: Callable[[DagRun], DAG | None] = lru_cache()( + cached_get_dag: Callable[[DagRun], SerializedDAG | None] = lru_cache()( partial(self.scheduler_dag_bag.get_dag_for_run, session=session) ) @@ -1858,7 +1858,7 @@ def _schedule_dag_run( if self._should_update_dag_next_dagruns( dag, dag_model, last_dag_run=dag_run, session=session ): - dag_model.calculate_dagrun_date_fields(dag, dag.get_run_data_interval(dag_run)) + dag_model.calculate_dagrun_date_fields(dag, get_run_data_interval(dag.timetable, dag_run)) callback_to_execute = DagCallbackRequest( filepath=dag_model.relative_fileloc, @@ -1918,7 +1918,7 @@ def _schedule_dag_run( schedulable_tis, callback_to_run = dag_run.update_state(session=session, execute_callbacks=False) if self._should_update_dag_next_dagruns(dag, dag_model, last_dag_run=dag_run, session=session): - dag_model.calculate_dagrun_date_fields(dag, dag.get_run_data_interval(dag_run)) + dag_model.calculate_dagrun_date_fields(dag, get_run_data_interval(dag.timetable, dag_run)) # This will do one query per dag run. We "could" build up a complex # query to update all the TIs across all the logical dates and dag # IDs in a single query, but it turns out that can be _very very slow_ @@ -1961,7 +1961,11 @@ def _verify_integrity_if_dag_changed(self, dag_run: DagRun, session: Session) -> return True - def _send_dag_callbacks_to_processor(self, dag: DAG, callback: DagCallbackRequest | None = None) -> None: + def _send_dag_callbacks_to_processor( + self, + dag: SerializedDAG, + callback: DagCallbackRequest | None = None, + ) -> None: if callback: self.job.executor.send_callback(callback) else: diff --git a/airflow-core/src/airflow/models/__init__.py b/airflow-core/src/airflow/models/__init__.py index c87ec2a68d599..06fbeb78a7702 100644 --- a/airflow-core/src/airflow/models/__init__.py +++ b/airflow-core/src/airflow/models/__init__.py @@ -36,6 +36,7 @@ "DagRun", "DagTag", "DbCallbackRequest", + "Deadline", "Log", "MappedOperator", "Operator", @@ -44,6 +45,7 @@ "RenderedTaskInstanceFields", "SkipMixin", "TaskInstance", + "TaskInstanceHistory", "TaskReschedule", "Trigger", "Variable", @@ -89,11 +91,11 @@ def __getattr__(name): __lazy_imports = { "Job": "airflow.jobs.job", - "DAG": "airflow.models.dag", + "DAG": "airflow.sdk", "ID_LEN": "airflow.models.base", "Base": "airflow.models.base", - "BaseOperator": "airflow.sdk.bases.operator", - "BaseOperatorLink": "airflow.sdk.bases.operatorlink", + "BaseOperator": "airflow.sdk", + "BaseOperatorLink": "airflow.sdk", "BaseXCom": "airflow.sdk.bases.xcom", "Connection": "airflow.models.connection", "DagBag": "airflow.models.dagbag", @@ -124,7 +126,7 @@ def __getattr__(name): # having to resort back to this hacky method from airflow.models.base import ID_LEN, Base from airflow.models.connection import Connection - from airflow.models.dag import DAG, DagModel, DagTag + from airflow.models.dag import DagModel, DagTag from airflow.models.dagbag import DagBag from airflow.models.dagrun import DagRun from airflow.models.dagwarning import DagWarning @@ -140,10 +142,8 @@ def __getattr__(name): from airflow.models.taskreschedule import TaskReschedule from airflow.models.trigger import Trigger from airflow.models.variable import Variable - from airflow.sdk.bases.operator import BaseOperator - from airflow.sdk.bases.operatorlink import BaseOperatorLink + from airflow.sdk import DAG, BaseOperator, BaseOperatorLink, Param from airflow.sdk.bases.xcom import BaseXCom - from airflow.sdk.definitions.param import Param from airflow.sdk.execution_time.xcom import XCom @@ -157,7 +157,7 @@ def __getattr__(name): "DEFAULT_TASK_EXECUTION_TIMEOUT": "airflow.sdk.definitions._internal.abstractoperator.DEFAULT_TASK_EXECUTION_TIMEOUT", }, "param": { - "Param": "airflow.sdk.definitions.param.Param", + "Param": "airflow.sdk.Param", "ParamsDict": "airflow.sdk.definitions.param.ParamsDict", }, "baseoperator": { @@ -167,10 +167,10 @@ def __getattr__(name): "cross_downstream": "airflow.sdk.bases.operator.cross_downstream", }, "baseoperatorlink": { - "BaseOperatorLink": "airflow.sdk.bases.operatorlink.BaseOperatorLink", + "BaseOperatorLink": "airflow.sdk.BaseOperatorLink", }, "operator": { - "BaseOperator": "airflow.sdk.bases.operator.BaseOperator", + "BaseOperator": "airflow.sdk.BaseOperator", "Operator": "airflow.sdk.types.Operator", }, } diff --git a/airflow-core/src/airflow/models/backfill.py b/airflow-core/src/airflow/models/backfill.py index 4142265b8bd44..4715e0202e292 100644 --- a/airflow-core/src/airflow/models/backfill.py +++ b/airflow-core/src/airflow/models/backfill.py @@ -54,7 +54,7 @@ if TYPE_CHECKING: from datetime import datetime - from airflow.models.dag import DAG + from airflow.serialization.serialized_objects import SerializedDAG from airflow.timetables.base import DagRunInfo log = logging.getLogger(__name__) @@ -284,7 +284,7 @@ def _do_dry_run(*, dag_id, from_date, to_date, reverse, reprocess_behavior, sess def _create_backfill_dag_run( *, - dag: DAG, + dag: SerializedDAG, info: DagRunInfo, reprocess_behavior: ReprocessBehavior, backfill_id, @@ -415,7 +415,6 @@ def _handle_clear_run(session, dag, dr, info, backfill_id, sort_ordinal, run_on_ run_id=dr.run_id, dag_run_state=DagRunState.QUEUED, session=session, - confirm_prompt=False, dry_run=False, run_on_latest_version=run_on_latest, ) diff --git a/airflow-core/src/airflow/models/dag.py b/airflow-core/src/airflow/models/dag.py index 9d801c197e641..40a229e02dea8 100644 --- a/airflow-core/src/airflow/models/dag.py +++ b/airflow-core/src/airflow/models/dag.py @@ -17,26 +17,12 @@ # under the License. from __future__ import annotations -import copy -import functools import logging -import re from collections import defaultdict -from collections.abc import Callable, Collection, Generator, Iterable, Sequence +from collections.abc import Callable, Collection from datetime import datetime, timedelta -from functools import cache -from typing import ( - TYPE_CHECKING, - Any, - TypeVar, - Union, - cast, - overload, -) +from typing import TYPE_CHECKING, Any, TypeVar, Union, cast -import attrs -import methodtools -import pendulum import sqlalchemy_jsonfield from dateutil.relativedelta import relativedelta from sqlalchemy import ( @@ -52,68 +38,42 @@ func, or_, select, - tuple_, - update, ) from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import backref, load_only, relationship -from sqlalchemy.sql import Select, expression +from sqlalchemy.sql import expression -from airflow import settings, utils +from airflow import settings from airflow._shared.timezones import timezone from airflow.assets.evaluation import AssetEvaluator from airflow.configuration import conf as airflow_conf -from airflow.exceptions import ( - AirflowException, - UnknownExecutorException, -) -from airflow.executors.executor_loader import ExecutorLoader -from airflow.models import Deadline -from airflow.models.asset import ( - AssetDagRunQueue, - AssetModel, -) +from airflow.exceptions import AirflowException +from airflow.models.asset import AssetDagRunQueue, AssetModel from airflow.models.base import Base, StringID -from airflow.models.dag_version import DagVersion from airflow.models.dagbundle import DagBundleModel -from airflow.models.dagrun import RUN_ID_REGEX, DagRun -from airflow.models.taskinstance import ( - TaskInstance, - TaskInstanceKey, - clear_task_instances, -) -from airflow.models.tasklog import LogTemplate +from airflow.models.dagrun import DagRun from airflow.models.team import Team -from airflow.sdk import TaskGroup from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey, BaseAsset -from airflow.sdk.definitions.dag import DAG as TaskSDKDag, dag as task_sdk_dag_decorator -from airflow.sdk.definitions.deadline import DeadlineAlert, DeadlineReference +from airflow.sdk.definitions.deadline import DeadlineAlert from airflow.settings import json -from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable +from airflow.timetables.base import DataInterval, Timetable from airflow.timetables.interval import CronDataIntervalTimetable, DeltaDataIntervalTimetable -from airflow.timetables.simple import ( - AssetTriggeredTimetable, - NullTimetable, - OnceTimetable, -) +from airflow.timetables.simple import AssetTriggeredTimetable, NullTimetable, OnceTimetable from airflow.utils.context import Context -from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.sqlalchemy import UtcDateTime, lock_rows, with_row_locks -from airflow.utils.state import DagRunState, TaskInstanceState -from airflow.utils.types import DagRunTriggeredByType, DagRunType +from airflow.utils.sqlalchemy import UtcDateTime, with_row_locks +from airflow.utils.state import DagRunState +from airflow.utils.types import DagRunType if TYPE_CHECKING: - from typing import Literal, TypeAlias + from typing import TypeAlias - from pydantic import NonNegativeInt from sqlalchemy.orm.query import Query from sqlalchemy.orm.session import Session - from airflow.models.dagbag import DBDagBag from airflow.models.mappedoperator import MappedOperator - from airflow.serialization.serialized_objects import MaybeSerializedDAG, SerializedBaseOperator + from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG Operator: TypeAlias = MappedOperator | SerializedBaseOperator @@ -129,6 +89,84 @@ ScheduleArg = ScheduleInterval | Timetable | BaseAsset | Collection[Union["Asset", "AssetAlias"]] +def infer_automated_data_interval(timetable: Timetable, logical_date: datetime) -> DataInterval: + """ + Infer a data interval for a run against this DAG. + + This method is used to bridge runs created prior to AIP-39 + implementation, which do not have an explicit data interval. Therefore, + this method only considers ``schedule_interval`` values valid prior to + Airflow 2.2. + + DO NOT call this method if there is a known data interval. + + :meta private: + """ + timetable_type = type(timetable) + if issubclass(timetable_type, (NullTimetable, OnceTimetable, AssetTriggeredTimetable)): + return DataInterval.exact(timezone.coerce_datetime(logical_date)) + start = timezone.coerce_datetime(logical_date) + if issubclass(timetable_type, CronDataIntervalTimetable): + end = cast("CronDataIntervalTimetable", timetable)._get_next(start) + elif issubclass(timetable_type, DeltaDataIntervalTimetable): + end = cast("DeltaDataIntervalTimetable", timetable)._get_next(start) + # Contributors: When the exception below is raised, you might want to + # add an 'elif' block here to handle custom timetables. Stop! The bug + # you're looking for is instead at when the DAG run (represented by + # logical_date) was created. See GH-31969 for an example: + # * Wrong fix: GH-32074 (modifies this function). + # * Correct fix: GH-32118 (modifies the DAG run creation code). + else: + raise ValueError(f"Not a valid timetable: {timetable!r}") + return DataInterval(start, end) + + +def get_run_data_interval(timetable: Timetable, run: DagRun) -> DataInterval: + """ + Get the data interval of this run. + + For compatibility, this method infers the data interval from the DAG's + schedule if the run does not have an explicit one set, which is possible for + runs created prior to AIP-39. + + This function is private to Airflow core and should not be depended on as a + part of the Python API. + + :meta private: + """ + data_interval = _get_model_data_interval(run, "data_interval_start", "data_interval_end") + if data_interval is not None: + return data_interval + # Compatibility: runs created before AIP-39 implementation don't have an + # explicit data interval. Try to infer from the logical date. + return infer_automated_data_interval(timetable, run.logical_date) + + +def get_next_data_interval(timetable: Timetable, dag_model: DagModel) -> DataInterval | None: + """ + Get the data interval of the next scheduled run. + + For compatibility, this method infers the data interval from the DAG's + schedule if the run does not have an explicit one set, which is possible + for runs created prior to AIP-39. + + This function is private to Airflow core and should not be depended on as a + part of the Python API. + + :meta private: + """ + if dag_model.next_dagrun is None: # Next run not scheduled. + return None + data_interval = dag_model.next_dagrun_data_interval + if data_interval is not None: + return data_interval + + # Compatibility: A run was scheduled without an explicit data interval. + # This means the run was scheduled before AIP-39 implementation. Try to + # infer from the logical date. + return infer_automated_data_interval(timetable, dag_model.next_dagrun) + + class InconsistentDataInterval(AirflowException): """ Exception raised when a model populates data interval fields incorrectly. @@ -224,1606 +262,6 @@ def get_asset_triggered_next_run_info( } -@provide_session -def _create_orm_dagrun( - *, - dag: DAG, - run_id: str, - logical_date: datetime | None, - data_interval: DataInterval | None, - run_after: datetime, - start_date: datetime | None, - conf: Any, - state: DagRunState | None, - run_type: DagRunType, - creating_job_id: int | None, - backfill_id: NonNegativeInt | None, - triggered_by: DagRunTriggeredByType, - triggering_user_name: str | None = None, - session: Session = NEW_SESSION, -) -> DagRun: - bundle_version = None - if not dag.disable_bundle_versioning: - bundle_version = session.scalar( - select(DagModel.bundle_version).where(DagModel.dag_id == dag.dag_id), - ) - dag_version = DagVersion.get_latest_version(dag.dag_id, session=session) - if not dag_version: - raise AirflowException(f"Cannot create DagRun for DAG {dag.dag_id} because the dag is not serialized") - - run = DagRun( - dag_id=dag.dag_id, - run_id=run_id, - logical_date=logical_date, - start_date=start_date, - run_after=run_after, - conf=conf, - state=state, - run_type=run_type, - creating_job_id=creating_job_id, - data_interval=data_interval, - triggered_by=triggered_by, - triggering_user_name=triggering_user_name, - backfill_id=backfill_id, - bundle_version=bundle_version, - ) - # Load defaults into the following two fields to ensure result can be serialized detached - run.log_template_id = int(session.scalar(select(func.max(LogTemplate.__table__.c.id)))) - run.created_dag_version = dag_version - run.consumed_asset_events = [] - session.add(run) - session.flush() - run.dag = dag - # create the associated task instances - # state is None at the moment of creation - run.verify_integrity(session=session, dag_version_id=dag_version.id) - return run - - -if TYPE_CHECKING: - dag = task_sdk_dag_decorator -else: - - def dag(dag_id: str = "", **kwargs): - return task_sdk_dag_decorator(dag_id, __DAG_class=DAG, __warnings_stacklevel_delta=3, **kwargs) - - -def _convert_max_consecutive_failed_dag_runs(val: int) -> int: - if val == 0: - val = airflow_conf.getint("core", "max_consecutive_failed_dag_runs_per_dag") - if val < 0: - raise ValueError( - f"Invalid max_consecutive_failed_dag_runs: {val}. Requires max_consecutive_failed_dag_runs >= 0" - ) - return val - - -@functools.total_ordering -@attrs.define(hash=False, repr=False, eq=False, slots=False) -class DAG(TaskSDKDag, LoggingMixin): - """ - A dag is a collection of tasks with directional dependencies. - - A dag also has a schedule, a start date and an end date (optional). For each schedule, - (say daily or hourly), the DAG needs to run each individual tasks as their dependencies - are met. Certain tasks have the property of depending on their own past, meaning that - they can't run until their previous schedule (and upstream tasks) are completed. - - DAGs essentially act as namespaces for tasks. A task_id can only be - added once to a DAG. - - Note that if you plan to use time zones all the dates provided should be pendulum - dates. See :ref:`timezone_aware_dags`. - - .. versionadded:: 2.4 - The *schedule* argument to specify either time-based scheduling logic - (timetable), or asset-driven triggers. - - .. versionchanged:: 3.0 - The default value of *schedule* has been changed to *None* (no schedule). - The previous default was ``timedelta(days=1)``. - - :param dag_id: The id of the DAG; must consist exclusively of alphanumeric - characters, dashes, dots and underscores (all ASCII) - :param description: The description for the DAG to e.g. be shown on the webserver - :param schedule: If provided, this defines the rules according to which DAG - runs are scheduled. Possible values include a cron expression string, - timedelta object, Timetable, or list of Asset objects. - See also :doc:`/howto/timetable`. - :param start_date: The timestamp from which the scheduler will - attempt to backfill. If this is not provided, backfilling must be done - manually with an explicit time range. - :param end_date: A date beyond which your DAG won't run, leave to None - for open-ended scheduling. - :param template_searchpath: This list of folders (non-relative) - defines where jinja will look for your templates. Order matters. - Note that jinja/airflow includes the path of your DAG file by - default - :param template_undefined: Template undefined type. - :param user_defined_macros: a dictionary of macros that will be exposed - in your jinja templates. For example, passing ``dict(foo='bar')`` - to this argument allows you to ``{{ foo }}`` in all jinja - templates related to this DAG. Note that you can pass any - type of object here. - :param user_defined_filters: a dictionary of filters that will be exposed - in your jinja templates. For example, passing - ``dict(hello=lambda name: 'Hello %s' % name)`` to this argument allows - you to ``{{ 'world' | hello }}`` in all jinja templates related to - this DAG. - :param default_args: A dictionary of default parameters to be used - as constructor keyword parameters when initialising operators. - Note that operators have the same hook, and precede those defined - here, meaning that if your dict contains `'depends_on_past': True` - here and `'depends_on_past': False` in the operator's call - `default_args`, the actual value will be `False`. - :param params: a dictionary of DAG level parameters that are made - accessible in templates, namespaced under `params`. These - params can be overridden at the task level. - :param max_active_tasks: the number of task instances allowed to run - concurrently - :param max_active_runs: maximum number of active DAG runs, beyond this - number of DAG runs in a running state, the scheduler won't create - new active DAG runs - :param max_consecutive_failed_dag_runs: (experimental) maximum number of consecutive failed DAG runs, - beyond this the scheduler will disable the DAG - :param dagrun_timeout: Specify the duration a DagRun should be allowed to run before it times out or - fails. Task instances that are running when a DagRun is timed out will be marked as skipped. - :param sla_miss_callback: DEPRECATED - The SLA feature is removed in Airflow 3.0, to be replaced with a new implementation in 3.1 - :param deadline: Optional Deadline Alert for the DAG. - Specifies a time by which the DAG run should be complete, either in the form of a static datetime - or calculated relative to a reference timestamp. If the deadline passes before completion, the - provided callback is triggered. - - **Example**: To set the deadline for one hour after the DAG run starts you could use :: - - DeadlineAlert( - reference=DeadlineReference.DAGRUN_LOGICAL_DATE, - interval=timedelta(hours=1), - callback=my_callback, - ) - - :param catchup: Perform scheduler catchup (or only run latest)? Defaults to False - :param on_failure_callback: A function or list of functions to be called when a DagRun of this dag fails. - A context dictionary is passed as a single parameter to this function. - :param on_success_callback: Much like the ``on_failure_callback`` except - that it is executed when the dag succeeds. - :param access_control: Specify optional DAG-level actions, e.g., - "{'role1': {'can_read'}, 'role2': {'can_read', 'can_edit', 'can_delete'}}" - or it can specify the resource name if there is a DAGs Run resource, e.g., - "{'role1': {'DAG Runs': {'can_create'}}, 'role2': {'DAGs': {'can_read', 'can_edit', 'can_delete'}}" - :param is_paused_upon_creation: Specifies if the dag is paused when created for the first time. - If the dag exists already, this flag will be ignored. If this optional parameter - is not specified, the global config setting will be used. - :param jinja_environment_kwargs: additional configuration options to be passed to Jinja - ``Environment`` for template rendering - - **Example**: to avoid Jinja from removing a trailing newline from template strings :: - - DAG( - dag_id="my-dag", - jinja_environment_kwargs={ - "keep_trailing_newline": True, - # some other jinja2 Environment options here - }, - ) - - **See**: `Jinja Environment documentation - `_ - - :param render_template_as_native_obj: If True, uses a Jinja ``NativeEnvironment`` - to render templates as native Python types. If False, a Jinja - ``Environment`` is used to render templates as string values. - :param tags: List of tags to help filtering DAGs in the UI. - :param owner_links: Dict of owners and their links, that will be clickable on the DAGs view UI. - Can be used as an HTTP link (for example the link to your Slack channel), or a mailto link. - e.g: {"dag_owner": "https://airflow.apache.org/"} - :param auto_register: Automatically register this DAG when it is used in a ``with`` block - :param fail_fast: Fails currently running tasks when task in DAG fails. - **Warning**: A fail fast dag can only have tasks with the default trigger rule ("all_success"). - An exception will be thrown if any task in a fail fast dag has a non default trigger rule. - :param dag_display_name: The display name of the DAG which appears on the UI. - """ - - partial: bool = False - last_loaded: datetime | None = attrs.field(factory=timezone.utcnow) - - # this will only be set at serialization time - # it's only use is for determining the relative fileloc based only on the serialize dag - _processor_dags_folder: str | None = attrs.field(init=False, default=None) - - # Override the default from parent class to use config - max_consecutive_failed_dag_runs: int = attrs.field( - default=0, - converter=_convert_max_consecutive_failed_dag_runs, - validator=attrs.validators.instance_of(int), - ) - - @property - def safe_dag_id(self): - return self.dag_id.replace(".", "__dot__") - - def validate(self): - super().validate() - self.validate_executor_field() - - def validate_executor_field(self): - for task in self.tasks: - if task.executor: - try: - ExecutorLoader.lookup_executor_name_by_str(task.executor) - except UnknownExecutorException: - raise UnknownExecutorException( - f"The specified executor {task.executor} for task {task.task_id} is not " - "configured. Review the core.executors Airflow configuration to add it or " - "update the executor configuration for this task." - ) - - @staticmethod - def _upgrade_outdated_dag_access_control(access_control=None): - """Look for outdated dag level actions in DAG access_controls and replace them with updated actions.""" - if access_control is None: - return None - updated_access_control = {} - for role, perms in access_control.items(): - updated_access_control[role] = updated_access_control.get(role, {}) - if isinstance(perms, (set, list)): - # Support for old-style access_control where only the actions are specified - updated_access_control[role]["DAGs"] = set(perms) - else: - updated_access_control[role] = perms - return updated_access_control - - def get_next_data_interval(self, dag_model: DagModel) -> DataInterval | None: - """ - Get the data interval of the next scheduled run. - - For compatibility, this method infers the data interval from the DAG's - schedule if the run does not have an explicit one set, which is possible - for runs created prior to AIP-39. - - This function is private to Airflow core and should not be depended on as a - part of the Python API. - - :meta private: - """ - if self.dag_id != dag_model.dag_id: - raise ValueError(f"Arguments refer to different DAGs: {self.dag_id} != {dag_model.dag_id}") - if dag_model.next_dagrun is None: # Next run not scheduled. - return None - data_interval = dag_model.next_dagrun_data_interval - if data_interval is not None: - return data_interval - - # Compatibility: A run was scheduled without an explicit data interval. - # This means the run was scheduled before AIP-39 implementation. Try to - # infer from the logical date. - return self.infer_automated_data_interval(dag_model.next_dagrun) - - def get_run_data_interval(self, run: DagRun) -> DataInterval: - """ - Get the data interval of this run. - - For compatibility, this method infers the data interval from the DAG's - schedule if the run does not have an explicit one set, which is possible for - runs created prior to AIP-39. - - This function is private to Airflow core and should not be depended on as a - part of the Python API. - - :meta private: - """ - if run.dag_id is not None and run.dag_id != self.dag_id: - raise ValueError(f"Arguments refer to different DAGs: {self.dag_id} != {run.dag_id}") - data_interval = _get_model_data_interval(run, "data_interval_start", "data_interval_end") - if data_interval is not None: - return data_interval - # Compatibility: runs created before AIP-39 implementation don't have an - # explicit data interval. Try to infer from the logical date. - return self.infer_automated_data_interval(run.logical_date) - - def infer_automated_data_interval(self, logical_date: datetime) -> DataInterval: - """ - Infer a data interval for a run against this DAG. - - This method is used to bridge runs created prior to AIP-39 - implementation, which do not have an explicit data interval. Therefore, - this method only considers ``schedule_interval`` values valid prior to - Airflow 2.2. - - DO NOT call this method if there is a known data interval. - - :meta private: - """ - timetable_type = type(self.timetable) - if issubclass(timetable_type, (NullTimetable, OnceTimetable, AssetTriggeredTimetable)): - return DataInterval.exact(timezone.coerce_datetime(logical_date)) - start = timezone.coerce_datetime(logical_date) - if issubclass(timetable_type, CronDataIntervalTimetable): - end = cast("CronDataIntervalTimetable", self.timetable)._get_next(start) - elif issubclass(timetable_type, DeltaDataIntervalTimetable): - end = cast("DeltaDataIntervalTimetable", self.timetable)._get_next(start) - # Contributors: When the exception below is raised, you might want to - # add an 'elif' block here to handle custom timetables. Stop! The bug - # you're looking for is instead at when the DAG run (represented by - # logical_date) was created. See GH-31969 for an example: - # * Wrong fix: GH-32074 (modifies this function). - # * Correct fix: GH-32118 (modifies the DAG run creation code). - else: - raise ValueError(f"Not a valid timetable: {self.timetable!r}") - return DataInterval(start, end) - - def next_dagrun_info( - self, - last_automated_dagrun: None | DataInterval, - *, - restricted: bool = True, - ) -> DagRunInfo | None: - """ - Get information about the next DagRun of this dag after ``date_last_automated_dagrun``. - - This calculates what time interval the next DagRun should operate on - (its logical date) and when it can be scheduled, according to the - dag's timetable, start_date, end_date, etc. This doesn't check max - active run or any other "max_active_tasks" type limits, but only - performs calculations based on the various date and interval fields of - this dag and its tasks. - - :param last_automated_dagrun: The ``max(logical_date)`` of - existing "automated" DagRuns for this dag (scheduled or backfill, - but not manual). - :param restricted: If set to *False* (default is *True*), ignore - ``start_date``, ``end_date``, and ``catchup`` specified on the DAG - or tasks. - :return: DagRunInfo of the next dagrun, or None if a dagrun is not - going to be scheduled. - """ - data_interval = None - if isinstance(last_automated_dagrun, datetime): - raise ValueError( - "Passing a datetime to DAG.next_dagrun_info is not supported anymore. Use a DataInterval instead." - ) - data_interval = last_automated_dagrun - if restricted: - restriction = self._time_restriction - else: - restriction = TimeRestriction(earliest=None, latest=None, catchup=True) - try: - info = self.timetable.next_dagrun_info( - last_automated_data_interval=data_interval, - restriction=restriction, - ) - except Exception: - self.log.exception( - "Failed to fetch run info after data interval %s for DAG %r", - data_interval, - self.dag_id, - ) - info = None - return info - - @functools.cached_property - def _time_restriction(self) -> TimeRestriction: - start_dates = [t.start_date for t in self.tasks if t.start_date] - if self.start_date is not None: - start_dates.append(self.start_date) - earliest = None - if start_dates: - earliest = timezone.coerce_datetime(min(start_dates)) - latest = timezone.coerce_datetime(self.end_date) - end_dates = [t.end_date for t in self.tasks if t.end_date] - if len(end_dates) == len(self.tasks): # not exists null end_date - if self.end_date is not None: - end_dates.append(self.end_date) - if end_dates: - latest = timezone.coerce_datetime(max(end_dates)) - return TimeRestriction(earliest, latest, self.catchup) - - def iter_dagrun_infos_between( - self, - earliest: pendulum.DateTime | datetime | None, - latest: pendulum.DateTime | datetime, - *, - align: bool = True, - ) -> Iterable[DagRunInfo]: - """ - Yield DagRunInfo using this DAG's timetable between given interval. - - DagRunInfo instances yielded if their ``logical_date`` is not earlier - than ``earliest``, nor later than ``latest``. The instances are ordered - by their ``logical_date`` from earliest to latest. - - If ``align`` is ``False``, the first run will happen immediately on - ``earliest``, even if it does not fall on the logical timetable schedule. - The default is ``True``. - - Example: A DAG is scheduled to run every midnight (``0 0 * * *``). If - ``earliest`` is ``2021-06-03 23:00:00``, the first DagRunInfo would be - ``2021-06-03 23:00:00`` if ``align=False``, and ``2021-06-04 00:00:00`` - if ``align=True``. - """ - if earliest is None: - earliest = self._time_restriction.earliest - if earliest is None: - raise ValueError("earliest was None and we had no value in time_restriction to fallback on") - earliest = timezone.coerce_datetime(earliest) - latest = timezone.coerce_datetime(latest) - - restriction = TimeRestriction(earliest, latest, catchup=True) - - try: - info = self.timetable.next_dagrun_info( - last_automated_data_interval=None, - restriction=restriction, - ) - except Exception: - self.log.exception( - "Failed to fetch run info after data interval %s for DAG %r", - None, - self.dag_id, - ) - info = None - - if info is None: - # No runs to be scheduled between the user-supplied timeframe. But - # if align=False, "invent" a data interval for the timeframe itself. - if not align: - yield DagRunInfo.interval(earliest, latest) - return - - # If align=False and earliest does not fall on the timetable's logical - # schedule, "invent" a data interval for it. - if not align and info.logical_date != earliest: - yield DagRunInfo.interval(earliest, info.data_interval.start) - - # Generate naturally according to schedule. - while info is not None: - yield info - try: - info = self.timetable.next_dagrun_info( - last_automated_data_interval=info.data_interval, - restriction=restriction, - ) - except Exception: - self.log.exception( - "Failed to fetch run info after data interval %s for DAG %r", - info.data_interval if info else "", - self.dag_id, - ) - break - - @provide_session - def get_last_dagrun(self, session=NEW_SESSION, include_manually_triggered=False): - return get_last_dagrun( - self.dag_id, session=session, include_manually_triggered=include_manually_triggered - ) - - @property - def dag_id(self) -> str: - return self._dag_id - - @dag_id.setter - def dag_id(self, value: str) -> None: - self._dag_id = value - - @provide_session - def get_concurrency_reached(self, session=NEW_SESSION) -> bool: - """Return a boolean indicating whether the max_active_tasks limit for this DAG has been reached.""" - TI = TaskInstance - total_tasks = session.scalar( - select(func.count(TI.task_id)).where( - TI.dag_id == self.dag_id, - TI.state == TaskInstanceState.RUNNING, - ) - ) - return total_tasks >= self.max_active_tasks - - @provide_session - def get_is_active(self, session=NEW_SESSION) -> None: - """Return a boolean indicating whether this DAG is active.""" - return session.scalar(select(~DagModel.is_stale).where(DagModel.dag_id == self.dag_id)) - - @provide_session - def get_is_stale(self, session=NEW_SESSION) -> None: - """Return a boolean indicating whether this DAG is stale.""" - return session.scalar(select(DagModel.is_stale).where(DagModel.dag_id == self.dag_id)) - - @provide_session - def get_is_paused(self, session=NEW_SESSION) -> None: - """Return a boolean indicating whether this DAG is paused.""" - return session.scalar(select(DagModel.is_paused).where(DagModel.dag_id == self.dag_id)) - - @provide_session - def get_bundle_name(self, session=NEW_SESSION) -> str | None: - """Return the bundle name this DAG is in.""" - return session.scalar(select(DagModel.bundle_name).where(DagModel.dag_id == self.dag_id)) - - @provide_session - def get_bundle_version(self, session=NEW_SESSION) -> str | None: - """Return the bundle version that was seen when this dag was processed.""" - return session.scalar(select(DagModel.bundle_version).where(DagModel.dag_id == self.dag_id)) - - @methodtools.lru_cache(maxsize=None) - @classmethod - def get_serialized_fields(cls): - """Stringified DAGs and operators contain exactly these fields.""" - return TaskSDKDag.get_serialized_fields() | {"_processor_dags_folder"} - - def get_active_runs(self): - """ - Return a list of dag run logical dates currently running. - - :return: List of logical dates - """ - runs = DagRun.find(dag_id=self.dag_id, state=DagRunState.RUNNING) - - active_dates = [] - for run in runs: - active_dates.append(run.logical_date) - - return active_dates - - @staticmethod - @provide_session - def fetch_dagrun(dag_id: str, run_id: str, session: Session = NEW_SESSION) -> DagRun: - """ - Return the dag run for a given run_id if it exists, otherwise none. - - :param dag_id: The dag_id of the DAG to find. - :param run_id: The run_id of the DagRun to find. - :param session: - :return: The DagRun if found, otherwise None. - """ - return session.scalar(select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == run_id)) - - @provide_session - def get_dagrun(self, run_id: str, session: Session = NEW_SESSION) -> DagRun: - return DAG.fetch_dagrun(dag_id=self.dag_id, run_id=run_id, session=session) - - @provide_session - def get_dagruns_between(self, start_date, end_date, session=NEW_SESSION): - """ - Return the list of dag runs between start_date (inclusive) and end_date (inclusive). - - :param start_date: The starting logical date of the DagRun to find. - :param end_date: The ending logical date of the DagRun to find. - :param session: - :return: The list of DagRuns found. - """ - dagruns = session.scalars( - select(DagRun).where( - DagRun.dag_id == self.dag_id, - DagRun.logical_date >= start_date, - DagRun.logical_date <= end_date, - ) - ).all() - - return dagruns - - @provide_session - def get_latest_logical_date(self, session: Session = NEW_SESSION) -> pendulum.DateTime | None: - """Return the latest date for which at least one dag run exists.""" - return session.scalar(select(func.max(DagRun.logical_date)).where(DagRun.dag_id == self.dag_id)) - - @provide_session - def get_task_instances_before( - self, - base_date: datetime, - num: int, - *, - session: Session = NEW_SESSION, - ) -> list[TaskInstance]: - """ - Get ``num`` task instances before (including) ``base_date``. - - The returned list may contain exactly ``num`` task instances - corresponding to any DagRunType. It can have less if there are - less than ``num`` scheduled DAG runs before ``base_date``. - """ - logical_dates: list[Any] = session.execute( - select(DagRun.logical_date) - .where( - DagRun.dag_id == self.dag_id, - DagRun.logical_date <= base_date, - ) - .order_by(DagRun.logical_date.desc()) - .limit(num) - ).all() - - if not logical_dates: - return self.get_task_instances(start_date=base_date, end_date=base_date, session=session) - - min_date: datetime | None = logical_dates[-1]._mapping.get( - "logical_date" - ) # getting the last value from the list - - return self.get_task_instances(start_date=min_date, end_date=base_date, session=session) - - @provide_session - def get_task_instances( - self, - start_date: datetime | None = None, - end_date: datetime | None = None, - state: TaskInstanceState | Sequence[TaskInstanceState] | None = None, - session: Session = NEW_SESSION, - ) -> list[TaskInstance]: - if not start_date: - start_date = (timezone.utcnow() - timedelta(30)).replace( - hour=0, minute=0, second=0, microsecond=0 - ) - - query = self._get_task_instances( - task_ids=None, - start_date=start_date, - end_date=end_date, - run_id=None, - state=state or (), - include_dependent_dags=False, - exclude_task_ids=(), - exclude_run_ids=None, - session=session, - ) - return session.scalars(cast("Select", query).order_by(DagRun.logical_date)).all() - - @overload - def _get_task_instances( - self, - *, - task_ids: Collection[str | tuple[str, int]] | None, - start_date: datetime | None, - end_date: datetime | None, - run_id: str | None, - state: TaskInstanceState | Sequence[TaskInstanceState], - include_dependent_dags: bool, - exclude_task_ids: Collection[str | tuple[str, int]] | None, - exclude_run_ids: frozenset[str] | None, - session: Session, - dag_bag: DBDagBag | None = ..., - ) -> Iterable[TaskInstance]: ... # pragma: no cover - - @overload - def _get_task_instances( - self, - *, - task_ids: Collection[str | tuple[str, int]] | None, - as_pk_tuple: Literal[True], - start_date: datetime | None, - end_date: datetime | None, - run_id: str | None, - state: TaskInstanceState | Sequence[TaskInstanceState], - include_dependent_dags: bool, - exclude_task_ids: Collection[str | tuple[str, int]] | None, - exclude_run_ids: frozenset[str] | None, - session: Session, - dag_bag: DBDagBag | None = ..., - recursion_depth: int = ..., - max_recursion_depth: int = ..., - visited_external_tis: set[TaskInstanceKey] = ..., - ) -> set[TaskInstanceKey]: ... # pragma: no cover - - def _get_task_instances( - self, - *, - task_ids: Collection[str | tuple[str, int]] | None, - as_pk_tuple: Literal[True, None] = None, - start_date: datetime | None, - end_date: datetime | None, - run_id: str | None, - state: TaskInstanceState | Sequence[TaskInstanceState], - include_dependent_dags: bool, - exclude_task_ids: Collection[str | tuple[str, int]] | None, - exclude_run_ids: frozenset[str] | None, - session: Session, - dag_bag: DBDagBag | None = None, - recursion_depth: int = 0, - max_recursion_depth: int | None = None, - visited_external_tis: set[TaskInstanceKey] | None = None, - ) -> Iterable[TaskInstance] | set[TaskInstanceKey]: - from airflow.models.dagbag import DBDagBag - - TI = TaskInstance - - # If we are looking at dependent dags we want to avoid UNION calls - # in SQL (it doesn't play nice with fields that have no equality operator, - # like JSON types), we instead build our result set separately. - # - # This will be empty if we are only looking at one dag, in which case - # we can return the filtered TI query object directly. - result: set[TaskInstanceKey] = set() - - # Do we want full objects, or just the primary columns? - if as_pk_tuple: - tis = select(TI.dag_id, TI.task_id, TI.run_id, TI.map_index) - else: - tis = select(TaskInstance) - tis = tis.join(TaskInstance.dag_run) - - if self.partial: - tis = tis.where(TaskInstance.dag_id == self.dag_id, TaskInstance.task_id.in_(self.task_ids)) - else: - tis = tis.where(TaskInstance.dag_id == self.dag_id) - if run_id: - tis = tis.where(TaskInstance.run_id == run_id) - if start_date: - tis = tis.where(DagRun.logical_date >= start_date) - if task_ids is not None: - tis = tis.where(TaskInstance.ti_selector_condition(task_ids)) - if end_date: - tis = tis.where(DagRun.logical_date <= end_date) - - if state: - if isinstance(state, (str, TaskInstanceState)): - tis = tis.where(TaskInstance.state == state) - elif len(state) == 1: - tis = tis.where(TaskInstance.state == state[0]) - else: - # this is required to deal with NULL values - if None in state: - if all(x is None for x in state): - tis = tis.where(TaskInstance.state.is_(None)) - else: - not_none_state = [s for s in state if s] - tis = tis.where( - or_(TaskInstance.state.in_(not_none_state), TaskInstance.state.is_(None)) - ) - else: - tis = tis.where(TaskInstance.state.in_(state)) - - if exclude_run_ids: - tis = tis.where(TaskInstance.run_id.not_in(exclude_run_ids)) - - if include_dependent_dags: - # Recursively find external tasks indicated by ExternalTaskMarker - from airflow.providers.standard.sensors.external_task import ExternalTaskMarker - - query = tis - if as_pk_tuple: - all_tis = session.execute(query).all() - condition = TI.filter_for_tis(TaskInstanceKey(*cols) for cols in all_tis) - if condition is not None: - query = select(TI).where(condition) - - if visited_external_tis is None: - visited_external_tis = set() - - external_tasks = session.scalars(query.where(TI.operator == ExternalTaskMarker.__name__)) - - for ti in external_tasks: - ti_key = ti.key.primary - if ti_key in visited_external_tis: - continue - - visited_external_tis.add(ti_key) - - task: ExternalTaskMarker = cast("ExternalTaskMarker", copy.copy(self.get_task(ti.task_id))) - ti.task = task - - if max_recursion_depth is None: - # Maximum recursion depth allowed is the recursion_depth of the first - # ExternalTaskMarker in the tasks to be visited. - max_recursion_depth = task.recursion_depth - - if recursion_depth + 1 > max_recursion_depth: - # Prevent cycles or accidents. - raise AirflowException( - f"Maximum recursion depth {max_recursion_depth} reached for " - f"{ExternalTaskMarker.__name__} {ti.task_id}. " - f"Attempted to clear too many tasks or there may be a cyclic dependency." - ) - ti.render_templates() - external_tis = session.scalars( - select(TI) - .join(TI.dag_run) - .where( - TI.dag_id == task.external_dag_id, - TI.task_id == task.external_task_id, - DagRun.logical_date == pendulum.parse(task.logical_date), - ) - ) - - for tii in external_tis: - if not dag_bag: - dag_bag = DBDagBag() - if not isinstance(dag_bag, DBDagBag): # Compat: This used to take non-db object. - external_dag = dag_bag.get_dag(tii.dag_id, session=session) - else: - external_dag = dag_bag.get_dag_for_run(tii.dag_run, session=session) - if not external_dag: - raise AirflowException(f"Could not find dag {tii.dag_id}") - downstream = external_dag.partial_subset( - task_ids=[tii.task_id], - include_upstream=False, - include_downstream=True, - ) - result.update( - downstream._get_task_instances( - task_ids=None, - run_id=tii.run_id, - start_date=None, - end_date=None, - state=state, - include_dependent_dags=include_dependent_dags, - as_pk_tuple=True, - exclude_task_ids=exclude_task_ids, - exclude_run_ids=exclude_run_ids, - dag_bag=dag_bag, - session=session, - recursion_depth=recursion_depth + 1, - max_recursion_depth=max_recursion_depth, - visited_external_tis=visited_external_tis, - ) - ) - - if result or as_pk_tuple: - # Only execute the `ti` query if we have also collected some other results - if as_pk_tuple: - tis_query = session.execute(tis).all() - result.update(TaskInstanceKey(**cols._mapping) for cols in tis_query) - else: - result.update(ti.key for ti in session.scalars(tis)) - - if exclude_task_ids is not None: - result = { - task - for task in result - if task.task_id not in exclude_task_ids - and (task.task_id, task.map_index) not in exclude_task_ids - } - - if as_pk_tuple: - return result - if result: - # We've been asked for objects, lets combine it all back in to a result set - ti_filters = TI.filter_for_tis(result) - if ti_filters is not None: - tis = select(TI).where(ti_filters) - elif exclude_task_ids is None: - pass # Disable filter if not set. - elif isinstance(next(iter(exclude_task_ids), None), str): - tis = tis.where(TI.task_id.notin_(exclude_task_ids)) - else: - tis = tis.where(tuple_(TI.task_id, TI.map_index).not_in(exclude_task_ids)) - - return tis - - @provide_session - def set_task_instance_state( - self, - *, - task_id: str, - map_indexes: Collection[int] | None = None, - run_id: str | None = None, - state: TaskInstanceState, - upstream: bool = False, - downstream: bool = False, - future: bool = False, - past: bool = False, - commit: bool = True, - session=NEW_SESSION, - ) -> list[TaskInstance]: - """ - Set the state of a TaskInstance and clear downstream tasks in failed or upstream_failed state. - - :param task_id: Task ID of the TaskInstance - :param map_indexes: Only set TaskInstance if its map_index matches. - If None (default), all mapped TaskInstances of the task are set. - :param run_id: The run_id of the TaskInstance - :param state: State to set the TaskInstance to - :param upstream: Include all upstream tasks of the given task_id - :param downstream: Include all downstream tasks of the given task_id - :param future: Include all future TaskInstances of the given task_id - :param commit: Commit changes - :param past: Include all past TaskInstances of the given task_id - """ - from airflow.api.common.mark_tasks import set_state - - # TODO (GH-52141): get_task in scheduler needs to return scheduler types - # instead, but currently it inherits SDK's DAG. - task = cast("Operator", self.get_task(task_id)) - task.dag = self - - tasks_to_set_state: list[Operator | tuple[Operator, int]] - if map_indexes is None: - tasks_to_set_state = [task] - else: - tasks_to_set_state = [(task, map_index) for map_index in map_indexes] - - altered = set_state( - tasks=tasks_to_set_state, - run_id=run_id, - upstream=upstream, - downstream=downstream, - future=future, - past=past, - state=state, - commit=commit, - session=session, - ) - - if not commit: - return altered - - # Clear downstream tasks that are in failed/upstream_failed state to resume them. - # Flush the session so that the tasks marked success are reflected in the db. - session.flush() - subset = self.partial_subset( - task_ids={task_id}, - include_downstream=True, - include_upstream=False, - ) - - # Raises an error if not found - dr_id, logical_date = session.execute( - select(DagRun.id, DagRun.logical_date).where( - DagRun.run_id == run_id, DagRun.dag_id == self.dag_id - ) - ).one() - - # Now we want to clear downstreams of tasks that had their state set... - clear_kwargs = { - "only_failed": True, - "session": session, - # Exclude the task itself from being cleared. - "exclude_task_ids": frozenset((task_id,)), - } - if not future and not past: # Simple case 1: we're only dealing with exactly one run. - clear_kwargs["run_id"] = run_id - subset.clear(**clear_kwargs) - elif future and past: # Simple case 2: we're clearing ALL runs. - subset.clear(**clear_kwargs) - else: # Complex cases: we may have more than one run, based on a date range. - # Make 'future' and 'past' make some sense when multiple runs exist - # for the same logical date. We order runs by their id and only - # clear runs have larger/smaller ids. - exclude_run_id_stmt = select(DagRun.run_id).where(DagRun.logical_date == logical_date) - if future: - clear_kwargs["start_date"] = logical_date - exclude_run_id_stmt = exclude_run_id_stmt.where(DagRun.id > dr_id) - else: - clear_kwargs["end_date"] = logical_date - exclude_run_id_stmt = exclude_run_id_stmt.where(DagRun.id < dr_id) - subset.clear(exclude_run_ids=frozenset(session.scalars(exclude_run_id_stmt)), **clear_kwargs) - return altered - - @provide_session - def set_task_group_state( - self, - *, - group_id: str, - run_id: str | None = None, - state: TaskInstanceState, - upstream: bool = False, - downstream: bool = False, - future: bool = False, - past: bool = False, - commit: bool = True, - session: Session = NEW_SESSION, - ) -> list[TaskInstance]: - """ - Set TaskGroup to the given state and clear downstream tasks in failed or upstream_failed state. - - :param group_id: The group_id of the TaskGroup - :param run_id: The run_id of the TaskInstance - :param state: State to set the TaskInstance to - :param upstream: Include all upstream tasks of the given task_id - :param downstream: Include all downstream tasks of the given task_id - :param future: Include all future TaskInstances of the given task_id - :param commit: Commit changes - :param past: Include all past TaskInstances of the given task_id - :param session: new session - """ - from airflow.api.common.mark_tasks import set_state - from airflow.serialization.serialized_objects import SerializedBaseOperator as BaseOperator - - tasks_to_set_state: list - task_ids: list[str] - - task_group_dict = self.task_group.get_task_group_dict() - task_group = task_group_dict.get(group_id) - if task_group is None: - raise ValueError("TaskGroup {group_id} could not be found") - tasks_to_set_state = [task for task in task_group.iter_tasks() if isinstance(task, BaseOperator)] - task_ids = [task.task_id for task in task_group.iter_tasks()] - dag_runs_query = select(DagRun.id).where(DagRun.dag_id == self.dag_id) - - @cache - def get_logical_date() -> datetime: - stmt = select(DagRun.logical_date).where(DagRun.run_id == run_id, DagRun.dag_id == self.dag_id) - return session.scalars(stmt).one() # Raises an error if not found - - end_date = None if future else get_logical_date() - start_date = None if past else get_logical_date() - - if future: - dag_runs_query = dag_runs_query.where(DagRun.logical_date <= start_date) - if past: - dag_runs_query = dag_runs_query.where(DagRun.logical_date >= end_date) - if not future and not past: - dag_runs_query = dag_runs_query.where(DagRun.run_id == run_id) - - with lock_rows(dag_runs_query, session): - altered = set_state( - tasks=tasks_to_set_state, - run_id=run_id, - upstream=upstream, - downstream=downstream, - future=future, - past=past, - state=state, - commit=commit, - session=session, - ) - if not commit: - return altered - - # Clear downstream tasks that are in failed/upstream_failed state to resume them. - # Flush the session so that the tasks marked success are reflected in the db. - session.flush() - subset = self.partial_subset( - task_ids=task_ids, - include_downstream=True, - include_upstream=False, - ) - - subset.clear( - start_date=start_date, - end_date=end_date, - only_failed=True, - session=session, - # Exclude the task from the current group from being cleared - exclude_task_ids=frozenset(task_ids), - ) - - return altered - - @overload - def clear( - self, - *, - dry_run: Literal[True], - task_ids: Collection[str | tuple[str, int]] | None = None, - run_id: str, - only_failed: bool = False, - only_running: bool = False, - confirm_prompt: bool = False, - dag_run_state: DagRunState = DagRunState.QUEUED, - session: Session = NEW_SESSION, - dag_bag: DBDagBag | None = None, - exclude_task_ids: frozenset[str] | frozenset[tuple[str, int]] | None = frozenset(), - exclude_run_ids: frozenset[str] | None = frozenset(), - run_on_latest_version: bool = False, - ) -> list[TaskInstance]: ... # pragma: no cover - - @overload - def clear( - self, - *, - task_ids: Collection[str | tuple[str, int]] | None = None, - run_id: str, - only_failed: bool = False, - only_running: bool = False, - confirm_prompt: bool = False, - dag_run_state: DagRunState = DagRunState.QUEUED, - dry_run: Literal[False] = False, - session: Session = NEW_SESSION, - dag_bag: DBDagBag | None = None, - exclude_task_ids: frozenset[str] | frozenset[tuple[str, int]] | None = frozenset(), - exclude_run_ids: frozenset[str] | None = frozenset(), - run_on_latest_version: bool = False, - ) -> int: ... # pragma: no cover - - @overload - def clear( - self, - *, - dry_run: Literal[True], - task_ids: Collection[str | tuple[str, int]] | None = None, - start_date: datetime | None = None, - end_date: datetime | None = None, - only_failed: bool = False, - only_running: bool = False, - confirm_prompt: bool = False, - dag_run_state: DagRunState = DagRunState.QUEUED, - session: Session = NEW_SESSION, - dag_bag: DBDagBag | None = None, - exclude_task_ids: frozenset[str] | frozenset[tuple[str, int]] | None = frozenset(), - exclude_run_ids: frozenset[str] | None = frozenset(), - run_on_latest_version: bool = False, - ) -> list[TaskInstance]: ... # pragma: no cover - - @overload - def clear( - self, - *, - task_ids: Collection[str | tuple[str, int]] | None = None, - start_date: datetime | None = None, - end_date: datetime | None = None, - only_failed: bool = False, - only_running: bool = False, - confirm_prompt: bool = False, - dag_run_state: DagRunState = DagRunState.QUEUED, - dry_run: Literal[False] = False, - session: Session = NEW_SESSION, - dag_bag: DBDagBag | None = None, - exclude_task_ids: frozenset[str] | frozenset[tuple[str, int]] | None = frozenset(), - exclude_run_ids: frozenset[str] | None = frozenset(), - run_on_latest_version: bool = False, - ) -> int: ... # pragma: no cover - - @provide_session - def clear( - self, - task_ids: Collection[str | tuple[str, int]] | None = None, - *, - run_id: str | None = None, - start_date: datetime | None = None, - end_date: datetime | None = None, - only_failed: bool = False, - only_running: bool = False, - confirm_prompt: bool = False, - dag_run_state: DagRunState = DagRunState.QUEUED, - dry_run: bool = False, - session: Session = NEW_SESSION, - dag_bag: DBDagBag | None = None, - exclude_task_ids: frozenset[str] | frozenset[tuple[str, int]] | None = frozenset(), - exclude_run_ids: frozenset[str] | None = frozenset(), - run_on_latest_version: bool = False, - ) -> int | Iterable[TaskInstance]: - """ - Clear a set of task instances associated with the current dag for a specified date range. - - :param task_ids: List of task ids or (``task_id``, ``map_index``) tuples to clear - :param run_id: The run_id for which the tasks should be cleared - :param start_date: The minimum logical_date to clear - :param end_date: The maximum logical_date to clear - :param only_failed: Only clear failed tasks - :param only_running: Only clear running tasks. - :param confirm_prompt: Ask for confirmation - :param dag_run_state: state to set DagRun to. If set to False, dagrun state will not - be changed. - :param dry_run: Find the tasks to clear but don't clear them. - :param run_on_latest_version: whether to run on latest serialized DAG and Bundle version - :param session: The sqlalchemy session to use - :param dag_bag: The DagBag used to find the dags (Optional) - :param exclude_task_ids: A set of ``task_id`` or (``task_id``, ``map_index``) - tuples that should not be cleared - :param exclude_run_ids: A set of ``run_id`` or (``run_id``) - """ - state: list[TaskInstanceState] = [] - if only_failed: - state += [TaskInstanceState.FAILED, TaskInstanceState.UPSTREAM_FAILED] - if only_running: - # Yes, having `+=` doesn't make sense, but this was the existing behaviour - state += [TaskInstanceState.RUNNING] - - tis = self._get_task_instances( - task_ids=task_ids, - start_date=start_date, - end_date=end_date, - run_id=run_id, - state=state, - include_dependent_dags=True, - session=session, - dag_bag=dag_bag, - exclude_task_ids=exclude_task_ids, - exclude_run_ids=exclude_run_ids, - ) - - if dry_run: - return session.scalars(tis).all() - - tis = session.scalars(tis).all() - - count = len(list(tis)) - do_it = True - if count == 0: - return 0 - if confirm_prompt: - ti_list = "\n".join(str(t) for t in tis) - question = f"You are about to delete these {count} tasks:\n{ti_list}\n\nAre you sure? [y/n]" - do_it = utils.helpers.ask_yesno(question) - - if do_it: - clear_task_instances( - list(tis), - session, - dag_run_state=dag_run_state, - run_on_latest_version=run_on_latest_version, - ) - else: - count = 0 - print("Cancelled, nothing was cleared.") - - session.flush() - return count - - @classmethod - def clear_dags( - cls, - dags, - start_date=None, - end_date=None, - only_failed=False, - only_running=False, - confirm_prompt=False, - dag_run_state=DagRunState.QUEUED, - dry_run=False, - ): - all_tis = [] - for dag in dags: - if not isinstance(dag, DAG): - dag = DAG.from_sdk_dag(dag) - tis = dag.clear( - start_date=start_date, - end_date=end_date, - only_failed=only_failed, - only_running=only_running, - confirm_prompt=False, - dag_run_state=dag_run_state, - dry_run=True, - ) - all_tis.extend(tis) - - if dry_run: - return all_tis - - count = len(all_tis) - do_it = True - if count == 0: - print("Nothing to clear.") - return 0 - if confirm_prompt: - ti_list = "\n".join(str(t) for t in all_tis) - question = f"You are about to delete these {count} tasks:\n{ti_list}\n\nAre you sure? [y/n]" - do_it = utils.helpers.ask_yesno(question) - - if do_it: - for dag in dags: - if not isinstance(dag, DAG): - dag = DAG.from_sdk_dag(dag) - dag.clear( - start_date=start_date, - end_date=end_date, - only_failed=only_failed, - only_running=only_running, - confirm_prompt=False, - dag_run_state=dag_run_state, - dry_run=False, - ) - else: - count = 0 - print("Cancelled, nothing was cleared.") - return count - - @provide_session - def create_dagrun( - self, - *, - run_id: str, - logical_date: datetime | None = None, - data_interval: tuple[datetime, datetime] | None = None, - run_after: datetime, - conf: dict | None = None, - run_type: DagRunType, - triggered_by: DagRunTriggeredByType, - triggering_user_name: str | None = None, - state: DagRunState, - start_date: datetime | None = None, - creating_job_id: int | None = None, - backfill_id: NonNegativeInt | None = None, - session: Session = NEW_SESSION, - ) -> DagRun: - """ - Create a run for this DAG to run its tasks. - - :param run_id: ID of the dag_run - :param logical_date: date of execution - :param run_after: the datetime before which dag won't run - :param conf: Dict containing configuration/parameters to pass to the DAG - :param triggered_by: the entity which triggers the dag_run - :param triggering_user_name: the user name who triggers the dag_run - :param start_date: the date this dag run should be evaluated - :param creating_job_id: ID of the job creating this DagRun - :param backfill_id: ID of the backfill run if one exists - :param session: Unused. Only added in compatibility with database isolation mode - :return: The created DAG run. - - :meta private: - """ - logical_date = timezone.coerce_datetime(logical_date) - # For manual runs where logical_date is None, ensure no data_interval is set. - if logical_date is None and data_interval is not None: - raise ValueError("data_interval must be None when logical_date is None") - - if data_interval and not isinstance(data_interval, DataInterval): - data_interval = DataInterval(*map(timezone.coerce_datetime, data_interval)) - - if isinstance(run_type, DagRunType): - pass - elif isinstance(run_type, str): # Ensure the input value is valid. - run_type = DagRunType(run_type) - else: - raise ValueError(f"run_type should be a DagRunType, not {type(run_type)}") - - if not isinstance(run_id, str): - raise ValueError(f"`run_id` should be a str, not {type(run_id)}") - - # This is also done on the DagRun model class, but SQLAlchemy column - # validator does not work well for some reason. - if not re.match(RUN_ID_REGEX, run_id): - regex = airflow_conf.get("scheduler", "allowed_run_id_pattern").strip() - if not regex or not re.match(regex, run_id): - raise ValueError( - f"The run_id provided '{run_id}' does not match regex pattern " - f"'{regex}' or '{RUN_ID_REGEX}'" - ) - - # Prevent a manual run from using an ID that looks like a scheduled run. - if run_type == DagRunType.MANUAL: - if (inferred_run_type := DagRunType.from_run_id(run_id)) != DagRunType.MANUAL: - raise ValueError( - f"A {run_type.value} DAG run cannot use ID {run_id!r} since it " - f"is reserved for {inferred_run_type.value} runs" - ) - - # todo: AIP-78 add verification that if run type is backfill then we have a backfill id - - # create a copy of params before validating - copied_params = copy.deepcopy(self.params) - if conf: - copied_params.update(conf) - copied_params.validate() - orm_dagrun = _create_orm_dagrun( - dag=self, - run_id=run_id, - logical_date=logical_date, - data_interval=data_interval, - run_after=timezone.coerce_datetime(run_after), - start_date=timezone.coerce_datetime(start_date), - conf=conf, - state=state, - run_type=run_type, - creating_job_id=creating_job_id, - backfill_id=backfill_id, - triggered_by=triggered_by, - triggering_user_name=triggering_user_name, - session=session, - ) - - if self.deadline and isinstance(self.deadline.reference, DeadlineReference.TYPES.DAGRUN): - session.add( - Deadline( - deadline_time=self.deadline.reference.evaluate_with( - session=session, - interval=self.deadline.interval, - dag_id=self.dag_id, - run_id=run_id, - ), - callback=self.deadline.callback, - dagrun_id=orm_dagrun.id, - ) - ) - - return orm_dagrun - - @classmethod - @provide_session - def bulk_write_to_db( - cls, - bundle_name: str, - bundle_version: str | None, - dags: Collection[MaybeSerializedDAG], - session: Session = NEW_SESSION, - ): - """ - Ensure the DagModel rows for the given dags are up-to-date in the dag table in the DB. - - :param dags: the DAG objects to save to the DB - :return: None - """ - if not dags: - return - - from airflow.dag_processing.collection import AssetModelOperation, DagModelOperation - - log.info("Sync %s DAGs", len(dags)) - dag_op = DagModelOperation( - bundle_name=bundle_name, bundle_version=bundle_version, dags={d.dag_id: d for d in dags} - ) - - orm_dags = dag_op.add_dags(session=session) - dag_op.update_dags(orm_dags, session=session) - - asset_op = AssetModelOperation.collect(dag_op.dags) - - orm_assets = asset_op.sync_assets(session=session) - orm_asset_aliases = asset_op.sync_asset_aliases(session=session) - session.flush() # This populates id so we can create fks in later calls. - - orm_dags = dag_op.find_orm_dags(session=session) # Refetch so relationship is up to date. - asset_op.add_dag_asset_references(orm_dags, orm_assets, session=session) - asset_op.add_dag_asset_alias_references(orm_dags, orm_asset_aliases, session=session) - asset_op.add_dag_asset_name_uri_references(session=session) - asset_op.add_task_asset_references(orm_dags, orm_assets, session=session) - asset_op.activate_assets_if_possible(orm_assets.values(), session=session) - session.flush() # Activation is needed when we add trigger references. - - asset_op.add_asset_trigger_references(orm_assets, session=session) - dag_op.update_dag_asset_expression(orm_dags=orm_dags, orm_assets=orm_assets) - session.flush() - - @provide_session - def sync_to_db(self, session=NEW_SESSION): - """ - Save attributes about this DAG to the DB. - - :return: None - """ - # TODO: AIP-66 should this be in the model? - bundle_name = self.get_bundle_name(session=session) - bundle_version = self.get_bundle_version(session=session) - self.bulk_write_to_db(bundle_name, bundle_version, [self], session=session) - - @staticmethod - @provide_session - def deactivate_unknown_dags(active_dag_ids, session=NEW_SESSION): - """ - Given a list of known DAGs, deactivate any other DAGs that are marked as active in the ORM. - - :param active_dag_ids: list of DAG IDs that are active - :return: None - """ - if not active_dag_ids: - return - for dag in session.scalars(select(DagModel).where(~DagModel.dag_id.in_(active_dag_ids))).all(): - dag.is_stale = True - session.merge(dag) - session.commit() - - @staticmethod - @provide_session - def deactivate_stale_dags(expiration_date, session=NEW_SESSION): - """ - Deactivate any DAGs that were last touched by the scheduler before the expiration date. - - These DAGs were likely deleted. - - :param expiration_date: set inactive DAGs that were touched before this time - :return: None - """ - for dag in session.scalars( - select(DagModel).where(DagModel.last_parsed_time < expiration_date, ~DagModel.is_stale) - ): - log.info( - "Deactivating DAG ID %s since it was last touched by the scheduler at %s", - dag.dag_id, - dag.last_parsed_time.isoformat(), - ) - dag.is_stale = True - session.merge(dag) - session.commit() - - @staticmethod - @provide_session - def get_num_task_instances(dag_id, run_id=None, task_ids=None, states=None, session=NEW_SESSION) -> int: - """ - Return the number of task instances in the given DAG. - - :param session: ORM session - :param dag_id: ID of the DAG to get the task concurrency of - :param run_id: ID of the DAG run to get the task concurrency of - :param task_ids: A list of valid task IDs for the given DAG - :param states: A list of states to filter by if supplied - :return: The number of running tasks - """ - qry = select(func.count(TaskInstance.task_id)).where( - TaskInstance.dag_id == dag_id, - ) - if run_id: - qry = qry.where( - TaskInstance.run_id == run_id, - ) - if task_ids: - qry = qry.where( - TaskInstance.task_id.in_(task_ids), - ) - - if states: - if None in states: - if all(x is None for x in states): - qry = qry.where(TaskInstance.state.is_(None)) - else: - not_none_states = [state for state in states if state] - qry = qry.where( - or_(TaskInstance.state.in_(not_none_states), TaskInstance.state.is_(None)) - ) - else: - qry = qry.where(TaskInstance.state.in_(states)) - return session.scalar(qry) - - # "default has type "type[Asset]", argument has type "type[AssetT]") [assignment]" :shrug: - def get_task_assets( - self, - inlets: bool = True, - outlets: bool = True, - of_type: type[AssetT] = Asset, # type: ignore[assignment] - ) -> Generator[tuple[str, AssetT], None, None]: - for task in self.task_dict.values(): - directions: tuple[str, ...] = ("inlets",) if inlets else () - if outlets: - directions += ("outlets",) - for direction in directions: - if not (ports := getattr(task, direction, None)): - continue - - for port in ports: - if isinstance(port, of_type): - yield task.task_id, port - - @classmethod - def from_sdk_dag(cls, dag: TaskSDKDag) -> DAG: - """Create a new (Scheduler) DAG object from a TaskSDKDag.""" - if not isinstance(dag, TaskSDKDag): - return dag - - fields = attrs.fields(dag.__class__) - - kwargs = {} - for field in fields: - # Skip fields that are: - # 1. Initialized after creation (init=False) - # 2. Internal state fields that shouldn't be copied - if not field.init or field.name in ["edge_info"]: - continue - - kwargs[field.name] = getattr(dag, field.name) - - new_dag = cls(**kwargs) - - task_group_map = {} - - def create_task_groups(task_group, parent_group=None): - new_task_group = copy.deepcopy(task_group) - - new_task_group.dag = new_dag - new_task_group.parent_group = parent_group - new_task_group.children = {} - - task_group_map[task_group.group_id] = new_task_group - - for child in task_group.children.values(): - if isinstance(child, TaskGroup): - create_task_groups(child, new_task_group) - - create_task_groups(dag.task_group) - - def create_tasks(task): - if isinstance(task, TaskGroup): - return task_group_map[task.group_id] - - new_task = copy.copy(task) - - # Only overwrite the specific attributes we want to change - new_task.task_id = task.task_id - new_task.dag = None # Don't set dag yet - new_task.task_group = task_group_map.get(task.task_group.group_id) if task.task_group else None - - return new_task - - # Process all tasks in the original DAG - for task in dag.tasks: - new_task = create_tasks(task) - if not isinstance(new_task, TaskGroup): - # Add the task to the DAG - new_dag.task_dict[new_task.task_id] = new_task - if new_task.task_group: - new_task.task_group.children[new_task.task_id] = new_task - new_task.dag = new_dag - - new_dag.edge_info = dag.edge_info.copy() - - return new_dag - - class DagTag(Base): """A tag name per dag, to allow quick filtering in the DAG view.""" @@ -2041,10 +479,6 @@ def get_last_dagrun(self, session=NEW_SESSION, include_manually_triggered=False) self.dag_id, session=session, include_manually_triggered=include_manually_triggered ) - def get_is_paused(self, *, session: Session | None = None) -> bool: - """Provide interface compatibility to 'DAG'.""" - return self.is_paused - def get_is_active(self, *, session: Session | None = None) -> bool: """Provide interface compatibility to 'DAG'.""" return not self.is_stale @@ -2072,26 +506,6 @@ def get_paused_dag_ids(dag_ids: list[str], session: Session = NEW_SESSION) -> se def safe_dag_id(self): return self.dag_id.replace(".", "__dot__") - @provide_session - def set_is_paused(self, is_paused: bool, session=NEW_SESSION) -> None: - """ - Pause/Un-pause a DAG. - - :param is_paused: Is the DAG paused - :param session: session - """ - filter_query = [ - DagModel.dag_id == self.dag_id, - ] - - session.execute( - update(DagModel) - .where(or_(*filter_query)) - .values(is_paused=is_paused) - .execution_options(synchronize_session="fetch") - ) - session.commit() - @hybrid_property def dag_display_name(self) -> str: return self._dag_display_property_value or self.dag_id @@ -2230,7 +644,7 @@ def dag_ready(dag_id: str, cond: BaseAsset, statuses: dict[AssetUniqueKey, bool] def calculate_dagrun_date_fields( self, - dag: DAG, + dag: SerializedDAG, last_automated_dag_run: None | DataInterval, ) -> None: """ @@ -2293,52 +707,21 @@ def get_team_name(dag_id: str, session=NEW_SESSION) -> str | None: """:sphinx-autoapi-skip:""" -def _get_or_create_dagrun( - *, - dag: DAG, - run_id: str, - logical_date: datetime | None, - data_interval: tuple[datetime, datetime] | None, - run_after: datetime, - conf: dict | None, - triggered_by: DagRunTriggeredByType, - triggering_user_name: str | None, - start_date: datetime, - session: Session, -) -> DagRun: - """ - Create a DAG run, replacing an existing instance if needed to prevent collisions. - - This function is only meant to be used by :meth:`DAG.test` as a helper function. +def __getattr__(name: str): + # Add DAG and dag for compatibility. We can't do this in + # airflow/models/__init__.py since this module contains other things. + if name not in ("DAG", "dag"): + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - :param dag: DAG to be used to find run. - :param conf: Configuration to pass to newly created run. - :param start_date: Start date of new run. - :param logical_date: Logical date for finding an existing run. - :param run_id: Run ID for the new DAG run. - :param triggered_by: the entity which triggers the dag_run - :param triggering_user_name: the user name who triggers the dag_run + import warnings - :return: The newly created DAG run. - """ - dr: DagRun = session.scalar( - select(DagRun).where(DagRun.dag_id == dag.dag_id, DagRun.logical_date == logical_date) - ) - if dr: - session.delete(dr) - session.commit() - dr = dag.create_dagrun( - run_id=run_id, - logical_date=logical_date, - data_interval=data_interval, - run_after=run_after, - conf=conf, - run_type=DagRunType.MANUAL, - state=DagRunState.RUNNING, - triggered_by=triggered_by, - triggering_user_name=triggering_user_name, - start_date=start_date or logical_date, - session=session, + warnings.warn( + f"Import {name!r} directly from the airflow module is deprecated and " + f"will be removed in the future. Please import it from 'airflow.sdk'.", + DeprecationWarning, + stacklevel=2, ) - log.info("created dagrun %s", dr) - return dr + + import airflow.sdk + + return getattr(airflow.sdk, name) diff --git a/airflow-core/src/airflow/models/dagbag.py b/airflow-core/src/airflow/models/dagbag.py index a2c6d5462c95a..19572d7f9305f 100644 --- a/airflow-core/src/airflow/models/dagbag.py +++ b/airflow-core/src/airflow/models/dagbag.py @@ -48,10 +48,13 @@ AirflowDagDuplicatedIdException, AirflowException, AirflowTaskTimeout, + UnknownExecutorException, ) +from airflow.executors.executor_loader import ExecutorLoader from airflow.listeners.listener import get_listener_manager from airflow.models.base import Base, StringID from airflow.models.dag_version import DagVersion +from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG from airflow.utils.docs import get_docs_url from airflow.utils.file import ( correct_maybe_zipped, @@ -73,8 +76,8 @@ from sqlalchemy.orm import Session + from airflow import DAG from airflow.models import DagRun - from airflow.models.dag import DAG from airflow.models.dagwarning import DagWarning from airflow.models.serialized_dag import SerializedDagModel from airflow.utils.types import ArgNotSet @@ -143,6 +146,20 @@ def handle_timeout(signum, frame): signal.setitimer(signal.ITIMER_REAL, 0) +def _validate_executor_fields(dag: DAG) -> None: + for task in dag.tasks: + if not task.executor: + continue + try: + ExecutorLoader.lookup_executor_name_by_str(task.executor) + except UnknownExecutorException: + raise UnknownExecutorException( + f"Task '{task.task_id}' specifies executor '{task.executor}', which is not available. " + "Make sure it is listed in your [core] executors configuration, or update the task's " + "executor to use one of the configured executors." + ) + + class DagBag(LoggingMixin): """ A dagbag is a collection of dags, parsed out of a folder tree and has high level configuration settings. @@ -475,11 +492,10 @@ def _load_modules_from_zip(self, filepath, safe_mode): return mods def _process_modules(self, filepath, mods, file_last_changed_on_disk): - from airflow.models.dag import DAG # Avoid circular import - from airflow.sdk import DAG as SDKDAG + from airflow.sdk import DAG from airflow.sdk.definitions._internal.contextmanager import DagContext - top_level_dags = {(o, m) for m in mods for o in m.__dict__.values() if isinstance(o, (DAG, SDKDAG))} + top_level_dags = {(o, m) for m in mods for o in m.__dict__.values() if isinstance(o, DAG)} top_level_dags.update(DagContext.autoregistered_dags) @@ -494,6 +510,7 @@ def _process_modules(self, filepath, mods, file_last_changed_on_disk): dag.relative_fileloc = relative_fileloc try: dag.validate() + _validate_executor_fields(dag) self.bag_dag(dag=dag) except AirflowClusterPolicySkipDag: pass @@ -642,22 +659,13 @@ def sync_bag_to_db( session: Session = NEW_SESSION, ) -> None: """Save attributes about list of DAG to the DB.""" - import airflow.models.dag from airflow.dag_processing.collection import update_dag_parsing_results_in_db - from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG - - dags = [ - dag - if isinstance(dag, airflow.models.dag.DAG) - else LazyDeserializedDAG(data=SerializedDAG.to_dict(dag)) - for dag in dagbag.dags.values() - ] - import_errors = {(bundle_name, rel_path): error for rel_path, error in dagbag.import_errors.items()} + import_errors = {(bundle_name, rel_path): error for rel_path, error in dagbag.import_errors.items()} update_dag_parsing_results_in_db( bundle_name, bundle_version, - dags, + [LazyDeserializedDAG.from_dag(dag) for dag in dagbag.dags.values()], import_errors, dagbag.dag_warnings, session=session, @@ -671,17 +679,17 @@ class DBDagBag: :meta private: """ - def __init__(self, load_op_links: bool = True): - self._dags: dict[str, DAG] = {} # dag_version_id to dag + def __init__(self, load_op_links: bool = True) -> None: + self._dags: dict[str, SerializedDAG] = {} # dag_version_id to dag self.load_op_links = load_op_links - def _read_dag(self, serdag: SerializedDagModel) -> DAG | None: + def _read_dag(self, serdag: SerializedDagModel) -> SerializedDAG | None: serdag.load_op_links = self.load_op_links if dag := serdag.dag: self._dags[serdag.dag_version_id] = dag return dag - def _get_dag(self, version_id: str, session: Session) -> DAG | None: + def _get_dag(self, version_id: str, session: Session) -> SerializedDAG | None: if dag := self._dags.get(version_id): return dag dag_version = session.get(DagVersion, version_id, options=[joinedload(DagVersion.serialized_dag)]) @@ -706,12 +714,12 @@ def _version_from_dag_run(dag_run: DagRun, *, session: Session) -> DagVersion: # Relationship not loaded, fetch it explicitly from current session return session.get(DagVersion, dag_run.created_dag_version_id) - def get_dag_for_run(self, dag_run: DagRun, session: Session) -> DAG | None: + def get_dag_for_run(self, dag_run: DagRun, session: Session) -> SerializedDAG | None: if version := self._version_from_dag_run(dag_run=dag_run, session=session): return self._get_dag(version_id=version.id, session=session) return None - def iter_all_latest_version_dags(self, *, session: Session) -> Generator[DAG, None, None]: + def iter_all_latest_version_dags(self, *, session: Session) -> Generator[SerializedDAG, None, None]: """Walk through all latest version dags available in the database.""" from airflow.models.serialized_dag import SerializedDagModel @@ -719,7 +727,7 @@ def iter_all_latest_version_dags(self, *, session: Session) -> Generator[DAG, No if dag := self._read_dag(sdm): yield dag - def get_latest_version_of_dag(self, dag_id: str, *, session: Session) -> DAG | None: + def get_latest_version_of_dag(self, dag_id: str, *, session: Session) -> SerializedDAG | None: """Get the latest version of a dag by its id.""" from airflow.models.serialized_dag import SerializedDagModel diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index 2b79da564235c..572a711c3015b 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -24,6 +24,7 @@ from collections.abc import Callable, Iterable, Iterator, Sequence from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar, cast, overload +import structlog from natsort import natsorted from sqlalchemy import ( JSON, @@ -94,12 +95,11 @@ from sqlalchemy.orm import Query, Session from sqlalchemy.sql.elements import Case - from airflow.models.dag import DAG from airflow.models.dag_version import DagVersion from airflow.models.mappedoperator import MappedOperator from airflow.models.taskinstancekey import TaskInstanceKey from airflow.sdk import DAG as SDKDAG - from airflow.serialization.serialized_objects import SerializedBaseOperator + from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG from airflow.utils.types import ArgNotSet CreatedTasks = TypeVar("CreatedTasks", Iterator["dict[str, Any]"], Iterator[TI]) @@ -110,6 +110,8 @@ RUN_ID_REGEX = r"^(?:manual|scheduled|asset_triggered)__(?:\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\+00:00)$" +log = structlog.get_logger(__name__) + class TISchedulingDecision(NamedTuple): """Type of return for DagRun.task_instance_scheduling_decisions.""" @@ -209,9 +211,9 @@ class DagRun(Base, LoggingMixin): # Remove this `if` after upgrading Sphinx-AutoAPI if not TYPE_CHECKING and "BUILDING_AIRFLOW_DOCS" in os.environ: - dag: DAG | None + dag: SerializedDAG | None else: - dag: DAG | None = None + dag: SerializedDAG | None = None __table_args__ = ( Index("dag_id_state", dag_id, _state), @@ -893,7 +895,7 @@ def fetch_task_instance( select(TI).filter_by(dag_id=dag_id, run_id=dag_run_id, task_id=task_id, map_index=map_index) ).one_or_none() - def get_dag(self) -> DAG: + def get_dag(self) -> SerializedDAG: """ Return the Dag associated with this DagRun. @@ -1298,7 +1300,7 @@ def task_instance_scheduling_decisions(self, session: Session = NEW_SESSION) -> tis = self.get_task_instances(session=session, state=State.task_states) self.log.debug("number of tis tasks for %s: %s task(s)", self, len(tis)) - def _filter_tis_and_exclude_removed(dag: DAG, tis: list[TI]) -> Iterable[TI]: + def _filter_tis_and_exclude_removed(dag: SerializedDAG, tis: list[TI]) -> Iterable[TI]: """Populate ``ti.task`` while excluding those missing one, marking them as REMOVED.""" for ti in tis: try: @@ -1360,7 +1362,7 @@ def notify_dagrun_state_changed(self, msg: str = ""): # or LocalTaskJob, so we don't want to "falsely advertise" we notify about that @provide_session - def get_last_ti(self, dag: DAG, session: Session = NEW_SESSION) -> TI | None: + def get_last_ti(self, dag: SerializedDAG, session: Session = NEW_SESSION) -> TI | None: """Get Last TI from the dagrun to build and pass Execution context object from server to then run callbacks.""" tis = self.get_task_instances(session=session) # tis from a dagrun may not be a part of dag.partial_subset, @@ -1584,6 +1586,8 @@ def _emit_true_scheduling_delay_stats_for_finished_state(self, finished_tis: lis Note that the stat will only be emitted for scheduler-triggered DAG runs (i.e. when ``run_type`` is *SCHEDULED* and ``clear_number`` is equal to 0). """ + from airflow.models.dag import get_run_data_interval + if self.state == TaskInstanceState.RUNNING: return if self.run_type != DagRunType.SCHEDULED: @@ -1610,7 +1614,7 @@ def _emit_true_scheduling_delay_stats_for_finished_state(self, finished_tis: lis # execution on DagModel.next_dagrun_create_after. We should add # a field on DagRun for this instead of relying on the run # always happening immediately after the data interval. - data_interval_end = dag.get_run_data_interval(self).end + data_interval_end = get_run_data_interval(dag.timetable, self).end true_delay = first_start_date - data_interval_end if true_delay.total_seconds() > 0: Stats.timing( @@ -1683,7 +1687,7 @@ def task_filter(task: Operator) -> bool: self._create_task_instances(self.dag_id, tis_to_create, created_counts, hook_is_noop, session=session) def _check_for_removed_or_restored_tasks( - self, dag: DAG, ti_mutation_hook, *, session: Session + self, dag: SerializedDAG, ti_mutation_hook, *, session: Session ) -> set[str]: """ Check for removed tasks/restored/missing tasks. @@ -2089,7 +2093,7 @@ def _get_log_template(log_template_id: int | None, session: Session = NEW_SESSIO return template @staticmethod - def _get_partial_task_ids(dag: DAG | None) -> list[str] | None: + def _get_partial_task_ids(dag: SerializedDAG | None) -> list[str] | None: return dag.task_ids if dag and dag.partial else None @@ -2125,3 +2129,54 @@ def __repr__(self): if self.map_index != -1: prefix += f" map_index={self.map_index}" return prefix + ">" + + +def get_or_create_dagrun( + *, + dag: SerializedDAG, + run_id: str, + logical_date: datetime | None, + data_interval: tuple[datetime, datetime] | None, + run_after: datetime, + conf: dict | None, + triggered_by: DagRunTriggeredByType, + triggering_user_name: str | None, + start_date: datetime, + session: Session, +) -> DagRun: + """ + Create a DAG run, replacing an existing instance if needed to prevent collisions. + + This function is only meant to be used by :meth:`DAG.test` as a helper function. + + :param dag: DAG to be used to find run. + :param conf: Configuration to pass to newly created run. + :param start_date: Start date of new run. + :param logical_date: Logical date for finding an existing run. + :param run_id: Run ID for the new DAG run. + :param triggered_by: the entity which triggers the dag_run + :param triggering_user_name: the user name who triggers the dag_run + + :return: The newly created DAG run. + """ + dr: DagRun = session.scalar( + select(DagRun).where(DagRun.dag_id == dag.dag_id, DagRun.logical_date == logical_date) + ) + if dr: + session.delete(dr) + session.commit() + dr = dag.create_dagrun( + run_id=run_id, + logical_date=logical_date, + data_interval=data_interval, + run_after=run_after, + conf=conf, + run_type=DagRunType.MANUAL, + state=DagRunState.RUNNING, + triggered_by=triggered_by, + triggering_user_name=triggering_user_name, + start_date=start_date or logical_date, + session=session, + ) + log.info("Created dag run.", dagrun=dr) + return dr diff --git a/airflow-core/src/airflow/models/mappedoperator.py b/airflow-core/src/airflow/models/mappedoperator.py index 6e086beeaaae3..705c7bad34602 100644 --- a/airflow-core/src/airflow/models/mappedoperator.py +++ b/airflow-core/src/airflow/models/mappedoperator.py @@ -35,6 +35,7 @@ from airflow.sdk.definitions._internal.node import DAGNode from airflow.sdk.definitions.mappedoperator import MappedOperator as TaskSDKMappedOperator from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup +from airflow.serialization.enums import DagAttributeTypes from airflow.serialization.serialized_objects import DEFAULT_OPERATOR_DEPS, SerializedBaseOperator from airflow.task.priority_strategy import PriorityWeightStrategy, validate_and_load_priority_weight_strategy @@ -45,11 +46,11 @@ import pendulum from airflow.models import TaskInstance - from airflow.models.dag import DAG as SchedulerDAG from airflow.models.expandinput import SchedulerExpandInput from airflow.sdk import BaseOperatorLink, Context from airflow.sdk.definitions.operator_resources import Resources from airflow.sdk.definitions.param import ParamsDict + from airflow.serialization.serialized_objects import SerializedDAG from airflow.task.trigger_rule import TriggerRule from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.triggers.base import StartTriggerArgs @@ -102,7 +103,7 @@ class MappedOperator(DAGNode): start_from_trigger: bool = False _needs_expansion: bool = True - dag: SchedulerDAG = attrs.field(init=False) + dag: SerializedDAG = attrs.field(init=False) task_group: TaskGroup = attrs.field(init=False) start_date: pendulum.DateTime | None = attrs.field(init=False, default=None) end_date: pendulum.DateTime | None = attrs.field(init=False, default=None) @@ -380,6 +381,10 @@ def get_extra_links(self, ti: TaskInstance, name: str) -> str | None: return None return link.get_link(self, ti_key=ti.key) # type: ignore[arg-type] # TODO: GH-52141 - BaseOperatorLink.get_link expects BaseOperator but receives MappedOperator + def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: + """Implement DAGNode.""" + return DagAttributeTypes.OP, self.task_id + # TODO (GH-52141): Copied from sdk. Find a better place for this to live in. def _get_specified_expand_input(self) -> SchedulerExpandInput: """Input received from the expand call on the operator.""" @@ -482,7 +487,7 @@ def _(task: MappedOperator | TaskSDKMappedOperator, run_id: str, *, session: Ses # TODO (GH-52141): 'task' here should be scheduler-bound and returns scheduler expand input. if not hasattr(exp_input, "get_total_map_length"): if TYPE_CHECKING: - assert isinstance(task.dag, SchedulerDAG) + assert isinstance(task.dag, SerializedDAG) current_count = ( _ExpandInputRef( exp_input.EXPAND_INPUT_TYPE, @@ -526,7 +531,7 @@ def iter_mapped_task_group_lengths(group) -> Iterator[int]: # TODO (GH-52141): 'group' here should be scheduler-bound and returns scheduler expand input. if not hasattr(exp_input, "get_total_map_length"): if TYPE_CHECKING: - assert isinstance(group.dag, SchedulerDAG) + assert isinstance(group.dag, SerializedDAG) exp_input = _ExpandInputRef( exp_input.EXPAND_INPUT_TYPE, BaseSerialization.deserialize(BaseSerialization.serialize(exp_input.value)), diff --git a/airflow-core/src/airflow/models/serialized_dag.py b/airflow-core/src/airflow/models/serialized_dag.py index 2d9c65f80666a..a9c05c71d5109 100644 --- a/airflow-core/src/airflow/models/serialized_dag.py +++ b/airflow-core/src/airflow/models/serialized_dag.py @@ -45,7 +45,7 @@ from airflow.models.dagrun import DagRun from airflow.sdk.definitions.asset import AssetUniqueKey from airflow.serialization.dag_dependency import DagDependency -from airflow.serialization.serialized_objects import SerializedDAG +from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG from airflow.settings import COMPRESS_SERIALIZED_DAGS, json from airflow.utils.hashlib_wrapper import md5 from airflow.utils.session import NEW_SESSION, provide_session @@ -57,8 +57,6 @@ from sqlalchemy.orm import Session from airflow.models import Operator - from airflow.sdk import DAG - from airflow.serialization.serialized_objects import LazyDeserializedDAG log = logging.getLogger(__name__) @@ -317,15 +315,9 @@ class SerializedDagModel(Base): load_op_links = True - def __init__(self, dag: DAG | LazyDeserializedDAG) -> None: - from airflow.sdk import DAG - + def __init__(self, dag: LazyDeserializedDAG) -> None: self.dag_id = dag.dag_id - if isinstance(dag, DAG): - dag_data = SerializedDAG.to_dict(dag) - else: - dag_data = dag.data - + dag_data = dag.data self.dag_hash = SerializedDagModel.hash(dag_data) # partially ordered json data @@ -382,7 +374,7 @@ def _sort_serialized_dag_dict(cls, serialized_dag: Any): @provide_session def write_dag( cls, - dag: DAG | LazyDeserializedDAG, + dag: LazyDeserializedDAG, bundle_name: str, bundle_version: str | None = None, min_update_interval: int | None = None, diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index e81b1691f7000..ca2aaeae78851 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -118,7 +118,7 @@ from sqlalchemy.sql.elements import BooleanClauseList from sqlalchemy.sql.expression import ColumnOperators - from airflow.models.dag import DAG as SchedulerDAG, DagModel + from airflow.models.dag import DagModel from airflow.models.dagrun import DagRun from airflow.models.mappedoperator import MappedOperator from airflow.sdk import DAG @@ -852,8 +852,6 @@ def get_previous_dagrun( if dag is None: return None - if TYPE_CHECKING: - assert isinstance(dag, SchedulerDAG) dr = self.get_dagrun(session=session) dr.dag = dag @@ -1030,7 +1028,6 @@ def get_dagrun(self, session: Session = NEW_SESSION) -> DagRun: if getattr(self, "task", None) is not None: if TYPE_CHECKING: assert self.task - assert isinstance(self.task.dag, SchedulerDAG) dr.dag = self.task.dag # Record it in the instance for next time. This means that `self.logical_date` will work correctly set_committed_value(self, "dag_run", dr) diff --git a/airflow-core/src/airflow/models/taskmap.py b/airflow-core/src/airflow/models/taskmap.py index e48c53aa4034e..edd0b21b114ea 100644 --- a/airflow-core/src/airflow/models/taskmap.py +++ b/airflow-core/src/airflow/models/taskmap.py @@ -35,7 +35,6 @@ if TYPE_CHECKING: from sqlalchemy.orm import Session - from airflow.models.dag import DAG as SchedulerDAG from airflow.models.mappedoperator import MappedOperator from airflow.models.taskinstance import TaskInstance from airflow.serialization.serialized_objects import SerializedBaseOperator @@ -176,7 +175,7 @@ def expand_mapped_task( if unmapped_ti: if TYPE_CHECKING: - assert task.dag is None or isinstance(task.dag, SchedulerDAG) + assert task.dag is None # The unmapped task instance still exists and is unfinished, i.e. we # haven't tried to run it before. diff --git a/airflow-core/src/airflow/models/xcom_arg.py b/airflow-core/src/airflow/models/xcom_arg.py index 78021e5043123..75cccba50334d 100644 --- a/airflow-core/src/airflow/models/xcom_arg.py +++ b/airflow-core/src/airflow/models/xcom_arg.py @@ -36,9 +36,8 @@ __all__ = ["XComArg", "get_task_map_length"] if TYPE_CHECKING: - from airflow.models.dag import DAG as SchedulerDAG from airflow.models.mappedoperator import MappedOperator - from airflow.serialization.serialized_objects import SerializedBaseOperator + from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG from airflow.typing_compat import Self Operator: TypeAlias = MappedOperator | SerializedBaseOperator @@ -46,7 +45,7 @@ class SchedulerXComArg: @classmethod - def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: + def _deserialize(cls, data: dict[str, Any], dag: SerializedDAG) -> Self: """ Deserialize an XComArg. @@ -92,8 +91,8 @@ class SchedulerPlainXComArg(SchedulerXComArg): key: str @classmethod - def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: - # TODO (GH-52141): SchedulerDAG should return scheduler operator instead. + def _deserialize(cls, data: dict[str, Any], dag: SerializedDAG) -> Self: + # TODO (GH-52141): SerializedDAG should return scheduler operator instead. return cls(cast("Operator", dag.get_task(data["task_id"])), data["key"]) def iter_references(self) -> Iterator[tuple[Operator, str]]: @@ -106,7 +105,7 @@ class SchedulerMapXComArg(SchedulerXComArg): callables: Sequence[str] @classmethod - def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: + def _deserialize(cls, data: dict[str, Any], dag: SerializedDAG) -> Self: # We are deliberately NOT deserializing the callables. These are shown # in the UI, and displaying a function object is useless. return cls(deserialize_xcom_arg(data["arg"], dag), data["callables"]) @@ -120,7 +119,7 @@ class SchedulerConcatXComArg(SchedulerXComArg): args: Sequence[SchedulerXComArg] @classmethod - def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: + def _deserialize(cls, data: dict[str, Any], dag: SerializedDAG) -> Self: return cls([deserialize_xcom_arg(arg, dag) for arg in data["args"]]) def iter_references(self) -> Iterator[tuple[Operator, str]]: @@ -134,7 +133,7 @@ class SchedulerZipXComArg(SchedulerXComArg): fillvalue: Any @classmethod - def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: + def _deserialize(cls, data: dict[str, Any], dag: SerializedDAG) -> Self: return cls( [deserialize_xcom_arg(arg, dag) for arg in data["args"]], fillvalue=data.get("fillvalue", NOTSET), @@ -219,7 +218,7 @@ def _(xcom_arg: SchedulerConcatXComArg, run_id: str, *, session: Session): return sum(ready_lengths) -def deserialize_xcom_arg(data: dict[str, Any], dag: SchedulerDAG): +def deserialize_xcom_arg(data: dict[str, Any], dag: SerializedDAG): """DAG serialization interface.""" klass = _XCOM_ARG_TYPES[data.get("type", "")] return klass._deserialize(data, dag) diff --git a/airflow-core/src/airflow/policies.py b/airflow-core/src/airflow/policies.py index 933ccaa24522c..7a8311bd67d99 100644 --- a/airflow-core/src/airflow/policies.py +++ b/airflow-core/src/airflow/policies.py @@ -27,13 +27,11 @@ __all__: list[str] = ["hookimpl"] if TYPE_CHECKING: - from airflow.models.dag import DAG from airflow.models.taskinstance import TaskInstance - from airflow.serialization.serialized_objects import SerializedBaseOperator as BaseOperator @local_settings_hookspec -def task_policy(task: BaseOperator) -> None: +def task_policy(task) -> None: """ Allow altering tasks after they are loaded in the DagBag. @@ -51,7 +49,7 @@ def task_policy(task: BaseOperator) -> None: @local_settings_hookspec -def dag_policy(dag: DAG) -> None: +def dag_policy(dag) -> None: """ Allow altering DAGs after they are loaded in the DagBag. diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py b/airflow-core/src/airflow/serialization/serialized_objects.py index 0a33a1ecf25b8..dcc0458c71771 100644 --- a/airflow-core/src/airflow/serialization/serialized_objects.py +++ b/airflow-core/src/airflow/serialization/serialized_objects.py @@ -21,17 +21,19 @@ from __future__ import annotations import collections.abc +import copy import datetime import enum import itertools import logging import math +import re import weakref -from collections.abc import Collection, Generator, Iterable, Iterator, Mapping, Sequence +from collections.abc import Collection, Iterable, Iterator, Mapping, Sequence from functools import cached_property, lru_cache from inspect import signature from textwrap import dedent -from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple, TypeAlias, TypeVar, cast, overload +from typing import TYPE_CHECKING, Any, ClassVar, Literal, NamedTuple, TypeAlias, TypeVar, cast, overload import attrs import lazy_object_proxy @@ -39,18 +41,24 @@ import pydantic from dateutil import relativedelta from pendulum.tz.timezone import FixedTimezone, Timezone +from sqlalchemy import func, or_, select, tuple_ from airflow import macros -from airflow._shared.timezones.timezone import from_timestamp, parse_timezone +from airflow._shared.timezones.timezone import coerce_datetime, from_timestamp, parse_timezone, utcnow from airflow.callbacks.callback_requests import DagCallbackRequest, TaskCallbackRequest +from airflow.configuration import conf as airflow_conf from airflow.exceptions import AirflowException, SerializationError, TaskDeferred from airflow.models.connection import Connection -from airflow.models.dag import DAG, _get_model_data_interval +from airflow.models.dag import DagModel +from airflow.models.dag_version import DagVersion +from airflow.models.dagrun import RUN_ID_REGEX, DagRun +from airflow.models.deadline import Deadline from airflow.models.expandinput import create_expand_input from airflow.models.taskinstancekey import TaskInstanceKey +from airflow.models.tasklog import LogTemplate from airflow.models.xcom import XComModel from airflow.models.xcom_arg import SchedulerXComArg, deserialize_xcom_arg -from airflow.sdk import Asset, AssetAlias, AssetAll, AssetAny, AssetWatcher, BaseOperator, XComArg +from airflow.sdk import DAG, Asset, AssetAlias, AssetAll, AssetAny, AssetWatcher, BaseOperator, XComArg from airflow.sdk.bases.operator import OPERATOR_DEFAULTS # TODO: Copy this into the scheduler? from airflow.sdk.definitions._internal.node import DAGNode from airflow.sdk.definitions.asset import ( @@ -60,7 +68,7 @@ AssetUniqueKey, BaseAsset, ) -from airflow.sdk.definitions.deadline import DeadlineAlert +from airflow.sdk.definitions.deadline import DeadlineAlert, DeadlineReference from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sdk.definitions.operator_resources import Resources from airflow.sdk.definitions.param import Param, ParamsDict @@ -83,31 +91,31 @@ from airflow.ti_deps.deps.not_previously_skipped_dep import NotPreviouslySkippedDep from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep +from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable from airflow.triggers.base import BaseTrigger, StartTriggerArgs from airflow.utils.code_utils import get_python_source -from airflow.utils.context import ( - ConnectionAccessor, - Context, - VariableAccessor, -) +from airflow.utils.context import ConnectionAccessor, Context, VariableAccessor from airflow.utils.db import LazySelectSequence from airflow.utils.docs import get_docs_url from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.module_loading import import_string, qualname -from airflow.utils.types import NOTSET, ArgNotSet +from airflow.utils.session import NEW_SESSION, provide_session +from airflow.utils.state import DagRunState, TaskInstanceState +from airflow.utils.types import NOTSET, ArgNotSet, DagRunTriggeredByType, DagRunType if TYPE_CHECKING: from inspect import Parameter - from airflow.models import DagRun + from pydantic import NonNegativeInt + from sqlalchemy.orm import Session + from airflow.models.expandinput import SchedulerExpandInput - from airflow.models.mappedoperator import MappedOperator as SchedulerMappedOperator + from airflow.models.mappedoperator import MappedOperator as SerializedMappedOperator from airflow.models.taskinstance import TaskInstance - from airflow.sdk import DAG as SdkDag, BaseOperatorLink + from airflow.sdk import BaseOperatorLink from airflow.serialization.json_schema import Validator from airflow.task.trigger_rule import TriggerRule from airflow.ti_deps.deps.base_ti_dep import BaseTIDep - from airflow.timetables.base import DagRunInfo, DataInterval, Timetable from airflow.triggers.base import BaseEventTrigger HAS_KUBERNETES: bool @@ -118,7 +126,7 @@ except ImportError: pass - SchedulerOperator: TypeAlias = "SchedulerMappedOperator | SerializedBaseOperator" + SerializedOperator: TypeAlias = "SerializedMappedOperator | SerializedBaseOperator" SdkOperator: TypeAlias = BaseOperator | MappedOperator DEFAULT_OPERATOR_DEPS: frozenset[BaseTIDep] = frozenset( @@ -513,7 +521,7 @@ class _XComRef(NamedTuple): data: dict - def deref(self, dag: DAG) -> SchedulerXComArg: + def deref(self, dag: SerializedDAG) -> SchedulerXComArg: return deserialize_xcom_arg(self.data, dag) @@ -551,7 +559,7 @@ def validate_expand_input_value(cls, value: _ExpandInputOriginalValue) -> None: possible ExpandInput cases. """ - def deref(self, dag: DAG) -> SchedulerExpandInput: + def deref(self, dag: SerializedDAG) -> SchedulerExpandInput: """ De-reference into a concrete ExpandInput object. @@ -580,7 +588,7 @@ class BaseSerialization: # Object types that are always excluded in serialization. _excluded_types = (logging.Logger, Connection, type, property) - _json_schema: Validator | None = None + _json_schema: ClassVar[Validator | None] = None # Should the extra operator link be loaded via plugins when # de-serializing the DAG? This flag is set to False in Scheduler so that Extra Operator links @@ -592,12 +600,12 @@ class BaseSerialization: SERIALIZER_VERSION = 2 @classmethod - def to_json(cls, var: DAG | SchedulerOperator | dict | list | set | tuple) -> str: + def to_json(cls, var: Any) -> str: """Stringify DAGs and operators contained by var and returns a JSON string of var.""" return json.dumps(cls.to_dict(var), ensure_ascii=True) @classmethod - def to_dict(cls, var: DAG | SchedulerOperator | dict | list | set | tuple) -> dict: + def to_dict(cls, var: Any) -> dict: """Stringify DAGs and operators contained by var and returns a dict of var.""" # Don't call on this class directly - only SerializedDAG or # SerializedBaseOperator should be used as the "entrypoint" @@ -653,7 +661,7 @@ def _is_excluded(cls, var: Any, attrname: str, instance: Any) -> bool: def serialize_to_json( cls, # TODO (GH-52141): When can we remove scheduler constructs here? - object_to_serialize: SdkOperator | SchedulerOperator | SdkDag | DAG, + object_to_serialize: SdkOperator | SerializedOperator | DAG | SerializedDAG, decorated_fields: set, ) -> dict[str, Any]: """Serialize an object to JSON.""" @@ -1192,7 +1200,7 @@ def detect_task_dependencies(task: SdkOperator) -> list[DagDependency]: return deps @staticmethod - def detect_dag_dependencies(dag: SdkDag | None) -> Iterable[DagDependency]: + def detect_dag_dependencies(dag: DAG | None) -> Iterable[DagDependency]: """Detect dependencies set directly on the DAG object.""" if not dag: return @@ -1219,7 +1227,7 @@ class SerializedBaseOperator(DAGNode, BaseSerialization): _CONSTRUCTOR_PARAMS = {} - _json_schema: Validator = lazy_object_proxy.Proxy(load_dag_schema) + _json_schema: ClassVar[Validator] = lazy_object_proxy.Proxy(load_dag_schema) _can_skip_downstream: bool _is_empty: bool @@ -1227,7 +1235,7 @@ class SerializedBaseOperator(DAGNode, BaseSerialization): _task_display_name: str | None _weight_rule: str | PriorityWeightStrategy = "downstream" - dag: DAG | None = None + dag: SerializedDAG | None = None task_group: TaskGroup | None = None allow_nested_operators: bool = True @@ -1247,7 +1255,7 @@ class SerializedBaseOperator(DAGNode, BaseSerialization): execution_timeout: datetime.timedelta | None executor: str | None - executor_config: dict | None = {} + executor_config: dict = {} ignore_first_depends_on_past: bool = False inlets: Sequence = [] @@ -1267,7 +1275,7 @@ class SerializedBaseOperator(DAGNode, BaseSerialization): has_on_success_callback: bool = False has_on_skipped_callback: bool = False - operator_extra_links: Collection[BaseOperatorLink] = () + operator_extra_links: Collection[BaseOperatorLink] = [] on_failure_fail_dagrun: bool = False outlets: Sequence = [] @@ -1332,7 +1340,7 @@ def __eq__(self, other: Any) -> bool: def node_id(self) -> str: return self.task_id - def get_dag(self) -> SdkDag | None: + def get_dag(self) -> DAG | None: return self.dag @property @@ -1509,7 +1517,7 @@ def _serialize_node(cls, op: SdkOperator) -> dict[str, Any]: @classmethod def populate_operator( cls, - op: SchedulerOperator, + op: SerializedOperator, encoded_op: dict[str, Any], client_defaults: dict[str, Any] | None = None, ) -> None: @@ -1656,7 +1664,7 @@ def populate_operator( setattr(op, "start_from_trigger", bool(encoded_op.get("start_from_trigger", False))) @staticmethod - def set_task_dag_references(task: SchedulerOperator, dag: DAG) -> None: + def set_task_dag_references(task: SerializedOperator, dag: SerializedDAG) -> None: """ Handle DAG references on an operator. @@ -1684,11 +1692,11 @@ def deserialize_operator( cls, encoded_op: dict[str, Any], client_defaults: dict[str, Any] | None = None, - ) -> SchedulerOperator: + ) -> SerializedOperator: """Deserializes an operator from a JSON object.""" - op: SchedulerOperator + op: SerializedOperator if encoded_op.get("_is_mapped", False): - from airflow.models.mappedoperator import MappedOperator as SchedulerMappedOperator + from airflow.models.mappedoperator import MappedOperator as SerializedMappedOperator try: operator_name = encoded_op["_operator_name"] @@ -1702,7 +1710,7 @@ def deserialize_operator( "_operator_name": operator_name, } - op = SchedulerMappedOperator( + op = SerializedMappedOperator( operator_class=operator_class_info, task_id=encoded_op["task_id"], operator_extra_links=SerializedBaseOperator.operator_extra_links, @@ -2246,6 +2254,63 @@ def get_parse_time_mapped_ti_count(self) -> int: return group.get_parse_time_mapped_ti_count() +@provide_session +def _create_orm_dagrun( + *, + dag: SerializedDAG, + run_id: str, + logical_date: datetime.datetime | None, + data_interval: DataInterval | None, + run_after: datetime.datetime, + start_date: datetime.datetime | None, + conf: Any, + state: DagRunState | None, + run_type: DagRunType, + creating_job_id: int | None, + backfill_id: NonNegativeInt | None, + triggered_by: DagRunTriggeredByType, + triggering_user_name: str | None = None, + session: Session = NEW_SESSION, +) -> DagRun: + bundle_version = None + if not dag.disable_bundle_versioning: + bundle_version = session.scalar( + select(DagModel.bundle_version).where(DagModel.dag_id == dag.dag_id), + ) + dag_version = DagVersion.get_latest_version(dag.dag_id, session=session) + if not dag_version: + raise AirflowException(f"Cannot create DagRun for DAG {dag.dag_id} because the dag is not serialized") + + run = DagRun( + dag_id=dag.dag_id, + run_id=run_id, + logical_date=logical_date, + start_date=start_date, + run_after=run_after, + conf=conf, + state=state, + run_type=run_type, + creating_job_id=creating_job_id, + data_interval=data_interval, + triggered_by=triggered_by, + triggering_user_name=triggering_user_name, + backfill_id=backfill_id, + bundle_version=bundle_version, + ) + # Load defaults into the following two fields to ensure result can be serialized detached + run.log_template_id = int(session.scalar(select(func.max(LogTemplate.__table__.c.id)))) + run.created_dag_version = dag_version + run.consumed_asset_events = [] + session.add(run) + session.flush() + run.dag = dag + # create the associated task instances + # state is None at the moment of creation + run.verify_integrity(session=session, dag_version_id=dag_version.id) + return run + + +@attrs.define(hash=False, repr=False, eq=False, slots=False) class SerializedDAG(DAG, BaseSerialization): """ A JSON serializable representation of DAG. @@ -2255,7 +2320,12 @@ class SerializedDAG(DAG, BaseSerialization): strings. """ - _decorated_fields = {"default_args", "access_control"} + _decorated_fields: ClassVar[set[str]] = {"default_args", "access_control"} + + last_loaded: datetime.datetime | None = attrs.field(init=False, factory=utcnow) + # this will only be set at serialization time + # it's only use is for determining the relative fileloc based only on the serialize dag + _processor_dags_folder: str = attrs.field(init=False) @staticmethod def __get_constructor_defaults(): @@ -2271,10 +2341,10 @@ def __get_constructor_defaults(): _CONSTRUCTOR_PARAMS = __get_constructor_defaults.__func__() # type: ignore del __get_constructor_defaults - _json_schema: Validator = lazy_object_proxy.Proxy(load_dag_schema) + _json_schema: ClassVar[Validator] = lazy_object_proxy.Proxy(load_dag_schema) @classmethod - def serialize_dag(cls, dag: SdkDag) -> dict: + def serialize_dag(cls, dag: DAG) -> dict: """Serialize a DAG into a JSON object.""" try: serialized_dag = cls.serialize_to_json(dag, cls._decorated_fields) @@ -2365,7 +2435,7 @@ def deserialize_dag( None, # TODO (GH-52141): SerializedDAG's task_dict should contain # scheduler types instead, but currently it inherits SDK's DAG. - cast("dict[str, SchedulerOperator]", dag.task_dict), + cast("dict[str, SerializedOperator]", dag.task_dict), dag, ) object.__setattr__(dag, "task_group", tg) @@ -2392,7 +2462,7 @@ def deserialize_dag( # TODO (GH-52141): SerializedDAG's task_dict should contain scheduler # types instead, but currently it inherits SDK's DAG. for task in dag.task_dict.values(): - SerializedBaseOperator.set_task_dag_references(cast("SchedulerOperator", task), dag) + SerializedBaseOperator.set_task_dag_references(cast("SerializedOperator", task), dag) return dag @@ -2571,6 +2641,722 @@ def from_dict(cls, serialized_obj: dict) -> SerializedDAG: # Pass client_defaults directly to deserialize_dag return cls.deserialize_dag(serialized_obj["dag"], client_defaults) + @classmethod + @provide_session + def bulk_write_to_db( + cls, + bundle_name: str, + bundle_version: str | None, + dags: Collection[DAG | LazyDeserializedDAG], + session: Session = NEW_SESSION, + ) -> None: + """ + Ensure the DagModel rows for the given dags are up-to-date in the dag table in the DB. + + :param dags: the DAG objects to save to the DB + :return: None + """ + if not dags: + return + + from airflow.dag_processing.collection import AssetModelOperation, DagModelOperation + + log.info("Sync %s DAGs", len(dags)) + dag_op = DagModelOperation( + bundle_name=bundle_name, + bundle_version=bundle_version, + dags={d.dag_id: LazyDeserializedDAG.from_dag(d) for d in dags}, + ) + + orm_dags = dag_op.add_dags(session=session) + dag_op.update_dags(orm_dags, session=session) + + asset_op = AssetModelOperation.collect(dag_op.dags) + + orm_assets = asset_op.sync_assets(session=session) + orm_asset_aliases = asset_op.sync_asset_aliases(session=session) + session.flush() # This populates id so we can create fks in later calls. + + orm_dags = dag_op.find_orm_dags(session=session) # Refetch so relationship is up to date. + asset_op.add_dag_asset_references(orm_dags, orm_assets, session=session) + asset_op.add_dag_asset_alias_references(orm_dags, orm_asset_aliases, session=session) + asset_op.add_dag_asset_name_uri_references(session=session) + asset_op.add_task_asset_references(orm_dags, orm_assets, session=session) + asset_op.activate_assets_if_possible(orm_assets.values(), session=session) + session.flush() # Activation is needed when we add trigger references. + + asset_op.add_asset_trigger_references(orm_assets, session=session) + dag_op.update_dag_asset_expression(orm_dags=orm_dags, orm_assets=orm_assets) + session.flush() + + @cached_property + def _time_restriction(self) -> TimeRestriction: + start_dates = [t.start_date for t in self.tasks if t.start_date] + if self.start_date is not None: + start_dates.append(self.start_date) + earliest = None + if start_dates: + earliest = coerce_datetime(min(start_dates)) + latest = coerce_datetime(self.end_date) + end_dates = [t.end_date for t in self.tasks if t.end_date] + if len(end_dates) == len(self.tasks): # not exists null end_date + if self.end_date is not None: + end_dates.append(self.end_date) + if end_dates: + latest = coerce_datetime(max(end_dates)) + return TimeRestriction(earliest, latest, self.catchup) + + def next_dagrun_info( + self, + last_automated_dagrun: None | DataInterval, + *, + restricted: bool = True, + ) -> DagRunInfo | None: + """ + Get information about the next DagRun of this dag after ``date_last_automated_dagrun``. + + This calculates what time interval the next DagRun should operate on + (its logical date) and when it can be scheduled, according to the + dag's timetable, start_date, end_date, etc. This doesn't check max + active run or any other "max_active_tasks" type limits, but only + performs calculations based on the various date and interval fields of + this dag and its tasks. + + :param last_automated_dagrun: The ``max(logical_date)`` of + existing "automated" DagRuns for this dag (scheduled or backfill, + but not manual). + :param restricted: If set to *False* (default is *True*), ignore + ``start_date``, ``end_date``, and ``catchup`` specified on the DAG + or tasks. + :return: DagRunInfo of the next dagrun, or None if a dagrun is not + going to be scheduled. + """ + if restricted: + restriction = self._time_restriction + else: + restriction = TimeRestriction(earliest=None, latest=None, catchup=True) + try: + info = self.timetable.next_dagrun_info( + last_automated_data_interval=last_automated_dagrun, + restriction=restriction, + ) + except Exception: + log.exception( + "Failed to fetch run info after data interval %s for DAG %r", + last_automated_dagrun, + self.dag_id, + ) + info = None + return info + + def iter_dagrun_infos_between( + self, + earliest: datetime.datetime | None, + latest: datetime.datetime, + *, + align: bool = True, + ) -> Iterable[DagRunInfo]: + """ + Yield DagRunInfo using this DAG's timetable between given interval. + + DagRunInfo instances yielded if their ``logical_date`` is not earlier + than ``earliest``, nor later than ``latest``. The instances are ordered + by their ``logical_date`` from earliest to latest. + + If ``align`` is ``False``, the first run will happen immediately on + ``earliest``, even if it does not fall on the logical timetable schedule. + The default is ``True``. + + Example: A DAG is scheduled to run every midnight (``0 0 * * *``). If + ``earliest`` is ``2021-06-03 23:00:00``, the first DagRunInfo would be + ``2021-06-03 23:00:00`` if ``align=False``, and ``2021-06-04 00:00:00`` + if ``align=True``. + """ + if earliest is None: + earliest = self._time_restriction.earliest + if earliest is None: + raise ValueError("earliest was None and we had no value in time_restriction to fallback on") + earliest = coerce_datetime(earliest) + latest = coerce_datetime(latest) + + restriction = TimeRestriction(earliest, latest, catchup=True) + + try: + info = self.timetable.next_dagrun_info( + last_automated_data_interval=None, + restriction=restriction, + ) + except Exception: + log.exception( + "Failed to fetch run info after data interval %s for DAG %r", + None, + self.dag_id, + ) + info = None + + if info is None: + # No runs to be scheduled between the user-supplied timeframe. But + # if align=False, "invent" a data interval for the timeframe itself. + if not align: + yield DagRunInfo.interval(earliest, latest) + return + + # If align=False and earliest does not fall on the timetable's logical + # schedule, "invent" a data interval for it. + if not align and info.logical_date != earliest: + yield DagRunInfo.interval(earliest, info.data_interval.start) + + # Generate naturally according to schedule. + while info is not None: + yield info + try: + info = self.timetable.next_dagrun_info( + last_automated_data_interval=info.data_interval, + restriction=restriction, + ) + except Exception: + log.exception( + "Failed to fetch run info after data interval %s for DAG %r", + info.data_interval if info else "", + self.dag_id, + ) + break + + @provide_session + def get_concurrency_reached(self, session=NEW_SESSION) -> bool: + """Return a boolean indicating whether the max_active_tasks limit for this DAG has been reached.""" + from airflow.models.taskinstance import TaskInstance + + total_tasks = session.scalar( + select(func.count(TaskInstance.task_id)).where( + TaskInstance.dag_id == self.dag_id, + TaskInstance.state == TaskInstanceState.RUNNING, + ) + ) + return total_tasks >= self.max_active_tasks + + @provide_session + def create_dagrun( + self, + *, + run_id: str, + logical_date: datetime.datetime | None = None, + data_interval: tuple[datetime.datetime, datetime.datetime] | None = None, + run_after: datetime.datetime, + conf: dict | None = None, + run_type: DagRunType, + triggered_by: DagRunTriggeredByType, + triggering_user_name: str | None = None, + state: DagRunState, + start_date: datetime.datetime | None = None, + creating_job_id: int | None = None, + backfill_id: NonNegativeInt | None = None, + session: Session = NEW_SESSION, + ) -> DagRun: + """ + Create a run for this DAG to run its tasks. + + :param run_id: ID of the dag_run + :param logical_date: date of execution + :param run_after: the datetime before which dag won't run + :param conf: Dict containing configuration/parameters to pass to the DAG + :param triggered_by: the entity which triggers the dag_run + :param triggering_user_name: the user name who triggers the dag_run + :param start_date: the date this dag run should be evaluated + :param creating_job_id: ID of the job creating this DagRun + :param backfill_id: ID of the backfill run if one exists + :param session: Unused. Only added in compatibility with database isolation mode + :return: The created DAG run. + + :meta private: + """ + logical_date = coerce_datetime(logical_date) + # For manual runs where logical_date is None, ensure no data_interval is set. + if logical_date is None and data_interval is not None: + raise ValueError("data_interval must be None when logical_date is None") + + if data_interval and not isinstance(data_interval, DataInterval): + data_interval = DataInterval(*map(coerce_datetime, data_interval)) + + if isinstance(run_type, DagRunType): + pass + elif isinstance(run_type, str): # Ensure the input value is valid. + run_type = DagRunType(run_type) + else: + raise ValueError(f"run_type should be a DagRunType, not {type(run_type)}") + + if not isinstance(run_id, str): + raise ValueError(f"`run_id` should be a str, not {type(run_id)}") + + # This is also done on the DagRun model class, but SQLAlchemy column + # validator does not work well for some reason. + if not re.match(RUN_ID_REGEX, run_id): + regex = airflow_conf.get("scheduler", "allowed_run_id_pattern").strip() + if not regex or not re.match(regex, run_id): + raise ValueError( + f"The run_id provided '{run_id}' does not match regex pattern " + f"'{regex}' or '{RUN_ID_REGEX}'" + ) + + # Prevent a manual run from using an ID that looks like a scheduled run. + if run_type == DagRunType.MANUAL: + if (inferred_run_type := DagRunType.from_run_id(run_id)) != DagRunType.MANUAL: + raise ValueError( + f"A {run_type.value} DAG run cannot use ID {run_id!r} since it " + f"is reserved for {inferred_run_type.value} runs" + ) + + # todo: AIP-78 add verification that if run type is backfill then we have a backfill id + + # create a copy of params before validating + copied_params = copy.deepcopy(self.params) + if conf: + copied_params.update(conf) + copied_params.validate() + orm_dagrun = _create_orm_dagrun( + dag=self, + run_id=run_id, + logical_date=logical_date, + data_interval=data_interval, + run_after=coerce_datetime(run_after), + start_date=coerce_datetime(start_date), + conf=conf, + state=state, + run_type=run_type, + creating_job_id=creating_job_id, + backfill_id=backfill_id, + triggered_by=triggered_by, + triggering_user_name=triggering_user_name, + session=session, + ) + + if self.deadline and isinstance(self.deadline.reference, DeadlineReference.TYPES.DAGRUN): + session.add( + Deadline( + deadline_time=self.deadline.reference.evaluate_with( + session=session, + interval=self.deadline.interval, + dag_id=self.dag_id, + run_id=run_id, + ), + callback=self.deadline.callback, + dagrun_id=orm_dagrun.id, + ) + ) + + return orm_dagrun + + @provide_session + def set_task_instance_state( + self, + *, + task_id: str, + map_indexes: Collection[int] | None = None, + run_id: str | None = None, + state: TaskInstanceState, + upstream: bool = False, + downstream: bool = False, + future: bool = False, + past: bool = False, + commit: bool = True, + session=NEW_SESSION, + ) -> list[TaskInstance]: + """ + Set the state of a TaskInstance and clear downstream tasks in failed or upstream_failed state. + + :param task_id: Task ID of the TaskInstance + :param map_indexes: Only set TaskInstance if its map_index matches. + If None (default), all mapped TaskInstances of the task are set. + :param run_id: The run_id of the TaskInstance + :param state: State to set the TaskInstance to + :param upstream: Include all upstream tasks of the given task_id + :param downstream: Include all downstream tasks of the given task_id + :param future: Include all future TaskInstances of the given task_id + :param commit: Commit changes + :param past: Include all past TaskInstances of the given task_id + """ + from airflow.api.common.mark_tasks import set_state + + # TODO (GH-52141): get_task in scheduler needs to return scheduler types + # instead, but currently it inherits SDK's DAG. + task = cast("SerializedOperator", self.get_task(task_id)) + task.dag = self + + tasks_to_set_state: list[SerializedOperator | tuple[SerializedOperator, int]] + if map_indexes is None: + tasks_to_set_state = [task] + else: + tasks_to_set_state = [(task, map_index) for map_index in map_indexes] + + altered = set_state( + tasks=tasks_to_set_state, + run_id=run_id, + upstream=upstream, + downstream=downstream, + future=future, + past=past, + state=state, + commit=commit, + session=session, + ) + + if not commit: + return altered + + # Clear downstream tasks that are in failed/upstream_failed state to resume them. + # Flush the session so that the tasks marked success are reflected in the db. + session.flush() + subset = self.partial_subset( + task_ids={task_id}, + include_downstream=True, + include_upstream=False, + ) + + # Raises an error if not found + dr_id, logical_date = session.execute( + select(DagRun.id, DagRun.logical_date).where( + DagRun.run_id == run_id, DagRun.dag_id == self.dag_id + ) + ).one() + + # Now we want to clear downstreams of tasks that had their state set... + clear_kwargs = { + "only_failed": True, + "session": session, + # Exclude the task itself from being cleared. + "exclude_task_ids": frozenset((task_id,)), + } + if not future and not past: # Simple case 1: we're only dealing with exactly one run. + clear_kwargs["run_id"] = run_id + subset.clear(**clear_kwargs) + elif future and past: # Simple case 2: we're clearing ALL runs. + subset.clear(**clear_kwargs) + else: # Complex cases: we may have more than one run, based on a date range. + # Make 'future' and 'past' make some sense when multiple runs exist + # for the same logical date. We order runs by their id and only + # clear runs have larger/smaller ids. + exclude_run_id_stmt = select(DagRun.run_id).where(DagRun.logical_date == logical_date) + if future: + clear_kwargs["start_date"] = logical_date + exclude_run_id_stmt = exclude_run_id_stmt.where(DagRun.id > dr_id) + else: + clear_kwargs["end_date"] = logical_date + exclude_run_id_stmt = exclude_run_id_stmt.where(DagRun.id < dr_id) + subset.clear(exclude_run_ids=frozenset(session.scalars(exclude_run_id_stmt)), **clear_kwargs) + return altered + + @overload + def _get_task_instances( + self, + *, + task_ids: Collection[str | tuple[str, int]] | None, + start_date: datetime.datetime | None, + end_date: datetime.datetime | None, + run_id: str | None, + state: TaskInstanceState | Sequence[TaskInstanceState], + exclude_task_ids: Collection[str | tuple[str, int]] | None, + exclude_run_ids: frozenset[str] | None, + session: Session, + ) -> Iterable[TaskInstance]: ... # pragma: no cover + + @overload + def _get_task_instances( + self, + *, + task_ids: Collection[str | tuple[str, int]] | None, + as_pk_tuple: Literal[True], + start_date: datetime.datetime | None, + end_date: datetime.datetime | None, + run_id: str | None, + state: TaskInstanceState | Sequence[TaskInstanceState], + exclude_task_ids: Collection[str | tuple[str, int]] | None, + exclude_run_ids: frozenset[str] | None, + session: Session, + ) -> set[TaskInstanceKey]: ... # pragma: no cover + + def _get_task_instances( + self, + *, + task_ids: Collection[str | tuple[str, int]] | None, + as_pk_tuple: Literal[True, None] = None, + start_date: datetime.datetime | None, + end_date: datetime.datetime | None, + run_id: str | None, + state: TaskInstanceState | Sequence[TaskInstanceState], + exclude_task_ids: Collection[str | tuple[str, int]] | None, + exclude_run_ids: frozenset[str] | None, + session: Session, + ) -> Iterable[TaskInstance] | set[TaskInstanceKey]: + from airflow.models.taskinstance import TaskInstance + + # If we are looking at dependent dags we want to avoid UNION calls + # in SQL (it doesn't play nice with fields that have no equality operator, + # like JSON types), we instead build our result set separately. + # + # This will be empty if we are only looking at one dag, in which case + # we can return the filtered TI query object directly. + result: set[TaskInstanceKey] = set() + + # Do we want full objects, or just the primary columns? + if as_pk_tuple: + tis = select( + TaskInstance.dag_id, + TaskInstance.task_id, + TaskInstance.run_id, + TaskInstance.map_index, + ) + else: + tis = select(TaskInstance) + tis = tis.join(TaskInstance.dag_run) + + if self.partial: + tis = tis.where(TaskInstance.dag_id == self.dag_id, TaskInstance.task_id.in_(self.task_ids)) + else: + tis = tis.where(TaskInstance.dag_id == self.dag_id) + if run_id: + tis = tis.where(TaskInstance.run_id == run_id) + if start_date: + tis = tis.where(DagRun.logical_date >= start_date) + if task_ids is not None: + tis = tis.where(TaskInstance.ti_selector_condition(task_ids)) + if end_date: + tis = tis.where(DagRun.logical_date <= end_date) + + if state: + if isinstance(state, (str, TaskInstanceState)): + tis = tis.where(TaskInstance.state == state) + elif len(state) == 1: + tis = tis.where(TaskInstance.state == state[0]) + else: + # this is required to deal with NULL values + if None in state: + if all(x is None for x in state): + tis = tis.where(TaskInstance.state.is_(None)) + else: + not_none_state = [s for s in state if s] + tis = tis.where( + or_(TaskInstance.state.in_(not_none_state), TaskInstance.state.is_(None)) + ) + else: + tis = tis.where(TaskInstance.state.in_(state)) + + if exclude_run_ids: + tis = tis.where(TaskInstance.run_id.not_in(exclude_run_ids)) + + if result or as_pk_tuple: + # Only execute the `ti` query if we have also collected some other results + if as_pk_tuple: + tis_query = session.execute(tis).all() + result.update(TaskInstanceKey(**cols._mapping) for cols in tis_query) + else: + result.update(ti.key for ti in session.scalars(tis)) + + if exclude_task_ids is not None: + result = { + task + for task in result + if task.task_id not in exclude_task_ids + and (task.task_id, task.map_index) not in exclude_task_ids + } + + if as_pk_tuple: + return result + if result: + # We've been asked for objects, lets combine it all back in to a result set + ti_filters = TaskInstance.filter_for_tis(result) + if ti_filters is not None: + tis = select(TaskInstance).where(ti_filters) + elif exclude_task_ids is None: + pass # Disable filter if not set. + elif isinstance(next(iter(exclude_task_ids), None), str): + tis = tis.where(TaskInstance.task_id.notin_(exclude_task_ids)) + else: + tis = tis.where(tuple_(TaskInstance.task_id, TaskInstance.map_index).not_in(exclude_task_ids)) + + return tis + + @overload + def clear( + self, + *, + dry_run: Literal[True], + task_ids: Collection[str | tuple[str, int]] | None = None, + run_id: str, + only_failed: bool = False, + only_running: bool = False, + dag_run_state: DagRunState = DagRunState.QUEUED, + session: Session = NEW_SESSION, + exclude_task_ids: frozenset[str] | frozenset[tuple[str, int]] | None = frozenset(), + exclude_run_ids: frozenset[str] | None = frozenset(), + run_on_latest_version: bool = False, + ) -> list[TaskInstance]: ... # pragma: no cover + + @overload + def clear( + self, + *, + task_ids: Collection[str | tuple[str, int]] | None = None, + run_id: str, + only_failed: bool = False, + only_running: bool = False, + dag_run_state: DagRunState = DagRunState.QUEUED, + dry_run: Literal[False] = False, + session: Session = NEW_SESSION, + exclude_task_ids: frozenset[str] | frozenset[tuple[str, int]] | None = frozenset(), + exclude_run_ids: frozenset[str] | None = frozenset(), + run_on_latest_version: bool = False, + ) -> int: ... # pragma: no cover + + @overload + def clear( + self, + *, + dry_run: Literal[True], + task_ids: Collection[str | tuple[str, int]] | None = None, + start_date: datetime.datetime | None = None, + end_date: datetime.datetime | None = None, + only_failed: bool = False, + only_running: bool = False, + dag_run_state: DagRunState = DagRunState.QUEUED, + session: Session = NEW_SESSION, + exclude_task_ids: frozenset[str] | frozenset[tuple[str, int]] | None = frozenset(), + exclude_run_ids: frozenset[str] | None = frozenset(), + run_on_latest_version: bool = False, + ) -> list[TaskInstance]: ... # pragma: no cover + + @overload + def clear( + self, + *, + task_ids: Collection[str | tuple[str, int]] | None = None, + start_date: datetime.datetime | None = None, + end_date: datetime.datetime | None = None, + only_failed: bool = False, + only_running: bool = False, + dag_run_state: DagRunState = DagRunState.QUEUED, + dry_run: Literal[False] = False, + session: Session = NEW_SESSION, + exclude_task_ids: frozenset[str] | frozenset[tuple[str, int]] | None = frozenset(), + exclude_run_ids: frozenset[str] | None = frozenset(), + run_on_latest_version: bool = False, + ) -> int: ... # pragma: no cover + + @provide_session + def clear( + self, + task_ids: Collection[str | tuple[str, int]] | None = None, + *, + run_id: str | None = None, + start_date: datetime.datetime | None = None, + end_date: datetime.datetime | None = None, + only_failed: bool = False, + only_running: bool = False, + dag_run_state: DagRunState = DagRunState.QUEUED, + dry_run: bool = False, + session: Session = NEW_SESSION, + exclude_task_ids: frozenset[str] | frozenset[tuple[str, int]] | None = frozenset(), + exclude_run_ids: frozenset[str] | None = frozenset(), + run_on_latest_version: bool = False, + ) -> int | Iterable[TaskInstance]: + """ + Clear a set of task instances associated with the current dag for a specified date range. + + :param task_ids: List of task ids or (``task_id``, ``map_index``) tuples to clear + :param run_id: The run_id for which the tasks should be cleared + :param start_date: The minimum logical_date to clear + :param end_date: The maximum logical_date to clear + :param only_failed: Only clear failed tasks + :param only_running: Only clear running tasks. + :param dag_run_state: state to set DagRun to. If set to False, dagrun state will not + be changed. + :param dry_run: Find the tasks to clear but don't clear them. + :param run_on_latest_version: whether to run on latest serialized DAG and Bundle version + :param session: The sqlalchemy session to use + :param exclude_task_ids: A set of ``task_id`` or (``task_id``, ``map_index``) + tuples that should not be cleared + :param exclude_run_ids: A set of ``run_id`` or (``run_id``) + """ + from airflow.models.taskinstance import clear_task_instances + + state: list[TaskInstanceState] = [] + if only_failed: + state += [TaskInstanceState.FAILED, TaskInstanceState.UPSTREAM_FAILED] + if only_running: + # Yes, having `+=` doesn't make sense, but this was the existing behaviour + state += [TaskInstanceState.RUNNING] + + tis = self._get_task_instances( + task_ids=task_ids, + start_date=start_date, + end_date=end_date, + run_id=run_id, + state=state, + session=session, + exclude_task_ids=exclude_task_ids, + exclude_run_ids=exclude_run_ids, + ) + + if dry_run: + return session.scalars(tis).all() + + tis = session.scalars(tis).all() + + count = len(list(tis)) + if count == 0: + return 0 + + clear_task_instances( + list(tis), + session, + dag_run_state=dag_run_state, + run_on_latest_version=run_on_latest_version, + ) + + session.flush() + return count + + @classmethod + def clear_dags( + cls, + dags, + start_date=None, + end_date=None, + only_failed=False, + only_running=False, + dag_run_state=DagRunState.QUEUED, + dry_run=False, + ): + def _coerce_dag(dag): + if isinstance(dag, SerializedDAG): + return dag + return SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag)) + + if dry_run: + tis = itertools.chain.from_iterable( + _coerce_dag(dag).clear( + start_date=start_date, + end_date=end_date, + only_failed=only_failed, + only_running=only_running, + dag_run_state=dag_run_state, + dry_run=True, + ) + for dag in dags + ) + return list(tis) + + return sum( + _coerce_dag(dag).clear( + start_date=start_date, + end_date=end_date, + only_failed=only_failed, + only_running=only_running, + dag_run_state=dag_run_state, + dry_run=False, + ) + for dag in dags + ) + class TaskGroupSerialization(BaseSerialization): """JSON serializable representation of a task group.""" @@ -2615,7 +3401,7 @@ def deserialize_task_group( cls, encoded_group: dict[str, Any], parent_group: TaskGroup | None, - task_dict: dict[str, SchedulerOperator], + task_dict: dict[str, SerializedOperator], dag: SerializedDAG, ) -> TaskGroup: """Deserializes a TaskGroup from a JSON object.""" @@ -2638,7 +3424,7 @@ def deserialize_task_group( **kwargs, ) - def set_ref(task: SchedulerOperator) -> SchedulerOperator: + def set_ref(task: SerializedOperator) -> SerializedOperator: task.task_group = weakref.proxy(group) return task @@ -2683,8 +3469,7 @@ def _has_kubernetes() -> bool: return HAS_KUBERNETES -AssetT = TypeVar("AssetT", bound=BaseAsset) -MaybeSerializedDAG: TypeAlias = "DAG | LazyDeserializedDAG" +AssetT = TypeVar("AssetT", bound=BaseAsset, covariant=True) class LazyDeserializedDAG(pydantic.BaseModel): @@ -2696,12 +3481,14 @@ class LazyDeserializedDAG(pydantic.BaseModel): """ data: dict + last_loaded: datetime.datetime | None = None NULLABLE_PROPERTIES: ClassVar[set[str]] = { "is_paused_upon_creation", "owner", "dag_display_name", "description", + "relative_fileloc", "max_active_tasks", "max_active_runs", "max_consecutive_failed_dag_runs", @@ -2709,6 +3496,12 @@ class LazyDeserializedDAG(pydantic.BaseModel): "access_control", } + @classmethod + def from_dag(cls, dag: DAG | LazyDeserializedDAG) -> LazyDeserializedDAG: + if isinstance(dag, LazyDeserializedDAG): + return dag + return cls(data=SerializedDAG.to_dict(dag)) + @property def hash(self) -> str: from airflow.models.serialized_dag import SerializedDagModel @@ -2759,50 +3552,6 @@ def owner(self) -> str: set(filter(None, (task[Encoding.VAR].get("owner") for task in self.data["dag"]["tasks"]))) ) - @staticmethod - def _get_mapped_operator_ports(task: dict, direction: str): - return task["partial_kwargs"][direction] - - @staticmethod - def _get_base_operator_ports(task: dict, direction: str): - return task[direction] - - def get_task_assets( - self, - inlets: bool = True, - outlets: bool = True, - of_type: type[AssetT] = Asset, # type: ignore[assignment] - ) -> Generator[tuple[str, AssetT], None, None]: - for task in self.data["dag"]["tasks"]: - task = task[Encoding.VAR] - if task.get("_is_mapped"): - ports_getter = self._get_mapped_operator_ports - else: - ports_getter = self._get_base_operator_ports - directions: tuple[str, ...] = ("inlets",) if inlets else () - if outlets: - directions += ("outlets",) - for direction in directions: - try: - ports = ports_getter(task, direction) - except KeyError: - continue - for port in ports: - obj = BaseSerialization.deserialize(port) - if isinstance(obj, of_type): - yield task["task_id"], obj - - def get_run_data_interval(self, run: DagRun) -> DataInterval | None: - """Get the data interval of this run.""" - if run.dag_id is not None and run.dag_id != self.dag_id: - raise ValueError(f"Arguments refer to different DAGs: {self.dag_id} != {run.dag_id}") - - data_interval = _get_model_data_interval(run, "data_interval_start", "data_interval_end") - if data_interval is None and run.logical_date is not None: - data_interval = self._real_dag.timetable.infer_manual_data_interval(run_after=run.logical_date) - - return data_interval - @attrs.define() class XComOperatorLink(LoggingMixin): @@ -2848,13 +3597,13 @@ def create_scheduler_operator(op: BaseOperator | SerializedBaseOperator) -> Seri @overload -def create_scheduler_operator(op: MappedOperator | SchedulerMappedOperator) -> SchedulerMappedOperator: ... +def create_scheduler_operator(op: MappedOperator | SerializedMappedOperator) -> SerializedMappedOperator: ... -def create_scheduler_operator(op: SdkOperator | SchedulerOperator) -> SchedulerOperator: - from airflow.models.mappedoperator import MappedOperator as SchedulerMappedOperator +def create_scheduler_operator(op: SdkOperator | SerializedOperator) -> SerializedOperator: + from airflow.models.mappedoperator import MappedOperator as SerializedMappedOperator - if isinstance(op, (SerializedBaseOperator, SchedulerMappedOperator)): + if isinstance(op, (SerializedBaseOperator, SerializedMappedOperator)): return op if isinstance(op, BaseOperator): d = SerializedBaseOperator.serialize_operator(op) diff --git a/airflow-core/src/airflow/ti_deps/deps/dag_unpaused_dep.py b/airflow-core/src/airflow/ti_deps/deps/dag_unpaused_dep.py index a854d21632d7d..5dec655d4e054 100644 --- a/airflow-core/src/airflow/ti_deps/deps/dag_unpaused_dep.py +++ b/airflow-core/src/airflow/ti_deps/deps/dag_unpaused_dep.py @@ -17,6 +17,8 @@ # under the License. from __future__ import annotations +from sqlalchemy import select + from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.utils.session import provide_session @@ -27,7 +29,14 @@ class DagUnpausedDep(BaseTIDep): NAME = "Dag Not Paused" IGNORABLE = True + @staticmethod + def _is_dag_paused(dag_id: str, session) -> bool: + """Check if a dag is paused. Extracted to simplify testing.""" + from airflow.models.dag import DagModel + + return session.scalar(select(DagModel.is_paused).where(DagModel.dag_id == dag_id)) + @provide_session def _get_dep_statuses(self, ti, session, dep_context): - if ti.task.dag.get_is_paused(session): + if self._is_dag_paused(ti.dag_id, session): yield self._failing_status(reason=f"Task's DAG '{ti.dag_id}' is paused.") diff --git a/airflow-core/src/airflow/utils/cli.py b/airflow-core/src/airflow/utils/cli.py index b33e50cc12c91..bf1a83838c8f2 100644 --- a/airflow-core/src/airflow/utils/cli.py +++ b/airflow-core/src/airflow/utils/cli.py @@ -46,7 +46,8 @@ T = TypeVar("T", bound=Callable) if TYPE_CHECKING: - from airflow.models.dag import DAG + from airflow.sdk import DAG + from airflow.serialization.serialized_objects import SerializedDAG logger = logging.getLogger(__name__) @@ -261,9 +262,7 @@ def _search_for_dag_file(val: str | None) -> str | None: return None -def get_dag( - bundle_names: list | None, dag_id: str, from_db: bool = False, dagfile_path: str | None = None -) -> DAG: +def get_bagged_dag(bundle_names: list | None, dag_id: str, dagfile_path: str | None = None) -> DAG: """ Return DAG of a given dag_id. @@ -272,44 +271,45 @@ def get_dag( dags folder. """ from airflow.models.dagbag import DagBag, sync_bag_to_db - from airflow.models.serialized_dag import SerializedDagModel - bundle_names = bundle_names or [] - dag: DAG | None = None - if from_db: - dag = SerializedDagModel.get_dag(dag_id) - elif bundle_names: - manager = DagBundlesManager() - for bundle_name in bundle_names: - bundle = manager.get_bundle(bundle_name) - with _airflow_parsing_context_manager(dag_id=dag_id): - dagbag = DagBag( - dag_folder=dagfile_path or bundle.path, bundle_path=bundle.path, include_examples=False - ) - if dag := dagbag.dags.get(dag_id): - break - if not dag: - if from_db: - raise AirflowException(f"Dag {dag_id!r} could not be found in DagBag read from database.") - manager = DagBundlesManager() - manager.sync_bundles_to_db() - all_bundles = list(manager.get_all_dag_bundles()) - for bundle in all_bundles: - bundle.initialize() - - with _airflow_parsing_context_manager(dag_id=dag_id): - dagbag = DagBag( - dag_folder=dagfile_path or bundle.path, bundle_path=bundle.path, include_examples=False - ) - sync_bag_to_db(dagbag, bundle.name, bundle.version) - dag = dagbag.dags.get(dag_id) - if dag: - break - if not dag: - raise AirflowException( - f"Dag {dag_id!r} could not be found; either it does not exist or it failed to parse." + manager = DagBundlesManager() + for bundle_name in bundle_names or (): + bundle = manager.get_bundle(bundle_name) + with _airflow_parsing_context_manager(dag_id=dag_id): + dagbag = DagBag( + dag_folder=dagfile_path or bundle.path, bundle_path=bundle.path, include_examples=False + ) + if dag := dagbag.dags.get(dag_id): + return dag + + manager.sync_bundles_to_db() + for bundle in manager.get_all_dag_bundles(): + bundle.initialize() + with _airflow_parsing_context_manager(dag_id=dag_id): + dagbag = DagBag( + dag_folder=dagfile_path or bundle.path, bundle_path=bundle.path, include_examples=False ) - return dag + sync_bag_to_db(dagbag, bundle.name, bundle.version) + if dag := dagbag.dags.get(dag_id): + return dag + if dag: + break + raise AirflowException( + f"Dag {dag_id!r} could not be found; either it does not exist or it failed to parse." + ) + + +def _get_db_dag(bundle_names: list | None, dag_id: str, dagfile_path: str | None = None) -> SerializedDAG: + """ + Return DAG of a given dag_id. + + This gets a serialized dag from the database. + """ + from airflow.models.serialized_dag import SerializedDagModel + + if dag := SerializedDagModel.get_dag(dag_id): + return dag + raise AirflowException(f"Dag {dag_id!r} could not be found in the database.") def get_dags(bundle_names: list | None, dag_id: str, use_regex: bool = False, from_db: bool = False): @@ -319,7 +319,9 @@ def get_dags(bundle_names: list | None, dag_id: str, use_regex: bool = False, fr bundle_names = bundle_names or [] if not use_regex: - return [get_dag(bundle_names=bundle_names, dag_id=dag_id, from_db=from_db)] + if from_db: + return [_get_db_dag(bundle_names=bundle_names, dag_id=dag_id)] + return [get_bagged_dag(bundle_names=bundle_names, dag_id=dag_id)] def _find_dag(bundle): dagbag = DagBag(dag_folder=bundle.path, bundle_path=bundle.path) diff --git a/airflow-core/src/airflow/utils/db_manager.py b/airflow-core/src/airflow/utils/db_manager.py index 37d09b4a60a62..59f4921ce9286 100644 --- a/airflow-core/src/airflow/utils/db_manager.py +++ b/airflow-core/src/airflow/utils/db_manager.py @@ -23,7 +23,6 @@ from sqlalchemy import inspect from airflow import settings -from airflow.api_fastapi.app import create_auth_manager from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.utils.log.logging_mixin import LoggingMixin @@ -145,6 +144,8 @@ class RunDBManager(LoggingMixin): """ def __init__(self): + from airflow.api_fastapi.app import create_auth_manager + super().__init__() self._managers: list[BaseDBManager] = [] managers_config = conf.get("database", "external_db_managers", fallback=None) diff --git a/airflow-core/src/airflow/utils/dot_renderer.py b/airflow-core/src/airflow/utils/dot_renderer.py index 50911572d3971..66b834922698a 100644 --- a/airflow-core/src/airflow/utils/dot_renderer.py +++ b/airflow-core/src/airflow/utils/dot_renderer.py @@ -24,7 +24,7 @@ from typing import TYPE_CHECKING, Any from airflow.exceptions import AirflowException -from airflow.sdk import BaseOperator +from airflow.sdk import DAG, BaseOperator from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sdk.definitions.taskgroup import TaskGroup from airflow.serialization.serialized_objects import SerializedBaseOperator @@ -35,7 +35,6 @@ import graphviz from airflow.models import TaskInstance - from airflow.models.dag import DAG from airflow.models.taskmixin import DependencyMixin from airflow.serialization.dag_dependency import DagDependency else: diff --git a/airflow-core/tests/integration/otel/test_otel.py b/airflow-core/tests/integration/otel/test_otel.py index 03f427227c113..d7e124740670c 100644 --- a/airflow-core/tests/integration/otel/test_otel.py +++ b/airflow-core/tests/integration/otel/test_otel.py @@ -24,6 +24,7 @@ import time import pytest +from sqlalchemy import select from airflow._shared.timezones import timezone from airflow.dag_processing.bundles.manager import DagBundlesManager @@ -32,6 +33,7 @@ from airflow.models import DAG, DagBag, DagRun from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import TaskInstance +from airflow.serialization.serialized_objects import SerializedDAG from airflow.utils.session import create_session from airflow.utils.span_status import SpanStatus from airflow.utils.state import State @@ -45,7 +47,7 @@ extract_spans_from_output, get_parent_child_dict, ) -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS log = logging.getLogger("integration.otel.test_otel") @@ -646,7 +648,7 @@ def setup_class(cls): cls.dags = cls.serialize_and_get_dags() @classmethod - def serialize_and_get_dags(cls) -> dict[str, DAG]: + def serialize_and_get_dags(cls) -> dict[str, SerializedDAG]: log.info("Serializing Dags from directory %s", cls.dag_folder) # Load DAGs from the dag directory. dag_bag = DagBag(dag_folder=cls.dag_folder, include_examples=False) @@ -654,14 +656,11 @@ def serialize_and_get_dags(cls) -> dict[str, DAG]: dag_ids = dag_bag.dag_ids assert len(dag_ids) == 3 - dag_dict: dict[str, DAG] = {} + dag_dict: dict[str, SerializedDAG] = {} with create_session() as session: for dag_id in dag_ids: dag = dag_bag.get_dag(dag_id) - dag_dict[dag_id] = dag - assert dag is not None, f"DAG with ID {dag_id} not found." - # Sync the DAG to the database. if AIRFLOW_V_3_0_PLUS: from airflow.models.dagbundle import DagBundleModel @@ -669,13 +668,22 @@ def serialize_and_get_dags(cls) -> dict[str, DAG]: if session.query(DagBundleModel).filter(DagBundleModel.name == "testing").count() == 0: session.add(DagBundleModel(name="testing")) session.commit() - dag.bulk_write_to_db( + SerializedDAG.bulk_write_to_db( bundle_name="testing", bundle_version=None, dags=[dag], session=session ) + dag_dict[dag_id] = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag)) else: dag.sync_to_db(session=session) + dag_dict[dag_id] = dag # Manually serialize the dag and write it to the db to avoid a db error. - SerializedDagModel.write_dag(dag, bundle_name="testing", session=session) + if AIRFLOW_V_3_1_PLUS: + from airflow.serialization.serialized_objects import LazyDeserializedDAG + + SerializedDagModel.write_dag( + LazyDeserializedDAG.from_dag(dag), bundle_name="testing", session=session + ) + else: + SerializedDagModel.write_dag(dag, bundle_name="testing", session=session) session.commit() @@ -744,13 +752,12 @@ def test_dag_execution_succeeds(self, monkeypatch, celery_worker_env_vars, capfd time.sleep(10) with create_session() as session: - tis: list[TaskInstance] = dag.get_task_instances(session=session) - - for ti in tis: - # Skip the span_status check. - check_ti_state_and_span_status( - task_id=ti.task_id, run_id=run_id, state=State.SUCCESS, span_status=None - ) + task_ids = session.scalars(select(TaskInstance.task_id).where(TaskInstance.dag_id == dag_id)) + for task_id in task_ids: + # Skip the span_status check. + check_ti_state_and_span_status( + task_id=task_id, run_id=run_id, state=State.SUCCESS, span_status=None + ) print_ti_output_for_dag_run(dag_id=dag_id, run_id=run_id) finally: @@ -818,12 +825,10 @@ def test_same_scheduler_processing_the_entire_dag( time.sleep(10) with create_session() as session: - tis: list[TaskInstance] = dag.get_task_instances(session=session) - - for ti in tis: - check_ti_state_and_span_status( - task_id=ti.task_id, run_id=run_id, state=State.SUCCESS, span_status=SpanStatus.ENDED - ) + for ti in session.scalars(select(TaskInstance).where(TaskInstance.dag_id == dag.dag_id)): + check_ti_state_and_span_status( + task_id=ti.task_id, run_id=run_id, state=State.SUCCESS, span_status=SpanStatus.ENDED + ) print_ti_output_for_dag_run(dag_id=dag_id, run_id=run_id) finally: diff --git a/airflow-core/tests/unit/api/common/test_mark_tasks.py b/airflow-core/tests/unit/api/common/test_mark_tasks.py index 35dfe7b58793b..3d4d8641caeeb 100644 --- a/airflow-core/tests/unit/api/common/test_mark_tasks.py +++ b/airflow-core/tests/unit/api/common/test_mark_tasks.py @@ -28,14 +28,15 @@ if TYPE_CHECKING: from airflow.models.taskinstance import TaskInstance + from airflow.serialization.serialized_objects import SerializedDAG from tests_common.pytest_plugin import DagMaker -pytestmark = pytest.mark.db_test +pytestmark = [pytest.mark.db_test, pytest.mark.need_serialized_dag] -def test_set_dag_run_state_to_failed(dag_maker: DagMaker): - with dag_maker("TEST_DAG_1"): +def test_set_dag_run_state_to_failed(dag_maker: DagMaker[SerializedDAG]): + with dag_maker("TEST_DAG_1") as dag: with EmptyOperator(task_id="teardown").as_teardown(): EmptyOperator(task_id="running") EmptyOperator(task_id="pending") @@ -44,10 +45,9 @@ def test_set_dag_run_state_to_failed(dag_maker: DagMaker): if ti.task_id == "running": ti.set_state(TaskInstanceState.RUNNING) dag_maker.session.flush() - assert dr.dag updated_tis: list[TaskInstance] = set_dag_run_state_to_failed( - dag=dr.dag, run_id=dr.run_id, commit=True, session=dag_maker.session + dag=dag, run_id=dr.run_id, commit=True, session=dag_maker.session ) assert len(updated_tis) == 2 task_dict = {ti.task_id: ti for ti in updated_tis} @@ -59,8 +59,11 @@ def test_set_dag_run_state_to_failed(dag_maker: DagMaker): @pytest.mark.parametrize( "unfinished_state", sorted([state for state in State.unfinished if state is not None]) ) -def test_set_dag_run_state_to_success_unfinished_teardown(dag_maker: DagMaker, unfinished_state): - with dag_maker("TEST_DAG_1"): +def test_set_dag_run_state_to_success_unfinished_teardown( + dag_maker: DagMaker[SerializedDAG], + unfinished_state, +): + with dag_maker("TEST_DAG_1") as dag: with EmptyOperator(task_id="teardown").as_teardown(): EmptyOperator(task_id="running") EmptyOperator(task_id="pending") @@ -73,11 +76,10 @@ def test_set_dag_run_state_to_success_unfinished_teardown(dag_maker: DagMaker, u ti.set_state(unfinished_state) dag_maker.session.flush() - assert dr.dag assert dr.state == DagRunState.RUNNING updated_tis: list[TaskInstance] = set_dag_run_state_to_success( - dag=dr.dag, run_id=dr.run_id, commit=True, session=dag_maker.session + dag=dag, run_id=dr.run_id, commit=True, session=dag_maker.session ) run = dag_maker.session.scalar(select(DagRun).filter_by(dag_id=dr.dag_id, run_id=dr.run_id)) assert run.state != DagRunState.SUCCESS @@ -89,8 +91,8 @@ def test_set_dag_run_state_to_success_unfinished_teardown(dag_maker: DagMaker, u @pytest.mark.parametrize("finished_state", sorted(list(State.finished))) -def test_set_dag_run_state_to_success_finished_teardown(dag_maker: DagMaker, finished_state): - with dag_maker("TEST_DAG_1"): +def test_set_dag_run_state_to_success_finished_teardown(dag_maker: DagMaker[SerializedDAG], finished_state): + with dag_maker("TEST_DAG_1") as dag: with EmptyOperator(task_id="teardown").as_teardown(): EmptyOperator(task_id="failed") dr = dag_maker.create_dagrun() @@ -101,10 +103,9 @@ def test_set_dag_run_state_to_success_finished_teardown(dag_maker: DagMaker, fin ti.set_state(finished_state) dag_maker.session.flush() dr.set_state(DagRunState.FAILED) - assert dr.dag updated_tis: list[TaskInstance] = set_dag_run_state_to_success( - dag=dr.dag, run_id=dr.run_id, commit=True, session=dag_maker.session + dag=dag, run_id=dr.run_id, commit=True, session=dag_maker.session ) run = dag_maker.session.scalar(select(DagRun).filter_by(dag_id=dr.dag_id, run_id=dr.run_id)) assert run.state == DagRunState.SUCCESS diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_parsing.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_parsing.py index a513d2ddebab0..d2abb5e672a5a 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_parsing.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_parsing.py @@ -56,7 +56,7 @@ def test_201_and_400_requests(self, url_safe_serializer, session, test_client): assert response.status_code == 201 parsing_requests = session.scalars(select(DagPriorityParsingRequest)).all() assert len(parsing_requests) == 1 - assert parsing_requests[0].bundle_name == test_dag.get_bundle_name() + assert parsing_requests[0].bundle_name == "dags-folder" assert parsing_requests[0].relative_fileloc == test_dag.relative_fileloc _check_last_log(session, dag_id=None, event="reparse_dag_file", logical_date=None) @@ -65,7 +65,7 @@ def test_201_and_400_requests(self, url_safe_serializer, session, test_client): assert response.status_code == 409 parsing_requests = session.scalars(select(DagPriorityParsingRequest)).all() assert len(parsing_requests) == 1 - assert parsing_requests[0].bundle_name == test_dag.get_bundle_name() + assert parsing_requests[0].bundle_name == "dags-folder" assert parsing_requests[0].relative_fileloc == test_dag.relative_fileloc _check_last_log(session, dag_id=None, event="reparse_dag_file", logical_date=None) diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py index e14879119a63e..b990c027239ca 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py @@ -1580,7 +1580,7 @@ def test_post_dag_runs_with_empty_payload(self, test_client): }, ] - @mock.patch("airflow.models.DAG.create_dagrun") + @mock.patch("airflow.serialization.serialized_objects.SerializedDAG.create_dagrun") def test_dagrun_creation_exception_is_handled(self, mock_create_dagrun, test_client): now = timezone.utcnow().isoformat() error_message = "Encountered Error" diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_sources.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_sources.py index d2e42c014db30..ca48080c3a0ec 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_sources.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_sources.py @@ -26,10 +26,10 @@ from airflow.models.dagbag import DBDagBag from airflow.models.dagcode import DagCode -from airflow.models.serialized_dag import SerializedDagModel from airflow.utils.state import DagRunState from airflow.utils.types import DagRunTriggeredByType, DagRunType +from tests_common.test_utils.dag import sync_dag_to_db from tests_common.test_utils.db import clear_db_dags, clear_db_runs, parse_and_sync_to_db from unit.serialization.test_dag_serialization import AIRFLOW_REPO_ROOT_PATH @@ -46,11 +46,24 @@ @pytest.fixture -def test_dag(session): - parse_and_sync_to_db(EXAMPLE_DAG_FILE, include_examples=False) +def real_dag_bag(): + return parse_and_sync_to_db(EXAMPLE_DAG_FILE, include_examples=False) + + +@pytest.fixture +def test_dag(session, real_dag_bag): return DBDagBag().get_latest_version_of_dag(TEST_DAG_ID, session=session) +@pytest.fixture +def force_reserialization(real_dag_bag, session): + def _force_reserialization(dag_id, bundle_name): + dag = real_dag_bag.get_dag(dag_id, session=session) + sync_dag_to_db(dag, bundle_name=bundle_name, session=session) + + return _force_reserialization + + class TestGetDAGSource: @pytest.fixture(autouse=True) def setup(self, url_safe_serializer) -> None: @@ -112,7 +125,7 @@ def test_should_respond_200_json(self, test_client, test_dag, headers): assert response.headers["Content-Type"].startswith("application/json") @pytest.mark.parametrize("accept", ["application/json", "text/plain"]) - def test_should_respond_200_version(self, test_client, accept, session, test_dag): + def test_should_respond_200_version(self, test_client, accept, session, test_dag, force_reserialization): dag_content = self._get_dag_file_code(test_dag.fileloc) test_dag.create_dagrun( run_id="test1", @@ -123,7 +136,7 @@ def test_should_respond_200_version(self, test_client, accept, session, test_dag ) # force reserialization test_dag.doc_md = "new doc" - SerializedDagModel.write_dag(test_dag, bundle_name="dags-folder") + force_reserialization(test_dag.dag_id, "dag-folder") dagcode = ( session.query(DagCode) .filter(DagCode.fileloc == test_dag.fileloc) diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_log.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_log.py index a326f7f291527..3989aae2aa545 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_log.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_log.py @@ -32,11 +32,11 @@ from airflow.api_fastapi.common.dagbag import create_dag_bag, dag_bag_from_app from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG from airflow.models.dag import DAG -from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.standard.operators.empty import EmptyOperator from airflow.sdk import task from airflow.utils.types import DagRunType +from tests_common.test_utils.dag import sync_dag_to_db from tests_common.test_utils.db import clear_db_runs from tests_common.test_utils.file_task_handler import convert_list_to_stream @@ -275,8 +275,7 @@ def test_get_logs_of_removed_task(self, request_url, expected_filename, extra_qu # Recreate DAG without tasks dagbag = create_dag_bag() dag = DAG(self.DAG_ID, schedule=None, start_date=timezone.parse(self.default_time)) - dag.sync_to_db() - SerializedDagModel.write_dag(dag, bundle_name="testing") + sync_dag_to_db(dag) self.app.dependency_overrides[dag_bag_from_app] = lambda: dagbag diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py index b2b06cc9cf122..f9306d20fe0de 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py @@ -34,13 +34,13 @@ from airflow.jobs.triggerer_job_runner import TriggererJobRunner from airflow.listeners.listener import get_listener_manager from airflow.models import DagRun, TaskInstance -from airflow.models.baseoperator import BaseOperator from airflow.models.dag_version import DagVersion from airflow.models.dagbag import DagBag, sync_bag_to_db from airflow.models.renderedtifields import RenderedTaskInstanceFields as RTIF from airflow.models.taskinstancehistory import TaskInstanceHistory from airflow.models.taskmap import TaskMap from airflow.models.trigger import Trigger +from airflow.sdk import BaseOperator from airflow.utils.platform import getuser from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.types import DagRunType @@ -3378,7 +3378,7 @@ def test_patch_task_instance_notifies_listeners(self, test_client, session, stat assert response2.json()["state"] == state assert listener.state == listener_state - @mock.patch("airflow.models.dag.DAG.set_task_instance_state") + @mock.patch("airflow.serialization.serialized_objects.SerializedDAG.set_task_instance_state") def test_should_call_mocked_api(self, mock_set_ti_state, test_client, session): self.create_task_instances(session) @@ -3711,7 +3711,7 @@ def test_should_raise_422_for_invalid_task_instance_state(self, payload, expecte ), ], ) - @mock.patch("airflow.models.dag.DAG.set_task_instance_state") + @mock.patch("airflow.serialization.serialized_objects.SerializedDAG.set_task_instance_state") def test_update_mask_should_call_mocked_api( self, mock_set_ti_state, @@ -4023,7 +4023,7 @@ def test_set_note_should_respond_200_when_note_is_empty(self, test_client, sessi assert response_ti["note"] == new_note_value _check_task_instance_note(session, response_ti["id"], {"content": new_note_value, "user_id": "test"}) - @mock.patch("airflow.models.dag.DAG.set_task_instance_state") + @mock.patch("airflow.serialization.serialized_objects.SerializedDAG.set_task_instance_state") def test_should_raise_409_for_updating_same_task_instance_state( self, mock_set_ti_state, test_client, session ): @@ -4049,7 +4049,7 @@ class TestPatchTaskInstanceDryRun(TestTaskInstanceEndpoint): RUN_ID = "TEST_DAG_RUN_ID" DAG_DISPLAY_NAME = "example_python_operator" - @mock.patch("airflow.models.dag.DAG.set_task_instance_state") + @mock.patch("airflow.serialization.serialized_objects.SerializedDAG.set_task_instance_state") def test_should_call_mocked_api(self, mock_set_ti_state, test_client, session): self.create_task_instances(session) @@ -4407,7 +4407,7 @@ def test_should_raise_422_for_invalid_task_instance_state(self, payload, expecte ), ], ) - @mock.patch("airflow.models.dag.DAG.set_task_instance_state") + @mock.patch("airflow.serialization.serialized_objects.SerializedDAG.set_task_instance_state") def test_update_mask_should_call_mocked_api( self, mock_set_ti_state, @@ -4440,7 +4440,7 @@ def test_update_mask_should_call_mocked_api( assert response.json() == expected_json assert mock_set_ti_state.call_count == set_ti_state_call_count - @mock.patch("airflow.models.dag.DAG.set_task_instance_state") + @mock.patch("airflow.serialization.serialized_objects.SerializedDAG.set_task_instance_state") def test_should_return_empty_list_for_updating_same_task_instance_state( self, mock_set_ti_state, test_client, session ): diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_tasks.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_tasks.py index 06de068baf91d..95b279af5905e 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_tasks.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_tasks.py @@ -20,15 +20,13 @@ import pytest -from airflow import settings from airflow.api_fastapi.common.dagbag import dag_bag_from_app -from airflow.models.dag import DAG from airflow.models.dagbag import DBDagBag -from airflow.models.dagbundle import DagBundleModel -from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.standard.operators.empty import EmptyOperator +from airflow.sdk import DAG from airflow.sdk.definitions._internal.expandinput import EXPAND_INPUT_EMPTY +from tests_common.test_utils.dag import sync_dag_to_db, sync_dags_to_db from tests_common.test_utils.db import ( clear_db_dag_bundles, clear_db_dags, @@ -57,6 +55,7 @@ def create_dags(self, test_client): with DAG(self.dag_id, schedule=None, start_date=self.task1_start_date, doc_md="details") as dag: task1 = EmptyOperator(task_id=self.task_id, params={"foo": "bar"}) task2 = EmptyOperator(task_id=self.task_id2, start_date=self.task2_start_date) + task1 >> task2 with DAG(self.mapped_dag_id, schedule=None, start_date=self.task1_start_date) as mapped_dag: EmptyOperator(task_id=self.task_id3) @@ -67,21 +66,10 @@ def create_dags(self, test_client): with DAG(self.unscheduled_dag_id, start_date=None, schedule=None) as unscheduled_dag: task4 = EmptyOperator(task_id=self.unscheduled_task_id1, params={"is_unscheduled": True}) task5 = EmptyOperator(task_id=self.unscheduled_task_id2, params={"is_unscheduled": True}) + task4 >> task5 - task1 >> task2 - task4 >> task5 - session = settings.Session() - bundle_name = "testing" - dag_bundle = DagBundleModel(name=bundle_name) - session.merge(dag_bundle) - session.commit() - DAG.bulk_write_to_db(bundle_name, None, [dag, mapped_dag, unscheduled_dag]) - SerializedDagModel.write_dag(dag, bundle_name=bundle_name) - SerializedDagModel.write_dag(mapped_dag, bundle_name=bundle_name) - SerializedDagModel.write_dag(unscheduled_dag, bundle_name=bundle_name) - dag_bag = DBDagBag() - - test_client.app.dependency_overrides[dag_bag_from_app] = lambda: dag_bag + sync_dags_to_db([dag, mapped_dag, unscheduled_dag]) + test_client.app.dependency_overrides[dag_bag_from_app] = DBDagBag @staticmethod def clear_db(): @@ -245,11 +233,9 @@ def test_should_respond_200_serialized(self, test_client, testing_dag_bundle): with DAG(self.dag_id, schedule=None, start_date=self.task1_start_date, doc_md="details") as dag: task1 = EmptyOperator(task_id=self.task_id, params={"foo": "bar"}) task2 = EmptyOperator(task_id=self.task_id2, start_date=self.task2_start_date) - task1 >> task2 - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [dag]) - SerializedDagModel.write_dag(dag, bundle_name=bundle_name) + + sync_dag_to_db(dag) dag_bag = DBDagBag() test_client.app.dependency_overrides[dag_bag_from_app] = lambda: dag_bag diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_xcom.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_xcom.py index e4817ec7ea777..38ca7f48c73ac 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_xcom.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_xcom.py @@ -21,22 +21,22 @@ import pytest -from airflow import DAG from airflow._shared.timezones import timezone from airflow.api_fastapi.core_api.datamodels.xcom import XComCreateBody from airflow.models.dag_version import DagVersion from airflow.models.dagbundle import DagBundleModel from airflow.models.dagrun import DagRun -from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import TaskInstance from airflow.models.xcom import XComModel from airflow.providers.standard.operators.empty import EmptyOperator +from airflow.sdk import DAG from airflow.sdk.bases.xcom import BaseXCom from airflow.sdk.execution_time.xcom import resolve_xcom_backend from airflow.utils.session import provide_session from airflow.utils.types import DagRunType from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.dag import sync_dag_to_db from tests_common.test_utils.db import clear_db_dag_bundles, clear_db_dags, clear_db_runs, clear_db_xcom from tests_common.test_utils.logs import check_last_log @@ -404,8 +404,7 @@ def _create_xcom_entries(self, dag_id, run_id, logical_date, task_id, mapped_ti= session.flush() dag = DAG(dag_id=dag_id) - DAG.bulk_write_to_db(bundle_name, None, [dag]) - SerializedDagModel.write_dag(dag, bundle_name=bundle_name) + sync_dag_to_db(dag) dagrun = DagRun( dag_id=dag_id, run_id=run_id, diff --git a/airflow-core/tests/unit/cli/commands/test_dag_command.py b/airflow-core/tests/unit/cli/commands/test_dag_command.py index cd02de96d4961..73317d9b44c4f 100644 --- a/airflow-core/tests/unit/cli/commands/test_dag_command.py +++ b/airflow-core/tests/unit/cli/commands/test_dag_command.py @@ -47,6 +47,7 @@ from airflow.utils.types import DagRunType from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.dag import sync_dag_to_db from tests_common.test_utils.db import ( clear_db_dags, clear_db_import_errors, @@ -55,11 +56,6 @@ ) from unit.models import TEST_DAGS_FOLDER -try: - from airflow.sdk import BaseOperator -except ImportError: - from airflow.models.baseoperator import BaseOperator # type: ignore[no-redef] - DEFAULT_DATE = timezone.make_aware(datetime(2015, 1, 1), timezone=timezone.utc) if pendulum.__version__.startswith("3"): DEFAULT_DATE_REPR = DEFAULT_DATE.isoformat(sep=" ") @@ -372,7 +368,8 @@ def test_list_dags_none_get_dagmodel(self, mock_get_dagmodel, stdout_capture): def test_dagbag_dag_col(self, session): dagbag = DBDagBag() dag_details = dag_command._get_dagbag_dag_details( - dagbag.get_latest_version_of_dag("tutorial_dag", session=session) + dagbag.get_latest_version_of_dag("tutorial_dag", session=session), + session=session, ) assert sorted(dag_details) == sorted(dag_command.DAG_DETAIL_FIELDS) @@ -656,7 +653,7 @@ def test_dag_state(self): is None ) - @mock.patch("airflow.cli.commands.dag_command.get_dag") + @mock.patch("airflow.cli.commands.dag_command.get_bagged_dag") def test_dag_test(self, mock_get_dag): cli_args = self.parser.parse_args(["dags", "test", "example_bash_operator", DEFAULT_DATE.isoformat()]) dag_command.dag_test(cli_args) @@ -674,7 +671,7 @@ def test_dag_test(self, mock_get_dag): ] ) - @mock.patch("airflow.cli.commands.dag_command.get_dag") + @mock.patch("airflow.cli.commands.dag_command.get_bagged_dag") def test_dag_test_fail_raise_error(self, mock_get_dag): logical_date_str = DEFAULT_DATE.isoformat() mock_get_dag.return_value.test.return_value = DagRun( @@ -684,7 +681,7 @@ def test_dag_test_fail_raise_error(self, mock_get_dag): with pytest.raises(SystemExit, match=r"DagRun failed"): dag_command.dag_test(cli_args) - @mock.patch("airflow.cli.commands.dag_command.get_dag") + @mock.patch("airflow.cli.commands.dag_command.get_bagged_dag") def test_dag_test_no_logical_date(self, mock_get_dag, time_machine): now = pendulum.now() time_machine.move_to(now, tick=False) @@ -707,7 +704,7 @@ def test_dag_test_no_logical_date(self, mock_get_dag, time_machine): ] ) - @mock.patch("airflow.cli.commands.dag_command.get_dag") + @mock.patch("airflow.cli.commands.dag_command.get_bagged_dag") def test_dag_test_conf(self, mock_get_dag): cli_args = self.parser.parse_args( [ @@ -735,7 +732,7 @@ def test_dag_test_conf(self, mock_get_dag): ) @mock.patch("airflow.cli.commands.dag_command.render_dag", return_value=MagicMock(source="SOURCE")) - @mock.patch("airflow.cli.commands.dag_command.get_dag") + @mock.patch("airflow.cli.commands.dag_command.get_bagged_dag") def test_dag_test_show_dag(self, mock_get_dag, mock_render_dag, stdout_capture): mock_get_dag.return_value.test.return_value.run_id = "__test_dag_test_show_dag_fake_dag_run_run_id__" @@ -841,8 +838,8 @@ def test_dag_test_with_both_bundle_and_dagfile_path(self, mock_dagbag, configure include_examples=False, ) - @mock.patch("airflow.models.dag._get_or_create_dagrun") - def test_dag_test_with_custom_timetable(self, mock__get_or_create_dagrun): + @mock.patch("airflow.models.dagrun.get_or_create_dagrun") + def test_dag_test_with_custom_timetable(self, mock_get_or_create_dagrun): """ when calling `dags test` on dag with custom timetable, the DagRun object should be created with data_intervals. @@ -854,11 +851,11 @@ def test_dag_test_with_custom_timetable(self, mock__get_or_create_dagrun): with mock.patch.object(AfterWorkdayTimetable, "get_next_workday", return_value=DEFAULT_DATE): dag_command.dag_test(cli_args) - assert "data_interval" in mock__get_or_create_dagrun.call_args.kwargs + assert "data_interval" in mock_get_or_create_dagrun.call_args.kwargs - @mock.patch("airflow.models.dag._get_or_create_dagrun") + @mock.patch("airflow.models.dagrun.get_or_create_dagrun") def test_dag_with_parsing_context( - self, mock__get_or_create_dagrun, testing_dag_bundle, configure_testing_dag_bundle + self, mock_get_or_create_dagrun, testing_dag_bundle, configure_testing_dag_bundle ): """ airflow parsing context should be set when calling `dags test`. @@ -874,7 +871,7 @@ def test_dag_with_parsing_context( dag_command.dag_test(cli_args) # if dag_parsing_context is not set, this DAG will only have 1 task - assert len(mock__get_or_create_dagrun.call_args[1]["dag"].task_ids) == 2 + assert len(mock_get_or_create_dagrun.call_args[1]["dag"].task_ids) == 2 def test_dag_test_run_inline_trigger(self, dag_maker): now = timezone.utcnow() @@ -918,6 +915,7 @@ def execute(self, context, event=None): task_two = two(task_one) op = MyOp(task_id="abc", tfield=task_two) task_two >> op + sync_dag_to_db(dag) dr = dag.test() trigger_arg = mock_run.call_args_list[0].args[0] @@ -951,15 +949,18 @@ def test_dag_test_with_mark_success(self, mock__execute_task): @conf_vars({("core", "load_examples"): "false"}) def test_get_dag_excludes_examples_with_bundle(self, configure_testing_dag_bundle): """Test that example DAGs are excluded when bundle names are passed.""" - from airflow.utils.cli import get_dag + try: + from airflow.utils.cli import get_bagged_dag + except ImportError: # Prior to Airflow 3.1.0. + from airflow.utils.cli import get_dag as get_bagged_dag # type: ignore with configure_testing_dag_bundle(TEST_DAGS_FOLDER / "test_sensor.py"): # example DAG should not be found since include_examples=False with pytest.raises(AirflowException, match="could not be found"): - get_dag(bundle_names=["testing"], dag_id="example_simplest_dag") + get_bagged_dag(bundle_names=["testing"], dag_id="example_simplest_dag") # However, "test_sensor.py" should exist - dag = get_dag(bundle_names=["testing"], dag_id="test_sensor") + dag = get_bagged_dag(bundle_names=["testing"], dag_id="test_sensor") assert dag.dag_id == "test_sensor" diff --git a/airflow-core/tests/unit/cli/commands/test_task_command.py b/airflow-core/tests/unit/cli/commands/test_task_command.py index 23743306b39ff..bbe597d0b01c9 100644 --- a/airflow-core/tests/unit/cli/commands/test_task_command.py +++ b/airflow-core/tests/unit/cli/commands/test_task_command.py @@ -43,12 +43,13 @@ from airflow.models.dagbag import DBDagBag from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.standard.operators.bash import BashOperator -from airflow.serialization.serialized_objects import SerializedDAG +from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG from airflow.utils.session import create_session from airflow.utils.state import State, TaskInstanceState from airflow.utils.types import DagRunTriggeredByType, DagRunType from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.dag import sync_dag_to_db from tests_common.test_utils.db import clear_db_runs, parse_and_sync_to_db if TYPE_CHECKING: @@ -319,8 +320,8 @@ def test_mapped_task_render_with_template(self, dag_maker): {% endfor %} """ commands = [templated_command, "echo 1"] - BashOperator.partial(task_id="some_command").expand(bash_command=commands) + sync_dag_to_db(dag) with redirect_stdout(io.StringIO()) as stdout: task_command.task_render( @@ -351,11 +352,12 @@ def test_task_state(self): def test_task_states_for_dag_run(self): dag2 = DagBag().dags["example_python_operator"] + lazy_deserialized_dag2 = LazyDeserializedDAG.from_dag(dag2) - SerializedDagModel.write_dag(dag2, bundle_name="testing") - task2 = dag2.get_task(task_id="print_the_context") + SerializedDagModel.write_dag(lazy_deserialized_dag2, bundle_name="testing") - dag2 = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag2)) + task2 = dag2.get_task(task_id="print_the_context") + dag2 = SerializedDAG.from_dict(lazy_deserialized_dag2.data) default_date2 = timezone.datetime(2016, 1, 9) dag2.clear() diff --git a/airflow-core/tests/unit/dag_processing/test_collection.py b/airflow-core/tests/unit/dag_processing/test_collection.py index 924f7757b33e1..ae3716438bd3b 100644 --- a/airflow-core/tests/unit/dag_processing/test_collection.py +++ b/airflow-core/tests/unit/dag_processing/test_collection.py @@ -48,13 +48,12 @@ DagScheduleAssetNameReference, DagScheduleAssetUriReference, ) -from airflow.models.dag import DAG from airflow.models.errors import ParseImportError from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.triggers.temporal import TimeDeltaTrigger -from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetWatcher -from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG +from airflow.sdk import DAG, Asset, AssetAlias, AssetWatcher +from airflow.serialization.serialized_objects import LazyDeserializedDAG from tests_common.test_utils.config import conf_vars from tests_common.test_utils.db import ( @@ -136,17 +135,19 @@ def per_test(self) -> Generator: ], ) @pytest.mark.usefixtures("testing_dag_bundle") - def test_add_asset_trigger_references(self, session, is_active, is_paused, expected_num_triggers): + def test_add_asset_trigger_references( + self, dag_maker, session, is_active, is_paused, expected_num_triggers + ): classpath, kwargs = TimeDeltaTrigger(timedelta(seconds=0)).serialize() asset = Asset( "test_add_asset_trigger_references_asset", watchers=[AssetWatcher(name="test", trigger={"classpath": classpath, "kwargs": kwargs})], ) - with DAG(dag_id="test_add_asset_trigger_references_dag", schedule=[asset]) as dag: + with dag_maker(dag_id="test_add_asset_trigger_references_dag", schedule=[asset]) as dag: EmptyOperator(task_id="mytask") - dags = {dag.dag_id: dag} + dags = {dag.dag_id: LazyDeserializedDAG.from_dag(dag)} orm_dags = DagModelOperation(dags, "testing", None).add_dags(session=session) # Simulate dag unpause and deletion. @@ -189,7 +190,7 @@ def test_add_dag_asset_name_uri_references(self, dag_maker, session, schedule, m with dag_maker(dag_id="test", schedule=schedule, session=session) as dag: pass - op = AssetModelOperation.collect({dag.dag_id: dag}) + op = AssetModelOperation.collect({dag.dag_id: LazyDeserializedDAG.from_dag(dag)}) op.add_dag_asset_name_uri_references(session=session) assert session.execute(select(*columns)).all() == expected @@ -198,14 +199,14 @@ def test_change_asset_property_sync_group(self, dag_maker, session): with dag_maker(schedule=[asset]) as dag: EmptyOperator(task_id="mytask") - asset_op = AssetModelOperation.collect({dag.dag_id: dag}) + asset_op = AssetModelOperation.collect({dag.dag_id: LazyDeserializedDAG.from_dag(dag)}) orm_assets = asset_op.sync_assets(session=session) assert len(orm_assets) == 1 assert next(iter(orm_assets.values())).group == "old_group" # Parser should pick up group change. asset.group = "new_group" - asset_op = AssetModelOperation.collect({dag.dag_id: dag}) + asset_op = AssetModelOperation.collect({dag.dag_id: LazyDeserializedDAG.from_dag(dag)}) orm_assets = asset_op.sync_assets(session=session) assert len(orm_assets) == 1 assert next(iter(orm_assets.values())).group == "new_group" @@ -215,14 +216,14 @@ def test_change_asset_property_sync_extra(self, dag_maker, session): with dag_maker(schedule=asset) as dag: EmptyOperator(task_id="mytask") - asset_op = AssetModelOperation.collect({dag.dag_id: dag}) + asset_op = AssetModelOperation.collect({dag.dag_id: LazyDeserializedDAG.from_dag(dag)}) orm_assets = asset_op.sync_assets(session=session) assert len(orm_assets) == 1 assert next(iter(orm_assets.values())).extra == {"foo": "old"} # Parser should pick up extra change. asset.extra = {"foo": "new"} - asset_op = AssetModelOperation.collect({dag.dag_id: dag}) + asset_op = AssetModelOperation.collect({dag.dag_id: LazyDeserializedDAG.from_dag(dag)}) orm_assets = asset_op.sync_assets(session=session) assert len(orm_assets) == 1 assert next(iter(orm_assets.values())).extra == {"foo": "new"} @@ -232,14 +233,14 @@ def test_change_asset_alias_property_sync_group(self, dag_maker, session): with dag_maker(schedule=alias) as dag: EmptyOperator(task_id="mytask") - asset_op = AssetModelOperation.collect({dag.dag_id: dag}) + asset_op = AssetModelOperation.collect({dag.dag_id: LazyDeserializedDAG.from_dag(dag)}) orm_aliases = asset_op.sync_asset_aliases(session=session) assert len(orm_aliases) == 1 assert next(iter(orm_aliases.values())).group == "old_group" # Parser should pick up group change. alias.group = "new_group" - asset_op = AssetModelOperation.collect({dag.dag_id: dag}) + asset_op = AssetModelOperation.collect({dag.dag_id: LazyDeserializedDAG.from_dag(dag)}) orm_aliases = asset_op.sync_asset_aliases(session=session) assert len(orm_aliases) == 1 assert next(iter(orm_aliases.values())).group == "new_group" @@ -265,7 +266,7 @@ def test_add_asset_activate(self, dag_maker, session): with dag_maker(schedule=[asset]) as dag: EmptyOperator(task_id="mytask") - asset_op = AssetModelOperation.collect({dag.dag_id: dag}) + asset_op = AssetModelOperation.collect({dag.dag_id: LazyDeserializedDAG.from_dag(dag)}) orm_assets = asset_op.sync_assets(session=session) session.flush() assert len(orm_assets) == 1 @@ -285,7 +286,7 @@ def test_add_asset_activate_already_exists(self, dag_maker, session): with dag_maker(schedule=[asset]) as dag: EmptyOperator(task_id="mytask") - asset_op = AssetModelOperation.collect({dag.dag_id: dag}) + asset_op = AssetModelOperation.collect({dag.dag_id: LazyDeserializedDAG.from_dag(dag)}) orm_assets = asset_op.sync_assets(session=session) session.flush() assert len(orm_assets) == 1 @@ -311,7 +312,7 @@ def test_add_asset_activate_conflict(self, dag_maker, session, existing_assets): with dag_maker(schedule=[asset]) as dag: EmptyOperator(task_id="mytask") - asset_op = AssetModelOperation.collect({dag.dag_id: dag}) + asset_op = AssetModelOperation.collect({dag.dag_id: LazyDeserializedDAG.from_dag(dag)}) orm_assets = asset_op.sync_assets(session=session) session.flush() assert len(orm_assets) == 1 @@ -321,6 +322,7 @@ def test_add_asset_activate_conflict(self, dag_maker, session, existing_assets): assert orm_assets["myasset", "file://myasset/"].active is None, "should not activate due to conflict" +@pytest.mark.need_serialized_dag @pytest.mark.db_test class TestUpdateDagParsingResults: """Tests centred around the ``update_dag_parsing_results_in_db`` function.""" @@ -341,10 +343,6 @@ def _dag_import_error_listener(self): get_listener_manager().clear() dag_import_error_listener.clear() - def dag_to_lazy_serdag(self, dag: DAG) -> LazyDeserializedDAG: - ser_dict = SerializedDAG.to_dict(dag) - return LazyDeserializedDAG(data=ser_dict) - @mark_fab_auth_manager_test @pytest.mark.usefixtures("clean_db") # sync_perms in fab has bad session commit hygiene def test_sync_perms_syncs_dag_specific_perms_on_update( @@ -388,7 +386,7 @@ def _sync_to_db(): serialized_dags_count = session.query(func.count(SerializedDagModel.dag_id)).scalar() @patch.object(SerializedDagModel, "write_dag") - @patch("airflow.models.dag.DAG.bulk_write_to_db") + @patch("airflow.serialization.serialized_objects.SerializedDAG.bulk_write_to_db") def test_sync_to_db_is_retried( self, mock_bulk_write_to_db, mock_s10n_write_dag, testing_dag_bundle, session ): @@ -444,8 +442,14 @@ def test_serialized_dags_are_written_to_db_on_sync(self, testing_dag_bundle, ses assert serialized_dags_count == 0 dag = DAG(dag_id="test") - - update_dag_parsing_results_in_db("testing", None, [dag], dict(), set(), session) + update_dag_parsing_results_in_db( + bundle_name="testing", + bundle_version=None, + dags=[LazyDeserializedDAG.from_dag(dag)], + import_errors={}, + warnings=set(), + session=session, + ) new_serialized_dags_count = session.query(func.count(SerializedDagModel.dag_id)).scalar() assert new_serialized_dags_count == 1 @@ -496,6 +500,7 @@ def test_import_error_persist_for_invalid_access_control_role( self, mock_full_path, monkeypatch, + dag_maker, session, time_machine, dag_import_error_listener, @@ -513,12 +518,13 @@ def test_import_error_persist_for_invalid_access_control_role( time_machine.move_to(tz.datetime(2020, 1, 5, 0, 0, 0), tick=False) # create a DAG and assign it a non-exist role. - dag = DAG( + with dag_maker( dag_id="test_nonexist_access_control", access_control={ "non_existing_role": {"can_edit", "can_read", "can_delete"}, }, - ) + ) as dag: + pass dag.fileloc = "test_nonexist_access_control.py" dag.relative_fileloc = "test_nonexist_access_control.py" mock_full_path.return_value = "test_nonexist_access_control.py" @@ -679,7 +685,14 @@ def test_remove_error_clears_import_error(self, testing_dag_bundle, session): dag.relative_fileloc = filename import_errors = {} - update_dag_parsing_results_in_db(bundle_name, None, [dag], import_errors, set(), session) + update_dag_parsing_results_in_db( + bundle_name, + bundle_version=None, + dags=[LazyDeserializedDAG.from_dag(dag)], + import_errors=dict.fromkeys(import_errors), + warnings=set(), + session=session, + ) dag_model: DagModel = session.get(DagModel, (dag.dag_id,)) assert dag_model.has_import_errors is False @@ -709,17 +722,36 @@ def test_remove_error_updates_loaded_dag_model(self, testing_dag_bundle, session ) ) session.flush() + dag = DAG(dag_id="test") dag.fileloc = filename dag.relative_fileloc = filename + lazy_deserialized_dags = [LazyDeserializedDAG.from_dag(dag)] + import_errors = {(bundle_name, filename): "Some error"} - update_dag_parsing_results_in_db(bundle_name, None, [dag], import_errors, set(), session) + update_dag_parsing_results_in_db( + bundle_name, + bundle_version=None, + dags=lazy_deserialized_dags, + import_errors=import_errors, + warnings=set(), + session=session, + ) dag_model = session.get(DagModel, (dag.dag_id,)) assert dag_model.has_import_errors is True + import_errors = {} - update_dag_parsing_results_in_db(bundle_name, None, [dag], import_errors, set(), session) + update_dag_parsing_results_in_db( + bundle_name, + bundle_version=None, + dags=lazy_deserialized_dags, + import_errors=import_errors, + warnings=set(), + session=session, + ) assert dag_model.has_import_errors is False + @pytest.mark.need_serialized_dag(False) @pytest.mark.parametrize( ("attrs", "expected"), [ @@ -770,7 +802,7 @@ def test_remove_error_updates_loaded_dag_model(self, testing_dag_bundle, session @pytest.mark.usefixtures("clean_db") def test_dagmodel_properties(self, attrs, expected, session, time_machine, testing_dag_bundle, dag_maker): """Test that properties on the dag model are correctly set when dealing with a LazySerializedDag""" - dt = tz.datetime(2020, 1, 5, 0, 0, 0) + dt = tz.datetime(2020, 1, 6, 0, 0, 0) time_machine.move_to(dt, tick=False) tasks = attrs.pop("_tasks_", None) @@ -787,8 +819,15 @@ def test_dagmodel_properties(self, attrs, expected, session, time_machine, testi } dr1 = DagRun(logical_date=dt, run_id="test_run_id_1", **dr_kwargs, start_date=dt) session.add(dr1) - session.commit() - update_dag_parsing_results_in_db("testing", None, [self.dag_to_lazy_serdag(dag)], {}, set(), session) + update_dag_parsing_results_in_db( + bundle_name="testing", + bundle_version=None, + dags=[LazyDeserializedDAG.from_dag(dag)], + import_errors={}, + warnings=set(), + session=session, + ) + session.flush() orm_dag = session.get(DagModel, ("dag",)) @@ -803,13 +842,13 @@ def test_dagmodel_properties(self, attrs, expected, session, time_machine, testi def test_existing_dag_is_paused_upon_creation(self, testing_dag_bundle, session, dag_maker): with dag_maker("dag_paused", schedule=None) as dag: ... - update_dag_parsing_results_in_db("testing", None, [self.dag_to_lazy_serdag(dag)], {}, set(), session) + update_dag_parsing_results_in_db("testing", None, [dag], {}, set(), session) orm_dag = session.get(DagModel, ("dag_paused",)) assert orm_dag.is_paused is False with dag_maker("dag_paused", schedule=None, is_paused_upon_creation=True) as dag: ... - update_dag_parsing_results_in_db("testing", None, [self.dag_to_lazy_serdag(dag)], {}, set(), session) + update_dag_parsing_results_in_db("testing", None, [dag], {}, set(), session) # Since the dag existed before, it should not follow the pause flag upon creation orm_dag = session.get(DagModel, ("dag_paused",)) assert orm_dag.is_paused is False @@ -817,7 +856,7 @@ def test_existing_dag_is_paused_upon_creation(self, testing_dag_bundle, session, def test_bundle_name_and_version_are_stored(self, testing_dag_bundle, session, dag_maker): with dag_maker("mydag", schedule=None) as dag: ... - update_dag_parsing_results_in_db("testing", "1.0", [self.dag_to_lazy_serdag(dag)], {}, set(), session) + update_dag_parsing_results_in_db("testing", "1.0", [dag], {}, set(), session) orm_dag = session.get(DagModel, "mydag") assert orm_dag.bundle_name == "testing" assert orm_dag.bundle_version == "1.0" @@ -825,7 +864,7 @@ def test_bundle_name_and_version_are_stored(self, testing_dag_bundle, session, d def test_max_active_tasks_explicit_value_is_used(self, testing_dag_bundle, session, dag_maker): with dag_maker("dag_max_tasks", schedule=None, max_active_tasks=5) as dag: ... - update_dag_parsing_results_in_db("testing", None, [self.dag_to_lazy_serdag(dag)], {}, set(), session) + update_dag_parsing_results_in_db("testing", None, [dag], {}, set(), session) orm_dag = session.get(DagModel, "dag_max_tasks") assert orm_dag.max_active_tasks == 5 @@ -834,16 +873,14 @@ def test_max_active_tasks_defaults_from_conf_when_none(self, testing_dag_bundle, with conf_vars({("core", "max_active_tasks_per_dag"): "7"}): with dag_maker("dag_max_tasks_default", schedule=None) as dag: ... - update_dag_parsing_results_in_db( - "testing", None, [self.dag_to_lazy_serdag(dag)], {}, set(), session - ) + update_dag_parsing_results_in_db("testing", None, [dag], {}, set(), session) orm_dag = session.get(DagModel, "dag_max_tasks_default") assert orm_dag.max_active_tasks == 7 def test_max_active_runs_explicit_value_is_used(self, testing_dag_bundle, session, dag_maker): with dag_maker("dag_max_runs", schedule=None, max_active_runs=3) as dag: ... - update_dag_parsing_results_in_db("testing", None, [self.dag_to_lazy_serdag(dag)], {}, set(), session) + update_dag_parsing_results_in_db("testing", None, [dag], {}, set(), session) orm_dag = session.get(DagModel, "dag_max_runs") assert orm_dag.max_active_runs == 3 @@ -851,9 +888,7 @@ def test_max_active_runs_defaults_from_conf_when_none(self, testing_dag_bundle, with conf_vars({("core", "max_active_runs_per_dag"): "4"}): with dag_maker("dag_max_runs_default", schedule=None) as dag: ... - update_dag_parsing_results_in_db( - "testing", None, [self.dag_to_lazy_serdag(dag)], {}, set(), session - ) + update_dag_parsing_results_in_db("testing", None, [dag], {}, set(), session) orm_dag = session.get(DagModel, "dag_max_runs_default") assert orm_dag.max_active_runs == 4 @@ -862,7 +897,7 @@ def test_max_consecutive_failed_dag_runs_explicit_value_is_used( ): with dag_maker("dag_max_failed_runs", schedule=None, max_consecutive_failed_dag_runs=2) as dag: ... - update_dag_parsing_results_in_db("testing", None, [self.dag_to_lazy_serdag(dag)], {}, set(), session) + update_dag_parsing_results_in_db("testing", None, [dag], {}, set(), session) orm_dag = session.get(DagModel, "dag_max_failed_runs") assert orm_dag.max_consecutive_failed_dag_runs == 2 @@ -872,8 +907,6 @@ def test_max_consecutive_failed_dag_runs_defaults_from_conf_when_none( with conf_vars({("core", "max_consecutive_failed_dag_runs_per_dag"): "6"}): with dag_maker("dag_max_failed_runs_default", schedule=None) as dag: ... - update_dag_parsing_results_in_db( - "testing", None, [self.dag_to_lazy_serdag(dag)], {}, set(), session - ) + update_dag_parsing_results_in_db("testing", None, [dag], {}, set(), session) orm_dag = session.get(DagModel, "dag_max_failed_runs_default") assert orm_dag.max_consecutive_failed_dag_runs == 6 diff --git a/airflow-core/tests/unit/dag_processing/test_manager.py b/airflow-core/tests/unit/dag_processing/test_manager.py index b2f76e8114dea..982cfb27fd0c5 100644 --- a/airflow-core/tests/unit/dag_processing/test_manager.py +++ b/airflow-core/tests/unit/dag_processing/test_manager.py @@ -50,10 +50,9 @@ DagFileStat, ) from airflow.dag_processing.processor import DagFileProcessorProcess -from airflow.models import DAG, DagBag, DagModel, DbCallbackRequest +from airflow.models import DagBag, DagModel, DbCallbackRequest from airflow.models.asset import TaskOutletAssetReference from airflow.models.dag_version import DagVersion -from airflow.models.dagbag import DBDagBag from airflow.models.dagbundle import DagBundleModel from airflow.models.dagcode import DagCode from airflow.models.serialized_dag import SerializedDagModel @@ -62,6 +61,7 @@ from tests_common.test_utils.compat import ParseImportError from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.dag import sync_dag_to_db from tests_common.test_utils.db import ( clear_db_assets, clear_db_callbacks, @@ -442,7 +442,8 @@ def test_parsing_requests_only_bundles_being_parsed(self, testing_dag_bundle): assert len(parsing_request_after) == 1 assert parsing_request_after[0].relative_fileloc == "file_x.py" - def test_scan_stale_dags(self, testing_dag_bundle): + @pytest.mark.usefixtures("testing_dag_bundle") + def test_scan_stale_dags(self, session): """ Ensure that DAGs are marked inactive when the file is parsed but the DagModel.last_parsed_time is not updated. @@ -466,57 +467,54 @@ def test_scan_stale_dags(self, testing_dag_bundle): bundle_path=test_dag_path.bundle_path, ) - with create_session() as session: - # Add stale DAG to the DB - dag = dagbag.get_dag("test_example_bash_operator") - dag.last_parsed_time = timezone.utcnow() - DAG.bulk_write_to_db("testing", None, [dag]) - SerializedDagModel.write_dag(dag, bundle_name="testing") - - # Add DAG to the file_parsing_stats - stat = DagFileStat( - num_dags=1, - import_errors=0, - last_finish_time=timezone.utcnow() + timedelta(hours=1), - last_duration=1, - run_count=1, - last_num_of_db_queries=1, - ) - manager._files = [test_dag_path] - manager._file_stats[test_dag_path] = stat - - active_dag_count = ( - session.query(func.count(DagModel.dag_id)) - .filter( - ~DagModel.is_stale, - DagModel.relative_fileloc == str(test_dag_path.rel_path), - DagModel.bundle_name == test_dag_path.bundle_name, - ) - .scalar() + # Add stale DAG to the DB + dag = dagbag.get_dag("test_example_bash_operator") + sync_dag_to_db(dag, session=session) + + # Add DAG to the file_parsing_stats + stat = DagFileStat( + num_dags=1, + import_errors=0, + last_finish_time=timezone.utcnow() + timedelta(hours=1), + last_duration=1, + run_count=1, + last_num_of_db_queries=1, + ) + manager._files = [test_dag_path] + manager._file_stats[test_dag_path] = stat + + active_dag_count = ( + session.query(func.count(DagModel.dag_id)) + .filter( + ~DagModel.is_stale, + DagModel.relative_fileloc == str(test_dag_path.rel_path), + DagModel.bundle_name == test_dag_path.bundle_name, ) - assert active_dag_count == 1 + .scalar() + ) + assert active_dag_count == 1 - manager._scan_stale_dags() + manager._scan_stale_dags() - active_dag_count = ( - session.query(func.count(DagModel.dag_id)) - .filter( - ~DagModel.is_stale, - DagModel.relative_fileloc == str(test_dag_path.rel_path), - DagModel.bundle_name == test_dag_path.bundle_name, - ) - .scalar() + active_dag_count = ( + session.query(func.count(DagModel.dag_id)) + .filter( + ~DagModel.is_stale, + DagModel.relative_fileloc == str(test_dag_path.rel_path), + DagModel.bundle_name == test_dag_path.bundle_name, ) - assert active_dag_count == 0 + .scalar() + ) + assert active_dag_count == 0 - serialized_dag_count = ( - session.query(func.count(SerializedDagModel.dag_id)) - .filter(SerializedDagModel.dag_id == dag.dag_id) - .scalar() - ) - # Deactivating the DagModel should not delete the SerializedDagModel - # SerializedDagModel gives history about Dags - assert serialized_dag_count == 1 + serialized_dag_count = ( + session.query(func.count(SerializedDagModel.dag_id)) + .filter(SerializedDagModel.dag_id == dag.dag_id) + .scalar() + ) + # Deactivating the DagModel should not delete the SerializedDagModel + # SerializedDagModel gives history about Dags + assert serialized_dag_count == 1 def test_kill_timed_out_processors_kill(self): manager = DagFileProcessorManager(max_runs=1, processor_timeout=5) @@ -664,15 +662,15 @@ def test_send_file_processing_statsd_timing( any_order=True, ) + @pytest.mark.usefixtures("testing_dag_bundle") def test_refresh_dags_dir_doesnt_delete_zipped_dags( - self, tmp_path, testing_dag_bundle, configure_testing_dag_bundle, test_zip_path + self, tmp_path, session, configure_testing_dag_bundle, test_zip_path ): """Test DagFileProcessorManager._refresh_dag_dir method""" dagbag = DagBag(dag_folder=tmp_path, include_examples=False) dagbag.process_file(test_zip_path) dag = dagbag.get_dag("test_zip_dag") - DAG.bulk_write_to_db("testing", None, [dag]) - SerializedDagModel.write_dag(dag, bundle_name="testing") + sync_dag_to_db(dag) with configure_testing_dag_bundle(test_zip_path): manager = DagFileProcessorManager(max_runs=1) @@ -683,7 +681,7 @@ def test_refresh_dags_dir_doesnt_delete_zipped_dags( # assert code not deleted assert DagCode.has_dag(dag.dag_id) # assert dag still active - assert not dag.get_is_stale() + assert session.get(DagModel, dag.dag_id).is_stale is False @pytest.mark.usefixtures("testing_dag_bundle") def test_refresh_dags_dir_deactivates_deleted_zipped_dags( @@ -738,11 +736,10 @@ def test_deactivate_deleted_dags(self, dag_maker, session): manager = DagFileProcessorManager(max_runs=1) manager.deactivate_deleted_dags("dag_maker", active_files) - dagbag = DBDagBag() # The DAG from test_dag1.py is still active - assert dagbag.get_latest_version_of_dag("test_dag1", session=session).get_is_active() is True + assert session.get(DagModel, "test_dag1").is_stale is False # and the DAG from test_dag2.py is deactivated - assert dagbag.get_latest_version_of_dag("test_dag2", session=session).get_is_active() is False + assert session.get(DagModel, "test_dag2").is_stale is True @conf_vars({("core", "load_examples"): "False"}) def test_fetch_callbacks_from_database(self, configure_testing_dag_bundle): diff --git a/airflow-core/tests/unit/datasets/test_dataset.py b/airflow-core/tests/unit/datasets/test_dataset.py index 4ad0946549937..dbbe537047c26 100644 --- a/airflow-core/tests/unit/datasets/test_dataset.py +++ b/airflow-core/tests/unit/datasets/test_dataset.py @@ -24,51 +24,56 @@ @pytest.mark.parametrize( "module_path, attr_name, expected_value, warning_message", ( - ( + pytest.param( "airflow", "Dataset", "airflow.sdk.definitions.asset.Asset", ( - "Import 'Dataset' directly from the airflow module is deprecated and will be removed in the future. " - "Please import it from 'airflow.sdk.definitions.asset.Asset'." + "Import 'Dataset' directly from the airflow module is deprecated " + "and will be removed in the future. Please import it from 'airflow.sdk.Asset'." ), + id="airflow.Dataset", ), - ( + pytest.param( "airflow.datasets", "Dataset", "airflow.sdk.definitions.asset.Asset", ( - "Import 'airflow.dataset.Dataset' is deprecated and " - "will be removed in the Airflow 3.2. Please import it from 'airflow.sdk.definitions.asset.Asset'." + "Import 'airflow.datasets.Dataset' is deprecated and " + "will be removed in Airflow 3.2. Please import it from 'airflow.sdk.Asset'." ), + id="airflow.datasets.Dataset", ), - ( + pytest.param( "airflow.datasets", "DatasetAlias", "airflow.sdk.definitions.asset.AssetAlias", ( - "Import 'airflow.dataset.DatasetAlias' is deprecated and " - "will be removed in the Airflow 3.2. Please import it from 'airflow.sdk.definitions.asset.AssetAlias'." + "Import 'airflow.datasets.DatasetAlias' is deprecated and " + "will be removed in Airflow 3.2. Please import it from 'airflow.sdk.AssetAlias'." ), + id="airflow.datasets.DatasetAlias", ), - ( + pytest.param( "airflow.datasets", "expand_alias_to_datasets", "airflow.models.asset.expand_alias_to_assets", ( - "Import 'airflow.dataset.expand_alias_to_datasets' is deprecated and " - "will be removed in the Airflow 3.2. Please import it from 'airflow.models.asset.expand_alias_to_assets'." + "Import 'airflow.datasets.expand_alias_to_datasets' is deprecated and will be removed " + "in Airflow 3.2. Please import it from 'airflow.models.asset.expand_alias_to_assets'." ), + id="airflow.datasets.expand_alias_to_datasets", ), - ( + pytest.param( "airflow.datasets.metadata", "Metadata", "airflow.sdk.definitions.asset.metadata.Metadata", ( - "Import from the airflow.dataset module is deprecated and " - "will be removed in the Airflow 3.2. Please import it from " - "'airflow.sdk.definitions.asset.metadata'." + "Import from the airflow.datasets.metadata module is deprecated and " + "will be removed in Airflow 3.2. Please import it from " + "'airflow.sdk'." ), + id="airflow.datasets.metadata.Metadata", ), ), ) diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index adf3fddc41af9..d52d3eaed1547 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -53,7 +53,7 @@ from airflow.jobs.scheduler_job_runner import SchedulerJobRunner from airflow.models.asset import AssetActive, AssetAliasModel, AssetDagRunQueue, AssetEvent, AssetModel from airflow.models.backfill import Backfill, _create_backfill -from airflow.models.dag import DAG, DagModel +from airflow.models.dag import DagModel, get_last_dagrun, infer_automated_data_interval from airflow.models.dag_version import DagVersion from airflow.models.dagbag import DagBag, sync_bag_to_db from airflow.models.dagrun import DagRun @@ -68,7 +68,7 @@ from airflow.providers.standard.operators.bash import BashOperator from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.triggers.temporal import DateTimeTrigger -from airflow.sdk import Asset, AssetAlias, AssetWatcher, task +from airflow.sdk import DAG, Asset, AssetAlias, AssetWatcher, task from airflow.sdk.definitions.deadline import AsyncCallback from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG from airflow.timetables.base import DataInterval @@ -82,6 +82,7 @@ from tests_common.pytest_plugin import AIRFLOW_ROOT_PATH from tests_common.test_utils.asserts import assert_queries_count from tests_common.test_utils.config import conf_vars, env_vars +from tests_common.test_utils.dag import create_scheduler_dag, sync_dag_to_db, sync_dags_to_db from tests_common.test_utils.db import ( clear_db_assets, clear_db_backfills, @@ -148,7 +149,7 @@ def _loader_mock(mock_executors): @pytest.fixture def create_dagrun(session): def _create_dagrun( - dag: DAG, + dag: SerializedDAG, *, logical_date: datetime.datetime, data_interval: DataInterval, @@ -1682,7 +1683,7 @@ def test_critical_section_enqueue_task_instances(self, task1_exec, task2_exec, d # because before scheduler._execute_task_instances would only # check the num tasks once so if max_active_tasks was 3, # we could execute arbitrarily many tasks in the second run - with dag_maker(dag_id=dag_id, max_active_tasks=3, session=session) as dag: + with dag_maker(dag_id=dag_id, max_active_tasks=3, session=session): task1 = EmptyOperator(task_id="t1", executor=task1_exec) task2 = EmptyOperator(task_id="t2", executor=task2_exec) task3 = EmptyOperator(task_id="t3", executor=task2_exec) @@ -1705,14 +1706,16 @@ def test_critical_section_enqueue_task_instances(self, task1_exec, task2_exec, d dr1_ti4.state = State.SCHEDULED session.flush() + def _count_tis(states): + return session.scalar( + select(func.count(TaskInstance.task_id)).where( + TaskInstance.dag_id == dag_id, + TaskInstance.state.in_(states), + ) + ) + assert dr1.state == State.RUNNING - num_tis = DAG.get_num_task_instances( - dag_id=dag_id, - task_ids=dag.task_ids, - states=[State.RUNNING], - session=session, - ) - assert num_tis == 3 + assert _count_tis([TaskInstanceState.RUNNING]) == 3 # create second dag run dr2 = dag_maker.create_dagrun_after(dr1, run_type=DagRunType.SCHEDULED, session=session) @@ -1733,14 +1736,7 @@ def test_critical_section_enqueue_task_instances(self, task1_exec, task2_exec, d assert num_queued == 3 # check that max_active_tasks is respected - - num_tis = DAG.get_num_task_instances( - dag_id=dag_id, - task_ids=dag.task_ids, - states=[State.RUNNING, State.QUEUED], - session=session, - ) - assert num_tis == 6 + assert _count_tis([TaskInstanceState.RUNNING, TaskInstanceState.QUEUED]) == 6 # this doesn't really tell us anything since we set these values manually, but hey dr1_counter = Counter(x.state for x in dr1.get_task_instances(session=session)) @@ -2719,7 +2715,7 @@ def test_dagrun_timeout_fails_run_and_update_next_dagrun(self, dag_maker): @pytest.mark.parametrize( "state, expected_callback_msg", [(State.SUCCESS, "success"), (State.FAILED, "task_failure")] ) - def test_dagrun_callbacks_are_called(self, state, expected_callback_msg, dag_maker): + def test_dagrun_callbacks_are_called(self, state, expected_callback_msg, dag_maker, session): """ Test if DagRun is successful, and if Success callbacks is defined, it is sent to DagFileProcessor. """ @@ -2727,15 +2723,14 @@ def test_dagrun_callbacks_are_called(self, state, expected_callback_msg, dag_mak dag_id="test_dagrun_callbacks_are_called", on_success_callback=lambda x: print("success"), on_failure_callback=lambda x: print("failed"), + session=session, ) as dag: EmptyOperator(task_id="dummy") scheduler_job = Job(executor=self.null_exec) self.job_runner = SchedulerJobRunner(job=scheduler_job) - session = settings.Session() dr = dag_maker.create_dagrun() - ti = dr.get_task_instance("dummy", session) ti.set_state(state, session) session.flush() @@ -2749,8 +2744,8 @@ def test_dagrun_callbacks_are_called(self, state, expected_callback_msg, dag_mak is_failure_callback=bool(state == State.FAILED), run_id=dr.run_id, msg=expected_callback_msg, - bundle_name=dag.get_bundle_name(), - bundle_version=dag.get_bundle_version(), + bundle_name="dag_maker", + bundle_version=None, context_from_server=DagRunContext( dag_run=dr, last_ti=ti, @@ -2803,6 +2798,7 @@ def test_dagrun_timeout_callbacks_are_stored_in_database(self, dag_maker, sessio dag_id="test_dagrun_timeout_callbacks_are_stored_in_database", on_failure_callback=lambda x: print("failed"), dagrun_timeout=timedelta(hours=1), + session=session, ) as dag: EmptyOperator(task_id="empty") @@ -2829,8 +2825,8 @@ def test_dagrun_timeout_callbacks_are_stored_in_database(self, dag_maker, sessio is_failure_callback=True, run_id=dr.run_id, msg="timed_out", - bundle_name=dag.get_bundle_name(), - bundle_version=dag.get_bundle_version(), + bundle_name="dag_maker", + bundle_version=None, context_from_server=DagRunContext( dag_run=dr, last_ti=dr.get_last_ti(dag, session), @@ -3078,8 +3074,6 @@ def test_dagrun_root_after_dagrun_unfinished(self, mock_executor, testing_dag_bu dagbag = DagBag(TEST_DAG_FOLDER, include_examples=False) sync_bag_to_db(dagbag, "testing", None) dag_id = "test_dagrun_states_root_future" - dag = dagbag.get_dag(dag_id) - DAG.bulk_write_to_db("testing", None, [dag]) scheduler_job = Job() self.job_runner = SchedulerJobRunner(job=scheduler_job, num_runs=2) @@ -3099,16 +3093,15 @@ def test_scheduler_start_date(self, testing_dag_bundle): with create_session() as session: dag_id = "test_start_date_scheduling" dag = dagbag.get_dag(dag_id) - dag.clear() - assert dag.start_date > datetime.datetime.now(timezone.utc) # Deactivate other dags in this file other_dag = dagbag.get_dag("test_task_start_date_scheduling") other_dag.is_paused_upon_creation = True - DAG.bulk_write_to_db("testing", None, [dag, other_dag]) - session.flush() - SerializedDagModel.write_dag(dag, bundle_name="testing", session=session) - SerializedDagModel.write_dag(other_dag, bundle_name="testing", session=session) + + scheduler_dag, _ = sync_dags_to_db([dag, other_dag]) + scheduler_dag.clear() + assert scheduler_dag.start_date > datetime.datetime.now(timezone.utc) + scheduler_job = Job(executor=self.null_exec) self.job_runner = SchedulerJobRunner(job=scheduler_job, num_runs=1) run_job(scheduler_job, execute_callable=self.job_runner._execute) @@ -3123,7 +3116,7 @@ def test_scheduler_start_date(self, testing_dag_bundle): # That behavior still exists, but now it will only do so if after the # start date data_interval_end = DEFAULT_DATE + timedelta(days=1) - dag.create_dagrun( + scheduler_dag.create_dagrun( state="success", triggered_by=DagRunTriggeredByType.TIMETABLE, run_id="abc123", @@ -3237,7 +3230,7 @@ def test_scheduler_multiprocessing(self): dag = dagbag.get_dag(dag_id) if not dag: raise ValueError(f"could not find dag {dag_id}") - dag.clear() + create_scheduler_dag(dag).clear() scheduler_job = Job(executor=self.null_exec) self.job_runner = SchedulerJobRunner(job=scheduler_job, num_runs=1) @@ -3325,32 +3318,11 @@ def _create_dagruns(): # As tasks require 2 slots, only 3 can fit into 6 available assert len(task_instances_list) == 3 + @pytest.mark.need_serialized_dag def test_scheduler_keeps_scheduling_pool_full(self, dag_maker, mock_executor): """ Test task instances in a pool that isn't full keep getting scheduled even when a pool is full. """ - with dag_maker( - dag_id="test_scheduler_keeps_scheduling_pool_full_d1", - start_date=DEFAULT_DATE, - ): - BashOperator( - task_id="test_scheduler_keeps_scheduling_pool_full_t1", - pool="test_scheduler_keeps_scheduling_pool_full_p1", - bash_command="echo hi", - ) - dag_d1 = dag_maker.dag - - with dag_maker( - dag_id="test_scheduler_keeps_scheduling_pool_full_d2", - start_date=DEFAULT_DATE, - ): - BashOperator( - task_id="test_scheduler_keeps_scheduling_pool_full_t2", - pool="test_scheduler_keeps_scheduling_pool_full_p2", - bash_command="echo hi", - ) - dag_d2 = dag_maker.dag - session = settings.Session() pool_p1 = Pool(pool="test_scheduler_keeps_scheduling_pool_full_p1", slots=1, include_deferred=False) pool_p2 = Pool(pool="test_scheduler_keeps_scheduling_pool_full_p2", slots=10, include_deferred=False) @@ -3361,7 +3333,10 @@ def test_scheduler_keeps_scheduling_pool_full(self, dag_maker, mock_executor): scheduler_job = Job() self.job_runner = SchedulerJobRunner(job=scheduler_job) - def _create_dagruns(dag: DAG): + # We'll use this to create 30 dagruns for each DAG. + # To increase the chances the TIs from the "full" pool will get + # retrieved first, we schedule all TIs from the first dag first. + def _create_dagruns(dag: SerializedDAG): next_info = dag.next_dagrun_info(None) assert next_info is not None for i in range(30): @@ -3379,11 +3354,27 @@ def _create_dagruns(dag: DAG): if next_info is None: break - # Create 30 dagruns for each DAG. - # To increase the chances the TIs from the "full" pool will get retrieved first, we schedule all - # TIs from the first dag first. + with dag_maker( + dag_id="test_scheduler_keeps_scheduling_pool_full_d1", + start_date=DEFAULT_DATE, + ) as dag_d1: + BashOperator( + task_id="test_scheduler_keeps_scheduling_pool_full_t1", + pool="test_scheduler_keeps_scheduling_pool_full_p1", + bash_command="echo hi", + ) for dr in _create_dagruns(dag_d1): self.job_runner._schedule_dag_run(dr, session) + + with dag_maker( + dag_id="test_scheduler_keeps_scheduling_pool_full_d2", + start_date=DEFAULT_DATE, + ) as dag_d2: + BashOperator( + task_id="test_scheduler_keeps_scheduling_pool_full_t2", + pool="test_scheduler_keeps_scheduling_pool_full_p2", + bash_command="echo hi", + ) for dr in _create_dagruns(dag_d2): self.job_runner._schedule_dag_run(dr, session) @@ -3540,7 +3531,7 @@ def test_verify_integrity_if_dag_changed(self, dag_maker): session = settings.Session() orm_dag = dag_maker.dag_model assert orm_dag is not None - SerializedDagModel.write_dag(dag, bundle_name="testing") + SerializedDagModel.write_dag(LazyDeserializedDAG.from_dag(dag), bundle_name="testing") assert orm_dag.bundle_version is None scheduler_job = Job() @@ -3558,7 +3549,9 @@ def test_verify_integrity_if_dag_changed(self, dag_maker): # Now let's say the DAG got updated (new task got added) BashOperator(task_id="bash_task_1", dag=dag, bash_command="echo hi") - SerializedDagModel.write_dag(dag=dag, bundle_name="testing", session=session) + SerializedDagModel.write_dag( + LazyDeserializedDAG.from_dag(dag), bundle_name="testing", session=session + ) dag_version_2 = DagVersion.get_latest_version(dr.dag_id, session=session) assert dag_version_2 != dag_version_1 @@ -3621,8 +3614,8 @@ def test_verify_integrity_not_called_for_versioned_bundles(self, dag_maker, sess dag_version_1 = DagVersion.get_latest_version(dr.dag_id, session=session) # Now let's say the DAG got updated (new task got added) - BashOperator(task_id="bash_task_1", dag=dag, bash_command="echo hi") - SerializedDagModel.write_dag(dag=dag, bundle_name="dag_maker", session=session) + BashOperator(task_id="bash_task_1", dag=dag_maker.dag, bash_command="echo hi") + sync_dag_to_db(dag_maker.dag, bundle_name="dag_maker", session=session) session.commit() dag_version_2 = DagVersion.get_latest_version(dr.dag_id, session=session) assert dag_version_2 != dag_version_1 @@ -3644,13 +3637,13 @@ def test_retry_still_in_executor(self, dag_maker, session): dag_id="test_retry_still_in_executor", schedule="@once", session=session, - ): + ) as dag: dag_task1 = BashOperator( task_id="test_retry_handling_op", bash_command="exit 1", retries=1, ) - dag_maker.dag_model.calculate_dagrun_date_fields(dag_maker.dag, None) + dag_maker.dag_model.calculate_dagrun_date_fields(dag, None) @provide_session def do_schedule(session): @@ -3709,42 +3702,6 @@ def run_with_error(ti, ignore_ti_state=False): assert ti.try_number == 1 assert ti.state == State.SUCCESS - def test_dag_get_active_runs(self, dag_maker): - """ - Test to check that a DAG returns its active runs - """ - now = timezone.utcnow() - six_hours_ago_to_the_hour = (now - datetime.timedelta(hours=6)).replace( - minute=0, second=0, microsecond=0 - ) - - start_date = six_hours_ago_to_the_hour - dag_name1 = "get_active_runs_test" - - default_args = {"depends_on_past": False, "start_date": start_date} - with dag_maker(dag_name1, schedule="* * * * *", max_active_runs=1, default_args=default_args) as dag1: - run_this_1 = EmptyOperator(task_id="run_this_1") - run_this_2 = EmptyOperator(task_id="run_this_2") - run_this_2.set_upstream(run_this_1) - run_this_3 = EmptyOperator(task_id="run_this_3") - run_this_3.set_upstream(run_this_2) - - dr = dag_maker.create_dagrun() - - # We had better get a dag run - assert dr is not None - - logical_date = dr.logical_date - - running_dates = dag1.get_active_runs() - - try: - running_date = running_dates[0] - except Exception: - running_date = "Except" - - assert logical_date == running_date, "Running Date must match Execution Date" - def test_adopt_or_reset_orphaned_tasks_nothing(self): """Try with nothing.""" scheduler_job = Job() @@ -4121,12 +4078,10 @@ def test_create_dag_runs(self, dag_maker): with create_session() as session: self.job_runner._create_dag_runs([dag_model], session) - dr = session.query(DagRun).filter(DagRun.dag_id == dag.dag_id).first() - # Assert dr state is queued + dr = session.query(DagRun).filter(DagRun.dag_id == dag.dag_id).one() assert dr.state == State.QUEUED assert dr.start_date is None - - assert dag.get_last_dagrun().creating_job_id == scheduler_job.id + assert dr.creating_job_id == scheduler_job.id @pytest.mark.need_serialized_dag def test_create_dag_runs_assets(self, session, dag_maker): @@ -4350,26 +4305,25 @@ def test_no_create_dag_runs_when_dag_disabled(self, session, dag_maker, disable, @time_machine.travel(DEFAULT_DATE + datetime.timedelta(days=1, seconds=9), tick=False) @mock.patch("airflow.jobs.scheduler_job_runner.Stats.timing") - def test_start_dagruns(self, stats_timing, dag_maker): + def test_start_dagruns(self, stats_timing, dag_maker, session): """ Test that _start_dagrun: - moves runs to RUNNING State - emit the right DagRun metrics """ + from airflow.models.dag import get_last_dagrun + with dag_maker(dag_id="test_start_dag_runs") as dag: - EmptyOperator( - task_id="dummy", - ) + EmptyOperator(task_id="dummy") dag_model = dag_maker.dag_model scheduler_job = Job(executor=self.null_exec) self.job_runner = SchedulerJobRunner(job=scheduler_job) - with create_session() as session: - self.job_runner._create_dag_runs([dag_model], session) - self.job_runner._start_queued_dagruns(session) + self.job_runner._create_dag_runs([dag_model], session) + self.job_runner._start_queued_dagruns(session) dr = session.query(DagRun).filter(DagRun.dag_id == dag.dag_id).first() # Assert dr state is running @@ -4389,7 +4343,7 @@ def test_start_dagruns(self, stats_timing, dag_maker): ] ) - assert dag.get_last_dagrun().creating_job_id == scheduler_job.id + assert get_last_dagrun(dag.dag_id, session).creating_job_id == scheduler_job.id def test_extra_operator_links_not_loaded_in_scheduler_loop(self, dag_maker): """ @@ -4498,7 +4452,7 @@ def test_bulk_write_to_db_external_trigger_dont_skip_scheduled_run(self, dag_mak assert dag_model.next_dagrun_data_interval_end == DEFAULT_DATE + timedelta(minutes=2) # Trigger the Dag externally - data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE) + data_interval = infer_automated_data_interval(dag.timetable, DEFAULT_LOGICAL_DATE) dr = dag.create_dagrun( run_id="test", state=DagRunState.RUNNING, @@ -4510,8 +4464,6 @@ def test_bulk_write_to_db_external_trigger_dont_skip_scheduled_run(self, dag_mak triggered_by=DagRunTriggeredByType.TEST, ) assert dr is not None - # Run DAG.bulk_write_to_db -- this is run when in DagFileProcessor.process_file - DAG.bulk_write_to_db("testing", None, [dag], session=session) # Test that 'dag_model.next_dagrun' has not been changed because of newly created external # triggered DagRun. @@ -4520,7 +4472,7 @@ def test_bulk_write_to_db_external_trigger_dont_skip_scheduled_run(self, dag_mak assert dag_model.next_dagrun_data_interval_start == DEFAULT_DATE + timedelta(minutes=1) assert dag_model.next_dagrun_data_interval_end == DEFAULT_DATE + timedelta(minutes=2) - def test_scheduler_create_dag_runs_check_existing_run(self, dag_maker): + def test_scheduler_create_dag_runs_check_existing_run(self, dag_maker, session): """ Test that if a dag run exists, scheduler._create_dag_runs does not raise an error. And if a Dag Run does not exist it creates next Dag Run. In both cases the Scheduler @@ -4533,12 +4485,9 @@ def test_scheduler_create_dag_runs_check_existing_run(self, dag_maker): schedule=timedelta(days=1), catchup=True, ) as dag: - EmptyOperator( - task_id="dummy", - ) + EmptyOperator(task_id="dummy") - session = settings.Session() - assert dag.get_last_dagrun(session) is None + assert get_last_dagrun(dag.dag_id, session) is None dag_model = dag_maker.dag_model @@ -4555,7 +4504,7 @@ def test_scheduler_create_dag_runs_check_existing_run(self, dag_maker): ) session.flush() - assert dag.get_last_dagrun(session) == dagrun + assert get_last_dagrun(dag.dag_id, session) == dagrun scheduler_job = Job(executor=self.null_exec) self.job_runner = SchedulerJobRunner(job=scheduler_job) @@ -4584,7 +4533,7 @@ def test_do_schedule_max_active_runs_dag_timed_out(self, dag_maker, session): bash_command=' for((i=1;i<=600;i+=1)); do sleep "$i"; done', ) - data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE) + data_interval = infer_automated_data_interval(dag.timetable, DEFAULT_LOGICAL_DATE) run1 = dag.create_dagrun( run_id="test1", run_type=DagRunType.SCHEDULED, @@ -4802,8 +4751,6 @@ def test_do_schedule_max_active_runs_and_manual_trigger(self, dag_maker, mock_ex session = settings.Session() dag_run = dag_maker.create_dagrun(state=State.QUEUED, session=session) - DAG.bulk_write_to_db("testing", None, [dag], session=session) # Update the date fields - scheduler_job = Job() self.job_runner = SchedulerJobRunner(job=scheduler_job) @@ -5562,7 +5509,7 @@ def test_no_dagruns_would_stuck_in_running(self, dag_maker): task1 = EmptyOperator(task_id="dummy_task") dr1_running = dag_maker.create_dagrun(run_id="dr1_run_1", logical_date=date) - data_interval = dag.infer_automated_data_interval(logical_date) + data_interval = infer_automated_data_interval(dag.timetable, logical_date) dag_maker.create_dagrun( run_id="dr1_run_2", state=State.QUEUED, @@ -5849,7 +5796,7 @@ def test_scheduler_job_add_new_task(self, dag_maker): assert len(tis) == 1 BashOperator(task_id="dummy2", dag=dag, bash_command="echo test") - SerializedDagModel.write_dag(dag=dag, bundle_name="dag_maker", session=session) + sync_dag_to_db(dag_maker.dag, bundle_name="dag_maker", session=session) session.commit() self.job_runner._schedule_dag_run(dr, session) session.expunge_all() @@ -5863,7 +5810,8 @@ def test_scheduler_job_add_new_task(self, dag_maker): tis = dr.get_task_instances(session=session) assert len(tis) == 2 - def test_runs_respected_after_clear(self, dag_maker): + @pytest.mark.need_serialized_dag + def test_runs_respected_after_clear(self, dag_maker, session): """ Test dag after dag.clear, max_active_runs is respected """ @@ -5877,17 +5825,15 @@ def test_runs_respected_after_clear(self, dag_maker): scheduler_job = Job() self.job_runner = SchedulerJobRunner(job=scheduler_job) - session = settings.Session() dr = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED, state=State.QUEUED) dr = dag_maker.create_dagrun_after(dr, run_type=DagRunType.SCHEDULED, state=State.QUEUED) dag_maker.create_dagrun_after(dr, run_type=DagRunType.SCHEDULED, state=State.QUEUED) - dag.clear() - + dag.clear(session=session) assert len(DagRun.find(dag_id=dag.dag_id, state=State.QUEUED, session=session)) == 3 - session = settings.Session() self.job_runner._start_queued_dagruns(session) session.flush() + # Assert that only 1 dagrun is active assert len(DagRun.find(dag_id=dag.dag_id, state=State.RUNNING, session=session)) == 1 # Assert that the other two are queued @@ -5953,7 +5899,7 @@ def test_retry_on_db_error_when_update_timeout_triggers(self, dag_maker, testing schedule="@once", max_active_runs=1, session=session, - ) as dag: + ): EmptyOperator(task_id="dummy1") # Mock the db failure within retry times @@ -5978,10 +5924,7 @@ def side_effect(*args, **kwargs): # Create a Task Instance for the task that is allegedly deferred # but past its timeout, and one that is still good. # We don't actually need a linked trigger here; the code doesn't check. - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [dag]) - SerializedDagModel.write_dag(dag=dag, bundle_name=bundle_name) - session.flush() + sync_dag_to_db(dag_maker.dag, session=session) dr1 = dag_maker.create_dagrun() dr2 = dag_maker.create_dagrun( run_id="test2", logical_date=DEFAULT_DATE + datetime.timedelta(seconds=1) @@ -6031,14 +5974,11 @@ def test_find_and_purge_task_instances_without_heartbeats(self, session, create_ dagfile = EXAMPLE_STANDARD_DAGS_FOLDER / "example_branch_operator.py" dagbag = DagBag(dagfile) dag = dagbag.get_dag("example_branch_operator") - dm = LazyDeserializedDAG(data=SerializedDAG.to_dict(dag)) - scheduler_dag = DAG.from_sdk_dag(dag) + scheduler_dag = sync_dag_to_db(dag) - DAG.bulk_write_to_db("testing", None, [dm]) - SerializedDagModel.write_dag(dag=dag, bundle_name="testing") dag_v = DagVersion.get_latest_version(dag.dag_id) - data_interval = scheduler_dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE) + data_interval = infer_automated_data_interval(scheduler_dag.timetable, DEFAULT_LOGICAL_DATE) dag_run = create_dagrun( scheduler_dag, @@ -6102,14 +6042,10 @@ def test_task_instance_heartbeat_timeout_message(self, session, create_dagrun): dagfile = EXAMPLE_STANDARD_DAGS_FOLDER / "example_branch_operator.py" dagbag = DagBag(dagfile) dag = dagbag.get_dag("example_branch_operator") - dm = LazyDeserializedDAG(data=SerializedDAG.to_dict(dag)) - scheduler_dag = DAG.from_sdk_dag(dag) - - DAG.bulk_write_to_db("testing", None, [dm]) - SerializedDagModel.write_dag(dag, bundle_name="testing") + scheduler_dag = sync_dag_to_db(dag, session=session) session.query(Job).delete() - data_interval = scheduler_dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE) + data_interval = infer_automated_data_interval(scheduler_dag.timetable, DEFAULT_LOGICAL_DATE) dag_run = create_dagrun( scheduler_dag, logical_date=DEFAULT_DATE, @@ -6234,7 +6170,7 @@ def test_mapped_dag(self, dag_id, session, testing_dag_bundle): dag = dagbag.get_dag(dag_id) assert dag logical_date = timezone.coerce_datetime(timezone.utcnow() - datetime.timedelta(days=2)) - data_interval = dag.infer_automated_data_interval(logical_date) + data_interval = infer_automated_data_interval(dag.timetable, logical_date) dr = dag.create_dagrun( run_id=f"{dag_id}_1", @@ -6363,8 +6299,8 @@ def test_catchup_works_correctly(self, dag_maker, testing_dag_bundle): self.job_runner._schedule_dag_run(dr, session) session.flush() - dag.catchup = False - DAG.bulk_write_to_db("testing", None, [dag]) + dag_maker.dag.catchup = False + dag = sync_dag_to_db(dag_maker.dag, bundle_name="dag_maker", session=session) assert not dag.catchup dm = DagModel.get_dagmodel(dag.dag_id) @@ -6546,6 +6482,7 @@ def test_asset_orphaning_ignore_orphaned_assets(self, dag_maker, session): pytest.param(True, False, None, id="stale-paused"), ], ) + @pytest.mark.need_serialized_dag(False) def test_delete_unreferenced_triggers(self, dag_maker, session, paused, stale, expected_classpath): self.job_runner = SchedulerJobRunner(job=Job()) @@ -6556,7 +6493,7 @@ def test_delete_unreferenced_triggers(self, dag_maker, session, paused, stale, e ) with dag_maker(dag_id="dag", schedule=[asset1], session=session) as dag: EmptyOperator(task_id="task") - dags = {"dag": dag} + dags = {"dag": LazyDeserializedDAG.from_dag(dag)} def _update_references() -> None: asset_op = AssetModelOperation.collect(dags) @@ -6618,8 +6555,7 @@ def test_activate_referenced_assets_with_no_existing_warning(self, session, test asset1_1 = Asset(name=asset1_name, uri="it's duplicate", extra=asset_extra) asset1_2 = Asset(name="it's also a duplicate", uri="s3://bucket/key/1", extra=asset_extra) dag1 = DAG(dag_id=dag_id1, start_date=DEFAULT_DATE, schedule=[asset1, asset1_1, asset1_2]) - - DAG.bulk_write_to_db("testing", None, [dag1], session=session) + sync_dag_to_db(dag1, session=session) asset_models = session.scalars(select(AssetModel)).all() assert len(asset_models) == 3 @@ -6645,13 +6581,6 @@ def test_activate_referenced_assets_with_existing_warnings(self, session, testin asset1_name = "asset1" asset_extra = {"foo": "bar"} - session.add_all( - [ - DagWarning(dag_id=dag_id, warning_type="asset conflict", message="will not exist") - for dag_id in dag_ids - ] - ) - asset1 = Asset(name=asset1_name, uri="s3://bucket/key/1", extra=asset_extra) asset1_1 = Asset(name=asset1_name, uri="it's duplicate", extra=asset_extra) asset1_2 = Asset(name=asset1_name, uri="it's duplicate 2", extra=asset_extra) @@ -6659,7 +6588,12 @@ def test_activate_referenced_assets_with_existing_warnings(self, session, testin dag2 = DAG(dag_id=dag_ids[1], start_date=DEFAULT_DATE) dag3 = DAG(dag_id=dag_ids[2], start_date=DEFAULT_DATE, schedule=[asset1_2]) - DAG.bulk_write_to_db("testing", None, [dag1, dag2, dag3], session=session) + sync_dags_to_db([dag1, dag2, dag3], session=session) + session.add_all( + DagWarning(dag_id=dag_id, warning_type="asset conflict", message="will not exist") + for dag_id in dag_ids + ) + session.flush() asset_models = session.scalars(select(AssetModel)).all() @@ -6700,15 +6634,15 @@ def test_activate_referenced_assets_with_multiple_conflict_asset_in_one_dag( asset1_name = "asset1" asset_extra = {"foo": "bar"} - session.add(DagWarning(dag_id=dag_id, warning_type="asset conflict", message="will not exist")) - schedule = [Asset(name=asset1_name, uri="s3://bucket/key/1", extra=asset_extra)] schedule.extend( [Asset(name=asset1_name, uri=f"it's duplicate {i}", extra=asset_extra) for i in range(100)] ) dag1 = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, schedule=schedule) + sync_dag_to_db(dag1, session=session) - DAG.bulk_write_to_db("testing", None, [dag1], session=session) + session.add(DagWarning(dag_id=dag_id, warning_type="asset conflict", message="will not exist")) + session.flush() asset_models = session.scalars(select(AssetModel)).all() @@ -6982,7 +6916,7 @@ def test_execute_queries_count_with_harvested_dags( sync_bag_to_db(dagbag, "testing", None) for i, dag in enumerate(dagbag.dags.values()): - dr = dag.create_dagrun( + dr = create_scheduler_dag(dag).create_dagrun( state=State.RUNNING, run_id=f"{DagRunType.MANUAL.value}__{i}", run_after=pendulum.datetime(2025, 1, 1, tz="UTC"), diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index 2c5cc4dd50f24..121088c645eb7 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -54,6 +54,7 @@ from airflow.providers.standard.operators.python import PythonOperator from airflow.providers.standard.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger from airflow.sdk import BaseHook, BaseOperator +from airflow.serialization.serialized_objects import LazyDeserializedDAG from airflow.triggers.base import BaseTrigger, TriggerEvent from airflow.triggers.testing import FailureTrigger, SuccessTrigger from airflow.utils.state import State, TaskInstanceState @@ -116,7 +117,8 @@ def create_trigger_in_db(session, trigger, operator=None): else: operator = BaseOperator(task_id="test_ti", dag=dag) session.add(dag_model) - SerializedDagModel.write_dag(dag, bundle_name=bundle_name) + + SerializedDagModel.write_dag(LazyDeserializedDAG.from_dag(dag), bundle_name=bundle_name) session.add(run) session.add(trigger_orm) session.flush() @@ -425,7 +427,7 @@ async def test_trigger_create_race_condition_38599(session, supervisor_builder, dag = DAG(dag_id="test-dag") dm = DagModel(dag_id="test-dag", bundle_name=bundle_name) session.add(dm) - SerializedDagModel.write_dag(dag, bundle_name=bundle_name) + SerializedDagModel.write_dag(LazyDeserializedDAG.from_dag(dag), bundle_name=bundle_name) dag_run = DagRun(dag.dag_id, run_id="abc", run_type="none", run_after=timezone.utcnow()) dag_version = DagVersion.get_latest_version(dag.dag_id) ti = TaskInstance( diff --git a/airflow-core/tests/unit/models/test_asset.py b/airflow-core/tests/unit/models/test_asset.py index b4412dab1884d..13e0754faad17 100644 --- a/airflow-core/tests/unit/models/test_asset.py +++ b/airflow-core/tests/unit/models/test_asset.py @@ -31,9 +31,11 @@ expand_alias_to_assets, remove_references_to_deleted_dags, ) -from airflow.models.dag import DAG, DagModel +from airflow.models.dag import DagModel from airflow.providers.standard.operators.empty import EmptyOperator -from airflow.sdk.definitions.asset import Asset, AssetAlias +from airflow.sdk import Asset, AssetAlias + +from tests_common.test_utils.dag import sync_dags_to_db pytestmark = pytest.mark.db_test @@ -145,14 +147,10 @@ def test_remove_reference_for_inactive_dag( EmptyOperator(task_id="t2", outlets=Asset(name="a", uri="b://b/")) with dag_maker(dag_id="test2", schedule=schedule, session=session) as dag2: EmptyOperator(task_id="t1", outlets=Asset(name="a", uri="b://b/")) - DAG.bulk_write_to_db( - bundle_name="testing", - bundle_version=None, - dags=[dag1, dag2], - session=session, - ) assert set(session.execute(select_stmt)) == expected_before_clear_1 + sync_dags_to_db([dag1, dag2]) + def _simulate_soft_dag_deletion(dag_id): session.execute(update(DagModel).where(DagModel.dag_id == dag_id).values(is_stale=True)) diff --git a/airflow-core/tests/unit/models/test_cleartasks.py b/airflow-core/tests/unit/models/test_cleartasks.py index 2d70e8a793c98..d0066fffe99f7 100644 --- a/airflow-core/tests/unit/models/test_cleartasks.py +++ b/airflow-core/tests/unit/models/test_cleartasks.py @@ -23,20 +23,20 @@ import pytest from sqlalchemy import select -from airflow.models import DagRun -from airflow.models.dag import DAG from airflow.models.dag_version import DagVersion -from airflow.models.serialized_dag import SerializedDagModel +from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstance, TaskInstance as TI, clear_task_instances from airflow.models.taskinstancehistory import TaskInstanceHistory from airflow.models.taskreschedule import TaskReschedule from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.sensors.python import PythonSensor +from airflow.serialization.serialized_objects import SerializedDAG from airflow.utils.session import create_session from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.types import DagRunTriggeredByType, DagRunType from tests_common.test_utils import db +from tests_common.test_utils.dag import sync_dag_to_db from unit.models import DEFAULT_DATE pytestmark = [pytest.mark.db_test, pytest.mark.need_serialized_dag] @@ -330,14 +330,7 @@ def test_clear_task_instances_maybe_task_removed(self, delete_tasks, dag_maker, dag.task_dict.clear() dag.task_group.children.clear() assert ti1.max_tries == 2 - SerializedDagModel.write_dag( - dag=dag, - bundle_name="dag_maker", - bundle_version=None, - min_update_interval=0, - session=session, - ) - session.commit() + sync_dag_to_db(dag, session=session) session.refresh(ti1) assert ti0.try_number == 1 assert ti0.max_tries == 0 @@ -631,7 +624,7 @@ def _get_ti(old_ti): ) ) - DAG.clear_dags(dags) + SerializedDAG.clear_dags(dags) session.commit() for i in range(num_of_dags): ti = _get_ti(tis[i]) @@ -650,7 +643,7 @@ def _get_ti(old_ti): assert ti.try_number == 2 assert ti.max_tries == 1 session.commit() - DAG.clear_dags(dags, dry_run=True) + SerializedDAG.clear_dags(dags, dry_run=True) session.commit() for i in range(num_of_dags): ti = _get_ti(tis[i]) @@ -664,7 +657,7 @@ def _get_ti(old_ti): ti_fail.state = State.FAILED session.commit() - DAG.clear_dags(dags, only_failed=True) + SerializedDAG.clear_dags(dags, only_failed=True) for ti in tis: ti = _get_ti(ti) diff --git a/airflow-core/tests/unit/models/test_dag.py b/airflow-core/tests/unit/models/test_dag.py index 867dc5aaaa555..cf3814fa9fdd7 100644 --- a/airflow-core/tests/unit/models/test_dag.py +++ b/airflow-core/tests/unit/models/test_dag.py @@ -38,11 +38,7 @@ from airflow._shared.timezones import timezone from airflow._shared.timezones.timezone import datetime as datetime_tz from airflow.configuration import conf -from airflow.exceptions import ( - AirflowException, - ParamValidationError, - UnknownExecutorException, -) +from airflow.exceptions import AirflowException, ParamValidationError from airflow.models import DagBag from airflow.models.asset import ( AssetAliasModel, @@ -51,16 +47,13 @@ AssetModel, TaskOutletAssetReference, ) -from airflow.models.baseoperator import BaseOperator from airflow.models.dag import ( - DAG, DagModel, DagOwnerAttributes, DagTag, - ExecutorLoader, get_asset_triggered_next_run_info, + get_next_data_interval, ) -from airflow.models.dag_version import DagVersion from airflow.models.dagbundle import DagBundleModel from airflow.models.dagrun import DagRun from airflow.models.serialized_dag import SerializedDagModel @@ -69,12 +62,13 @@ from airflow.providers.standard.operators.bash import BashOperator from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.operators.python import PythonOperator -from airflow.sdk import TaskGroup, setup, task as task_decorator, teardown +from airflow.sdk import DAG, BaseOperator, TaskGroup, setup, task as task_decorator, teardown from airflow.sdk.definitions._internal.contextmanager import TaskGroupContext from airflow.sdk.definitions._internal.templater import NativeEnvironment, SandboxedEnvironment from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAll, AssetAny from airflow.sdk.definitions.deadline import AsyncCallback, DeadlineAlert, DeadlineReference from airflow.sdk.definitions.param import Param +from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG from airflow.task.trigger_rule import TriggerRule from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable from airflow.timetables.simple import ( @@ -88,6 +82,7 @@ from airflow.utils.types import DagRunTriggeredByType, DagRunType from tests_common.test_utils.asserts import assert_queries_count +from tests_common.test_utils.dag import create_scheduler_dag, sync_dag_to_db from tests_common.test_utils.db import ( clear_db_assets, clear_db_dag_bundles, @@ -170,22 +165,16 @@ def _create_dagrun( start_date: datetime.datetime | None = None, **kwargs, ) -> DagRun: - bundle_name = "testing" - with create_session() as session: - orm_dag_bundle = DagBundleModel(name=bundle_name) - session.merge(orm_dag_bundle) - session.commit() - DAG.bulk_write_to_db(bundle_name, None, [dag]) - SerializedDagModel.write_dag(dag, bundle_name=bundle_name) logical_date = timezone.coerce_datetime(logical_date) if not isinstance(data_interval, DataInterval): data_interval = DataInterval(*map(timezone.coerce_datetime, data_interval)) - run_id = dag.timetable.generate_run_id( + scheduler_dag = sync_dag_to_db(dag) + run_id = scheduler_dag.timetable.generate_run_id( run_type=run_type, run_after=logical_date or data_interval.end, data_interval=data_interval, ) - return dag.create_dagrun( + return scheduler_dag.create_dagrun( run_id=run_id, logical_date=logical_date, data_interval=data_interval, @@ -262,230 +251,6 @@ def test_dag_task_custom_weight_strategy(self, cls, expected): ti = dr.get_task_instance(task.task_id) assert ti.priority_weight == expected - def test_get_num_task_instances(self, testing_dag_bundle): - test_dag_id = "test_get_num_task_instances_dag" - test_task_id = "task_1" - - test_dag = DAG(dag_id=test_dag_id, schedule=None, start_date=DEFAULT_DATE) - test_task = EmptyOperator(task_id=test_task_id, dag=test_dag) - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [test_dag]) - SerializedDagModel.write_dag(test_dag, bundle_name=bundle_name) - dag_version = DagVersion.get_latest_version(test_dag_id) - dag_version_id = dag_version.id - - dr1 = _create_dagrun( - test_dag, - run_type=DagRunType.MANUAL, - logical_date=DEFAULT_DATE, - data_interval=(DEFAULT_DATE, DEFAULT_DATE), - ) - dr2 = _create_dagrun( - test_dag, - run_type=DagRunType.MANUAL, - logical_date=DEFAULT_DATE + datetime.timedelta(days=1), - data_interval=( - DEFAULT_DATE + datetime.timedelta(days=1), - DEFAULT_DATE + datetime.timedelta(days=1), - ), - ) - dr3 = _create_dagrun( - test_dag, - run_type=DagRunType.MANUAL, - logical_date=DEFAULT_DATE + datetime.timedelta(days=2), - data_interval=( - DEFAULT_DATE + datetime.timedelta(days=2), - DEFAULT_DATE + datetime.timedelta(days=2), - ), - ) - dr4 = _create_dagrun( - test_dag, - run_type=DagRunType.MANUAL, - logical_date=DEFAULT_DATE + datetime.timedelta(days=3), - data_interval=( - DEFAULT_DATE + datetime.timedelta(days=2), - DEFAULT_DATE + datetime.timedelta(days=2), - ), - ) - ti1 = TI(task=test_task, run_id=dr1.run_id, dag_version_id=dag_version_id) - ti1.refresh_from_db() - ti1.state = None - ti2 = TI(task=test_task, run_id=dr2.run_id, dag_version_id=dag_version_id) - ti2.refresh_from_db() - ti2.state = State.RUNNING - ti3 = TI(task=test_task, run_id=dr3.run_id, dag_version_id=dag_version_id) - ti3.refresh_from_db() - ti3.state = State.QUEUED - ti4 = TI(task=test_task, run_id=dr4.run_id, dag_version_id=dag_version_id) - ti4.refresh_from_db() - ti4.state = State.RUNNING - session = settings.Session() - session.merge(ti1) - session.merge(ti2) - session.merge(ti3) - session.merge(ti4) - session.commit() - - assert DAG.get_num_task_instances(test_dag_id, task_ids=["fakename"], session=session) == 0 - assert DAG.get_num_task_instances(test_dag_id, task_ids=[test_task_id], session=session) == 4 - assert ( - DAG.get_num_task_instances(test_dag_id, task_ids=["fakename", test_task_id], session=session) == 4 - ) - assert ( - DAG.get_num_task_instances(test_dag_id, task_ids=[test_task_id], states=[None], session=session) - == 1 - ) - assert ( - DAG.get_num_task_instances( - test_dag_id, task_ids=[test_task_id], states=[State.RUNNING], session=session - ) - == 2 - ) - assert ( - DAG.get_num_task_instances( - test_dag_id, task_ids=[test_task_id], states=[None, State.RUNNING], session=session - ) - == 3 - ) - assert ( - DAG.get_num_task_instances( - test_dag_id, - task_ids=[test_task_id], - states=[None, State.QUEUED, State.RUNNING], - session=session, - ) - == 4 - ) - session.close() - - def test_get_task_instances_before(self, testing_dag_bundle): - BASE_DATE = timezone.datetime(2022, 7, 20, 20) - - test_dag_id = "test_get_task_instances_before" - test_task_id = "the_task" - - test_dag = DAG(dag_id=test_dag_id, schedule=None, start_date=BASE_DATE) - EmptyOperator(task_id=test_task_id, dag=test_dag) - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [test_dag]) - SerializedDagModel.write_dag(test_dag, bundle_name=bundle_name) - - session = settings.Session() - - def dag_run_before(delta_h=0, type=DagRunType.SCHEDULED): - dagrun = test_dag.create_dagrun( - state=State.SUCCESS, - run_type=type, - run_id=f"test_{delta_h}", - logical_date=None, - data_interval=None, - run_after=None, - session=session, - triggered_by=DagRunTriggeredByType.TEST, - ) - dagrun.start_date = dagrun.run_after = dagrun.logical_date = BASE_DATE + timedelta(hours=delta_h) - return dagrun - - dr1 = dag_run_before(delta_h=-1, type=DagRunType.MANUAL) # H19 - dr2 = dag_run_before(delta_h=-2, type=DagRunType.MANUAL) # H18 - dr3 = dag_run_before(delta_h=-3, type=DagRunType.MANUAL) # H17 - dr4 = dag_run_before(delta_h=-4, type=DagRunType.MANUAL) # H16 - dr5 = dag_run_before(delta_h=-5) # H15 - dr6 = dag_run_before(delta_h=-6) # H14 - dr7 = dag_run_before(delta_h=-7) # H13 - dr8 = dag_run_before(delta_h=-8) # H12 - - session.commit() - - REF_DATE = BASE_DATE - - assert set([dr.run_id for dr in [dr1]]) == set( - [ - ti.run_id - for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=1, session=session) - ] - ) - assert set([dr.run_id for dr in [dr1, dr2, dr3]]) == set( - [ - ti.run_id - for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=3, session=session) - ] - ) - assert set([dr.run_id for dr in [dr1, dr2, dr3, dr4, dr5]]) == set( - [ - ti.run_id - for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=5, session=session) - ] - ) - assert set([dr.run_id for dr in [dr1, dr2, dr3, dr4, dr5, dr6, dr7]]) == set( - [ - ti.run_id - for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=7, session=session) - ] - ) - assert set([dr.run_id for dr in [dr1, dr2, dr3, dr4, dr5, dr6, dr7, dr8]]) == set( - [ - ti.run_id - for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=9, session=session) - ] - ) - assert set([dr.run_id for dr in [dr1, dr2, dr3, dr4, dr5, dr6, dr7, dr8]]) == set( - [ - ti.run_id - for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=10, session=session) - ] - ) # stays constrained to available ones - - REF_DATE = BASE_DATE + timedelta(hours=-3.5) - - assert set([dr.run_id for dr in [dr4]]) == set( - [ - ti.run_id - for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=1, session=session) - ] - ) - assert set([dr.run_id for dr in [dr4, dr5, dr6]]) == set( - [ - ti.run_id - for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=3, session=session) - ] - ) - assert set([dr.run_id for dr in [dr4, dr5, dr6, dr7, dr8]]) == set( - [ - ti.run_id - for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=5, session=session) - ] - ) - assert set([dr.run_id for dr in [dr4, dr5, dr6, dr7, dr8]]) == set( - [ - ti.run_id - for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=6, session=session) - ] - ) # stays constrained to available ones - - REF_DATE = BASE_DATE + timedelta(hours=-8) - - assert set([dr.run_id for dr in [dr8]]) == set( - [ - ti.run_id - for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=0, session=session) - ] - ) - assert set([dr.run_id for dr in [dr8]]) == set( - [ - ti.run_id - for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=1, session=session) - ] - ) - assert set([dr.run_id for dr in [dr8]]) == set( - [ - ti.run_id - for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=10, session=session) - ] - ) - - session.close() - def test_user_defined_filters_macros(self): def jinja_udf(name): return f"Hello {name}" @@ -576,10 +341,8 @@ def test_create_dagrun_when_schedule_is_none_and_empty_start_date(self, testing_ # Check that we don't get an AttributeError 'start_date' for self.start_date when schedule is none dag = DAG("dag_with_none_schedule_and_empty_start_date", schedule=None) dag.add_task(BaseOperator(task_id="task_without_start_date")) - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [dag]) - SerializedDagModel.write_dag(dag, bundle_name=bundle_name) - dagrun = dag.create_dagrun( + scheduler_dag = sync_dag_to_db(dag) + dagrun = scheduler_dag.create_dagrun( run_id="test", state=State.RUNNING, run_type=DagRunType.MANUAL, @@ -602,7 +365,7 @@ def test_dagtag_repr(self, testing_dag_bundle): session.add(orm_dag) session.flush() - dag.sync_to_db() + sync_dag_to_db(dag) with create_session() as session: assert {"tag-1", "tag-2"} == { repr(t) for t in session.query(DagTag).filter(DagTag.dag_id == "dag-test-dagtag").all() @@ -611,12 +374,16 @@ def test_dagtag_repr(self, testing_dag_bundle): def test_bulk_write_to_db(self, testing_dag_bundle): clear_db_dags() dags = [ - DAG(f"dag-bulk-sync-{i}", schedule=None, start_date=DEFAULT_DATE, tags=["test-dag"]) + SerializedDAG.deserialize_dag( + SerializedDAG.serialize_dag( + DAG(f"dag-bulk-sync-{i}", schedule=None, start_date=DEFAULT_DATE, tags=["test-dag"]) + ) + ) for i in range(4) ] with assert_queries_count(6): - DAG.bulk_write_to_db("testing", None, dags) + SerializedDAG.bulk_write_to_db("testing", None, dags) with create_session() as session: assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == { row[0] for row in session.query(DagModel.dag_id).all() @@ -633,14 +400,14 @@ def test_bulk_write_to_db(self, testing_dag_bundle): # Re-sync should do fewer queries with assert_queries_count(9): - DAG.bulk_write_to_db("testing", None, dags) + SerializedDAG.bulk_write_to_db("testing", None, dags) with assert_queries_count(9): - DAG.bulk_write_to_db("testing", None, dags) + SerializedDAG.bulk_write_to_db("testing", None, dags) # Adding tags for dag in dags: dag.tags.add("test-dag2") with assert_queries_count(10): - DAG.bulk_write_to_db("testing", None, dags) + SerializedDAG.bulk_write_to_db("testing", None, dags) with create_session() as session: assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == { row[0] for row in session.query(DagModel.dag_id).all() @@ -659,7 +426,7 @@ def test_bulk_write_to_db(self, testing_dag_bundle): for dag in dags: dag.tags.remove("test-dag") with assert_queries_count(10): - DAG.bulk_write_to_db("testing", None, dags) + SerializedDAG.bulk_write_to_db("testing", None, dags) with create_session() as session: assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == { row[0] for row in session.query(DagModel.dag_id).all() @@ -678,7 +445,7 @@ def test_bulk_write_to_db(self, testing_dag_bundle): for dag in dags: dag.tags = set() with assert_queries_count(10): - DAG.bulk_write_to_db("testing", None, dags) + SerializedDAG.bulk_write_to_db("testing", None, dags) with create_session() as session: assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == { row[0] for row in session.query(DagModel.dag_id).all() @@ -694,12 +461,16 @@ def test_bulk_write_to_db_single_dag(self, testing_dag_bundle): """ clear_db_dags() dags = [ - DAG(f"dag-bulk-sync-{i}", schedule=None, start_date=DEFAULT_DATE, tags=["test-dag"]) + SerializedDAG.deserialize_dag( + SerializedDAG.serialize_dag( + DAG(f"dag-bulk-sync-{i}", schedule=None, start_date=DEFAULT_DATE, tags=["test-dag"]) + ) + ) for i in range(1) ] with assert_queries_count(6): - DAG.bulk_write_to_db("testing", None, dags) + SerializedDAG.bulk_write_to_db("testing", None, dags) with create_session() as session: assert {"dag-bulk-sync-0"} == {row[0] for row in session.query(DagModel.dag_id).all()} assert { @@ -711,9 +482,9 @@ def test_bulk_write_to_db_single_dag(self, testing_dag_bundle): # Re-sync should do fewer queries with assert_queries_count(8): - DAG.bulk_write_to_db("testing", None, dags) + SerializedDAG.bulk_write_to_db("testing", None, dags) with assert_queries_count(8): - DAG.bulk_write_to_db("testing", None, dags) + SerializedDAG.bulk_write_to_db("testing", None, dags) def test_bulk_write_to_db_multiple_dags(self, testing_dag_bundle): """ @@ -721,12 +492,16 @@ def test_bulk_write_to_db_multiple_dags(self, testing_dag_bundle): """ clear_db_dags() dags = [ - DAG(f"dag-bulk-sync-{i}", schedule=None, start_date=DEFAULT_DATE, tags=["test-dag"]) + SerializedDAG.deserialize_dag( + SerializedDAG.serialize_dag( + DAG(f"dag-bulk-sync-{i}", schedule=None, start_date=DEFAULT_DATE, tags=["test-dag"]) + ) + ) for i in range(4) ] with assert_queries_count(6): - DAG.bulk_write_to_db("testing", None, dags) + SerializedDAG.bulk_write_to_db("testing", None, dags) with create_session() as session: assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == { row[0] for row in session.query(DagModel.dag_id).all() @@ -743,19 +518,25 @@ def test_bulk_write_to_db_multiple_dags(self, testing_dag_bundle): # Re-sync should do fewer queries with assert_queries_count(9): - DAG.bulk_write_to_db("testing", None, dags) + SerializedDAG.bulk_write_to_db("testing", None, dags) with assert_queries_count(9): - DAG.bulk_write_to_db("testing", None, dags) + SerializedDAG.bulk_write_to_db("testing", None, dags) @pytest.mark.parametrize("interval", [None, "@daily"]) def test_bulk_write_to_db_interval_save_runtime(self, testing_dag_bundle, interval): mock_active_runs_of_dags = mock.MagicMock(side_effect=DagRun.active_runs_of_dags) with mock.patch.object(DagRun, "active_runs_of_dags", mock_active_runs_of_dags): dags_null_timetable = [ - DAG("dag-interval-None", schedule=None, start_date=TEST_DATE), - DAG("dag-interval-test", schedule=interval, start_date=TEST_DATE), + SerializedDAG.deserialize_dag( + SerializedDAG.serialize_dag(DAG("dag-interval-None", schedule=None, start_date=TEST_DATE)) + ), + SerializedDAG.deserialize_dag( + SerializedDAG.serialize_dag( + DAG("dag-interval-test", schedule=interval, start_date=TEST_DATE) + ) + ), ] - DAG.bulk_write_to_db("testing", None, dags_null_timetable, session=settings.Session()) + SerializedDAG.bulk_write_to_db("testing", None, dags_null_timetable) if interval: mock_active_runs_of_dags.assert_called_once() else: @@ -784,13 +565,12 @@ def test_bulk_write_to_db_max_active_runs(self, testing_dag_bundle, state, catch catchup=catchup, ) dag.max_active_runs = 1 - EmptyOperator(task_id="dummy", dag=dag, owner="airflow") - session = settings.Session() - dag.clear() - DAG.bulk_write_to_db("testing", None, [dag], session=session) - SerializedDagModel.write_dag(dag, bundle_name="testing") + scheduler_dag = sync_dag_to_db(dag) + scheduler_dag.clear() + + session = settings.Session() model = session.get(DagModel, dag.dag_id) if expected_next_dagrun is None: @@ -798,7 +578,8 @@ def test_bulk_write_to_db_max_active_runs(self, testing_dag_bundle, state, catch # Instead of comparing exact dates, verify it's relatively recent and not the old start date current_time = timezone.utcnow() - # Verify it's not using the old DEFAULT_DATE from 2016 and is after that since we are picking up present date. + # Verify it's not using the old DEFAULT_DATE from 2016 and is after + # that since we are picking up present date. assert model.next_dagrun.year >= DEFAULT_DATE.year assert model.next_dagrun.month >= DEFAULT_DATE.month @@ -816,7 +597,7 @@ def test_bulk_write_to_db_max_active_runs(self, testing_dag_bundle, state, catch assert model.next_dagrun == expected_next_dagrun assert model.next_dagrun_create_after == expected_next_dagrun + timedelta(days=1) - dr = dag.create_dagrun( + dr = scheduler_dag.create_dagrun( run_id="test", state=state, logical_date=model.next_dagrun, @@ -827,13 +608,13 @@ def test_bulk_write_to_db_max_active_runs(self, testing_dag_bundle, state, catch triggered_by=DagRunTriggeredByType.TEST, ) assert dr is not None - DAG.bulk_write_to_db("testing", None, [dag]) + SerializedDAG.bulk_write_to_db("testing", None, [dag]) model = session.get(DagModel, dag.dag_id) # We signal "at max active runs" by saying this run is never eligible to be created assert model.next_dagrun_create_after is None # test that bulk_write_to_db again doesn't update next_dagrun_create_after - DAG.bulk_write_to_db("testing", None, [dag]) + SerializedDAG.bulk_write_to_db("testing", None, [dag]) model = session.get(DagModel, dag.dag_id) assert model.next_dagrun_create_after is None @@ -842,12 +623,11 @@ def test_bulk_write_to_db_has_import_error(self, testing_dag_bundle): Test that DagModel.has_import_error is set to false if no import errors. """ dag = DAG(dag_id="test_has_import_error", schedule=None, start_date=DEFAULT_DATE) - EmptyOperator(task_id="dummy", dag=dag, owner="airflow") session = settings.Session() - dag.clear() - DAG.bulk_write_to_db("testing", None, [dag], session=session) + scheduler_dag = sync_dag_to_db(dag, session=session) + scheduler_dag.clear() model = session.get(DagModel, dag.dag_id) @@ -861,7 +641,7 @@ def test_bulk_write_to_db_has_import_error(self, testing_dag_bundle): # assert assert model.has_import_errors # parse - DAG.bulk_write_to_db("testing", None, [dag]) + SerializedDAG.bulk_write_to_db("testing", None, [dag]) model = session.get(DagModel, dag.dag_id) # assert that has_import_error is now false @@ -893,8 +673,8 @@ def test_bulk_write_to_db_assets(self, testing_dag_bundle): ) session = settings.Session() - dag1.clear() - DAG.bulk_write_to_db("testing", None, [dag1, dag2], session=session) + SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag1)).clear() + SerializedDAG.bulk_write_to_db("testing", None, [dag1, dag2], session=session) session.commit() stored_assets = {x.uri: x for x in session.query(AssetModel).all()} asset1_orm = stored_assets[a1.uri] @@ -925,7 +705,8 @@ def test_bulk_write_to_db_assets(self, testing_dag_bundle): EmptyOperator(task_id=task_id, dag=dag1, outlets=[a2]) dag2 = DAG(dag_id=dag_id2, start_date=DEFAULT_DATE, schedule=None) EmptyOperator(task_id=task_id, dag=dag2) - DAG.bulk_write_to_db("testing", None, [dag1, dag2], session=session) + + SerializedDAG.bulk_write_to_db("testing", None, [dag1, dag2], session=session) session.commit() session.expunge_all() stored_assets = {x.uri: x for x in session.query(AssetModel).all()} @@ -957,8 +738,9 @@ def test_bulk_write_to_db_asset_aliases(self, testing_dag_bundle): EmptyOperator(task_id=task_id, dag=dag1, outlets=[asset_alias_1, asset_alias_2, asset_alias_3]) dag2 = DAG(dag_id=dag_id2, start_date=DEFAULT_DATE, schedule=None) EmptyOperator(task_id=task_id, dag=dag2, outlets=[asset_alias_2_2, asset_alias_3]) + session = settings.Session() - DAG.bulk_write_to_db("testing", None, [dag1, dag2], session=session) + SerializedDAG.bulk_write_to_db("testing", None, [dag1, dag2], session=session) session.commit() stored_asset_alias_models = {x.name: x for x in session.query(AssetAliasModel).all()} @@ -1007,34 +789,28 @@ def add_failed_dag_run(dag, id, logical_date): op1 = BashOperator(task_id="task", bash_command="exit 1;") dag.add_task(op1) session = settings.Session() - bundle_name = "testing" orm_dag = DagModel( dag_id=dag.dag_id, - bundle_name=bundle_name, + bundle_name="testing", is_stale=False, ) session.add(orm_dag) session.flush() - dag.sync_to_db(session=session) - SerializedDagModel.write_dag(dag, bundle_name=bundle_name) - assert not dag.get_is_paused() + scheduler_dag = sync_dag_to_db(dag, session=session) + assert not session.get(DagModel, dag.dag_id).is_paused # dag should be paused after 2 failed dag_runs - add_failed_dag_run( - dag, - "1", - TEST_DATE, - ) - add_failed_dag_run(dag, "2", TEST_DATE + timedelta(days=1)) - assert dag.get_is_paused() + add_failed_dag_run(scheduler_dag, "1", TEST_DATE) + add_failed_dag_run(scheduler_dag, "2", TEST_DATE + timedelta(days=1)) + assert session.get(DagModel, dag.dag_id).is_paused def test_dag_is_deactivated_upon_dagfile_deletion(self, dag_maker): dag_id = "old_existing_dag" with dag_maker(dag_id, schedule=None, is_paused_upon_creation=True) as dag: ... session = settings.Session() - dag.sync_to_db(session=session) + sync_dag_to_db(dag, session=session) orm_dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).one() @@ -1068,10 +844,9 @@ def test_schedule_dag_no_previous_runs(self, testing_dag_bundle): dag_id = "test_schedule_dag_no_previous_runs" dag = DAG(dag_id=dag_id, schedule=None) dag.add_task(BaseOperator(task_id="faketastic", owner="Also fake", start_date=TEST_DATE)) - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [dag]) - SerializedDagModel.write_dag(dag, bundle_name=bundle_name) - dag_run = dag.create_dagrun( + + scheduler_dag = sync_dag_to_db(dag) + dag_run = scheduler_dag.create_dagrun( run_id="test", run_type=DagRunType.SCHEDULED, logical_date=TEST_DATE, @@ -1107,12 +882,11 @@ def test_dag_handle_callback_crash(self, mock_stats, testing_dag_bundle): ) when = TEST_DATE dag.add_task(BaseOperator(task_id="faketastic", owner="Also fake", start_date=when)) - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [dag]) - SerializedDagModel.write_dag(dag, bundle_name=bundle_name) + + scheduler_dag = sync_dag_to_db(dag) with create_session() as session: - dag_run = dag.create_dagrun( + dag_run = scheduler_dag.create_dagrun( run_id="test", state=State.RUNNING, logical_date=when, @@ -1146,11 +920,10 @@ def test_dag_handle_callback_with_removed_task(self, dag_maker, session, testing ) as dag: EmptyOperator(task_id="faketastic") task_removed = EmptyOperator(task_id="removed_task") - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [dag]) - SerializedDagModel.write_dag(dag, bundle_name=bundle_name) + + scheduler_dag = sync_dag_to_db(dag) with create_session() as session: - dag_run = dag.create_dagrun( + dag_run = scheduler_dag.create_dagrun( run_id="test", state=State.RUNNING, logical_date=TEST_DATE, @@ -1166,8 +939,8 @@ def test_dag_handle_callback_with_removed_task(self, dag_maker, session, testing assert dag_run.get_task_instance(task_removed.task_id).state == TaskInstanceState.REMOVED # should not raise any exception - dag_run.handle_dag_callback(dag=dag, success=False) - dag_run.handle_dag_callback(dag=dag, success=True) + dag_run.handle_dag_callback(dag=scheduler_dag, success=False) + dag_run.handle_dag_callback(dag=scheduler_dag, success=True) @pytest.mark.parametrize("catchup,expected_next_dagrun", [(True, DEFAULT_DATE), (False, None)]) def test_next_dagrun_after_fake_scheduled_previous( @@ -1199,7 +972,7 @@ def test_next_dagrun_after_fake_scheduled_previous( state=State.SUCCESS, data_interval=(DEFAULT_DATE, DEFAULT_DATE), ) - dag.sync_to_db() + sync_dag_to_db(dag) with create_session() as session: model = session.get(DagModel, dag.dag_id) @@ -1227,9 +1000,6 @@ def test_schedule_dag_once(self, testing_dag_bundle): assert isinstance(dag.timetable, OnceTimetable) dag.add_task(BaseOperator(task_id="faketastic", owner="Also fake", start_date=TEST_DATE)) - # Sync once to create the DagModel - DAG.bulk_write_to_db("testing", None, [dag]) - _create_dagrun( dag, run_type=DagRunType.SCHEDULED, @@ -1239,7 +1009,7 @@ def test_schedule_dag_once(self, testing_dag_bundle): ) # Then sync again after creating the dag run -- this should update next_dagrun - DAG.bulk_write_to_db("testing", None, [dag]) + SerializedDAG.bulk_write_to_db("testing", None, [dag]) with create_session() as session: model = session.get(DagModel, dag.dag_id) @@ -1281,22 +1051,17 @@ class DAGsubclass(DAG): dag_eq = DAG(test_dag_id, schedule=None, default_args=args) - dag_diff_load_time = DAG(test_dag_id, schedule=None, default_args=args) dag_diff_name = DAG(test_dag_id + "_neq", schedule=None, default_args=args) dag_subclass = DAGsubclass(test_dag_id, schedule=None, default_args=args) dag_subclass_diff_name = DAGsubclass(test_dag_id + "2", schedule=None, default_args=args) - for dag_ in [dag_eq, dag_diff_name, dag_subclass, dag_subclass_diff_name]: - dag_.last_loaded = dag.last_loaded - # test identity equality assert dag == dag # test dag (in)equality based on _comps assert dag_eq == dag assert dag_diff_name != dag - assert dag_diff_load_time != dag # test dag inequality based on type even if _comps happen to match assert dag_subclass != dag @@ -1307,7 +1072,6 @@ class DAGsubclass(DAG): # dags are ordered based on dag_id no matter what the type is assert dag < dag_diff_name - assert dag > dag_diff_load_time assert dag < dag_subclass_diff_name # greater than should have been created automatically by functools @@ -1331,7 +1095,7 @@ def test_get_paused_dag_ids(self, testing_dag_bundle): ) session.add(orm_dag) session.flush() - dag.sync_to_db() + sync_dag_to_db(dag) assert DagModel.get_dagmodel(dag_id) is not None paused_dag_ids = DagModel.get_paused_dag_ids([dag_id]) @@ -1398,9 +1162,8 @@ def test_description_from_timetable(self, timetable, expected_description): def test_create_dagrun_job_id_is_set(self, testing_dag_bundle): job_id = 42 dag = DAG(dag_id="test_create_dagrun_job_id_is_set", schedule=None) - DAG.bulk_write_to_db("testing", None, [dag]) - SerializedDagModel.write_dag(dag, bundle_name="testing") - dr = dag.create_dagrun( + scheduler_dag = sync_dag_to_db(dag) + dr = scheduler_dag.create_dagrun( run_id="test_create_dagrun_job_id_is_set", logical_date=DEFAULT_DATE, data_interval=(DEFAULT_DATE, DEFAULT_DATE), @@ -1515,9 +1278,7 @@ def consumer(value): def test_dag_test_basic(self, testing_dag_bundle): dag = DAG(dag_id="test_local_testing_conn_file", schedule=None, start_date=DEFAULT_DATE) - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [dag]) - SerializedDagModel.write_dag(dag, bundle_name=bundle_name) + sync_dag_to_db(dag) mock_object = mock.MagicMock() @@ -1528,17 +1289,14 @@ def check_task(): with dag: check_task() - dag.sync_to_db() - SerializedDagModel.write_dag(dag, bundle_name="testing") + sync_dag_to_db(dag) dag.test() mock_object.assert_called_once() def test_dag_test_with_dependencies(self, testing_dag_bundle): dag = DAG(dag_id="test_local_testing_conn_file", schedule=None, start_date=DEFAULT_DATE) - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [dag]) - SerializedDagModel.write_dag(dag, bundle_name=bundle_name) + sync_dag_to_db(dag) mock_object = mock.MagicMock() @task_decorator @@ -1552,9 +1310,7 @@ def check_task_2(my_input): with dag: check_task_2(check_task()) - - dag.sync_to_db() - SerializedDagModel.write_dag(dag, bundle_name="testing") + sync_dag_to_db(dag) dag.test() mock_object.assert_called_with("output of first task") @@ -1581,9 +1337,7 @@ def handle_dag_failure(context): mock_task_object_1 = mock.MagicMock() mock_task_object_2 = mock.MagicMock() - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [dag]) - SerializedDagModel.write_dag(dag, bundle_name=bundle_name) + sync_dag_to_db(dag) @task_decorator def check_task(): @@ -1597,10 +1351,9 @@ def check_task_2(my_input): with dag: check_task_2(check_task()) - dag.sync_to_db() - SerializedDagModel.write_dag(dag, bundle_name="testing") - dr = dag.test() + sync_dag_to_db(dag) + dr = dag.test() ti1 = dr.get_task_instance("check_task") ti2 = dr.get_task_instance("check_task_2") @@ -1622,9 +1375,7 @@ def test_dag_connection_file(self, tmp_path, testing_dag_bundle): conn_type: postgres """ dag = DAG(dag_id="test_local_testing_conn_file", schedule=None, start_date=DEFAULT_DATE) - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [dag]) - SerializedDagModel.write_dag(dag, bundle_name=bundle_name) + sync_dag_to_db(dag) @task_decorator def check_task(): @@ -1685,7 +1436,7 @@ def test_clear_dag( task_instance_1.job_id = 123 session.commit() - dag.clear( + create_scheduler_dag(dag).clear( start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=1), session=session, @@ -1699,12 +1450,13 @@ def test_clear_dag( def test_next_dagrun_info_once(self): dag = DAG("test_scheduler_dagrun_once", start_date=timezone.datetime(2015, 1, 1), schedule="@once") + scheduler_dag = create_scheduler_dag(dag) - next_info = dag.next_dagrun_info(None) + next_info = scheduler_dag.next_dagrun_info(None) assert next_info assert next_info.logical_date == timezone.datetime(2015, 1, 1) - next_info = dag.next_dagrun_info(next_info.data_interval) + next_info = scheduler_dag.next_dagrun_info(next_info.data_interval) assert next_info is None def test_next_dagrun_info_catchup(self): @@ -1712,7 +1464,7 @@ def test_next_dagrun_info_catchup(self): Test to check that a DAG with catchup = False only schedules beginning now, not back to the start date """ - def make_dag(dag_id, schedule, start_date, catchup): + def make_dag(dag_id, schedule, start_date, catchup) -> SerializedDAG: default_args = { "owner": "airflow", "depends_on_past": False, @@ -1730,7 +1482,7 @@ def make_dag(dag_id, schedule, start_date, catchup): op3 = EmptyOperator(task_id="t3", dag=dag) op1 >> op2 >> op3 - return dag + return create_scheduler_dag(dag) now = timezone.utcnow() six_hours_ago_to_the_hour = (now - datetime.timedelta(hours=6)).replace( @@ -1787,13 +1539,14 @@ def test_next_dagrun_info_timedelta_schedule_and_catchup_false(self, schedule): schedule=schedule, catchup=False, ) + scheduler_dag = create_scheduler_dag(dag) - next_info = dag.next_dagrun_info(None) + next_info = scheduler_dag.next_dagrun_info(None) assert next_info assert next_info.logical_date == timezone.datetime(2020, 1, 4) # The date to create is in the future, this is handled by "DagModel.dags_needing_dagruns" - next_info = dag.next_dagrun_info(next_info.data_interval) + next_info = scheduler_dag.next_dagrun_info(next_info.data_interval) assert next_info assert next_info.logical_date == timezone.datetime(2020, 1, 5) @@ -1809,21 +1562,22 @@ def test_next_dagrun_info_timedelta_schedule_and_catchup_true(self): schedule=timedelta(days=1), catchup=True, ) + scheduler_dag = create_scheduler_dag(dag) - next_info = dag.next_dagrun_info(None) + next_info = scheduler_dag.next_dagrun_info(None) assert next_info assert next_info.logical_date == timezone.datetime(2020, 5, 1) - next_info = dag.next_dagrun_info(next_info.data_interval) + next_info = scheduler_dag.next_dagrun_info(next_info.data_interval) assert next_info assert next_info.logical_date == timezone.datetime(2020, 5, 2) - next_info = dag.next_dagrun_info(next_info.data_interval) + next_info = scheduler_dag.next_dagrun_info(next_info.data_interval) assert next_info assert next_info.logical_date == timezone.datetime(2020, 5, 3) # The date to create is in the future, this is handled by "DagModel.dags_needing_dagruns" - next_info = dag.next_dagrun_info(next_info.data_interval) + next_info = scheduler_dag.next_dagrun_info(next_info.data_interval) assert next_info assert next_info.logical_date == timezone.datetime(2020, 5, 4) @@ -1835,12 +1589,22 @@ class FailingTimetable(Timetable): def next_dagrun_info(self, last_automated_data_interval, restriction): raise RuntimeError("this fails") + def _get_registered_timetable(s): + if s == "unit.models.test_dag.FailingTimetable": + return FailingTimetable + raise ValueError(f"unexpected class {s!r}") + dag = DAG( "test_next_dagrun_info_timetable_exception", start_date=timezone.datetime(2020, 5, 1), schedule=FailingTimetable(), catchup=True, ) + with mock.patch( + "airflow.serialization.serialized_objects._get_registered_timetable", + _get_registered_timetable, + ): + scheduler_dag = create_scheduler_dag(dag) def _check_logs(records: list[logging.LogRecord], data_interval: DataInterval) -> None: assert len(records) == 1 @@ -1852,13 +1616,13 @@ def _check_logs(records: list[logging.LogRecord], data_interval: DataInterval) - ) with caplog.at_level(level=logging.ERROR): - next_info = dag.next_dagrun_info(None) + next_info = scheduler_dag.next_dagrun_info(None) assert next_info is None, "failed next_dagrun_info should return None" _check_logs(caplog.records, data_interval=None) caplog.clear() data_interval = DataInterval(timezone.datetime(2020, 5, 1), timezone.datetime(2020, 5, 2)) with caplog.at_level(level=logging.ERROR): - next_info = dag.next_dagrun_info(data_interval) + next_info = scheduler_dag.next_dagrun_info(data_interval) assert next_info is None, "failed next_dagrun_info should return None" _check_logs(caplog.records, data_interval) @@ -1881,8 +1645,9 @@ def test_next_dagrun_after_auto_align(self): catchup=True, ) EmptyOperator(task_id="dummy", dag=dag, owner="airflow") + scheduler_dag = create_scheduler_dag(dag) - next_info = dag.next_dagrun_info(None) + next_info = scheduler_dag.next_dagrun_info(None) assert next_info assert next_info.logical_date == timezone.datetime(2016, 1, 2, 5, 4) @@ -1893,8 +1658,9 @@ def test_next_dagrun_after_auto_align(self): catchup=True, ) EmptyOperator(task_id="dummy", dag=dag, owner="airflow") + scheduler_dag = create_scheduler_dag(dag) - next_info = dag.next_dagrun_info(None) + next_info = scheduler_dag.next_dagrun_info(None) assert next_info assert next_info.logical_date == timezone.datetime(2016, 1, 1, 10, 10) @@ -1907,8 +1673,9 @@ def test_next_dagrun_after_auto_align(self): catchup=False, ) EmptyOperator(task_id="dummy", dag=dag, owner="airflow") + scheduler_dag = create_scheduler_dag(dag) - next_info = dag.next_dagrun_info(None) + next_info = scheduler_dag.next_dagrun_info(None) assert next_info # With catchup=False, next_dagrun should be based on the current date # Verify it's not using the old start_date @@ -1927,8 +1694,9 @@ def test_next_dagrun_after_auto_align(self): catchup=False, ) EmptyOperator(task_id="dummy", dag=dag, owner="airflow") + scheduler_dag = create_scheduler_dag(dag) - next_info = dag.next_dagrun_info(None) + next_info = scheduler_dag.next_dagrun_info(None) assert next_info # With catchup=False, next_dagrun should be based on the current date # Verify it's not using the old start_date @@ -1942,38 +1710,21 @@ def test_next_dagrun_info_on_29_feb(self): dag = DAG( "test_scheduler_dagrun_29_feb", start_date=timezone.datetime(2024, 1, 1), schedule="0 0 29 2 *" ) + scheduler_dag = create_scheduler_dag(dag) - next_info = dag.next_dagrun_info(None) + next_info = scheduler_dag.next_dagrun_info(None) assert next_info assert next_info.logical_date == timezone.datetime(2024, 2, 29) - next_info = dag.next_dagrun_info(next_info.data_interval) + next_info = scheduler_dag.next_dagrun_info(next_info.data_interval) assert next_info.logical_date == timezone.datetime(2028, 2, 29) assert next_info.data_interval.start == timezone.datetime(2028, 2, 29) assert next_info.data_interval.end == timezone.datetime(2032, 2, 29) - def test_validate_executor_field_executor_not_configured(self): - dag = DAG("test-dag", schedule=None) - EmptyOperator(task_id="t1", dag=dag, executor="test.custom.executor") - with pytest.raises( - UnknownExecutorException, - match="The specified executor test.custom.executor for task t1 is not configured", - ): - dag.validate() - - def test_validate_executor_field(self): - with patch.object(ExecutorLoader, "lookup_executor_name_by_str"): - dag = DAG("test-dag", schedule=None) - EmptyOperator(task_id="t1", dag=dag, executor="test.custom.executor") - dag.validate() - def test_validate_params_on_trigger_dag(self, testing_dag_bundle): dag = DAG("dummy-dag", schedule=None, params={"param1": Param(type="string")}) - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [dag]) - SerializedDagModel.write_dag(dag, bundle_name=bundle_name) with pytest.raises(ParamValidationError, match="No value passed and Param has no default value"): - dag.create_dagrun( + sync_dag_to_db(dag).create_dagrun( run_id="test_dagrun_missing_param", run_type=DagRunType.MANUAL, state=State.RUNNING, @@ -1987,7 +1738,7 @@ def test_validate_params_on_trigger_dag(self, testing_dag_bundle): with pytest.raises( ParamValidationError, match="Invalid input for param param1: None is not of type 'string'" ): - dag.create_dagrun( + sync_dag_to_db(dag).create_dagrun( run_id="test_dagrun_missing_param", run_type=DagRunType.MANUAL, state=State.RUNNING, @@ -1999,7 +1750,7 @@ def test_validate_params_on_trigger_dag(self, testing_dag_bundle): ) dag = DAG("dummy-dag", schedule=None, params={"param1": Param(type="string")}) - dag.create_dagrun( + sync_dag_to_db(dag).create_dagrun( run_id="test_dagrun_missing_param", run_type=DagRunType.MANUAL, state=State.RUNNING, @@ -2026,7 +1777,7 @@ def test_dag_owner_links(self, testing_dag_bundle): session.add(orm_dag) session.flush() assert dag.owner_links == {"owner1": "https://mylink.com", "owner2": "mailto:someone@yoursite.com"} - dag.sync_to_db(session=session) + sync_dag_to_db(dag, session=session) expected_owners = {"dag": {"owner1": "https://mylink.com", "owner2": "mailto:someone@yoursite.com"}} orm_dag_owners = DagOwnerAttributes.get_all(session) @@ -2034,32 +1785,12 @@ def test_dag_owner_links(self, testing_dag_bundle): # Test dag owner links are removed completely dag = DAG("dag", schedule=None, start_date=DEFAULT_DATE) - dag.sync_to_db(session=session) + sync_dag_to_db(dag, session=session) orm_dag_owners = session.query(DagOwnerAttributes).all() assert not orm_dag_owners - def test_get_bundle_name(self, testing_dag_bundle): - dag = DAG("dag") - - # until we've sycned, it'll be None - assert dag.get_bundle_name() is None - - DAG.bulk_write_to_db("testing", None, [dag]) - assert dag.get_bundle_name() == "testing" - - def test_get_bundle_version(self, testing_dag_bundle): - dag = DAG("dag") - - # until we've sycned, it'll be None - assert dag.get_bundle_version() is None - - # Now, it can be none or a str - DAG.bulk_write_to_db("testing", None, [dag]) - assert dag.get_bundle_version() is None - DAG.bulk_write_to_db("testing", "abc", [dag]) - assert dag.get_bundle_version() == "abc" - + @pytest.mark.need_serialized_dag @pytest.mark.parametrize( "reference_type, reference_column", [ @@ -2363,14 +2094,16 @@ def test_relative_fileloc(self, session, testing_dag_bundle): session.merge(dag_model) session.flush() - dag.sync_to_db(session=session) + sync_dag_to_db(dag, session=session) assert dag.fileloc == str(file_path) assert dag.relative_fileloc == str(rel_path) - SerializedDagModel.write_dag(dag, bundle_name=bundle_name, session=session) - session.commit() - session.expunge_all() + SerializedDagModel.write_dag( + LazyDeserializedDAG.from_dag(dag), + bundle_name=bundle_name, + session=session, + ) dm = session.get(DagModel, dag.dag_id) assert dm.fileloc == str(file_path) assert dm.relative_fileloc == str(rel_path) @@ -2392,9 +2125,8 @@ def test__processor_dags_folder(self, session, testing_dag_bundle): session.merge(dag_model) session.flush() - dag.sync_to_db() - assert dag._processor_dags_folder is None - SerializedDagModel.write_dag(dag, bundle_name=bundle_name) + scheduler_dag = sync_dag_to_db(dag) + assert scheduler_dag._processor_dags_folder == settings.DAGS_FOLDER sdm = SerializedDagModel.get(dag.dag_id, session) assert sdm.dag._processor_dags_folder == settings.DAGS_FOLDER @@ -2470,7 +2202,12 @@ def test_asset_expression(self, session: Session, testing_dag_bundle) -> None: ), start_date=datetime.datetime.min, ) - DAG.bulk_write_to_db("testing", None, [dag], session=session) + SerializedDAG.bulk_write_to_db( + "testing", + None, + [SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag))], + session=session, + ) expression = session.scalars(select(DagModel.asset_expression).filter_by(dag_id=dag.dag_id)).one() assert expression == { @@ -2575,10 +2312,9 @@ def test_count_number_queries(self, tasks_count, testing_dag_bundle): dag = DAG("test_dagrun_query_count", schedule=None, start_date=DEFAULT_DATE) for i in range(tasks_count): EmptyOperator(task_id=f"dummy_task_{i}", owner="test", dag=dag) - DAG.bulk_write_to_db("testing", None, [dag]) - SerializedDagModel.write_dag(dag, bundle_name="testing") + scheduler_dag = sync_dag_to_db(dag) with assert_queries_count(5): - dag.create_dagrun( + scheduler_dag.create_dagrun( run_id="test_dagrun_query_count", run_type=DagRunType.MANUAL, state=State.RUNNING, @@ -2589,6 +2325,7 @@ def test_count_number_queries(self, tasks_count, testing_dag_bundle): ) +@pytest.mark.need_serialized_dag @pytest.mark.parametrize( "run_id", ["test-run-id"], @@ -2736,89 +2473,6 @@ def consumer(value): ] -def test_set_task_group_state(session, dag_maker): - """Test that set_task_group_state updates the TaskGroup state and clear downstream failed""" - start_date = datetime_tz(2020, 1, 1) - with dag_maker( - "test_set_task_group_state", - start_date=start_date, - session=session, - serialized=True, - ) as dag: - start = EmptyOperator(task_id="start") - - with TaskGroup("section_1", tooltip="Tasks for section_1") as section_1: - task_1 = EmptyOperator(task_id="task_1") - task_2 = EmptyOperator(task_id="task_2") - task_3 = EmptyOperator(task_id="task_3") - - task_1 >> [task_2, task_3] - - task_4 = EmptyOperator(task_id="task_4") - task_5 = EmptyOperator(task_id="task_5") - task_6 = EmptyOperator(task_id="task_6") - task_7 = EmptyOperator(task_id="task_7") - task_8 = EmptyOperator(task_id="task_8") - - start >> section_1 >> [task_4, task_5, task_6, task_7, task_8] - - dagrun = dag_maker.create_dagrun( - run_id="test-run-id", - state=State.FAILED, - run_type=DagRunType.SCHEDULED, - ) - - def get_ti_from_db(task): - return ( - session.query(TI) - .filter( - TI.dag_id == dag.dag_id, - TI.task_id == task.task_id, - TI.run_id == dagrun.run_id, - ) - .one() - ) - - get_ti_from_db(task_1).state = State.FAILED - get_ti_from_db(task_2).state = State.SUCCESS - get_ti_from_db(task_3).state = State.UPSTREAM_FAILED - get_ti_from_db(task_4).state = State.SUCCESS - get_ti_from_db(task_5).state = State.UPSTREAM_FAILED - get_ti_from_db(task_6).state = State.FAILED - get_ti_from_db(task_7).state = State.SKIPPED - - session.flush() - - altered = dag.set_task_group_state( - group_id=section_1.group_id, - run_id="test-run-id", - state=State.SUCCESS, - session=session, - ) - - # After _mark_task_instance_state, task_1 is marked as SUCCESS - assert get_ti_from_db(task_1).state == State.SUCCESS - # task_2 remains as SUCCESS - assert get_ti_from_db(task_2).state == State.SUCCESS - # task_3 should be marked as SUCCESS - assert get_ti_from_db(task_3).state == State.SUCCESS - # task_4 should remain as SUCCESS - assert get_ti_from_db(task_4).state == State.SUCCESS - # task_5 and task_6 are cleared because they were in FAILED/UPSTREAM_FAILED state - assert get_ti_from_db(task_5).state == State.NONE - assert get_ti_from_db(task_6).state == State.NONE - # task_7 remains as SKIPPED - assert get_ti_from_db(task_7).state == State.SKIPPED - dagrun.refresh_from_db(session=session) - # dagrun should be set to QUEUED - assert dagrun.get_state() == State.QUEUED - - assert {t.key for t in altered} == { - ("test_set_task_group_state", "section_1.task_1", dagrun.run_id, 0, -1), - ("test_set_task_group_state", "section_1.task_3", dagrun.run_id, 0, -1), - } - - def test_dag_teardowns_property_lists_all_teardown_tasks(): @setup def setup_task(): @@ -2889,7 +2543,7 @@ def test_iter_dagrun_infos_between(start_date, expected_infos): dag = DAG(dag_id="test_get_dates", start_date=DEFAULT_DATE, schedule="@hourly") EmptyOperator(task_id="dummy", dag=dag) - iterator = dag.iter_dagrun_infos_between( + iterator = create_scheduler_dag(dag).iter_dagrun_infos_between( earliest=pendulum.instance(start_date), latest=pendulum.instance(DEFAULT_DATE), align=True, @@ -2908,13 +2562,23 @@ def next_dagrun_info(self, last_automated_data_interval, restriction): return DagRunInfo.interval(start, end) raise RuntimeError("this fails") + def _get_registered_timetable(s): + if s == "unit.models.test_dag.FailingAfterOneTimetable": + return FailingAfterOneTimetable + raise ValueError(f"unexpected class {s!r}") + dag = DAG( dag_id="test_iter_dagrun_infos_between_error", start_date=DEFAULT_DATE, schedule=FailingAfterOneTimetable(), ) + with mock.patch( + "airflow.serialization.serialized_objects._get_registered_timetable", + _get_registered_timetable, + ): + scheduler_dag = create_scheduler_dag(dag) - iterator = dag.iter_dagrun_infos_between(earliest=start, latest=end, align=True) + iterator = scheduler_dag.iter_dagrun_infos_between(earliest=start, latest=end, align=True) with caplog.at_level(logging.ERROR): infos = list(iterator) @@ -2923,7 +2587,7 @@ def next_dagrun_info(self, last_automated_data_interval, restriction): assert caplog.record_tuples == [ ( - "airflow.models.dag.DAG", + "airflow.serialization.serialized_objects", logging.ERROR, f"Failed to fetch run info after data interval {DataInterval(start, end)} for DAG {dag.dag_id!r}", ), @@ -2966,9 +2630,10 @@ def test_get_next_data_interval( next_dagrun_data_interval_end=data_interval_end, ) - assert dag.get_next_data_interval(dag_model) == expected_data_interval + assert get_next_data_interval(dag.timetable, dag_model) == expected_data_interval +@pytest.mark.need_serialized_dag @pytest.mark.parametrize( ("dag_date", "tasks_date", "catchup", "restrict"), [ @@ -3101,7 +2766,7 @@ def test_create_dagrun_disallow_manual_to_use_automated_run_id(run_id_type: DagR run_id = DagRun.generate_run_id(run_type=run_id_type, run_after=DEFAULT_DATE, logical_date=DEFAULT_DATE) with pytest.raises(ValueError) as ctx: - dag.create_dagrun( + SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag)).create_dagrun( run_type=DagRunType.MANUAL, run_id=run_id, logical_date=DEFAULT_DATE, diff --git a/airflow-core/tests/unit/models/test_dag_version.py b/airflow-core/tests/unit/models/test_dag_version.py index caab85104d7e7..6d8e240495477 100644 --- a/airflow-core/tests/unit/models/test_dag_version.py +++ b/airflow-core/tests/unit/models/test_dag_version.py @@ -20,9 +20,9 @@ from sqlalchemy import func, select from airflow.models.dag_version import DagVersion -from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.standard.operators.empty import EmptyOperator +from tests_common.test_utils.dag import sync_dag_to_db from tests_common.test_utils.db import clear_db_dags pytestmark = pytest.mark.db_test @@ -48,16 +48,13 @@ def test_writing_dag_version_with_changes(self, dag_maker, session): """This also tested the get_latest_version method""" with dag_maker("test1") as dag: EmptyOperator(task_id="task1") - dag.sync_to_db() - SerializedDagModel.write_dag(dag, bundle_name="dag_maker") + sync_dag_to_db(dag) dag_maker.create_dagrun() # Add extra task to change the dag with dag_maker("test1") as dag2: EmptyOperator(task_id="task1") EmptyOperator(task_id="task2") - dag2.sync_to_db() - SerializedDagModel.write_dag(dag2, bundle_name="dag_maker") - + sync_dag_to_db(dag2) latest_version = DagVersion.get_latest_version(dag.dag_id) assert latest_version.version_number == 2 assert session.scalar(select(func.count()).where(DagVersion.dag_id == dag.dag_id)) == 2 diff --git a/airflow-core/tests/unit/models/test_dagbag.py b/airflow-core/tests/unit/models/test_dagbag.py index c2668755c0670..943d9f094b51a 100644 --- a/airflow-core/tests/unit/models/test_dagbag.py +++ b/airflow-core/tests/unit/models/test_dagbag.py @@ -21,6 +21,7 @@ import logging import os import pathlib +import re import sys import textwrap import warnings @@ -34,12 +35,13 @@ from sqlalchemy import select from airflow import settings -from airflow.models.dag import DAG, DagModel -from airflow.models.dagbag import DagBag, _capture_with_reraise +from airflow.exceptions import UnknownExecutorException +from airflow.executors.executor_loader import ExecutorLoader +from airflow.models.dag import DagModel +from airflow.models.dagbag import DagBag, _capture_with_reraise, _validate_executor_fields from airflow.models.dagwarning import DagWarning, DagWarningType from airflow.models.serialized_dag import SerializedDagModel -from airflow.sdk import BaseOperator -from airflow.utils.session import create_session +from airflow.sdk import DAG, BaseOperator from tests_common.pytest_plugin import AIRFLOW_ROOT_PATH from tests_common.test_utils import db @@ -60,6 +62,27 @@ INVALID_DAG_WITH_DEPTH_FILE_CONTENTS = "def something():\n return airflow_DAG\nsomething()" +def test_validate_executor_field_executor_not_configured(): + with DAG("test-dag", schedule=None) as dag: + BaseOperator(task_id="t1", executor="test.custom.executor") + with pytest.raises( + UnknownExecutorException, + match=re.escape( + "Task 't1' specifies executor 'test.custom.executor', which is not available. " + "Make sure it is listed in your [core] executors configuration, or update the task's " + "executor to use one of the configured executors." + ), + ): + _validate_executor_fields(dag) + + +def test_validate_executor_field(): + with DAG("test-dag", schedule=None) as dag: + BaseOperator(task_id="t1", executor="test.custom.executor") + with patch.object(ExecutorLoader, "lookup_executor_name_by_str"): + _validate_executor_fields(dag) + + def db_clean_up(): db.clear_db_dags() db.clear_db_runs() @@ -549,36 +572,6 @@ def test_process_file_with_none(self, tmp_path): assert dagbag.process_file(None) == [] - def test_deactivate_unknown_dags(self, testing_dag_bundle): - """ - Test that dag_ids not passed into deactivate_unknown_dags - are deactivated when function is invoked - """ - dagbag = DagBag(include_examples=True) - dag_id = "test_deactivate_unknown_dags" - expected_active_dags = dagbag.dags.keys() - - bundle_name = "testing" - - model_before = DagModel( - dag_id=dag_id, - bundle_name=bundle_name, - is_stale=False, - ) - with create_session() as session: - session.merge(model_before) - session.flush() - - DAG.deactivate_unknown_dags(expected_active_dags) - - after_model = DagModel.get_dagmodel(dag_id) - assert not model_before.is_stale - assert after_model.is_stale - - # clean up - with create_session() as session: - session.query(DagModel).filter(DagModel.dag_id == "test_deactivate_unknown_dags").delete() - def test_timeout_dag_errors_are_import_errors(self, tmp_path, caplog): """ Test that if the DAG contains Timeout error it will be still loaded to DB as import_errors diff --git a/airflow-core/tests/unit/models/test_dagcode.py b/airflow-core/tests/unit/models/test_dagcode.py index 818fb4915fd78..14b733e13eb3c 100644 --- a/airflow-core/tests/unit/models/test_dagcode.py +++ b/airflow-core/tests/unit/models/test_dagcode.py @@ -25,12 +25,10 @@ import airflow.example_dags as example_dags_module from airflow.models import DagBag -from airflow.models.dag import DAG from airflow.models.dag_version import DagVersion from airflow.models.dagcode import DagCode -from airflow.models.serialized_dag import SerializedDagModel as SDM from airflow.sdk import task as task_decorator -from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG +from airflow.serialization.serialized_objects import SerializedDAG # To move it to a shared module. from airflow.utils.file import open_maybe_zipped @@ -38,6 +36,7 @@ from airflow.utils.state import DagRunState from airflow.utils.types import DagRunTriggeredByType, DagRunType +from tests_common.test_utils.dag import sync_dag_to_db from tests_common.test_utils.db import clear_db_dag_code, clear_db_dags pytestmark = pytest.mark.db_test @@ -55,8 +54,7 @@ def make_example_dags(module): session.add(testing) dagbag = DagBag(module.__path__[0]) - dags = [LazyDeserializedDAG(data=SerializedDAG.to_dict(dag)) for dag in dagbag.dags.values()] - DAG.bulk_write_to_db("testing", None, dags) + SerializedDAG.bulk_write_to_db("testing", None, dagbag.dags.values()) return dagbag.dags @@ -74,13 +72,13 @@ def teardown_method(self): def _write_two_example_dags(self, session): example_dags = make_example_dags(example_dags_module) bash_dag = example_dags["example_bash_operator"] - SDM.write_dag(bash_dag, bundle_name="testing") + sync_dag_to_db(bash_dag, session=session) dag_version = DagVersion.get_latest_version("example_bash_operator") x = DagCode(dag_version, bash_dag.fileloc) session.add(x) session.commit() xcom_dag = example_dags["example_xcom"] - SDM.write_dag(xcom_dag, bundle_name="testing") + sync_dag_to_db(xcom_dag, session=session) dag_version = DagVersion.get_latest_version("example_xcom") x = DagCode(dag_version, xcom_dag.fileloc) session.add(x) @@ -89,8 +87,9 @@ def _write_two_example_dags(self, session): def _write_example_dags(self): example_dags = make_example_dags(example_dags_module) - for dag in example_dags.values(): - SDM.write_dag(dag, bundle_name="testing") + with create_session() as session: + for dag in example_dags.values(): + sync_dag_to_db(dag, session=session) return example_dags def test_write_to_db(self, testing_dag_bundle): @@ -130,7 +129,7 @@ def test_code_can_be_read_when_no_access_to_file(self, testing_dag_bundle): Source Code should at least exist in one of DB or File. """ example_dag = make_example_dags(example_dags_module).get("example_bash_operator") - SDM.write_dag(example_dag, bundle_name="testing") + sync_dag_to_db(example_dag) # Mock that there is no access to the Dag File with patch("airflow.models.dagcode.open_maybe_zipped") as mock_open: @@ -143,10 +142,7 @@ def test_code_can_be_read_when_no_access_to_file(self, testing_dag_bundle): def test_db_code_created_on_serdag_change(self, session, testing_dag_bundle): """Test new DagCode is created in DB when ser dag is changed""" example_dag = make_example_dags(example_dags_module).get("example_bash_operator") - SDM.write_dag(example_dag, bundle_name="testing") - - dag = DAG.from_sdk_dag(example_dag) - dag.create_dagrun( + sync_dag_to_db(example_dag, session=session).create_dagrun( run_id="test1", run_after=pendulum.datetime(2025, 1, 1, tz="UTC"), state=DagRunState.QUEUED, @@ -166,7 +162,7 @@ def test_db_code_created_on_serdag_change(self, session, testing_dag_bundle): example_dag.doc_md = "new doc" with patch("airflow.models.dagcode.DagCode.get_code_from_file") as mock_code: mock_code.return_value = "# dummy code" - SDM.write_dag(example_dag, bundle_name="testing") + sync_dag_to_db(example_dag, session=session) new_result = ( session.query(DagCode) @@ -184,13 +180,11 @@ def test_has_dag(self, dag_maker): """Test has_dag method.""" with dag_maker("test_has_dag") as dag: pass - dag.sync_to_db() - SDM.write_dag(dag, bundle_name="dag_maker") + sync_dag_to_db(dag) with dag_maker() as dag2: pass - dag2.sync_to_db() - SDM.write_dag(dag2, bundle_name="dag_maker") + sync_dag_to_db(dag2) assert DagCode.has_dag(dag.dag_id) @@ -203,8 +197,7 @@ def mytask(): print("task4") mytask() - dag.sync_to_db() - SDM.write_dag(dag, bundle_name="dag_maker") + sync_dag_to_db(dag) dag_code = DagCode.get_latest_dagcode(dag.dag_id) dag_code.source_code_hash = 2 session.add(dag_code) diff --git a/airflow-core/tests/unit/models/test_dagrun.py b/airflow-core/tests/unit/models/test_dagrun.py index 9c85dc041e618..54a5a55544c26 100644 --- a/airflow-core/tests/unit/models/test_dagrun.py +++ b/airflow-core/tests/unit/models/test_dagrun.py @@ -33,7 +33,7 @@ from airflow import settings from airflow._shared.timezones import timezone from airflow.callbacks.callback_requests import DagCallbackRequest, DagRunContext -from airflow.models.dag import DAG, DagModel +from airflow.models.dag import DagModel, infer_automated_data_interval from airflow.models.dag_version import DagVersion from airflow.models.dagrun import DagRun, DagRunNote from airflow.models.serialized_dag import SerializedDagModel @@ -43,9 +43,9 @@ from airflow.providers.standard.operators.bash import BashOperator from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.operators.python import PythonOperator, ShortCircuitOperator -from airflow.sdk import BaseOperator, setup, task, task_group, teardown +from airflow.sdk import DAG, BaseOperator, setup, task, task_group, teardown from airflow.sdk.definitions.deadline import AsyncCallback, DeadlineAlert, DeadlineReference -from airflow.serialization.serialized_objects import SerializedDAG +from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG from airflow.stats import Stats from airflow.task.trigger_rule import TriggerRule from airflow.triggers.base import StartTriggerArgs @@ -56,6 +56,7 @@ from tests_common.test_utils import db from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.dag import sync_dag_to_db from tests_common.test_utils.mock_operators import MockOperator from unit.models import DEFAULT_DATE as _DEFAULT_DATE @@ -99,9 +100,9 @@ def _clean_db(): db.clear_db_xcom() db.clear_db_dags() + @staticmethod def create_dag_run( - self, - dag: DAG, + dag: SerializedDAG, *, task_states: Mapping[str, TaskInstanceState] | None = None, logical_date: datetime.datetime | None = None, @@ -113,7 +114,7 @@ def create_dag_run( logical_date = pendulum.instance(logical_date or now) if is_backfill: run_type = DagRunType.BACKFILL_JOB - data_interval = dag.infer_automated_data_interval(logical_date) + data_interval = infer_automated_data_interval(dag.timetable, logical_date) else: run_type = DagRunType.MANUAL data_interval = dag.timetable.infer_manual_data_interval(run_after=logical_date) @@ -465,14 +466,14 @@ def test_on_success_callback_when_task_skipped(self, session, testing_dag_bundle session.merge(dag_model) session.flush() - dag.sync_to_db() - SerializedDagModel.write_dag(dag, bundle_name="testing") + scheduler_dag = sync_dag_to_db(dag, session=session) + scheduler_dag.on_success_callback = mock_on_success initial_task_states = { "test_state_succeeded1": TaskInstanceState.SKIPPED, } - dag_run = self.create_dag_run(dag=dag, task_states=initial_task_states, session=session) + dag_run = self.create_dag_run(scheduler_dag, task_states=initial_task_states, session=session) _, _ = dag_run.update_state(execute_callbacks=True) task = dag_run.get_task_instances()[0] @@ -487,7 +488,7 @@ def test_start_dr_spans_if_needed_new_span(self, testing_dag_bundle, dag_maker, start_date=datetime.datetime(2017, 1, 1), ) as dag: ... - DAG.bulk_write_to_db("testing", None, dags=[dag], session=session) + SerializedDAG.bulk_write_to_db("testing", None, dags=[dag], session=session) dag_task1 = EmptyOperator(task_id="test_task1", dag=dag) dag_task2 = EmptyOperator(task_id="test_task2", dag=dag) @@ -524,7 +525,7 @@ def test_start_dr_spans_if_needed_span_with_continuance(self, testing_dag_bundle start_date=datetime.datetime(2017, 1, 1), ) as dag: ... - DAG.bulk_write_to_db("testing", None, dags=[dag], session=session) + SerializedDAG.bulk_write_to_db("testing", None, dags=[dag], session=session) dag_task1 = EmptyOperator(task_id="test_task1", dag=dag) dag_task2 = EmptyOperator(task_id="test_task2", dag=dag) @@ -570,7 +571,7 @@ def test_end_dr_span_if_needed(self, testing_dag_bundle, dag_maker, session): start_date=datetime.datetime(2017, 1, 1), ) as dag: ... - DAG.bulk_write_to_db("testing", None, dags=[dag], session=session) + SerializedDAG.bulk_write_to_db("testing", None, dags=[dag], session=session) dag_task1 = EmptyOperator(task_id="test_task1", dag=dag) dag_task2 = EmptyOperator(task_id="test_task2", dag=dag) @@ -612,7 +613,7 @@ def test_end_dr_span_if_needed_with_span_from_another_scheduler( start_date=datetime.datetime(2017, 1, 1), ) as dag: ... - DAG.bulk_write_to_db("testing", None, dags=[dag], session=session) + SerializedDAG.bulk_write_to_db("testing", None, dags=[dag], session=session) dag_task1 = EmptyOperator(task_id="test_task1", dag=dag) dag_task2 = EmptyOperator(task_id="test_task2", dag=dag) @@ -652,7 +653,6 @@ def on_success_callable(context): on_success_callback=on_success_callable, ) as dag: ... - DAG.bulk_write_to_db("testing", None, dags=[dag], session=session) dm = DagModel.get_dagmodel(dag.dag_id, session=session) dm.relative_fileloc = relative_fileloc session.merge(dm) @@ -670,7 +670,7 @@ def on_success_callable(context): # Scheduler uses Serialized DAG -- so use that instead of the Actual DAG dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) dag.relative_fileloc = relative_fileloc - SerializedDagModel.write_dag(dag, bundle_name="testing") + SerializedDagModel.write_dag(LazyDeserializedDAG.from_dag(dag), bundle_name="dag_maker") session.commit() dag_run = self.create_dag_run(dag=dag, task_states=initial_task_states, session=session) @@ -684,7 +684,7 @@ def on_success_callable(context): dag_id="test_dagrun_update_state_with_handle_callback_success", run_id=dag_run.run_id, is_failure_callback=False, - bundle_name="testing", + bundle_name="dag_maker", bundle_version=None, context_from_server=DagRunContext( dag_run=dag_run, @@ -705,7 +705,6 @@ def on_failure_callable(context): on_failure_callback=on_failure_callable, ) as dag: ... - DAG.bulk_write_to_db("testing", None, dags=[dag], session=session) dm = DagModel.get_dagmodel(dag.dag_id, session=session) dm.relative_fileloc = relative_fileloc session.merge(dm) @@ -723,7 +722,7 @@ def on_failure_callable(context): # Scheduler uses Serialized DAG -- so use that instead of the Actual DAG dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) dag.relative_fileloc = relative_fileloc - SerializedDagModel.write_dag(dag, bundle_name="testing") + SerializedDagModel.write_dag(LazyDeserializedDAG.from_dag(dag), bundle_name="dag_maker") session.commit() dag_run = self.create_dag_run(dag=dag, task_states=initial_task_states, session=session) @@ -739,7 +738,7 @@ def on_failure_callable(context): run_id=dag_run.run_id, is_failure_callback=True, msg="task_failure", - bundle_name="testing", + bundle_name="dag_maker", bundle_version=None, context_from_server=DagRunContext( dag_run=dag_run, @@ -1068,17 +1067,17 @@ def test_next_dagruns_to_examine_only_unpaused(self, session, state, testing_dag ) session.add(orm_dag) session.flush() - SerializedDagModel.write_dag(dag, bundle_name="testing") - dr = dag.create_dagrun( - run_id=dag.timetable.generate_run_id( + scheduler_dag = sync_dag_to_db(dag, session=session) + dr = scheduler_dag.create_dagrun( + run_id=scheduler_dag.timetable.generate_run_id( run_type=DagRunType.SCHEDULED, run_after=DEFAULT_DATE, - data_interval=dag.infer_automated_data_interval(DEFAULT_DATE), + data_interval=infer_automated_data_interval(scheduler_dag.timetable, DEFAULT_DATE), ), run_type=DagRunType.SCHEDULED, state=state, logical_date=DEFAULT_DATE, - data_interval=dag.infer_automated_data_interval(DEFAULT_DATE), + data_interval=infer_automated_data_interval(scheduler_dag.timetable, DEFAULT_DATE), run_after=DEFAULT_DATE, start_date=DEFAULT_DATE if state == DagRunState.RUNNING else None, session=session, @@ -1117,14 +1116,9 @@ def test_no_scheduling_delay_for_nonscheduled_runs(self, stats_mock, session, te session.merge(dag_model) session.flush() - dag.sync_to_db(session=session) - SerializedDagModel.write_dag(dag, bundle_name="testing") - - initial_task_states = { - dag_task.task_id: TaskInstanceState.SUCCESS, - } - - dag_run = self.create_dag_run(dag=dag, task_states=initial_task_states, session=session) + scheduler_dag = sync_dag_to_db(dag, session=session) + initial_task_states = {dag_task.task_id: TaskInstanceState.SUCCESS} + dag_run = self.create_dag_run(scheduler_dag, task_states=initial_task_states, session=session) dag_run.update_state(session=session) assert call(f"dagrun.{dag.dag_id}.first_task_scheduling_delay") not in stats_mock.mock_calls @@ -1144,9 +1138,9 @@ def test_emit_scheduling_delay(self, session, schedule, expected, testing_dag_bu dag = DAG(dag_id="test_emit_dag_stats", start_date=DEFAULT_DATE, schedule=schedule) dag_task = EmptyOperator(task_id="dummy", dag=dag, owner="airflow") expected_stat_tags = {"dag_id": f"{dag.dag_id}", "run_type": DagRunType.SCHEDULED} - + scheduler_dag = sync_dag_to_db(dag, session=session) try: - info = dag.next_dagrun_info(None) + info = scheduler_dag.next_dagrun_info(None) orm_dag_kwargs = { "dag_id": dag.dag_id, "bundle_name": "testing", @@ -1162,19 +1156,18 @@ def test_emit_scheduling_delay(self, session, schedule, expected, testing_dag_bu }, ) orm_dag = DagModel(**orm_dag_kwargs) - session.add(orm_dag) + session.merge(orm_dag) session.flush() - SerializedDagModel.write_dag(dag, bundle_name="testing") - dag_run = dag.create_dagrun( - run_id=dag.timetable.generate_run_id( + dag_run = scheduler_dag.create_dagrun( + run_id=scheduler_dag.timetable.generate_run_id( run_type=DagRunType.SCHEDULED, run_after=dag.start_date, - data_interval=dag.infer_automated_data_interval(dag.start_date), + data_interval=infer_automated_data_interval(scheduler_dag.timetable, dag.start_date), ), run_type=DagRunType.SCHEDULED, state=DagRunState.SUCCESS, logical_date=dag.start_date, - data_interval=dag.infer_automated_data_interval(dag.start_date), + data_interval=infer_automated_data_interval(scheduler_dag.timetable, dag.start_date), run_after=dag.start_date, start_date=dag.start_date, triggered_by=DagRunTriggeredByType.TEST, diff --git a/airflow-core/tests/unit/models/test_mappedoperator.py b/airflow-core/tests/unit/models/test_mappedoperator.py index cb8d3cf3fad65..84d05b3b932e6 100644 --- a/airflow-core/tests/unit/models/test_mappedoperator.py +++ b/airflow-core/tests/unit/models/test_mappedoperator.py @@ -26,18 +26,15 @@ from sqlalchemy import select from airflow.exceptions import AirflowSkipException -from airflow.models.baseoperator import BaseOperator -from airflow.models.dag import DAG from airflow.models.dag_version import DagVersion -from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import TaskInstance from airflow.models.taskmap import TaskMap from airflow.providers.standard.operators.python import PythonOperator -from airflow.sdk import setup, task, task_group, teardown -from airflow.sdk.definitions.taskgroup import TaskGroup +from airflow.sdk import DAG, BaseOperator, TaskGroup, setup, task, task_group, teardown from airflow.task.trigger_rule import TriggerRule from airflow.utils.state import TaskInstanceState +from tests_common.test_utils.dag import sync_dag_to_db from tests_common.test_utils.mapping import expand_mapped_task from tests_common.test_utils.mock_operators import MockOperator from unit.models import DEFAULT_DATE @@ -70,8 +67,7 @@ def execute(self, context: Context): unrenderable_values = [UnrenderableClass(), UnrenderableClass()] mapped = CustomOperator.partial(task_id="task_2").expand(arg=unrenderable_values) task1 >> mapped - DAG.bulk_write_to_db("testing", None, [dag]) - SerializedDagModel.write_dag(dag, bundle_name="testing") + sync_dag_to_db(dag) dag.test() assert ( "Unable to check if the value of type 'UnrenderableClass' is False for task 'task_2', field 'arg'" diff --git a/airflow-core/tests/unit/models/test_serialized_dag.py b/airflow-core/tests/unit/models/test_serialized_dag.py index f6277038dd96c..ba207cc39ec70 100644 --- a/airflow-core/tests/unit/models/test_serialized_dag.py +++ b/airflow-core/tests/unit/models/test_serialized_dag.py @@ -27,7 +27,7 @@ import airflow.example_dags as example_dags_module from airflow.models.asset import AssetActive, AssetAliasModel, AssetModel -from airflow.models.dag import DAG as SchedulerDAG, DagModel +from airflow.models.dag import DagModel from airflow.models.dag_version import DagVersion from airflow.models.dagbag import DagBag from airflow.models.serialized_dag import SerializedDagModel as SDM @@ -44,6 +44,7 @@ from airflow.utils.types import DagRunTriggeredByType, DagRunType from tests_common.test_utils import db +from tests_common.test_utils.dag import sync_dag_to_db pytestmark = pytest.mark.db_test @@ -62,7 +63,7 @@ def make_example_dags(module): dagbag = DagBag(module.__path__[0]) dags = [LazyDeserializedDAG(data=SerializedDAG.to_dict(dag)) for dag in dagbag.dags.values()] - SchedulerDAG.bulk_write_to_db("testing", None, dags) + SerializedDAG.bulk_write_to_db("testing", None, dags) return dagbag.dags @@ -87,7 +88,7 @@ def setup_test_cases(self, request, monkeypatch): def _write_example_dags(self): example_dags = make_example_dags(example_dags_module) for dag in example_dags.values(): - SDM.write_dag(dag, bundle_name="testing") + SDM.write_dag(LazyDeserializedDAG.from_dag(dag), bundle_name="testing") return example_dags def test_write_dag(self, testing_dag_bundle): @@ -107,48 +108,45 @@ def test_write_dag_when_python_callable_name_changes(self, dag_maker, session): def my_callable(): pass - with dag_maker("dag1") as dag: + with dag_maker("dag1"): PythonOperator(task_id="task1", python_callable=my_callable) - dag.sync_to_db() - SDM.write_dag(dag, bundle_name="dag_maker") dag_maker.create_dagrun(run_id="test1") - with dag_maker("dag1") as dag: + + with dag_maker("dag1"): PythonOperator(task_id="task1", python_callable=lambda x: None) - SDM.write_dag(dag, bundle_name="dag_maker") dag_maker.create_dagrun(run_id="test2", logical_date=pendulum.datetime(2025, 1, 1)) assert len(session.query(DagVersion).all()) == 2 - with dag_maker("dag2") as dag: + with dag_maker("dag2"): @task_decorator def my_callable(): pass my_callable() - dag.sync_to_db() - SDM.write_dag(dag, bundle_name="dag_maker") dag_maker.create_dagrun(run_id="test3", logical_date=pendulum.datetime(2025, 1, 2)) - with dag_maker("dag2") as dag: + + with dag_maker("dag2"): @task_decorator def my_callable2(): pass my_callable2() - SDM.write_dag(dag, bundle_name="testing") - assert len(session.query(DagVersion).all()) == 4 def test_serialized_dag_is_updated_if_dag_is_changed(self, testing_dag_bundle): """Test Serialized DAG is updated if DAG is changed""" example_dags = make_example_dags(example_dags_module) example_bash_op_dag = example_dags.get("example_bash_operator") - dag_updated = SDM.write_dag(dag=example_bash_op_dag, bundle_name="testing") + dag_updated = SDM.write_dag( + dag=LazyDeserializedDAG.from_dag(example_bash_op_dag), + bundle_name="testing", + ) assert dag_updated is True - # SchedulerDAG is created to create dagrun - dag = SchedulerDAG.from_sdk_dag(dag=example_bash_op_dag) - dag.create_dagrun( + s_dag = SDM.get(example_bash_op_dag.dag_id) + s_dag.dag.create_dagrun( run_id="test1", run_after=pendulum.datetime(2025, 1, 1, tz="UTC"), state=DagRunState.QUEUED, @@ -156,11 +154,12 @@ def test_serialized_dag_is_updated_if_dag_is_changed(self, testing_dag_bundle): run_type=DagRunType.MANUAL, ) - s_dag = SDM.get(example_bash_op_dag.dag_id) - # Test that if DAG is not changed, Serialized DAG is not re-written and last_updated # column is not updated - dag_updated = SDM.write_dag(dag=example_bash_op_dag, bundle_name="testing") + dag_updated = SDM.write_dag( + dag=LazyDeserializedDAG.from_dag(example_bash_op_dag), + bundle_name="testing", + ) s_dag_1 = SDM.get(example_bash_op_dag.dag_id) assert s_dag_1.dag_hash == s_dag.dag_hash @@ -171,7 +170,10 @@ def test_serialized_dag_is_updated_if_dag_is_changed(self, testing_dag_bundle): example_bash_op_dag.tags.add("new_tag") assert example_bash_op_dag.tags == {"example", "example2", "new_tag"} - dag_updated = SDM.write_dag(dag=example_bash_op_dag, bundle_name="testing") + dag_updated = SDM.write_dag( + dag=LazyDeserializedDAG.from_dag(example_bash_op_dag), + bundle_name="testing", + ) s_dag_2 = SDM.get(example_bash_op_dag.dag_id) assert s_dag.created_at != s_dag_2.created_at @@ -196,18 +198,16 @@ def test_read_all_dags_only_picks_the_latest_serdags(self, session): assert len(example_dags) == len(serialized_dags) dag = example_dags.get("example_bash_operator") - - # DAGs are serialized and deserialized to access create_dagrun object - sdag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag=dag)) - sdag.create_dagrun( + SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag=dag)).create_dagrun( run_id="test1", run_after=pendulum.datetime(2025, 1, 1, tz="UTC"), state=DagRunState.QUEUED, triggered_by=DagRunTriggeredByType.TEST, run_type=DagRunType.MANUAL, ) + dag.doc_md = "new doc string" - SDM.write_dag(dag, bundle_name="testing") + SDM.write_dag(LazyDeserializedDAG.from_dag(dag), bundle_name="testing") serialized_dags2 = SDM.read_all_dags() sdags = session.query(SDM).all() # assert only the latest SDM is returned @@ -223,7 +223,7 @@ def test_order_of_dag_params_is_stable(self): example_params_trigger_ui = example_dags.get("example_params_trigger_ui") before = list(example_params_trigger_ui.params.keys()) - SDM.write_dag(example_params_trigger_ui, bundle_name="testing") + SDM.write_dag(LazyDeserializedDAG.from_dag(example_params_trigger_ui), bundle_name="testing") retrieved_dag = SDM.get_dag("example_params_trigger_ui") after = list(retrieved_dag.params.keys()) @@ -299,26 +299,22 @@ def test_get_latest_serdag_versions(self, dag_maker, session): # first dag with dag_maker("dag1") as dag: EmptyOperator(task_id="task1") - dag.sync_to_db() - SDM.write_dag(dag, bundle_name="testing") + sync_dag_to_db(dag, session=session) dag_maker.create_dagrun() with dag_maker("dag1") as dag: EmptyOperator(task_id="task1") EmptyOperator(task_id="task2") - dag.sync_to_db() - SDM.write_dag(dag, bundle_name="testing") + sync_dag_to_db(dag, session=session) dag_maker.create_dagrun(run_id="test2", logical_date=pendulum.datetime(2025, 1, 1)) # second dag with dag_maker("dag2") as dag: EmptyOperator(task_id="task1") - dag.sync_to_db() - SDM.write_dag(dag, bundle_name="testing") + sync_dag_to_db(dag, session=session) dag_maker.create_dagrun(run_id="test3", logical_date=pendulum.datetime(2025, 1, 2)) with dag_maker("dag2") as dag: EmptyOperator(task_id="task1") EmptyOperator(task_id="task2") - dag.sync_to_db() - SDM.write_dag(dag, bundle_name="testing") + sync_dag_to_db(dag, session=session) # Total serdags should be 4 assert session.scalar(select(func.count()).select_from(SDM)) == 4 @@ -336,7 +332,7 @@ def test_new_dag_versions_are_not_created_if_no_dagruns(self, dag_maker, session last_updated = sdm1.last_updated # new task PythonOperator(task_id="task2", python_callable=lambda: None, dag=dag) - SDM.write_dag(dag, bundle_name="dag_maker") + SDM.write_dag(LazyDeserializedDAG.from_dag(dag), bundle_name="dag_maker") sdm2 = SDM.get(dag.dag_id, session=session) assert sdm2.dag_hash != dag_hash # first recorded serdag @@ -353,7 +349,7 @@ def test_new_dag_versions_are_created_if_there_is_a_dagrun(self, dag_maker, sess assert session.query(DagVersion).count() == 1 # new task PythonOperator(task_id="task2", python_callable=lambda: None, dag=dag) - SDM.write_dag(dag, bundle_name="dag_maker") + SDM.write_dag(LazyDeserializedDAG.from_dag(dag), bundle_name="dag_maker") assert session.query(DagVersion).count() == 2 assert session.query(SDM).count() == 2 @@ -402,8 +398,7 @@ def test_get_dependencies_with_asset_ref(self, dag_maker, session): schedule=[Asset.ref(uri=asset_uri), Asset.ref(uri="test://no-such-asset/")], ) as dag: BashOperator(task_id="any", bash_command="sleep 5") - dag.sync_to_db() - SDM.write_dag(dag, bundle_name="testing") + sync_dag_to_db(dag, session=session) dependencies = SDM.get_dag_dependencies(session=session) assert dependencies == { @@ -448,8 +443,7 @@ def test_get_dependencies_with_asset_alias(self, dag_maker, session): schedule=[AssetAlias(name="alias_1"), AssetAlias(name="alias_2")], ) as dag: BashOperator(task_id="any", bash_command="sleep 5") - dag.sync_to_db() - SDM.write_dag(dag, bundle_name="testing") + sync_dag_to_db(dag, session=session) dependencies = SDM.get_dag_dependencies(session=session) assert dependencies == { @@ -498,7 +492,7 @@ def test_min_update_interval_is_respected(self, provide_interval, new_task, shou PythonOperator(task_id="task2", python_callable=lambda: None, dag=dag) did_write = SDM.write_dag( - dag, + LazyDeserializedDAG.from_dag(dag), bundle_name="dag_maker", min_update_interval=min_update_interval, ) @@ -518,7 +512,7 @@ def test_new_dag_version_created_when_bundle_name_changes_and_hash_unchanged(sel # Write the same DAG (no changes, so hash is the same) with a new bundle_name new_bundle = "bundleB" - SDM.write_dag(dag, bundle_name=new_bundle) + SDM.write_dag(LazyDeserializedDAG.from_dag(dag), bundle_name=new_bundle) # There should now be two versions of the DAG assert session.query(DagVersion).count() == 2 diff --git a/airflow-core/tests/unit/models/test_taskinstance.py b/airflow-core/tests/unit/models/test_taskinstance.py index a730fa8d8d8df..094b8f362e8d4 100644 --- a/airflow-core/tests/unit/models/test_taskinstance.py +++ b/airflow-core/tests/unit/models/test_taskinstance.py @@ -42,9 +42,7 @@ AirflowSkipException, ) from airflow.models.asset import AssetActive, AssetAliasModel, AssetEvent, AssetModel -from airflow.models.baseoperator import BaseOperator from airflow.models.connection import Connection -from airflow.models.dag import DAG from airflow.models.dag_version import DagVersion from airflow.models.dagrun import DagRun from airflow.models.pool import Pool @@ -64,7 +62,7 @@ from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.operators.python import PythonOperator from airflow.providers.standard.sensors.python import PythonSensor -from airflow.sdk import BaseSensorOperator, Metadata, task, task_group +from airflow.sdk import DAG, BaseOperator, BaseSensorOperator, Metadata, task, task_group from airflow.sdk.api.datamodels._generated import AssetEventResponse, AssetResponse from airflow.sdk.definitions.asset import Asset, AssetAlias from airflow.sdk.definitions.param import process_params @@ -72,7 +70,7 @@ from airflow.sdk.execution_time.comms import ( AssetEventsResult, ) -from airflow.serialization.serialized_objects import SerializedBaseOperator +from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG from airflow.stats import Stats from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS @@ -91,7 +89,7 @@ from tests_common.test_utils.mock_operators import MockOperator from unit.models import DEFAULT_DATE -pytestmark = [pytest.mark.db_test] +pytestmark = [pytest.mark.db_test, pytest.mark.need_serialized_dag, pytest.mark.want_activate_assets] @pytest.fixture @@ -223,6 +221,7 @@ def test_set_dag(self, dag_maker): assert op.dag is dag assert op in dag.tasks + @pytest.mark.need_serialized_dag(False) def test_infer_dag(self, create_dummy_dag): op1 = EmptyOperator(task_id="test_op_1") op2 = EmptyOperator(task_id="test_op_2") @@ -267,8 +266,8 @@ def test_init_on_load(self, create_task_instance): assert ti.log.name == "airflow.task" assert not ti.test_mode - @patch.object(DAG, "get_concurrency_reached") - def test_requeue_over_dag_concurrency(self, mock_concurrency_reached, create_task_instance, dag_maker): + @patch.object(SerializedDAG, "get_concurrency_reached") + def test_requeue_over_dag_concurrency(self, mock_concurrency_reached, create_task_instance): mock_concurrency_reached.return_value = True ti = create_task_instance( @@ -416,15 +415,17 @@ def test_ti_updates_with_task(self, create_task_instance, session): dag = ti.task.dag ti.run(session=session) - tis = dag.get_task_instances() - assert tis[0].executor_config == {"foo": "bar"} + executor_configs = session.scalars( + select(TaskInstance.executor_config).where(TaskInstance.dag_id == ti.dag_id) + ).all() + assert executor_configs == [{"foo": "bar"}] + task2 = EmptyOperator( task_id="test_run_pooling_task_op2", executor_config={"bar": "baz"}, start_date=timezone.datetime(2016, 2, 1, 0, 0, 0), dag=dag, ) - ti2 = TI(task=task2, run_id=ti.run_id, dag_version_id=ti.dag_version_id) session.add(ti2) session.flush() @@ -1116,15 +1117,15 @@ def test_check_task_dependencies( assert s.success >= s.success_setup assert s.done == s.failed + s.success + s.removed + s.upstream_failed + s.skipped - with dag_maker() as dag: + with dag_maker(): downstream = EmptyOperator(task_id="downstream", trigger_rule=trigger_rule) if set_teardown: downstream.as_teardown() for i in range(5): - task = EmptyOperator(task_id=f"work_{i}", dag=dag) + task = EmptyOperator(task_id=f"work_{i}") task.set_downstream(downstream) for i in range(upstream_setups): - task = EmptyOperator(task_id=f"setup_{i}", dag=dag).as_setup() + task = EmptyOperator(task_id=f"setup_{i}").as_setup() task.set_downstream(downstream) assert task.start_date is not None run_date = task.start_date + datetime.timedelta(days=5) @@ -1341,7 +1342,6 @@ def test_check_and_change_state_before_execution(self, create_task_instance, tes dag_id="test_check_and_change_state_before_execution", external_executor_id=expected_external_executor_id, ) - SerializedDagModel.write_dag(ti.task.dag, bundle_name="testing") serialized_dag = SerializedDagModel.get(ti.task.dag.dag_id).dag ti_from_deserialized_task = TI( @@ -1364,7 +1364,6 @@ def test_check_and_change_state_before_execution_provided_id_overrides( external_executor_id="apple", ) assert ti.external_executor_id == "apple" - SerializedDagModel.write_dag(ti.task.dag, bundle_name="testing") serialized_dag = SerializedDagModel.get(ti.task.dag.dag_id).dag ti_from_deserialized_task = TI( @@ -1384,7 +1383,6 @@ def test_check_and_change_state_before_execution_with_exec_id(self, create_task_ expected_external_executor_id = "minions" ti = create_task_instance(dag_id="test_check_and_change_state_before_execution") assert ti.external_executor_id is None - SerializedDagModel.write_dag(ti.task.dag, bundle_name="testing") serialized_dag = SerializedDagModel.get(ti.task.dag.dag_id).dag ti_from_deserialized_task = TI( @@ -1400,18 +1398,14 @@ def test_check_and_change_state_before_execution_with_exec_id(self, create_task_ assert ti_from_deserialized_task.state == State.RUNNING assert ti_from_deserialized_task.try_number == 0 - def test_check_and_change_state_before_execution_dep_not_met( - self, create_task_instance, testing_dag_bundle - ): - ti = create_task_instance(dag_id="test_check_and_change_state_before_execution") - task2 = EmptyOperator(task_id="task2", dag=ti.task.dag, start_date=DEFAULT_DATE) - ti.task >> task2 - SerializedDagModel.write_dag(ti.task.dag, bundle_name="testing") - - serialized_dag = SerializedDagModel.get(ti.task.dag.dag_id).dag - ti2 = TI( - task=serialized_dag.get_task(task2.task_id), run_id=ti.run_id, dag_version_id=ti.dag_version_id - ) + def test_check_and_change_state_before_execution_dep_not_met(self, dag_maker): + with dag_maker(dag_id="test_check_and_change_state_before_execution") as dag: + task1 = EmptyOperator(task_id="task1") + task2 = EmptyOperator(task_id="task2", start_date=DEFAULT_DATE) + task1 >> task2 + dr = dag_maker.create_dagrun() + ti2 = dr.get_task_instance("task2") + ti2.refresh_from_task(dag.get_task("task2")) # Need scheduler task for the check. assert not ti2.check_and_change_state_before_execution() def test_check_and_change_state_before_execution_dep_not_met_already_running( @@ -1422,8 +1416,6 @@ def test_check_and_change_state_before_execution_dep_not_met_already_running( with create_session() as _: ti.state = State.RUNNING - SerializedDagModel.write_dag(ti.task.dag, bundle_name="testing") - serialized_dag = SerializedDagModel.get(ti.task.dag.dag_id).dag ti_from_deserialized_task = TI( task=serialized_dag.get_task(ti.task_id), run_id=ti.run_id, dag_version_id=ti.dag_version_id @@ -1441,8 +1433,6 @@ def test_check_and_change_state_before_execution_dep_not_met_not_runnable_state( with create_session() as _: ti.state = State.FAILED - SerializedDagModel.write_dag(ti.task.dag, bundle_name="testing") - serialized_dag = SerializedDagModel.get(ti.task.dag.dag_id).dag ti_from_deserialized_task = TI( task=serialized_dag.get_task(ti.task_id), run_id=ti.run_id, dag_version_id=ti.dag_version_id @@ -1455,7 +1445,7 @@ def test_try_number(self, create_task_instance): """ Test the try_number accessor behaves in various running states """ - ti = create_task_instance(dag_id="test_check_and_change_state_before_execution") + ti = create_task_instance(dag_id="test_try_number") # TI starts at 0. It's only incremented by the scheduler. assert ti.try_number == 0 ti.try_number = 2 @@ -1465,7 +1455,7 @@ def test_try_number(self, create_task_instance): ti.state = State.SUCCESS assert ti.try_number == 2 # unaffected by state - def test_get_num_running_task_instances(self, create_task_instance): + def test_get_num_running_task_instances(self, dag_maker, create_task_instance): session = settings.Session() ti1 = create_task_instance( @@ -1473,7 +1463,7 @@ def test_get_num_running_task_instances(self, create_task_instance): ) logical_date = DEFAULT_DATE + datetime.timedelta(days=1) - dr = ti1.task.dag.create_dagrun( + dr = dag_maker.create_dagrun( logical_date=logical_date, run_type=DagRunType.MANUAL, state=None, @@ -1612,7 +1602,6 @@ def test_set_duration_empty_dates(self): ti.set_duration() assert ti.duration is None - @pytest.mark.want_activate_assets(True) def test_outlet_asset_extra(self, dag_maker, session): from airflow.sdk.definitions.asset import Asset @@ -1653,99 +1642,6 @@ def _write2_post_execute(context, _): assert events["write2"].asset.uri == "test_outlet_asset_extra_2" assert events["write2"].extra == {"x": 1} - @pytest.mark.want_activate_assets(True) - def test_outlet_asset_template_extra(self, dag_maker, session): - from airflow.sdk.definitions.asset import Asset - - with dag_maker(schedule=None, serialized=True, session=session): - - @task( - outlets=Asset( - "test_outlet_asset_template_extra1", - extra={ - "static_extra": "value", - "dag_id": "{{ dag.dag_id }}", - "nested_extra": { - "task_id": "{{ task.task_id }}", - "logical_date": "{{ ds }}", - }, - }, - ) - ) - def write_template1(*, outlet_events): - yield Metadata( - Asset("test_outlet_asset_template_extra1"), - { - "dag_id": "override_dag_id", - "some_other_key": "some_other_value", - }, - ) - - write_template1() - - BashOperator( - task_id="write_template2", - bash_command=":", - outlets=Asset( - "test_outlet_asset_template_extra2", - extra={ - "static_extra": "value", - "dag_id": "{{ dag.dag_id }}", - "nested_extra": { - "task_id": "{{ task.task_id }}", - "logical_date": "{{ ds }}", - }, - }, - ), - ) - - BashOperator( - task_id="write_asset_no_extra", - bash_command=":", - outlets=Asset("test_outlet_asset_no_extra"), - ) - - dr: DagRun = dag_maker.create_dagrun() - for ti in dr.get_task_instances(session=session): - ti.run(session=session) - - events = dict(iter(session.execute(select(AssetEvent.source_task_id, AssetEvent)))) - assert set(events) == {"write_template1", "write_template2", "write_asset_no_extra"} - - assert events["write_template1"].source_dag_id == dr.dag_id - assert events["write_template1"].source_run_id == dr.run_id - assert events["write_template1"].source_task_id == "write_template1" - assert events["write_template1"].asset.uri == "test_outlet_asset_template_extra1" - assert events["write_template1"].extra == { - "static_extra": "value", - "dag_id": "override_dag_id", # Overridden by Metadata - "nested_extra": { - "task_id": "write_template1", - "logical_date": dr.logical_date.strftime("%Y-%m-%d"), - }, - "some_other_key": "some_other_value", # Added by Metadata - } - - assert events["write_template2"].source_dag_id == dr.dag_id - assert events["write_template2"].source_run_id == dr.run_id - assert events["write_template2"].source_task_id == "write_template2" - assert events["write_template2"].asset.uri == "test_outlet_asset_template_extra2" - assert events["write_template2"].extra == { - "static_extra": "value", - "dag_id": dr.dag_id, - "nested_extra": { - "task_id": "write_template2", - "logical_date": dr.logical_date.strftime("%Y-%m-%d"), - }, - } - - assert events["write_asset_no_extra"].source_dag_id == dr.dag_id - assert events["write_asset_no_extra"].source_run_id == dr.run_id - assert events["write_asset_no_extra"].source_task_id == "write_asset_no_extra" - assert events["write_asset_no_extra"].asset.uri == "test_outlet_asset_no_extra" - assert events["write_asset_no_extra"].extra == {} - - @pytest.mark.want_activate_assets(True) def test_outlet_asset_extra_ignore_different(self, dag_maker, session): from airflow.sdk.definitions.asset import Asset @@ -1767,7 +1663,6 @@ def write(*, outlet_events): assert event.source_task_id == "write" assert event.extra == {"one": 1} - @pytest.mark.want_activate_assets(True) def test_outlet_asset_alias(self, dag_maker, session): from airflow.sdk.definitions.asset import Asset @@ -1815,7 +1710,6 @@ def producer(*, outlet_events): assert len(asset_alias_obj.assets) == 1 assert asset_alias_obj.assets[0].uri == asset_uri - @pytest.mark.want_activate_assets(True) def test_outlet_multiple_asset_alias(self, dag_maker, session): from airflow.sdk.definitions.asset import Asset @@ -1888,10 +1782,7 @@ def producer(*, outlet_events): assert len(asset_alias_obj.assets) == 1 assert asset_alias_obj.assets[0].uri == asset_uri - @pytest.mark.want_activate_assets(True) def test_outlet_asset_alias_through_metadata(self, dag_maker, session): - from airflow.sdk.definitions.asset.metadata import Metadata - asset_uri = "test_outlet_asset_alias_through_metadata_ds" asset_alias_name = "test_outlet_asset_alias_through_metadata_asset_alias" @@ -1931,7 +1822,6 @@ def producer(*, outlet_events): assert len(asset_alias_obj.assets) == 1 assert asset_alias_obj.assets[0].uri == asset_uri - @pytest.mark.want_activate_assets(True) def test_outlet_asset_alias_asset_not_exists(self, dag_maker, session): from airflow.sdk.definitions.asset import Asset @@ -2026,8 +1916,6 @@ def producer_with_inactive(*, outlet_events): asset_alias_obj = session.scalar(select(AssetAliasModel)) assert sorted(a.name for a in asset_alias_obj.assets) == ["asset1", "asset2"] - @pytest.mark.want_activate_assets(True) - @pytest.mark.need_serialized_dag def test_inlet_asset_extra(self, dag_maker, session, mock_supervisor_comms): from airflow.sdk.definitions.asset import Asset @@ -2102,7 +1990,6 @@ def read(*, inlet_events): assert not dr.task_instance_scheduling_decisions(session=session).schedulable_tis assert read_task_evaluated - @pytest.mark.need_serialized_dag def test_inlet_unresolved_asset_alias(self, dag_maker, session, mock_supervisor_comms): asset_alias_name = "test_inlet_asset_extra_asset_alias" mock_supervisor_comms.send.return_value = AssetEventsResult(asset_events=[]) @@ -2809,7 +2696,6 @@ def test_task_instance_history_is_created_when_ti_goes_for_retry(self, dag_maker @pytest.mark.skip( reason="This test has some issues that were surfaced when dag_maker started allowing multiple serdag versions. Issue #48539 will track fixing this." ) - @pytest.mark.want_activate_assets(True) def test_run_with_inactive_assets(self, dag_maker, session): from airflow.sdk.definitions.asset import Asset @@ -3021,10 +2907,9 @@ def test_map_literal(self, literal, expected_outputs, dag_maker, session): def show(value): outputs.append(value) - show.expand(value=literal) + show_task = show.expand(value=literal).operator dag_run = dag_maker.create_dagrun() - show_task = dag.get_task("show") mapped_tis = ( session.query(TI) .filter_by(task_id="show", dag_id=dag_run.dag_id, run_id=dag_run.run_id) diff --git a/airflow-core/tests/unit/models/test_xcom.py b/airflow-core/tests/unit/models/test_xcom.py index 57d7558cbaeab..37dc65dfc6508 100644 --- a/airflow-core/tests/unit/models/test_xcom.py +++ b/airflow-core/tests/unit/models/test_xcom.py @@ -24,22 +24,20 @@ import pytest -from airflow import DAG from airflow._shared.timezones import timezone from airflow.configuration import conf +from airflow.models.dag import DAG from airflow.models.dag_version import DagVersion -from airflow.models.dagbundle import DagBundleModel from airflow.models.dagrun import DagRun, DagRunType -from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import TaskInstance from airflow.models.xcom import XComModel from airflow.providers.standard.operators.empty import EmptyOperator from airflow.sdk.bases.xcom import BaseXCom from airflow.sdk.execution_time.xcom import resolve_xcom_backend from airflow.settings import json -from airflow.utils.session import create_session from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.dag import sync_dag_to_db from tests_common.test_utils.db import clear_db_dag_bundles, clear_db_dags, clear_db_runs, clear_db_xcom from tests_common.test_utils.markers import skip_if_force_lowest_dependencies_marker @@ -65,14 +63,7 @@ def reset_db(): @pytest.fixture def task_instance_factory(request, session: Session): def func(*, dag_id, task_id, logical_date, run_after=None): - dag = DAG(dag_id=dag_id) - bundle_name = "testing" - with create_session() as session: - orm_dag_bundle = DagBundleModel(name=bundle_name) - session.merge(orm_dag_bundle) - session.commit() - DAG.bulk_write_to_db(bundle_name, None, [dag], session=session) - SerializedDagModel.write_dag(dag, bundle_name=bundle_name) + sync_dag_to_db(DAG(dag_id=dag_id)) run_id = DagRun.generate_run_id( run_type=DagRunType.SCHEDULED, logical_date=logical_date, diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py b/airflow-core/tests/unit/serialization/test_dag_serialization.py index e4f0535493359..6d71851dc879a 100644 --- a/airflow-core/tests/unit/serialization/test_dag_serialization.py +++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py @@ -56,13 +56,12 @@ ) from airflow.models.asset import AssetModel from airflow.models.connection import Connection -from airflow.models.dag import DAG from airflow.models.dagbag import DagBag from airflow.models.mappedoperator import MappedOperator from airflow.models.xcom import XCOM_RETURN_KEY, XComModel from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator from airflow.providers.standard.operators.bash import BashOperator -from airflow.sdk import AssetAlias, BaseHook, teardown +from airflow.sdk import DAG, AssetAlias, BaseHook, teardown from airflow.sdk.bases.decorator import DecoratedOperator from airflow.sdk.bases.operator import OPERATOR_DEFAULTS, BaseOperator from airflow.sdk.definitions._internal.expandinput import EXPAND_INPUT_EMPTY @@ -156,6 +155,9 @@ "downstream_task_ids": [], }, "is_paused_upon_creation": False, + "max_active_runs": 16, + "max_active_tasks": 16, + "max_consecutive_failed_dag_runs": 0, "dag_id": "simple_dag", "deadline": None, "catchup": False, @@ -256,7 +258,7 @@ ], } }, - } + }, }, }, "edge_info": {}, @@ -336,8 +338,14 @@ def make_user_defined_macro_filter_dag(): (2) templates with function macros have been rendered before serialization. """ + # TODO (GH-52141): Since the worker would not have access to the database in + # production anyway, we should rewrite this test to better match reality. def compute_last_dagrun(dag: DAG): - return dag.get_last_dagrun(include_manually_triggered=True) + from airflow.models.dag import get_last_dagrun + from airflow.utils.session import create_session + + with create_session() as session: + return get_last_dagrun(dag.dag_id, session=session, include_manually_triggered=True) default_args = {"start_date": datetime(2019, 7, 10)} dag = DAG( @@ -661,7 +669,7 @@ def test_dag_roundtrip_from_timetable(self, timetable): roundtripped = SerializedDAG.from_json(SerializedDAG.to_json(dag)) self.validate_deserialized_dag(roundtripped, dag) - def validate_deserialized_dag(self, serialized_dag: DAG, dag: DAG): + def validate_deserialized_dag(self, serialized_dag: SerializedDAG, dag: DAG): """ Verify that all example DAGs work with DAG Serialization by checking fields between Serialized Dags & non-Serialized Dags @@ -1347,6 +1355,7 @@ def test_dag_serialized_fields_with_schema(self): # The parameters we add manually in Serialization need to be ignored ignored_keys: set = { + "_processor_dags_folder", "tasks", "has_on_success_callback", "has_on_failure_callback", @@ -3026,6 +3035,9 @@ def test_handle_v1_serdag(): "downstream_task_ids": [], }, "is_paused_upon_creation": False, + "max_active_runs": 16, + "max_active_tasks": 16, + "max_consecutive_failed_dag_runs": 0, "_dag_id": "simple_dag", "deadline": None, "doc_md": "### DAG Tutorial Documentation", diff --git a/airflow-core/tests/unit/serialization/test_serialized_objects.py b/airflow-core/tests/unit/serialization/test_serialized_objects.py index 6b6fa9a52602e..a5cc17dc3ed16 100644 --- a/airflow-core/tests/unit/serialization/test_serialized_objects.py +++ b/airflow-core/tests/unit/serialization/test_serialized_objects.py @@ -76,7 +76,6 @@ SerializedDAG, create_scheduler_operator, ) -from airflow.timetables.base import DataInterval from airflow.triggers.base import BaseTrigger from airflow.utils.db import LazySelectSequence from airflow.utils.state import DagRunState, State @@ -600,90 +599,6 @@ def map_me_but_slowly(a): assert lazy_serialized_dag.has_task_concurrency_limits -@pytest.mark.db_test -@pytest.mark.parametrize( - "create_dag_run_kwargs", - ( - {}, - { - "data_interval": None, - "logical_date": pendulum.DateTime(2016, 1, 1, 0, 0, 0, tzinfo=Timezone("UTC")), - }, - {"data_interval": None, "logical_date": None}, - ), - ids=["post-AIP-39", "pre-AIP-39-should-infer", "pre-AIP-39"], -) -def test_serialized_dag_get_run_data_interval(create_dag_run_kwargs, dag_maker, session): - """Test whether LazyDeserializedDAG can correctly get dag run data_interval - - post-AIP-39: the dag run itself contains both data_interval start and data_interval end, and thus can - be retrieved directly - pre-AIP-39-should-infer: the dag run itself has neither data_interval_start nor data_interval_end, - and thus needs to infer the data_interval from its timetable - pre-AIP-39: the dag run itself has neither data_interval_start nor data_interval_end, and its logical_date - is none. it should return data_interval as none - """ - with dag_maker(dag_id="test_dag", session=session, serialized=False) as dag: - BaseOperator(task_id="test_task") - session.commit() - - dr = dag_maker.create_dagrun(**create_dag_run_kwargs) - ser_dict = SerializedDAG.to_dict(dag) - deser_dag = LazyDeserializedDAG(data=ser_dict) - if "logical_date" in create_dag_run_kwargs and create_dag_run_kwargs["logical_date"] is None: - data_interval = deser_dag.get_run_data_interval(dr) - assert data_interval is None - else: - data_interval = deser_dag.get_run_data_interval(dr) - assert data_interval == DataInterval( - start=pendulum.DateTime(2015, 12, 31, 0, 0, 0, tzinfo=Timezone("UTC")), - end=pendulum.DateTime(2016, 1, 1, 0, 0, 0, tzinfo=Timezone("UTC")), - ) - - -def test_get_task_assets(): - asset1 = Asset("1") - with DAG("testdag") as source_dag: - a = BashOperator(task_id="a", outlets=[asset1], bash_command="echo u") - b = BashOperator(task_id="b", inlets=[asset1], bash_command="echo v") - c = BashOperator.partial(task_id="c", inlets=[asset1]).expand(bash_command=["echo w", "echo x"]) - d = BashOperator.partial(task_id="d", outlets=[asset1]).expand(bash_command=["echo y", "echo z"]) - a >> b >> c >> d - - deser_dag = LazyDeserializedDAG(data=SerializedDAG.to_dict(source_dag)) - assert sorted(deser_dag.get_task_assets()) == [ - ("a", asset1), - ("b", asset1), - ("c", asset1), - ("d", asset1), - ] - - -def test_lazy_dag_run_interval_wrong_dag(): - lazy = LazyDeserializedDAG(data={"dag": {"dag_id": "dag1"}}) - - with pytest.raises(ValueError, match="different DAGs"): - lazy.get_run_data_interval(DAG_RUN) - - -def test_lazy_dag_run_interval_missing_interval(): - lazy = LazyDeserializedDAG(data={"dag": {"dag_id": "test_dag_id"}}) - - with pytest.raises(ValueError, match="Unsure how to deserialize version ''"): - lazy.get_run_data_interval(DAG_RUN) - - -def test_lazy_dag_run_interval_success(): - run = DAG_RUN - run.data_interval_start = datetime(2025, 1, 1) - run.data_interval_end = datetime(2025, 1, 2) - - lazy = LazyDeserializedDAG(data={"dag": {"dag_id": "test_dag_id"}}) - interval = lazy.get_run_data_interval(run) - - assert isinstance(interval, DataInterval) - - def test_hash_property(): from airflow.models.serialized_dag import SerializedDagModel diff --git a/airflow-core/tests/unit/ti_deps/deps/test_dag_unpaused_dep.py b/airflow-core/tests/unit/ti_deps/deps/test_dag_unpaused_dep.py index bf1a8bdf8945e..700adcd7803ae 100644 --- a/airflow-core/tests/unit/ti_deps/deps/test_dag_unpaused_dep.py +++ b/airflow-core/tests/unit/ti_deps/deps/test_dag_unpaused_dep.py @@ -18,7 +18,6 @@ from __future__ import annotations from unittest import mock -from unittest.mock import Mock import pytest @@ -28,23 +27,24 @@ pytestmark = pytest.mark.db_test +@mock.patch.object(DagUnpausedDep, "_is_dag_paused") class TestDagUnpausedDep: - def test_concurrency_reached(self): + def test_concurrency_reached(self, mock_is_dag_paused): """ Test paused DAG should fail dependency """ - dag = Mock(**{"get_is_paused.return_value": True}) - task = Mock(dag=dag) + mock_is_dag_paused.return_value = True + task = mock.Mock() ti = TaskInstance(task=task, dag_version_id=mock.MagicMock()) assert not DagUnpausedDep().is_met(ti=ti) - def test_all_conditions_met(self): + def test_all_conditions_met(self, mock_is_dag_paused): """ Test all conditions met should pass dep """ - dag = Mock(**{"get_is_paused.return_value": False}) - task = Mock(dag=dag) + mock_is_dag_paused.return_value = False + task = mock.Mock() ti = TaskInstance(task=task, dag_version_id=mock.MagicMock()) assert DagUnpausedDep().is_met(ti=ti) diff --git a/airflow-core/tests/unit/ti_deps/deps/test_prev_dagrun_dep.py b/airflow-core/tests/unit/ti_deps/deps/test_prev_dagrun_dep.py index 915ad54aad8f6..59a2274c5feb8 100644 --- a/airflow-core/tests/unit/ti_deps/deps/test_prev_dagrun_dep.py +++ b/airflow-core/tests/unit/ti_deps/deps/test_prev_dagrun_dep.py @@ -23,14 +23,13 @@ import pytest from airflow._shared.timezones.timezone import convert_to_utc, datetime -from airflow.models.baseoperator import BaseOperator -from airflow.models.dag import DAG -from airflow.models.serialized_dag import SerializedDagModel +from airflow.sdk import DAG, BaseOperator from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep from airflow.utils.state import DagRunState, TaskInstanceState from airflow.utils.types import DagRunTriggeredByType, DagRunType +from tests_common.test_utils.dag import create_scheduler_dag, sync_dag_to_db from tests_common.test_utils.db import clear_db_runs pytestmark = pytest.mark.db_test @@ -55,10 +54,9 @@ def test_first_task_run_of_new_task(self, testing_dag_bundle): start_date=START_DATE, wait_for_downstream=False, ) - DAG.bulk_write_to_db("testing", None, [dag]) - SerializedDagModel.write_dag(dag, bundle_name="testing") + scheduler_dag = sync_dag_to_db(dag) # Old DAG run will include only TaskInstance of old_task - dag.create_dagrun( + scheduler_dag.create_dagrun( run_id="old_run", state=TaskInstanceState.SUCCESS, logical_date=old_task.start_date, @@ -78,7 +76,7 @@ def test_first_task_run_of_new_task(self, testing_dag_bundle): # New DAG run will include 1st TaskInstance of new_task logical_date = convert_to_utc(datetime(2016, 1, 2)) - dr = dag.create_dagrun( + dr = create_scheduler_dag(dag).create_dagrun( run_id="new_run", state=DagRunState.RUNNING, logical_date=logical_date, diff --git a/airflow-core/tests/unit/utils/test_db_cleanup.py b/airflow-core/tests/unit/utils/test_db_cleanup.py index 27c82057bc822..fb1077e0136d3 100644 --- a/airflow-core/tests/unit/utils/test_db_cleanup.py +++ b/airflow-core/tests/unit/utils/test_db_cleanup.py @@ -36,7 +36,7 @@ from airflow.models import DagModel, DagRun, TaskInstance from airflow.models.dag_version import DagVersion from airflow.models.dagbundle import DagBundleModel -from airflow.models.serialized_dag import SerializedDagModel +from airflow.models.serialized_dag import LazyDeserializedDAG, SerializedDagModel from airflow.providers.standard.operators.python import PythonOperator from airflow.utils.db_cleanup import ( ARCHIVE_TABLE_PREFIX, @@ -684,7 +684,7 @@ def create_tis(base_date, num_tis, run_type=DagRunType.SCHEDULED): dag = DAG(dag_id=dag_id) dm = DagModel(dag_id=dag_id, bundle_name=bundle_name) session.add(dm) - SerializedDagModel.write_dag(dag, bundle_name=bundle_name) + SerializedDagModel.write_dag(LazyDeserializedDAG.from_dag(dag), bundle_name=bundle_name) dag_version = DagVersion.get_latest_version(dag.dag_id) for num in range(num_tis): start_date = base_date.add(days=num) diff --git a/airflow-core/tests/unit/utils/test_sqlalchemy.py b/airflow-core/tests/unit/utils/test_sqlalchemy.py index 212c63571db68..051c91f05b1d7 100644 --- a/airflow-core/tests/unit/utils/test_sqlalchemy.py +++ b/airflow-core/tests/unit/utils/test_sqlalchemy.py @@ -29,9 +29,8 @@ from sqlalchemy.exc import StatementError from airflow import settings -from airflow._shared.timezones.timezone import utcnow -from airflow.models.dag import DAG -from airflow.models.serialized_dag import SerializedDagModel +from airflow.sdk import DAG +from airflow.sdk.timezone import utcnow from airflow.serialization.enums import DagAttributeTypes, Encoding from airflow.serialization.serialized_objects import BaseSerialization from airflow.settings import Session @@ -45,6 +44,8 @@ from airflow.utils.state import State from airflow.utils.types import DagRunTriggeredByType, DagRunType +from tests_common.test_utils.dag import sync_dag_to_db + pytestmark = pytest.mark.db_test @@ -72,10 +73,8 @@ def test_utc_transformations(self, testing_dag_bundle): iso_date = start_date.isoformat() logical_date = start_date + datetime.timedelta(hours=1, days=1) - dag = DAG(dag_id=dag_id, schedule=datetime.timedelta(days=1), start_date=start_date) + dag = sync_dag_to_db(DAG(dag_id=dag_id, schedule=datetime.timedelta(days=1), start_date=start_date)) dag.clear() - DAG.bulk_write_to_db("testing", None, [dag], session=self.session) - SerializedDagModel.write_dag(dag, bundle_name="testing", session=self.session) run = dag.create_dagrun( run_id=iso_date, run_type=DagRunType.MANUAL, @@ -107,7 +106,7 @@ def test_process_bind_param_naive(self): # naive start_date = datetime.datetime.now() - dag = DAG(dag_id=dag_id, start_date=start_date, schedule=datetime.timedelta(days=1)) + dag = sync_dag_to_db(DAG(dag_id=dag_id, start_date=start_date, schedule=datetime.timedelta(days=1))) dag.clear() with pytest.raises((ValueError, StatementError)): diff --git a/airflow-core/tests/unit/utils/test_state.py b/airflow-core/tests/unit/utils/test_state.py index 8bc27a1dbdb84..463f943320462 100644 --- a/airflow-core/tests/unit/utils/test_state.py +++ b/airflow-core/tests/unit/utils/test_state.py @@ -20,13 +20,13 @@ import pytest -from airflow.models.dag import DAG from airflow.models.dagrun import DagRun -from airflow.models.serialized_dag import SerializedDagModel +from airflow.sdk import DAG from airflow.utils.session import create_session from airflow.utils.state import DagRunState, IntermediateTIState, State, TaskInstanceState, TerminalTIState from airflow.utils.types import DagRunTriggeredByType, DagRunType +from tests_common.test_utils.dag import sync_dag_to_db from unit.models import DEFAULT_DATE pytestmark = pytest.mark.db_test @@ -38,9 +38,10 @@ def test_dagrun_state_enum_escape(testing_dag_bundle): referenced in DB query """ with create_session() as session: - dag = DAG(dag_id="test_dagrun_state_enum_escape", schedule=timedelta(days=1), start_date=DEFAULT_DATE) - DAG.bulk_write_to_db("testing", None, [dag]) - SerializedDagModel.write_dag(dag, bundle_name="testing") + dag = sync_dag_to_db( + DAG(dag_id="test_dagrun_state_enum_escape", schedule=timedelta(days=1), start_date=DEFAULT_DATE), + session=session, + ) dag.create_dagrun( run_id=dag.timetable.generate_run_id( run_type=DagRunType.SCHEDULED, diff --git a/dev/airflow_perf/dags/elastic_dag.py b/dev/airflow_perf/dags/elastic_dag.py index d8959c1c76c68..502c1fc743d7a 100644 --- a/dev/airflow_perf/dags/elastic_dag.py +++ b/dev/airflow_perf/dags/elastic_dag.py @@ -22,9 +22,8 @@ from datetime import datetime, timedelta from enum import Enum -from airflow.models.dag import DAG from airflow.providers.standard.operators.bash import BashOperator -from airflow.sdk import chain +from airflow.sdk import DAG, chain # DAG File used in performance tests. Its shape can be configured by environment variables. RE_TIME_DELTA = re.compile( diff --git a/dev/airflow_perf/dags/perf_dag_1.py b/dev/airflow_perf/dags/perf_dag_1.py index f54b763e13f30..6f974dd5c9e9d 100644 --- a/dev/airflow_perf/dags/perf_dag_1.py +++ b/dev/airflow_perf/dags/perf_dag_1.py @@ -23,8 +23,8 @@ import datetime -from airflow.models.dag import DAG from airflow.operators.bash_operator import BashOperator +from airflow.sdk import DAG args = { "owner": "airflow", diff --git a/dev/airflow_perf/dags/perf_dag_2.py b/dev/airflow_perf/dags/perf_dag_2.py index 592bbe6087838..c6bd4b0151c4d 100644 --- a/dev/airflow_perf/dags/perf_dag_2.py +++ b/dev/airflow_perf/dags/perf_dag_2.py @@ -23,8 +23,8 @@ import datetime -from airflow.models.dag import DAG from airflow.providers.standard.operators.bash import BashOperator +from airflow.sdk import DAG args = { "owner": "airflow", diff --git a/dev/airflow_perf/dags/sql_perf_dag.py b/dev/airflow_perf/dags/sql_perf_dag.py index 2081eb482ae01..0e62bcf19ba3d 100644 --- a/dev/airflow_perf/dags/sql_perf_dag.py +++ b/dev/airflow_perf/dags/sql_perf_dag.py @@ -18,8 +18,8 @@ from datetime import datetime, timedelta -from airflow.models.dag import DAG from airflow.providers.standard.operators.python import PythonOperator +from airflow.sdk import DAG default_args = { "owner": "Airflow", diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index 4e6263491b3d0..af87248da467a 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -30,7 +30,7 @@ from contextlib import ExitStack, suppress from datetime import datetime, timedelta, timezone from pathlib import Path -from typing import TYPE_CHECKING, Any, Protocol, TypeVar +from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar from unittest import mock import pytest @@ -42,17 +42,17 @@ from itsdangerous import URLSafeSerializer from sqlalchemy.orm import Session - from airflow.models.baseoperator import BaseOperator - from airflow.models.dag import DAG, ScheduleArg from airflow.models.dagrun import DagRun, DagRunType + from airflow.models.mappedoperator import MappedOperator from airflow.models.taskinstance import TaskInstance from airflow.providers.standard.operators.empty import EmptyOperator - from airflow.sdk import Context, TriggerRule + from airflow.sdk import DAG, BaseOperator, Context, TriggerRule from airflow.sdk.api.datamodels._generated import TaskInstanceState as TIState - from airflow.sdk.bases.operator import BaseOperator as TaskSDKBaseOperator + from airflow.sdk.definitions.dag import ScheduleArg from airflow.sdk.execution_time.comms import StartupDetails, ToSupervisor from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance from airflow.sdk.types import DagRunProtocol + from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG from airflow.timetables.base import DataInterval from airflow.typing_compat import Self from airflow.utils.state import DagRunState, TaskInstanceState @@ -60,7 +60,8 @@ from tests_common._internals.capture_warnings import CaptureWarningsPlugin # noqa: F401 from tests_common._internals.forbidden_warnings import ForbiddenWarningsPlugin # noqa: F401 - Op = TypeVar("Op", bound=BaseOperator) +Dag = TypeVar("Dag", "DAG", "SerializedDAG", covariant=True) +Op = TypeVar("Op", bound="BaseOperator") # NOTE: DO NOT IMPORT AIRFLOW THINGS HERE! # @@ -776,7 +777,7 @@ def fake_sleep(seconds): traveller.stop() -class DagMaker(Protocol): +class DagMaker(Generic[Dag], Protocol): """ Interface definition for dag_maker return value. @@ -785,8 +786,9 @@ class DagMaker(Protocol): """ session: Session + dag: DAG - def __enter__(self) -> DAG: ... + def __enter__(self) -> Dag: ... def __exit__(self, type, value, traceback) -> None: ... @@ -949,43 +951,56 @@ def __exit__(self, type, value, traceback): if type is not None: return - if AIRFLOW_V_3_0_PLUS: - dag.bulk_write_to_db(self.bundle_name, self.bundle_version, [dag], session=self.session) - else: - dag.sync_to_db(session=self.session) - if dag.access_control and "FabAuthManager" in conf.get("core", "auth_manager"): if AIRFLOW_V_3_0_PLUS: from airflow.providers.fab.www.security_appless import ApplessAirflowSecurityManager else: - from airflow.www.security_appless import ApplessAirflowSecurityManager + from airflow.www.security_appless import ApplessAirflowSecurityManager # type: ignore security_manager = ApplessAirflowSecurityManager(session=self.session) security_manager.sync_perm_for_dag(dag.dag_id, dag.access_control) - self.dag_model = self.session.get(DagModel, dag.dag_id) - self._make_serdag(dag) - self._bag_dag_compat(self.dag) + self.dag_model = self.session.get(DagModel, dag.dag_id) + self.session.commit() - def _make_serdag(self, dag): + def _make_serdag(self, dag: DAG): from sqlalchemy import select from airflow.models.serialized_dag import SerializedDagModel - self.serialized_model = SerializedDagModel(dag) + if AIRFLOW_V_3_1_PLUS: + from airflow.serialization.serialized_objects import LazyDeserializedDAG + + self.serialized_model = SerializedDagModel(LazyDeserializedDAG.from_dag(dag)) + else: + self.serialized_model = SerializedDagModel(dag) # type: ignore[arg-type] + sdm = self.session.scalar( select(SerializedDagModel).where( SerializedDagModel.dag_id == dag.dag_id, SerializedDagModel.dag_hash == self.serialized_model.dag_hash, ) ) + + if AIRFLOW_V_3_0_PLUS: + from airflow.serialization.serialized_objects import SerializedDAG + + SerializedDAG.bulk_write_to_db( + self.bundle_name, + self.bundle_version, + [dag], + session=self.session, + ) + else: + dag.sync_to_db(session=self.session) # type: ignore[attr-defined] + if AIRFLOW_V_3_0_PLUS and self.serialized_model != sdm: from airflow.models.dag_version import DagVersion from airflow.models.dagcode import DagCode dagv = DagVersion.write_dag( dag_id=dag.dag_id, - bundle_name=self.dag_model.bundle_name, - bundle_version=self.dag_model.bundle_version, + bundle_name=self.bundle_name, + bundle_version=self.bundle_version, session=self.session, ) self.session.add(dagv) @@ -1000,9 +1015,8 @@ def _make_serdag(self, dag): sdm._data = self.serialized_model._data self.serialized_model = sdm else: - self.session.merge(self.serialized_model) - serialized_dag = self._serialized_dag() - self._bag_dag_compat(serialized_dag) + sdm = self.session.merge(self.serialized_model) + self._bag_dag_compat(dag) self.session.flush() def create_dagrun(self, *, logical_date=NOTSET, **kwargs): @@ -1025,7 +1039,7 @@ def create_dagrun(self, *, logical_date=NOTSET, **kwargs): ) logical_date = kwargs.pop("execution_date") - dag = self.dag + dag = self._serialized_dag() kwargs = { "state": DagRunState.RUNNING, "start_date": self.start_date, @@ -1054,6 +1068,10 @@ def create_dagrun(self, *, logical_date=NOTSET, **kwargs): if logical_date is not None: if run_type == DagRunType.MANUAL: data_interval = dag.timetable.infer_manual_data_interval(run_after=logical_date) + elif AIRFLOW_V_3_1_PLUS: + from airflow.models.dag import infer_automated_data_interval + + data_interval = infer_automated_data_interval(dag.timetable, logical_date) else: data_interval = dag.infer_automated_data_interval(logical_date) kwargs["data_interval"] = data_interval @@ -1084,18 +1102,21 @@ def create_dagrun(self, *, logical_date=NOTSET, **kwargs): kwargs.pop("triggered_by", None) kwargs["execution_date"] = logical_date - if self.want_serialized: - dag = self.serialized_model.dag self.dag_run = dag.create_dagrun(**kwargs) for ti in self.dag_run.task_instances: # This need to always operate on the _real_ dag ti.refresh_from_task(self.dag.get_task(ti.task_id)) - if self.want_serialized: - self.session.commit() + self.session.commit() return self.dag_run def create_dagrun_after(self, dagrun, **kwargs): - next_info = self.dag.next_dagrun_info(self.dag.get_run_data_interval(dagrun)) + sdag = self._serialized_dag() + if AIRFLOW_V_3_1_PLUS: + from airflow.models.dag import get_run_data_interval + + next_info = sdag.next_dagrun_info(get_run_data_interval(sdag.timetable, dagrun)) + else: + next_info = sdag.next_dagrun_info(sdag.get_run_data_interval(dagrun)) if next_info is None: raise ValueError(f"cannot create run after {dagrun}") return self.create_dagrun( @@ -1169,7 +1190,14 @@ def __call__( **kwargs, ): from airflow import settings - from airflow.models.dag import DAG + + # Don't change this to AIRFLOW_V_3_0_PLUS. Although SDK DAG exists + # before 3.1, things in dag maker setup can't handle it in compat + # tests. They are probably fixable, but it's not worthwhile to. + if AIRFLOW_V_3_1_PLUS: + from airflow.sdk import DAG + else: + from airflow import DAG timezone = _import_timezone() @@ -1617,13 +1645,18 @@ def _get(dag_id: str): return if AIRFLOW_V_3_0_PLUS: - session = settings.Session() + from sqlalchemy import func, select + from airflow.models.dagbundle import DagBundleModel + from airflow.serialization.serialized_objects import SerializedDAG - if session.query(DagBundleModel).filter(DagBundleModel.name == "testing").count() == 0: + session = settings.Session() + if not session.scalar(select(func.count()).where(DagBundleModel.name == "testing")): session.add(DagBundleModel(name="testing")) - session.commit() - dag.bulk_write_to_db("testing", None, [dag]) + session.flush() + SerializedDAG.bulk_write_to_db("testing", None, [dag], session=session) + session.commit() + session.close() else: dag.sync_to_db() SerializedDagModel.write_dag(dag, bundle_name="testing") @@ -2159,8 +2192,8 @@ def mocked_parse(spy_agency): ) """ - def set_dag(what: StartupDetails, dag_id: str, task: TaskSDKBaseOperator) -> RuntimeTaskInstance: - from airflow.sdk.definitions.dag import DAG + def set_dag(what: StartupDetails, dag_id: str, task: BaseOperator) -> RuntimeTaskInstance: + from airflow.sdk import DAG from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance, parse timezone = _import_timezone() @@ -2238,7 +2271,7 @@ def context(self) -> Context: ... def __call__( self, - task: TaskSDKBaseOperator, + task: BaseOperator, dag_id: str = ..., run_id: str = ..., logical_date: datetime | None = None, @@ -2274,18 +2307,19 @@ def execute(self, context): """ from uuid6 import uuid7 + from airflow.sdk import DAG from airflow.sdk.api.datamodels._generated import TaskInstance - from airflow.sdk.definitions.dag import DAG from airflow.sdk.execution_time.comms import BundleInfo, StartupDetails + from airflow.serialization.serialized_objects import SerializedDAG from airflow.timetables.base import TimeRestriction timezone = _import_timezone() def _create_task_instance( - task: BaseOperator, + task: MappedOperator | SerializedBaseOperator, dag_id: str = "test_dag", run_id: str = "test_run", - logical_date: str | datetime = "2024-12-01T01:00:00Z", + logical_date: str | datetime | None = "2024-12-01T01:00:00Z", start_date: str | datetime = "2024-12-01T01:00:00Z", run_type: str = "manual", try_number: int = 1, @@ -2300,23 +2334,46 @@ def _create_task_instance( from airflow.sdk.api.datamodels._generated import DagRun, DagRunState, TIRunContext from airflow.utils.types import DagRunType + if isinstance(logical_date, str): + logical_date = timezone.parse(logical_date) + else: + logical_date = timezone.coerce_datetime(logical_date) + if isinstance(start_date, str): + start_date = timezone.parse(start_date) + else: + start_date = timezone.coerce_datetime(start_date) + + if TYPE_CHECKING: + from pendulum import DateTime + + assert logical_date is None or isinstance(logical_date, DateTime) + assert isinstance(start_date, DateTime) + if not ti_id: ti_id = uuid7() if not task.has_dag(): - dag = DAG(dag_id=dag_id, start_date=timezone.datetime(2024, 12, 3)) + dag = SerializedDAG.deserialize_dag( + SerializedDAG.serialize_dag(DAG(dag_id=dag_id, start_date=timezone.datetime(2024, 12, 3))) + ) # Fixture only helps in regular base operator tasks, so mypy is wrong here task.dag = dag - task = dag.task_dict[task.task_id] + # TODO (GH-52141): Scheduler DAG should contain scheduler tasks, but + # currently this inherits from SDK DAG. + task = dag.task_dict[task.task_id] # type: ignore[assignment] + + if TYPE_CHECKING: + assert task.dag is not None data_interval_start = None data_interval_end = None if task.dag.timetable: if run_type == DagRunType.MANUAL: - data_interval_start, data_interval_end = task.dag.timetable.infer_manual_data_interval( - run_after=logical_date - ) + if logical_date is not None: + data_interval_start, data_interval_end = task.dag.timetable.infer_manual_data_interval( + run_after=logical_date, + ) else: drinfo = task.dag.timetable.next_dagrun_info( last_automated_data_interval=None, @@ -2543,7 +2600,7 @@ def context(self) -> Context: def __call__( self, - task: TaskSDKBaseOperator, + task: BaseOperator, dag_id: str = "test_dag", run_id: str = "test_run", logical_date: datetime | None = None, @@ -2636,11 +2693,11 @@ def _create_conn(connection, session=None): def _import_timezone(): try: - from airflow.sdk._shared.timezones import timezone - except ModuleNotFoundError: + from airflow.sdk import timezone + except ImportError: try: from airflow._shared.timezones import timezone - except ModuleNotFoundError: + except ImportError: from airflow.utils import timezone return timezone @@ -2648,7 +2705,12 @@ def _import_timezone(): @pytest.fixture def create_dag_without_db(): def create_dag(dag_id: str): - from airflow.models.dag import DAG + from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS + + if AIRFLOW_V_3_0_PLUS: + from airflow.sdk import DAG + else: + from airflow import DAG return DAG(dag_id=dag_id, schedule=None, render_template_as_native_obj=True) diff --git a/devel-common/src/tests_common/test_utils/dag.py b/devel-common/src/tests_common/test_utils/dag.py new file mode 100644 index 0000000000000..8d5cd1da069ca --- /dev/null +++ b/devel-common/src/tests_common/test_utils/dag.py @@ -0,0 +1,78 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from collections.abc import Collection, Sequence +from typing import TYPE_CHECKING + +from airflow.utils.session import NEW_SESSION, provide_session + +if TYPE_CHECKING: + from sqlalchemy.orm import Session + + from airflow.sdk import DAG + from airflow.serialization.serialized_objects import SerializedDAG + + +def create_scheduler_dag(dag: DAG | SerializedDAG) -> SerializedDAG: + from airflow.serialization.serialized_objects import SerializedDAG + + if isinstance(dag, SerializedDAG): + return dag + return SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag)) + + +@provide_session +def sync_dag_to_db( + dag: DAG, + bundle_name: str = "testing", + session: Session = NEW_SESSION, +) -> SerializedDAG: + return sync_dags_to_db([dag], bundle_name=bundle_name, session=session)[0] + + +@provide_session +def sync_dags_to_db( + dags: Collection[DAG], + bundle_name: str = "testing", + session: Session = NEW_SESSION, +) -> Sequence[SerializedDAG]: + """ + Sync dags into the database. + + This serializes dags and saves the results to the database. The serialized + (scheduler-oeirnted) dags are returned. If the input is ordered (e.g. a list), + the returned sequence is guaranteed to be in the same order. + """ + from airflow.models.dagbundle import DagBundleModel + from airflow.models.serialized_dag import SerializedDagModel + from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG + + session.merge(DagBundleModel(name=bundle_name)) + session.flush() + + def _write_dag(dag: DAG) -> SerializedDAG: + data = SerializedDAG.to_dict(dag) + SerializedDagModel.write_dag(LazyDeserializedDAG(data=data), bundle_name, session=session) + return SerializedDAG.from_dict(data) + + SerializedDAG.bulk_write_to_db(bundle_name, None, dags, session=session) + scheduler_dags = [_write_dag(dag) for dag in dags] + session.flush() + return scheduler_dags diff --git a/devel-common/src/tests_common/test_utils/db.py b/devel-common/src/tests_common/test_utils/db.py index 28764aeb90d4f..4c40cc74d0037 100644 --- a/devel-common/src/tests_common/test_utils/db.py +++ b/devel-common/src/tests_common/test_utils/db.py @@ -22,6 +22,8 @@ from tempfile import gettempdir from typing import TYPE_CHECKING +from sqlalchemy import select + from airflow.configuration import conf from airflow.jobs.job import Job from airflow.models import ( @@ -72,8 +74,22 @@ from airflow.models.dag_favorite import DagFavorite +def _deactivate_unknown_dags(active_dag_ids, session): + """ + Given a list of known DAGs, deactivate any other DAGs that are marked as active in the ORM. + + :param active_dag_ids: list of DAG IDs that are active + :return: None + """ + if not active_dag_ids: + return + for dag in session.scalars(select(DagModel).where(~DagModel.dag_id.in_(active_dag_ids))): + dag.is_stale = True + session.merge(dag) + session.commit() + + def _bootstrap_dagbag(): - from airflow.models.dag import DAG from airflow.models.dagbag import DagBag if AIRFLOW_V_3_0_PLUS: @@ -100,7 +116,7 @@ def _bootstrap_dagbag(): dagbag.sync_to_db(session=session) # type: ignore[attr-defined] # Deactivate the unknown ones - DAG.deactivate_unknown_dags(dagbag.dags.keys(), session=session) + _deactivate_unknown_dags(dagbag.dags, session=session) def initial_db_init(): @@ -155,7 +171,7 @@ def parse_and_sync_to_db(folder: Path | str, include_examples: bool = False): with create_session() as session: if AIRFLOW_V_3_0_PLUS: DagBundlesManager().sync_bundles_to_db(session=session) - session.commit() + session.flush() dagbag = DagBag(dag_folder=folder, include_examples=include_examples) if AIRFLOW_V_3_1_PLUS: @@ -166,7 +182,6 @@ def parse_and_sync_to_db(folder: Path | str, include_examples: bool = False): dagbag.sync_to_db("dags-folder", None, session) # type: ignore[attr-defined] else: dagbag.sync_to_db(session=session) # type: ignore[attr-defined] - session.commit() return dagbag diff --git a/devel-common/src/tests_common/test_utils/perf/perf_kit/__init__.py b/devel-common/src/tests_common/test_utils/perf/perf_kit/__init__.py index 1b98a49ea411a..d2e72a822f54f 100644 --- a/devel-common/src/tests_common/test_utils/perf/perf_kit/__init__.py +++ b/devel-common/src/tests_common/test_utils/perf/perf_kit/__init__.py @@ -76,7 +76,7 @@ def test_bulk_write_to_db(self): dags = [DAG(f"dag-bulk-sync-{i}", start_date=DEFAULT_DATE, tags=["test-dag"]) for i in range(0, 4)] with assert_queries_count(3): - DAG.bulk_write_to_db(dags) + SerializedDAG.bulk_write_to_db(dags) You can add a code snippet before the method definition, and then perform only one test and count the queries in it. @@ -99,7 +99,7 @@ def test_bulk_write_to_db(self): dags = [DAG(f"dag-bulk-sync-{i}", start_date=DEFAULT_DATE, tags=["test-dag"]) for i in range(0, 4)] with assert_queries_count(3): - DAG.bulk_write_to_db(dags) + SerializedDAG.bulk_write_to_db(dags) To run the test, execute the command diff --git a/devel-common/src/tests_common/test_utils/system_tests.py b/devel-common/src/tests_common/test_utils/system_tests.py index 388b0b2377f5e..544b4f3bb6c5e 100644 --- a/devel-common/src/tests_common/test_utils/system_tests.py +++ b/devel-common/src/tests_common/test_utils/system_tests.py @@ -71,6 +71,7 @@ def test_run(): if AIRFLOW_V_3_0_PLUS: from airflow.models.dag import DagModel from airflow.models.serialized_dag import SerializedDagModel + from airflow.serialization.serialized_objects import LazyDeserializedDAG from airflow.settings import Session s = Session() @@ -78,7 +79,7 @@ def test_run(): d = DagModel(dag_id=dag.dag_id, bundle_name=bundle_name) s.add(d) s.commit() - SerializedDagModel.write_dag(dag, bundle_name=bundle_name) + SerializedDagModel.write_dag(LazyDeserializedDAG.from_dag(dag), bundle_name=bundle_name) dag_run = dag.test( use_executor=os.environ.get("_AIRFLOW__SYSTEM_TEST_USE_EXECUTOR") == "1", diff --git a/providers/amazon/tests/unit/amazon/aws/log/test_cloudwatch_task_handler.py b/providers/amazon/tests/unit/amazon/aws/log/test_cloudwatch_task_handler.py index 8b730f87dc854..2714b036e721a 100644 --- a/providers/amazon/tests/unit/amazon/aws/log/test_cloudwatch_task_handler.py +++ b/providers/amazon/tests/unit/amazon/aws/log/test_cloudwatch_task_handler.py @@ -33,7 +33,6 @@ from watchtower import CloudWatchLogHandler from airflow.models import DAG, DagRun, TaskInstance -from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook from airflow.providers.amazon.aws.log.cloudwatch_task_handler import ( CloudWatchRemoteLogIO, @@ -45,6 +44,7 @@ from airflow.utils.timezone import datetime from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.dag import sync_dag_to_db from tests_common.test_utils.db import clear_db_dag_bundles, clear_db_dags, clear_db_runs from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS @@ -203,10 +203,7 @@ def setup(self, create_log_template, tmp_path_factory, session, testing_dag_bund self.dag = DAG(dag_id=dag_id, schedule=None, start_date=date) task = EmptyOperator(task_id=task_id, dag=self.dag) if AIRFLOW_V_3_0_PLUS: - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [self.dag]) - self.dag.sync_to_db() - SerializedDagModel.write_dag(self.dag, bundle_name=bundle_name) + sync_dag_to_db(self.dag) dag_run = DagRun( dag_id=self.dag.dag_id, logical_date=date, diff --git a/providers/amazon/tests/unit/amazon/aws/log/test_s3_task_handler.py b/providers/amazon/tests/unit/amazon/aws/log/test_s3_task_handler.py index cb0d32aa5578b..47fa354af535a 100644 --- a/providers/amazon/tests/unit/amazon/aws/log/test_s3_task_handler.py +++ b/providers/amazon/tests/unit/amazon/aws/log/test_s3_task_handler.py @@ -28,17 +28,21 @@ from moto import mock_aws from airflow.models import DAG, DagRun, TaskInstance -from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.providers.amazon.aws.log.s3_task_handler import S3TaskHandler from airflow.providers.standard.operators.empty import EmptyOperator from airflow.utils.state import State, TaskInstanceState -from airflow.utils.timezone import datetime from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.dag import sync_dag_to_db from tests_common.test_utils.db import clear_db_dag_bundles, clear_db_dags, clear_db_runs from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +try: + from airflow.sdk.timezone import datetime +except ImportError: + from airflow.utils.timezone import datetime # type: ignore[attr-defined,no-redef] + @pytest.fixture(autouse=True) def s3mock(): @@ -71,9 +75,7 @@ def setup_tests(self, create_log_template, tmp_path_factory, session, testing_da self.dag = DAG("dag_for_testing_s3_task_handler", schedule=None, start_date=date) task = EmptyOperator(task_id="task_for_testing_s3_log_handler", dag=self.dag) if AIRFLOW_V_3_0_PLUS: - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [self.dag]) - SerializedDagModel.write_dag(self.dag, bundle_name=bundle_name) + scheduler_dag = sync_dag_to_db(self.dag) dag_run = DagRun( dag_id=self.dag.dag_id, logical_date=date, @@ -81,6 +83,7 @@ def setup_tests(self, create_log_template, tmp_path_factory, session, testing_da run_type="manual", ) else: + scheduler_dag = self.dag dag_run = DagRun( dag_id=self.dag.dag_id, execution_date=date, @@ -107,7 +110,7 @@ def setup_tests(self, create_log_template, tmp_path_factory, session, testing_da self.conn.create_bucket(Bucket="bucket") yield - self.dag.clear() + scheduler_dag.clear() self.clear_db() if self.s3_task_handler.handler: @@ -204,9 +207,7 @@ def setup_tests(self, create_log_template, tmp_path_factory, session, testing_da self.dag = DAG("dag_for_testing_s3_task_handler", schedule=None, start_date=date) task = EmptyOperator(task_id="task_for_testing_s3_log_handler", dag=self.dag) if AIRFLOW_V_3_0_PLUS: - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [self.dag]) - SerializedDagModel.write_dag(self.dag, bundle_name=bundle_name) + scheduler_dag = sync_dag_to_db(self.dag) dag_run = DagRun( dag_id=self.dag.dag_id, logical_date=date, @@ -214,6 +215,7 @@ def setup_tests(self, create_log_template, tmp_path_factory, session, testing_da run_type="manual", ) else: + scheduler_dag = self.dag dag_run = DagRun( dag_id=self.dag.dag_id, execution_date=date, @@ -240,7 +242,7 @@ def setup_tests(self, create_log_template, tmp_path_factory, session, testing_da self.conn.create_bucket(Bucket="bucket") yield - self.dag.clear() + scheduler_dag.clear() self.clear_db() if self.s3_task_handler.handler: diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_athena.py b/providers/amazon/tests/unit/amazon/aws/operators/test_athena.py index 721db41d036bc..1273b26f873aa 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_athena.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_athena.py @@ -21,10 +21,10 @@ from unittest import mock import pytest +from moto import mock_aws from airflow.exceptions import AirflowException, TaskDeferred from airflow.models import DAG, DagRun, TaskInstance -from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.amazon.aws.hooks.athena import AthenaHook from airflow.providers.amazon.aws.operators.athena import AthenaOperator from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger @@ -38,19 +38,18 @@ SymlinksDatasetFacet, ) from airflow.providers.openlineage.extractors import OperatorLineage - -try: - from airflow.sdk import timezone -except ImportError: - from airflow.utils import timezone # type: ignore[attr-defined,no-redef] -from moto import mock_aws - from airflow.utils.state import DagRunState from airflow.utils.types import DagRunType +from tests_common.test_utils.dag import sync_dag_to_db from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS from unit.amazon.aws.utils.test_template_fields import validate_template_fields +try: + from airflow.sdk import timezone +except ImportError: + from airflow.utils import timezone # type: ignore[attr-defined,no-redef] + TEST_DAG_ID = "unit_tests" DEFAULT_DATE = timezone.datetime(2018, 1, 1) ATHENA_QUERY_ID = "eac29bf8-daa1-4ffc-b19a-0db31dc3b784" @@ -250,9 +249,7 @@ def test_return_value( if AIRFLOW_V_3_0_PLUS: from airflow.models.dag_version import DagVersion - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [self.dag]) - SerializedDagModel.write_dag(self.dag, bundle_name=bundle_name) + sync_dag_to_db(self.dag) dag_version = DagVersion.get_latest_version(self.dag.dag_id) ti = TaskInstance(task=self.athena, dag_version_id=dag_version.id) dag_run = DagRun( diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_datasync.py b/providers/amazon/tests/unit/amazon/aws/operators/test_datasync.py index 544204f6d3a4e..da938ceae35e3 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_datasync.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_datasync.py @@ -24,21 +24,21 @@ from airflow.exceptions import AirflowException from airflow.models import DAG, DagRun, TaskInstance -from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.amazon.aws.hooks.datasync import DataSyncHook from airflow.providers.amazon.aws.links.datasync import DataSyncTaskLink from airflow.providers.amazon.aws.operators.datasync import DataSyncOperator - -try: - from airflow.sdk import timezone -except ImportError: - from airflow.utils import timezone # type: ignore[attr-defined,no-redef] from airflow.utils.state import DagRunState from airflow.utils.types import DagRunType +from tests_common.test_utils.dag import sync_dag_to_db from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS from unit.amazon.aws.utils.test_template_fields import validate_template_fields +try: + from airflow.sdk import timezone +except ImportError: + from airflow.utils import timezone # type: ignore[attr-defined,no-redef] + TEST_DAG_ID = "unit_tests" DEFAULT_DATE = timezone.datetime(2018, 1, 1) @@ -361,11 +361,9 @@ def test_return_value( self.set_up_operator() if AIRFLOW_V_3_0_PLUS: - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [self.dag]) - SerializedDagModel.write_dag(self.dag, bundle_name=bundle_name) from airflow.models.dag_version import DagVersion + sync_dag_to_db(self.dag) dag_version = DagVersion.get_latest_version(self.dag.dag_id) dag_run = DagRun( dag_id=self.dag.dag_id, @@ -584,11 +582,9 @@ def test_return_value( self.set_up_operator() if AIRFLOW_V_3_0_PLUS: - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [self.dag]) - SerializedDagModel.write_dag(self.dag, bundle_name=bundle_name) from airflow.models.dag_version import DagVersion + sync_dag_to_db(self.dag) dag_version = DagVersion.get_latest_version(self.dag.dag_id) ti = TaskInstance(task=self.datasync, dag_version_id=dag_version.id) dag_run = DagRun( @@ -709,11 +705,9 @@ def test_return_value( self.set_up_operator() if AIRFLOW_V_3_0_PLUS: - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [self.dag]) - SerializedDagModel.write_dag(self.dag, bundle_name=bundle_name) from airflow.models.dag_version import DagVersion + sync_dag_to_db(self.dag) dag_version = DagVersion.get_latest_version(self.dag.dag_id) ti = TaskInstance(task=self.datasync, dag_version_id=dag_version.id) dag_run = DagRun( @@ -927,11 +921,9 @@ def test_return_value( self.set_up_operator() if AIRFLOW_V_3_0_PLUS: - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [self.dag]) - SerializedDagModel.write_dag(self.dag, bundle_name=bundle_name) from airflow.models.dag_version import DagVersion + sync_dag_to_db(self.dag) dag_version = DagVersion.get_latest_version(self.dag.dag_id) ti = TaskInstance(task=self.datasync, dag_version_id=dag_version.id) dag_run = DagRun( @@ -1048,11 +1040,9 @@ def test_return_value( self.set_up_operator() if AIRFLOW_V_3_0_PLUS: - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [self.dag]) - SerializedDagModel.write_dag(self.dag, bundle_name=bundle_name) from airflow.models.dag_version import DagVersion + sync_dag_to_db(self.dag) dag_version = DagVersion.get_latest_version(self.dag.dag_id) ti = TaskInstance(task=self.datasync, dag_version_id=dag_version.id) dag_run = DagRun( diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_dms.py b/providers/amazon/tests/unit/amazon/aws/operators/test_dms.py index d5ae886de6a2d..658b02e188745 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_dms.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_dms.py @@ -25,7 +25,6 @@ from airflow.exceptions import AirflowException, TaskDeferred from airflow.models import DAG, DagRun, TaskInstance -from airflow.models.serialized_dag import SerializedDagModel from airflow.models.variable import Variable from airflow.providers.amazon.aws.hooks.dms import DmsHook from airflow.providers.amazon.aws.operators.dms import ( @@ -45,17 +44,18 @@ DmsReplicationDeprovisionedTrigger, DmsReplicationTerminalStatusTrigger, ) - -try: - from airflow.sdk import timezone -except ImportError: - from airflow.utils import timezone # type: ignore[attr-defined,no-redef] from airflow.utils.state import DagRunState from airflow.utils.types import DagRunType +from tests_common.test_utils.dag import sync_dag_to_db from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS from unit.amazon.aws.utils.test_template_fields import validate_template_fields +try: + from airflow.sdk import timezone +except ImportError: + from airflow.utils import timezone # type: ignore[attr-defined,no-redef] + if AIRFLOW_V_3_0_PLUS: from airflow.models.dag_version import DagVersion @@ -330,9 +330,7 @@ def test_describe_tasks_return_value( ) if AIRFLOW_V_3_0_PLUS: - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [self.dag]) - SerializedDagModel.write_dag(self.dag, bundle_name=bundle_name) + sync_dag_to_db(self.dag) dag_version = DagVersion.get_latest_version(self.dag.dag_id) ti = TaskInstance(task=describe_task, dag_version_id=dag_version.id) dag_run = DagRun( @@ -534,9 +532,7 @@ def test_template_fields_native( ) if AIRFLOW_V_3_0_PLUS: - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [dag]) - SerializedDagModel.write_dag(dag, bundle_name=bundle_name) + sync_dag_to_db(dag) dag_version = DagVersion.get_latest_version(dag.dag_id) ti = TaskInstance(task=op, dag_version_id=dag_version.id) dag_run = DagRun( @@ -546,8 +542,7 @@ def test_template_fields_native( state=DagRunState.RUNNING, logical_date=logical_date, ) - DAG.bulk_write_to_db(bundle_name, None, [dag]) - SerializedDagModel.write_dag(dag, bundle_name=bundle_name) + sync_dag_to_db(dag) dag_version = DagVersion.get_latest_version(dag.dag_id) ti = TaskInstance(task=op, dag_version_id=dag_version.id) dag_run = DagRun( diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_add_steps.py b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_add_steps.py index 56d4d9f5cffbf..b4f571b39cc84 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_add_steps.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_add_steps.py @@ -27,7 +27,6 @@ from airflow.exceptions import AirflowException, TaskDeferred from airflow.models import DAG, DagRun, TaskInstance -from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.amazon.aws.operators.emr import EmrAddStepsOperator from airflow.providers.amazon.aws.triggers.emr import EmrAddStepsTrigger @@ -38,6 +37,7 @@ from airflow.utils.state import DagRunState from airflow.utils.types import DagRunType +from tests_common.test_utils.dag import sync_dag_to_db from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS from unit.amazon.aws.utils.test_template_fields import validate_template_fields @@ -106,11 +106,9 @@ def test_validate_mutually_exclusive_args(self, job_flow_id, job_flow_name): @pytest.mark.db_test def test_render_template(self, session, clean_dags_dagruns_and_dagbundles, testing_dag_bundle): if AIRFLOW_V_3_0_PLUS: - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [self.operator.dag]) - SerializedDagModel.write_dag(self.operator.dag, bundle_name=bundle_name) from airflow.models.dag_version import DagVersion + sync_dag_to_db(self.operator.dag) dag_version = DagVersion.get_latest_version(self.operator.dag.dag_id) ti = TaskInstance(task=self.operator, dag_version_id=dag_version.id) dag_run = DagRun( @@ -182,11 +180,9 @@ def test_render_template_from_file( do_xcom_push=False, ) if AIRFLOW_V_3_0_PLUS: - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [dag]) - SerializedDagModel.write_dag(dag, bundle_name=bundle_name) from airflow.models.dag_version import DagVersion + sync_dag_to_db(dag) dag_version = DagVersion.get_latest_version(dag.dag_id) ti = TaskInstance(task=test_task, dag_version_id=dag_version.id) dag_run = DagRun( diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py index cf89cdbb78021..43624ef501dba 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py @@ -28,22 +28,22 @@ from airflow.exceptions import TaskDeferred from airflow.models import DAG, DagRun, TaskInstance -from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.amazon.aws.operators.emr import EmrCreateJobFlowOperator from airflow.providers.amazon.aws.triggers.emr import EmrCreateJobFlowTrigger from airflow.providers.amazon.aws.utils.waiter import WAITER_POLICY_NAME_MAPPING, WaitPolicy - -try: - from airflow.sdk import timezone -except ImportError: - from airflow.utils import timezone # type: ignore[attr-defined,no-redef] from airflow.utils.state import DagRunState from airflow.utils.types import DagRunType +from tests_common.test_utils.dag import sync_dag_to_db from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS from unit.amazon.aws.utils.test_template_fields import validate_template_fields from unit.amazon.aws.utils.test_waiter import assert_expected_waiter_type +try: + from airflow.sdk import timezone +except ImportError: + from airflow.utils import timezone # type: ignore[attr-defined,no-redef] + TASK_ID = "test_task" TEST_DAG_ID = "test_dag_id" @@ -107,9 +107,7 @@ def test_render_template(self, session, clean_dags_dagruns_and_dagbundles, testi if AIRFLOW_V_3_0_PLUS: from airflow.models.dag_version import DagVersion - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [self.operator.dag]) - SerializedDagModel.write_dag(self.operator.dag, bundle_name=bundle_name) + sync_dag_to_db(self.operator.dag) dag_version = DagVersion.get_latest_version(self.operator.dag.dag_id) ti = TaskInstance(task=self.operator, dag_version_id=dag_version.id) dag_run = DagRun( @@ -164,9 +162,7 @@ def test_render_template_from_file( if AIRFLOW_V_3_0_PLUS: from airflow.models.dag_version import DagVersion - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [self.operator.dag]) - SerializedDagModel.write_dag(self.operator.dag, bundle_name=bundle_name) + sync_dag_to_db(self.operator.dag) dag_version = DagVersion.get_latest_version(self.operator.dag.dag_id) ti = TaskInstance(task=self.operator, dag_version_id=dag_version.id) dag_run = DagRun( diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_s3.py b/providers/amazon/tests/unit/amazon/aws/operators/test_s3.py index e02e8b4d4ab15..01404e0682454 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_s3.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_s3.py @@ -33,7 +33,6 @@ from airflow import DAG from airflow.exceptions import AirflowException from airflow.models.dagrun import DagRun -from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import TaskInstance from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.providers.amazon.aws.operators.s3 import ( @@ -57,12 +56,17 @@ ) from airflow.providers.openlineage.extractors import OperatorLineage from airflow.utils.state import DagRunState -from airflow.utils.timezone import datetime, utcnow from airflow.utils.types import DagRunType -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.dag import sync_dag_to_db +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS from unit.amazon.aws.utils.test_template_fields import validate_template_fields +if AIRFLOW_V_3_1_PLUS: + from airflow.sdk.timezone import datetime, utcnow +else: + from airflow.utils.timezone import datetime, utcnow # type: ignore[attr-defined,no-redef] + BUCKET_NAME = os.environ.get("BUCKET_NAME", "test-airflow-bucket") S3_KEY = "test-airflow-key" TAG_SET = [{"Key": "Color", "Value": "Green"}] @@ -672,9 +676,7 @@ def test_dates_from_template(self, session, testing_dag_bundle): if AIRFLOW_V_3_0_PLUS: from airflow.models.dag_version import DagVersion - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [dag]) - SerializedDagModel.write_dag(dag, bundle_name=bundle_name) + sync_dag_to_db(dag) dag_version = DagVersion.get_latest_version(dag.dag_id) ti = TaskInstance(task=op, dag_version_id=dag_version.id) else: diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_base.py b/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_base.py index 11cac3602a1fd..887f8adf6b011 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_base.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_base.py @@ -26,22 +26,22 @@ from airflow.exceptions import AirflowException from airflow.models import DAG, DagRun, TaskInstance -from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.amazon.aws.operators.sagemaker import ( SageMakerBaseOperator, SageMakerCreateExperimentOperator, ) - -try: - from airflow.sdk import timezone -except ImportError: - from airflow.utils import timezone # type: ignore[attr-defined,no-redef] from airflow.utils.state import DagRunState from airflow.utils.types import DagRunType +from tests_common.test_utils.dag import sync_dag_to_db from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS from unit.amazon.aws.utils.test_template_fields import validate_template_fields +try: + from airflow.sdk import timezone +except ImportError: + from airflow.utils import timezone # type: ignore[attr-defined,no-redef] + CONFIG: dict = { "key1": "1", "key2": {"key3": "3", "key4": "4"}, @@ -218,9 +218,7 @@ def test_create_experiment( if AIRFLOW_V_3_0_PLUS: from airflow.models.dag_version import DagVersion - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [dag]) - SerializedDagModel.write_dag(dag, bundle_name=bundle_name) + sync_dag_to_db(dag) dag_version = DagVersion.get_latest_version(dag.dag_id) ti = TaskInstance(task=op, dag_version_id=dag_version.id) dag_run = DagRun( diff --git a/providers/amazon/tests/unit/amazon/aws/sensors/test_s3.py b/providers/amazon/tests/unit/amazon/aws/sensors/test_s3.py index afe449d4e5d9e..64fe824582aab 100644 --- a/providers/amazon/tests/unit/amazon/aws/sensors/test_s3.py +++ b/providers/amazon/tests/unit/amazon/aws/sensors/test_s3.py @@ -26,19 +26,19 @@ from airflow.exceptions import AirflowException from airflow.models import DAG, DagRun, TaskInstance -from airflow.models.serialized_dag import SerializedDagModel from airflow.models.variable import Variable from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor, S3KeysUnchangedSensor +from airflow.utils.state import DagRunState +from airflow.utils.types import DagRunType + +from tests_common.test_utils.dag import sync_dag_to_db +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS try: from airflow.sdk import timezone except ImportError: from airflow.utils import timezone # type: ignore[attr-defined,no-redef] -from airflow.utils.state import DagRunState -from airflow.utils.types import DagRunType - -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS DEFAULT_DATE = datetime(2015, 1, 1) @@ -138,9 +138,7 @@ def test_parse_bucket_key_from_jinja( if AIRFLOW_V_3_0_PLUS: from airflow.models.dag_version import DagVersion - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [dag]) - SerializedDagModel.write_dag(dag, bundle_name=bundle_name) + sync_dag_to_db(dag) dag_version = DagVersion.get_latest_version(dag.dag_id) dag_run = DagRun( dag_id=dag.dag_id, @@ -199,9 +197,7 @@ def test_parse_list_of_bucket_keys_from_jinja( else: from airflow.models.dag_version import DagVersion - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [dag]) - SerializedDagModel.write_dag(dag, bundle_name=bundle_name) + sync_dag_to_db(dag) dag_version = DagVersion.get_latest_version(dag.dag_id) dag_run = DagRun( dag_id=dag.dag_id, diff --git a/providers/amazon/tests/unit/amazon/aws/transfers/test_base.py b/providers/amazon/tests/unit/amazon/aws/transfers/test_base.py index 39fa2c39bd923..f45cbfdad4dae 100644 --- a/providers/amazon/tests/unit/amazon/aws/transfers/test_base.py +++ b/providers/amazon/tests/unit/amazon/aws/transfers/test_base.py @@ -21,7 +21,6 @@ from airflow import DAG from airflow.models import DagRun, TaskInstance -from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.amazon.aws.transfers.base import AwsToAwsBaseOperator try: @@ -31,6 +30,7 @@ from airflow.utils.state import DagRunState from airflow.utils.types import DagRunType +from tests_common.test_utils.dag import sync_dag_to_db from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS DEFAULT_DATE = timezone.datetime(2020, 1, 1) @@ -53,9 +53,7 @@ def test_render_template(self, session, clean_dags_dagruns_and_dagbundles, testi if AIRFLOW_V_3_0_PLUS: from airflow.models.dag_version import DagVersion - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [self.dag]) - SerializedDagModel.write_dag(self.dag, bundle_name=bundle_name) + sync_dag_to_db(self.dag) dag_version = DagVersion.get_latest_version(self.dag.dag_id) ti = TaskInstance(operator, run_id="something", dag_version_id=dag_version.id) ti.dag_run = DagRun( diff --git a/providers/amazon/tests/unit/amazon/aws/transfers/test_dynamodb_to_s3.py b/providers/amazon/tests/unit/amazon/aws/transfers/test_dynamodb_to_s3.py index d1a7d8e2e2c67..a2efafa1f330c 100644 --- a/providers/amazon/tests/unit/amazon/aws/transfers/test_dynamodb_to_s3.py +++ b/providers/amazon/tests/unit/amazon/aws/transfers/test_dynamodb_to_s3.py @@ -26,11 +26,7 @@ from airflow import DAG from airflow.models import DagRun, TaskInstance -from airflow.models.serialized_dag import SerializedDagModel -from airflow.providers.amazon.aws.transfers.dynamodb_to_s3 import ( - DynamoDBToS3Operator, - JSONEncoder, -) +from airflow.providers.amazon.aws.transfers.dynamodb_to_s3 import DynamoDBToS3Operator, JSONEncoder try: from airflow.sdk import timezone @@ -39,6 +35,7 @@ from airflow.utils.state import DagRunState from airflow.utils.types import DagRunType +from tests_common.test_utils.dag import sync_dag_to_db from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS @@ -283,9 +280,7 @@ def test_render_template(self, session, clean_dags_dagruns_and_dagbundles, testi if AIRFLOW_V_3_0_PLUS: from airflow.models.dag_version import DagVersion - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [dag]) - SerializedDagModel.write_dag(dag, bundle_name=bundle_name) + sync_dag_to_db(dag) dag_version = DagVersion.get_latest_version(dag.dag_id) ti = TaskInstance(operator, run_id="something", dag_version_id=dag_version.id) ti.dag_run = DagRun( diff --git a/providers/amazon/tests/unit/amazon/aws/transfers/test_mongo_to_s3.py b/providers/amazon/tests/unit/amazon/aws/transfers/test_mongo_to_s3.py index 48e01f72e6055..e3a112c540de2 100644 --- a/providers/amazon/tests/unit/amazon/aws/transfers/test_mongo_to_s3.py +++ b/providers/amazon/tests/unit/amazon/aws/transfers/test_mongo_to_s3.py @@ -22,17 +22,17 @@ import pytest from airflow.models import DAG, DagRun, TaskInstance -from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.amazon.aws.transfers.mongo_to_s3 import MongoToS3Operator +from airflow.utils.state import DagRunState +from airflow.utils.types import DagRunType + +from tests_common.test_utils.dag import sync_dag_to_db +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS try: from airflow.sdk import timezone except ImportError: from airflow.utils import timezone # type: ignore[attr-defined,no-redef] -from airflow.utils.state import DagRunState -from airflow.utils.types import DagRunType - -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS TASK_ID = "test_mongo_to_s3_operator" MONGO_CONN_ID = "default_mongo" @@ -91,9 +91,7 @@ def test_render_template(self, session, clean_dags_dagruns_and_dagbundles, testi if AIRFLOW_V_3_0_PLUS: from airflow.models.dag_version import DagVersion - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [self.dag]) - SerializedDagModel.write_dag(self.dag, bundle_name=bundle_name) + sync_dag_to_db(self.dag) dag_version = DagVersion.get_latest_version(self.mock_operator.dag_id) ti = TaskInstance(self.mock_operator, dag_version_id=dag_version.id) dag_run = DagRun( diff --git a/providers/apache/kylin/tests/unit/apache/kylin/operators/test_kylin_cube.py b/providers/apache/kylin/tests/unit/apache/kylin/operators/test_kylin_cube.py index f64fd420f5739..0033beda1c2c2 100644 --- a/providers/apache/kylin/tests/unit/apache/kylin/operators/test_kylin_cube.py +++ b/providers/apache/kylin/tests/unit/apache/kylin/operators/test_kylin_cube.py @@ -29,6 +29,7 @@ from airflow.utils import state, timezone from airflow.utils.types import DagRunType +from tests_common.test_utils.dag import sync_dag_to_db from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS DEFAULT_DATE = timezone.datetime(2020, 1, 1) @@ -173,11 +174,8 @@ def test_render_template(self, session, testing_dag_bundle): if AIRFLOW_V_3_0_PLUS: from airflow.models.dag_version import DagVersion - from airflow.models.serialized_dag import SerializedDagModel - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [self.dag]) - SerializedDagModel.write_dag(dag=self.dag, bundle_name=bundle_name) + sync_dag_to_db(self.dag) dag_version = DagVersion.get_latest_version(operator.dag_id) ti = TaskInstance(operator, run_id="kylin_test", dag_version_id=dag_version.id) ti.dag_run = DagRun( diff --git a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py index f277d8420bac3..7acffb447d07c 100644 --- a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py +++ b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py @@ -30,6 +30,7 @@ from airflow.utils import timezone from airflow.utils.types import DagRunType +from tests_common.test_utils.dag import sync_dag_to_db from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS DEFAULT_DATE = timezone.datetime(2017, 1, 1) @@ -200,11 +201,8 @@ def test_render_template(self, session, testing_dag_bundle): if AIRFLOW_V_3_0_PLUS: from airflow.models.dag_version import DagVersion - from airflow.models.serialized_dag import SerializedDagModel - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [self.dag]) - SerializedDagModel.write_dag(dag=self.dag, bundle_name=bundle_name) + sync_dag_to_db(self.dag) dag_version = DagVersion.get_latest_version(operator.dag_id) ti = TaskInstance(operator, run_id="spark_test", dag_version_id=dag_version.id) ti.dag_run = DagRun( diff --git a/providers/celery/tests/unit/celery/executors/test_celery_executor.py b/providers/celery/tests/unit/celery/executors/test_celery_executor.py index 214ec70842b76..9384ef626ce1f 100644 --- a/providers/celery/tests/unit/celery/executors/test_celery_executor.py +++ b/providers/celery/tests/unit/celery/executors/test_celery_executor.py @@ -34,21 +34,24 @@ from kombu.asynchronous import set_event_loop from airflow.configuration import conf -from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DAG -from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import TaskInstance, TaskInstanceKey from airflow.providers.celery.executors import celery_executor, celery_executor_utils, default_celery from airflow.providers.celery.executors.celery_executor import CeleryExecutor -from airflow.utils import timezone from airflow.utils.state import State from tests_common.test_utils import db from tests_common.test_utils.config import conf_vars -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.dag import sync_dag_to_db +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS if AIRFLOW_V_3_0_PLUS: from airflow.models.dag_version import DagVersion +if AIRFLOW_V_3_1_PLUS: + from airflow.sdk import BaseOperator, timezone +else: + from airflow.models.baseoperator import BaseOperator # type: ignore[attr-defined,no-redef] + from airflow.utils import timezone # type: ignore[attr-defined,no-redef] pytestmark = pytest.mark.db_test @@ -200,9 +203,7 @@ def test_try_adopt_task_instances_none(self, clean_dags_dagruns_and_dagbundles, task_1 = BaseOperator(task_id="task_1", start_date=start_date) if AIRFLOW_V_3_0_PLUS: - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [dag]) - SerializedDagModel.write_dag(dag, bundle_name=bundle_name) + sync_dag_to_db(dag) dag_version = DagVersion.get_latest_version(dag.dag_id) key1 = TaskInstance(task=task_1, run_id=None, dag_version_id=dag_version.id) else: @@ -223,9 +224,7 @@ def test_try_adopt_task_instances(self, clean_dags_dagruns_and_dagbundles, testi task_2 = BaseOperator(task_id="task_2", start_date=start_date) if AIRFLOW_V_3_0_PLUS: - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [dag]) - SerializedDagModel.write_dag(dag, bundle_name=bundle_name) + sync_dag_to_db(dag) dag_version = DagVersion.get_latest_version(dag.dag_id) ti1 = TaskInstance(task=task_1, run_id=None, dag_version_id=dag_version.id) ti2 = TaskInstance(task=task_2, run_id=None, dag_version_id=dag_version.id) @@ -269,9 +268,7 @@ def test_cleanup_stuck_queued_tasks( task = BaseOperator(task_id="task_1", start_date=start_date) if AIRFLOW_V_3_0_PLUS: - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [dag]) - SerializedDagModel.write_dag(dag, bundle_name=bundle_name) + sync_dag_to_db(dag) dag_version = DagVersion.get_latest_version(task.dag.dag_id) ti = TaskInstance(task=task, run_id=None, dag_version_id=dag_version.id) else: @@ -305,9 +302,7 @@ def test_revoke_task(self, mock_fail, clean_dags_dagruns_and_dagbundles, testing task = BaseOperator(task_id="task_1", start_date=start_date) if AIRFLOW_V_3_0_PLUS: - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [dag]) - SerializedDagModel.write_dag(dag, bundle_name=bundle_name) + sync_dag_to_db(dag) dag_version = DagVersion.get_latest_version(task.dag.dag_id) ti = TaskInstance(task=task, run_id=None, dag_version_id=dag_version.id) else: diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/cli/kubernetes_command.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/cli/kubernetes_command.py index 41737d713eec2..1301aff7c95d1 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/cli/kubernetes_command.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/cli/kubernetes_command.py @@ -32,12 +32,16 @@ from airflow.providers.cncf.kubernetes.kube_client import get_kube_client from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import create_unique_id from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator, generate_pod_command_args -from airflow.providers.cncf.kubernetes.version_compat import AIRFLOW_V_3_0_PLUS +from airflow.providers.cncf.kubernetes.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS from airflow.utils import cli as cli_utils, yaml -from airflow.utils.cli import get_dag from airflow.utils.providers_configuration_loader import providers_configuration_loaded from airflow.utils.types import DagRunType +if AIRFLOW_V_3_1_PLUS: + from airflow.utils.cli import get_bagged_dag +else: + from airflow.utils.cli import get_dag as get_bagged_dag # type: ignore[attr-defined,no-redef] + @cli_utils.action_cli @providers_configuration_loaded @@ -45,9 +49,9 @@ def generate_pod_yaml(args): """Generate yaml files for each task in the DAG. Used for testing output of KubernetesExecutor.""" logical_date = args.logical_date if AIRFLOW_V_3_0_PLUS else args.execution_date if AIRFLOW_V_3_0_PLUS: - dag = get_dag(bundle_names=args.bundle_name, dag_id=args.dag_id) + dag = get_bagged_dag(bundle_names=args.bundle_name, dag_id=args.dag_id) else: - dag = get_dag(subdir=args.subdir, dag_id=args.dag_id) + dag = get_bagged_dag(subdir=args.subdir, dag_id=args.dag_id) yaml_output_path = args.output_path dm = DagModel(dag_id=dag.dag_id) diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes.py index 8935187778eb1..68bc19437995d 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes.py @@ -36,7 +36,7 @@ class TestKubernetesDecorator(TestKubernetesDecoratorsBase): def test_basic_kubernetes(self): """Test basic proper KubernetesPodOperator creation from @task.kubernetes decorator""" - with self.dag: + with self.dag_maker: @task.kubernetes( image="python:3.10-slim-buster", @@ -77,7 +77,7 @@ def f(): @pytest.mark.asyncio def test_kubernetes_with_input_output(self): """Verify @task.kubernetes will run XCom container if do_xcom_push is set.""" - with self.dag: + with self.dag_maker: @task.kubernetes( image="python:3.10-slim-buster", diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes_cmd.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes_cmd.py index 684cbf67ac012..17bd38d981eff 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes_cmd.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes_cmd.py @@ -42,7 +42,7 @@ class TestKubernetesCmdDecorator(TestKubernetesDecoratorsBase): def test_basic_kubernetes(self, args_only: bool): """Test basic proper KubernetesPodOperator creation from @task.kubernetes_cmd decorator""" expected = ["echo", "Hello world!"] - with self.dag: + with self.dag_maker: @task.kubernetes_cmd( image="python:3.10-slim-buster", @@ -102,7 +102,7 @@ def test_kubernetes_cmd_wrong_cmd( Test that @task.kubernetes_cmd raises an error if the python_callable returns an invalid value. """ - with self.dag: + with self.dag_maker: @task.kubernetes_cmd( image="python:3.10-slim-buster", @@ -123,7 +123,7 @@ def hello(): @pytest.mark.asyncio def test_kubernetes_cmd_with_input_output(self): """Verify @task.kubernetes_cmd will run XCom container if do_xcom_push is set.""" - with self.dag: + with self.dag_maker: @task.kubernetes_cmd( image="python:3.10-slim-buster", @@ -200,7 +200,7 @@ def test_ignored_decorator_parameters( context_manager = contextlib.nullcontext() # type: ignore expected = ["func", "return"] - with self.dag: + with self.dag_maker: # We need to suppress the warning about `cmds` and `arguments` being unused with context_manager: @@ -256,7 +256,7 @@ def test_rendering_kubernetes_cmd( expected_command: list[str], ): """Test that templating works in function return value""" - with self.dag: + with self.dag_maker: @task.kubernetes_cmd( image="python:3.10-slim-buster", @@ -287,7 +287,7 @@ def hello(add_to_command: str): @pytest.mark.asyncio def test_basic_context_works(self): """Test that decorator works with context as kwargs unpcacked in function arguments""" - with self.dag: + with self.dag_maker: @task.kubernetes_cmd( image="python:3.10-slim-buster", @@ -318,7 +318,7 @@ def hello(**context): @pytest.mark.asyncio def test_named_context_variables(self): """Test that decorator works with specific context variable as kwargs in function arguments""" - with self.dag: + with self.dag_maker: @task.kubernetes_cmd( image="python:3.10-slim-buster", @@ -349,7 +349,7 @@ def hello(ti=None, dag_run=None): @pytest.mark.asyncio def test_rendering_kubernetes_cmd_decorator_params(self): """Test that templating works in decorator parameters""" - with self.dag: + with self.dag_maker: @task.kubernetes_cmd( image="python:{{ dag.dag_id }}", @@ -379,7 +379,7 @@ def hello(): def test_airflow_skip(self): """Test that the operator is skipped if the task is skipped""" - with self.dag: + with self.dag_maker: @task.kubernetes_cmd( image="python:3.10-slim-buster", diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes_commons.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes_commons.py index 0f5f577913b5d..fbf79353668a5 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes_commons.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes_commons.py @@ -84,10 +84,10 @@ class TestKubernetesDecoratorsBase: def setup(self, dag_maker): self.dag_maker = dag_maker - with dag_maker(dag_id=DAG_ID) as dag: + with dag_maker(dag_id=DAG_ID): ... - self.dag = dag + self.dag = self.dag_maker.dag self.mock_create_pod = mock.patch(f"{POD_MANAGER_CLASS}.create_pod").start() self.mock_await_pod_start = mock.patch(f"{POD_MANAGER_CLASS}.await_pod_start").start() @@ -122,9 +122,7 @@ def teardown_method(self): def execute_task(self, task): session = self.dag_maker.session - dag_run = self.dag_maker.create_dagrun( - run_id=f"k8s_decorator_test_{DEFAULT_DATE.date()}", session=session - ) + dag_run = self.dag_maker.create_dagrun(run_id=f"k8s_decorator_test_{DEFAULT_DATE.date()}") ti = dag_run.get_task_instance(task.operator.task_id, session=session) return_val = task.operator.execute(context=ti.get_template_context(session=session)) @@ -153,7 +151,7 @@ class TestKubernetesDecoratorsCommons(TestKubernetesDecoratorsBase): def test_k8s_decorator_init(self, task_decorator, decorator_name): """Test the initialization of the @task.kubernetes[_cmd] decorated task.""" - with self.dag: + with self.dag_maker: @task_decorator( image="python:3.10-slim-buster", @@ -174,7 +172,7 @@ def k8s_task_function() -> list[str]: def test_decorators_with_marked_as_setup(self, task_decorator, decorator_name): """Test the @task.kubernetes[_cmd] decorated task works with setup decorator.""" - with self.dag: + with self.dag_maker: task_function_name = setup(_prepare_task(task_decorator, decorator_name)) task_function_name() @@ -184,7 +182,7 @@ def test_decorators_with_marked_as_setup(self, task_decorator, decorator_name): def test_decorators_with_marked_as_teardown(self, task_decorator, decorator_name): """Test the @task.kubernetes[_cmd] decorated task works with teardown decorator.""" - with self.dag: + with self.dag_maker: task_function_name = teardown(_prepare_task(task_decorator, decorator_name)) task_function_name() @@ -229,7 +227,7 @@ def test_pod_naming( **extra_kwargs, } - with self.dag: + with self.dag_maker: task_function_name = _prepare_task( task_decorator, decorator_name, diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/log_handlers/test_log_handlers.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/log_handlers/test_log_handlers.py index ed0d3b608c373..954ec3ed80645 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/log_handlers/test_log_handlers.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/log_handlers/test_log_handlers.py @@ -30,23 +30,26 @@ from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG from airflow.executors import executor_loader from airflow.models.dag import DAG -from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import TaskInstance from airflow.utils.log.file_task_handler import ( FileTaskHandler, ) from airflow.utils.log.logging_mixin import set_context from airflow.utils.state import State, TaskInstanceState -from airflow.utils.timezone import datetime from airflow.utils.types import DagRunType from tests_common.test_utils.compat import PythonOperator from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.dag import sync_dag_to_db from tests_common.test_utils.db import clear_db_dag_bundles, clear_db_dags, clear_db_runs -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS if AIRFLOW_V_3_0_PLUS: from airflow.utils.types import DagRunTriggeredByType +if AIRFLOW_V_3_1_PLUS: + from airflow.sdk.timezone import datetime +else: + from airflow.utils.timezone import datetime # type: ignore[attr-defined,no-redef] pytestmark = pytest.mark.db_test @@ -133,9 +136,7 @@ def task_callable(ti): "run_after": DEFAULT_DATE, "triggered_by": DagRunTriggeredByType.TEST, } - bundle_name = "testing" - DAG.bulk_write_to_db(bundle_name, None, [dag]) - SerializedDagModel.write_dag(dag, bundle_name=bundle_name) + dag = sync_dag_to_db(dag) else: dagrun_kwargs = {"execution_date": DEFAULT_DATE} dagrun = dag.create_dagrun( diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_job.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_job.py index c52c3dd2104ec..052bfd5eb8bf3 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_job.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_job.py @@ -28,7 +28,6 @@ from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.models import DAG, DagModel, DagRun, TaskInstance -from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.cncf.kubernetes.operators.job import ( KubernetesDeleteJobOperator, KubernetesJobOperator, @@ -38,6 +37,7 @@ from airflow.utils.session import create_session from airflow.utils.types import DagRunType +from tests_common.test_utils.dag import sync_dag_to_db from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS DEFAULT_DATE = timezone.datetime(2016, 1, 1, 1, 0, 0) @@ -62,15 +62,7 @@ def create_context(task, persist_to_db=False, map_index=None): dag = DAG(dag_id="dag", schedule=None, start_date=pendulum.now()) dag.add_task(task) if AIRFLOW_V_3_0_PLUS: - from airflow.models.dagbundle import DagBundleModel - - with create_session() as session: - bundle_name = "testing" - orm_dag_bundle = DagBundleModel(name=bundle_name) - session.add(orm_dag_bundle) - session.commit() - DAG.bulk_write_to_db(bundle_name, None, [dag]) - SerializedDagModel.write_dag(dag, bundle_name=bundle_name) + sync_dag_to_db(dag) dag_run = DagRun( run_id=DagRun.generate_run_id( run_type=DagRunType.MANUAL, logical_date=DEFAULT_DATE, run_after=DEFAULT_DATE diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_pod.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_pod.py index 5641d00e9d4d9..adca2955a30a3 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_pod.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_pod.py @@ -36,7 +36,6 @@ TaskDeferred, ) from airflow.models import DAG, DagModel, DagRun, TaskInstance -from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.cncf.kubernetes import pod_generator from airflow.providers.cncf.kubernetes.operators.pod import ( KubernetesPodOperator, @@ -48,13 +47,11 @@ from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction, PodLoggingStatus, PodPhase from airflow.providers.cncf.kubernetes.utils.xcom_sidecar import PodDefaults from airflow.utils import timezone - -if TYPE_CHECKING: - from airflow.utils.context import Context from airflow.utils.session import create_session from airflow.utils.types import DagRunType from tests_common.test_utils import db +from tests_common.test_utils.dag import sync_dag_to_db from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS if AIRFLOW_V_3_0_PLUS: @@ -62,6 +59,9 @@ else: from airflow.models.xcom import XCom # type: ignore[no-redef] +if TYPE_CHECKING: + from airflow.utils.context import Context + pytestmark = pytest.mark.db_test DEFAULT_DATE = timezone.datetime(2016, 1, 1, 1, 0, 0) @@ -113,16 +113,7 @@ def create_context(task, persist_to_db=False, map_index=None): dag.add_task(task) now = timezone.utcnow() if AIRFLOW_V_3_0_PLUS: - with create_session() as session: - from airflow.models.dagbundle import DagBundleModel - - bundle_name = "testing" - session.add(DagBundleModel(name=bundle_name)) - session.flush() - session.add(DagModel(dag_id=dag.dag_id, bundle_name=bundle_name)) - session.commit() - dag.sync_to_db() - SerializedDagModel.write_dag(dag, bundle_name="testing") + sync_dag_to_db(dag) dag_run = DagRun( run_id=DagRun.generate_run_id( run_type=DagRunType.MANUAL, logical_date=DEFAULT_DATE, run_after=DEFAULT_DATE diff --git a/providers/common/sql/tests/unit/common/sql/operators/test_sql.py b/providers/common/sql/tests/unit/common/sql/operators/test_sql.py index 731c94ecc58ec..c395a78b3618d 100644 --- a/providers/common/sql/tests/unit/common/sql/operators/test_sql.py +++ b/providers/common/sql/tests/unit/common/sql/operators/test_sql.py @@ -36,7 +36,6 @@ from airflow import DAG from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.models import Connection -from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.common.sql.hooks.handlers import fetch_all_handler from airflow.providers.common.sql.operators.sql import ( BaseSQLOperator, @@ -52,10 +51,10 @@ from airflow.providers.postgres.hooks.postgres import PostgresHook from airflow.providers.standard.operators.empty import EmptyOperator from airflow.utils import timezone # type: ignore[attr-defined] -from airflow.utils.session import create_session from airflow.utils.state import State from airflow.utils.types import DagRunType +from tests_common.test_utils.dag import sync_dag_to_db from tests_common.test_utils.db import clear_db_dag_bundles, clear_db_dags, clear_db_runs, clear_db_xcom from tests_common.test_utils.markers import skip_if_force_lowest_dependencies_marker from tests_common.test_utils.providers import get_provider_min_airflow_version @@ -1110,15 +1109,9 @@ def setup_method(self): self.branch_2 = EmptyOperator(task_id="branch_2", dag=self.dag) self.branch_3 = None if AIRFLOW_V_3_0_PLUS: - from airflow.models.dagbundle import DagBundleModel - - with create_session() as session: - bundle_name = "testing" - orm_dag_bundle = DagBundleModel(name=bundle_name) - session.add(orm_dag_bundle) - session.commit() - DAG.bulk_write_to_db(bundle_name, None, [self.dag]) - SerializedDagModel.write_dag(self.dag, bundle_name=bundle_name) + self.scheduler_dag = sync_dag_to_db(self.dag) + else: + self.scheduler_dag = self.dag def get_ti(self, task_id, dr=None): if dr is None: @@ -1130,7 +1123,7 @@ def get_ti(self, task_id, dr=None): } else: dagrun_kwargs = {"execution_date": DEFAULT_DATE} - dr = self.dag.create_dagrun( + dr = self.scheduler_dag.create_dagrun( run_id=f"manual__{timezone.utcnow().isoformat()}", run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), @@ -1158,8 +1151,7 @@ def branch_op(self): dag=self.dag, ) if AIRFLOW_V_3_0_PLUS: - self.dag.sync_to_db() - SerializedDagModel.write_dag(self.dag, bundle_name="testing") + self.scheduler_dag = sync_dag_to_db(self.dag) return branch_op def test_unsupported_conn_type(self): @@ -1172,7 +1164,6 @@ def test_unsupported_conn_type(self): follow_task_ids_if_false=["branch_2"], dag=self.dag, ) - with pytest.raises(AirflowException): op.execute({}) @@ -1232,6 +1223,8 @@ def test_sql_branch_operator_postgres(self): follow_task_ids_if_false=["branch_2"], dag=self.dag, ) + if AIRFLOW_V_3_0_PLUS: + self.scheduler_dag = sync_dag_to_db(self.dag) self.get_ti(branch_op.task_id).run() @mock.patch("airflow.providers.common.sql.operators.sql.BaseSQLOperator.get_db_hook") @@ -1239,7 +1232,7 @@ def test_branch_single_value_with_dag_run(self, mock_get_db_hook, branch_op): """Check BranchSQLOperator branch operation""" self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) - self.dag.clear() + self.scheduler_dag.clear() if AIRFLOW_V_3_0_PLUS: dagrun_kwargs = { @@ -1249,8 +1242,7 @@ def test_branch_single_value_with_dag_run(self, mock_get_db_hook, branch_op): } else: dagrun_kwargs = {"execution_date": DEFAULT_DATE} - - dr = self.dag.create_dagrun( + dr = self.scheduler_dag.create_dagrun( run_id="manual__", run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), @@ -1260,7 +1252,6 @@ def test_branch_single_value_with_dag_run(self, mock_get_db_hook, branch_op): ) mock_get_records = mock_get_db_hook.return_value.get_first - mock_get_records.return_value = 1 if AIRFLOW_V_3_0_1: @@ -1290,7 +1281,7 @@ def test_branch_true_with_dag_run(self, mock_get_db_hook, true_value, branch_op) """Check BranchSQLOperator branch operation""" self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) - self.dag.clear() + self.scheduler_dag.clear() if AIRFLOW_V_3_0_PLUS: dagrun_kwargs = { @@ -1300,7 +1291,7 @@ def test_branch_true_with_dag_run(self, mock_get_db_hook, true_value, branch_op) } else: dagrun_kwargs = {"execution_date": DEFAULT_DATE} - dr = self.dag.create_dagrun( + dr = self.scheduler_dag.create_dagrun( run_id="manual__", run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), @@ -1338,7 +1329,7 @@ def test_branch_false_with_dag_run(self, mock_get_db_hook, false_value, branch_o """Check BranchSQLOperator branch operation""" self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) - self.dag.clear() + self.scheduler_dag.clear() if AIRFLOW_V_3_0_PLUS: dagrun_kwargs = { @@ -1348,7 +1339,7 @@ def test_branch_false_with_dag_run(self, mock_get_db_hook, false_value, branch_o } else: dagrun_kwargs = {"execution_date": DEFAULT_DATE} - dr = self.dag.create_dagrun( + dr = self.scheduler_dag.create_dagrun( run_id="manual__", run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), @@ -1358,8 +1349,8 @@ def test_branch_false_with_dag_run(self, mock_get_db_hook, false_value, branch_o ) mock_get_records = mock_get_db_hook.return_value.get_first - mock_get_records.return_value = false_value + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped @@ -1396,11 +1387,10 @@ def test_branch_list_with_dag_run(self, mock_get_db_hook): self.branch_2.set_upstream(branch_op) self.branch_3 = EmptyOperator(task_id="branch_3", dag=self.dag) self.branch_3.set_upstream(branch_op) - self.dag.clear() + self.scheduler_dag.clear() if AIRFLOW_V_3_0_PLUS: - self.dag.sync_to_db() - SerializedDagModel.write_dag(self.dag, bundle_name="testing") + self.scheduler_dag = sync_dag_to_db(self.dag) dagrun_kwargs = { "logical_date": DEFAULT_DATE, "run_after": DEFAULT_DATE, @@ -1408,7 +1398,7 @@ def test_branch_list_with_dag_run(self, mock_get_db_hook): } else: dagrun_kwargs = {"execution_date": DEFAULT_DATE} - dr = self.dag.create_dagrun( + dr = self.scheduler_dag.create_dagrun( run_id="manual__", run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), @@ -1444,7 +1434,7 @@ def test_invalid_query_result_with_dag_run(self, mock_get_db_hook, branch_op): """Check BranchSQLOperator branch operation""" self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) - self.dag.clear() + self.scheduler_dag.clear() if AIRFLOW_V_3_0_PLUS: dagrun_kwargs = { @@ -1454,7 +1444,7 @@ def test_invalid_query_result_with_dag_run(self, mock_get_db_hook, branch_op): } else: dagrun_kwargs = {"execution_date": DEFAULT_DATE} - self.dag.create_dagrun( + self.scheduler_dag.create_dagrun( run_id="manual__", run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), @@ -1475,7 +1465,7 @@ def test_with_skip_in_branch_downstream_dependencies(self, mock_get_db_hook, bra """Test SQL Branch with skipping all downstream dependencies""" branch_op >> self.branch_1 >> self.branch_2 branch_op >> self.branch_2 - self.dag.clear() + self.scheduler_dag.clear() if AIRFLOW_V_3_0_PLUS: dagrun_kwargs = { @@ -1485,7 +1475,7 @@ def test_with_skip_in_branch_downstream_dependencies(self, mock_get_db_hook, bra } else: dagrun_kwargs = {"execution_date": DEFAULT_DATE} - dr = self.dag.create_dagrun( + dr = self.scheduler_dag.create_dagrun( run_id="manual__", run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), @@ -1516,7 +1506,7 @@ def test_with_skip_in_branch_downstream_dependencies2(self, mock_get_db_hook, fa """Test skipping downstream dependency for false condition""" branch_op >> self.branch_1 >> self.branch_2 branch_op >> self.branch_2 - self.dag.clear() + self.scheduler_dag.clear() if AIRFLOW_V_3_0_PLUS: dagrun_kwargs = { @@ -1526,7 +1516,7 @@ def test_with_skip_in_branch_downstream_dependencies2(self, mock_get_db_hook, fa } else: dagrun_kwargs = {"execution_date": DEFAULT_DATE} - dr = self.dag.create_dagrun( + dr = self.scheduler_dag.create_dagrun( run_id="manual__", run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), diff --git a/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py b/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py index 3d36507b7f64e..9624cb685fd53 100644 --- a/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py +++ b/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py @@ -27,9 +27,8 @@ from flask_appbuilder.api import expose from airflow.exceptions import AirflowException, TaskInstanceNotFound -from airflow.models.dag import DAG, clear_task_instances from airflow.models.dagrun import DagRun -from airflow.models.taskinstance import TaskInstance, TaskInstanceKey +from airflow.models.taskinstance import TaskInstance, TaskInstanceKey, clear_task_instances from airflow.plugins_manager import AirflowPlugin from airflow.providers.databricks.hooks.databricks import DatabricksHook from airflow.providers.databricks.version_compat import AIRFLOW_V_3_0_PLUS, BaseOperatorLink, TaskGroup, XCom @@ -89,7 +88,7 @@ def get_databricks_task_ids( if not AIRFLOW_V_3_0_PLUS: from airflow.utils.session import NEW_SESSION, provide_session - def _get_dag(dag_id: str, session: Session) -> DAG: + def _get_dag(dag_id: str, session: Session): from airflow.models.serialized_dag import SerializedDagModel dag = SerializedDagModel.get_dag(dag_id, session=session) @@ -97,7 +96,7 @@ def _get_dag(dag_id: str, session: Session) -> DAG: raise AirflowException("Dag not found.") return dag - def _get_dagrun(dag: DAG, run_id: str, session: Session) -> DagRun: + def _get_dagrun(dag, run_id: str, session: Session) -> DagRun: """ Retrieve the DagRun object associated with the specified DAG and run_id. diff --git a/providers/docker/tests/unit/docker/decorators/test_docker.py b/providers/docker/tests/unit/docker/decorators/test_docker.py index 9ec307e35e500..56818438cbcde 100644 --- a/providers/docker/tests/unit/docker/decorators/test_docker.py +++ b/providers/docker/tests/unit/docker/decorators/test_docker.py @@ -22,19 +22,23 @@ import pytest -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS - -if AIRFLOW_V_3_0_PLUS: - from airflow.sdk import setup, task, teardown -else: - from airflow.decorators import setup, task, teardown # type: ignore[attr-defined,no-redef] from airflow.exceptions import AirflowException from airflow.models import TaskInstance -from airflow.models.dag import DAG -from airflow.utils import timezone from airflow.utils.state import TaskInstanceState from tests_common.test_utils.markers import skip_if_force_lowest_dependencies_marker +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS + +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk import DAG, setup, task, teardown +else: + from airflow.decorators import setup, task, teardown # type: ignore[attr-defined,no-redef] + from airflow.models import DAG # type: ignore[attr-defined,no-redef] + +if AIRFLOW_V_3_1_PLUS: + from airflow.sdk import timezone +else: + from airflow.utils import timezone # type: ignore[attr-defined,no-redef] DEFAULT_DATE = timezone.datetime(2021, 9, 1) DILL_INSTALLED = find_spec("dill") is not None diff --git a/providers/edge3/src/airflow/providers/edge3/example_dags/integration_test.py b/providers/edge3/src/airflow/providers/edge3/example_dags/integration_test.py index 535c1746eb38d..b7ffda7ebdd16 100644 --- a/providers/edge3/src/airflow/providers/edge3/example_dags/integration_test.py +++ b/providers/edge3/src/airflow/providers/edge3/example_dags/integration_test.py @@ -46,7 +46,7 @@ except ImportError: # Airflow 2.10 compat from airflow.decorators import task, task_group # type: ignore[attr-defined,no-redef] - from airflow.models.dag import DAG # type: ignore[assignment] + from airflow.models.dag import DAG # type: ignore[no-redef] from airflow.models.param import Param # type: ignore[no-redef] from airflow.models.variable import Variable from airflow.operators.bash import BashOperator # type: ignore[no-redef] diff --git a/providers/elasticsearch/src/airflow/providers/elasticsearch/log/es_task_handler.py b/providers/elasticsearch/src/airflow/providers/elasticsearch/log/es_task_handler.py index 4001331e8f72b..93532b4efdacf 100644 --- a/providers/elasticsearch/src/airflow/providers/elasticsearch/log/es_task_handler.py +++ b/providers/elasticsearch/src/airflow/providers/elasticsearch/log/es_task_handler.py @@ -41,17 +41,23 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.models.dagrun import DagRun -from airflow.providers.elasticsearch.log.es_json_formatter import ( - ElasticsearchJSONFormatter, -) +from airflow.providers.elasticsearch.log.es_json_formatter import ElasticsearchJSONFormatter from airflow.providers.elasticsearch.log.es_response import ElasticSearchResponse, Hit -from airflow.providers.elasticsearch.version_compat import AIRFLOW_V_3_0_PLUS, EsLogMsgType -from airflow.utils import timezone +from airflow.providers.elasticsearch.version_compat import ( + AIRFLOW_V_3_0_PLUS, + AIRFLOW_V_3_1_PLUS, + EsLogMsgType, +) from airflow.utils.log.file_task_handler import FileTaskHandler from airflow.utils.log.logging_mixin import ExternalLoggingMixin, LoggingMixin from airflow.utils.module_loading import import_string from airflow.utils.session import create_session +if AIRFLOW_V_3_1_PLUS: + from airflow.sdk import timezone +else: + from airflow.utils import timezone # type: ignore[attr-defined,no-redef] + if TYPE_CHECKING: from datetime import datetime @@ -235,29 +241,17 @@ def _render_log_id(self, ti: TaskInstance | TaskInstanceKey, try_number: int) -> if USE_PER_RUN_LOG_ID: log_id_template = dag_run.get_log_template(session=session).elasticsearch_id - if TYPE_CHECKING: - assert ti.task - try: - dag = ti.task.dag - except AttributeError: # ti.task is not always set. - data_interval = (dag_run.data_interval_start, dag_run.data_interval_end) - else: - if TYPE_CHECKING: - assert dag is not None - # TODO: Task-SDK: Where should this function be? - data_interval = dag.get_run_data_interval(dag_run) # type: ignore[attr-defined] - if self.json_format: - data_interval_start = self._clean_date(data_interval[0]) - data_interval_end = self._clean_date(data_interval[1]) + data_interval_start = self._clean_date(dag_run.data_interval_start) + data_interval_end = self._clean_date(dag_run.data_interval_end) logical_date = self._clean_date(dag_run.logical_date) else: - if data_interval[0]: - data_interval_start = data_interval[0].isoformat() + if dag_run.data_interval_start: + data_interval_start = dag_run.data_interval_start.isoformat() else: data_interval_start = "" - if data_interval[1]: - data_interval_end = data_interval[1].isoformat() + if dag_run.data_interval_end: + data_interval_end = dag_run.data_interval_end.isoformat() else: data_interval_end = "" logical_date = dag_run.logical_date.isoformat() diff --git a/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py b/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py index 5675f4b173a01..3cad295e2f709 100644 --- a/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py +++ b/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py @@ -43,6 +43,7 @@ ) from tests_common.test_utils.compat import ignore_provider_compatibility_error +from tests_common.test_utils.dag import sync_dag_to_db with ignore_provider_compatibility_error("2.9.0+", __file__): from airflow.providers.fab.auth_manager.fab_auth_manager import FabAuthManager @@ -764,10 +765,16 @@ def test_get_authorized_dag_ids( ): with dag_maker("test_dag1"): EmptyOperator(task_id="task1") + if AIRFLOW_V_3_1_PLUS: + sync_dag_to_db(dag_maker.dag) with dag_maker("test_dag2"): EmptyOperator(task_id="task1") + if AIRFLOW_V_3_1_PLUS: + sync_dag_to_db(dag_maker.dag) with dag_maker("Connections"): EmptyOperator(task_id="task1") + if AIRFLOW_V_3_1_PLUS: + sync_dag_to_db(dag_maker.dag) auth_manager_with_appbuilder.security_manager.sync_perm_for_dag("test_dag1") auth_manager_with_appbuilder.security_manager.sync_perm_for_dag("test_dag2") diff --git a/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py b/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py index 57ef81b4dc6ad..7763836628b28 100644 --- a/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py +++ b/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py @@ -43,6 +43,7 @@ from tests_common.test_utils.compat import EmptyOperator, PythonOperator from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.dag import create_scheduler_dag from tests_common.test_utils.db import clear_db_runs from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS @@ -199,7 +200,7 @@ def sample_callable(**kwargs): ) t = PythonOperator(task_id=f"test_task_{scenario_name}", dag=dag, python_callable=python_callable) run_id = str(uuid.uuid1()) - dagrun = dag.create_dagrun( + dagrun = create_scheduler_dag(dag).create_dagrun( run_id=run_id, data_interval=(date, date), run_type=types.DagRunType.MANUAL, @@ -1059,7 +1060,7 @@ def sample_callable(**kwargs): "triggered_by": types.DagRunTriggeredByType.TEST, } - dagrun = dag.create_dagrun( + dagrun = create_scheduler_dag(dag).create_dagrun( run_id=run_id, data_interval=(date, date), start_date=date, diff --git a/providers/openlineage/tests/unit/openlineage/plugins/test_utils.py b/providers/openlineage/tests/unit/openlineage/plugins/test_utils.py index 7a63f78010bbb..4d4bc2878668a 100644 --- a/providers/openlineage/tests/unit/openlineage/plugins/test_utils.py +++ b/providers/openlineage/tests/unit/openlineage/plugins/test_utils.py @@ -20,7 +20,7 @@ import json import uuid from json import JSONEncoder -from typing import TYPE_CHECKING, Any +from typing import Any from unittest.mock import MagicMock, patch import pytest @@ -28,7 +28,6 @@ from openlineage.client.utils import RedactMixin from pkg_resources import parse_version -from airflow.models import DAG, DagModel from airflow.providers.common.compat.assets import Asset from airflow.providers.openlineage.plugins.facets import AirflowDebugRunFacet from airflow.providers.openlineage.utils.utils import ( @@ -44,28 +43,32 @@ is_operator_disabled, ) from airflow.serialization.enums import DagAttributeTypes, Encoding -from airflow.utils import timezone # type:ignore[attr-defined] from airflow.utils.state import State from airflow.utils.types import DagRunType from tests_common.test_utils.compat import ( BashOperator, ) -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS + +if AIRFLOW_V_3_1_PLUS: + from airflow.models.dag import get_next_data_interval + from airflow.sdk import timezone +else: + from airflow.utils import timezone # type: ignore[attr-defined,no-redef] + +if AIRFLOW_V_3_1_PLUS: + from airflow.sdk._shared.secrets_masker import _secrets_masker +elif AIRFLOW_V_3_0_PLUS: + from airflow.sdk.execution_time.secrets_masker import _secrets_masker # type: ignore[no-redef] +else: + from airflow.utils.log.secrets_masker import _secrets_masker # type: ignore[attr-defined,no-redef] if AIRFLOW_V_3_0_PLUS: + from airflow.sdk import DAG from airflow.utils.types import DagRunTriggeredByType - -if TYPE_CHECKING: - from airflow.sdk.execution_time.secrets_masker import _secrets_masker else: - try: - from airflow.sdk._shared.secrets_masker import _secrets_masker - except ImportError: - try: - from airflow.sdk.execution_time.secrets_masker import _secrets_masker - except ImportError: - from airflow.utils.log.secrets_masker import _secrets_masker + from airflow import DAG class SafeStrDict(dict): @@ -102,16 +105,19 @@ def test_get_airflow_debug_facet_logging_set_to_debug(mock_debug_mode, mock_get_ @pytest.mark.db_test +@pytest.mark.need_serialized_dag def test_get_dagrun_start_end(dag_maker): start_date = datetime.datetime(2022, 1, 1) end_date = datetime.datetime(2022, 1, 1, hour=2) with dag_maker("test", start_date=start_date, end_date=end_date, schedule="@once") as dag: pass dag_maker.sync_dagbag_to_db() - dag_model = DagModel.get_dagmodel(dag.dag_id) run_id = str(uuid.uuid1()) - data_interval = dag.get_next_data_interval(dag_model) + if AIRFLOW_V_3_1_PLUS: + data_interval = get_next_data_interval(dag.timetable, dag_maker.dag_model) + else: + data_interval = dag.get_next_data_interval(dag_maker.dag_model) if AIRFLOW_V_3_0_PLUS: dagrun_kwargs = { "logical_date": data_interval.start, diff --git a/providers/opensearch/src/airflow/providers/opensearch/log/os_task_handler.py b/providers/opensearch/src/airflow/providers/opensearch/log/os_task_handler.py index ea7d72f152df2..0c5d8f6c9f3c3 100644 --- a/providers/opensearch/src/airflow/providers/opensearch/log/os_task_handler.py +++ b/providers/opensearch/src/airflow/providers/opensearch/log/os_task_handler.py @@ -298,29 +298,17 @@ def _render_log_id(self, ti: TaskInstance | TaskInstanceKey, try_number: int) -> if USE_PER_RUN_LOG_ID: log_id_template = dag_run.get_log_template(session=session).elasticsearch_id - if TYPE_CHECKING: - assert ti.task - try: - dag = ti.task.dag - except AttributeError: # ti.task is not always set. - data_interval = (dag_run.data_interval_start, dag_run.data_interval_end) - else: - if TYPE_CHECKING: - assert dag is not None - # TODO: Task-SDK: Where should this function be? - data_interval = dag.get_run_data_interval(dag_run) # type: ignore[attr-defined] - if self.json_format: - data_interval_start = self._clean_date(data_interval[0]) - data_interval_end = self._clean_date(data_interval[1]) + data_interval_start = self._clean_date(dag_run.data_interval_start) + data_interval_end = self._clean_date(dag_run.data_interval_end) logical_date = self._clean_date(dag_run.logical_date) else: - if data_interval[0]: - data_interval_start = data_interval[0].isoformat() + if dag_run.data_interval_start: + data_interval_start = dag_run.data_interval_start.isoformat() else: data_interval_start = "" - if data_interval[1]: - data_interval_end = data_interval[1].isoformat() + if dag_run.data_interval_end: + data_interval_end = dag_run.data_interval_end.isoformat() else: data_interval_end = "" logical_date = dag_run.logical_date.isoformat() diff --git a/providers/redis/tests/unit/redis/log/test_redis_task_handler.py b/providers/redis/tests/unit/redis/log/test_redis_task_handler.py index eb03386ca7674..4d1e1927f265f 100644 --- a/providers/redis/tests/unit/redis/log/test_redis_task_handler.py +++ b/providers/redis/tests/unit/redis/log/test_redis_task_handler.py @@ -22,22 +22,29 @@ import pytest -from airflow.models import DAG, DagRun, TaskInstance -from airflow.models.serialized_dag import SerializedDagModel +from airflow.models import DagRun, TaskInstance from airflow.providers.redis.log.redis_task_handler import RedisTaskHandler from airflow.providers.standard.operators.empty import EmptyOperator from airflow.utils.session import create_session from airflow.utils.state import State -from airflow.utils.timezone import datetime from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.dag import sync_dag_to_db from tests_common.test_utils.db import clear_db_dag_bundles, clear_db_dags, clear_db_runs from tests_common.test_utils.file_task_handler import extract_events from tests_common.test_utils.version_compat import ( AIRFLOW_V_3_0_PLUS, + AIRFLOW_V_3_1_PLUS, get_base_airflow_version_tuple, ) +if AIRFLOW_V_3_1_PLUS: + from airflow.sdk import DAG + from airflow.sdk.timezone import datetime +else: + from airflow.models import DAG + from airflow.utils.timezone import datetime # type: ignore[no-redef] + class TestRedisTaskHandler: @staticmethod @@ -72,20 +79,25 @@ def ti(self): dag_run.set_state(State.RUNNING) with create_session() as session: session.add(dag_run) - session.commit() + session.flush() session.refresh(dag_run) + bundle_name = "testing" + if AIRFLOW_V_3_1_PLUS: + sync_dag_to_db(dag, bundle_name=bundle_name, session=session) + elif AIRFLOW_V_3_0_PLUS: + from airflow.models.dagbundle import DagBundleModel + from airflow.models.serialized_dag import SerializedDagModel + from airflow.serialization.serialized_objects import SerializedDAG + + session.add(DagBundleModel(name=bundle_name)) + session.flush() + SerializedDAG.bulk_write_to_db(bundle_name, None, [dag]) + SerializedDagModel.write_dag(dag, bundle_name=bundle_name) + if AIRFLOW_V_3_0_PLUS: from airflow.models.dag_version import DagVersion - from airflow.models.dagbundle import DagBundleModel - bundle_name = "testing" - with create_session() as session: - orm_dag_bundle = DagBundleModel(name=bundle_name) - session.add(orm_dag_bundle) - session.commit() - DAG.bulk_write_to_db(bundle_name, None, [dag]) - SerializedDagModel.write_dag(dag, bundle_name=bundle_name) dag_version = DagVersion.get_latest_version(dag.dag_id) ti = TaskInstance(task=task, run_id=dag_run.run_id, dag_version_id=dag_version.id) else: diff --git a/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py b/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py index 3d91f29bcaff0..9d4d6a5c114ab 100644 --- a/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py +++ b/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py @@ -36,12 +36,16 @@ SnowflakeValueCheckOperator, ) from airflow.providers.snowflake.triggers.snowflake_trigger import SnowflakeSqlApiTrigger -from airflow.utils import timezone -from airflow.utils.session import create_session from airflow.utils.types import DagRunType +from tests_common.test_utils.dag import sync_dag_to_db from tests_common.test_utils.db import clear_db_dag_bundles, clear_db_dags, clear_db_runs -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS + +if AIRFLOW_V_3_1_PLUS: + from airflow.sdk import timezone +else: + from airflow.utils import timezone # type: ignore[attr-defined,no-redef] DEFAULT_DATE = timezone.datetime(2015, 1, 1) DEFAULT_DATE_ISO = DEFAULT_DATE.isoformat() @@ -237,16 +241,8 @@ def create_context(task, dag=None): logical_date = timezone.datetime(2022, 1, 1, 1, 0, 0, tzinfo=tzinfo) if AIRFLOW_V_3_0_PLUS: from airflow.models.dag_version import DagVersion - from airflow.models.dagbundle import DagBundleModel - from airflow.models.serialized_dag import SerializedDagModel - - bundle_name = "testing" - with create_session() as session: - orm_dag_bundle = DagBundleModel(name=bundle_name) - session.add(orm_dag_bundle) - session.commit() - DAG.bulk_write_to_db(bundle_name, None, [dag]) - SerializedDagModel.write_dag(dag, bundle_name=bundle_name) + + sync_dag_to_db(dag) dag_version = DagVersion.get_latest_version(dag.dag_id) task_instance = TaskInstance(task=task, run_id="test_run_id", dag_version_id=dag_version.id) dag_run = DagRun( @@ -350,7 +346,10 @@ def test_snowflake_sql_api_to_fails_when_one_query_fails( with pytest.raises(AirflowException): operator.execute(context=None) - @pytest.mark.parametrize("mock_sql, statement_count", [(SQL_MULTIPLE_STMTS, 4), (SINGLE_STMT, 1)]) + @pytest.mark.parametrize( + "mock_sql, statement_count", + [pytest.param(SQL_MULTIPLE_STMTS, 4, id="multi"), pytest.param(SINGLE_STMT, 1, id="single")], + ) @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.execute_query") def test_snowflake_sql_api_execute_operator_async( self, mock_execute_query, mock_sql, statement_count, mock_get_sql_api_query_status diff --git a/providers/ssh/tests/unit/ssh/operators/test_ssh.py b/providers/ssh/tests/unit/ssh/operators/test_ssh.py index 2222db83cc74a..3b2cc03a3c051 100644 --- a/providers/ssh/tests/unit/ssh/operators/test_ssh.py +++ b/providers/ssh/tests/unit/ssh/operators/test_ssh.py @@ -27,18 +27,22 @@ from airflow.exceptions import AirflowException, AirflowSkipException, AirflowTaskTimeout from airflow.models import TaskInstance -from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.ssh.hooks.ssh import SSHHook from airflow.providers.ssh.operators.ssh import SSHOperator -from airflow.utils.timezone import datetime from airflow.utils.types import NOTSET from tests_common.test_utils.config import conf_vars -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.dag import sync_dag_to_db +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS if AIRFLOW_V_3_0_PLUS: from airflow.models.dag_version import DagVersion +if AIRFLOW_V_3_1_PLUS: + from airflow.sdk.timezone import datetime +else: + from airflow.utils.timezone import datetime # type: ignore[attr-defined,no-redef] + pytestmark = pytest.mark.db_test @@ -272,8 +276,7 @@ def test_push_ssh_exit_to_xcom(self, request, dag_maker): task = SSHOperator(task_id="push_xcom", ssh_hook=self.hook, command=command) dr = dag_maker.create_dagrun(run_id="push_xcom") if AIRFLOW_V_3_0_PLUS: - dag.sync_to_db() - SerializedDagModel.write_dag(dag, bundle_name="testing") + sync_dag_to_db(dag) dag_version = DagVersion.get_latest_version(dag.dag_id) ti = TaskInstance(task=task, run_id=dr.run_id, dag_version_id=dag_version.id) else: diff --git a/providers/standard/tests/unit/standard/decorators/test_bash.py b/providers/standard/tests/unit/standard/decorators/test_bash.py index cf61a137c3bea..0448f93e37d59 100644 --- a/providers/standard/tests/unit/standard/decorators/test_bash.py +++ b/providers/standard/tests/unit/standard/decorators/test_bash.py @@ -27,10 +27,9 @@ from airflow.exceptions import AirflowException, AirflowSkipException from airflow.models.renderedtifields import RenderedTaskInstanceFields -from airflow.utils import timezone from tests_common.test_utils.db import clear_db_dags, clear_db_runs, clear_rendered_ti_fields -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS if TYPE_CHECKING: from airflow.models import TaskInstance @@ -43,6 +42,10 @@ # bad hack but does the job from airflow.decorators import task # type: ignore[attr-defined,no-redef] from airflow.utils.types import NOTSET as SET_DURING_EXECUTION # type: ignore[assignment] +if AIRFLOW_V_3_1_PLUS: + from airflow.sdk import timezone +else: + from airflow.utils import timezone # type: ignore[attr-defined,no-redef] DEFAULT_DATE = timezone.datetime(2023, 1, 1) @@ -73,16 +76,15 @@ def execute_task(self, task): return ti, return_val - @staticmethod - def validate_bash_command_rtif(ti, expected_command): + def validate_bash_command_rtif(self, ti, expected_command): if AIRFLOW_V_3_0_PLUS: - assert ti.task.overwrite_rtif_after_execution + assert self.dag.get_task(ti.task_id).overwrite_rtif_after_execution else: assert RenderedTaskInstanceFields.get_templated_fields(ti)["bash_command"] == expected_command def test_bash_decorator_init(self): """Test the initialization of the @task.bash decorator.""" - with self.dag: + with self.dag_maker: @task.bash def bash(): ... @@ -112,7 +114,7 @@ def bash(): ... ) def test_bash_command(self, command, expected_command, expected_return_val): """Test the runtime bash_command is the function's return string, rendered if needed.""" - with self.dag: + with self.dag_maker: @task.bash def bash(): @@ -130,7 +132,7 @@ def bash(): def test_op_args_kwargs(self): """Test op_args and op_kwargs are passed to the bash_command.""" - with self.dag: + with self.dag_maker: @task.bash def bash(id, other_id): @@ -155,7 +157,7 @@ def test_multiline_command(self): """ excepted_command = command.format(foo="foo") - with self.dag: + with self.dag_maker: @task.bash def bash(foo): @@ -181,7 +183,7 @@ def bash(foo): ) def test_env_variables(self, append_env, user_defined_env, expected_airflow_home, caplog): """Test env variables exist appropriately depending on if the existing env variables are allowed.""" - with self.dag: + with self.dag_maker: @task.bash(env=user_defined_env, append_env=append_env) def bash(): @@ -210,7 +212,7 @@ def bash(): ) def test_exit_code_behavior(self, exit_code, expected): """Test @task.bash tasks behave appropriately relative the exit code from the bash_command.""" - with self.dag: + with self.dag_maker: @task.bash def bash(code): @@ -254,7 +256,7 @@ def bash(code): ) def test_skip_on_exit_code_behavior(self, skip_on_exit_code, exit_code, expected): """Ensure tasks behave appropriately relative to defined skip exit code from the bash_command.""" - with self.dag: + with self.dag_maker: @task.bash(**skip_on_exit_code if skip_on_exit_code else {}) def bash(code): @@ -300,18 +302,17 @@ def test_env_variables_in_bash_command_file( # setting chmod +x test_file.sh cmd_file.chmod(0o755) - with self.dag: + with self.dag_maker: @task.bash(env=user_defined_env, append_env=append_env) def bash(command_file_name): return command_file_name - with mock.patch.dict("os.environ", {"AIRFLOW_HOME": "path/to/airflow/home"}): - bash_task = bash(f"{cmd_file} ") - - assert bash_task.operator.bash_command == SET_DURING_EXECUTION + bash_task = bash(f"{cmd_file} ") + assert bash_task.operator.bash_command == SET_DURING_EXECUTION - ti, return_val = self.execute_task(bash_task) + with mock.patch.dict("os.environ", {"AIRFLOW_HOME": "path/to/airflow/home"}): + ti, return_val = self.execute_task(bash_task) assert f"razz={expected_razz}" in caplog.text assert f"AIRFLOW_HOME={expected_airflow_home}" in caplog.text @@ -323,7 +324,7 @@ def test_valid_cwd(self, tmp_path): cwd_path = tmp_path / "test_cwd" cwd_path.mkdir() - with self.dag: + with self.dag_maker: @task.bash(cwd=os.fspath(cwd_path)) def bash(): @@ -402,7 +403,7 @@ def bash(): def test_multiple_outputs_true(self): """Verify setting `multiple_outputs` for a @task.bash-decorated function is ignored.""" - with self.dag: + with self.dag_maker: @task.bash(multiple_outputs=True) def bash(): @@ -415,7 +416,7 @@ def bash(): assert bash_task.operator.bash_command == SET_DURING_EXECUTION - ti, _ = self.execute_task(bash_task) + ti, _ = self.execute_task(bash_task) assert bash_task.operator.multiple_outputs is False self.validate_bash_command_rtif(ti, "echo") @@ -430,20 +431,18 @@ def test_multiple_outputs(self, multiple_outputs): if multiple_outputs is not SET_DURING_EXECUTION: decorator_kwargs["multiple_outputs"] = multiple_outputs - with self.dag: + with self.dag_maker: @task.bash(**decorator_kwargs) def bash(): return "echo" - with warnings.catch_warnings(): - warnings.simplefilter("error", category=UserWarning) - - bash_task = bash() - - assert bash_task.operator.bash_command == SET_DURING_EXECUTION + bash_task = bash() + assert bash_task.operator.bash_command == SET_DURING_EXECUTION - ti, _ = self.execute_task(bash_task) + with warnings.catch_warnings(): + warnings.simplefilter("error", category=UserWarning) + ti, _ = self.execute_task(bash_task) assert bash_task.operator.multiple_outputs is False self.validate_bash_command_rtif(ti, "echo") @@ -465,7 +464,7 @@ def bash(): ) def test_callable_return_is_string(self, return_val, expected): """Ensure the returned value from the decorated callable is a non-empty string.""" - with self.dag: + with self.dag_maker: @task.bash def bash(): diff --git a/providers/standard/tests/unit/standard/decorators/test_python.py b/providers/standard/tests/unit/standard/decorators/test_python.py index e7a3b1697604b..80b7e4f2b56a4 100644 --- a/providers/standard/tests/unit/standard/decorators/test_python.py +++ b/providers/standard/tests/unit/standard/decorators/test_python.py @@ -25,16 +25,14 @@ from airflow.exceptions import AirflowException, XComNotFound from airflow.models.taskinstance import TaskInstance from airflow.models.taskmap import TaskMap - -try: - from airflow.sdk import TriggerRule -except ImportError: - # Compatibility for Airflow < 3.1 - from airflow.utils.trigger_rule import TriggerRule # type: ignore[no-redef,attr-defined] -from airflow.utils import timezone from airflow.utils.task_instance_session import set_current_task_instance_session -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_1, AIRFLOW_V_3_0_PLUS, XCOM_RETURN_KEY +from tests_common.test_utils.version_compat import ( + AIRFLOW_V_3_0_1, + AIRFLOW_V_3_0_PLUS, + AIRFLOW_V_3_1_PLUS, + XCOM_RETURN_KEY, +) from unit.standard.operators.test_python import BasePythonTest if AIRFLOW_V_3_0_PLUS: @@ -42,7 +40,6 @@ from airflow.sdk.bases.decorator import DecoratedMappedOperator from airflow.sdk.definitions._internal.expandinput import DictOfListsExpandInput from airflow.sdk.definitions.mappedoperator import MappedOperator - else: from airflow.decorators import ( # type: ignore[attr-defined,no-redef] setup, @@ -51,12 +48,18 @@ ) from airflow.decorators.base import DecoratedMappedOperator # type: ignore[no-redef] from airflow.models.baseoperator import BaseOperator # type: ignore[no-redef] - from airflow.models.dag import DAG # type: ignore[assignment] + from airflow.models.dag import DAG # type: ignore[assignment,no-redef] from airflow.models.expandinput import DictOfListsExpandInput from airflow.models.mappedoperator import MappedOperator # type: ignore[assignment,no-redef] from airflow.models.xcom_arg import XComArg from airflow.utils.task_group import TaskGroup # type: ignore[no-redef] +if AIRFLOW_V_3_1_PLUS: + from airflow.sdk import TriggerRule, timezone +else: + from airflow.utils import timezone # type: ignore[attr-defined,no-redef] + from airflow.utils.trigger_rule import TriggerRule # type: ignore[no-redef,attr-defined] + pytestmark = pytest.mark.db_test diff --git a/providers/standard/tests/unit/standard/operators/test_python.py b/providers/standard/tests/unit/standard/operators/test_python.py index e0c9d70245adb..62ac1f1ae39a5 100644 --- a/providers/standard/tests/unit/standard/operators/test_python.py +++ b/providers/standard/tests/unit/standard/operators/test_python.py @@ -63,30 +63,30 @@ get_current_context, ) from airflow.providers.standard.utils.python_virtualenv import execute_in_subprocess, prepare_virtualenv - -try: - from airflow.sdk import timezone -except ImportError: - from airflow.utils import timezone # type: ignore[attr-defined,no-redef] -try: - from airflow.sdk import TriggerRule -except ImportError: - # Compatibility for Airflow < 3.1 - from airflow.utils.trigger_rule import TriggerRule # type: ignore[no-redef,attr-defined] from airflow.utils.session import create_session from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.types import NOTSET, DagRunType from tests_common.test_utils.db import clear_db_runs -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_1, AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_1, AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS if AIRFLOW_V_3_0_PLUS: from airflow.sdk import BaseOperator from airflow.sdk.execution_time.context import set_current_context + from airflow.serialization.serialized_objects import LazyDeserializedDAG else: from airflow.models.baseoperator import BaseOperator # type: ignore[no-redef] from airflow.models.taskinstance import set_current_context # type: ignore[attr-defined,no-redef] +try: + from airflow.sdk import timezone +except ImportError: + from airflow.utils import timezone # type: ignore[attr-defined,no-redef] +try: + from airflow.sdk import TriggerRule +except ImportError: + # Compatibility for Airflow < 3.1 + from airflow.utils.trigger_rule import TriggerRule # type: ignore[no-redef,attr-defined] if TYPE_CHECKING: from airflow.models.dag import DAG @@ -164,7 +164,12 @@ def create_dag_run(self) -> DagRun: from airflow.models.serialized_dag import SerializedDagModel # Update the serialized DAG with any tasks added after initial dag was created - self.dag_maker.serialized_model = SerializedDagModel(self.dag_non_serialized) + if AIRFLOW_V_3_1_PLUS: + self.dag_maker.serialized_model = SerializedDagModel( + LazyDeserializedDAG.from_dag(self.dag_non_serialized) + ) + else: + self.dag_maker.serialized_model = SerializedDagModel(self.dag_non_serialized) return self.dag_maker.create_dagrun( state=DagRunState.RUNNING, start_date=self.dag_maker.start_date, diff --git a/providers/standard/tests/unit/standard/operators/test_trigger_dagrun.py b/providers/standard/tests/unit/standard/operators/test_trigger_dagrun.py index cbd80bb9def43..dc3f32d535035 100644 --- a/providers/standard/tests/unit/standard/operators/test_trigger_dagrun.py +++ b/providers/standard/tests/unit/standard/operators/test_trigger_dagrun.py @@ -23,6 +23,7 @@ import pytest import time_machine +from sqlalchemy import update from airflow.configuration import conf from airflow.exceptions import AirflowException, DagRunAlreadyExists, TaskDeferred @@ -32,16 +33,19 @@ from airflow.models.taskinstance import TaskInstance from airflow.providers.standard.operators.trigger_dagrun import DagIsPaused, TriggerDagRunOperator from airflow.providers.standard.triggers.external_task import DagStateTrigger -from airflow.utils import timezone from airflow.utils.session import create_session from airflow.utils.state import DagRunState, TaskInstanceState from airflow.utils.types import DagRunType from tests_common.test_utils.db import parse_and_sync_to_db -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS if AIRFLOW_V_3_0_PLUS: from airflow.exceptions import DagRunTriggerException +if AIRFLOW_V_3_1_PLUS: + from airflow.sdk import timezone +else: + from airflow.utils import timezone # type: ignore[attr-defined,no-redef] pytestmark = pytest.mark.db_test @@ -749,9 +753,12 @@ def test_dagstatetrigger_run_id_with_clear_and_reset(self, dag_maker): # The second DagStateTrigger call should still use the original `logical_date` value. assert mock_task_defer.call_args_list[1].kwargs["trigger"].run_ids == [run_id] - def test_trigger_dagrun_with_fail_when_dag_is_paused(self, dag_maker): + def test_trigger_dagrun_with_fail_when_dag_is_paused(self, dag_maker, session): """Test TriggerDagRunOperator with fail_when_dag_is_paused set to True.""" - self.dag_model.set_is_paused(True) + session.execute( + update(DagModel).where(DagModel.dag_id == self.dag_model.dag_id).values(is_paused=True) + ) + session.commit() with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True diff --git a/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py b/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py index 76688b5cf3334..a5a068f088f7c 100644 --- a/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py +++ b/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py @@ -50,9 +50,10 @@ from airflow.serialization.serialized_objects import SerializedBaseOperator from airflow.timetables.base import DataInterval from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.state import DagRunState, State, TaskInstanceState +from airflow.utils.state import DagRunState, State from airflow.utils.types import DagRunType +from tests_common.test_utils.dag import create_scheduler_dag, sync_dag_to_db, sync_dags_to_db from tests_common.test_utils.db import clear_db_runs from tests_common.test_utils.mock_operators import MockOperator from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS @@ -1685,7 +1686,7 @@ def run_tasks( for dag in dag_bag.dags.values(): data_interval = DataInterval(coerce_datetime(logical_date), coerce_datetime(logical_date)) if AIRFLOW_V_3_0_PLUS: - runs[dag.dag_id] = dagrun = dag.create_dagrun( + runs[dag.dag_id] = dagrun = create_scheduler_dag(dag).create_dagrun( run_id=dag.timetable.generate_run_id( run_type=DagRunType.MANUAL, run_after=logical_date, @@ -1701,7 +1702,7 @@ def run_tasks( session=session, ) else: - runs[dag.dag_id] = dagrun = dag.create_dagrun( # type: ignore[call-arg] + runs[dag.dag_id] = dagrun = dag.create_dagrun( # type: ignore[attr-defined,call-arg] run_id=dag.timetable.generate_run_id( # type: ignore[call-arg] run_type=DagRunType.MANUAL, logical_date=logical_date, @@ -1916,9 +1917,7 @@ def _factory(depth: int) -> DagBag: for dag in dags: if AIRFLOW_V_3_0_PLUS: - dag.sync_to_db() - SerializedDagModel.write_dag(dag, bundle_name="testing") - dag_bag.bag_dag(dag=dag) + dag_bag.bag_dag(dag=sync_dag_to_db(dag)) else: dag_bag.bag_dag(dag=dag, root_dag=dag) # type: ignore[call-arg] @@ -1993,34 +1992,11 @@ def dag_bag_multiple(session): begin >> task if AIRFLOW_V_3_0_PLUS: - from airflow.models.dagbundle import DagBundleModel - - bundle_name = "abcbunhdlerch3rc" - session.merge(DagBundleModel(name=bundle_name)) - session.flush() - DAG.bulk_write_to_db(bundle_name=bundle_name, dags=[daily_dag, agg_dag], bundle_version=None) - SerializedDagModel.write_dag(dag=daily_dag, bundle_name=bundle_name) - SerializedDagModel.write_dag(dag=agg_dag, bundle_name=bundle_name) + sync_dags_to_db([agg_dag, daily_dag]) return dag_bag -def test_clear_multiple_external_task_marker(dag_bag_multiple): - """ - Test clearing a dag that has multiple ExternalTaskMarker. - """ - agg_dag = dag_bag_multiple.get_dag("agg_dag") - _, tis = run_tasks(dag_bag_multiple, logical_date=DEFAULT_DATE) - session = settings.Session() - try: - qry = session.query(TaskInstance).filter( - TaskInstance.state.is_(None), TaskInstance.dag_id.in_(dag_bag_multiple.dag_ids) - ) - assert agg_dag.clear(dag_bag=dag_bag_multiple) == len(tis) == qry.count() == 10 - finally: - session.close() - - @pytest.fixture def dag_bag_head_tail(session): """ @@ -2058,101 +2034,16 @@ def dag_bag_head_tail(session): head >> body >> tail if AIRFLOW_V_3_0_PLUS: - from airflow.models.dagbundle import DagBundleModel - - dag_bag.bag_dag(dag=dag) - bundle_name = "9e8uh9odhu9c" - session.merge(DagBundleModel(name=bundle_name)) - session.flush() - DAG.bulk_write_to_db(bundle_name=bundle_name, dags=[dag], bundle_version=None) - SerializedDagModel.write_dag(dag=dag, bundle_name=bundle_name) + dag_bag.bag_dag(dag) + sync_dag_to_db(dag) else: dag_bag.bag_dag(dag=dag, root_dag=dag) return dag_bag -@provide_session -def test_clear_overlapping_external_task_marker(dag_bag_head_tail, session): - dag: DAG = dag_bag_head_tail.get_dag("head_tail") - dag.sync_to_db() - if AIRFLOW_V_3_0_PLUS: - SerializedDagModel.write_dag(dag, bundle_name="testing") - - # "Run" 10 times. - for delta in range(10): - logical_date = DEFAULT_DATE + timedelta(days=delta) - dagrun = DagRun( - dag_id=dag.dag_id, - start_date=logical_date, - state=DagRunState.SUCCESS, - run_type=DagRunType.MANUAL, - run_id=f"test_{delta}", - ) - if AIRFLOW_V_3_0_PLUS: - dagrun.logical_date = logical_date - else: - dagrun.execution_date = logical_date - session.add(dagrun) - for task in dag.tasks: - if AIRFLOW_V_3_0_PLUS: - dag_version = DagVersion.get_latest_version(task.dag_id, session=session) - ti = TaskInstance(task=task, dag_version_id=dag_version.id) - - else: - ti = TaskInstance(task=task) - dagrun.task_instances.append(ti) - ti.state = TaskInstanceState.SUCCESS - session.flush() - - assert dag.clear(start_date=DEFAULT_DATE, dag_bag=dag_bag_head_tail, session=session) == 30 - - -def test_clear_overlapping_external_task_marker_with_end_date(dag_bag_head_tail, session): - dag: DAG = dag_bag_head_tail.get_dag("head_tail") - dag.sync_to_db() - if AIRFLOW_V_3_0_PLUS: - SerializedDagModel.write_dag(dag=dag, bundle_name="testing") - - # "Run" 10 times. - for delta in range(10): - logical_date = DEFAULT_DATE + timedelta(days=delta) - dagrun = DagRun( - dag_id=dag.dag_id, - start_date=logical_date, - state=DagRunState.SUCCESS, - run_type=DagRunType.MANUAL, - run_id=f"test_{delta}", - ) - if AIRFLOW_V_3_0_PLUS: - dagrun.logical_date = logical_date - else: - dagrun.execution_date = logical_date - session.add(dagrun) - - for task in dag.tasks: - if AIRFLOW_V_3_0_PLUS: - dag_version = DagVersion.get_latest_version(dag.dag_id, session=session) - ti = TaskInstance(task=task, dag_version_id=dag_version.id) - else: - ti = TaskInstance(task=task) - dagrun.task_instances.append(ti) - ti.state = TaskInstanceState.SUCCESS - session.flush() - - assert ( - dag.clear( - start_date=DEFAULT_DATE, - end_date=logical_date, - dag_bag=dag_bag_head_tail, - session=session, - ) - == 30 - ) - - @pytest.fixture -def dag_bag_head_tail_mapped_tasks(): +def dag_bag_head_tail_mapped_tasks(session): """ Create a DagBag containing one DAG, with task "head" depending on task "tail" of the previous logical_date. @@ -2194,79 +2085,8 @@ def dummy_task(x: int): head >> body >> tail if AIRFLOW_V_3_0_PLUS: - dag.sync_to_db() - dag_bag.bag_dag(dag=dag) + sync_dag_to_db(dag) else: dag_bag.bag_dag(dag=dag, root_dag=dag) return dag_bag - - -def test_clear_overlapping_external_task_marker_mapped_tasks(dag_bag_head_tail_mapped_tasks, session): - dag: DAG = dag_bag_head_tail_mapped_tasks.get_dag("head_tail") - dag.sync_to_db() - if AIRFLOW_V_3_0_PLUS: - SerializedDagModel.write_dag(dag=dag, bundle_name="testing") - # "Run" 10 times. - for delta in range(10): - logical_date = DEFAULT_DATE + timedelta(days=delta) - dagrun = DagRun( - dag_id=dag.dag_id, - start_date=logical_date, - state=DagRunState.SUCCESS, - run_type=DagRunType.MANUAL, - run_id=f"test_{delta}", - ) - if AIRFLOW_V_3_0_PLUS: - dagrun.logical_date = logical_date - else: - dagrun.execution_date = logical_date - session.add(dagrun) - for task in dag.tasks: - if task.task_id == "dummy_task": - for map_index in range(5): - if AIRFLOW_V_3_0_PLUS: - dag_version = DagVersion.get_latest_version(dag.dag_id, session=session) - ti = TaskInstance( - task=task, - run_id=dagrun.run_id, - map_index=map_index, - dag_version_id=dag_version.id, - ) - else: - ti = TaskInstance(task=task, run_id=dagrun.run_id, map_index=map_index) - ti.state = TaskInstanceState.SUCCESS - dagrun.task_instances.append(ti) - else: - if AIRFLOW_V_3_0_PLUS: - dag_version = DagVersion.get_latest_version(dag.dag_id, session=session) - ti = TaskInstance(task=task, run_id=dagrun.run_id, dag_version_id=dag_version.id) - else: - ti = TaskInstance(task=task, run_id=dagrun.run_id) - ti.state = TaskInstanceState.SUCCESS - dagrun.task_instances.append(ti) - session.flush() - if AIRFLOW_V_3_0_PLUS: - dag = dag.partial_subset( - task_ids=["head"], - include_downstream=True, - include_upstream=False, - ) - else: - dag = dag.partial_subset( - task_ids_or_regex=["head"], - include_downstream=True, - include_upstream=False, - ) - - task_ids = list(dag.task_dict) - assert ( - dag.clear( - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, - dag_bag=dag_bag_head_tail_mapped_tasks, - session=session, - task_ids=task_ids, - ) - == 70 - ) diff --git a/providers/standard/tests/unit/standard/sensors/test_time_delta.py b/providers/standard/tests/unit/standard/sensors/test_time_delta.py index 8e85be40f9c3f..4c2ad8129b60b 100644 --- a/providers/standard/tests/unit/standard/sensors/test_time_delta.py +++ b/providers/standard/tests/unit/standard/sensors/test_time_delta.py @@ -82,15 +82,15 @@ def test_timedelta_sensor_run_after_vs_interval(run_after, interval_end, dag_mak if interval_end: context["data_interval_end"] = interval_end delta = timedelta(seconds=1) - with dag_maker() as dag: - op = TimeDeltaSensor(task_id="wait_sensor_check", delta=delta, dag=dag, mode="reschedule") + with dag_maker(): + op = TimeDeltaSensor(task_id="wait_sensor_check", delta=delta, mode="reschedule") kwargs = {} if AIRFLOW_V_3_0_PLUS: from airflow.utils.types import DagRunTriggeredByType kwargs.update(triggered_by=DagRunTriggeredByType.TEST, run_after=run_after) - dr = dag.create_dagrun( + dr = dag_maker.create_dagrun( run_id="abcrhroceuh", run_type=DagRunType.MANUAL, state=None, @@ -120,7 +120,7 @@ def test_timedelta_sensor_deferrable_run_after_vs_interval(run_after, interval_e if interval_end: context["data_interval_end"] = interval_end - with dag_maker() as dag: + with dag_maker(): kwargs = {} if AIRFLOW_V_3_0_PLUS: from airflow.utils.types import DagRunTriggeredByType @@ -131,11 +131,10 @@ def test_timedelta_sensor_deferrable_run_after_vs_interval(run_after, interval_e sensor = TimeDeltaSensor( task_id="timedelta_sensor_deferrable", delta=delta, - dag=dag, deferrable=True, # <-- the feature under test ) - dr = dag.create_dagrun( + dr = dag_maker.create_dagrun( run_id="abcrhroceuh", run_type=DagRunType.MANUAL, state=None, @@ -222,7 +221,7 @@ def test_timedelta_sensor_async_run_after_vs_interval(self, run_after, interval_ kwargs.update(triggered_by=DagRunTriggeredByType.TEST, run_after=run_after) - dr = dag.create_dagrun( + dr = dag_maker.create_dagrun( run_id="abcrhroceuh", run_type=DagRunType.MANUAL, state=None, diff --git a/providers/standard/tests/unit/standard/utils/test_sensor_helper.py b/providers/standard/tests/unit/standard/utils/test_sensor_helper.py index 89735e206db5a..156d2958a54d8 100644 --- a/providers/standard/tests/unit/standard/utils/test_sensor_helper.py +++ b/providers/standard/tests/unit/standard/utils/test_sensor_helper.py @@ -33,18 +33,21 @@ _get_count, _get_external_task_group_task_ids, ) -from airflow.utils import timezone from airflow.utils.state import DagRunState, TaskInstanceState from airflow.utils.types import DagRunType from tests_common.test_utils import db +from tests_common.test_utils.dag import create_scheduler_dag from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS try: from airflow.sdk.definitions.taskgroup import TaskGroup -except ImportError: - # Fallback for Airflow < 3.1 +except ImportError: # Fallback for Airflow < 3.1 from airflow.utils.task_group import TaskGroup # type: ignore[no-redef] +try: + from airflow.sdk import timezone +except ImportError: # Fallback for Airflow < 3.1 + from airflow.utils import timezone # type: ignore[attr-defined,no-redef] if TYPE_CHECKING: from sqlalchemy.orm.session import Session @@ -92,7 +95,7 @@ def create_dag_run( execution_date = pendulum.instance(execution_date or now) run_type = DagRunType.MANUAL data_interval = dag.timetable.infer_manual_data_interval(run_after=execution_date) - dag_run = dag.create_dagrun( + dag_run = create_scheduler_dag(dag).create_dagrun( run_id=dag.timetable.generate_run_id( run_type=run_type, run_after=execution_date, diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py b/task-sdk/src/airflow/sdk/definitions/dag.py index 23c0138a4b9b3..551126e5a3ce4 100644 --- a/task-sdk/src/airflow/sdk/definitions/dag.py +++ b/task-sdk/src/airflow/sdk/definitions/dag.py @@ -93,7 +93,7 @@ DagStateChangeCallback = Callable[[Context], None] ScheduleInterval = None | str | timedelta | relativedelta -ScheduleArg = ScheduleInterval | Timetable | BaseAsset | Collection[BaseAsset] +ScheduleArg: TypeAlias = ScheduleInterval | Timetable | BaseAsset | Collection[BaseAsset] _DAG_HASH_ATTRS = frozenset( @@ -135,12 +135,18 @@ def _create_timetable(interval: ScheduleInterval, timezone: Timezone | FixedTime raise ValueError(f"{interval!r} is not a valid schedule.") -def _config_bool_factory(section: str, key: str): +def _config_bool_factory(section: str, key: str) -> Callable[[], bool]: from airflow.configuration import conf return functools.partial(conf.getboolean, section, key) +def _config_int_factory(section: str, key: str) -> Callable[[], int]: + from airflow.configuration import conf + + return functools.partial(conf.getint, section, key) + + def _convert_params(val: abc.MutableMapping | None, self_: DAG) -> ParamsDict: """ Convert the plain dict into a ParamsDict. @@ -170,10 +176,18 @@ def _convert_tags(tags: Collection[str] | None) -> MutableSet[str]: return set(tags or []) -def _convert_access_control(value, self_: DAG): - if hasattr(self_, "_upgrade_outdated_dag_access_control"): - return self_._upgrade_outdated_dag_access_control(value) - return value +def _convert_access_control(access_control): + if access_control is None: + return None + updated_access_control = {} + for role, perms in access_control.items(): + updated_access_control[role] = updated_access_control.get(role, {}) + if isinstance(perms, (set, list)): + # Support for old-style access_control where only the actions are specified + updated_access_control[role]["DAGs"] = set(perms) + else: + updated_access_control[role] = perms + return updated_access_control def _convert_doc_md(doc_md: str | None) -> str | None: @@ -398,9 +412,33 @@ def __rich_repr__(self): template_undefined: type[jinja2.StrictUndefined] = jinja2.StrictUndefined user_defined_macros: dict | None = None user_defined_filters: dict | None = None - max_active_tasks: int = attrs.field(default=16, validator=attrs.validators.instance_of(int)) - max_active_runs: int = attrs.field(default=16, validator=attrs.validators.instance_of(int)) - max_consecutive_failed_dag_runs: int = attrs.field(default=0, validator=attrs.validators.instance_of(int)) + max_active_tasks: int = attrs.field( + factory=_config_int_factory("core", "max_active_tasks_per_dag"), + converter=attrs.converters.default_if_none( # type: ignore[misc] + # attrs only supports named callables or lambdas, but partial works + # OK here too. This is a false positive from attrs's Mypy plugin. + factory=_config_int_factory("core", "max_active_tasks_per_dag"), + ), + validator=attrs.validators.instance_of(int), + ) + max_active_runs: int = attrs.field( + factory=_config_int_factory("core", "max_active_runs_per_dag"), + converter=attrs.converters.default_if_none( # type: ignore[misc] + # attrs only supports named callables or lambdas, but partial works + # OK here too. This is a false positive from attrs's Mypy plugin. + factory=_config_int_factory("core", "max_active_runs_per_dag"), + ), + validator=attrs.validators.instance_of(int), + ) + max_consecutive_failed_dag_runs: int = attrs.field( + factory=_config_int_factory("core", "max_consecutive_failed_dag_runs_per_dag"), + converter=attrs.converters.default_if_none( # type: ignore[misc] + # attrs only supports named callables or lambdas, but partial works + # OK here too. This is a false positive from attrs's Mypy plugin. + factory=_config_int_factory("core", "max_consecutive_failed_dag_runs_per_dag"), + ), + validator=attrs.validators.instance_of(int), + ) dagrun_timeout: timedelta | None = attrs.field( default=None, validator=attrs.validators.optional(attrs.validators.instance_of(timedelta)), @@ -423,7 +461,7 @@ def __rich_repr__(self): ) access_control: dict[str, dict[str, Collection[str]]] | None = attrs.field( default=None, - converter=attrs.Converter(_convert_access_control, takes_self=True), # type: ignore[misc, call-overload] + converter=attrs.Converter(_convert_access_control), # type: ignore[misc, call-overload] ) is_paused_upon_creation: bool | None = None jinja_environment_kwargs: dict | None = None @@ -455,6 +493,12 @@ def __rich_repr__(self): factory=_config_bool_factory("dag_processor", "disable_bundle_versioning") ) + # TODO (GH-52141): This is never used in the sdk dag (it only makes sense + # after this goes through the dag processor), but various parts of the code + # depends on its existence. We should remove this after completely splitting + # DAG classes in the SDK and scheduler. + last_loaded: datetime | None = attrs.field(init=False, default=None) + def __attrs_post_init__(self): from airflow.sdk import timezone @@ -1097,8 +1141,7 @@ def test( from airflow import settings from airflow.configuration import secrets_backend_list - from airflow.models.dag import DAG as SchedulerDAG, _get_or_create_dagrun - from airflow.models.dagrun import DagRun + from airflow.models.dagrun import DagRun, get_or_create_dagrun from airflow.sdk import timezone from airflow.secrets.local_filesystem import LocalFilesystemBackend from airflow.serialization.serialized_objects import SerializedDAG @@ -1147,7 +1190,7 @@ def add_logger_if_needed(ti: TaskInstance): log.debug("Clearing existing task instances for logical date %s", logical_date) # TODO: Replace with calling client.dag_run.clear in Execution API at some point - SchedulerDAG.clear_dags( + SerializedDAG.clear_dags( dags=[self], start_date=logical_date, end_date=logical_date, @@ -1170,7 +1213,7 @@ def add_logger_if_needed(ti: TaskInstance): scheduler_dag.on_success_callback = self.on_success_callback scheduler_dag.on_failure_callback = self.on_failure_callback - dr: DagRun = _get_or_create_dagrun( + dr: DagRun = get_or_create_dagrun( dag=scheduler_dag, start_date=logical_date or run_after, logical_date=logical_date,