Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify lazy db sequence implementations #39426

Merged
merged 3 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions airflow/lineage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,10 @@ def wrapper(self, context, *args, **kwargs):
# Remove auto and task_ids
self.inlets = [i for i in self.inlets if not isinstance(i, str)]

# We manually create a session here since xcom_pull returns a LazyXComAccess iterator.
# If we do not pass a session a new session will be created, however that session will not be
# properly closed and will remain open. After we are done iterating we can safely close this
# session.
# We manually create a session here since xcom_pull returns a
# LazySelectSequence proxy. If we do not pass a session, a new one
# will be created, but that session will not be properly closed.
# After we are done iterating, we can safely close this session.
with create_session() as session:
_inlets = self.xcom_pull(
context, task_ids=task_ids, dag_id=self.dag_id, key=PIPELINE_OUTLETS, session=session
Expand Down
7 changes: 5 additions & 2 deletions airflow/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# under the License.
from __future__ import annotations

from typing import Any
from typing import TYPE_CHECKING, Any

from sqlalchemy import Column, Integer, MetaData, String, text
from sqlalchemy.orm import registry
Expand Down Expand Up @@ -48,7 +48,10 @@ def _get_schema():
mapper_registry = registry(metadata=metadata)
_sentinel = object()

Base: Any = mapper_registry.generate_base()
if TYPE_CHECKING:
Base = Any
else:
Base = mapper_registry.generate_base()

ID_LEN = 250

Expand Down
39 changes: 21 additions & 18 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.models.taskmap import TaskMap
from airflow.models.taskreschedule import TaskReschedule
from airflow.models.xcom import LazyXComAccess, XCom
from airflow.models.xcom import LazyXComSelectSequence, XCom
from airflow.plugins_manager import integrate_macros_plugins
from airflow.sentry import Sentry
from airflow.settings import task_instance_mutation_hook
Expand Down Expand Up @@ -3358,34 +3358,37 @@ def xcom_pull(
return default
if map_indexes is not None or first.map_index < 0:
return XCom.deserialize_value(first)
query = query.order_by(None).order_by(XCom.map_index.asc())
return LazyXComAccess.build_from_xcom_query(query)
return LazyXComSelectSequence.from_select(
query.with_entities(XCom.value).order_by(None).statement,
order_by=[XCom.map_index],
session=session,
)

# At this point either task_ids or map_indexes is explicitly multi-value.
# Order return values to match task_ids and map_indexes ordering.
query = query.order_by(None)
ordering = []
if task_ids is None or isinstance(task_ids, str):
query = query.order_by(XCom.task_id)
ordering.append(XCom.task_id)
elif task_id_whens := {tid: i for i, tid in enumerate(task_ids)}:
ordering.append(case(task_id_whens, value=XCom.task_id))
else:
task_id_whens = {tid: i for i, tid in enumerate(task_ids)}
if task_id_whens:
query = query.order_by(case(task_id_whens, value=XCom.task_id))
else:
query = query.order_by(XCom.task_id)
ordering.append(XCom.task_id)
if map_indexes is None or isinstance(map_indexes, int):
query = query.order_by(XCom.map_index)
ordering.append(XCom.map_index)
elif isinstance(map_indexes, range):
order = XCom.map_index
if map_indexes.step < 0:
order = order.desc()
query = query.order_by(order)
ordering.append(order)
elif map_index_whens := {map_index: i for i, map_index in enumerate(map_indexes)}:
ordering.append(case(map_index_whens, value=XCom.map_index))
else:
map_index_whens = {map_index: i for i, map_index in enumerate(map_indexes)}
if map_index_whens:
query = query.order_by(case(map_index_whens, value=XCom.map_index))
else:
query = query.order_by(XCom.map_index)
return LazyXComAccess.build_from_xcom_query(query)
ordering.append(XCom.map_index)
return LazyXComSelectSequence.from_select(
query.with_entities(XCom.value).order_by(None).statement,
order_by=ordering,
session=session,
)

@provide_session
def get_num_running_task_instances(self, session: Session, same_dagrun: bool = False) -> int:
Expand Down
125 changes: 16 additions & 109 deletions airflow/models/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,14 @@
# under the License.
from __future__ import annotations

import collections.abc
import contextlib
import inspect
import itertools
import json
import logging
import pickle
import warnings
from functools import cached_property, wraps
from typing import TYPE_CHECKING, Any, Generator, Iterable, cast, overload
from functools import wraps
from typing import TYPE_CHECKING, Any, Iterable, cast, overload

import attr
from sqlalchemy import (
Column,
ForeignKeyConstraint,
Expand All @@ -38,19 +34,20 @@
PrimaryKeyConstraint,
String,
delete,
select,
text,
)
from sqlalchemy.dialects.mysql import LONGBLOB
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import Query, reconstructor, relationship
from sqlalchemy.orm.exc import NoResultFound

from airflow import settings
from airflow.api_internal.internal_api_call import internal_api_call
from airflow.configuration import conf
from airflow.exceptions import RemovedInAirflow3Warning
from airflow.models.base import COLLATION_ARGS, ID_LEN, TaskInstanceDependencies
from airflow.utils import timezone
from airflow.utils.db import LazySelectSequence
from airflow.utils.helpers import exactly_one, is_container
from airflow.utils.json import XComDecoder, XComEncoder
from airflow.utils.log.logging_mixin import LoggingMixin
Expand All @@ -70,7 +67,9 @@
import datetime

import pendulum
from sqlalchemy.engine import Row
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import Select, TextClause

from airflow.models.taskinstancekey import TaskInstanceKey

Expand Down Expand Up @@ -222,11 +221,11 @@ def set(
if dag_run_id is None:
raise ValueError(f"DAG run not found on DAG {dag_id!r} with ID {run_id!r}")

# Seamlessly resolve LazyXComAccess to a list. This is intended to work
# Seamlessly resolve LazySelectSequence to a list. This intends to work
# as a "lazy list" to avoid pulling a ton of XComs unnecessarily, but if
# it's pushed into XCom, the user should be aware of the performance
# implications, and this avoids leaking the implementation detail.
if isinstance(value, LazyXComAccess):
if isinstance(value, LazySelectSequence):
warning_message = (
"Coercing mapped lazy proxy %s from task %s (DAG %s, run %s) "
"to list, which may degrade performance. Review resource "
Expand Down Expand Up @@ -716,111 +715,19 @@ def orm_deserialize_value(self) -> Any:
return BaseXCom._deserialize_value(self, True)


class _LazyXComAccessIterator(collections.abc.Iterator):
def __init__(self, cm: contextlib.AbstractContextManager[Query]) -> None:
self._cm = cm
self._entered = False

def __del__(self) -> None:
if self._entered:
self._cm.__exit__(None, None, None)

def __iter__(self) -> collections.abc.Iterator:
return self

def __next__(self) -> Any:
return XCom.deserialize_value(next(self._it))

@cached_property
def _it(self) -> collections.abc.Iterator:
self._entered = True
return iter(self._cm.__enter__())


@attr.define(slots=True)
class LazyXComAccess(collections.abc.Sequence):
"""Wrapper to lazily pull XCom with a sequence-like interface.

Note that since the session bound to the parent query may have died when we
actually access the sequence's content, we must create a new session
for every function call with ``with_session()``.
class LazyXComSelectSequence(LazySelectSequence[Any]):
"""List-like interface to lazily access XCom values.

:meta private:
"""

_query: Query
_len: int | None = attr.ib(init=False, default=None)

@classmethod
def build_from_xcom_query(cls, query: Query) -> LazyXComAccess:
return cls(query=query.with_entities(XCom.value))

def __repr__(self) -> str:
return f"LazyXComAccess([{len(self)} items])"

def __str__(self) -> str:
return str(list(self))

def __eq__(self, other: Any) -> bool:
if isinstance(other, (list, LazyXComAccess)):
z = itertools.zip_longest(iter(self), iter(other), fillvalue=object())
return all(x == y for x, y in z)
return NotImplemented

def __getstate__(self) -> Any:
# We don't want to go to the trouble of serializing the entire Query
# object, including its filters, hints, etc. (plus SQLAlchemy does not
# provide a public API to inspect a query's contents). Converting the
# query into a SQL string is the best we can get. Theoratically we can
# do the same for count(), but I think it should be performant enough to
# calculate only that eagerly.
with self._get_bound_query() as query:
statement = query.statement.compile(
query.session.get_bind(),
# This inlines all the values into the SQL string to simplify
# cross-process commuinication as much as possible.
compile_kwargs={"literal_binds": True},
)
return (str(statement), query.count())

def __setstate__(self, state: Any) -> None:
statement, self._len = state
self._query = Query(XCom.value).from_statement(text(statement))

def __len__(self):
if self._len is None:
with self._get_bound_query() as query:
self._len = query.count()
return self._len

def __iter__(self):
return _LazyXComAccessIterator(self._get_bound_query())
@staticmethod
def _rebuild_select(stmt: TextClause) -> Select:
return select(XCom.value).from_statement(stmt)

def __getitem__(self, key):
if not isinstance(key, int):
raise ValueError("only support index access for now")
try:
with self._get_bound_query() as query:
r = query.offset(key).limit(1).one()
except NoResultFound:
raise IndexError(key) from None
return XCom.deserialize_value(r)

@contextlib.contextmanager
def _get_bound_query(self) -> Generator[Query, None, None]:
# Do we have a valid session already?
if self._query.session and self._query.session.is_active:
yield self._query
return

Session = getattr(settings, "Session", None)
if Session is None:
raise RuntimeError("Session must be set before!")
session = Session()
try:
yield self._query.with_session(session)
finally:
session.close()
@staticmethod
def _process_row(row: Row) -> Any:
return XCom.deserialize_value(row)


def _patch_outdated_serializer(clazz: type[BaseXCom], params: Iterable[str]) -> None:
Expand Down
6 changes: 6 additions & 0 deletions airflow/typing_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"Literal",
"ParamSpec",
"Protocol",
"Self",
"TypedDict",
"TypeGuard",
"runtime_checkable",
Expand All @@ -45,3 +46,8 @@
from typing import ParamSpec, TypeGuard
else:
from typing_extensions import ParamSpec, TypeGuard

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self
Loading