Skip to content

Commit

Permalink
Make filters param optional and fix typing
Browse files Browse the repository at this point in the history
Given that sometimes we don't want to apply any filters, it makes sense to make the param optional.  I also fix the typing on `paginated_select`.
  • Loading branch information
dstandish committed Nov 21, 2024
1 parent 22d1406 commit c2eb8e2
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 18 deletions.
58 changes: 55 additions & 3 deletions airflow/api_fastapi/common/db/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@

from typing import TYPE_CHECKING, Literal, Sequence, overload

from airflow.utils.db import get_query_count
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from sqlalchemy.ext.asyncio import AsyncSession

from airflow.utils.db import get_query_count, get_query_count_async
from airflow.utils.session import NEW_SESSION, create_session, create_session_async, provide_session

if TYPE_CHECKING:
from sqlalchemy.orm import Session
Expand Down Expand Up @@ -53,7 +55,9 @@ def your_route(session: Annotated[Session, Depends(get_session)]):


def apply_filters_to_select(
*, base_select: Select, filters: Sequence[BaseParam | None] | None = None
*,
base_select: Select,
filters: Sequence[BaseParam | None] | None = None,
) -> Select:
if filters is None:
return base_select
Expand All @@ -65,6 +69,54 @@ def apply_filters_to_select(
return base_select


async def get_async_session() -> AsyncSession:
"""
Dependency for providing a session.
Example usage:
.. code:: python
@router.get("/your_path")
def your_route(session: Annotated[AsyncSession, Depends(get_async_session)]):
pass
"""
async with create_session_async() as session:
yield session


async def paginated_select_async(
*,
base_select: Select,
filters: Sequence[BaseParam | None] | None = None,
order_by: BaseParam | None = None,
offset: BaseParam | None = None,
limit: BaseParam | None = None,
session: AsyncSession,
return_total_entries: bool = True,
) -> tuple[Select, int | None]:
base_select = apply_filters_to_select(
base_select=base_select,
filters=filters,
)

total_entries = None
if return_total_entries:
total_entries = await get_query_count_async(base_select, session=session)

# TODO: Re-enable when permissions are handled. Readable / writable entities,
# for instance:
# readable_dags = get_auth_manager().get_permitted_dag_ids(user=g.user)
# dags_select = dags_select.where(DagModel.dag_id.in_(readable_dags))

base_select = apply_filters_to_select(
base_select=base_select,
filters=[order_by, offset, limit],
)

return base_select, total_entries


@overload
def paginated_select(
*,
Expand Down
15 changes: 7 additions & 8 deletions airflow/api_fastapi/core_api/routes/public/backfills.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@

from fastapi import Depends, HTTPException, status
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session

from airflow.api_fastapi.common.db.common import get_session, paginated_select
from airflow.api_fastapi.common.db.common import get_async_session, get_session, paginated_select_async
from airflow.api_fastapi.common.parameters import QueryLimit, QueryOffset, SortParam
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.core_api.datamodels.backfills import (
Expand All @@ -49,26 +50,24 @@
@backfills_router.get(
path="",
)
def list_backfills(
async def list_backfills(
dag_id: str,
limit: QueryLimit,
offset: QueryOffset,
order_by: Annotated[
SortParam,
Depends(SortParam(["id"], Backfill).dynamic_depends()),
],
session: Annotated[Session, Depends(get_session)],
session: Annotated[AsyncSession, Depends(get_async_session)],
) -> BackfillCollectionResponse:
select_stmt, total_entries = paginated_select(
select=select(Backfill).where(Backfill.dag_id == dag_id),
select_stmt, total_entries = await paginated_select_async(
base_select=select(Backfill).where(Backfill.dag_id == dag_id),
order_by=order_by,
offset=offset,
limit=limit,
session=session,
)

backfills = session.scalars(select_stmt)

backfills = await session.scalars(select_stmt)
return BackfillCollectionResponse(
backfills=[BackfillResponse.model_validate(b, from_attributes=True) for b in backfills],
total_entries=total_entries,
Expand Down
10 changes: 5 additions & 5 deletions airflow/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import pluggy
from packaging.version import Version
from sqlalchemy import create_engine, exc, text
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession as SAAsyncSession, create_async_engine
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.pool import NullPool

Expand Down Expand Up @@ -111,7 +111,7 @@
# this is achieved by the Session factory above.
NonScopedSession: Callable[..., SASession]
async_engine: AsyncEngine
create_async_session: Callable[..., AsyncSession]
AsyncSession: Callable[..., SAAsyncSession]

# The JSON library to use for DAG Serialization and De-Serialization
json = json
Expand Down Expand Up @@ -469,7 +469,7 @@ def configure_orm(disable_connection_pool=False, pool_class=None):
global Session
global engine
global async_engine
global create_async_session
global AsyncSession
global NonScopedSession

if os.environ.get("_AIRFLOW_SKIP_DB_TESTS") == "true":
Expand Down Expand Up @@ -498,11 +498,11 @@ def configure_orm(disable_connection_pool=False, pool_class=None):

engine = create_engine(SQL_ALCHEMY_CONN, connect_args=connect_args, **engine_args, future=True)
async_engine = create_async_engine(SQL_ALCHEMY_CONN_ASYNC, future=True)
create_async_session = sessionmaker(
AsyncSession = sessionmaker(
bind=async_engine,
autocommit=False,
autoflush=False,
class_=AsyncSession,
class_=SAAsyncSession,
expire_on_commit=False,
)
mask_secret(engine.url.password)
Expand Down
16 changes: 16 additions & 0 deletions airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
from alembic.runtime.environment import EnvironmentContext
from alembic.script import ScriptDirectory
from sqlalchemy.engine import Row
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from sqlalchemy.sql.elements import ClauseElement, TextClause
from sqlalchemy.sql.selectable import Select
Expand Down Expand Up @@ -1447,6 +1448,21 @@ def get_query_count(query_stmt: Select, *, session: Session) -> int:
return session.scalar(count_stmt)


async def get_query_count_async(query_stmt: Select, *, session: AsyncSession) -> int:
"""
Get count of a query.
A SELECT COUNT() FROM is issued against the subquery built from the
given statement. The ORDER BY clause is stripped from the statement
since it's unnecessary for COUNT, and can impact query planning and
degrade performance.
:meta private:
"""
count_stmt = select(func.count()).select_from(query_stmt.order_by(None).subquery())
return await session.scalar(count_stmt)


def check_query_exists(query_stmt: Select, *, session: Session) -> bool:
"""
Check whether there is at least one row matching a query.
Expand Down
18 changes: 18 additions & 0 deletions airflow/utils/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,24 @@ def create_session(scoped: bool = True) -> Generator[SASession, None, None]:
session.close()


@contextlib.asynccontextmanager
async def create_session_async():
"""
Context manager to create async session.
:meta private:
"""
from airflow.settings import AsyncSession

async with AsyncSession() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise


PS = ParamSpec("PS")
RT = TypeVar("RT")

Expand Down
4 changes: 2 additions & 2 deletions tests/utils/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ def test_provide_session_with_kwargs(self):

@pytest.mark.asyncio
async def test_async_session(self):
from airflow.settings import create_async_session
from airflow.settings import AsyncSession

session = create_async_session()
session = AsyncSession()
session.add(Log(event="hihi1234"))
await session.commit()
my_special_log_event = await session.scalar(select(Log).where(Log.event == "hihi1234").limit(1))
Expand Down

0 comments on commit c2eb8e2

Please sign in to comment.