diff --git a/airflow-core/src/airflow/models/asset.py b/airflow-core/src/airflow/models/asset.py index a28addd99819f..4d9bd8ae1cf8d 100644 --- a/airflow-core/src/airflow/models/asset.py +++ b/airflow-core/src/airflow/models/asset.py @@ -35,12 +35,12 @@ text, ) from sqlalchemy.ext.associationproxy import association_proxy -from sqlalchemy.orm import relationship +from sqlalchemy.orm import Mapped, relationship from airflow._shared.timezones import timezone from airflow.models.base import Base, StringID from airflow.settings import json -from airflow.utils.sqlalchemy import UtcDateTime +from airflow.utils.sqlalchemy import UtcDateTime, mapped_column if TYPE_CHECKING: from collections.abc import Iterable @@ -140,7 +140,7 @@ def remove_references_to_deleted_dags(session: Session): class AssetWatcherModel(Base): """A table to store asset watchers.""" - name = Column( + name: Mapped[str] = mapped_column( String(length=1500).with_variant( String( length=1500, @@ -152,8 +152,8 @@ class AssetWatcherModel(Base): ), nullable=False, ) - asset_id = Column(Integer, primary_key=True, nullable=False) - trigger_id = Column(Integer, primary_key=True, nullable=False) + asset_id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False) + trigger_id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False) asset = relationship("AssetModel", back_populates="watchers") trigger = relationship("Trigger", back_populates="asset_watchers") @@ -187,8 +187,8 @@ class AssetAliasModel(Base): :param uri: a string that uniquely identifies the asset alias """ - id = Column(Integer, primary_key=True, autoincrement=True) - name = Column( + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + name: Mapped[str] = mapped_column( String(length=1500).with_variant( String( length=1500, @@ -200,7 +200,7 @@ class AssetAliasModel(Base): ), nullable=False, ) - group = Column( + group: Mapped[str] = mapped_column( String(length=1500).with_variant( String( length=1500, @@ -263,8 +263,8 @@ class AssetModel(Base): :param extra: JSON field for arbitrary extra info """ - id = Column(Integer, primary_key=True, autoincrement=True) - name = Column( + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + name: Mapped[str] = mapped_column( String(length=1500).with_variant( String( length=1500, @@ -276,7 +276,7 @@ class AssetModel(Base): ), nullable=False, ) - uri = Column( + uri: Mapped[str] = mapped_column( String(length=1500).with_variant( String( length=1500, @@ -288,7 +288,7 @@ class AssetModel(Base): ), nullable=False, ) - group = Column( + group: Mapped[str] = mapped_column( String(length=1500).with_variant( String( length=1500, @@ -301,10 +301,12 @@ class AssetModel(Base): default=str, nullable=False, ) - extra = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={}) + extra: Mapped[dict] = mapped_column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={}) - created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) - updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) + created_at: Mapped[UtcDateTime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) + updated_at: Mapped[UtcDateTime] = mapped_column( + UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False + ) active = relationship("AssetActive", uselist=False, viewonly=True, back_populates="asset") @@ -374,7 +376,7 @@ class AssetActive(Base): *name and URI are each unique* within active assets. """ - name = Column( + name: Mapped[str] = mapped_column( String(length=1500).with_variant( String( length=1500, @@ -386,7 +388,7 @@ class AssetActive(Base): ), nullable=False, ) - uri = Column( + uri: Mapped[str] = mapped_column( String(length=1500).with_variant( String( length=1500, @@ -422,7 +424,7 @@ def for_asset(cls, asset: AssetModel) -> AssetActive: class DagScheduleAssetNameReference(Base): """Reference from a DAG to an asset name reference of which it is a consumer.""" - name = Column( + name: Mapped[str] = mapped_column( String(length=1500).with_variant( String( length=1500, @@ -435,8 +437,8 @@ class DagScheduleAssetNameReference(Base): primary_key=True, nullable=False, ) - dag_id = Column(StringID(), primary_key=True, nullable=False) - created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) + dag_id: Mapped[str] = mapped_column(StringID(), primary_key=True, nullable=False) + created_at: Mapped[UtcDateTime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) dag = relationship("DagModel", back_populates="schedule_asset_name_references") @@ -468,7 +470,7 @@ def __repr__(self): class DagScheduleAssetUriReference(Base): """Reference from a DAG to an asset URI reference of which it is a consumer.""" - uri = Column( + uri: Mapped[str] = mapped_column( String(length=1500).with_variant( String( length=1500, @@ -481,8 +483,8 @@ class DagScheduleAssetUriReference(Base): primary_key=True, nullable=False, ) - dag_id = Column(StringID(), primary_key=True, nullable=False) - created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) + dag_id: Mapped[str] = mapped_column(StringID(), primary_key=True, nullable=False) + created_at: Mapped[UtcDateTime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) dag = relationship("DagModel", back_populates="schedule_asset_uri_references") @@ -514,10 +516,12 @@ def __repr__(self): class DagScheduleAssetAliasReference(Base): """References from a DAG to an asset alias of which it is a consumer.""" - alias_id = Column(Integer, primary_key=True, nullable=False) - dag_id = Column(StringID(), primary_key=True, nullable=False) - created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) - updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) + alias_id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False) + dag_id: Mapped[str] = mapped_column(StringID(), primary_key=True, nullable=False) + created_at: Mapped[UtcDateTime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) + updated_at: Mapped[UtcDateTime] = mapped_column( + UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False + ) asset_alias = relationship("AssetAliasModel", back_populates="scheduled_dags") dag = relationship("DagModel", back_populates="schedule_asset_alias_references") @@ -556,10 +560,12 @@ def __repr__(self): class DagScheduleAssetReference(Base): """References from a DAG to an asset of which it is a consumer.""" - asset_id = Column(Integer, primary_key=True, nullable=False) - dag_id = Column(StringID(), primary_key=True, nullable=False) - created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) - updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) + asset_id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False) + dag_id: Mapped[str] = mapped_column(StringID(), primary_key=True, nullable=False) + created_at: Mapped[UtcDateTime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) + updated_at: Mapped[UtcDateTime] = mapped_column( + UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False + ) asset = relationship("AssetModel", back_populates="scheduled_dags") dag = relationship("DagModel", back_populates="schedule_asset_references") @@ -607,11 +613,13 @@ def __repr__(self): class TaskOutletAssetReference(Base): """References from a task to an asset that it updates / produces.""" - asset_id = Column(Integer, primary_key=True, nullable=False) - dag_id = Column(StringID(), primary_key=True, nullable=False) - task_id = Column(StringID(), primary_key=True, nullable=False) - created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) - updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) + asset_id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False) + dag_id: Mapped[str] = mapped_column(StringID(), primary_key=True, nullable=False) + task_id: Mapped[str] = mapped_column(StringID(), primary_key=True, nullable=False) + created_at: Mapped[UtcDateTime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) + updated_at: Mapped[UtcDateTime] = mapped_column( + UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False + ) asset = relationship("AssetModel", back_populates="producing_tasks") @@ -656,11 +664,13 @@ def __repr__(self): class TaskInletAssetReference(Base): """References from a task to an asset that it references as an inlet.""" - asset_id = Column(Integer, primary_key=True, nullable=False) - dag_id = Column(StringID(), primary_key=True, nullable=False) - task_id = Column(StringID(), primary_key=True, nullable=False) - created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) - updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) + asset_id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False) + dag_id: Mapped[str] = mapped_column(StringID(), primary_key=True, nullable=False) + task_id: Mapped[str] = mapped_column(StringID(), primary_key=True, nullable=False) + created_at: Mapped[UtcDateTime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) + updated_at: Mapped[UtcDateTime] = mapped_column( + UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False + ) asset = relationship("AssetModel", back_populates="consuming_tasks") @@ -700,9 +710,9 @@ def __repr__(self): class AssetDagRunQueue(Base): """Model for storing asset events that need processing.""" - asset_id = Column(Integer, primary_key=True, nullable=False) - target_dag_id = Column(StringID(), primary_key=True, nullable=False) - created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) + asset_id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False) + target_dag_id: Mapped[str] = mapped_column(StringID(), primary_key=True, nullable=False) + created_at: Mapped[UtcDateTime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) asset = relationship("AssetModel", viewonly=True) dag_model = relationship("DagModel", viewonly=True) @@ -765,14 +775,14 @@ class AssetEvent(Base): if the foreign key object is. """ - id = Column(Integer, primary_key=True, autoincrement=True) - asset_id = Column(Integer, nullable=False) - extra = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={}) - source_task_id = Column(StringID(), nullable=True) - source_dag_id = Column(StringID(), nullable=True) - source_run_id = Column(StringID(), nullable=True) - source_map_index = Column(Integer, nullable=True, server_default=text("-1")) - timestamp = Column(UtcDateTime, default=timezone.utcnow, nullable=False) + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + asset_id: Mapped[int] = mapped_column(Integer, nullable=False) + extra: Mapped[dict] = mapped_column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={}) + source_task_id: Mapped[str | None] = mapped_column(StringID(), nullable=True) + source_dag_id: Mapped[str | None] = mapped_column(StringID(), nullable=True) + source_run_id: Mapped[str | None] = mapped_column(StringID(), nullable=True) + source_map_index: Mapped[int | None] = mapped_column(Integer, nullable=True, server_default=text("-1")) + timestamp: Mapped[UtcDateTime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) __tablename__ = "asset_event" __table_args__ = ( diff --git a/airflow-core/src/airflow/models/backfill.py b/airflow-core/src/airflow/models/backfill.py index 7840dfaf3ec5c..569224ca8cbe8 100644 --- a/airflow-core/src/airflow/models/backfill.py +++ b/airflow-core/src/airflow/models/backfill.py @@ -30,7 +30,6 @@ from sqlalchemy import ( Boolean, - Column, ForeignKeyConstraint, Integer, String, @@ -40,7 +39,7 @@ select, ) from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import relationship, validates +from sqlalchemy.orm import Mapped, relationship, validates from sqlalchemy_jsonfield import JSONField from airflow._shared.timezones import timezone @@ -48,7 +47,7 @@ from airflow.models.base import Base, StringID from airflow.settings import json from airflow.utils.session import create_session -from airflow.utils.sqlalchemy import UtcDateTime, nulls_first, with_row_locks +from airflow.utils.sqlalchemy import UtcDateTime, mapped_column, nulls_first, with_row_locks from airflow.utils.state import DagRunState from airflow.utils.types import DagRunTriggeredByType, DagRunType @@ -116,23 +115,27 @@ class Backfill(Base): __tablename__ = "backfill" - id = Column(Integer, primary_key=True, autoincrement=True) - dag_id = Column(StringID(), nullable=False) - from_date = Column(UtcDateTime, nullable=False) - to_date = Column(UtcDateTime, nullable=False) - dag_run_conf = Column(JSONField(json=json), nullable=False, default={}) - is_paused = Column(Boolean, default=False) + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + dag_id: Mapped[str] = mapped_column(StringID(), nullable=False) + from_date: Mapped[UtcDateTime] = mapped_column(UtcDateTime, nullable=False) + to_date: Mapped[UtcDateTime] = mapped_column(UtcDateTime, nullable=False) + dag_run_conf: Mapped[JSONField] = mapped_column(JSONField(json=json), nullable=False, default={}) + is_paused: Mapped[bool | None] = mapped_column(Boolean, default=False, nullable=True) """ Controls whether new dag runs will be created for this backfill. Does not pause existing dag runs. """ - reprocess_behavior = Column(StringID(), nullable=False, default=ReprocessBehavior.NONE) - max_active_runs = Column(Integer, default=10, nullable=False) - created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) - completed_at = Column(UtcDateTime, nullable=True) - updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) - triggering_user_name = Column( + reprocess_behavior: Mapped[str] = mapped_column( + StringID(), nullable=False, default=ReprocessBehavior.NONE + ) + max_active_runs: Mapped[int] = mapped_column(Integer, default=10, nullable=False) + created_at: Mapped[UtcDateTime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) + completed_at: Mapped[UtcDateTime | None] = mapped_column(UtcDateTime, nullable=True) + updated_at: Mapped[UtcDateTime] = mapped_column( + UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False + ) + triggering_user_name: Mapped[str | None] = mapped_column( String(512), nullable=True, ) # The user that triggered the Backfill, if applicable @@ -166,12 +169,12 @@ class BackfillDagRun(Base): """Mapping table between backfill run and dag run.""" __tablename__ = "backfill_dag_run" - id = Column(Integer, primary_key=True, autoincrement=True) - backfill_id = Column(Integer, nullable=False) - dag_run_id = Column(Integer, nullable=True) - exception_reason = Column(StringID()) - logical_date = Column(UtcDateTime, nullable=False) - sort_ordinal = Column(Integer, nullable=False) + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + backfill_id: Mapped[int] = mapped_column(Integer, nullable=False) + dag_run_id: Mapped[int | None] = mapped_column(Integer, nullable=True) + exception_reason: Mapped[str | None] = mapped_column(StringID(), nullable=True) + logical_date: Mapped[UtcDateTime] = mapped_column(UtcDateTime, nullable=False) + sort_ordinal: Mapped[int] = mapped_column(Integer, nullable=False) backfill = relationship("Backfill", back_populates="backfill_dag_run_associations") dag_run = relationship("DagRun") diff --git a/airflow-core/src/airflow/models/base.py b/airflow-core/src/airflow/models/base.py index 0548e08a8f605..cc9853330625a 100644 --- a/airflow-core/src/airflow/models/base.py +++ b/airflow-core/src/airflow/models/base.py @@ -19,11 +19,11 @@ from typing import TYPE_CHECKING, Any -from sqlalchemy import Column, Integer, MetaData, String, text -from sqlalchemy.orm import registry +from sqlalchemy import Integer, MetaData, String, text +from sqlalchemy.orm import Mapped, registry from airflow.configuration import conf -from airflow.utils.sqlalchemy import is_sqlalchemy_v1 +from airflow.utils.sqlalchemy import is_sqlalchemy_v1, mapped_column SQL_ALCHEMY_SCHEMA = conf.get("database", "SQL_ALCHEMY_SCHEMA") @@ -94,7 +94,7 @@ class TaskInstanceDependencies(Base): __abstract__ = True - task_id = Column(StringID(), nullable=False) - dag_id = Column(StringID(), nullable=False) - run_id = Column(StringID(), nullable=False) - map_index = Column(Integer, nullable=False, server_default=text("-1")) + task_id: Mapped[str] = mapped_column(StringID(), nullable=False) + dag_id: Mapped[str] = mapped_column(StringID(), nullable=False) + run_id: Mapped[str] = mapped_column(StringID(), nullable=False) + map_index: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("-1")) diff --git a/airflow-core/src/airflow/models/connection.py b/airflow-core/src/airflow/models/connection.py index 875d8d0b5711b..06c2ee54c805d 100644 --- a/airflow-core/src/airflow/models/connection.py +++ b/airflow-core/src/airflow/models/connection.py @@ -27,8 +27,8 @@ from typing import Any from urllib.parse import parse_qsl, quote, unquote, urlencode, urlsplit -from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Text, select -from sqlalchemy.orm import declared_attr, reconstructor, synonym +from sqlalchemy import Boolean, ForeignKey, Integer, String, Text, select +from sqlalchemy.orm import Mapped, declared_attr, reconstructor, synonym from sqlalchemy_utils import UUIDType from airflow._shared.secrets_masker import mask_secret @@ -42,6 +42,7 @@ from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.module_loading import import_string from airflow.utils.session import NEW_SESSION, provide_session +from airflow.utils.sqlalchemy import mapped_column log = logging.getLogger(__name__) # sanitize the `conn_id` pattern by allowing alphanumeric characters plus @@ -126,19 +127,21 @@ class Connection(Base, LoggingMixin): __tablename__ = "connection" - id = Column(Integer(), primary_key=True) - conn_id = Column(String(ID_LEN), unique=True, nullable=False) - conn_type = Column(String(500), nullable=False) - description = Column(Text().with_variant(Text(5000), "mysql").with_variant(String(5000), "sqlite")) - host = Column(String(500)) - schema = Column(String(500)) - login = Column(Text()) - _password = Column("password", Text()) - port = Column(Integer()) - is_encrypted = Column(Boolean, unique=False, default=False) - is_extra_encrypted = Column(Boolean, unique=False, default=False) - team_id = Column(UUIDType(binary=False), ForeignKey("team.id"), nullable=True) - _extra = Column("extra", Text()) + id: Mapped[int] = mapped_column(Integer(), primary_key=True) + conn_id: Mapped[str] = mapped_column(String(ID_LEN), unique=True, nullable=False) + conn_type: Mapped[str] = mapped_column(String(500), nullable=False) + description: Mapped[str | None] = mapped_column( + Text().with_variant(Text(5000), "mysql").with_variant(String(5000), "sqlite"), nullable=True + ) + host: Mapped[str | None] = mapped_column(String(500), nullable=True) + schema: Mapped[str | None] = mapped_column(String(500), nullable=True) + login: Mapped[str | None] = mapped_column(Text(), nullable=True) + _password: Mapped[str | None] = mapped_column("password", Text(), nullable=True) + port: Mapped[int | None] = mapped_column(Integer(), nullable=True) + is_encrypted: Mapped[bool] = mapped_column(Boolean, unique=False, default=False) + is_extra_encrypted: Mapped[bool] = mapped_column(Boolean, unique=False, default=False) + team_id: Mapped[str | None] = mapped_column(UUIDType(binary=False), ForeignKey("team.id"), nullable=True) + _extra: Mapped[str | None] = mapped_column("extra", Text(), nullable=True) def __init__( self, diff --git a/airflow-core/src/airflow/models/dag.py b/airflow-core/src/airflow/models/dag.py index 844c28b724c0f..da1412b28561f 100644 --- a/airflow-core/src/airflow/models/dag.py +++ b/airflow-core/src/airflow/models/dag.py @@ -27,7 +27,6 @@ from dateutil.relativedelta import relativedelta from sqlalchemy import ( Boolean, - Column, Float, ForeignKey, Index, @@ -42,7 +41,7 @@ ) from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.orm import backref, load_only, relationship +from sqlalchemy.orm import Mapped, backref, load_only, relationship from sqlalchemy.sql import expression from airflow import settings @@ -63,7 +62,7 @@ from airflow.timetables.simple import AssetTriggeredTimetable, NullTimetable, OnceTimetable from airflow.utils.context import Context from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.sqlalchemy import UtcDateTime, with_row_locks +from airflow.utils.sqlalchemy import UtcDateTime, mapped_column, with_row_locks from airflow.utils.state import DagRunState from airflow.utils.types import DagRunType @@ -267,8 +266,8 @@ class DagTag(Base): """A tag name per dag, to allow quick filtering in the DAG view.""" __tablename__ = "dag_tag" - name = Column(String(TAG_MAX_LEN), primary_key=True) - dag_id = Column( + name: Mapped[str] = mapped_column(String(TAG_MAX_LEN), primary_key=True) + dag_id: Mapped[str] = mapped_column( StringID(), ForeignKey("dag.dag_id", name="dag_tag_dag_id_fkey", ondelete="CASCADE"), primary_key=True, @@ -288,14 +287,14 @@ class DagOwnerAttributes(Base): """ __tablename__ = "dag_owner_attributes" - dag_id = Column( + dag_id: Mapped[str] = mapped_column( StringID(), ForeignKey("dag.dag_id", name="dag.dag_id", ondelete="CASCADE"), nullable=False, primary_key=True, ) - owner = Column(String(500), primary_key=True, nullable=False) - link = Column(String(500), nullable=False) + owner: Mapped[str] = mapped_column(String(500), primary_key=True, nullable=False) + link: Mapped[str] = mapped_column(String(500), nullable=False) def __repr__(self): return f"" @@ -315,43 +314,49 @@ class DagModel(Base): """ These items are stored in the database for state related information. """ - dag_id = Column(StringID(), primary_key=True) + dag_id: Mapped[str] = mapped_column(StringID(), primary_key=True) # A DAG can be paused from the UI / DB # Set this default value of is_paused based on a configuration value! is_paused_at_creation = airflow_conf.getboolean("core", "dags_are_paused_at_creation") - is_paused = Column(Boolean, default=is_paused_at_creation) + is_paused: Mapped[bool] = mapped_column(Boolean, default=is_paused_at_creation) # Whether that DAG was seen on the last DagBag load - is_stale = Column(Boolean, default=True) + is_stale: Mapped[bool] = mapped_column(Boolean, default=True) # Last time the scheduler started - last_parsed_time = Column(UtcDateTime) + last_parsed_time: Mapped[UtcDateTime | None] = mapped_column(UtcDateTime, nullable=True) # How long it took to parse this file - last_parse_duration = Column(Float) + last_parse_duration: Mapped[float | None] = mapped_column(Float, nullable=True) # Time when the DAG last received a refresh signal # (e.g. the DAG's "refresh" button was clicked in the web UI) - last_expired = Column(UtcDateTime) + last_expired: Mapped[UtcDateTime | None] = mapped_column(UtcDateTime, nullable=True) # The location of the file containing the DAG object # Note: Do not depend on fileloc pointing to a file; in the case of a # packaged DAG, it will point to the subpath of the DAG within the # associated zip. - fileloc = Column(String(2000)) - relative_fileloc = Column(String(2000)) - bundle_name = Column(StringID(), ForeignKey("dag_bundle.name"), nullable=False) + fileloc: Mapped[str | None] = mapped_column(String(2000), nullable=True) + relative_fileloc: Mapped[str | None] = mapped_column(String(2000), nullable=True) + bundle_name: Mapped[str] = mapped_column(StringID(), ForeignKey("dag_bundle.name"), nullable=False) # The version of the bundle the last time the DAG was processed - bundle_version = Column(String(200), nullable=True) + bundle_version: Mapped[str | None] = mapped_column(String(200), nullable=True) # String representing the owners - owners = Column(String(2000)) + owners: Mapped[str | None] = mapped_column(String(2000), nullable=True) # Display name of the dag - _dag_display_property_value = Column("dag_display_name", String(2000), nullable=True) + _dag_display_property_value: Mapped[str | None] = mapped_column( + "dag_display_name", String(2000), nullable=True + ) # Description of the dag - description = Column(Text) + description: Mapped[str | None] = mapped_column(Text, nullable=True) # Timetable summary - timetable_summary = Column(Text, nullable=True) + timetable_summary: Mapped[str | None] = mapped_column(Text, nullable=True) # Timetable description - timetable_description = Column(String(1000), nullable=True) + timetable_description: Mapped[str | None] = mapped_column(String(1000), nullable=True) # Asset expression based on asset triggers - asset_expression = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=True) + asset_expression: Mapped[dict[str, Any] | None] = mapped_column( + sqlalchemy_jsonfield.JSONField(json=json), nullable=True + ) # DAG deadline information - _deadline = Column("deadline", sqlalchemy_jsonfield.JSONField(json=json), nullable=True) + _deadline: Mapped[dict[str, Any] | None] = mapped_column( + "deadline", sqlalchemy_jsonfield.JSONField(json=json), nullable=True + ) # Tags for view filter tags = relationship("DagTag", cascade="all, delete, delete-orphan", backref=backref("dag")) # Dag owner links for DAGs view @@ -359,22 +364,24 @@ class DagModel(Base): "DagOwnerAttributes", cascade="all, delete, delete-orphan", backref=backref("dag") ) - max_active_tasks = Column(Integer, nullable=False) - max_active_runs = Column(Integer, nullable=True) # todo: should not be nullable if we have a default - max_consecutive_failed_dag_runs = Column(Integer, nullable=False) + max_active_tasks: Mapped[int] = mapped_column(Integer, nullable=False) + max_active_runs: Mapped[int | None] = mapped_column( + Integer, nullable=True + ) # todo: should not be nullable if we have a default + max_consecutive_failed_dag_runs: Mapped[int] = mapped_column(Integer, nullable=False) - has_task_concurrency_limits = Column(Boolean, nullable=False) - has_import_errors = Column(Boolean(), default=False, server_default="0") + has_task_concurrency_limits: Mapped[bool] = mapped_column(Boolean, nullable=False) + has_import_errors: Mapped[bool] = mapped_column(Boolean(), default=False, server_default="0") # The logical date of the next dag run. - next_dagrun = Column(UtcDateTime) + next_dagrun: Mapped[UtcDateTime | None] = mapped_column(UtcDateTime, nullable=True) # Must be either both NULL or both datetime. - next_dagrun_data_interval_start = Column(UtcDateTime) - next_dagrun_data_interval_end = Column(UtcDateTime) + next_dagrun_data_interval_start: Mapped[UtcDateTime | None] = mapped_column(UtcDateTime, nullable=True) + next_dagrun_data_interval_end: Mapped[UtcDateTime | None] = mapped_column(UtcDateTime, nullable=True) # Earliest time at which this ``next_dagrun`` can be created. - next_dagrun_create_after = Column(UtcDateTime) + next_dagrun_create_after: Mapped[UtcDateTime | None] = mapped_column(UtcDateTime, nullable=True) __table_args__ = (Index("idx_next_dagrun_create_after", next_dagrun_create_after, unique=False),) diff --git a/airflow-core/src/airflow/models/dag_favorite.py b/airflow-core/src/airflow/models/dag_favorite.py index 5dfb742fdaf80..8006fe086681f 100644 --- a/airflow-core/src/airflow/models/dag_favorite.py +++ b/airflow-core/src/airflow/models/dag_favorite.py @@ -17,9 +17,11 @@ # under the License. from __future__ import annotations -from sqlalchemy import Column, ForeignKey +from sqlalchemy import ForeignKey +from sqlalchemy.orm import Mapped from airflow.models.base import Base, StringID +from airflow.utils.sqlalchemy import mapped_column class DagFavorite(Base): @@ -27,5 +29,7 @@ class DagFavorite(Base): __tablename__ = "dag_favorite" - user_id = Column(StringID(), primary_key=True) - dag_id = Column(StringID(), ForeignKey("dag.dag_id", ondelete="CASCADE"), primary_key=True) + user_id: Mapped[str] = mapped_column(StringID(), primary_key=True) + dag_id: Mapped[str] = mapped_column( + StringID(), ForeignKey("dag.dag_id", ondelete="CASCADE"), primary_key=True + ) diff --git a/airflow-core/src/airflow/models/dag_version.py b/airflow-core/src/airflow/models/dag_version.py index a51f3cb301867..687e1367ece73 100644 --- a/airflow-core/src/airflow/models/dag_version.py +++ b/airflow-core/src/airflow/models/dag_version.py @@ -21,19 +21,20 @@ from typing import TYPE_CHECKING import uuid6 -from sqlalchemy import Column, ForeignKey, Integer, UniqueConstraint, select -from sqlalchemy.orm import joinedload, relationship +from sqlalchemy import ForeignKey, Integer, UniqueConstraint, select +from sqlalchemy.orm import Mapped, joinedload, relationship from sqlalchemy_utils import UUIDType from airflow._shared.timezones import timezone from airflow.models.base import Base, StringID from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.sqlalchemy import UtcDateTime, with_row_locks +from airflow.utils.sqlalchemy import UtcDateTime, mapped_column, with_row_locks if TYPE_CHECKING: from sqlalchemy.orm import Session from sqlalchemy.sql import Select + log = logging.getLogger(__name__) @@ -41,12 +42,14 @@ class DagVersion(Base): """Model to track the versions of DAGs in the database.""" __tablename__ = "dag_version" - id = Column(UUIDType(binary=False), primary_key=True, default=uuid6.uuid7) - version_number = Column(Integer, nullable=False, default=1) - dag_id = Column(StringID(), ForeignKey("dag.dag_id", ondelete="CASCADE"), nullable=False) + id: Mapped[str] = mapped_column(UUIDType(binary=False), primary_key=True, default=uuid6.uuid7) + version_number: Mapped[int] = mapped_column(Integer, nullable=False, default=1) + dag_id: Mapped[str] = mapped_column( + StringID(), ForeignKey("dag.dag_id", ondelete="CASCADE"), nullable=False + ) dag_model = relationship("DagModel", back_populates="dag_versions") - bundle_name = Column(StringID(), nullable=True) - bundle_version = Column(StringID()) + bundle_name: Mapped[str | None] = mapped_column(StringID(), nullable=True) + bundle_version: Mapped[str | None] = mapped_column(StringID(), nullable=True) dag_code = relationship( "DagCode", back_populates="dag_version", @@ -62,8 +65,10 @@ class DagVersion(Base): cascade_backrefs=False, ) task_instances = relationship("TaskInstance", back_populates="dag_version") - created_at = Column(UtcDateTime, nullable=False, default=timezone.utcnow) - last_updated = Column(UtcDateTime, nullable=False, default=timezone.utcnow, onupdate=timezone.utcnow) + created_at: Mapped[UtcDateTime] = mapped_column(UtcDateTime, nullable=False, default=timezone.utcnow) + last_updated: Mapped[UtcDateTime] = mapped_column( + UtcDateTime, nullable=False, default=timezone.utcnow, onupdate=timezone.utcnow + ) __table_args__ = ( UniqueConstraint("dag_id", "version_number", name="dag_id_v_name_v_number_unique_constraint"), diff --git a/airflow-core/src/airflow/models/dagbag.py b/airflow-core/src/airflow/models/dagbag.py index 928e8e406d4ec..a7d76bd1f5307 100644 --- a/airflow-core/src/airflow/models/dagbag.py +++ b/airflow-core/src/airflow/models/dagbag.py @@ -20,12 +20,13 @@ import hashlib from typing import TYPE_CHECKING, Any -from sqlalchemy import Column, String, inspect, select -from sqlalchemy.orm import joinedload +from sqlalchemy import String, inspect, select +from sqlalchemy.orm import Mapped, joinedload from sqlalchemy.orm.attributes import NO_VALUE from airflow.models.base import Base, StringID from airflow.models.dag_version import DagVersion +from airflow.utils.sqlalchemy import mapped_column if TYPE_CHECKING: from collections.abc import Generator @@ -115,14 +116,16 @@ class DagPriorityParsingRequest(Base): # Adding a unique constraint to fileloc results in the creation of an index and we have a limitation # on the size of the string we can use in the index for MySQL DB. We also have to keep the fileloc # size consistent with other tables. This is a workaround to enforce the unique constraint. - id = Column(String(32), primary_key=True, default=generate_md5_hash, onupdate=generate_md5_hash) + id: Mapped[str] = mapped_column( + String(32), primary_key=True, default=generate_md5_hash, onupdate=generate_md5_hash + ) - bundle_name = Column(StringID(), nullable=False) + bundle_name: Mapped[str] = mapped_column(StringID(), nullable=False) # The location of the file containing the DAG object # Note: Do not depend on fileloc pointing to a file; in the case of a # packaged DAG, it will point to the subpath of the DAG within the # associated zip. - relative_fileloc = Column(String(2000), nullable=False) + relative_fileloc: Mapped[str] = mapped_column(String(2000), nullable=False) def __init__(self, bundle_name: str, relative_fileloc: str) -> None: super().__init__() diff --git a/airflow-core/src/airflow/models/dagbundle.py b/airflow-core/src/airflow/models/dagbundle.py index 91e3f97c402b3..1def1ac7189dc 100644 --- a/airflow-core/src/airflow/models/dagbundle.py +++ b/airflow-core/src/airflow/models/dagbundle.py @@ -16,14 +16,14 @@ # under the License. from __future__ import annotations -from sqlalchemy import Boolean, Column, String -from sqlalchemy.orm import relationship +from sqlalchemy import Boolean, String +from sqlalchemy.orm import Mapped, relationship from sqlalchemy_utils import JSONType from airflow.models.base import Base, StringID from airflow.models.team import dag_bundle_team_association_table from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.sqlalchemy import UtcDateTime +from airflow.utils.sqlalchemy import UtcDateTime, mapped_column class DagBundleModel(Base, LoggingMixin): @@ -42,12 +42,12 @@ class DagBundleModel(Base, LoggingMixin): """ __tablename__ = "dag_bundle" - name = Column(StringID(length=250), primary_key=True, nullable=False) - active = Column(Boolean, default=True) - version = Column(String(200), nullable=True) - last_refreshed = Column(UtcDateTime, nullable=True) - signed_url_template = Column(String(200), nullable=True) - template_params = Column(JSONType, nullable=True) + name: Mapped[str] = mapped_column(StringID(length=250), primary_key=True, nullable=False) + active: Mapped[bool | None] = mapped_column(Boolean, default=True, nullable=True) + version: Mapped[str | None] = mapped_column(String(200), nullable=True) + last_refreshed: Mapped[UtcDateTime | None] = mapped_column(UtcDateTime, nullable=True) + signed_url_template: Mapped[str | None] = mapped_column(String(200), nullable=True) + template_params: Mapped[dict | None] = mapped_column(JSONType, nullable=True) teams = relationship("Team", secondary=dag_bundle_team_association_table, back_populates="dag_bundles") def __init__(self, *, name: str, version: str | None = None): diff --git a/airflow-core/src/airflow/models/dagcode.py b/airflow-core/src/airflow/models/dagcode.py index a68885e1dfdf2..5be3c5395f274 100644 --- a/airflow-core/src/airflow/models/dagcode.py +++ b/airflow-core/src/airflow/models/dagcode.py @@ -20,9 +20,9 @@ from typing import TYPE_CHECKING import uuid6 -from sqlalchemy import Column, ForeignKey, String, Text, select +from sqlalchemy import ForeignKey, String, Text, select from sqlalchemy.dialects.mysql import MEDIUMTEXT -from sqlalchemy.orm import relationship +from sqlalchemy.orm import Mapped, relationship from sqlalchemy.sql.expression import literal from sqlalchemy_utils import UUIDType @@ -33,7 +33,7 @@ from airflow.utils.file import open_maybe_zipped from airflow.utils.hashlib_wrapper import md5 from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.sqlalchemy import UtcDateTime +from airflow.utils.sqlalchemy import UtcDateTime, mapped_column if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -54,15 +54,17 @@ class DagCode(Base): """ __tablename__ = "dag_code" - id = Column(UUIDType(binary=False), primary_key=True, default=uuid6.uuid7) - dag_id = Column(String(ID_LEN), nullable=False) - fileloc = Column(String(2000), nullable=False) + id: Mapped[str] = mapped_column(UUIDType(binary=False), primary_key=True, default=uuid6.uuid7) + dag_id: Mapped[str] = mapped_column(String(ID_LEN), nullable=False) + fileloc: Mapped[str] = mapped_column(String(2000), nullable=False) # The max length of fileloc exceeds the limit of indexing. - created_at = Column(UtcDateTime, nullable=False, default=timezone.utcnow) - last_updated = Column(UtcDateTime, nullable=False, default=timezone.utcnow, onupdate=timezone.utcnow) - source_code = Column(Text().with_variant(MEDIUMTEXT(), "mysql"), nullable=False) - source_code_hash = Column(String(32), nullable=False) - dag_version_id = Column( + created_at: Mapped[UtcDateTime] = mapped_column(UtcDateTime, nullable=False, default=timezone.utcnow) + last_updated: Mapped[UtcDateTime] = mapped_column( + UtcDateTime, nullable=False, default=timezone.utcnow, onupdate=timezone.utcnow + ) + source_code: Mapped[str] = mapped_column(Text().with_variant(MEDIUMTEXT(), "mysql"), nullable=False) + source_code_hash: Mapped[str] = mapped_column(String(32), nullable=False) + dag_version_id: Mapped[str] = mapped_column( UUIDType(binary=False), ForeignKey("dag_version.id", ondelete="CASCADE"), nullable=False, unique=True ) dag_version = relationship("DagVersion", back_populates="dag_code", uselist=False) diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index aa8f01dd3160c..2c9cac3193ded 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -29,7 +29,6 @@ from natsort import natsorted from sqlalchemy import ( JSON, - Column, Enum, ForeignKey, ForeignKeyConstraint, @@ -52,7 +51,7 @@ from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.ext.mutable import MutableDict -from sqlalchemy.orm import declared_attr, joinedload, relationship, synonym, validates +from sqlalchemy.orm import Mapped, declared_attr, joinedload, relationship, synonym, validates from sqlalchemy.sql.expression import false, select from sqlalchemy.sql.functions import coalesce from sqlalchemy_utils import UUIDType @@ -80,7 +79,7 @@ from airflow.utils.retries import retry_db_transaction from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.span_status import SpanStatus -from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime, nulls_first, with_row_locks +from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime, mapped_column, nulls_first, with_row_locks from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.strings import get_random_string from airflow.utils.thread_safe_dict import ThreadSafeDict @@ -148,57 +147,65 @@ class DagRun(Base, LoggingMixin): __tablename__ = "dag_run" - id = Column(Integer, primary_key=True) - dag_id = Column(StringID(), nullable=False) - queued_at = Column(UtcDateTime) - logical_date = Column(UtcDateTime, nullable=True) - start_date = Column(UtcDateTime) - end_date = Column(UtcDateTime) - _state = Column("state", String(50), default=DagRunState.QUEUED) - run_id = Column(StringID(), nullable=False) - creating_job_id = Column(Integer) - run_type = Column(String(50), nullable=False) - triggered_by = Column( - Enum(DagRunTriggeredByType, native_enum=False, length=50) + id: Mapped[int] = mapped_column(Integer, primary_key=True) + dag_id: Mapped[str] = mapped_column(StringID(), nullable=False) + queued_at: Mapped[UtcDateTime | None] = mapped_column(UtcDateTime, nullable=True) + logical_date: Mapped[UtcDateTime | None] = mapped_column(UtcDateTime, nullable=True) + start_date: Mapped[UtcDateTime | None] = mapped_column(UtcDateTime, nullable=True) + end_date: Mapped[UtcDateTime | None] = mapped_column(UtcDateTime, nullable=True) + _state: Mapped[str] = mapped_column("state", String(50), default=DagRunState.QUEUED) + run_id: Mapped[str] = mapped_column(StringID(), nullable=False) + creating_job_id: Mapped[int | None] = mapped_column(Integer, nullable=True) + run_type: Mapped[str] = mapped_column(String(50), nullable=False) + triggered_by: Mapped[DagRunTriggeredByType | None] = mapped_column( + Enum(DagRunTriggeredByType, native_enum=False, length=50), nullable=True ) # Airflow component that triggered the run. - triggering_user_name = Column( + triggering_user_name: Mapped[str | None] = mapped_column( String(512), nullable=True, ) # The user that triggered the DagRun, if applicable - conf = Column(JSON().with_variant(postgresql.JSONB, "postgresql")) + conf: Mapped[dict[str, Any] | None] = mapped_column( + JSON().with_variant(postgresql.JSONB, "postgresql"), nullable=True + ) # These two must be either both NULL or both datetime. - data_interval_start = Column(UtcDateTime) - data_interval_end = Column(UtcDateTime) + data_interval_start: Mapped[UtcDateTime | None] = mapped_column(UtcDateTime, nullable=True) + data_interval_end: Mapped[UtcDateTime | None] = mapped_column(UtcDateTime, nullable=True) # Earliest time when this DagRun can start running. - run_after = Column(UtcDateTime, default=_default_run_after, nullable=False) + run_after: Mapped[UtcDateTime] = mapped_column(UtcDateTime, default=_default_run_after, nullable=False) # When a scheduler last attempted to schedule TIs for this DagRun - last_scheduling_decision = Column(UtcDateTime) + last_scheduling_decision: Mapped[UtcDateTime | None] = mapped_column(UtcDateTime, nullable=True) # Foreign key to LogTemplate. DagRun rows created prior to this column's # existence have this set to NULL. Later rows automatically populate this on # insert to point to the latest LogTemplate entry. - log_template_id = Column( + log_template_id: Mapped[int] = mapped_column( Integer, ForeignKey("log_template.id", name="task_instance_log_template_id_fkey", ondelete="NO ACTION"), default=select(func.max(LogTemplate.__table__.c.id)), ) - updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow) + updated_at: Mapped[UtcDateTime] = mapped_column( + UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow + ) # Keeps track of the number of times the dagrun had been cleared. # This number is incremented only when the DagRun is re-Queued, # when the DagRun is cleared. - clear_number = Column(Integer, default=0, nullable=False, server_default="0") - backfill_id = Column(Integer, ForeignKey("backfill.id"), nullable=True) + clear_number: Mapped[int] = mapped_column(Integer, default=0, nullable=False, server_default="0") + backfill_id: Mapped[int | None] = mapped_column(Integer, ForeignKey("backfill.id"), nullable=True) """ The backfill this DagRun is currently associated with. It's possible this could change if e.g. the dag run is cleared to be rerun, or perhaps re-backfilled. """ - bundle_version = Column(StringID()) + bundle_version: Mapped[str | None] = mapped_column(StringID(), nullable=True) - scheduled_by_job_id = Column(Integer) + scheduled_by_job_id: Mapped[int | None] = mapped_column(Integer, nullable=True) # Span context carrier, used for context propagation. - context_carrier = Column(MutableDict.as_mutable(ExtendedJSON)) - span_status = Column(String(250), server_default=SpanStatus.NOT_STARTED, nullable=False) - created_dag_version_id = Column( + context_carrier: Mapped[dict[str, Any] | None] = mapped_column( + MutableDict.as_mutable(ExtendedJSON), nullable=True + ) + span_status: Mapped[str] = mapped_column( + String(250), server_default=SpanStatus.NOT_STARTED, nullable=False + ) + created_dag_version_id: Mapped[str | None] = mapped_column( UUIDType(binary=False), ForeignKey("dag_version.id", name="created_dag_version_id_fkey", ondelete="set null"), nullable=True, @@ -2115,11 +2122,13 @@ class DagRunNote(Base): __tablename__ = "dag_run_note" - user_id = Column(String(128), nullable=True) - dag_run_id = Column(Integer, primary_key=True, nullable=False) - content = Column(String(1000).with_variant(Text(1000), "mysql")) - created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) - updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) + user_id: Mapped[str | None] = mapped_column(String(128), nullable=True) + dag_run_id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False) + content: Mapped[str | None] = mapped_column(String(1000).with_variant(Text(1000), "mysql")) + created_at: Mapped[UtcDateTime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) + updated_at: Mapped[UtcDateTime] = mapped_column( + UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False + ) dag_run = relationship("DagRun", back_populates="dag_run_note") diff --git a/airflow-core/src/airflow/models/dagwarning.py b/airflow-core/src/airflow/models/dagwarning.py index f11b5037a9090..cec81ce8abeaf 100644 --- a/airflow-core/src/airflow/models/dagwarning.py +++ b/airflow-core/src/airflow/models/dagwarning.py @@ -20,15 +20,15 @@ from enum import Enum from typing import TYPE_CHECKING -from sqlalchemy import Column, ForeignKeyConstraint, Index, String, Text, delete, select, true -from sqlalchemy.orm import relationship +from sqlalchemy import ForeignKeyConstraint, Index, String, Text, delete, select, true +from sqlalchemy.orm import Mapped, relationship from airflow._shared.timezones import timezone from airflow.models.base import Base, StringID from airflow.models.dag import DagModel from airflow.utils.retries import retry_db_transaction from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.sqlalchemy import UtcDateTime +from airflow.utils.sqlalchemy import UtcDateTime, mapped_column if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -43,10 +43,10 @@ class DagWarning(Base): when parsing DAG and displayed on the Webserver in a flash message. """ - dag_id = Column(StringID(), primary_key=True) - warning_type = Column(String(50), primary_key=True) - message = Column(Text, nullable=False) - timestamp = Column(UtcDateTime, nullable=False, default=timezone.utcnow) + dag_id: Mapped[str] = mapped_column(StringID(), primary_key=True) + warning_type: Mapped[str] = mapped_column(String(50), primary_key=True) + message: Mapped[str] = mapped_column(Text, nullable=False) + timestamp: Mapped[UtcDateTime] = mapped_column(UtcDateTime, nullable=False, default=timezone.utcnow) dag_model = relationship("DagModel", viewonly=True, lazy="selectin") diff --git a/airflow-core/src/airflow/models/db_callback_request.py b/airflow-core/src/airflow/models/db_callback_request.py index f1009ca1babe4..2176f3cccfb9f 100644 --- a/airflow-core/src/airflow/models/db_callback_request.py +++ b/airflow-core/src/airflow/models/db_callback_request.py @@ -20,11 +20,12 @@ from importlib import import_module from typing import TYPE_CHECKING -from sqlalchemy import Column, Integer, String +from sqlalchemy import Integer, String +from sqlalchemy.orm import Mapped from airflow._shared.timezones import timezone from airflow.models.base import Base -from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime +from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime, mapped_column if TYPE_CHECKING: from airflow.callbacks.callback_requests import CallbackRequest @@ -35,11 +36,11 @@ class DbCallbackRequest(Base): __tablename__ = "callback_request" - id = Column(Integer(), nullable=False, primary_key=True) - created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) - priority_weight = Column(Integer(), nullable=False) - callback_data = Column(ExtendedJSON, nullable=False) - callback_type = Column(String(20), nullable=False) + id: Mapped[int] = mapped_column(Integer(), nullable=False, primary_key=True) + created_at: Mapped[UtcDateTime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) + priority_weight: Mapped[int] = mapped_column(Integer(), nullable=False) + callback_data: Mapped[dict] = mapped_column(ExtendedJSON, nullable=False) + callback_type: Mapped[str] = mapped_column(String(20), nullable=False) def __init__(self, priority_weight: int, callback: CallbackRequest): self.created_at = timezone.utcnow() diff --git a/airflow-core/src/airflow/models/deadline.py b/airflow-core/src/airflow/models/deadline.py index 21e49a36ed1e7..052901f3ee281 100644 --- a/airflow-core/src/airflow/models/deadline.py +++ b/airflow-core/src/airflow/models/deadline.py @@ -26,9 +26,9 @@ import sqlalchemy_jsonfield import uuid6 -from sqlalchemy import Column, ForeignKey, Index, Integer, String, and_, func, select, text +from sqlalchemy import ForeignKey, Index, Integer, String, and_, func, select, text from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.orm import relationship +from sqlalchemy.orm import Mapped, relationship from sqlalchemy_utils import UUIDType from airflow._shared.timezones import timezone @@ -40,7 +40,7 @@ from airflow.triggers.deadline import PAYLOAD_BODY_KEY, PAYLOAD_STATUS_KEY, DeadlineCallbackTrigger from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import provide_session -from airflow.utils.sqlalchemy import UtcDateTime +from airflow.utils.sqlalchemy import UtcDateTime, mapped_column if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -97,22 +97,26 @@ class Deadline(Base): __tablename__ = "deadline" - id = Column(UUIDType(binary=False), primary_key=True, default=uuid6.uuid7) + id: Mapped[str] = mapped_column(UUIDType(binary=False), primary_key=True, default=uuid6.uuid7) # If the Deadline Alert is for a DAG, store the DAG run ID from the dag_run. - dagrun_id = Column(Integer, ForeignKey("dag_run.id", ondelete="CASCADE")) + dagrun_id: Mapped[int | None] = mapped_column( + Integer, ForeignKey("dag_run.id", ondelete="CASCADE"), nullable=True + ) # The time after which the Deadline has passed and the callback should be triggered. - deadline_time = Column(UtcDateTime, nullable=False) + deadline_time: Mapped[UtcDateTime] = mapped_column(UtcDateTime, nullable=False) # The (serialized) callback to be called when the Deadline has passed. - _callback = Column("callback", sqlalchemy_jsonfield.JSONField(json=json), nullable=False) + _callback: Mapped[dict] = mapped_column( + "callback", sqlalchemy_jsonfield.JSONField(json=json), nullable=False + ) # The state of the deadline callback - callback_state = Column(String(20)) + callback_state: Mapped[str | None] = mapped_column(String(20), nullable=True) dagrun = relationship("DagRun", back_populates="deadlines") # The Trigger where the callback is running - trigger_id = Column(Integer, ForeignKey("trigger.id"), nullable=True) + trigger_id: Mapped[int | None] = mapped_column(Integer, ForeignKey("trigger.id"), nullable=True) trigger = relationship("Trigger", back_populates="deadline") __table_args__ = (Index("deadline_callback_state_time_idx", callback_state, deadline_time, unique=False),) @@ -145,7 +149,7 @@ def _determine_resource() -> tuple[str, str]: ) @classmethod - def prune_deadlines(cls, *, session: Session, conditions: dict[Column, Any]) -> int: + def prune_deadlines(cls, *, session: Session, conditions: dict[Mapped, Any]) -> int: """ Remove deadlines from the table which match the provided conditions and return the number removed. @@ -479,7 +483,7 @@ def deserialize_reference(cls, reference_data: dict): @provide_session -def _fetch_from_db(model_reference: Column, session=None, **conditions) -> datetime: +def _fetch_from_db(model_reference: Mapped, session=None, **conditions) -> datetime: """ Fetch a datetime value from the database using the provided model reference and filtering conditions. diff --git a/airflow-core/src/airflow/models/errors.py b/airflow-core/src/airflow/models/errors.py index 6670df1dfaf62..16d6720662386 100644 --- a/airflow-core/src/airflow/models/errors.py +++ b/airflow-core/src/airflow/models/errors.py @@ -17,22 +17,23 @@ # under the License. from __future__ import annotations -from sqlalchemy import Column, Integer, String, Text +from sqlalchemy import Integer, String, Text +from sqlalchemy.orm import Mapped from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.models.base import Base, StringID -from airflow.utils.sqlalchemy import UtcDateTime +from airflow.utils.sqlalchemy import UtcDateTime, mapped_column class ParseImportError(Base): """Stores all Import Errors which are recorded when parsing DAGs and displayed on the Webserver.""" __tablename__ = "import_error" - id = Column(Integer, primary_key=True) - timestamp = Column(UtcDateTime) - filename = Column(String(1024)) - bundle_name = Column(StringID()) - stacktrace = Column(Text) + id: Mapped[int] = mapped_column(Integer, primary_key=True) + timestamp: Mapped[UtcDateTime | None] = mapped_column(UtcDateTime, nullable=True) + filename: Mapped[str | None] = mapped_column(String(1024), nullable=True) + bundle_name: Mapped[str | None] = mapped_column(StringID(), nullable=True) + stacktrace: Mapped[str | None] = mapped_column(Text, nullable=True) def full_file_path(self) -> str: """Return the full file path of the dag.""" diff --git a/airflow-core/src/airflow/models/hitl.py b/airflow-core/src/airflow/models/hitl.py index c1bf9b8fb5c96..fbfee77fbcf27 100644 --- a/airflow-core/src/airflow/models/hitl.py +++ b/airflow-core/src/airflow/models/hitl.py @@ -19,17 +19,17 @@ from typing import TYPE_CHECKING, Any, TypedDict import sqlalchemy_jsonfield -from sqlalchemy import Boolean, Column, ForeignKeyConstraint, String, Text, func, literal +from sqlalchemy import Boolean, ForeignKeyConstraint, String, Text, func, literal from sqlalchemy.dialects import postgresql from sqlalchemy.ext.compiler import compiles from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.orm import relationship +from sqlalchemy.orm import Mapped, relationship from sqlalchemy.sql.functions import FunctionElement from airflow._shared.timezones import timezone from airflow.models.base import Base from airflow.settings import json -from airflow.utils.sqlalchemy import UtcDateTime +from airflow.utils.sqlalchemy import UtcDateTime, mapped_column if TYPE_CHECKING: from sqlalchemy.sql import ColumnElement @@ -137,31 +137,37 @@ class HITLDetail(Base, HITLDetailPropertyMixin): """Human-in-the-loop request and corresponding response.""" __tablename__ = "hitl_detail" - ti_id = Column( + ti_id: Mapped[str] = mapped_column( String(36).with_variant(postgresql.UUID(as_uuid=False), "postgresql"), primary_key=True, nullable=False, ) # User Request Detail - options = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False) - subject = Column(Text, nullable=False) - body = Column(Text, nullable=True) - defaults = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=True) - multiple = Column(Boolean, unique=False, default=False) - params = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={}) - assignees = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=True) - created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) + options: Mapped[dict] = mapped_column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False) + subject: Mapped[str] = mapped_column(Text, nullable=False) + body: Mapped[str | None] = mapped_column(Text, nullable=True) + defaults: Mapped[dict | None] = mapped_column(sqlalchemy_jsonfield.JSONField(json=json), nullable=True) + multiple: Mapped[bool | None] = mapped_column(Boolean, unique=False, default=False, nullable=True) + params: Mapped[dict] = mapped_column( + sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={} + ) + assignees: Mapped[dict | None] = mapped_column(sqlalchemy_jsonfield.JSONField(json=json), nullable=True) + created_at: Mapped[UtcDateTime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) # Response Content Detail - responded_at = Column(UtcDateTime, nullable=True) - responded_by = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=True) - chosen_options = Column( + responded_at: Mapped[UtcDateTime | None] = mapped_column(UtcDateTime, nullable=True) + responded_by: Mapped[dict | None] = mapped_column( + sqlalchemy_jsonfield.JSONField(json=json), nullable=True + ) + chosen_options: Mapped[dict | None] = mapped_column( sqlalchemy_jsonfield.JSONField(json=json), nullable=True, default=None, ) - params_input = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={}) + params_input: Mapped[dict] = mapped_column( + sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={} + ) task_instance = relationship( "TaskInstance", lazy="joined", diff --git a/airflow-core/src/airflow/models/log.py b/airflow-core/src/airflow/models/log.py index 645c3ee6c7783..f0849332bd265 100644 --- a/airflow-core/src/airflow/models/log.py +++ b/airflow-core/src/airflow/models/log.py @@ -19,12 +19,12 @@ from typing import TYPE_CHECKING -from sqlalchemy import Column, Index, Integer, String, Text -from sqlalchemy.orm import relationship +from sqlalchemy import Index, Integer, String, Text +from sqlalchemy.orm import Mapped, relationship from airflow._shared.timezones import timezone from airflow.models.base import Base, StringID -from airflow.utils.sqlalchemy import UtcDateTime +from airflow.utils.sqlalchemy import UtcDateTime, mapped_column if TYPE_CHECKING: from airflow.models.taskinstance import TaskInstance @@ -36,18 +36,18 @@ class Log(Base): __tablename__ = "log" - id = Column(Integer, primary_key=True) - dttm = Column(UtcDateTime) - dag_id = Column(StringID()) - task_id = Column(StringID()) - map_index = Column(Integer) - event = Column(String(60)) - logical_date = Column(UtcDateTime) - run_id = Column(StringID()) - owner = Column(String(500)) - owner_display_name = Column(String(500)) - extra = Column(Text) - try_number = Column(Integer) + id: Mapped[int] = mapped_column(Integer, primary_key=True) + dttm: Mapped[UtcDateTime] = mapped_column(UtcDateTime) + dag_id: Mapped[str | None] = mapped_column(StringID(), nullable=True) + task_id: Mapped[str | None] = mapped_column(StringID(), nullable=True) + map_index: Mapped[int | None] = mapped_column(Integer, nullable=True) + event: Mapped[str] = mapped_column(String(60)) + logical_date: Mapped[UtcDateTime | None] = mapped_column(UtcDateTime, nullable=True) + run_id: Mapped[str | None] = mapped_column(StringID(), nullable=True) + owner: Mapped[str | None] = mapped_column(String(500), nullable=True) + owner_display_name: Mapped[str | None] = mapped_column(String(500), nullable=True) + extra: Mapped[str | None] = mapped_column(Text, nullable=True) + try_number: Mapped[int | None] = mapped_column(Integer, nullable=True) dag_model = relationship( "DagModel", diff --git a/airflow-core/src/airflow/models/pool.py b/airflow-core/src/airflow/models/pool.py index 6f9f159c51857..04fa9d5010ae8 100644 --- a/airflow-core/src/airflow/models/pool.py +++ b/airflow-core/src/airflow/models/pool.py @@ -19,7 +19,8 @@ from typing import TYPE_CHECKING, Any, TypedDict -from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Text, func, select +from sqlalchemy import Boolean, ForeignKey, Integer, String, Text, func, select +from sqlalchemy.orm import Mapped from sqlalchemy_utils import UUIDType from airflow.exceptions import AirflowException, PoolNotFound @@ -28,7 +29,7 @@ from airflow.ti_deps.dependencies_states import EXECUTION_STATES from airflow.utils.db import exists_query from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.sqlalchemy import with_row_locks +from airflow.utils.sqlalchemy import mapped_column, with_row_locks from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: @@ -51,13 +52,13 @@ class Pool(Base): __tablename__ = "slot_pool" - id = Column(Integer, primary_key=True) - pool = Column(String(256), unique=True) + id: Mapped[int] = mapped_column(Integer, primary_key=True) + pool: Mapped[str] = mapped_column(String(256), unique=True) # -1 for infinite - slots = Column(Integer, default=0) - description = Column(Text) - include_deferred = Column(Boolean, nullable=False) - team_id = Column(UUIDType(binary=False), ForeignKey("team.id"), nullable=True) + slots: Mapped[int] = mapped_column(Integer, default=0) + description: Mapped[str | None] = mapped_column(Text, nullable=True) + include_deferred: Mapped[bool] = mapped_column(Boolean, nullable=False) + team_id: Mapped[str | None] = mapped_column(UUIDType(binary=False), ForeignKey("team.id"), nullable=True) DEFAULT_POOL_NAME = "default_pool" diff --git a/airflow-core/src/airflow/models/renderedtifields.py b/airflow-core/src/airflow/models/renderedtifields.py index b19b7f1f6edf9..1f55d946cb0d2 100644 --- a/airflow-core/src/airflow/models/renderedtifields.py +++ b/airflow-core/src/airflow/models/renderedtifields.py @@ -24,7 +24,6 @@ import sqlalchemy_jsonfield from sqlalchemy import ( - Column, ForeignKeyConstraint, Integer, PrimaryKeyConstraint, @@ -34,7 +33,7 @@ text, ) from sqlalchemy.ext.associationproxy import association_proxy -from sqlalchemy.orm import relationship +from sqlalchemy.orm import Mapped, relationship from airflow.configuration import conf from airflow.models.base import StringID, TaskInstanceDependencies @@ -42,6 +41,7 @@ from airflow.settings import json from airflow.utils.retries import retry_db_transaction from airflow.utils.session import NEW_SESSION, provide_session +from airflow.utils.sqlalchemy import mapped_column if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -69,12 +69,14 @@ class RenderedTaskInstanceFields(TaskInstanceDependencies): __tablename__ = "rendered_task_instance_fields" - dag_id = Column(StringID(), primary_key=True) - task_id = Column(StringID(), primary_key=True) - run_id = Column(StringID(), primary_key=True) - map_index = Column(Integer, primary_key=True, server_default=text("-1")) - rendered_fields = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False) - k8s_pod_yaml = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=True) + dag_id: Mapped[str] = mapped_column(StringID(), primary_key=True) + task_id: Mapped[str] = mapped_column(StringID(), primary_key=True) + run_id: Mapped[str] = mapped_column(StringID(), primary_key=True) + map_index: Mapped[int] = mapped_column(Integer, primary_key=True, server_default=text("-1")) + rendered_fields: Mapped[dict] = mapped_column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False) + k8s_pod_yaml: Mapped[dict | None] = mapped_column( + sqlalchemy_jsonfield.JSONField(json=json), nullable=True + ) __table_args__ = ( PrimaryKeyConstraint( diff --git a/airflow-core/src/airflow/models/serialized_dag.py b/airflow-core/src/airflow/models/serialized_dag.py index 32a44baaffccf..c0bf5be286009 100644 --- a/airflow-core/src/airflow/models/serialized_dag.py +++ b/airflow-core/src/airflow/models/serialized_dag.py @@ -27,9 +27,9 @@ import sqlalchemy_jsonfield import uuid6 -from sqlalchemy import Column, ForeignKey, LargeBinary, String, select, tuple_ +from sqlalchemy import ForeignKey, LargeBinary, String, select, tuple_ from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy.orm import backref, foreign, relationship +from sqlalchemy.orm import Mapped, backref, foreign, relationship from sqlalchemy.sql.expression import func, literal from sqlalchemy_utils import UUIDType @@ -49,7 +49,7 @@ from airflow.settings import COMPRESS_SERIALIZED_DAGS, json from airflow.utils.hashlib_wrapper import md5 from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.sqlalchemy import UtcDateTime +from airflow.utils.sqlalchemy import UtcDateTime, mapped_column if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -280,15 +280,17 @@ class SerializedDagModel(Base): """ __tablename__ = "serialized_dag" - id = Column(UUIDType(binary=False), primary_key=True, default=uuid6.uuid7) - dag_id = Column(String(ID_LEN), nullable=False) - _data = Column( + id: Mapped[str] = mapped_column(UUIDType(binary=False), primary_key=True, default=uuid6.uuid7) + dag_id: Mapped[str] = mapped_column(String(ID_LEN), nullable=False) + _data: Mapped[dict | None] = mapped_column( "data", sqlalchemy_jsonfield.JSONField(json=json).with_variant(JSONB, "postgresql"), nullable=True ) - _data_compressed = Column("data_compressed", LargeBinary, nullable=True) - created_at = Column(UtcDateTime, nullable=False, default=timezone.utcnow) - last_updated = Column(UtcDateTime, nullable=False, default=timezone.utcnow, onupdate=timezone.utcnow) - dag_hash = Column(String(32), nullable=False) + _data_compressed: Mapped[bytes | None] = mapped_column("data_compressed", LargeBinary, nullable=True) + created_at: Mapped[UtcDateTime] = mapped_column(UtcDateTime, nullable=False, default=timezone.utcnow) + last_updated: Mapped[UtcDateTime] = mapped_column( + UtcDateTime, nullable=False, default=timezone.utcnow, onupdate=timezone.utcnow + ) + dag_hash: Mapped[str] = mapped_column(String(32), nullable=False) dag_runs = relationship( DagRun, @@ -304,7 +306,7 @@ class SerializedDagModel(Base): innerjoin=True, backref=backref("serialized_dag", uselist=False, innerjoin=True), ) - dag_version_id = Column( + dag_version_id: Mapped[str] = mapped_column( UUIDType(binary=False), ForeignKey("dag_version.id", ondelete="CASCADE"), nullable=False, diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index e9806696e243b..f0d6294b0a18c 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -25,7 +25,7 @@ import uuid from collections import defaultdict from collections.abc import Collection, Iterable -from datetime import timedelta +from datetime import datetime, timedelta from functools import cache from typing import TYPE_CHECKING, Any from urllib.parse import quote @@ -35,7 +35,6 @@ import lazy_object_proxy import uuid6 from sqlalchemy import ( - Column, Float, ForeignKey, ForeignKeyConstraint, @@ -62,7 +61,7 @@ from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.ext.mutable import MutableDict -from sqlalchemy.orm import lazyload, reconstructor, relationship +from sqlalchemy.orm import Mapped, lazyload, reconstructor, relationship from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value from sqlalchemy_utils import UUIDType @@ -93,7 +92,7 @@ from airflow.utils.retries import run_with_db_retries from airflow.utils.session import NEW_SESSION, create_session, provide_session from airflow.utils.span_status import SpanStatus -from airflow.utils.sqlalchemy import ExecutorConfigType, ExtendedJSON, UtcDateTime +from airflow.utils.sqlalchemy import ExecutorConfigType, ExtendedJSON, UtcDateTime, mapped_column from airflow.utils.state import DagRunState, State, TaskInstanceState TR = TaskReschedule @@ -374,61 +373,68 @@ class TaskInstance(Base, LoggingMixin): """ __tablename__ = "task_instance" - id = Column( + id: Mapped[str] = mapped_column( String(36).with_variant(postgresql.UUID(as_uuid=False), "postgresql"), primary_key=True, default=uuid7, nullable=False, ) - task_id = Column(StringID(), nullable=False) - dag_id = Column(StringID(), nullable=False) - run_id = Column(StringID(), nullable=False) - map_index = Column(Integer, nullable=False, server_default=text("-1")) - - start_date = Column(UtcDateTime) - end_date = Column(UtcDateTime) - duration = Column(Float) - state = Column(String(20)) - try_number = Column(Integer, default=0) - max_tries = Column(Integer, server_default=text("-1")) - hostname = Column(String(1000)) - unixname = Column(String(1000)) - pool = Column(String(256), nullable=False) - pool_slots = Column(Integer, default=1, nullable=False) - queue = Column(String(256)) - priority_weight = Column(Integer) - operator = Column(String(1000)) - custom_operator_name = Column(String(1000)) - queued_dttm = Column(UtcDateTime) - scheduled_dttm = Column(UtcDateTime) - queued_by_job_id = Column(Integer) - - last_heartbeat_at = Column(UtcDateTime) - pid = Column(Integer) - executor = Column(String(1000)) - executor_config = Column(ExecutorConfigType(pickler=dill)) - updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow) - _rendered_map_index = Column("rendered_map_index", String(250)) - context_carrier = Column(MutableDict.as_mutable(ExtendedJSON)) - span_status = Column(String(250), server_default=SpanStatus.NOT_STARTED, nullable=False) - - external_executor_id = Column(StringID()) + task_id: Mapped[str] = mapped_column(StringID(), nullable=False) + dag_id: Mapped[str] = mapped_column(StringID(), nullable=False) + run_id: Mapped[str] = mapped_column(StringID(), nullable=False) + map_index: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("-1")) + + start_date: Mapped[UtcDateTime | None] = mapped_column(UtcDateTime, nullable=True) + end_date: Mapped[UtcDateTime | None] = mapped_column(UtcDateTime, nullable=True) + duration: Mapped[float | None] = mapped_column(Float, nullable=True) + state: Mapped[str | None] = mapped_column(String(20), nullable=True) + try_number: Mapped[int] = mapped_column(Integer, default=0) + max_tries: Mapped[int] = mapped_column(Integer, server_default=text("-1")) + hostname: Mapped[str] = mapped_column(String(1000)) + unixname: Mapped[str] = mapped_column(String(1000)) + pool: Mapped[str] = mapped_column(String(256), nullable=False) + pool_slots: Mapped[int] = mapped_column(Integer, default=1, nullable=False) + queue: Mapped[str] = mapped_column(String(256)) + priority_weight: Mapped[int] = mapped_column(Integer) + operator: Mapped[str | None] = mapped_column(String(1000), nullable=True) + custom_operator_name: Mapped[str] = mapped_column(String(1000)) + queued_dttm: Mapped[UtcDateTime | None] = mapped_column(UtcDateTime, nullable=True) + scheduled_dttm: Mapped[UtcDateTime | None] = mapped_column(UtcDateTime, nullable=True) + queued_by_job_id: Mapped[int | None] = mapped_column(Integer, nullable=True) + + last_heartbeat_at: Mapped[UtcDateTime | None] = mapped_column(UtcDateTime, nullable=True) + pid: Mapped[int | None] = mapped_column(Integer, nullable=True) + executor: Mapped[str | None] = mapped_column(String(1000), nullable=True) + executor_config: Mapped[dict] = mapped_column(ExecutorConfigType(pickler=dill)) + updated_at: Mapped[UtcDateTime | None] = mapped_column( + UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=True + ) + _rendered_map_index: Mapped[str | None] = mapped_column("rendered_map_index", String(250), nullable=True) + context_carrier: Mapped[dict | None] = mapped_column(MutableDict.as_mutable(ExtendedJSON), nullable=True) + span_status: Mapped[str] = mapped_column( + String(250), server_default=SpanStatus.NOT_STARTED, nullable=False + ) + + external_executor_id: Mapped[str | None] = mapped_column(StringID(), nullable=True) # The trigger to resume on if we are in state DEFERRED - trigger_id = Column(Integer) + trigger_id: Mapped[int | None] = mapped_column(Integer, nullable=True) # Optional timeout utcdatetime for the trigger (past this, we'll fail) - trigger_timeout = Column(UtcDateTime) + trigger_timeout: Mapped[UtcDateTime | None] = mapped_column(UtcDateTime, nullable=True) # The method to call next, and any extra arguments to pass to it. # Usually used when resuming from DEFERRED. - next_method = Column(String(1000)) - next_kwargs = Column(MutableDict.as_mutable(ExtendedJSON)) + next_method: Mapped[str | None] = mapped_column(String(1000), nullable=True) + next_kwargs: Mapped[dict | None] = mapped_column(MutableDict.as_mutable(ExtendedJSON), nullable=True) - _task_display_property_value = Column("task_display_name", String(2000), nullable=True) - dag_version_id = Column( + _task_display_property_value: Mapped[str | None] = mapped_column( + "task_display_name", String(2000), nullable=True + ) + dag_version_id: Mapped[str | None] = mapped_column( UUIDType(binary=False), ForeignKey("dag_version.id", ondelete="RESTRICT"), + nullable=True, ) dag_version = relationship("DagVersion", back_populates="task_instances") @@ -2202,15 +2208,17 @@ class TaskInstanceNote(Base): """For storage of arbitrary notes concerning the task instance.""" __tablename__ = "task_instance_note" - ti_id = Column( + ti_id: Mapped[str] = mapped_column( String(36).with_variant(postgresql.UUID(as_uuid=False), "postgresql"), primary_key=True, nullable=False, ) - user_id = Column(String(128), nullable=True) - content = Column(String(1000).with_variant(Text(1000), "mysql")) - created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) - updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) + user_id: Mapped[str | None] = mapped_column(String(128), nullable=True) + content: Mapped[str | None] = mapped_column(String(1000).with_variant(Text(1000), "mysql")) + created_at: Mapped[UtcDateTime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) + updated_at: Mapped[UtcDateTime] = mapped_column( + UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False + ) task_instance = relationship("TaskInstance", back_populates="task_instance_note", uselist=False) diff --git a/airflow-core/src/airflow/models/taskinstancehistory.py b/airflow-core/src/airflow/models/taskinstancehistory.py index 4ca4caecc4ce3..8b09b001847f0 100644 --- a/airflow-core/src/airflow/models/taskinstancehistory.py +++ b/airflow-core/src/airflow/models/taskinstancehistory.py @@ -21,7 +21,6 @@ import dill from sqlalchemy import ( - Column, DateTime, Float, ForeignKeyConstraint, @@ -35,7 +34,7 @@ ) from sqlalchemy.dialects import postgresql from sqlalchemy.ext.mutable import MutableDict -from sqlalchemy.orm import relationship +from sqlalchemy.orm import Mapped, relationship from sqlalchemy_utils import UUIDType from airflow._shared.timezones import timezone @@ -48,6 +47,7 @@ ExecutorConfigType, ExtendedJSON, UtcDateTime, + mapped_column, ) from airflow.utils.state import State, TaskInstanceState @@ -66,48 +66,52 @@ class TaskInstanceHistory(Base): """ __tablename__ = "task_instance_history" - task_instance_id = Column( + task_instance_id: Mapped[str] = mapped_column( String(36).with_variant(postgresql.UUID(as_uuid=False), "postgresql"), nullable=False, primary_key=True, ) - task_id = Column(StringID(), nullable=False) - dag_id = Column(StringID(), nullable=False) - run_id = Column(StringID(), nullable=False) - map_index = Column(Integer, nullable=False, server_default=text("-1")) - try_number = Column(Integer, nullable=False) - start_date = Column(UtcDateTime) - end_date = Column(UtcDateTime) - duration = Column(Float) - state = Column(String(20)) - max_tries = Column(Integer, server_default=text("-1")) - hostname = Column(String(1000)) - unixname = Column(String(1000)) - pool = Column(String(256), nullable=False) - pool_slots = Column(Integer, default=1, nullable=False) - queue = Column(String(256)) - priority_weight = Column(Integer) - operator = Column(String(1000)) - custom_operator_name = Column(String(1000)) - queued_dttm = Column(UtcDateTime) - scheduled_dttm = Column(UtcDateTime) - queued_by_job_id = Column(Integer) - pid = Column(Integer) - executor = Column(String(1000)) - executor_config = Column(ExecutorConfigType(pickler=dill)) - updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow) - rendered_map_index = Column(String(250)) - context_carrier = Column(MutableDict.as_mutable(ExtendedJSON)) - span_status = Column(String(250), server_default=SpanStatus.NOT_STARTED, nullable=False) - - external_executor_id = Column(StringID()) - trigger_id = Column(Integer) - trigger_timeout = Column(DateTime) - next_method = Column(String(1000)) - next_kwargs = Column(MutableDict.as_mutable(ExtendedJSON)) - - task_display_name = Column(String(2000), nullable=True) - dag_version_id = Column(UUIDType(binary=False)) + task_id: Mapped[str] = mapped_column(StringID(), nullable=False) + dag_id: Mapped[str] = mapped_column(StringID(), nullable=False) + run_id: Mapped[str] = mapped_column(StringID(), nullable=False) + map_index: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("-1")) + try_number: Mapped[int] = mapped_column(Integer, nullable=False) + start_date: Mapped[UtcDateTime | None] = mapped_column(UtcDateTime, nullable=True) + end_date: Mapped[UtcDateTime | None] = mapped_column(UtcDateTime, nullable=True) + duration: Mapped[float | None] = mapped_column(Float, nullable=True) + state: Mapped[str | None] = mapped_column(String(20), nullable=True) + max_tries: Mapped[int | None] = mapped_column(Integer, server_default=text("-1"), nullable=True) + hostname: Mapped[str | None] = mapped_column(String(1000), nullable=True) + unixname: Mapped[str | None] = mapped_column(String(1000), nullable=True) + pool: Mapped[str] = mapped_column(String(256), nullable=False) + pool_slots: Mapped[int] = mapped_column(Integer, default=1, nullable=False) + queue: Mapped[str | None] = mapped_column(String(256), nullable=True) + priority_weight: Mapped[int | None] = mapped_column(Integer, nullable=True) + operator: Mapped[str | None] = mapped_column(String(1000), nullable=True) + custom_operator_name: Mapped[str | None] = mapped_column(String(1000), nullable=True) + queued_dttm: Mapped[UtcDateTime | None] = mapped_column(UtcDateTime, nullable=True) + scheduled_dttm: Mapped[UtcDateTime | None] = mapped_column(UtcDateTime, nullable=True) + queued_by_job_id: Mapped[int | None] = mapped_column(Integer, nullable=True) + pid: Mapped[int | None] = mapped_column(Integer, nullable=True) + executor: Mapped[str | None] = mapped_column(String(1000), nullable=True) + executor_config: Mapped[dict | None] = mapped_column(ExecutorConfigType(pickler=dill), nullable=True) + updated_at: Mapped[UtcDateTime | None] = mapped_column( + UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=True + ) + rendered_map_index: Mapped[str | None] = mapped_column(String(250), nullable=True) + context_carrier: Mapped[dict | None] = mapped_column(MutableDict.as_mutable(ExtendedJSON), nullable=True) + span_status: Mapped[str] = mapped_column( + String(250), server_default=SpanStatus.NOT_STARTED, nullable=False + ) + + external_executor_id: Mapped[str | None] = mapped_column(StringID(), nullable=True) + trigger_id: Mapped[int | None] = mapped_column(Integer, nullable=True) + trigger_timeout: Mapped[DateTime | None] = mapped_column(DateTime, nullable=True) + next_method: Mapped[str | None] = mapped_column(String(1000), nullable=True) + next_kwargs: Mapped[dict | None] = mapped_column(MutableDict.as_mutable(ExtendedJSON), nullable=True) + + task_display_name: Mapped[str | None] = mapped_column(String(2000), nullable=True) + dag_version_id: Mapped[str | None] = mapped_column(UUIDType(binary=False), nullable=True) dag_version = relationship( "DagVersion", diff --git a/airflow-core/src/airflow/models/tasklog.py b/airflow-core/src/airflow/models/tasklog.py index d9a5c57c30ac5..c28005bd11e0e 100644 --- a/airflow-core/src/airflow/models/tasklog.py +++ b/airflow-core/src/airflow/models/tasklog.py @@ -17,11 +17,12 @@ # under the License. from __future__ import annotations -from sqlalchemy import Column, Integer, Text +from sqlalchemy import Integer, Text +from sqlalchemy.orm import Mapped from airflow._shared.timezones import timezone from airflow.models.base import Base -from airflow.utils.sqlalchemy import UtcDateTime +from airflow.utils.sqlalchemy import UtcDateTime, mapped_column class LogTemplate(Base): @@ -34,10 +35,10 @@ class LogTemplate(Base): __tablename__ = "log_template" - id = Column(Integer, primary_key=True, autoincrement=True) - filename = Column(Text, nullable=False) - elasticsearch_id = Column(Text, nullable=False) - created_at = Column(UtcDateTime, nullable=False, default=timezone.utcnow) + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + filename: Mapped[str] = mapped_column(Text, nullable=False) + elasticsearch_id: Mapped[str] = mapped_column(Text, nullable=False) + created_at: Mapped[UtcDateTime] = mapped_column(UtcDateTime, nullable=False, default=timezone.utcnow) def __repr__(self) -> str: attrs = ", ".join(f"{k}={getattr(self, k)}" for k in ("filename", "elasticsearch_id")) diff --git a/airflow-core/src/airflow/models/taskmap.py b/airflow-core/src/airflow/models/taskmap.py index edd0b21b114ea..601a64ebbe737 100644 --- a/airflow-core/src/airflow/models/taskmap.py +++ b/airflow-core/src/airflow/models/taskmap.py @@ -24,12 +24,13 @@ from collections.abc import Collection, Iterable, Sequence from typing import TYPE_CHECKING, Any -from sqlalchemy import CheckConstraint, Column, ForeignKeyConstraint, Integer, String, func, or_, select +from sqlalchemy import CheckConstraint, ForeignKeyConstraint, Integer, String, func, or_, select +from sqlalchemy.orm import Mapped from airflow.models.base import COLLATION_ARGS, ID_LEN, TaskInstanceDependencies from airflow.models.dag_version import DagVersion from airflow.utils.db import exists_query -from airflow.utils.sqlalchemy import ExtendedJSON, with_row_locks +from airflow.utils.sqlalchemy import ExtendedJSON, mapped_column, with_row_locks from airflow.utils.state import State, TaskInstanceState if TYPE_CHECKING: @@ -63,13 +64,13 @@ class TaskMap(TaskInstanceDependencies): __tablename__ = "task_map" # Link to upstream TaskInstance creating this dynamic mapping information. - dag_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True) - task_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True) - run_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True) - map_index = Column(Integer, primary_key=True) + dag_id: Mapped[str] = mapped_column(String(ID_LEN, **COLLATION_ARGS), primary_key=True) + task_id: Mapped[str] = mapped_column(String(ID_LEN, **COLLATION_ARGS), primary_key=True) + run_id: Mapped[str] = mapped_column(String(ID_LEN, **COLLATION_ARGS), primary_key=True) + map_index: Mapped[int] = mapped_column(Integer, primary_key=True) - length = Column(Integer, nullable=False) - keys = Column(ExtendedJSON, nullable=True) + length: Mapped[int] = mapped_column(Integer, nullable=False) + keys: Mapped[list | None] = mapped_column(ExtendedJSON, nullable=True) __table_args__ = ( CheckConstraint(length >= 0, name="task_map_length_not_negative"), diff --git a/airflow-core/src/airflow/models/taskreschedule.py b/airflow-core/src/airflow/models/taskreschedule.py index e07d750ebdef7..2dc6bcb18b257 100644 --- a/airflow-core/src/airflow/models/taskreschedule.py +++ b/airflow-core/src/airflow/models/taskreschedule.py @@ -19,11 +19,11 @@ from __future__ import annotations +import datetime import uuid from typing import TYPE_CHECKING from sqlalchemy import ( - Column, ForeignKey, Integer, String, @@ -32,10 +32,10 @@ select, ) from sqlalchemy.dialects import postgresql -from sqlalchemy.orm import relationship +from sqlalchemy.orm import Mapped, relationship from airflow.models.base import Base -from airflow.utils.sqlalchemy import UtcDateTime +from airflow.utils.sqlalchemy import UtcDateTime, mapped_column if TYPE_CHECKING: import datetime @@ -49,16 +49,16 @@ class TaskReschedule(Base): """TaskReschedule tracks rescheduled task instances.""" __tablename__ = "task_reschedule" - id = Column(Integer, primary_key=True) - ti_id = Column( + id: Mapped[int] = mapped_column(Integer, primary_key=True) + ti_id: Mapped[str] = mapped_column( String(36).with_variant(postgresql.UUID(as_uuid=False), "postgresql"), ForeignKey("task_instance.id", ondelete="CASCADE", name="task_reschedule_ti_fkey"), nullable=False, ) - start_date = Column(UtcDateTime, nullable=False) - end_date = Column(UtcDateTime, nullable=False) - duration = Column(Integer, nullable=False) - reschedule_date = Column(UtcDateTime, nullable=False) + start_date: Mapped[UtcDateTime] = mapped_column(UtcDateTime, nullable=False) + end_date: Mapped[UtcDateTime] = mapped_column(UtcDateTime, nullable=False) + duration: Mapped[int] = mapped_column(Integer, nullable=False) + reschedule_date: Mapped[UtcDateTime] = mapped_column(UtcDateTime, nullable=False) task_instance = relationship( "TaskInstance", primaryjoin="TaskReschedule.ti_id == foreign(TaskInstance.id)", uselist=False diff --git a/airflow-core/src/airflow/models/team.py b/airflow-core/src/airflow/models/team.py index 3ef31b434095a..551b426abcace 100644 --- a/airflow-core/src/airflow/models/team.py +++ b/airflow-core/src/airflow/models/team.py @@ -21,11 +21,12 @@ import uuid6 from sqlalchemy import Column, ForeignKey, Index, String, Table, select -from sqlalchemy.orm import relationship +from sqlalchemy.orm import Mapped, relationship from sqlalchemy_utils import UUIDType from airflow.models.base import Base, StringID from airflow.utils.session import NEW_SESSION, provide_session +from airflow.utils.sqlalchemy import mapped_column if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -54,8 +55,8 @@ class Team(Base): __tablename__ = "team" - id = Column(UUIDType(binary=False), primary_key=True, default=uuid6.uuid7) - name = Column(String(50), unique=True, nullable=False) + id: Mapped[str] = mapped_column(UUIDType(binary=False), primary_key=True, default=uuid6.uuid7) + name: Mapped[str] = mapped_column(String(50), unique=True, nullable=False) dag_bundles = relationship( "DagBundleModel", secondary=dag_bundle_team_association_table, back_populates="teams" ) diff --git a/airflow-core/src/airflow/models/trigger.py b/airflow-core/src/airflow/models/trigger.py index 8de31b312d985..94da2acc6aa84 100644 --- a/airflow-core/src/airflow/models/trigger.py +++ b/airflow-core/src/airflow/models/trigger.py @@ -24,9 +24,9 @@ from traceback import format_exception from typing import TYPE_CHECKING, Any -from sqlalchemy import Column, Integer, String, Text, delete, func, or_, select, update +from sqlalchemy import Integer, String, Text, delete, func, or_, select, update from sqlalchemy.ext.associationproxy import association_proxy -from sqlalchemy.orm import Session, relationship, selectinload +from sqlalchemy.orm import Mapped, Session, relationship, selectinload from sqlalchemy.sql.functions import coalesce from airflow._shared.timezones import timezone @@ -37,7 +37,7 @@ from airflow.triggers.base import BaseTaskEndEvent from airflow.utils.retries import run_with_db_retries from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.sqlalchemy import UtcDateTime, with_row_locks +from airflow.utils.sqlalchemy import UtcDateTime, mapped_column, with_row_locks from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: @@ -90,11 +90,11 @@ class Trigger(Base): __tablename__ = "trigger" - id = Column(Integer, primary_key=True) - classpath = Column(String(1000), nullable=False) - encrypted_kwargs = Column("kwargs", Text, nullable=False) - created_date = Column(UtcDateTime, nullable=False) - triggerer_id = Column(Integer, nullable=True) + id: Mapped[int] = mapped_column(Integer, primary_key=True) + classpath: Mapped[str] = mapped_column(String(1000), nullable=False) + encrypted_kwargs: Mapped[str] = mapped_column("kwargs", Text, nullable=False) + created_date: Mapped[UtcDateTime] = mapped_column(UtcDateTime, nullable=False) + triggerer_id: Mapped[int | None] = mapped_column(Integer, nullable=True) triggerer_job = relationship( "Job", diff --git a/airflow-core/src/airflow/models/variable.py b/airflow-core/src/airflow/models/variable.py index 2c50da9add721..a13eb4fe158dc 100644 --- a/airflow-core/src/airflow/models/variable.py +++ b/airflow-core/src/airflow/models/variable.py @@ -24,9 +24,9 @@ import warnings from typing import TYPE_CHECKING, Any -from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Text, delete, select +from sqlalchemy import Boolean, ForeignKey, Integer, String, Text, delete, select from sqlalchemy.dialects.mysql import MEDIUMTEXT -from sqlalchemy.orm import declared_attr, reconstructor, synonym +from sqlalchemy.orm import Mapped, declared_attr, reconstructor, synonym from sqlalchemy_utils import UUIDType from airflow._shared.secrets_masker import mask_secret @@ -38,6 +38,7 @@ from airflow.secrets.metastore import MetastoreBackend from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import NEW_SESSION, create_session, provide_session +from airflow.utils.sqlalchemy import mapped_column if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -51,12 +52,12 @@ class Variable(Base, LoggingMixin): __tablename__ = "variable" __NO_DEFAULT_SENTINEL = object() - id = Column(Integer, primary_key=True) - key = Column(String(ID_LEN), unique=True) - _val = Column("val", Text().with_variant(MEDIUMTEXT, "mysql")) - description = Column(Text) - is_encrypted = Column(Boolean, unique=False, default=False) - team_id = Column(UUIDType(binary=False), ForeignKey("team.id"), nullable=True) + id: Mapped[int] = mapped_column(Integer, primary_key=True) + key: Mapped[str] = mapped_column(String(ID_LEN), unique=True) + _val: Mapped[str] = mapped_column("val", Text().with_variant(MEDIUMTEXT, "mysql")) + description: Mapped[str | None] = mapped_column(Text, nullable=True) + is_encrypted: Mapped[bool] = mapped_column(Boolean, unique=False, default=False) + team_id: Mapped[str | None] = mapped_column(UUIDType(binary=False), ForeignKey("team.id"), nullable=True) def __init__(self, key=None, val=None, description=None, team_id=None): super().__init__() diff --git a/airflow-core/src/airflow/models/xcom.py b/airflow-core/src/airflow/models/xcom.py index 709d6bf69030c..5c300f1f62029 100644 --- a/airflow-core/src/airflow/models/xcom.py +++ b/airflow-core/src/airflow/models/xcom.py @@ -24,7 +24,6 @@ from sqlalchemy import ( JSON, - Column, ForeignKeyConstraint, Index, Integer, @@ -37,7 +36,7 @@ ) from sqlalchemy.dialects import postgresql from sqlalchemy.ext.associationproxy import association_proxy -from sqlalchemy.orm import relationship +from sqlalchemy.orm import Mapped, relationship from airflow._shared.timezones import timezone from airflow.models.base import COLLATION_ARGS, ID_LEN, TaskInstanceDependencies @@ -45,7 +44,7 @@ from airflow.utils.helpers import is_container from airflow.utils.json import XComDecoder, XComEncoder from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.sqlalchemy import UtcDateTime +from airflow.utils.sqlalchemy import UtcDateTime, mapped_column log = logging.getLogger(__name__) @@ -63,17 +62,19 @@ class XComModel(TaskInstanceDependencies): __tablename__ = "xcom" - dag_run_id = Column(Integer(), nullable=False, primary_key=True) - task_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False, primary_key=True) - map_index = Column(Integer, primary_key=True, nullable=False, server_default=text("-1")) - key = Column(String(512, **COLLATION_ARGS), nullable=False, primary_key=True) + dag_run_id: Mapped[int] = mapped_column(Integer(), nullable=False, primary_key=True) + task_id: Mapped[str] = mapped_column(String(ID_LEN, **COLLATION_ARGS), nullable=False, primary_key=True) + map_index: Mapped[int] = mapped_column( + Integer, primary_key=True, nullable=False, server_default=text("-1") + ) + key: Mapped[str] = mapped_column(String(512, **COLLATION_ARGS), nullable=False, primary_key=True) # Denormalized for easier lookup. - dag_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False) - run_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False) + dag_id: Mapped[str] = mapped_column(String(ID_LEN, **COLLATION_ARGS), nullable=False) + run_id: Mapped[str] = mapped_column(String(ID_LEN, **COLLATION_ARGS), nullable=False) - value = Column(JSON().with_variant(postgresql.JSONB, "postgresql")) - timestamp = Column(UtcDateTime, default=timezone.utcnow, nullable=False) + value: Mapped[Any] = mapped_column(JSON().with_variant(postgresql.JSONB, "postgresql"), nullable=True) + timestamp: Mapped[UtcDateTime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) __table_args__ = ( # Ideally we should create a unique index over (key, dag_id, task_id, run_id), diff --git a/airflow-core/src/airflow/utils/sqlalchemy.py b/airflow-core/src/airflow/utils/sqlalchemy.py index 072f767335c6f..98bb68843e812 100644 --- a/airflow-core/src/airflow/utils/sqlalchemy.py +++ b/airflow-core/src/airflow/utils/sqlalchemy.py @@ -48,6 +48,15 @@ log = logging.getLogger(__name__) +try: + from sqlalchemy.orm import mapped_column +except ImportError: + # fallback for SQLAlchemy < 2.0 + def mapped_column(*args, **kwargs): + from sqlalchemy import Column + + return Column(*args, **kwargs) + class UtcDateTime(TypeDecorator): """ diff --git a/pyproject.toml b/pyproject.toml index 1e4790a506d89..5d2eda194eb57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -871,7 +871,7 @@ banned-module-level-imports = ["numpy", "pandas", "polars"] "providers".msg = "You should not import 'providers' as a Python module. Imports in providers should be done starting from 'src' or `tests' folders, for example 'from airflow.providers.airbyte' or 'from unit.airbyte' or 'from system.airbyte'" [tool.ruff.lint.flake8-type-checking] -exempt-modules = ["typing", "typing_extensions"] +exempt-modules = ["typing", "typing_extensions", "sqlalchemy.orm.Mapped"] [tool.ruff.lint.flake8-pytest-style] mark-parentheses = false