Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions airflow-core/src/airflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,21 +80,20 @@

# Things to lazy import in form {local_name: ('target_module', 'target_name', 'deprecated')}
__lazy_imports: dict[str, tuple[str, str, bool]] = {
"DAG": (".models.dag", "DAG", False),
"Asset": (".assets", "Asset", False),
"DAG": (".sdk", "DAG", False),
"Asset": (".sdk", "Asset", False),
"XComArg": (".models.xcom_arg", "XComArg", False),
"version": (".version", "", False),
# Deprecated lazy imports
"AirflowException": (".exceptions", "AirflowException", True),
"Dataset": (".sdk.definitions.asset", "Asset", True),
"Dataset": (".sdk", "Asset", True),
}
if TYPE_CHECKING:
# These objects are imported by PEP-562, however, static analyzers and IDE's
# have no idea about typing of these objects.
# Add it under TYPE_CHECKING block should help with it.
from airflow.models.dag import DAG
from airflow.models.xcom_arg import XComArg
from airflow.sdk.definitions.asset import Asset, Dataset
from airflow.sdk import DAG, Asset, Asset as Dataset


def __getattr__(name: str):
Expand Down
58 changes: 34 additions & 24 deletions airflow-core/src/airflow/api/common/mark_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,8 @@
if TYPE_CHECKING:
from sqlalchemy.orm import Session as SASession

from airflow.models.dag import DAG
from airflow.models.mappedoperator import MappedOperator
from airflow.serialization.serialized_objects import SerializedBaseOperator
from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG

Operator: TypeAlias = MappedOperator | SerializedBaseOperator

Expand Down Expand Up @@ -86,9 +85,6 @@ def set_state(
if not run_id:
raise ValueError("Received tasks with no run_id")

if TYPE_CHECKING:
assert isinstance(dag, DAG)

dag_run_ids = get_run_ids(dag, run_id, future, past, session=session)
task_id_map_index_list = list(find_task_relatives(tasks, downstream, upstream))
# now look for the task instances that are affected
Expand All @@ -106,7 +102,7 @@ def set_state(


def get_all_dag_task_query(
dag: DAG,
dag: SerializedDAG,
state: TaskInstanceState,
task_ids: list[str | tuple[str, int]],
run_ids: Iterable[str],
Expand Down Expand Up @@ -142,30 +138,44 @@ def find_task_relatives(tasks, downstream, upstream):


@provide_session
def get_run_ids(dag: DAG, run_id: str, future: bool, past: bool, session: SASession = NEW_SESSION):
def get_run_ids(dag: SerializedDAG, run_id: str, future: bool, past: bool, session: SASession = NEW_SESSION):
"""Return DAG executions' run_ids."""
current_dagrun = dag.get_dagrun(run_id=run_id, session=session)
if current_dagrun.logical_date is None:
current_logical_date = session.scalar(
select(DagRun.logical_date).where(DagRun.dag_id == dag.dag_id, DagRun.run_id == run_id)
)
if current_logical_date is None:
return [run_id]

last_dagrun = dag.get_last_dagrun(include_manually_triggered=True, session=session)
first_dagrun = session.scalar(
select(DagRun)
last_logical_date = session.scalar(
select(DagRun.logical_date)
.where(DagRun.dag_id == dag.dag_id, DagRun.logical_date.is_not(None))
.order_by(DagRun.logical_date.asc())
.order_by(DagRun.logical_date.desc())
.limit(1)
)
if last_dagrun is None:
if last_logical_date is None:
raise ValueError(f"DagRun for {dag.dag_id} not found")

first_logical_date = session.scalar(
select(DagRun.logical_date)
.where(DagRun.dag_id == dag.dag_id, DagRun.logical_date.is_not(None))
.order_by(DagRun.logical_date.asc())
.limit(1)
)

# determine run_id range of dag runs and tasks to consider
end_date = last_dagrun.logical_date if future else current_dagrun.logical_date
start_date = current_dagrun.logical_date if not past else first_dagrun.logical_date
end_date = last_logical_date if future else current_logical_date
start_date = current_logical_date if not past else first_logical_date
if not dag.timetable.can_be_scheduled:
# If the DAG never schedules, need to look at existing DagRun if the user wants future or
# past runs.
dag_runs = dag.get_dagruns_between(start_date=start_date, end_date=end_date, session=session)
run_ids = sorted({d.run_id for d in dag_runs})
# If the DAG never schedules, need to look at existing DagRun if the
# user wants future or past runs.
dag_runs = session.scalars(
select(DagRun).where(
DagRun.dag_id == dag.dag_id,
DagRun.logical_date >= start_date,
DagRun.logical_date <= end_date,
)
)
run_ids = sorted(d.run_id for d in dag_runs)
elif not dag.timetable.periodic:
run_ids = [run_id]
else:
Expand Down Expand Up @@ -195,7 +205,7 @@ def _set_dag_run_state(dag_id: str, run_id: str, state: DagRunState, session: SA
@provide_session
def set_dag_run_state_to_success(
*,
dag: DAG,
dag: SerializedDAG,
run_id: str | None = None,
commit: bool = False,
session: SASession = NEW_SESSION,
Expand Down Expand Up @@ -256,7 +266,7 @@ def set_dag_run_state_to_success(
@provide_session
def set_dag_run_state_to_failed(
*,
dag: DAG,
dag: SerializedDAG,
run_id: str | None = None,
commit: bool = False,
session: SASession = NEW_SESSION,
Expand Down Expand Up @@ -348,7 +358,7 @@ def _set_runing_task(task: Operator) -> Operator:
def __set_dag_run_state_to_running_or_queued(
*,
new_state: DagRunState,
dag: DAG,
dag: SerializedDAG,
run_id: str | None = None,
commit: bool = False,
session: SASession,
Expand Down Expand Up @@ -379,7 +389,7 @@ def __set_dag_run_state_to_running_or_queued(
@provide_session
def set_dag_run_state_to_queued(
*,
dag: DAG,
dag: SerializedDAG,
run_id: str | None = None,
commit: bool = False,
session: SASession = NEW_SESSION,
Expand Down
10 changes: 5 additions & 5 deletions airflow-core/src/airflow/api_fastapi/common/dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from airflow.models.dagbag import DBDagBag

if TYPE_CHECKING:
from airflow.models.dag import DAG
from airflow.models.dagrun import DagRun
from airflow.serialization.serialized_objects import SerializedDAG


def create_dag_bag() -> DBDagBag:
Expand All @@ -45,7 +45,7 @@ def dag_bag_from_app(request: Request) -> DBDagBag:

def get_latest_version_of_dag(
dag_bag: DBDagBag, dag_id: str, session: Session, include_reason: bool = False
) -> DAG:
) -> SerializedDAG:
dag = dag_bag.get_latest_version_of_dag(dag_id, session=session)
if not dag:
if include_reason:
Expand All @@ -60,7 +60,7 @@ def get_latest_version_of_dag(
return dag


def get_dag_for_run(dag_bag: DBDagBag, dag_run: DagRun, session: Session) -> DAG:
def get_dag_for_run(dag_bag: DBDagBag, dag_run: DagRun, session: Session) -> SerializedDAG:
dag = dag_bag.get_dag_for_run(dag_run, session=session)
if not dag:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"The Dag with ID: `{dag_run.dag_id}` was not found")
Expand All @@ -69,8 +69,8 @@ def get_dag_for_run(dag_bag: DBDagBag, dag_run: DagRun, session: Session) -> DAG

def get_dag_for_run_or_latest_version(
dag_bag: DBDagBag, dag_run: DagRun | None, dag_id: str | None, session: Session
) -> DAG:
dag: DAG | None = None
) -> SerializedDAG:
dag: SerializedDAG | None = None
if dag_run:
dag = dag_bag.get_dag_for_run(dag_run, session=session)
elif dag_id:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from airflow.utils.types import DagRunTriggeredByType, DagRunType

if TYPE_CHECKING:
from airflow.models import DAG
from airflow.serialization.serialized_objects import SerializedDAG


class DAGRunPatchStates(str, Enum):
Expand Down Expand Up @@ -113,7 +113,7 @@ def check_data_intervals(cls, values):
)
return values

def validate_context(self, dag: DAG) -> dict:
def validate_context(self, dag: SerializedDAG) -> dict:
coerced_logical_date = timezone.coerce_datetime(self.logical_date)
run_after = self.run_after or timezone.utcnow()
data_interval = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,6 @@ def post_clear_task_instances(
common_params = {
"dry_run": True,
"task_ids": task_ids,
"dag_bag": dag_bag,
"session": session,
"run_on_latest_version": body.run_on_latest_version,
"only_failed": body.only_failed,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
from airflow.api_fastapi.core_api.security import GetUserDep
from airflow.api_fastapi.core_api.services.public.common import BulkService
from airflow.listeners.listener import get_listener_manager
from airflow.models.dag import DAG
from airflow.models.taskinstance import TaskInstance as TI
from airflow.serialization.serialized_objects import SerializedDAG
from airflow.utils.state import TaskInstanceState

log = structlog.get_logger(__name__)
Expand All @@ -55,7 +55,7 @@ def _patch_ti_validate_request(
session: SessionDep,
map_index: int | None = -1,
update_mask: list[str] | None = Query(None),
) -> tuple[DAG, list[TI], dict]:
) -> tuple[SerializedDAG, list[TI], dict]:
dag = get_latest_version_of_dag(dag_bag, dag_id, session)
if not dag.has_task(task_id):
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Task '{task_id}' not found in DAG '{dag_id}'")
Expand Down Expand Up @@ -94,7 +94,7 @@ def _patch_ti_validate_request(
def _patch_task_instance_state(
task_id: str,
dag_run_id: str,
dag: DAG,
dag: SerializedDAG,
task_instance_body: BulkTaskInstanceBody | PatchTaskInstanceBody,
data: dict,
session: Session,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
CalendarTimeRangeCollectionResponse,
CalendarTimeRangeResponse,
)
from airflow.models.dag import DAG
from airflow.models.dagrun import DagRun
from airflow.serialization.serialized_objects import SerializedDAG
from airflow.timetables._cron import CronMixin
from airflow.timetables.base import DataInterval, TimeRestriction
from airflow.timetables.simple import ContinuousTimetable
Expand All @@ -52,7 +52,7 @@ def get_calendar_data(
self,
dag_id: str,
session: Session,
dag: DAG,
dag: SerializedDAG,
logical_date: RangeFilter,
granularity: Literal["hourly", "daily"] = "daily",
) -> CalendarTimeRangeCollectionResponse:
Expand Down Expand Up @@ -126,7 +126,7 @@ def _get_historical_dag_runs(

def _get_planned_dag_runs(
self,
dag: DAG,
dag: SerializedDAG,
raw_dag_states: list[Row],
logical_date: RangeFilter,
granularity: Literal["hourly", "daily"],
Expand All @@ -152,7 +152,7 @@ def _get_planned_dag_runs(
dag, last_data_interval, year, restriction, logical_date, granularity
)

def _should_calculate_planned_runs(self, dag: DAG, raw_dag_states: list[Row]) -> bool:
def _should_calculate_planned_runs(self, dag: SerializedDAG, raw_dag_states: list[Row]) -> bool:
"""Check if we should calculate planned runs."""
return (
bool(raw_dag_states)
Expand All @@ -177,7 +177,7 @@ def _get_last_data_interval(self, raw_dag_states: list[Row]) -> DataInterval | N

def _calculate_cron_planned_runs(
self,
dag: DAG,
dag: SerializedDAG,
last_data_interval: DataInterval,
year: int,
logical_date: RangeFilter,
Expand Down Expand Up @@ -208,7 +208,7 @@ def _calculate_cron_planned_runs(

def _calculate_timetable_planned_runs(
self,
dag: DAG,
dag: SerializedDAG,
last_data_interval: DataInterval,
year: int,
restriction: TimeRestriction,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@
from sqlalchemy.sql.dml import Update

from airflow.models.expandinput import SchedulerExpandInput
from airflow.sdk.types import Operator
from airflow.models.mappedoperator import MappedOperator
from airflow.serialization.serialized_objects import SerializedBaseOperator

router = VersionedAPIRouter()

Expand Down Expand Up @@ -254,7 +255,14 @@ def ti_run(

if dag := dag_bag.get_dag_for_run(dag_run=dr, session=session):
upstream_map_indexes = dict(
_get_upstream_map_indexes(dag.get_task(ti.task_id), ti.map_index, ti.run_id, session)
_get_upstream_map_indexes(
# TODO (GH-52141): This get_task should return scheduler
# types instead, but currently it inherits SDK's DAG.
cast("MappedOperator | SerializedBaseOperator", dag.get_task(ti.task_id)),
ti.map_index,
ti.run_id,
session=session,
)
)
else:
upstream_map_indexes = None
Expand Down Expand Up @@ -285,7 +293,7 @@ def ti_run(


def _get_upstream_map_indexes(
task: Operator, ti_map_index: int, run_id: str, session: SessionDep
task: MappedOperator | SerializedBaseOperator, ti_map_index: int, run_id: str, session: SessionDep
) -> Iterator[tuple[str, int | list[int] | None]]:
task_mapped_group = task.get_closest_mapped_task_group()
for upstream_task in task.upstream_list:
Expand Down
Loading