Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize DAG run scheduling based on dataset triggers and batching #37707

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 42 additions & 43 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -3785,59 +3785,58 @@ def dags_needing_dagruns(cls, session: Session) -> tuple[Query, dict[str, tuple[
"""
from airflow.models.serialized_dag import SerializedDagModel

NUM_DAGS_PER_DAGRUN_QUERY = cls.NUM_DAGS_PER_DAGRUN_QUERY
dataset_triggered_dag_info = {}

def dag_ready(dag_id: str, cond: BaseDatasetEventInput, statuses: dict) -> bool | None:
# if dag was serialized before 2.9 and we *just* upgraded,
# we may be dealing with old version. In that case,
# we may be dealing with old version. In that case,
# just wait for the dag to be reserialized.
try:
return cond.evaluate(statuses)
except AttributeError:
log.warning("dag '%s' has old serialization; skipping DAG run creation.", dag_id)
logging.warning("dag '%s' has old serialization; skipping DAG run creation.", dag_id)
return None

# this loads all the DDRQ records.... may need to limit num dags
all_records = session.scalars(select(DatasetDagRunQueue)).all()
by_dag = defaultdict(list)
for r in all_records:
by_dag[r.target_dag_id].append(r)
del all_records
dag_statuses = {}
for dag_id, records in by_dag.items():
dag_statuses[dag_id] = {x.dataset.uri: True for x in records}
ser_dags = session.scalars(
select(SerializedDagModel).where(SerializedDagModel.dag_id.in_(dag_statuses.keys()))
).all()
for ser_dag in ser_dags:
dag_id = ser_dag.dag_id
statuses = dag_statuses[dag_id]
if not dag_ready(dag_id, cond=ser_dag.dag.dataset_triggers, statuses=statuses):
del by_dag[dag_id]
del dag_statuses[dag_id]
del dag_statuses
dataset_triggered_dag_info = {}
for dag_id, records in by_dag.items():
times = sorted(x.created_at for x in records)
dataset_triggered_dag_info[dag_id] = (times[0], times[-1])
del by_dag
dataset_triggered_dag_ids = set(dataset_triggered_dag_info.keys())
if dataset_triggered_dag_ids:
exclusion_list = set(
session.scalars(
select(DagModel.dag_id)
.join(DagRun.dag_model)
.where(DagRun.state.in_((DagRunState.QUEUED, DagRunState.RUNNING)))
.where(DagModel.dag_id.in_(dataset_triggered_dag_ids))
.group_by(DagModel.dag_id)
.having(func.count() >= func.max(DagModel.max_active_runs))
)
dag_statuses: dict[str, dict[str, bool]] = defaultdict(dict)

# Get distinct target_dag_id from DatasetDagRunQueue
distinct_dag_ids_subq = session.query(DatasetDagRunQueue.target_dag_id).distinct().subquery()

# Process in batches using NUM_DAGS_PER_DAGRUN_QUERY
batch_offset = 0
while True:
batch = (
session.query(distinct_dag_ids_subq.c.target_dag_id)
.order_by(distinct_dag_ids_subq.c.target_dag_id)
.limit(NUM_DAGS_PER_DAGRUN_QUERY)
.offset(batch_offset)
.all()
)
if exclusion_list:
dataset_triggered_dag_ids -= exclusion_list
dataset_triggered_dag_info = {
k: v for k, v in dataset_triggered_dag_info.items() if k not in exclusion_list

if not batch:
break # Exit loop if no more batches

batch_dag_ids = [row[0] for row in batch]
batch_offset += NUM_DAGS_PER_DAGRUN_QUERY

for dag_id in batch_dag_ids:
# Populate dag_statuses for the current batch
dag_statuses[dag_id] = {
record.dataset.uri: True
for record in session.query(DatasetDagRunQueue).filter_by(target_dag_id=dag_id).all()
}

# We limit so that _one_ scheduler doesn't try to do all the creation of dag runs
ser_dag = session.query(SerializedDagModel).filter_by(dag_id=dag_id).first()
if ser_dag and dag_ready(dag_id, ser_dag.dag.dataset_triggers, dag_statuses[dag_id]):
# The dag is ready, note down the times for dataset_triggered_dag_info
times = [
record.created_at
for record in session.query(DatasetDagRunQueue).filter_by(target_dag_id=dag_id).all()
]
if times: # Ensure times list is not empty
dataset_triggered_dag_info[dag_id] = (min(times), max(times))

query = (
select(cls)
.where(
Expand All @@ -3846,7 +3845,7 @@ def dag_ready(dag_id: str, cond: BaseDatasetEventInput, statuses: dict) -> bool
cls.has_import_errors == expression.false(),
or_(
cls.next_dagrun_create_after <= func.now(),
cls.dag_id.in_(dataset_triggered_dag_ids),
cls.dag_id.in_(dataset_triggered_dag_info.keys()),
),
)
.order_by(cls.next_dagrun_create_after)
Expand Down
Loading