Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion airflow-core/src/airflow/api_fastapi/common/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
overload,
)

from fastapi import Depends, HTTPException, Query
from fastapi import Depends, HTTPException, Query, status
from pendulum.parsing.exceptions import ParserError
from pydantic import AfterValidator, BaseModel, NonNegativeInt
from sqlalchemy import Column, and_, func, not_, or_, select as sql_select
Expand Down Expand Up @@ -69,6 +69,7 @@
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.sql import ColumnElement, Select

from airflow.serialization.serialized_objects import SerializedDAG

T = TypeVar("T")

Expand Down Expand Up @@ -185,6 +186,57 @@ def depends(cls, *args: Any, **kwargs: Any) -> Self:
raise NotImplementedError("Use search_param_factory instead , depends is not implemented.")


class QueryTaskInstanceTaskGroupFilter(BaseParam[str]):
"""Task group filter - returns all tasks in the specified group."""

def __init__(self, dag=None, skip_none: bool = True):
super().__init__(skip_none=skip_none)
self._dag: None | SerializedDAG = dag

@property
def dag(self) -> None | SerializedDAG:
return self._dag

@dag.setter
def dag(self, value: None | SerializedDAG) -> None:
self._dag = value

def to_orm(self, select: Select) -> Select:
if self.value is None and self.skip_none:
return select

if not self.dag:
raise ValueError("Dag must be set before calling to_orm")

if not hasattr(self.dag, "task_group"):
return select

# Exact matching on group_id
task_groups = self.dag.task_group.get_task_group_dict()
task_group = task_groups.get(self.value)
if not task_group:
raise HTTPException(
status.HTTP_404_NOT_FOUND,
detail={
"reason": "not_found",
"message": f"Task group {self.value} not found",
},
)

return select.where(TaskInstance.task_id.in_(task.task_id for task in task_group.iter_tasks()))

@classmethod
def depends(
cls,
value: str | None = Query(
alias="task_group_id",
default=None,
description="Filter by exact task group ID. Returns all tasks within the specified task group.",
),
) -> QueryTaskInstanceTaskGroupFilter:
return cls(dag=None).set_value(value)


def search_param_factory(
attribute: ColumnElement,
pattern_name: str,
Expand Down Expand Up @@ -888,6 +940,9 @@ def _transform_ti_states(states: list[str] | None) -> list[TaskInstanceState | N
QueryTITaskDisplayNamePatternSearch = Annotated[
_SearchParam, Depends(search_param_factory(TaskInstance.task_display_name, "task_display_name_pattern"))
]
QueryTITaskGroupFilter = Annotated[
QueryTaskInstanceTaskGroupFilter, Depends(QueryTaskInstanceTaskGroupFilter.depends)
]
QueryTIDagVersionFilter = Annotated[
FilterParam[list[int]],
Depends(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6597,6 +6597,18 @@ paths:
title: Task Display Name Pattern
description: "SQL LIKE expression \u2014 use `%` / `_` wildcards (e.g. `%customer_%`).\
\ Regular expressions are **not** supported."
- name: task_group_id
in: query
required: false
schema:
anyOf:
- type: string
- type: 'null'
description: Filter by exact task group ID. Returns all tasks within the
specified task group.
title: Task Group Id
description: Filter by exact task group ID. Returns all tasks within the specified
task group.
- name: dag_id_pattern
in: query
required: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
QueryTIQueueNamePatternSearch,
QueryTIStateFilter,
QueryTITaskDisplayNamePatternSearch,
QueryTITaskGroupFilter,
QueryTITryNumberFilter,
Range,
RangeFilter,
Expand Down Expand Up @@ -424,6 +425,7 @@ def get_task_instances(
update_at_range: Annotated[RangeFilter, Depends(datetime_range_filter_factory("updated_at", TI))],
duration_range: Annotated[RangeFilter, Depends(float_range_filter_factory("duration", TI))],
task_display_name_pattern: QueryTITaskDisplayNamePatternSearch,
task_group_id: QueryTITaskGroupFilter,
dag_id_pattern: Annotated[_SearchParam, Depends(search_param_factory(TI.dag_id, "dag_id_pattern"))],
run_id_pattern: Annotated[_SearchParam, Depends(search_param_factory(TI.run_id, "run_id_pattern"))],
state: QueryTIStateFilter,
Expand Down Expand Up @@ -490,8 +492,10 @@ def get_task_instances(
)
query = query.where(TI.run_id == dag_run_id)
if dag_id != "~":
get_dag_for_run_or_latest_version(dag_bag, dag_run, dag_id, session)
dag = get_dag_for_run_or_latest_version(dag_bag, dag_run, dag_id, session)
query = query.where(TI.dag_id == dag_id)
if dag:
task_group_id.dag = dag

task_instance_select, total_entries = paginated_select(
statement=query,
Expand All @@ -510,6 +514,7 @@ def get_task_instances(
executor,
task_id,
task_display_name_pattern,
task_group_id,
dag_id_pattern,
run_id_pattern,
version_number,
Expand Down
5 changes: 3 additions & 2 deletions airflow-core/src/airflow/ui/openapi-gen/queries/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ export const UseTaskInstanceServiceGetMappedTaskInstanceKeyFn = ({ dagId, dagRun
export type TaskInstanceServiceGetTaskInstancesDefaultResponse = Awaited<ReturnType<typeof TaskInstanceService.getTaskInstances>>;
export type TaskInstanceServiceGetTaskInstancesQueryResult<TData = TaskInstanceServiceGetTaskInstancesDefaultResponse, TError = unknown> = UseQueryResult<TData, TError>;
export const useTaskInstanceServiceGetTaskInstancesKey = "TaskInstanceServiceGetTaskInstances";
export const UseTaskInstanceServiceGetTaskInstancesKeyFn = ({ dagId, dagIdPattern, dagRunId, durationGt, durationGte, durationLt, durationLte, endDateGt, endDateGte, endDateLt, endDateLte, executor, limit, logicalDateGt, logicalDateGte, logicalDateLt, logicalDateLte, mapIndex, offset, operator, operatorNamePattern, orderBy, pool, poolNamePattern, queue, queueNamePattern, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runIdPattern, startDateGt, startDateGte, startDateLt, startDateLte, state, taskDisplayNamePattern, taskId, tryNumber, updatedAtGt, updatedAtGte, updatedAtLt, updatedAtLte, versionNumber }: {
export const UseTaskInstanceServiceGetTaskInstancesKeyFn = ({ dagId, dagIdPattern, dagRunId, durationGt, durationGte, durationLt, durationLte, endDateGt, endDateGte, endDateLt, endDateLte, executor, limit, logicalDateGt, logicalDateGte, logicalDateLt, logicalDateLte, mapIndex, offset, operator, operatorNamePattern, orderBy, pool, poolNamePattern, queue, queueNamePattern, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runIdPattern, startDateGt, startDateGte, startDateLt, startDateLte, state, taskDisplayNamePattern, taskGroupId, taskId, tryNumber, updatedAtGt, updatedAtGte, updatedAtLt, updatedAtLte, versionNumber }: {
dagId: string;
dagIdPattern?: string;
dagRunId: string;
Expand Down Expand Up @@ -512,14 +512,15 @@ export const UseTaskInstanceServiceGetTaskInstancesKeyFn = ({ dagId, dagIdPatter
startDateLte?: string;
state?: string[];
taskDisplayNamePattern?: string;
taskGroupId?: string;
taskId?: string;
tryNumber?: number[];
updatedAtGt?: string;
updatedAtGte?: string;
updatedAtLt?: string;
updatedAtLte?: string;
versionNumber?: number[];
}, queryKey?: Array<unknown>) => [useTaskInstanceServiceGetTaskInstancesKey, ...(queryKey ?? [{ dagId, dagIdPattern, dagRunId, durationGt, durationGte, durationLt, durationLte, endDateGt, endDateGte, endDateLt, endDateLte, executor, limit, logicalDateGt, logicalDateGte, logicalDateLt, logicalDateLte, mapIndex, offset, operator, operatorNamePattern, orderBy, pool, poolNamePattern, queue, queueNamePattern, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runIdPattern, startDateGt, startDateGte, startDateLt, startDateLte, state, taskDisplayNamePattern, taskId, tryNumber, updatedAtGt, updatedAtGte, updatedAtLt, updatedAtLte, versionNumber }])];
}, queryKey?: Array<unknown>) => [useTaskInstanceServiceGetTaskInstancesKey, ...(queryKey ?? [{ dagId, dagIdPattern, dagRunId, durationGt, durationGte, durationLt, durationLte, endDateGt, endDateGte, endDateLt, endDateLte, executor, limit, logicalDateGt, logicalDateGte, logicalDateLt, logicalDateLte, mapIndex, offset, operator, operatorNamePattern, orderBy, pool, poolNamePattern, queue, queueNamePattern, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runIdPattern, startDateGt, startDateGte, startDateLt, startDateLte, state, taskDisplayNamePattern, taskGroupId, taskId, tryNumber, updatedAtGt, updatedAtGte, updatedAtLt, updatedAtLte, versionNumber }])];
export type TaskInstanceServiceGetTaskInstanceTryDetailsDefaultResponse = Awaited<ReturnType<typeof TaskInstanceService.getTaskInstanceTryDetails>>;
export type TaskInstanceServiceGetTaskInstanceTryDetailsQueryResult<TData = TaskInstanceServiceGetTaskInstanceTryDetailsDefaultResponse, TError = unknown> = UseQueryResult<TData, TError>;
export const useTaskInstanceServiceGetTaskInstanceTryDetailsKey = "TaskInstanceServiceGetTaskInstanceTryDetails";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,7 @@ export const ensureUseTaskInstanceServiceGetMappedTaskInstanceData = (queryClien
* @param data.durationLte
* @param data.durationLt
* @param data.taskDisplayNamePattern SQL LIKE expression — use `%` / `_` wildcards (e.g. `%customer_%`). Regular expressions are **not** supported.
* @param data.taskGroupId Filter by exact task group ID. Returns all tasks within the specified task group.
* @param data.dagIdPattern SQL LIKE expression — use `%` / `_` wildcards (e.g. `%customer_%`). Regular expressions are **not** supported.
* @param data.runIdPattern SQL LIKE expression — use `%` / `_` wildcards (e.g. `%customer_%`). Regular expressions are **not** supported.
* @param data.state
Expand All @@ -944,7 +945,7 @@ export const ensureUseTaskInstanceServiceGetMappedTaskInstanceData = (queryClien
* @returns TaskInstanceCollectionResponse Successful Response
* @throws ApiError
*/
export const ensureUseTaskInstanceServiceGetTaskInstancesData = (queryClient: QueryClient, { dagId, dagIdPattern, dagRunId, durationGt, durationGte, durationLt, durationLte, endDateGt, endDateGte, endDateLt, endDateLte, executor, limit, logicalDateGt, logicalDateGte, logicalDateLt, logicalDateLte, mapIndex, offset, operator, operatorNamePattern, orderBy, pool, poolNamePattern, queue, queueNamePattern, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runIdPattern, startDateGt, startDateGte, startDateLt, startDateLte, state, taskDisplayNamePattern, taskId, tryNumber, updatedAtGt, updatedAtGte, updatedAtLt, updatedAtLte, versionNumber }: {
export const ensureUseTaskInstanceServiceGetTaskInstancesData = (queryClient: QueryClient, { dagId, dagIdPattern, dagRunId, durationGt, durationGte, durationLt, durationLte, endDateGt, endDateGte, endDateLt, endDateLte, executor, limit, logicalDateGt, logicalDateGte, logicalDateLt, logicalDateLte, mapIndex, offset, operator, operatorNamePattern, orderBy, pool, poolNamePattern, queue, queueNamePattern, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runIdPattern, startDateGt, startDateGte, startDateLt, startDateLte, state, taskDisplayNamePattern, taskGroupId, taskId, tryNumber, updatedAtGt, updatedAtGte, updatedAtLt, updatedAtLte, versionNumber }: {
dagId: string;
dagIdPattern?: string;
dagRunId: string;
Expand Down Expand Up @@ -982,14 +983,15 @@ export const ensureUseTaskInstanceServiceGetTaskInstancesData = (queryClient: Qu
startDateLte?: string;
state?: string[];
taskDisplayNamePattern?: string;
taskGroupId?: string;
taskId?: string;
tryNumber?: number[];
updatedAtGt?: string;
updatedAtGte?: string;
updatedAtLt?: string;
updatedAtLte?: string;
versionNumber?: number[];
}) => queryClient.ensureQueryData({ queryKey: Common.UseTaskInstanceServiceGetTaskInstancesKeyFn({ dagId, dagIdPattern, dagRunId, durationGt, durationGte, durationLt, durationLte, endDateGt, endDateGte, endDateLt, endDateLte, executor, limit, logicalDateGt, logicalDateGte, logicalDateLt, logicalDateLte, mapIndex, offset, operator, operatorNamePattern, orderBy, pool, poolNamePattern, queue, queueNamePattern, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runIdPattern, startDateGt, startDateGte, startDateLt, startDateLte, state, taskDisplayNamePattern, taskId, tryNumber, updatedAtGt, updatedAtGte, updatedAtLt, updatedAtLte, versionNumber }), queryFn: () => TaskInstanceService.getTaskInstances({ dagId, dagIdPattern, dagRunId, durationGt, durationGte, durationLt, durationLte, endDateGt, endDateGte, endDateLt, endDateLte, executor, limit, logicalDateGt, logicalDateGte, logicalDateLt, logicalDateLte, mapIndex, offset, operator, operatorNamePattern, orderBy, pool, poolNamePattern, queue, queueNamePattern, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runIdPattern, startDateGt, startDateGte, startDateLt, startDateLte, state, taskDisplayNamePattern, taskId, tryNumber, updatedAtGt, updatedAtGte, updatedAtLt, updatedAtLte, versionNumber }) });
}) => queryClient.ensureQueryData({ queryKey: Common.UseTaskInstanceServiceGetTaskInstancesKeyFn({ dagId, dagIdPattern, dagRunId, durationGt, durationGte, durationLt, durationLte, endDateGt, endDateGte, endDateLt, endDateLte, executor, limit, logicalDateGt, logicalDateGte, logicalDateLt, logicalDateLte, mapIndex, offset, operator, operatorNamePattern, orderBy, pool, poolNamePattern, queue, queueNamePattern, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runIdPattern, startDateGt, startDateGte, startDateLt, startDateLte, state, taskDisplayNamePattern, taskGroupId, taskId, tryNumber, updatedAtGt, updatedAtGte, updatedAtLt, updatedAtLte, versionNumber }), queryFn: () => TaskInstanceService.getTaskInstances({ dagId, dagIdPattern, dagRunId, durationGt, durationGte, durationLt, durationLte, endDateGt, endDateGte, endDateLt, endDateLte, executor, limit, logicalDateGt, logicalDateGte, logicalDateLt, logicalDateLte, mapIndex, offset, operator, operatorNamePattern, orderBy, pool, poolNamePattern, queue, queueNamePattern, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runIdPattern, startDateGt, startDateGte, startDateLt, startDateLte, state, taskDisplayNamePattern, taskGroupId, taskId, tryNumber, updatedAtGt, updatedAtGte, updatedAtLt, updatedAtLte, versionNumber }) });
/**
* Get Task Instance Try Details
* Get task instance details by try number.
Expand Down
6 changes: 4 additions & 2 deletions airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,7 @@ export const prefetchUseTaskInstanceServiceGetMappedTaskInstance = (queryClient:
* @param data.durationLte
* @param data.durationLt
* @param data.taskDisplayNamePattern SQL LIKE expression — use `%` / `_` wildcards (e.g. `%customer_%`). Regular expressions are **not** supported.
* @param data.taskGroupId Filter by exact task group ID. Returns all tasks within the specified task group.
* @param data.dagIdPattern SQL LIKE expression — use `%` / `_` wildcards (e.g. `%customer_%`). Regular expressions are **not** supported.
* @param data.runIdPattern SQL LIKE expression — use `%` / `_` wildcards (e.g. `%customer_%`). Regular expressions are **not** supported.
* @param data.state
Expand All @@ -944,7 +945,7 @@ export const prefetchUseTaskInstanceServiceGetMappedTaskInstance = (queryClient:
* @returns TaskInstanceCollectionResponse Successful Response
* @throws ApiError
*/
export const prefetchUseTaskInstanceServiceGetTaskInstances = (queryClient: QueryClient, { dagId, dagIdPattern, dagRunId, durationGt, durationGte, durationLt, durationLte, endDateGt, endDateGte, endDateLt, endDateLte, executor, limit, logicalDateGt, logicalDateGte, logicalDateLt, logicalDateLte, mapIndex, offset, operator, operatorNamePattern, orderBy, pool, poolNamePattern, queue, queueNamePattern, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runIdPattern, startDateGt, startDateGte, startDateLt, startDateLte, state, taskDisplayNamePattern, taskId, tryNumber, updatedAtGt, updatedAtGte, updatedAtLt, updatedAtLte, versionNumber }: {
export const prefetchUseTaskInstanceServiceGetTaskInstances = (queryClient: QueryClient, { dagId, dagIdPattern, dagRunId, durationGt, durationGte, durationLt, durationLte, endDateGt, endDateGte, endDateLt, endDateLte, executor, limit, logicalDateGt, logicalDateGte, logicalDateLt, logicalDateLte, mapIndex, offset, operator, operatorNamePattern, orderBy, pool, poolNamePattern, queue, queueNamePattern, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runIdPattern, startDateGt, startDateGte, startDateLt, startDateLte, state, taskDisplayNamePattern, taskGroupId, taskId, tryNumber, updatedAtGt, updatedAtGte, updatedAtLt, updatedAtLte, versionNumber }: {
dagId: string;
dagIdPattern?: string;
dagRunId: string;
Expand Down Expand Up @@ -982,14 +983,15 @@ export const prefetchUseTaskInstanceServiceGetTaskInstances = (queryClient: Quer
startDateLte?: string;
state?: string[];
taskDisplayNamePattern?: string;
taskGroupId?: string;
taskId?: string;
tryNumber?: number[];
updatedAtGt?: string;
updatedAtGte?: string;
updatedAtLt?: string;
updatedAtLte?: string;
versionNumber?: number[];
}) => queryClient.prefetchQuery({ queryKey: Common.UseTaskInstanceServiceGetTaskInstancesKeyFn({ dagId, dagIdPattern, dagRunId, durationGt, durationGte, durationLt, durationLte, endDateGt, endDateGte, endDateLt, endDateLte, executor, limit, logicalDateGt, logicalDateGte, logicalDateLt, logicalDateLte, mapIndex, offset, operator, operatorNamePattern, orderBy, pool, poolNamePattern, queue, queueNamePattern, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runIdPattern, startDateGt, startDateGte, startDateLt, startDateLte, state, taskDisplayNamePattern, taskId, tryNumber, updatedAtGt, updatedAtGte, updatedAtLt, updatedAtLte, versionNumber }), queryFn: () => TaskInstanceService.getTaskInstances({ dagId, dagIdPattern, dagRunId, durationGt, durationGte, durationLt, durationLte, endDateGt, endDateGte, endDateLt, endDateLte, executor, limit, logicalDateGt, logicalDateGte, logicalDateLt, logicalDateLte, mapIndex, offset, operator, operatorNamePattern, orderBy, pool, poolNamePattern, queue, queueNamePattern, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runIdPattern, startDateGt, startDateGte, startDateLt, startDateLte, state, taskDisplayNamePattern, taskId, tryNumber, updatedAtGt, updatedAtGte, updatedAtLt, updatedAtLte, versionNumber }) });
}) => queryClient.prefetchQuery({ queryKey: Common.UseTaskInstanceServiceGetTaskInstancesKeyFn({ dagId, dagIdPattern, dagRunId, durationGt, durationGte, durationLt, durationLte, endDateGt, endDateGte, endDateLt, endDateLte, executor, limit, logicalDateGt, logicalDateGte, logicalDateLt, logicalDateLte, mapIndex, offset, operator, operatorNamePattern, orderBy, pool, poolNamePattern, queue, queueNamePattern, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runIdPattern, startDateGt, startDateGte, startDateLt, startDateLte, state, taskDisplayNamePattern, taskGroupId, taskId, tryNumber, updatedAtGt, updatedAtGte, updatedAtLt, updatedAtLte, versionNumber }), queryFn: () => TaskInstanceService.getTaskInstances({ dagId, dagIdPattern, dagRunId, durationGt, durationGte, durationLt, durationLte, endDateGt, endDateGte, endDateLt, endDateLte, executor, limit, logicalDateGt, logicalDateGte, logicalDateLt, logicalDateLte, mapIndex, offset, operator, operatorNamePattern, orderBy, pool, poolNamePattern, queue, queueNamePattern, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runIdPattern, startDateGt, startDateGte, startDateLt, startDateLte, state, taskDisplayNamePattern, taskGroupId, taskId, tryNumber, updatedAtGt, updatedAtGte, updatedAtLt, updatedAtLte, versionNumber }) });
/**
* Get Task Instance Try Details
* Get task instance details by try number.
Expand Down
Loading