From 894b7057639c9198f7418e44442e65e629efd7a8 Mon Sep 17 00:00:00 2001 From: OscarLigthart Date: Sat, 22 Nov 2025 11:04:08 +0100 Subject: [PATCH 1/6] refactor: move task stream filtering logic to endpoint for Grid --- .../core_api/openapi/_private_ui.yaml | 22 +++++++ .../api_fastapi/core_api/routes/ui/grid.py | 11 ++++ .../airflow/ui/openapi-gen/queries/common.ts | 7 ++- .../ui/openapi-gen/queries/ensureQueryData.ts | 10 ++- .../ui/openapi-gen/queries/prefetch.ts | 10 ++- .../airflow/ui/openapi-gen/queries/queries.ts | 10 ++- .../ui/openapi-gen/queries/suspense.ts | 10 ++- .../ui/openapi-gen/requests/services.gen.ts | 6 ++ .../ui/openapi-gen/requests/types.gen.ts | 3 + .../ui/src/layouts/Details/Grid/Grid.tsx | 61 ++----------------- .../ui/src/queries/useGridStructure.ts | 9 +++ 11 files changed, 93 insertions(+), 66 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml b/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml index 711c7ba4e16e5..cf241db0d076b 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml +++ b/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml @@ -629,6 +629,28 @@ paths: schema: type: string title: Dag Id + - name: include_upstream + in: query + required: false + schema: + type: boolean + default: false + title: Include Upstream + - name: include_downstream + in: query + required: false + schema: + type: boolean + default: false + title: Include Downstream + - name: root + in: query + required: false + schema: + anyOf: + - type: string + - type: 'null' + title: Root - name: offset in: query required: false diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py index 6fb6cf03b7a5c..34e2ecafd8fd7 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py @@ -31,6 +31,8 @@ QueryDagRunRunTypesFilter, QueryDagRunStateFilter, QueryDagRunTriggeringUserSearch, + QueryIncludeDownstream, + QueryIncludeUpstream, QueryLimit, QueryOffset, RangeFilter, @@ -133,11 +135,20 @@ def get_dag_structure( run_type: QueryDagRunRunTypesFilter, state: QueryDagRunStateFilter, triggering_user: QueryDagRunTriggeringUserSearch, + include_upstream: QueryIncludeUpstream = False, + include_downstream: QueryIncludeDownstream = False, + root: str | None = None, ) -> list[GridNodeResponse]: """Return dag structure for grid view.""" latest_serdag = _get_latest_serdag(dag_id, session) latest_dag = latest_serdag.dag + # Apply filtering if root task is specified + if root: + latest_dag = latest_dag.partial_subset( + task_ids=root, include_upstream=include_upstream, include_downstream=include_downstream + ) + # Retrieve, sort the previous DAG Runs base_query = select(DagRun.id).where(DagRun.dag_id == dag_id) # This comparison is to fall back to DAG timetable when no order_by is provided diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts index aa544d1108aaf..eecadc2ce242f 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts @@ -808,11 +808,14 @@ export const UseStructureServiceStructureDataKeyFn = ({ dagId, externalDependenc export type GridServiceGetDagStructureDefaultResponse = Awaited>; export type GridServiceGetDagStructureQueryResult = UseQueryResult; export const useGridServiceGetDagStructureKey = "GridServiceGetDagStructure"; -export const UseGridServiceGetDagStructureKeyFn = ({ dagId, limit, offset, orderBy, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }: { +export const UseGridServiceGetDagStructureKeyFn = ({ dagId, includeDownstream, includeUpstream, limit, offset, orderBy, root, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }: { dagId: string; + includeDownstream?: boolean; + includeUpstream?: boolean; limit?: number; offset?: number; orderBy?: string[]; + root?: string; runAfterGt?: string; runAfterGte?: string; runAfterLt?: string; @@ -820,7 +823,7 @@ export const UseGridServiceGetDagStructureKeyFn = ({ dagId, limit, offset, order runType?: string[]; state?: string[]; triggeringUser?: string; -}, queryKey?: Array) => [useGridServiceGetDagStructureKey, ...(queryKey ?? [{ dagId, limit, offset, orderBy, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }])]; +}, queryKey?: Array) => [useGridServiceGetDagStructureKey, ...(queryKey ?? [{ dagId, includeDownstream, includeUpstream, limit, offset, orderBy, root, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }])]; export type GridServiceGetGridRunsDefaultResponse = Awaited>; export type GridServiceGetGridRunsQueryResult = UseQueryResult; export const useGridServiceGetGridRunsKey = "GridServiceGetGridRuns"; diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts index 3e921c98eab42..6d051aaabc0ae 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts @@ -1532,6 +1532,9 @@ export const ensureUseStructureServiceStructureDataData = (queryClient: QueryCli * Return dag structure for grid view. * @param data The data for the request. * @param data.dagId +* @param data.includeUpstream +* @param data.includeDownstream +* @param data.root * @param data.offset * @param data.limit * @param data.orderBy Attributes to order by, multi criteria sort is supported. Prefix with `-` for descending order. Supported attributes: `run_after, logical_date, start_date, end_date` @@ -1545,11 +1548,14 @@ export const ensureUseStructureServiceStructureDataData = (queryClient: QueryCli * @returns GridNodeResponse Successful Response * @throws ApiError */ -export const ensureUseGridServiceGetDagStructureData = (queryClient: QueryClient, { dagId, limit, offset, orderBy, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }: { +export const ensureUseGridServiceGetDagStructureData = (queryClient: QueryClient, { dagId, includeDownstream, includeUpstream, limit, offset, orderBy, root, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }: { dagId: string; + includeDownstream?: boolean; + includeUpstream?: boolean; limit?: number; offset?: number; orderBy?: string[]; + root?: string; runAfterGt?: string; runAfterGte?: string; runAfterLt?: string; @@ -1557,7 +1563,7 @@ export const ensureUseGridServiceGetDagStructureData = (queryClient: QueryClient runType?: string[]; state?: string[]; triggeringUser?: string; -}) => queryClient.ensureQueryData({ queryKey: Common.UseGridServiceGetDagStructureKeyFn({ dagId, limit, offset, orderBy, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }), queryFn: () => GridService.getDagStructure({ dagId, limit, offset, orderBy, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }) }); +}) => queryClient.ensureQueryData({ queryKey: Common.UseGridServiceGetDagStructureKeyFn({ dagId, includeDownstream, includeUpstream, limit, offset, orderBy, root, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }), queryFn: () => GridService.getDagStructure({ dagId, includeDownstream, includeUpstream, limit, offset, orderBy, root, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }) }); /** * Get Grid Runs * Get info about a run for the grid. diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts index 4a2fa7fc2c2bc..1020640d3a912 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts @@ -1532,6 +1532,9 @@ export const prefetchUseStructureServiceStructureData = (queryClient: QueryClien * Return dag structure for grid view. * @param data The data for the request. * @param data.dagId +* @param data.includeUpstream +* @param data.includeDownstream +* @param data.root * @param data.offset * @param data.limit * @param data.orderBy Attributes to order by, multi criteria sort is supported. Prefix with `-` for descending order. Supported attributes: `run_after, logical_date, start_date, end_date` @@ -1545,11 +1548,14 @@ export const prefetchUseStructureServiceStructureData = (queryClient: QueryClien * @returns GridNodeResponse Successful Response * @throws ApiError */ -export const prefetchUseGridServiceGetDagStructure = (queryClient: QueryClient, { dagId, limit, offset, orderBy, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }: { +export const prefetchUseGridServiceGetDagStructure = (queryClient: QueryClient, { dagId, includeDownstream, includeUpstream, limit, offset, orderBy, root, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }: { dagId: string; + includeDownstream?: boolean; + includeUpstream?: boolean; limit?: number; offset?: number; orderBy?: string[]; + root?: string; runAfterGt?: string; runAfterGte?: string; runAfterLt?: string; @@ -1557,7 +1563,7 @@ export const prefetchUseGridServiceGetDagStructure = (queryClient: QueryClient, runType?: string[]; state?: string[]; triggeringUser?: string; -}) => queryClient.prefetchQuery({ queryKey: Common.UseGridServiceGetDagStructureKeyFn({ dagId, limit, offset, orderBy, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }), queryFn: () => GridService.getDagStructure({ dagId, limit, offset, orderBy, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }) }); +}) => queryClient.prefetchQuery({ queryKey: Common.UseGridServiceGetDagStructureKeyFn({ dagId, includeDownstream, includeUpstream, limit, offset, orderBy, root, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }), queryFn: () => GridService.getDagStructure({ dagId, includeDownstream, includeUpstream, limit, offset, orderBy, root, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }) }); /** * Get Grid Runs * Get info about a run for the grid. diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts index ccd2ff54d9947..98025319021d7 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts @@ -1532,6 +1532,9 @@ export const useStructureServiceStructureData = = unknown[]>({ dagId, limit, offset, orderBy, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }: { +export const useGridServiceGetDagStructure = = unknown[]>({ dagId, includeDownstream, includeUpstream, limit, offset, orderBy, root, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }: { dagId: string; + includeDownstream?: boolean; + includeUpstream?: boolean; limit?: number; offset?: number; orderBy?: string[]; + root?: string; runAfterGt?: string; runAfterGte?: string; runAfterLt?: string; @@ -1557,7 +1563,7 @@ export const useGridServiceGetDagStructure = , "queryKey" | "queryFn">) => useQuery({ queryKey: Common.UseGridServiceGetDagStructureKeyFn({ dagId, limit, offset, orderBy, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }, queryKey), queryFn: () => GridService.getDagStructure({ dagId, limit, offset, orderBy, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }) as TData, ...options }); +}, queryKey?: TQueryKey, options?: Omit, "queryKey" | "queryFn">) => useQuery({ queryKey: Common.UseGridServiceGetDagStructureKeyFn({ dagId, includeDownstream, includeUpstream, limit, offset, orderBy, root, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }, queryKey), queryFn: () => GridService.getDagStructure({ dagId, includeDownstream, includeUpstream, limit, offset, orderBy, root, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }) as TData, ...options }); /** * Get Grid Runs * Get info about a run for the grid. diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts index 51a339c4d7b08..0f2ca446f9644 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts @@ -1532,6 +1532,9 @@ export const useStructureServiceStructureDataSuspense = = unknown[]>({ dagId, limit, offset, orderBy, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }: { +export const useGridServiceGetDagStructureSuspense = = unknown[]>({ dagId, includeDownstream, includeUpstream, limit, offset, orderBy, root, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }: { dagId: string; + includeDownstream?: boolean; + includeUpstream?: boolean; limit?: number; offset?: number; orderBy?: string[]; + root?: string; runAfterGt?: string; runAfterGte?: string; runAfterLt?: string; @@ -1557,7 +1563,7 @@ export const useGridServiceGetDagStructureSuspense = , "queryKey" | "queryFn">) => useSuspenseQuery({ queryKey: Common.UseGridServiceGetDagStructureKeyFn({ dagId, limit, offset, orderBy, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }, queryKey), queryFn: () => GridService.getDagStructure({ dagId, limit, offset, orderBy, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }) as TData, ...options }); +}, queryKey?: TQueryKey, options?: Omit, "queryKey" | "queryFn">) => useSuspenseQuery({ queryKey: Common.UseGridServiceGetDagStructureKeyFn({ dagId, includeDownstream, includeUpstream, limit, offset, orderBy, root, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }, queryKey), queryFn: () => GridService.getDagStructure({ dagId, includeDownstream, includeUpstream, limit, offset, orderBy, root, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }) as TData, ...options }); /** * Get Grid Runs * Get info about a run for the grid. diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts index d016eaa00f4fd..a9e999c34275b 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts @@ -3883,6 +3883,9 @@ export class GridService { * Return dag structure for grid view. * @param data The data for the request. * @param data.dagId + * @param data.includeUpstream + * @param data.includeDownstream + * @param data.root * @param data.offset * @param data.limit * @param data.orderBy Attributes to order by, multi criteria sort is supported. Prefix with `-` for descending order. Supported attributes: `run_after, logical_date, start_date, end_date` @@ -3904,6 +3907,9 @@ export class GridService { dag_id: data.dagId }, query: { + include_upstream: data.includeUpstream, + include_downstream: data.includeDownstream, + root: data.root, offset: data.offset, limit: data.limit, order_by: data.orderBy, diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts index 024485ac0df6b..d1b82a1d43694 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts @@ -3398,12 +3398,15 @@ export type StructureDataResponse2 = StructureDataResponse; export type GetDagStructureData = { dagId: string; + includeDownstream?: boolean; + includeUpstream?: boolean; limit?: number; offset?: number; /** * Attributes to order by, multi criteria sort is supported. Prefix with `-` for descending order. Supported attributes: `run_after, logical_date, start_date, end_date` */ orderBy?: Array<(string)>; + root?: string | null; runAfterGt?: string | null; runAfterGte?: string | null; runAfterLt?: string | null; diff --git a/airflow-core/src/airflow/ui/src/layouts/Details/Grid/Grid.tsx b/airflow-core/src/airflow/ui/src/layouts/Details/Grid/Grid.tsx index 3690542e00b41..9c879723d2952 100644 --- a/airflow-core/src/airflow/ui/src/layouts/Details/Grid/Grid.tsx +++ b/airflow-core/src/airflow/ui/src/layouts/Details/Grid/Grid.tsx @@ -24,11 +24,9 @@ import { useTranslation } from "react-i18next"; import { FiChevronsRight } from "react-icons/fi"; import { Link, useParams, useSearchParams } from "react-router-dom"; -import { useStructureServiceStructureData } from "openapi/queries"; import type { DagRunState, DagRunType, GridRunsResponse } from "openapi/requests"; import { useOpenGroups } from "src/context/openGroups"; import { useNavigation } from "src/hooks/navigation"; -import useSelectedVersion from "src/hooks/useSelectedVersion"; import { useGridRuns } from "src/queries/useGridRuns.ts"; import { useGridStructure } from "src/queries/useGridStructure.ts"; import { isStatePending } from "src/utils"; @@ -79,53 +77,14 @@ export const Grid = ({ dagRunState, limit, runType, showGantt, triggeringUser }: const { data: dagStructure } = useGridStructure({ dagRunState, hasActiveRun: gridRuns?.some((dr) => isStatePending(dr.state)), + includeDownstream, + includeUpstream, limit, + root: filterRoot, runType, triggeringUser, }); - const selectedVersion = useSelectedVersion(); - - const hasActiveFilter = includeUpstream || includeDownstream; - - // fetch filtered structure when filter is active - const { data: taskStructure } = useStructureServiceStructureData( - { - dagId, - externalDependencies: false, - includeDownstream, - includeUpstream, - root: hasActiveFilter && filterRoot !== undefined ? filterRoot : undefined, - versionNumber: selectedVersion, - }, - undefined, - { - enabled: selectedVersion !== undefined && hasActiveFilter && filterRoot !== undefined, - }, - ); - - // extract allowed task IDs from task structure when filter is active - const allowedTaskIds = useMemo(() => { - if (!hasActiveFilter || filterRoot === undefined || taskStructure === undefined) { - return undefined; - } - - const taskIds = new Set(); - - const addNodeAndChildren = | null; id: string }>(currentNode: T) => { - taskIds.add(currentNode.id); - if (currentNode.children) { - currentNode.children.forEach((child) => addNodeAndChildren(child)); - } - }; - - taskStructure.nodes.forEach((node) => { - addNodeAndChildren(node); - }); - - return taskIds; - }, [hasActiveFilter, filterRoot, taskStructure]); - // calculate dag run bar heights relative to max const max = Math.max.apply( undefined, @@ -137,18 +96,8 @@ export const Grid = ({ dagRunState, limit, runType, showGantt, triggeringUser }: ); const { flatNodes } = useMemo(() => { - const nodes = flattenNodes(dagStructure, openGroupIds); - - // filter nodes based on task stream filter if active - if (allowedTaskIds !== undefined) { - return { - ...nodes, - flatNodes: nodes.flatNodes.filter((node) => allowedTaskIds.has(node.id)), - }; - } - - return nodes; - }, [dagStructure, openGroupIds, allowedTaskIds]); + return flattenNodes(dagStructure, openGroupIds); + }, [dagStructure, openGroupIds]); const { setMode } = useNavigation({ onToggleGroup: toggleGroupId, diff --git a/airflow-core/src/airflow/ui/src/queries/useGridStructure.ts b/airflow-core/src/airflow/ui/src/queries/useGridStructure.ts index ce334759e777f..a74483a6f1568 100644 --- a/airflow-core/src/airflow/ui/src/queries/useGridStructure.ts +++ b/airflow-core/src/airflow/ui/src/queries/useGridStructure.ts @@ -25,13 +25,19 @@ import { useAutoRefresh } from "src/utils"; export const useGridStructure = ({ dagRunState, hasActiveRun, + includeDownstream, + includeUpstream, limit, + root, runType, triggeringUser, }: { dagRunState?: DagRunState | undefined; hasActiveRun?: boolean; + includeDownstream?: boolean; + includeUpstream?: boolean; limit?: number; + root?: string; runType?: DagRunType | undefined; triggeringUser?: string | undefined; }) => { @@ -42,8 +48,11 @@ export const useGridStructure = ({ const { data: dagStructure, ...rest } = useGridServiceGetDagStructure( { dagId, + includeDownstream, + includeUpstream, limit, orderBy: ["-run_after"], + root, runType: runType ? [runType] : undefined, state: dagRunState ? [dagRunState] : undefined, triggeringUser: triggeringUser ?? undefined, From 07893b788c813d81aaf96b95c9d1a93247b97c96 Mon Sep 17 00:00:00 2001 From: OscarLigthart Date: Sat, 22 Nov 2025 11:35:10 +0100 Subject: [PATCH 2/6] fix: lint --- airflow-core/src/airflow/ui/src/layouts/Details/Grid/Grid.tsx | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/airflow-core/src/airflow/ui/src/layouts/Details/Grid/Grid.tsx b/airflow-core/src/airflow/ui/src/layouts/Details/Grid/Grid.tsx index 9c879723d2952..e566094be968b 100644 --- a/airflow-core/src/airflow/ui/src/layouts/Details/Grid/Grid.tsx +++ b/airflow-core/src/airflow/ui/src/layouts/Details/Grid/Grid.tsx @@ -95,9 +95,7 @@ export const Grid = ({ dagRunState, limit, runType, showGantt, triggeringUser }: .filter((duration: number | null): duration is number => duration !== null), ); - const { flatNodes } = useMemo(() => { - return flattenNodes(dagStructure, openGroupIds); - }, [dagStructure, openGroupIds]); + const { flatNodes } = useMemo(() => flattenNodes(dagStructure, openGroupIds), [dagStructure, openGroupIds]); const { setMode } = useNavigation({ onToggleGroup: toggleGroupId, From 091b6c8b518431f1fc45f17daa95259b7cbf94a8 Mon Sep 17 00:00:00 2001 From: OscarLigthart Date: Tue, 2 Dec 2025 11:06:33 +0100 Subject: [PATCH 3/6] feat: add tests for stream filter in grid --- .../core_api/routes/ui/test_grid.py | 133 ++++++++++++++++++ 1 file changed, 133 insertions(+) diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py index a3f8ba4e2ac0d..e770b245c9894 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py @@ -45,6 +45,8 @@ DAG_ID_2 = "test_dag_2" DAG_ID_3 = "test_dag_3" DAG_ID_4 = "test_dag_4" +DAG_ID_5 = "test_dag_5" +DAG_ID_6 = "test_dag_6" TASK_ID = "task" TASK_ID_2 = "task2" TASK_ID_3 = "task3" @@ -253,6 +255,60 @@ def mapped_task_group(arg1): ti.end_date = end_date start_date = end_date end_date = start_date.add(seconds=2) + + # DAG 5 for testing root, include_upstream, include_downstream parameters + with dag_maker(dag_id=DAG_ID_5, serialized=True, session=session) as dag_5: + task_a = EmptyOperator(task_id="task_a") + task_b = EmptyOperator(task_id="task_b") + task_c = EmptyOperator(task_id="task_c") + task_d = EmptyOperator(task_id="task_d") + task_e = EmptyOperator(task_id="task_e") + # Create linear dependency: task_a >> task_b >> task_c >> task_d >> task_e + task_a >> task_b >> task_c >> task_d >> task_e + + logical_date = timezone.datetime(2024, 11, 30) + data_interval = dag_5.timetable.infer_manual_data_interval(run_after=logical_date) + run_5 = dag_maker.create_dagrun( + run_id="run_5-1", + state=DagRunState.SUCCESS, + run_type=DagRunType.SCHEDULED, + start_date=logical_date, + logical_date=logical_date, + data_interval=data_interval, + **triggered_by_kwargs, + ) + for ti in run_5.task_instances: + ti.state = TaskInstanceState.SUCCESS + + # DAG 6 for testing root, include_upstream, include_downstream with non-linear dependencies + # Structure: start >> [branch_a, branch_b] >> merge >> end + # branch_a >> intermediate >> merge + with dag_maker(dag_id=DAG_ID_6, serialized=True, session=session) as dag_6: + start = EmptyOperator(task_id="start") + branch_a = EmptyOperator(task_id="branch_a") + branch_b = EmptyOperator(task_id="branch_b") + intermediate = EmptyOperator(task_id="intermediate") + merge = EmptyOperator(task_id="merge") + end = EmptyOperator(task_id="end") + # Create non-linear dependencies + start >> [branch_a, branch_b] + branch_a >> intermediate >> merge + branch_b >> merge + merge >> end + + logical_date = timezone.datetime(2024, 11, 30) + data_interval = dag_6.timetable.infer_manual_data_interval(run_after=logical_date) + run_6 = dag_maker.create_dagrun( + run_id="run_6-1", + state=DagRunState.SUCCESS, + run_type=DagRunType.SCHEDULED, + start_date=logical_date, + logical_date=logical_date, + data_interval=data_interval, + **triggered_by_kwargs, + ) + for ti in run_6.task_instances: + ti.state = TaskInstanceState.SUCCESS session.commit() @@ -776,3 +832,80 @@ def test_structure_includes_historical_removed_task_with_proper_shape(self, sess # Optional None fields are excluded from response due to response_model_exclude_none=True assert "is_mapped" not in t4 assert "children" not in t4 + + # Tests for root, include_upstream, and include_downstream parameters + def test_structure_with_root_only(self, test_client): + """Test that specifying only root parameter returns just that task.""" + response = test_client.get(f"/grid/structure/{DAG_ID_5}?root=task_c") + assert response.status_code == 200 + nodes = response.json() + task_ids = [node["id"] for node in nodes] + # Only task_c should be returned when root is specified without upstream/downstream + assert task_ids == ["task_c"] + + def test_structure_with_root_and_include_upstream(self, test_client): + """Test that root + include_upstream returns the root task and all upstream tasks.""" + response = test_client.get(f"/grid/structure/{DAG_ID_5}?root=task_c&include_upstream=true") + assert response.status_code == 200 + nodes = response.json() + task_ids = sorted([node["id"] for node in nodes]) + # Should return task_c and all upstream tasks (task_a, task_b) + assert task_ids == ["task_a", "task_b", "task_c"] + + def test_structure_with_root_and_include_downstream(self, test_client): + """Test that root + include_downstream returns the root task and all downstream tasks.""" + response = test_client.get(f"/grid/structure/{DAG_ID_5}?root=task_c&include_downstream=true") + assert response.status_code == 200 + nodes = response.json() + task_ids = sorted([node["id"] for node in nodes]) + # Should return task_c and all downstream tasks (task_d, task_e) + assert task_ids == ["task_c", "task_d", "task_e"] + + # Tests for non-linear DAG structure + def test_nonlinear_structure_with_root_downstream_from_branch(self, test_client): + """Test filtering downstream from branch point includes both branches.""" + response = test_client.get(f"/grid/structure/{DAG_ID_6}?root=start&include_downstream=true") + assert response.status_code == 200 + nodes = response.json() + task_ids = sorted([node["id"] for node in nodes]) + # Should return all tasks since start is the root of the DAG + assert task_ids == ["branch_a", "branch_b", "end", "intermediate", "merge", "start"] + + def test_nonlinear_structure_with_root_upstream_from_merge(self, test_client): + """Test filtering upstream from merge point includes all upstream branches.""" + response = test_client.get(f"/grid/structure/{DAG_ID_6}?root=merge&include_upstream=true") + assert response.status_code == 200 + nodes = response.json() + task_ids = sorted([node["id"] for node in nodes]) + # Should return merge and all upstream tasks (both branches and their parents) + assert task_ids == ["branch_a", "branch_b", "intermediate", "merge", "start"] + + def test_nonlinear_structure_with_root_on_branch_include_downstream(self, test_client): + """Test filtering downstream from one branch.""" + response = test_client.get(f"/grid/structure/{DAG_ID_6}?root=branch_a&include_downstream=true") + assert response.status_code == 200 + nodes = response.json() + task_ids = sorted([node["id"] for node in nodes]) + # Should return branch_a and its downstream path (intermediate, merge, end) + assert task_ids == ["branch_a", "end", "intermediate", "merge"] + + def test_nonlinear_structure_with_root_on_branch_include_upstream(self, test_client): + """Test filtering upstream from one branch.""" + response = test_client.get(f"/grid/structure/{DAG_ID_6}?root=branch_a&include_upstream=true") + assert response.status_code == 200 + nodes = response.json() + task_ids = sorted([node["id"] for node in nodes]) + # Should return branch_a and its upstream (start) + assert task_ids == ["branch_a", "start"] + + def test_nonlinear_structure_intermediate_with_both_directions(self, test_client): + """Test filtering from intermediate node with both upstream and downstream.""" + response = test_client.get( + f"/grid/structure/{DAG_ID_6}?root=intermediate&include_upstream=true&include_downstream=true" + ) + assert response.status_code == 200 + nodes = response.json() + task_ids = sorted([node["id"] for node in nodes]) + # Should return intermediate, its upstream path, and downstream path + # upstream: branch_a, start; downstream: merge, end + assert task_ids == ["branch_a", "end", "intermediate", "merge", "start"] From 22d73f3e2a527d8c6136d047b989803b292dea1f Mon Sep 17 00:00:00 2001 From: OscarLigthart Date: Wed, 3 Dec 2025 11:41:01 +0100 Subject: [PATCH 4/6] fix: historical task filtering --- .../api_fastapi/core_api/routes/ui/grid.py | 43 ++++++------------- .../api_fastapi/core_api/services/ui/grid.py | 42 ++++++++++++++++++ 2 files changed, 56 insertions(+), 29 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py index 34e2ecafd8fd7..67b0472ba1394 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py @@ -53,6 +53,7 @@ _find_aggregates, _get_aggs_for_node, _merge_node_dicts, + collect_historical_tasks, ) from airflow.api_fastapi.core_api.services.ui.task_group import ( get_task_group_children_getter, @@ -192,40 +193,24 @@ def get_dag_structure( dags = [latest_dag] for serdag in serdags: if serdag: - dags.append(serdag.dag) + filtered_dag = serdag.dag + # Apply the same filtering to historical DAG versions + if root: + filtered_dag = filtered_dag.partial_subset( + task_ids=root, include_upstream=include_upstream, include_downstream=include_downstream + ) + dags.append(filtered_dag) for dag in dags: nodes = [task_group_to_dict_grid(x) for x in task_group_sort(dag.task_group)] _merge_node_dicts(merged_nodes, nodes) # Ensure historical tasks (e.g. removed) that exist in TIs for the selected runs are represented - def _collect_ids(nodes: list[dict[str, Any]]) -> set[str]: - ids: set[str] = set() - for n in nodes: - nid = n.get("id") - if nid: - ids.add(nid) - children = n.get("children") - if children: - ids |= _collect_ids(children) # recurse - return ids - - existing_ids = _collect_ids(merged_nodes) - historical_tasks = session.execute( - select(TaskInstance.task_id, TaskInstance.task_display_name) - .join(TaskInstance.dag_run) - .where(TaskInstance.dag_id == dag_id, DagRun.id.in_(run_ids)) - .distinct() - ) - for task_id, task_display_name in historical_tasks: - if task_id not in existing_ids: - merged_nodes.append( - { - "id": task_id, - "label": task_display_name, - "is_mapped": None, - "children": None, - } - ) + # Only add in case no filter is applied + if not root: + historical_nodes = collect_historical_tasks( + nodes=merged_nodes, dag_id=dag_id, run_ids=run_ids, session=session + ) + _merge_node_dicts(merged_nodes, historical_nodes) return [GridNodeResponse(**n) for n in merged_nodes] diff --git a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py index 5b20511d31f47..7ff37d9d3dc95 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py @@ -19,12 +19,17 @@ from collections import Counter from collections.abc import Iterable +from typing import Any import structlog +from sqlalchemy import select +from sqlalchemy.orm import Session from airflow.api_fastapi.common.parameters import state_priority from airflow.api_fastapi.core_api.services.ui.task_group import get_task_group_children_getter +from airflow.models.dagrun import DagRun from airflow.models.mappedoperator import MappedOperator +from airflow.models.taskinstance import TaskInstance from airflow.models.taskmap import TaskMap from airflow.serialization.definitions.taskgroup import SerializedTaskGroup from airflow.serialization.serialized_objects import SerializedBaseOperator @@ -51,6 +56,43 @@ def _get_node_by_id(nodes, node_id): return {} +def collect_historical_tasks( + nodes: list[dict[str, Any]], dag_id: str, run_ids: list[Any], session: Session +) -> list[dict[str, Any]]: + historical_nodes = [] + existing_ids = _collect_ids(nodes) + historical_tasks = session.execute( + select(TaskInstance.task_id, TaskInstance.task_display_name) + .join(TaskInstance.dag_run) + .where(TaskInstance.dag_id == dag_id, DagRun.id.in_(run_ids)) + .distinct() + ) + for task_id, task_display_name in historical_tasks: + if task_id not in existing_ids: + historical_nodes.append( + { + "id": task_id, + "label": task_display_name, + "is_mapped": None, + "children": None, + } + ) + + return nodes + + +def _collect_ids(nodes: list[dict[str, Any]]) -> set[str]: + ids: set[str] = set() + for n in nodes: + nid = n.get("id") + if nid: + ids.add(nid) + children = n.get("children") + if children: + ids |= _collect_ids(children) # recurse + return ids + + def agg_state(states): states = Counter(states) for state in state_priority: From 65af1e6c29e001603941297b2ff48dfd0c6676f8 Mon Sep 17 00:00:00 2001 From: OscarLigthart Date: Thu, 4 Dec 2025 10:34:15 +0100 Subject: [PATCH 5/6] fix: return value for collect_historical_tasks --- .../src/airflow/api_fastapi/core_api/services/ui/grid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py index 7ff37d9d3dc95..f99f5fa134163 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py @@ -78,7 +78,7 @@ def collect_historical_tasks( } ) - return nodes + return historical_nodes def _collect_ids(nodes: list[dict[str, Any]]) -> set[str]: From 2752bfbeaaecfb583de5e48b203a0bcf219d8886 Mon Sep 17 00:00:00 2001 From: OscarLigthart Date: Sat, 13 Dec 2025 19:55:04 +0100 Subject: [PATCH 6/6] fix: remove duplicate historical task retrieval and add historical task to test case --- .../api_fastapi/core_api/routes/ui/grid.py | 9 ---- .../api_fastapi/core_api/services/ui/grid.py | 42 ------------------- .../core_api/routes/ui/test_grid.py | 42 +++++++++++++++---- 3 files changed, 33 insertions(+), 60 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py index 67b0472ba1394..9c8a4e1782066 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py @@ -53,7 +53,6 @@ _find_aggregates, _get_aggs_for_node, _merge_node_dicts, - collect_historical_tasks, ) from airflow.api_fastapi.core_api.services.ui.task_group import ( get_task_group_children_getter, @@ -204,14 +203,6 @@ def get_dag_structure( nodes = [task_group_to_dict_grid(x) for x in task_group_sort(dag.task_group)] _merge_node_dicts(merged_nodes, nodes) - # Ensure historical tasks (e.g. removed) that exist in TIs for the selected runs are represented - # Only add in case no filter is applied - if not root: - historical_nodes = collect_historical_tasks( - nodes=merged_nodes, dag_id=dag_id, run_ids=run_ids, session=session - ) - _merge_node_dicts(merged_nodes, historical_nodes) - return [GridNodeResponse(**n) for n in merged_nodes] diff --git a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py index f99f5fa134163..5b20511d31f47 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py @@ -19,17 +19,12 @@ from collections import Counter from collections.abc import Iterable -from typing import Any import structlog -from sqlalchemy import select -from sqlalchemy.orm import Session from airflow.api_fastapi.common.parameters import state_priority from airflow.api_fastapi.core_api.services.ui.task_group import get_task_group_children_getter -from airflow.models.dagrun import DagRun from airflow.models.mappedoperator import MappedOperator -from airflow.models.taskinstance import TaskInstance from airflow.models.taskmap import TaskMap from airflow.serialization.definitions.taskgroup import SerializedTaskGroup from airflow.serialization.serialized_objects import SerializedBaseOperator @@ -56,43 +51,6 @@ def _get_node_by_id(nodes, node_id): return {} -def collect_historical_tasks( - nodes: list[dict[str, Any]], dag_id: str, run_ids: list[Any], session: Session -) -> list[dict[str, Any]]: - historical_nodes = [] - existing_ids = _collect_ids(nodes) - historical_tasks = session.execute( - select(TaskInstance.task_id, TaskInstance.task_display_name) - .join(TaskInstance.dag_run) - .where(TaskInstance.dag_id == dag_id, DagRun.id.in_(run_ids)) - .distinct() - ) - for task_id, task_display_name in historical_tasks: - if task_id not in existing_ids: - historical_nodes.append( - { - "id": task_id, - "label": task_display_name, - "is_mapped": None, - "children": None, - } - ) - - return historical_nodes - - -def _collect_ids(nodes: list[dict[str, Any]]) -> set[str]: - ids: set[str] = set() - for n in nodes: - nid = n.get("id") - if nid: - ids.add(nid) - children = n.get("children") - if children: - ids |= _collect_ids(children) # recurse - return ids - - def agg_state(states): states = Counter(states) for state in state_priority: diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py index e770b245c9894..7fbb81d1fa214 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py @@ -257,19 +257,20 @@ def mapped_task_group(arg1): end_date = start_date.add(seconds=2) # DAG 5 for testing root, include_upstream, include_downstream parameters + # Also includes a Historical task with dag_maker(dag_id=DAG_ID_5, serialized=True, session=session) as dag_5: task_a = EmptyOperator(task_id="task_a") task_b = EmptyOperator(task_id="task_b") task_c = EmptyOperator(task_id="task_c") task_d = EmptyOperator(task_id="task_d") - task_e = EmptyOperator(task_id="task_e") - # Create linear dependency: task_a >> task_b >> task_c >> task_d >> task_e - task_a >> task_b >> task_c >> task_d >> task_e + task_f = EmptyOperator(task_id="task_f") + task_a >> task_b >> task_c >> task_d >> task_f + # Create linear dependency: task_a >> task_b >> task_c >> task_d >> task_f (HISTORICAL_TASK) logical_date = timezone.datetime(2024, 11, 30) data_interval = dag_5.timetable.infer_manual_data_interval(run_after=logical_date) - run_5 = dag_maker.create_dagrun( - run_id="run_5-1", + run_5_1 = dag_maker.create_dagrun( + run_id="run_5_1", state=DagRunState.SUCCESS, run_type=DagRunType.SCHEDULED, start_date=logical_date, @@ -277,8 +278,31 @@ def mapped_task_group(arg1): data_interval=data_interval, **triggered_by_kwargs, ) - for ti in run_5.task_instances: + + with dag_maker(dag_id=DAG_ID_5, serialized=True, session=session) as dag_5: + task_a = EmptyOperator(task_id="task_a") + task_b = EmptyOperator(task_id="task_b") + task_c = EmptyOperator(task_id="task_c") + task_d = EmptyOperator(task_id="task_d") + task_e = EmptyOperator(task_id="task_e") + task_a >> task_b >> task_c >> task_d >> task_e + # Create linear dependency: task_a >> task_b >> task_c >> task_d >> task_e + + run_5_2 = dag_maker.create_dagrun( + run_id="run_5_2", + state=DagRunState.SUCCESS, + run_type=DagRunType.SCHEDULED, + start_date=logical_date, + logical_date=logical_date + timedelta(days=1), + data_interval=data_interval, + **triggered_by_kwargs, + ) + for ti in run_5_1.task_instances: ti.state = TaskInstanceState.SUCCESS + ti.end_date = None + for ti in run_5_2.task_instances: + ti.state = TaskInstanceState.SUCCESS + ti.end_date = None # DAG 6 for testing root, include_upstream, include_downstream with non-linear dependencies # Structure: start >> [branch_a, branch_b] >> merge >> end @@ -852,14 +876,14 @@ def test_structure_with_root_and_include_upstream(self, test_client): # Should return task_c and all upstream tasks (task_a, task_b) assert task_ids == ["task_a", "task_b", "task_c"] - def test_structure_with_root_and_include_downstream(self, test_client): + def test_structure_with_root_and_include_downstream_and_historical(self, test_client): """Test that root + include_downstream returns the root task and all downstream tasks.""" response = test_client.get(f"/grid/structure/{DAG_ID_5}?root=task_c&include_downstream=true") assert response.status_code == 200 nodes = response.json() task_ids = sorted([node["id"] for node in nodes]) - # Should return task_c and all downstream tasks (task_d, task_e) - assert task_ids == ["task_c", "task_d", "task_e"] + # Should return task_c and all downstream tasks, including the historical (task_d, task_e, task_f (HISTORICAL)) + assert task_ids == ["task_c", "task_d", "task_e", "task_f"] # Tests for non-linear DAG structure def test_nonlinear_structure_with_root_downstream_from_branch(self, test_client):