Skip to content

Commit

Permalink
Make filters param optional and fix typing
Browse files Browse the repository at this point in the history
Given that sometimes we don't want to apply any filters, it makes sense to make the param optional.  I also fix the typing on `paginated_select`.
  • Loading branch information
dstandish committed Nov 20, 2024
1 parent a825c95 commit 2b3838f
Show file tree
Hide file tree
Showing 9 changed files with 60 additions and 38 deletions.
30 changes: 20 additions & 10 deletions airflow/api_fastapi/common/db/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Database helpers for Airflow REST API.
:meta private:
"""

from __future__ import annotations

Expand Down Expand Up @@ -47,40 +52,45 @@ def your_route(session: Annotated[Session, Depends(get_session)]):
yield session


def apply_filters_to_select(base_select: Select, filters: Sequence[BaseParam | None]) -> Select:
base_select = base_select
for filter in filters:
if filter is None:
def apply_filters_to_select(
*, base_select: Select, filters: Sequence[BaseParam | None] | None = None
) -> Select:
if filters is None:
return base_select
for f in filters:
if f is None:
continue
base_select = filter.to_orm(base_select)
base_select = f.to_orm(base_select)

return base_select


@provide_session
def paginated_select(
base_select: Select,
filters: Sequence[BaseParam],
filters: Sequence[BaseParam] | None = None,
order_by: BaseParam | None = None,
offset: BaseParam | None = None,
limit: BaseParam | None = None,
session: Session = NEW_SESSION,
return_total_entries: bool = True,
) -> Select:
) -> tuple[Select, int | None]:
base_select = apply_filters_to_select(
base_select,
filters,
base_select=base_select,
filters=filters,
)

total_entries = None
if return_total_entries:
total_entries = get_query_count(base_select, session=session)
if TYPE_CHECKING:
assert isinstance(total_entries, int)

# TODO: Re-enable when permissions are handled. Readable / writable entities,
# for instance:
# readable_dags = get_auth_manager().get_permitted_dag_ids(user=g.user)
# dags_select = dags_select.where(DagModel.dag_id.in_(readable_dags))

base_select = apply_filters_to_select(base_select, [order_by, offset, limit])
base_select = apply_filters_to_select(base_select=base_select, filters=[order_by, offset, limit])

return base_select, total_entries
22 changes: 13 additions & 9 deletions airflow/api_fastapi/core_api/routes/public/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

from datetime import datetime
from typing import Annotated
from typing import TYPE_CHECKING, Annotated

from fastapi import Depends, HTTPException, status
from sqlalchemy import delete, select
Expand Down Expand Up @@ -102,6 +102,9 @@ def get_assets(
limit=limit,
session=session,
)
if TYPE_CHECKING:
assert isinstance(total_entries, int)

assets = session.scalars(
assets_select.options(
subqueryload(AssetModel.consuming_dags), subqueryload(AssetModel.producing_tasks)
Expand Down Expand Up @@ -151,6 +154,8 @@ def get_asset_events(
limit=limit,
session=session,
)
if TYPE_CHECKING:
assert isinstance(total_entries, int)

assets_event_select = assets_event_select.options(subqueryload(AssetEvent.created_dagruns))
assets_events = session.scalars(assets_event_select).all()
Expand Down Expand Up @@ -211,11 +216,10 @@ def get_asset_queued_events(
.where(*where_clause)
)

dag_asset_queued_events_select, total_entries = paginated_select(
query,
[],
)
dag_asset_queued_events_select, total_entries = paginated_select(query)
adrqs = session.execute(dag_asset_queued_events_select).all()
if TYPE_CHECKING:
assert isinstance(total_entries, int)

if not adrqs:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Queue event with uri: `{uri}` was not found")
Expand Down Expand Up @@ -273,10 +277,10 @@ def get_dag_asset_queued_events(
.where(*where_clause)
)

dag_asset_queued_events_select, total_entries = paginated_select(
query,
[],
)
dag_asset_queued_events_select, total_entries = paginated_select(query)
if TYPE_CHECKING:
assert isinstance(total_entries, int)

adrqs = session.execute(dag_asset_queued_events_select).all()

if not adrqs:
Expand Down
1 change: 0 additions & 1 deletion airflow/api_fastapi/core_api/routes/public/backfills.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def list_backfills(
) -> BackfillCollectionResponse:
select_stmt, total_entries = paginated_select(
select(Backfill).where(Backfill.dag_id == dag_id),
[],
order_by=order_by,
offset=offset,
limit=limit,
Expand Down
5 changes: 3 additions & 2 deletions airflow/api_fastapi/core_api/routes/public/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import annotations

import os
from typing import Annotated
from typing import TYPE_CHECKING, Annotated

from fastapi import Depends, HTTPException, Query, status
from sqlalchemy import select
Expand Down Expand Up @@ -101,12 +101,13 @@ def get_connections(
"""Get all connection entries."""
connection_select, total_entries = paginated_select(
select(Connection),
[],
order_by=order_by,
offset=offset,
limit=limit,
session=session,
)
if TYPE_CHECKING:
assert isinstance(total_entries, int)

connections = session.scalars(connection_select).all()

Expand Down
7 changes: 5 additions & 2 deletions airflow/api_fastapi/core_api/routes/public/event_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import annotations

from datetime import datetime
from typing import Annotated
from typing import TYPE_CHECKING, Annotated

from fastapi import Depends, HTTPException, Query, status
from sqlalchemy import select
Expand Down Expand Up @@ -126,12 +126,15 @@ def get_event_logs(
base_select = base_select.where(Log.dttm > after)
event_logs_select, total_entries = paginated_select(
base_select,
[],
None,
order_by,
offset,
limit,
session,
)
if TYPE_CHECKING:
assert isinstance(total_entries, int)

event_logs = session.scalars(event_logs_select).all()

return EventLogCollectionResponse(
Expand Down
6 changes: 4 additions & 2 deletions airflow/api_fastapi/core_api/routes/public/import_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations

from typing import Annotated
from typing import TYPE_CHECKING, Annotated

from fastapi import Depends, HTTPException, status
from sqlalchemy import select
Expand Down Expand Up @@ -90,12 +90,14 @@ def get_import_errors(
"""Get all import errors."""
import_errors_select, total_entries = paginated_select(
select(ParseImportError),
[],
None,
order_by,
offset,
limit,
session,
)
if TYPE_CHECKING:
assert isinstance(total_entries, int)
import_errors = session.scalars(import_errors_select).all()

return ImportErrorCollectionResponse(
Expand Down
5 changes: 3 additions & 2 deletions airflow/api_fastapi/core_api/routes/public/pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations

from typing import Annotated
from typing import TYPE_CHECKING, Annotated

from fastapi import Depends, HTTPException, Query, status
from fastapi.exceptions import RequestValidationError
Expand Down Expand Up @@ -98,12 +98,13 @@ def get_pools(
"""Get all pools entries."""
pools_select, total_entries = paginated_select(
select(Pool),
[],
order_by=order_by,
offset=offset,
limit=limit,
session=session,
)
if TYPE_CHECKING:
assert isinstance(total_entries, int)

pools = session.scalars(pools_select).all()

Expand Down
5 changes: 3 additions & 2 deletions airflow/api_fastapi/core_api/routes/public/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations

from typing import Annotated
from typing import TYPE_CHECKING, Annotated

from fastapi import Depends, HTTPException, Query, status
from sqlalchemy import select
Expand Down Expand Up @@ -91,12 +91,13 @@ def get_variables(
"""Get all Variables entries."""
variable_select, total_entries = paginated_select(
select(Variable),
[],
order_by=order_by,
offset=offset,
limit=limit,
session=session,
)
if TYPE_CHECKING:
assert isinstance(total_entries, int)

variables = session.scalars(variable_select).all()

Expand Down
17 changes: 9 additions & 8 deletions airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,15 +838,15 @@ def process_executor_events(
"scheduler.tasks.killed_externally",
tags={"dag_id": ti.dag_id, "task_id": ti.task_id},
)
msg = (
"Executor %s reported that the task instance %s finished with state %s, but the task instance's state attribute is %s. " # noqa: RUF100, UP031, flynt
"Learn more: https://airflow.apache.org/docs/apache-airflow/stable/troubleshooting.html#task-state-changed-externally"
% (executor, ti, state, ti.state)
)
if info is not None:
msg += " Extra info: %s" % info # noqa: RUF100, UP031, flynt
msg = f"Received executor event {state} but task state is {ti.state}. task id={ti.id} info={info}Learn more: https://airflow.apache.org/docs/apache-airflow/stable/troubleshooting.html#task-state-changed-externally"
cls.logger().error(msg)
session.add(Log(event="state mismatch", extra=msg, task_instance=ti.key))
session.add(
Log(
event="state mismatch",
extra=f"Received executor event {state} but task state is {ti.state}.",
task_instance=ti.key,
)
)

# Get task from the Serialized DAG
try:
Expand Down Expand Up @@ -1883,6 +1883,7 @@ def _reschedule_stuck_task(self, ti: TaskInstance, session: Session):
.values(
state=TaskInstanceState.SCHEDULED,
queued_dttm=None,
executor_config=None,
)
.execution_options(synchronize_session=False)
)
Expand Down

0 comments on commit 2b3838f

Please sign in to comment.