Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions airflow-core/src/airflow/api_fastapi/common/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@
if TYPE_CHECKING:
from sqlalchemy.sql import ColumnElement, Select

from airflow.serialization.serialized_objects import SerializedDAG

T = TypeVar("T")


Expand Down Expand Up @@ -181,6 +183,57 @@ def depends(cls, *args: Any, **kwargs: Any) -> Self:
raise NotImplementedError("Use search_param_factory instead , depends is not implemented.")


class QueryTaskInstanceTaskGroupFilter(BaseParam[str]):
"""Task group filter - returns all tasks in the specified group."""

def __init__(self, dag=None, skip_none: bool = True):
super().__init__(skip_none=skip_none)
self._dag: None | SerializedDAG = dag

@property
def dag(self) -> None | SerializedDAG:
return self._dag

@dag.setter
def dag(self, value: None | SerializedDAG) -> None:
self._dag = value

def to_orm(self, select: Select) -> Select:
if self.value is None and self.skip_none:
return select

if not self.dag:
raise ValueError("Dag must be set before calling to_orm")

if not hasattr(self.dag, "task_group"):
return select

# Exact matching on group_id
task_groups = self.dag.task_group.get_task_group_dict()
task_group = task_groups.get(self.value)
if not task_group:
raise HTTPException(
status.HTTP_404_NOT_FOUND,
detail={
"reason": "not_found",
"message": f"Task group {self.value} not found",
},
)

return select.where(TaskInstance.task_id.in_(task.task_id for task in task_group.iter_tasks()))

@classmethod
def depends(
cls,
value: str | None = Query(
alias="task_group_id",
default=None,
description="Filter by exact task group ID. Returns all tasks within the specified task group.",
),
) -> QueryTaskInstanceTaskGroupFilter:
return cls(dag=None).set_value(value)


def search_param_factory(
attribute: ColumnElement,
pattern_name: str,
Expand Down Expand Up @@ -862,6 +915,9 @@ def _transform_ti_states(states: list[str] | None) -> list[TaskInstanceState | N
QueryTITaskDisplayNamePatternSearch = Annotated[
_SearchParam, Depends(search_param_factory(TaskInstance.task_display_name, "task_display_name_pattern"))
]
QueryTITaskGroupFilter = Annotated[
QueryTaskInstanceTaskGroupFilter, Depends(QueryTaskInstanceTaskGroupFilter.depends)
]
QueryTIDagVersionFilter = Annotated[
FilterParam[list[int]],
Depends(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6499,6 +6499,18 @@ paths:
title: Task Display Name Pattern
description: "SQL LIKE expression \u2014 use `%` / `_` wildcards (e.g. `%customer_%`).\
\ Regular expressions are **not** supported."
- name: task_group_id
in: query
required: false
schema:
anyOf:
- type: string
- type: 'null'
description: Filter by exact task group ID. Returns all tasks within the
specified task group.
title: Task Group Id
description: Filter by exact task group ID. Returns all tasks within the specified
task group.
- name: state
in: query
required: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
QueryTIQueueFilter,
QueryTIStateFilter,
QueryTITaskDisplayNamePatternSearch,
QueryTITaskGroupFilter,
QueryTITryNumberFilter,
Range,
RangeFilter,
Expand Down Expand Up @@ -407,6 +408,7 @@ def get_task_instances(
update_at_range: Annotated[RangeFilter, Depends(datetime_range_filter_factory("updated_at", TI))],
duration_range: Annotated[RangeFilter, Depends(float_range_filter_factory("duration", TI))],
task_display_name_pattern: QueryTITaskDisplayNamePatternSearch,
task_group_id: QueryTITaskGroupFilter,
state: QueryTIStateFilter,
pool: QueryTIPoolFilter,
queue: QueryTIQueueFilter,
Expand Down Expand Up @@ -468,8 +470,10 @@ def get_task_instances(
)
query = query.where(TI.run_id == dag_run_id)
if dag_id != "~":
get_dag_for_run_or_latest_version(dag_bag, dag_run, dag_id, session)
dag = get_dag_for_run_or_latest_version(dag_bag, dag_run, dag_id, session)
query = query.where(TI.dag_id == dag_id)
if dag:
task_group_id.dag = dag

task_instance_select, total_entries = paginated_select(
statement=query,
Expand All @@ -486,6 +490,7 @@ def get_task_instances(
executor,
task_id,
task_display_name_pattern,
task_group_id,
version_number,
readable_ti_filter,
try_number,
Expand Down
5 changes: 3 additions & 2 deletions airflow-core/src/airflow/ui/openapi-gen/queries/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ export const UseTaskInstanceServiceGetMappedTaskInstanceKeyFn = ({ dagId, dagRun
export type TaskInstanceServiceGetTaskInstancesDefaultResponse = Awaited<ReturnType<typeof TaskInstanceService.getTaskInstances>>;
export type TaskInstanceServiceGetTaskInstancesQueryResult<TData = TaskInstanceServiceGetTaskInstancesDefaultResponse, TError = unknown> = UseQueryResult<TData, TError>;
export const useTaskInstanceServiceGetTaskInstancesKey = "TaskInstanceServiceGetTaskInstances";
export const UseTaskInstanceServiceGetTaskInstancesKeyFn = ({ dagId, dagRunId, durationGt, durationGte, durationLt, durationLte, endDateGt, endDateGte, endDateLt, endDateLte, executor, limit, logicalDateGt, logicalDateGte, logicalDateLt, logicalDateLte, mapIndex, offset, operator, orderBy, pool, queue, runAfterGt, runAfterGte, runAfterLt, runAfterLte, startDateGt, startDateGte, startDateLt, startDateLte, state, taskDisplayNamePattern, taskId, tryNumber, updatedAtGt, updatedAtGte, updatedAtLt, updatedAtLte, versionNumber }: {
export const UseTaskInstanceServiceGetTaskInstancesKeyFn = ({ dagId, dagRunId, durationGt, durationGte, durationLt, durationLte, endDateGt, endDateGte, endDateLt, endDateLte, executor, limit, logicalDateGt, logicalDateGte, logicalDateLt, logicalDateLte, mapIndex, offset, operator, orderBy, pool, queue, runAfterGt, runAfterGte, runAfterLt, runAfterLte, startDateGt, startDateGte, startDateLt, startDateLte, state, taskDisplayNamePattern, taskGroupId, taskId, tryNumber, updatedAtGt, updatedAtGte, updatedAtLt, updatedAtLte, versionNumber }: {
dagId: string;
dagRunId: string;
durationGt?: number;
Expand Down Expand Up @@ -497,14 +497,15 @@ export const UseTaskInstanceServiceGetTaskInstancesKeyFn = ({ dagId, dagRunId, d
startDateLte?: string;
state?: string[];
taskDisplayNamePattern?: string;
taskGroupId?: string;
taskId?: string;
tryNumber?: number[];
updatedAtGt?: string;
updatedAtGte?: string;
updatedAtLt?: string;
updatedAtLte?: string;
versionNumber?: number[];
}, queryKey?: Array<unknown>) => [useTaskInstanceServiceGetTaskInstancesKey, ...(queryKey ?? [{ dagId, dagRunId, durationGt, durationGte, durationLt, durationLte, endDateGt, endDateGte, endDateLt, endDateLte, executor, limit, logicalDateGt, logicalDateGte, logicalDateLt, logicalDateLte, mapIndex, offset, operator, orderBy, pool, queue, runAfterGt, runAfterGte, runAfterLt, runAfterLte, startDateGt, startDateGte, startDateLt, startDateLte, state, taskDisplayNamePattern, taskId, tryNumber, updatedAtGt, updatedAtGte, updatedAtLt, updatedAtLte, versionNumber }])];
}, queryKey?: Array<unknown>) => [useTaskInstanceServiceGetTaskInstancesKey, ...(queryKey ?? [{ dagId, dagRunId, durationGt, durationGte, durationLt, durationLte, endDateGt, endDateGte, endDateLt, endDateLte, executor, limit, logicalDateGt, logicalDateGte, logicalDateLt, logicalDateLte, mapIndex, offset, operator, orderBy, pool, queue, runAfterGt, runAfterGte, runAfterLt, runAfterLte, startDateGt, startDateGte, startDateLt, startDateLte, state, taskDisplayNamePattern, taskGroupId, taskId, tryNumber, updatedAtGt, updatedAtGte, updatedAtLt, updatedAtLte, versionNumber }])];
export type TaskInstanceServiceGetTaskInstanceTryDetailsDefaultResponse = Awaited<ReturnType<typeof TaskInstanceService.getTaskInstanceTryDetails>>;
export type TaskInstanceServiceGetTaskInstanceTryDetailsQueryResult<TData = TaskInstanceServiceGetTaskInstanceTryDetailsDefaultResponse, TError = unknown> = UseQueryResult<TData, TError>;
export const useTaskInstanceServiceGetTaskInstanceTryDetailsKey = "TaskInstanceServiceGetTaskInstanceTryDetails";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,7 @@ export const ensureUseTaskInstanceServiceGetMappedTaskInstanceData = (queryClien
* @param data.durationLte
* @param data.durationLt
* @param data.taskDisplayNamePattern SQL LIKE expression — use `%` / `_` wildcards (e.g. `%customer_%`). Regular expressions are **not** supported.
* @param data.taskGroupId Filter by exact task group ID. Returns all tasks within the specified task group.
* @param data.state
* @param data.pool
* @param data.queue
Expand All @@ -919,7 +920,7 @@ export const ensureUseTaskInstanceServiceGetMappedTaskInstanceData = (queryClien
* @returns TaskInstanceCollectionResponse Successful Response
* @throws ApiError
*/
export const ensureUseTaskInstanceServiceGetTaskInstancesData = (queryClient: QueryClient, { dagId, dagRunId, durationGt, durationGte, durationLt, durationLte, endDateGt, endDateGte, endDateLt, endDateLte, executor, limit, logicalDateGt, logicalDateGte, logicalDateLt, logicalDateLte, mapIndex, offset, operator, orderBy, pool, queue, runAfterGt, runAfterGte, runAfterLt, runAfterLte, startDateGt, startDateGte, startDateLt, startDateLte, state, taskDisplayNamePattern, taskId, tryNumber, updatedAtGt, updatedAtGte, updatedAtLt, updatedAtLte, versionNumber }: {
export const ensureUseTaskInstanceServiceGetTaskInstancesData = (queryClient: QueryClient, { dagId, dagRunId, durationGt, durationGte, durationLt, durationLte, endDateGt, endDateGte, endDateLt, endDateLte, executor, limit, logicalDateGt, logicalDateGte, logicalDateLt, logicalDateLte, mapIndex, offset, operator, orderBy, pool, queue, runAfterGt, runAfterGte, runAfterLt, runAfterLte, startDateGt, startDateGte, startDateLt, startDateLte, state, taskDisplayNamePattern, taskGroupId, taskId, tryNumber, updatedAtGt, updatedAtGte, updatedAtLt, updatedAtLte, versionNumber }: {
dagId: string;
dagRunId: string;
durationGt?: number;
Expand Down Expand Up @@ -952,14 +953,15 @@ export const ensureUseTaskInstanceServiceGetTaskInstancesData = (queryClient: Qu
startDateLte?: string;
state?: string[];
taskDisplayNamePattern?: string;
taskGroupId?: string;
taskId?: string;
tryNumber?: number[];
updatedAtGt?: string;
updatedAtGte?: string;
updatedAtLt?: string;
updatedAtLte?: string;
versionNumber?: number[];
}) => queryClient.ensureQueryData({ queryKey: Common.UseTaskInstanceServiceGetTaskInstancesKeyFn({ dagId, dagRunId, durationGt, durationGte, durationLt, durationLte, endDateGt, endDateGte, endDateLt, endDateLte, executor, limit, logicalDateGt, logicalDateGte, logicalDateLt, logicalDateLte, mapIndex, offset, operator, orderBy, pool, queue, runAfterGt, runAfterGte, runAfterLt, runAfterLte, startDateGt, startDateGte, startDateLt, startDateLte, state, taskDisplayNamePattern, taskId, tryNumber, updatedAtGt, updatedAtGte, updatedAtLt, updatedAtLte, versionNumber }), queryFn: () => TaskInstanceService.getTaskInstances({ dagId, dagRunId, durationGt, durationGte, durationLt, durationLte, endDateGt, endDateGte, endDateLt, endDateLte, executor, limit, logicalDateGt, logicalDateGte, logicalDateLt, logicalDateLte, mapIndex, offset, operator, orderBy, pool, queue, runAfterGt, runAfterGte, runAfterLt, runAfterLte, startDateGt, startDateGte, startDateLt, startDateLte, state, taskDisplayNamePattern, taskId, tryNumber, updatedAtGt, updatedAtGte, updatedAtLt, updatedAtLte, versionNumber }) });
}) => queryClient.ensureQueryData({ queryKey: Common.UseTaskInstanceServiceGetTaskInstancesKeyFn({ dagId, dagRunId, durationGt, durationGte, durationLt, durationLte, endDateGt, endDateGte, endDateLt, endDateLte, executor, limit, logicalDateGt, logicalDateGte, logicalDateLt, logicalDateLte, mapIndex, offset, operator, orderBy, pool, queue, runAfterGt, runAfterGte, runAfterLt, runAfterLte, startDateGt, startDateGte, startDateLt, startDateLte, state, taskDisplayNamePattern, taskGroupId, taskId, tryNumber, updatedAtGt, updatedAtGte, updatedAtLt, updatedAtLte, versionNumber }), queryFn: () => TaskInstanceService.getTaskInstances({ dagId, dagRunId, durationGt, durationGte, durationLt, durationLte, endDateGt, endDateGte, endDateLt, endDateLte, executor, limit, logicalDateGt, logicalDateGte, logicalDateLt, logicalDateLte, mapIndex, offset, operator, orderBy, pool, queue, runAfterGt, runAfterGte, runAfterLt, runAfterLte, startDateGt, startDateGte, startDateLt, startDateLte, state, taskDisplayNamePattern, taskGroupId, taskId, tryNumber, updatedAtGt, updatedAtGte, updatedAtLt, updatedAtLte, versionNumber }) });
/**
* Get Task Instance Try Details
* Get task instance details by try number.
Expand Down
6 changes: 4 additions & 2 deletions airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,7 @@ export const prefetchUseTaskInstanceServiceGetMappedTaskInstance = (queryClient:
* @param data.durationLte
* @param data.durationLt
* @param data.taskDisplayNamePattern SQL LIKE expression — use `%` / `_` wildcards (e.g. `%customer_%`). Regular expressions are **not** supported.
* @param data.taskGroupId Filter by exact task group ID. Returns all tasks within the specified task group.
* @param data.state
* @param data.pool
* @param data.queue
Expand All @@ -919,7 +920,7 @@ export const prefetchUseTaskInstanceServiceGetMappedTaskInstance = (queryClient:
* @returns TaskInstanceCollectionResponse Successful Response
* @throws ApiError
*/
export const prefetchUseTaskInstanceServiceGetTaskInstances = (queryClient: QueryClient, { dagId, dagRunId, durationGt, durationGte, durationLt, durationLte, endDateGt, endDateGte, endDateLt, endDateLte, executor, limit, logicalDateGt, logicalDateGte, logicalDateLt, logicalDateLte, mapIndex, offset, operator, orderBy, pool, queue, runAfterGt, runAfterGte, runAfterLt, runAfterLte, startDateGt, startDateGte, startDateLt, startDateLte, state, taskDisplayNamePattern, taskId, tryNumber, updatedAtGt, updatedAtGte, updatedAtLt, updatedAtLte, versionNumber }: {
export const prefetchUseTaskInstanceServiceGetTaskInstances = (queryClient: QueryClient, { dagId, dagRunId, durationGt, durationGte, durationLt, durationLte, endDateGt, endDateGte, endDateLt, endDateLte, executor, limit, logicalDateGt, logicalDateGte, logicalDateLt, logicalDateLte, mapIndex, offset, operator, orderBy, pool, queue, runAfterGt, runAfterGte, runAfterLt, runAfterLte, startDateGt, startDateGte, startDateLt, startDateLte, state, taskDisplayNamePattern, taskGroupId, taskId, tryNumber, updatedAtGt, updatedAtGte, updatedAtLt, updatedAtLte, versionNumber }: {
dagId: string;
dagRunId: string;
durationGt?: number;
Expand Down Expand Up @@ -952,14 +953,15 @@ export const prefetchUseTaskInstanceServiceGetTaskInstances = (queryClient: Quer
startDateLte?: string;
state?: string[];
taskDisplayNamePattern?: string;
taskGroupId?: string;
taskId?: string;
tryNumber?: number[];
updatedAtGt?: string;
updatedAtGte?: string;
updatedAtLt?: string;
updatedAtLte?: string;
versionNumber?: number[];
}) => queryClient.prefetchQuery({ queryKey: Common.UseTaskInstanceServiceGetTaskInstancesKeyFn({ dagId, dagRunId, durationGt, durationGte, durationLt, durationLte, endDateGt, endDateGte, endDateLt, endDateLte, executor, limit, logicalDateGt, logicalDateGte, logicalDateLt, logicalDateLte, mapIndex, offset, operator, orderBy, pool, queue, runAfterGt, runAfterGte, runAfterLt, runAfterLte, startDateGt, startDateGte, startDateLt, startDateLte, state, taskDisplayNamePattern, taskId, tryNumber, updatedAtGt, updatedAtGte, updatedAtLt, updatedAtLte, versionNumber }), queryFn: () => TaskInstanceService.getTaskInstances({ dagId, dagRunId, durationGt, durationGte, durationLt, durationLte, endDateGt, endDateGte, endDateLt, endDateLte, executor, limit, logicalDateGt, logicalDateGte, logicalDateLt, logicalDateLte, mapIndex, offset, operator, orderBy, pool, queue, runAfterGt, runAfterGte, runAfterLt, runAfterLte, startDateGt, startDateGte, startDateLt, startDateLte, state, taskDisplayNamePattern, taskId, tryNumber, updatedAtGt, updatedAtGte, updatedAtLt, updatedAtLte, versionNumber }) });
}) => queryClient.prefetchQuery({ queryKey: Common.UseTaskInstanceServiceGetTaskInstancesKeyFn({ dagId, dagRunId, durationGt, durationGte, durationLt, durationLte, endDateGt, endDateGte, endDateLt, endDateLte, executor, limit, logicalDateGt, logicalDateGte, logicalDateLt, logicalDateLte, mapIndex, offset, operator, orderBy, pool, queue, runAfterGt, runAfterGte, runAfterLt, runAfterLte, startDateGt, startDateGte, startDateLt, startDateLte, state, taskDisplayNamePattern, taskGroupId, taskId, tryNumber, updatedAtGt, updatedAtGte, updatedAtLt, updatedAtLte, versionNumber }), queryFn: () => TaskInstanceService.getTaskInstances({ dagId, dagRunId, durationGt, durationGte, durationLt, durationLte, endDateGt, endDateGte, endDateLt, endDateLte, executor, limit, logicalDateGt, logicalDateGte, logicalDateLt, logicalDateLte, mapIndex, offset, operator, orderBy, pool, queue, runAfterGt, runAfterGte, runAfterLt, runAfterLte, startDateGt, startDateGte, startDateLt, startDateLte, state, taskDisplayNamePattern, taskGroupId, taskId, tryNumber, updatedAtGt, updatedAtGte, updatedAtLt, updatedAtLte, versionNumber }) });
/**
* Get Task Instance Try Details
* Get task instance details by try number.
Expand Down
Loading
Loading