Skip to content

Commit

Permalink
Add docstrings to methods
Browse files Browse the repository at this point in the history
  • Loading branch information
ephraimbuddy committed Nov 5, 2024
1 parent a248737 commit b95ed50
Show file tree
Hide file tree
Showing 6 changed files with 1,864 additions and 1,851 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
add dag versioning.
Revision ID: 2b47dc6bc8df
Revises: d8cd3297971e
Revises: d03e4a635aa3
Create Date: 2024-10-09 05:44:04.670984
"""
Expand All @@ -38,7 +38,7 @@

# revision identifiers, used by Alembic.
revision = "2b47dc6bc8df"
down_revision = "d8cd3297971e"
down_revision = "d03e4a635aa3"
branch_labels = None
depends_on = None
airflow_version = "3.0.0"
Expand Down
2 changes: 1 addition & 1 deletion airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2278,7 +2278,7 @@ def dag_ready(dag_id: str, cond: BaseAsset, statuses: dict) -> bool | None:
dag_statuses = {}
for dag_id, records in by_dag.items():
dag_statuses[dag_id] = {x.asset.uri: True for x in records}
dag_versions = DagVersion.get_latest_dag_versions(list(dag_statuses.keys()), session=session)
dag_versions = DagVersion.get_latest_dag_versions(dag_ids=list(dag_statuses.keys()), session=session)
ser_dags = [x.serialized_dag for x in dag_versions]

for ser_dag in ser_dags:
Expand Down
51 changes: 46 additions & 5 deletions airflow/models/dag_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class DagVersion(Base):
)

def __repr__(self):
"""Represent the object as a string."""
return f"<DagVersion {self.dag_id} {self.version}>"

@classmethod
Expand All @@ -81,7 +82,17 @@ def write_dag(
version_number: int = 1,
session: Session = NEW_SESSION,
) -> DagVersion:
"""Write a new DagVersion into database."""
"""
Write a new DagVersion into database.
Checks if a version of the DAG exists and increments the version number if it does.
:param dag_id: The DAG ID.
:param version_name: The version name.
:param version_number: The version number.
:param session: The database session.
:return: The DagVersion object.
"""
existing_dag_version = session.scalar(
with_row_locks(cls._latest_version_select(dag_id), of=DagVersion, session=session, nowait=True)
)
Expand All @@ -102,11 +113,24 @@ def write_dag(

@classmethod
def _latest_version_select(cls, dag_id: str) -> Select:
"""
Get the select object to get the latest version of the DAG.
:param dag_id: The DAG ID.
:return: The select object.
"""
return select(cls).where(cls.dag_id == dag_id).order_by(cls.created_at.desc()).limit(1)

@classmethod
@provide_session
def get_latest_version(cls, dag_id: str, session: Session = NEW_SESSION) -> DagVersion | None:
def get_latest_version(cls, dag_id: str, *, session: Session = NEW_SESSION) -> DagVersion | None:
"""
Get the latest version of the DAG.
:param dag_id: The DAG ID.
:param session: The database session.
:return: The latest version of the DAG or None if not found.
"""
return session.scalar(cls._latest_version_select(dag_id))

@classmethod
Expand All @@ -115,8 +139,17 @@ def get_version(
cls,
dag_id: str,
version_number: int = 1,
*,
session: Session = NEW_SESSION,
) -> DagVersion | None:
"""
Get the version of the DAG.
:param dag_id: The DAG ID.
:param version_number: The version number.
:param session: The database session.
:return: The version of the DAG or None if not found.
"""
version_select_obj = (
select(cls)
.where(cls.dag_id == dag_id, cls.version_number == version_number)
Expand All @@ -127,16 +160,24 @@ def get_version(

@property
def version(self) -> str:
"""Return the version name."""
"""A human-friendly representation of the version."""
name = f"{self.version_number}"
if self.version_name:
name = f"{self.version_name}-{self.version_number}"
return name

@classmethod
@provide_session
def get_latest_dag_versions(cls, dag_ids: list[str], session: Session = NEW_SESSION) -> list[DagVersion]:
"""Get the latest version of DAGs."""
def get_latest_dag_versions(
cls, *, dag_ids: list[str], session: Session = NEW_SESSION
) -> list[DagVersion]:
"""
Get the latest version of DAGs.
:param dag_ids: The list of DAG IDs.
:param session: The database session.
:return: The latest version of the DAGs.
"""
# Subquery to get the latest version number per dag_id
latest_version_subquery = (
session.query(cls.dag_id, func.max(cls.created_at).label("created_at"))
Expand Down
2 changes: 1 addition & 1 deletion docs/apache-airflow/img/airflow_erd.sha256
Original file line number Diff line number Diff line change
@@ -1 +1 @@
5bdd94eb42b63cd676d0bd51afd6a52112a30927881ae8990c50c5ef95bf8898
f997746cdee45147831f81bcd2d43ec3ca45d7429afa691e385104987ed51d88
Loading

0 comments on commit b95ed50

Please sign in to comment.