Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 41 additions & 40 deletions airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,8 @@
datetime_range_filter_factory,
)
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.core_api.datamodels.ui.common import (
GridNodeResponse,
GridRunsResponse,
)
from airflow.api_fastapi.core_api.datamodels.ui.grid import (
GridTISummaries,
)
from airflow.api_fastapi.core_api.datamodels.ui.common import GridNodeResponse, GridRunsResponse
from airflow.api_fastapi.core_api.datamodels.ui.grid import GridTISummaries
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
from airflow.api_fastapi.core_api.security import requires_access_dag
from airflow.api_fastapi.core_api.services.ui.grid import (
Expand All @@ -68,9 +63,7 @@
def _get_latest_serdag(dag_id, session):
serdag = session.scalar(
select(SerializedDagModel)
.where(
SerializedDagModel.dag_id == dag_id,
)
.where(SerializedDagModel.dag_id == dag_id)
.order_by(SerializedDagModel.id.desc())
.limit(1)
)
Expand All @@ -92,9 +85,7 @@ def _get_serdag(dag_id, dag_version_id, session) -> SerializedDagModel | None:
if not version:
version = session.scalar(
select(DagVersion)
.where(
DagVersion.dag_id == dag_id,
)
.where(DagVersion.dag_id == dag_id)
.options(joinedload(DagVersion.serialized_dag))
.order_by(DagVersion.id) # ascending cus this is mostly for pre-3.0 upgrade
.limit(1)
Expand Down Expand Up @@ -166,11 +157,7 @@ def get_dag_structure(
SerializedDagModel.dag_id == dag_id,
SerializedDagModel.id != latest_serdag.id,
SerializedDagModel.dag_version_id.in_(
select(TaskInstance.dag_version_id)
.join(TaskInstance.dag_run)
.where(
DagRun.id.in_(run_ids),
)
select(TaskInstance.dag_version_id).join(TaskInstance.dag_run).where(DagRun.id.in_(run_ids))
),
)
)
Expand Down Expand Up @@ -346,31 +333,44 @@ def get_grid_ti_summaries(
TaskInstance.dag_version_id,
TaskInstance.start_date,
TaskInstance.end_date,
TaskInstance.duration,
)
.where(TaskInstance.dag_id == dag_id)
.where(
TaskInstance.run_id == run_id,
)
.where(TaskInstance.run_id == run_id)
),
filters=[],
order_by=SortParam(allowed_attrs=["task_id", "run_id"], model=TaskInstance).set_value(["task_id"]),
limit=None,
return_total_entries=False,
)
task_instances = list(session.execute(tis_of_dag_runs))

task_instances = session.scalars(tis_of_dag_runs).all()
if not task_instances:
raise HTTPException(
status.HTTP_404_NOT_FOUND, f"No task instances for dag_id={dag_id} run_id={run_id}"
)
ti_details = collections.defaultdict(list)

ti_details: dict[str, list[dict[str, object | None]]] = collections.defaultdict(list)
for ti in task_instances:
ti_details[ti.task_id].append(
{
"state": ti.state,
"start_date": ti.start_date,
"end_date": ti.end_date,
"duration": ti.duration,
}
)

# Pre-compute min start / max end per leaf task (for tooltip) and pick a representative duration.
task_min_max: dict[str, tuple[object | None, object | None]] = {}
task_first_duration: dict[str, float | None] = {}
for task_id, items in ti_details.items():
starts = [i["start_date"] for i in items if i.get("start_date") is not None]
ends = [i["end_date"] for i in items if i.get("end_date") is not None]
task_min_max[task_id] = (min(starts) if starts else None, max(ends) if ends else None)
dur = next((i["duration"] for i in items if i.get("duration") is not None), None)
task_first_duration[task_id] = dur # may be None

serdag = _get_serdag(
dag_id=dag_id,
dag_version_id=task_instances[0].dag_version_id,
Expand All @@ -379,10 +379,10 @@ def get_grid_ti_summaries(
if TYPE_CHECKING:
assert serdag

def get_node_sumaries():
def gen_nodes():
yielded_task_ids: set[str] = set()

# Yield all nodes discoverable from the serialized DAG structure
# Emit nodes discovered from the serialized DAG structure
for node in _find_aggregates(
node=serdag.dag.task_group,
parent_node=None,
Expand All @@ -391,33 +391,34 @@ def get_node_sumaries():
if node["type"] in {"task", "mapped_task"}:
yielded_task_ids.add(node["task_id"])
if node["type"] == "task":
# Attach min/max and DB-backed duration for leaf tasks.
min_start, max_end = task_min_max.get(node["task_id"], (None, None))
node["min_start_date"] = min_start
node["max_end_date"] = max_end
node["duration"] = task_first_duration.get(node["task_id"])
node["child_states"] = None
node["min_start_date"] = None
node["max_end_date"] = None
yield node

# For good history: add synthetic leaf nodes for task_ids that have TIs in this run
# but are not present in the current DAG structure (e.g. removed tasks)
missing_task_ids = set(ti_details.keys()) - yielded_task_ids
for task_id in sorted(missing_task_ids):
detail = ti_details[task_id]
# Create a leaf task node with aggregated state from its TIs
agg = _get_aggs_for_node(detail)
# Add synthetic leaves for historical/removed tasks present only in TI table
missing = set(ti_details.keys()) - yielded_task_ids
for task_id in sorted(missing):
agg = _get_aggs_for_node(ti_details[task_id])
min_start, max_end = task_min_max.get(task_id, (None, None))
yield {
"task_id": task_id,
"type": "task",
"parent_id": None,
**agg,
# Align with leaf behavior
"child_states": None,
"min_start_date": None,
"max_end_date": None,
"min_start_date": min_start,
"max_end_date": max_end,
"duration": task_first_duration.get(task_id),
}

task_instances = list(get_node_sumaries())
task_nodes = list(gen_nodes())
# If a group id and a task id collide, prefer the group record
group_ids = {n.get("task_id") for n in task_instances if n.get("type") == "group"}
filtered = [n for n in task_instances if not (n.get("type") == "task" and n.get("task_id") in group_ids)]
group_ids = {n.get("task_id") for n in task_nodes if n.get("type") == "group"}
filtered = [n for n in task_nodes if not (n.get("type") == "task" and n.get("task_id") in group_ids)]

return { # type: ignore[return-value]
"run_id": run_id,
Expand Down
75 changes: 50 additions & 25 deletions airflow-core/src/airflow/ui/src/layouts/Details/Grid/GridTI.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
import { Badge, Flex } from "@chakra-ui/react";
import { Badge, Flex, Box, Text } from "@chakra-ui/react";
import type { MouseEvent } from "react";
import React, { useCallback } from "react";
import { useTranslation } from "react-i18next";
Expand All @@ -27,26 +27,25 @@ import { StateIcon } from "src/components/StateIcon";
import Time from "src/components/Time";
import { Tooltip } from "src/components/ui";
import { type HoverContextType, useHover } from "src/context/hover";
import { getDuration, renderDuration } from "src/utils";
import { buildTaskInstanceUrl } from "src/utils/links";

const handleMouseEnter =
(setHoveredTaskId: HoverContextType["setHoveredTaskId"]) => (event: MouseEvent<HTMLDivElement>) => {
const tasks = document.querySelectorAll<HTMLDivElement>(`#${event.currentTarget.id}`);

tasks.forEach((task) => {
task.style.backgroundColor = "var(--chakra-colors-info-subtle)";
tasks.forEach((taskEl) => {
taskEl.style.backgroundColor = "var(--chakra-colors-info-subtle)";
});

setHoveredTaskId(event.currentTarget.id.replaceAll("-", "."));
};

const handleMouseLeave = (taskId: string, setHoveredTaskId: HoverContextType["setHoveredTaskId"]) => () => {
const tasks = document.querySelectorAll<HTMLDivElement>(`#${taskId.replaceAll(".", "-")}`);

tasks.forEach((task) => {
task.style.backgroundColor = "";
tasks.forEach((taskEl) => {
taskEl.style.backgroundColor = "";
});

setHoveredTaskId(undefined);
};

Expand All @@ -65,8 +64,8 @@ type Props = {
const Instance = ({ dagId, instance, isGroup, isMapped, onClick, runId, search, taskId }: Props) => {
const { setHoveredTaskId } = useHover();
const { groupId: selectedGroupId, taskId: selectedTaskId } = useParams();
const { t: translate } = useTranslation();
const location = useLocation();
const { t: translate } = useTranslation("common");

const onMouseEnter = handleMouseEnter(setHoveredTaskId);
const onMouseLeave = handleMouseLeave(taskId, setHoveredTaskId);
Expand All @@ -84,10 +83,25 @@ const Instance = ({ dagId, instance, isGroup, isMapped, onClick, runId, search,
[dagId, isGroup, isMapped, location.pathname, runId, taskId],
);

const start: string | undefined = instance?.min_start_date
const end: string | undefined = instance?.max_end_date
const hasStart = start !== undefined;
const hasEnd = end !== undefined;

const serverDurationUnknown = (instance as unknown as { duration?: unknown }).duration;
const serverDurationSeconds = typeof serverDurationUnknown === "number" ? serverDurationUnknown : undefined;

const durationText =
serverDurationSeconds === undefined ? getDuration(start, end) : renderDuration(serverDurationSeconds);

const isSelected =
((selectedTaskId ?? undefined) !== undefined && selectedTaskId === taskId) ||
((selectedGroupId ?? undefined) !== undefined && selectedGroupId === taskId);
Comment on lines +98 to +99
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(selectedTaskId ?? undefined) that's not necessary I believe.


return (
<Flex
alignItems="center"
bg={selectedTaskId === taskId || selectedGroupId === taskId ? "info.muted" : undefined}
bg={isSelected ? "info.muted" : undefined}
height="20px"
id={taskId.replaceAll(".", "-")}
justifyContent="center"
Expand All @@ -110,24 +124,35 @@ const Instance = ({ dagId, instance, isGroup, isMapped, onClick, runId, search,
>
<Tooltip
content={
<>
{translate("taskId")}: {taskId}
<br />
{translate("state")}: {instance.state}
{instance.min_start_date !== null && (
<>
<br />
{translate("startDate")}: <Time datetime={instance.min_start_date} />
</>
)}
{instance.max_end_date !== null && (
<>
<br />
{translate("endDate")}: <Time datetime={instance.max_end_date} />
</>
<Box>
<Text>
{translate("taskId")}: {taskId}
</Text>
<Text>
{translate("state")}: {instance.state}
</Text>

{hasStart ? (
<Text>
{translate("startDate")}: <Time datetime={start} />
</Text>
) : undefined}

{hasEnd ? (
<Text>
{translate("endDate")}: <Time datetime={end} />
</Text>
) : undefined}

{durationText === undefined ? undefined : (
<Text>
{translate("duration")}: {durationText}
</Text>
)}
</>
</Box>
}
portalled
positioning={{ offset: { crossAxis: 5, mainAxis: 5 }, placement: "bottom-start" }}
>
<Badge
alignItems="center"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ export const TaskInstancesColumn = ({ nodes, onCellClick, runId, taskInstances }
const search = searchParams.toString();

return nodes.map((node) => {
// todo: how does this work with mapped? same task id for multiple tis
const taskInstance = taskInstances.find((ti) => ti.task_id === node.id);

if (!taskInstance) {
Expand Down