diff --git a/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml b/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml index 711c7ba4e16e5..cf241db0d076b 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml +++ b/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml @@ -629,6 +629,28 @@ paths: schema: type: string title: Dag Id + - name: include_upstream + in: query + required: false + schema: + type: boolean + default: false + title: Include Upstream + - name: include_downstream + in: query + required: false + schema: + type: boolean + default: false + title: Include Downstream + - name: root + in: query + required: false + schema: + anyOf: + - type: string + - type: 'null' + title: Root - name: offset in: query required: false diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py index 6fb6cf03b7a5c..9c8a4e1782066 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py @@ -31,6 +31,8 @@ QueryDagRunRunTypesFilter, QueryDagRunStateFilter, QueryDagRunTriggeringUserSearch, + QueryIncludeDownstream, + QueryIncludeUpstream, QueryLimit, QueryOffset, RangeFilter, @@ -133,11 +135,20 @@ def get_dag_structure( run_type: QueryDagRunRunTypesFilter, state: QueryDagRunStateFilter, triggering_user: QueryDagRunTriggeringUserSearch, + include_upstream: QueryIncludeUpstream = False, + include_downstream: QueryIncludeDownstream = False, + root: str | None = None, ) -> list[GridNodeResponse]: """Return dag structure for grid view.""" latest_serdag = _get_latest_serdag(dag_id, session) latest_dag = latest_serdag.dag + # Apply filtering if root task is specified + if root: + latest_dag = latest_dag.partial_subset( + task_ids=root, include_upstream=include_upstream, include_downstream=include_downstream + ) + # Retrieve, sort the previous DAG Runs base_query = select(DagRun.id).where(DagRun.dag_id == dag_id) # This comparison is to fall back to DAG timetable when no order_by is provided @@ -181,41 +192,17 @@ def get_dag_structure( dags = [latest_dag] for serdag in serdags: if serdag: - dags.append(serdag.dag) + filtered_dag = serdag.dag + # Apply the same filtering to historical DAG versions + if root: + filtered_dag = filtered_dag.partial_subset( + task_ids=root, include_upstream=include_upstream, include_downstream=include_downstream + ) + dags.append(filtered_dag) for dag in dags: nodes = [task_group_to_dict_grid(x) for x in task_group_sort(dag.task_group)] _merge_node_dicts(merged_nodes, nodes) - # Ensure historical tasks (e.g. removed) that exist in TIs for the selected runs are represented - def _collect_ids(nodes: list[dict[str, Any]]) -> set[str]: - ids: set[str] = set() - for n in nodes: - nid = n.get("id") - if nid: - ids.add(nid) - children = n.get("children") - if children: - ids |= _collect_ids(children) # recurse - return ids - - existing_ids = _collect_ids(merged_nodes) - historical_tasks = session.execute( - select(TaskInstance.task_id, TaskInstance.task_display_name) - .join(TaskInstance.dag_run) - .where(TaskInstance.dag_id == dag_id, DagRun.id.in_(run_ids)) - .distinct() - ) - for task_id, task_display_name in historical_tasks: - if task_id not in existing_ids: - merged_nodes.append( - { - "id": task_id, - "label": task_display_name, - "is_mapped": None, - "children": None, - } - ) - return [GridNodeResponse(**n) for n in merged_nodes] diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts index aa544d1108aaf..eecadc2ce242f 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts @@ -808,11 +808,14 @@ export const UseStructureServiceStructureDataKeyFn = ({ dagId, externalDependenc export type GridServiceGetDagStructureDefaultResponse = Awaited>; export type GridServiceGetDagStructureQueryResult = UseQueryResult; export const useGridServiceGetDagStructureKey = "GridServiceGetDagStructure"; -export const UseGridServiceGetDagStructureKeyFn = ({ dagId, limit, offset, orderBy, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }: { +export const UseGridServiceGetDagStructureKeyFn = ({ dagId, includeDownstream, includeUpstream, limit, offset, orderBy, root, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }: { dagId: string; + includeDownstream?: boolean; + includeUpstream?: boolean; limit?: number; offset?: number; orderBy?: string[]; + root?: string; runAfterGt?: string; runAfterGte?: string; runAfterLt?: string; @@ -820,7 +823,7 @@ export const UseGridServiceGetDagStructureKeyFn = ({ dagId, limit, offset, order runType?: string[]; state?: string[]; triggeringUser?: string; -}, queryKey?: Array) => [useGridServiceGetDagStructureKey, ...(queryKey ?? [{ dagId, limit, offset, orderBy, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }])]; +}, queryKey?: Array) => [useGridServiceGetDagStructureKey, ...(queryKey ?? [{ dagId, includeDownstream, includeUpstream, limit, offset, orderBy, root, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }])]; export type GridServiceGetGridRunsDefaultResponse = Awaited>; export type GridServiceGetGridRunsQueryResult = UseQueryResult; export const useGridServiceGetGridRunsKey = "GridServiceGetGridRuns"; diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts index 3e921c98eab42..6d051aaabc0ae 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts @@ -1532,6 +1532,9 @@ export const ensureUseStructureServiceStructureDataData = (queryClient: QueryCli * Return dag structure for grid view. * @param data The data for the request. * @param data.dagId +* @param data.includeUpstream +* @param data.includeDownstream +* @param data.root * @param data.offset * @param data.limit * @param data.orderBy Attributes to order by, multi criteria sort is supported. Prefix with `-` for descending order. Supported attributes: `run_after, logical_date, start_date, end_date` @@ -1545,11 +1548,14 @@ export const ensureUseStructureServiceStructureDataData = (queryClient: QueryCli * @returns GridNodeResponse Successful Response * @throws ApiError */ -export const ensureUseGridServiceGetDagStructureData = (queryClient: QueryClient, { dagId, limit, offset, orderBy, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }: { +export const ensureUseGridServiceGetDagStructureData = (queryClient: QueryClient, { dagId, includeDownstream, includeUpstream, limit, offset, orderBy, root, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }: { dagId: string; + includeDownstream?: boolean; + includeUpstream?: boolean; limit?: number; offset?: number; orderBy?: string[]; + root?: string; runAfterGt?: string; runAfterGte?: string; runAfterLt?: string; @@ -1557,7 +1563,7 @@ export const ensureUseGridServiceGetDagStructureData = (queryClient: QueryClient runType?: string[]; state?: string[]; triggeringUser?: string; -}) => queryClient.ensureQueryData({ queryKey: Common.UseGridServiceGetDagStructureKeyFn({ dagId, limit, offset, orderBy, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }), queryFn: () => GridService.getDagStructure({ dagId, limit, offset, orderBy, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }) }); +}) => queryClient.ensureQueryData({ queryKey: Common.UseGridServiceGetDagStructureKeyFn({ dagId, includeDownstream, includeUpstream, limit, offset, orderBy, root, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }), queryFn: () => GridService.getDagStructure({ dagId, includeDownstream, includeUpstream, limit, offset, orderBy, root, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }) }); /** * Get Grid Runs * Get info about a run for the grid. diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts index 4a2fa7fc2c2bc..1020640d3a912 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts @@ -1532,6 +1532,9 @@ export const prefetchUseStructureServiceStructureData = (queryClient: QueryClien * Return dag structure for grid view. * @param data The data for the request. * @param data.dagId +* @param data.includeUpstream +* @param data.includeDownstream +* @param data.root * @param data.offset * @param data.limit * @param data.orderBy Attributes to order by, multi criteria sort is supported. Prefix with `-` for descending order. Supported attributes: `run_after, logical_date, start_date, end_date` @@ -1545,11 +1548,14 @@ export const prefetchUseStructureServiceStructureData = (queryClient: QueryClien * @returns GridNodeResponse Successful Response * @throws ApiError */ -export const prefetchUseGridServiceGetDagStructure = (queryClient: QueryClient, { dagId, limit, offset, orderBy, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }: { +export const prefetchUseGridServiceGetDagStructure = (queryClient: QueryClient, { dagId, includeDownstream, includeUpstream, limit, offset, orderBy, root, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }: { dagId: string; + includeDownstream?: boolean; + includeUpstream?: boolean; limit?: number; offset?: number; orderBy?: string[]; + root?: string; runAfterGt?: string; runAfterGte?: string; runAfterLt?: string; @@ -1557,7 +1563,7 @@ export const prefetchUseGridServiceGetDagStructure = (queryClient: QueryClient, runType?: string[]; state?: string[]; triggeringUser?: string; -}) => queryClient.prefetchQuery({ queryKey: Common.UseGridServiceGetDagStructureKeyFn({ dagId, limit, offset, orderBy, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }), queryFn: () => GridService.getDagStructure({ dagId, limit, offset, orderBy, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }) }); +}) => queryClient.prefetchQuery({ queryKey: Common.UseGridServiceGetDagStructureKeyFn({ dagId, includeDownstream, includeUpstream, limit, offset, orderBy, root, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }), queryFn: () => GridService.getDagStructure({ dagId, includeDownstream, includeUpstream, limit, offset, orderBy, root, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }) }); /** * Get Grid Runs * Get info about a run for the grid. diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts index ccd2ff54d9947..98025319021d7 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts @@ -1532,6 +1532,9 @@ export const useStructureServiceStructureData = = unknown[]>({ dagId, limit, offset, orderBy, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }: { +export const useGridServiceGetDagStructure = = unknown[]>({ dagId, includeDownstream, includeUpstream, limit, offset, orderBy, root, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }: { dagId: string; + includeDownstream?: boolean; + includeUpstream?: boolean; limit?: number; offset?: number; orderBy?: string[]; + root?: string; runAfterGt?: string; runAfterGte?: string; runAfterLt?: string; @@ -1557,7 +1563,7 @@ export const useGridServiceGetDagStructure = , "queryKey" | "queryFn">) => useQuery({ queryKey: Common.UseGridServiceGetDagStructureKeyFn({ dagId, limit, offset, orderBy, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }, queryKey), queryFn: () => GridService.getDagStructure({ dagId, limit, offset, orderBy, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }) as TData, ...options }); +}, queryKey?: TQueryKey, options?: Omit, "queryKey" | "queryFn">) => useQuery({ queryKey: Common.UseGridServiceGetDagStructureKeyFn({ dagId, includeDownstream, includeUpstream, limit, offset, orderBy, root, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }, queryKey), queryFn: () => GridService.getDagStructure({ dagId, includeDownstream, includeUpstream, limit, offset, orderBy, root, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }) as TData, ...options }); /** * Get Grid Runs * Get info about a run for the grid. diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts index 51a339c4d7b08..0f2ca446f9644 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts @@ -1532,6 +1532,9 @@ export const useStructureServiceStructureDataSuspense = = unknown[]>({ dagId, limit, offset, orderBy, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }: { +export const useGridServiceGetDagStructureSuspense = = unknown[]>({ dagId, includeDownstream, includeUpstream, limit, offset, orderBy, root, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }: { dagId: string; + includeDownstream?: boolean; + includeUpstream?: boolean; limit?: number; offset?: number; orderBy?: string[]; + root?: string; runAfterGt?: string; runAfterGte?: string; runAfterLt?: string; @@ -1557,7 +1563,7 @@ export const useGridServiceGetDagStructureSuspense = , "queryKey" | "queryFn">) => useSuspenseQuery({ queryKey: Common.UseGridServiceGetDagStructureKeyFn({ dagId, limit, offset, orderBy, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }, queryKey), queryFn: () => GridService.getDagStructure({ dagId, limit, offset, orderBy, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }) as TData, ...options }); +}, queryKey?: TQueryKey, options?: Omit, "queryKey" | "queryFn">) => useSuspenseQuery({ queryKey: Common.UseGridServiceGetDagStructureKeyFn({ dagId, includeDownstream, includeUpstream, limit, offset, orderBy, root, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }, queryKey), queryFn: () => GridService.getDagStructure({ dagId, includeDownstream, includeUpstream, limit, offset, orderBy, root, runAfterGt, runAfterGte, runAfterLt, runAfterLte, runType, state, triggeringUser }) as TData, ...options }); /** * Get Grid Runs * Get info about a run for the grid. diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts index d016eaa00f4fd..a9e999c34275b 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts @@ -3883,6 +3883,9 @@ export class GridService { * Return dag structure for grid view. * @param data The data for the request. * @param data.dagId + * @param data.includeUpstream + * @param data.includeDownstream + * @param data.root * @param data.offset * @param data.limit * @param data.orderBy Attributes to order by, multi criteria sort is supported. Prefix with `-` for descending order. Supported attributes: `run_after, logical_date, start_date, end_date` @@ -3904,6 +3907,9 @@ export class GridService { dag_id: data.dagId }, query: { + include_upstream: data.includeUpstream, + include_downstream: data.includeDownstream, + root: data.root, offset: data.offset, limit: data.limit, order_by: data.orderBy, diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts index 024485ac0df6b..d1b82a1d43694 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts @@ -3398,12 +3398,15 @@ export type StructureDataResponse2 = StructureDataResponse; export type GetDagStructureData = { dagId: string; + includeDownstream?: boolean; + includeUpstream?: boolean; limit?: number; offset?: number; /** * Attributes to order by, multi criteria sort is supported. Prefix with `-` for descending order. Supported attributes: `run_after, logical_date, start_date, end_date` */ orderBy?: Array<(string)>; + root?: string | null; runAfterGt?: string | null; runAfterGte?: string | null; runAfterLt?: string | null; diff --git a/airflow-core/src/airflow/ui/src/layouts/Details/Grid/Grid.tsx b/airflow-core/src/airflow/ui/src/layouts/Details/Grid/Grid.tsx index 3690542e00b41..e566094be968b 100644 --- a/airflow-core/src/airflow/ui/src/layouts/Details/Grid/Grid.tsx +++ b/airflow-core/src/airflow/ui/src/layouts/Details/Grid/Grid.tsx @@ -24,11 +24,9 @@ import { useTranslation } from "react-i18next"; import { FiChevronsRight } from "react-icons/fi"; import { Link, useParams, useSearchParams } from "react-router-dom"; -import { useStructureServiceStructureData } from "openapi/queries"; import type { DagRunState, DagRunType, GridRunsResponse } from "openapi/requests"; import { useOpenGroups } from "src/context/openGroups"; import { useNavigation } from "src/hooks/navigation"; -import useSelectedVersion from "src/hooks/useSelectedVersion"; import { useGridRuns } from "src/queries/useGridRuns.ts"; import { useGridStructure } from "src/queries/useGridStructure.ts"; import { isStatePending } from "src/utils"; @@ -79,53 +77,14 @@ export const Grid = ({ dagRunState, limit, runType, showGantt, triggeringUser }: const { data: dagStructure } = useGridStructure({ dagRunState, hasActiveRun: gridRuns?.some((dr) => isStatePending(dr.state)), + includeDownstream, + includeUpstream, limit, + root: filterRoot, runType, triggeringUser, }); - const selectedVersion = useSelectedVersion(); - - const hasActiveFilter = includeUpstream || includeDownstream; - - // fetch filtered structure when filter is active - const { data: taskStructure } = useStructureServiceStructureData( - { - dagId, - externalDependencies: false, - includeDownstream, - includeUpstream, - root: hasActiveFilter && filterRoot !== undefined ? filterRoot : undefined, - versionNumber: selectedVersion, - }, - undefined, - { - enabled: selectedVersion !== undefined && hasActiveFilter && filterRoot !== undefined, - }, - ); - - // extract allowed task IDs from task structure when filter is active - const allowedTaskIds = useMemo(() => { - if (!hasActiveFilter || filterRoot === undefined || taskStructure === undefined) { - return undefined; - } - - const taskIds = new Set(); - - const addNodeAndChildren = | null; id: string }>(currentNode: T) => { - taskIds.add(currentNode.id); - if (currentNode.children) { - currentNode.children.forEach((child) => addNodeAndChildren(child)); - } - }; - - taskStructure.nodes.forEach((node) => { - addNodeAndChildren(node); - }); - - return taskIds; - }, [hasActiveFilter, filterRoot, taskStructure]); - // calculate dag run bar heights relative to max const max = Math.max.apply( undefined, @@ -136,19 +95,7 @@ export const Grid = ({ dagRunState, limit, runType, showGantt, triggeringUser }: .filter((duration: number | null): duration is number => duration !== null), ); - const { flatNodes } = useMemo(() => { - const nodes = flattenNodes(dagStructure, openGroupIds); - - // filter nodes based on task stream filter if active - if (allowedTaskIds !== undefined) { - return { - ...nodes, - flatNodes: nodes.flatNodes.filter((node) => allowedTaskIds.has(node.id)), - }; - } - - return nodes; - }, [dagStructure, openGroupIds, allowedTaskIds]); + const { flatNodes } = useMemo(() => flattenNodes(dagStructure, openGroupIds), [dagStructure, openGroupIds]); const { setMode } = useNavigation({ onToggleGroup: toggleGroupId, diff --git a/airflow-core/src/airflow/ui/src/queries/useGridStructure.ts b/airflow-core/src/airflow/ui/src/queries/useGridStructure.ts index ce334759e777f..a74483a6f1568 100644 --- a/airflow-core/src/airflow/ui/src/queries/useGridStructure.ts +++ b/airflow-core/src/airflow/ui/src/queries/useGridStructure.ts @@ -25,13 +25,19 @@ import { useAutoRefresh } from "src/utils"; export const useGridStructure = ({ dagRunState, hasActiveRun, + includeDownstream, + includeUpstream, limit, + root, runType, triggeringUser, }: { dagRunState?: DagRunState | undefined; hasActiveRun?: boolean; + includeDownstream?: boolean; + includeUpstream?: boolean; limit?: number; + root?: string; runType?: DagRunType | undefined; triggeringUser?: string | undefined; }) => { @@ -42,8 +48,11 @@ export const useGridStructure = ({ const { data: dagStructure, ...rest } = useGridServiceGetDagStructure( { dagId, + includeDownstream, + includeUpstream, limit, orderBy: ["-run_after"], + root, runType: runType ? [runType] : undefined, state: dagRunState ? [dagRunState] : undefined, triggeringUser: triggeringUser ?? undefined, diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py index a3f8ba4e2ac0d..7fbb81d1fa214 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py @@ -45,6 +45,8 @@ DAG_ID_2 = "test_dag_2" DAG_ID_3 = "test_dag_3" DAG_ID_4 = "test_dag_4" +DAG_ID_5 = "test_dag_5" +DAG_ID_6 = "test_dag_6" TASK_ID = "task" TASK_ID_2 = "task2" TASK_ID_3 = "task3" @@ -253,6 +255,84 @@ def mapped_task_group(arg1): ti.end_date = end_date start_date = end_date end_date = start_date.add(seconds=2) + + # DAG 5 for testing root, include_upstream, include_downstream parameters + # Also includes a Historical task + with dag_maker(dag_id=DAG_ID_5, serialized=True, session=session) as dag_5: + task_a = EmptyOperator(task_id="task_a") + task_b = EmptyOperator(task_id="task_b") + task_c = EmptyOperator(task_id="task_c") + task_d = EmptyOperator(task_id="task_d") + task_f = EmptyOperator(task_id="task_f") + task_a >> task_b >> task_c >> task_d >> task_f + # Create linear dependency: task_a >> task_b >> task_c >> task_d >> task_f (HISTORICAL_TASK) + + logical_date = timezone.datetime(2024, 11, 30) + data_interval = dag_5.timetable.infer_manual_data_interval(run_after=logical_date) + run_5_1 = dag_maker.create_dagrun( + run_id="run_5_1", + state=DagRunState.SUCCESS, + run_type=DagRunType.SCHEDULED, + start_date=logical_date, + logical_date=logical_date, + data_interval=data_interval, + **triggered_by_kwargs, + ) + + with dag_maker(dag_id=DAG_ID_5, serialized=True, session=session) as dag_5: + task_a = EmptyOperator(task_id="task_a") + task_b = EmptyOperator(task_id="task_b") + task_c = EmptyOperator(task_id="task_c") + task_d = EmptyOperator(task_id="task_d") + task_e = EmptyOperator(task_id="task_e") + task_a >> task_b >> task_c >> task_d >> task_e + # Create linear dependency: task_a >> task_b >> task_c >> task_d >> task_e + + run_5_2 = dag_maker.create_dagrun( + run_id="run_5_2", + state=DagRunState.SUCCESS, + run_type=DagRunType.SCHEDULED, + start_date=logical_date, + logical_date=logical_date + timedelta(days=1), + data_interval=data_interval, + **triggered_by_kwargs, + ) + for ti in run_5_1.task_instances: + ti.state = TaskInstanceState.SUCCESS + ti.end_date = None + for ti in run_5_2.task_instances: + ti.state = TaskInstanceState.SUCCESS + ti.end_date = None + + # DAG 6 for testing root, include_upstream, include_downstream with non-linear dependencies + # Structure: start >> [branch_a, branch_b] >> merge >> end + # branch_a >> intermediate >> merge + with dag_maker(dag_id=DAG_ID_6, serialized=True, session=session) as dag_6: + start = EmptyOperator(task_id="start") + branch_a = EmptyOperator(task_id="branch_a") + branch_b = EmptyOperator(task_id="branch_b") + intermediate = EmptyOperator(task_id="intermediate") + merge = EmptyOperator(task_id="merge") + end = EmptyOperator(task_id="end") + # Create non-linear dependencies + start >> [branch_a, branch_b] + branch_a >> intermediate >> merge + branch_b >> merge + merge >> end + + logical_date = timezone.datetime(2024, 11, 30) + data_interval = dag_6.timetable.infer_manual_data_interval(run_after=logical_date) + run_6 = dag_maker.create_dagrun( + run_id="run_6-1", + state=DagRunState.SUCCESS, + run_type=DagRunType.SCHEDULED, + start_date=logical_date, + logical_date=logical_date, + data_interval=data_interval, + **triggered_by_kwargs, + ) + for ti in run_6.task_instances: + ti.state = TaskInstanceState.SUCCESS session.commit() @@ -776,3 +856,80 @@ def test_structure_includes_historical_removed_task_with_proper_shape(self, sess # Optional None fields are excluded from response due to response_model_exclude_none=True assert "is_mapped" not in t4 assert "children" not in t4 + + # Tests for root, include_upstream, and include_downstream parameters + def test_structure_with_root_only(self, test_client): + """Test that specifying only root parameter returns just that task.""" + response = test_client.get(f"/grid/structure/{DAG_ID_5}?root=task_c") + assert response.status_code == 200 + nodes = response.json() + task_ids = [node["id"] for node in nodes] + # Only task_c should be returned when root is specified without upstream/downstream + assert task_ids == ["task_c"] + + def test_structure_with_root_and_include_upstream(self, test_client): + """Test that root + include_upstream returns the root task and all upstream tasks.""" + response = test_client.get(f"/grid/structure/{DAG_ID_5}?root=task_c&include_upstream=true") + assert response.status_code == 200 + nodes = response.json() + task_ids = sorted([node["id"] for node in nodes]) + # Should return task_c and all upstream tasks (task_a, task_b) + assert task_ids == ["task_a", "task_b", "task_c"] + + def test_structure_with_root_and_include_downstream_and_historical(self, test_client): + """Test that root + include_downstream returns the root task and all downstream tasks.""" + response = test_client.get(f"/grid/structure/{DAG_ID_5}?root=task_c&include_downstream=true") + assert response.status_code == 200 + nodes = response.json() + task_ids = sorted([node["id"] for node in nodes]) + # Should return task_c and all downstream tasks, including the historical (task_d, task_e, task_f (HISTORICAL)) + assert task_ids == ["task_c", "task_d", "task_e", "task_f"] + + # Tests for non-linear DAG structure + def test_nonlinear_structure_with_root_downstream_from_branch(self, test_client): + """Test filtering downstream from branch point includes both branches.""" + response = test_client.get(f"/grid/structure/{DAG_ID_6}?root=start&include_downstream=true") + assert response.status_code == 200 + nodes = response.json() + task_ids = sorted([node["id"] for node in nodes]) + # Should return all tasks since start is the root of the DAG + assert task_ids == ["branch_a", "branch_b", "end", "intermediate", "merge", "start"] + + def test_nonlinear_structure_with_root_upstream_from_merge(self, test_client): + """Test filtering upstream from merge point includes all upstream branches.""" + response = test_client.get(f"/grid/structure/{DAG_ID_6}?root=merge&include_upstream=true") + assert response.status_code == 200 + nodes = response.json() + task_ids = sorted([node["id"] for node in nodes]) + # Should return merge and all upstream tasks (both branches and their parents) + assert task_ids == ["branch_a", "branch_b", "intermediate", "merge", "start"] + + def test_nonlinear_structure_with_root_on_branch_include_downstream(self, test_client): + """Test filtering downstream from one branch.""" + response = test_client.get(f"/grid/structure/{DAG_ID_6}?root=branch_a&include_downstream=true") + assert response.status_code == 200 + nodes = response.json() + task_ids = sorted([node["id"] for node in nodes]) + # Should return branch_a and its downstream path (intermediate, merge, end) + assert task_ids == ["branch_a", "end", "intermediate", "merge"] + + def test_nonlinear_structure_with_root_on_branch_include_upstream(self, test_client): + """Test filtering upstream from one branch.""" + response = test_client.get(f"/grid/structure/{DAG_ID_6}?root=branch_a&include_upstream=true") + assert response.status_code == 200 + nodes = response.json() + task_ids = sorted([node["id"] for node in nodes]) + # Should return branch_a and its upstream (start) + assert task_ids == ["branch_a", "start"] + + def test_nonlinear_structure_intermediate_with_both_directions(self, test_client): + """Test filtering from intermediate node with both upstream and downstream.""" + response = test_client.get( + f"/grid/structure/{DAG_ID_6}?root=intermediate&include_upstream=true&include_downstream=true" + ) + assert response.status_code == 200 + nodes = response.json() + task_ids = sorted([node["id"] for node in nodes]) + # Should return intermediate, its upstream path, and downstream path + # upstream: branch_a, start; downstream: merge, end + assert task_ids == ["branch_a", "end", "intermediate", "merge", "start"]