Skip to content

Commit

Permalink
Optimize DAG run scheduling based on dataset triggers and batching
Browse files Browse the repository at this point in the history
  • Loading branch information
sunank200 committed Feb 27, 2024
1 parent cd31715 commit 493a934
Showing 1 changed file with 42 additions and 43 deletions.
85 changes: 42 additions & 43 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -3785,59 +3785,58 @@ def dags_needing_dagruns(cls, session: Session) -> tuple[Query, dict[str, tuple[
"""
from airflow.models.serialized_dag import SerializedDagModel

NUM_DAGS_PER_DAGRUN_QUERY = cls.NUM_DAGS_PER_DAGRUN_QUERY
dataset_triggered_dag_info = {}

def dag_ready(dag_id: str, cond: BaseDatasetEventInput, statuses: dict) -> bool | None:
# if dag was serialized before 2.9 and we *just* upgraded,
# we may be dealing with old version. In that case,
# we may be dealing with old version. In that case,
# just wait for the dag to be reserialized.
try:
return cond.evaluate(statuses)
except AttributeError:
log.warning("dag '%s' has old serialization; skipping DAG run creation.", dag_id)
logging.warning("dag '%s' has old serialization; skipping DAG run creation.", dag_id)
return None

# this loads all the DDRQ records.... may need to limit num dags
all_records = session.scalars(select(DatasetDagRunQueue)).all()
by_dag = defaultdict(list)
for r in all_records:
by_dag[r.target_dag_id].append(r)
del all_records
dag_statuses = {}
for dag_id, records in by_dag.items():
dag_statuses[dag_id] = {x.dataset.uri: True for x in records}
ser_dags = session.scalars(
select(SerializedDagModel).where(SerializedDagModel.dag_id.in_(dag_statuses.keys()))
).all()
for ser_dag in ser_dags:
dag_id = ser_dag.dag_id
statuses = dag_statuses[dag_id]
if not dag_ready(dag_id, cond=ser_dag.dag.dataset_triggers, statuses=statuses):
del by_dag[dag_id]
del dag_statuses[dag_id]
del dag_statuses
dataset_triggered_dag_info = {}
for dag_id, records in by_dag.items():
times = sorted(x.created_at for x in records)
dataset_triggered_dag_info[dag_id] = (times[0], times[-1])
del by_dag
dataset_triggered_dag_ids = set(dataset_triggered_dag_info.keys())
if dataset_triggered_dag_ids:
exclusion_list = set(
session.scalars(
select(DagModel.dag_id)
.join(DagRun.dag_model)
.where(DagRun.state.in_((DagRunState.QUEUED, DagRunState.RUNNING)))
.where(DagModel.dag_id.in_(dataset_triggered_dag_ids))
.group_by(DagModel.dag_id)
.having(func.count() >= func.max(DagModel.max_active_runs))
)
dag_statuses: dict[str, dict[str, bool]] = defaultdict(dict)

# Get distinct target_dag_id from DatasetDagRunQueue
distinct_dag_ids_subq = session.query(DatasetDagRunQueue.target_dag_id).distinct().subquery()

# Process in batches using NUM_DAGS_PER_DAGRUN_QUERY
batch_offset = 0
while True:
batch = (
session.query(distinct_dag_ids_subq.c.target_dag_id)
.order_by(distinct_dag_ids_subq.c.target_dag_id)
.limit(NUM_DAGS_PER_DAGRUN_QUERY)
.offset(batch_offset)
.all()
)
if exclusion_list:
dataset_triggered_dag_ids -= exclusion_list
dataset_triggered_dag_info = {
k: v for k, v in dataset_triggered_dag_info.items() if k not in exclusion_list

if not batch:
break # Exit loop if no more batches

batch_dag_ids = [row[0] for row in batch]
batch_offset += NUM_DAGS_PER_DAGRUN_QUERY

for dag_id in batch_dag_ids:
# Populate dag_statuses for the current batch
dag_statuses[dag_id] = {
record.dataset.uri: True
for record in session.query(DatasetDagRunQueue).filter_by(target_dag_id=dag_id).all()
}

# We limit so that _one_ scheduler doesn't try to do all the creation of dag runs
ser_dag = session.query(SerializedDagModel).filter_by(dag_id=dag_id).first()
if ser_dag and dag_ready(dag_id, ser_dag.dag.dataset_triggers, dag_statuses[dag_id]):
# The dag is ready, note down the times for dataset_triggered_dag_info
times = [
record.created_at
for record in session.query(DatasetDagRunQueue).filter_by(target_dag_id=dag_id).all()
]
if times: # Ensure times list is not empty
dataset_triggered_dag_info[dag_id] = (min(times), max(times))

query = (
select(cls)
.where(
Expand All @@ -3846,7 +3845,7 @@ def dag_ready(dag_id: str, cond: BaseDatasetEventInput, statuses: dict) -> bool
cls.has_import_errors == expression.false(),
or_(
cls.next_dagrun_create_after <= func.now(),
cls.dag_id.in_(dataset_triggered_dag_ids),
cls.dag_id.in_(dataset_triggered_dag_info.keys()),
),
)
.order_by(cls.next_dagrun_create_after)
Expand Down

0 comments on commit 493a934

Please sign in to comment.