Skip to content

Commit

Permalink
Move getting latest serdags to SerializedDagModel
Browse files Browse the repository at this point in the history
  • Loading branch information
ephraimbuddy committed Nov 5, 2024
1 parent 569eb53 commit a7bba02
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 65 deletions.
7 changes: 4 additions & 3 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@
)
from airflow.models.base import Base, StringID
from airflow.models.baseoperator import BaseOperator
from airflow.models.dagcode import DagCode
from airflow.models.dag_version import DagVersion
from airflow.models.dagrun import RUN_ID_REGEX, DagRun
from airflow.models.taskinstance import (
Expand Down Expand Up @@ -2258,6 +2257,7 @@ def dags_needing_dagruns(cls, session: Session) -> tuple[Query, dict[str, tuple[
you should ensure that any scheduling decisions are made in a single transaction -- as soon as the
transaction is committed it will be unlocked.
"""
from airflow.models.serialized_dag import SerializedDagModel

def dag_ready(dag_id: str, cond: BaseAsset, statuses: dict) -> bool | None:
# if dag was serialized before 2.9 and we *just* upgraded,
Expand All @@ -2278,8 +2278,9 @@ def dag_ready(dag_id: str, cond: BaseAsset, statuses: dict) -> bool | None:
dag_statuses = {}
for dag_id, records in by_dag.items():
dag_statuses[dag_id] = {x.asset.uri: True for x in records}
dag_versions = DagVersion.get_latest_dag_versions(dag_ids=list(dag_statuses.keys()), session=session)
ser_dags = [x.serialized_dag for x in dag_versions]
ser_dags = SerializedDagModel.get_latest_serialized_dags(
dag_ids=list(dag_statuses.keys()), session=session
)

for ser_dag in ser_dags:
dag_id = ser_dag.dag_id
Expand Down
31 changes: 1 addition & 30 deletions airflow/models/dag_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing import TYPE_CHECKING

import uuid6
from sqlalchemy import Column, ForeignKey, Integer, UniqueConstraint, func, select
from sqlalchemy import Column, ForeignKey, Integer, UniqueConstraint, select
from sqlalchemy.orm import relationship
from sqlalchemy_utils import UUIDType

Expand Down Expand Up @@ -165,32 +165,3 @@ def version(self) -> str:
if self.version_name:
name = f"{self.version_name}-{self.version_number}"
return name

@classmethod
@provide_session
def get_latest_dag_versions(
cls, *, dag_ids: list[str], session: Session = NEW_SESSION
) -> list[DagVersion]:
"""
Get the latest version of DAGs.
:param dag_ids: The list of DAG IDs.
:param session: The database session.
:return: The latest version of the DAGs.
"""
# Subquery to get the latest version number per dag_id
latest_version_subquery = (
session.query(cls.dag_id, func.max(cls.created_at).label("created_at"))
.filter(cls.dag_id.in_(dag_ids))
.group_by(cls.dag_id)
.subquery()
)
latest_versions = session.scalars(
select(cls)
.join(
latest_version_subquery,
cls.created_at == latest_version_subquery.c.created_at,
)
.where(cls.dag_id.in_(dag_ids))
).all()
return latest_versions or []
29 changes: 29 additions & 0 deletions airflow/models/serialized_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,35 @@ def write_dag(
def latest_item_select_object(cls, dag_id):
return select(cls).where(cls.dag_id == dag_id).order_by(cls.last_updated.desc()).limit(1)

@classmethod
@provide_session
def get_latest_serialized_dags(
cls, *, dag_ids: list[str], session: Session = NEW_SESSION
) -> list[SerializedDagModel]:
"""
Get the latest serialized dags of given DAGs.
:param dag_ids: The list of DAG IDs.
:param session: The database session.
:return: The latest serialized dag of the DAGs.
"""
# Subquery to get the latest serdag per dag_id
latest_serdag_subquery = (
session.query(cls.dag_id, func.max(cls.last_updated).label("last_updated"))
.filter(cls.dag_id.in_(dag_ids))
.group_by(cls.dag_id)
.subquery()
)
latest_serdags = session.scalars(
select(cls)
.join(
latest_serdag_subquery,
cls.last_updated == latest_serdag_subquery.c.last_updated,
)
.where(cls.dag_id.in_(dag_ids))
).all()
return latest_serdags or []

@classmethod
@provide_session
def read_all_dags(cls, session: Session = NEW_SESSION) -> dict[str, SerializedDAG]:
Expand Down
31 changes: 0 additions & 31 deletions tests/models/test_dag_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,37 +95,6 @@ def test_get_version(self, dag_maker, session):
assert version.dag_id == dag1_id
assert version.version == "my_version-1"

def test_get_latest_dag_versions(self, dag_maker, session):
# first dag
version_name = "test_v"
with dag_maker("dag1", version_name=version_name) as dag:
EmptyOperator(task_id="task1")
dag.sync_to_db()
SerializedDagModel.write_dag(dag)
with dag_maker("dag1", version_name=version_name) as dag:
EmptyOperator(task_id="task1")
EmptyOperator(task_id="task2")
dag.sync_to_db()
SerializedDagModel.write_dag(dag)
# second dag
version_name2 = "test_v2"
with dag_maker("dag2", version_name=version_name2) as dag:
EmptyOperator(task_id="task1")
dag.sync_to_db()
SerializedDagModel.write_dag(dag)
with dag_maker("dag2", version_name=version_name2) as dag:
EmptyOperator(task_id="task1")
EmptyOperator(task_id="task2")
dag.sync_to_db()
SerializedDagModel.write_dag(dag)

# Total versions should be 4
assert session.scalar(select(func.count()).select_from(DagVersion)) == 4

latest_versions_for_the_dags = {f"{version_name}-2", f"{version_name2}-2"}
latest_versions = DagVersion.get_latest_dag_versions(["dag1", "dag2"])
assert latest_versions_for_the_dags == {x.version for x in latest_versions}

@pytest.mark.need_serialized_dag
def test_version_property(self, dag_maker):
version_name = "my_version"
Expand Down
31 changes: 30 additions & 1 deletion tests/models/test_serialized_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@

import pendulum
import pytest
from sqlalchemy import select
from sqlalchemy import func, select

import airflow.example_dags as example_dags_module
from airflow.assets import Asset
from airflow.models.dag import DAG
from airflow.models.dagbag import DagBag
from airflow.models.dagcode import DagCode
from airflow.models.serialized_dag import SerializedDagModel as SDM
from airflow.operators.empty import EmptyOperator
from airflow.providers.standard.operators.bash import BashOperator
from airflow.serialization.serialized_objects import SerializedDAG
from airflow.settings import json
Expand Down Expand Up @@ -284,3 +285,31 @@ def get_hash_set():
first_hashes = get_hash_set()
# assert that the hashes are the same
assert first_hashes == get_hash_set()

def test_get_latest_serdag_versions(self, dag_maker, session):
# first dag
with dag_maker("dag1") as dag:
EmptyOperator(task_id="task1")
dag.sync_to_db()
SDM.write_dag(dag)
with dag_maker("dag1") as dag:
EmptyOperator(task_id="task1")
EmptyOperator(task_id="task2")
dag.sync_to_db()
SDM.write_dag(dag)
# second dag
with dag_maker("dag2") as dag:
EmptyOperator(task_id="task1")
dag.sync_to_db()
SDM.write_dag(dag)
with dag_maker("dag2") as dag:
EmptyOperator(task_id="task1")
EmptyOperator(task_id="task2")
dag.sync_to_db()
SDM.write_dag(dag)

# Total serdags should be 4
assert session.scalar(select(func.count()).select_from(SDM)) == 4

latest_versions = SDM.get_latest_serialized_dags(dag_ids=["dag1", "dag2"], session=session)
assert len(latest_versions) == 2

0 comments on commit a7bba02

Please sign in to comment.