Skip to content

Commit

Permalink
Merge pull request #455 from kreneskyp/nested_run_log_nodes
Browse files Browse the repository at this point in the history
Run log can now load missing nodes.
  • Loading branch information
kreneskyp authored Feb 19, 2024
2 parents ece1bf7 + bdc58db commit 388508f
Show file tree
Hide file tree
Showing 9 changed files with 163 additions and 38 deletions.
10 changes: 1 addition & 9 deletions frontend/chains/ChainGraphEditor.js
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import { TabState } from "chains/hooks/useTabState";
import { NOTIFY_SAVED } from "chains/editor/constants";
import { PropEdge } from "chains/PropEdge";
import { LinkEdge } from "chains/LinkEdge";
import { addNode, addType } from "chains/utils";

// Nodes are either a single node or a group of nodes
// ConfigNode renders class_path specific content
Expand All @@ -61,15 +62,6 @@ const getExpectedTypes = (connector) => {
: new Set([connector.source_type]);
};

const addType = (type, setTypes) => {
setTypes((prev) => {
if (!prev.some((t) => t.id === type.id)) {
return [...prev, type];
}
return prev;
});
};

const ChainGraphEditor = ({ graph }) => {
// editor contexts
const [chain, setChain] = useContext(ChainState);
Expand Down
36 changes: 27 additions & 9 deletions frontend/chains/editor/run_log/ExecutionList.js
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,35 @@ import { StyledIcon } from "components/StyledIcon";
import { DEFAULT_NODE_STYLE, NODE_STYLES } from "chains/editor/styles";
import { useRunLog } from "chains/editor/run_log/useRunLog";

const useExecutionTree = (executions, nodes, types) => {
const useExecutionTree = (
executions,
nodes,
types,
extra_nodes,
extra_types
) => {
const buildTree = (items, parentId = null) => {
return (
items
?.filter((item) => item.parent_id === parentId)
?.map((item) => {
const node = nodes?.[item.node_id];
const node = nodes?.[item.node_id] || extra_nodes?.[item.node_id];
const type =
types?.find((t) => t.id === node?.node_type_id) ||
extra_types?.[node?.node_type_id];
return {
execution: item,
children: buildTree(items, item.id),
node: node,
type: types?.find((t) => t.id === node?.node_type_id),
node,
type,
};
}) || []
);
};

return React.useMemo(() => {
return buildTree(executions);
}, [executions, nodes, types]);
}, [executions, nodes, types, extra_nodes, extra_types]);
};

const ExecutionIcon = ({ type, isLight }) => {
Expand Down Expand Up @@ -132,7 +141,10 @@ const ExecutionTreeNode = ({ execution, isFirst, isLast }) => {
{...(isSelected ? selectedItemStyle : itemStyle)}
width={"100%"}
>
<TreeItem isFirst={isFirst} isLast={isLast}>
<TreeItem
isFirst={isFirst}
isLast={isLast && (execution.children.length < 0 || !isOpen)}
>
<ExecutionIcon type={execution.type} isLight={isLight} />
</TreeItem>
<ExecutionBrief
Expand Down Expand Up @@ -160,19 +172,25 @@ const ExecutionTreeNode = ({ execution, isFirst, isLast }) => {

{isOpen && (
<HStack bg={"transparent"} height={"100%"} spacing={0}>
<BranchLine height={execution.children.length * 60} />
<BranchLine height={execution.children.length * 60} isLast={isLast} />
{children.length > 1 && <VStack spacing={0}>{children}</VStack>}
</HStack>
)}
</Box>
);
};

export const ExecutionList = ({ log }) => {
export const ExecutionList = ({ log, extra_nodes, extra_types }) => {
const { nodes } = React.useContext(NodeStateContext);
const [types, setTypes] = React.useContext(ChainTypes);

const executions = useExecutionTree(log?.executions, nodes, types);
const executions = useExecutionTree(
log?.executions,
nodes,
types,
extra_nodes,
extra_types
);

return (
<Box width={"200px"}>
Expand Down
5 changes: 4 additions & 1 deletion frontend/chains/editor/run_log/RunLog.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ import { ExecutionDetail } from "chains/editor/run_log/ExecutionDetail";
import { useEditorColorMode } from "chains/editor/useColorMode";

export const RunLog = ({}) => {
const { log, execution, setExecution } = useRunLog();
const { log, execution, setExecution, extra_nodes, extra_types } =
useRunLog();
const { scrollbar } = useEditorColorMode();

return (
Expand All @@ -21,6 +22,8 @@ export const RunLog = ({}) => {
>
<ExecutionList
log={log}
extra_nodes={extra_nodes}
extra_types={extra_types}
selectedExecution={execution}
setExecution={setExecution}
/>
Expand Down
56 changes: 53 additions & 3 deletions frontend/chains/editor/run_log/RunLogProvider.js
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import React from "react";
import { RunLog } from "chains/editor/contexts";
import { ChainTypes, NodeStateContext, RunLog } from "chains/editor/contexts";
import { useAxios } from "utils/hooks/useAxios";
import { useDisclosure } from "@chakra-ui/react";
import { useRunEventStream } from "chains/editor/run_log/useRunEventStream";

export const RunLogProvider = ({ chain_id, children }) => {
const disclosure = useDisclosure();
const { call, response, error } = useAxios();
const { call: fetch_nodes, response: extra_nodes_resp } = useAxios({
method: "post",
});

const [types, setTypes] = React.useContext(ChainTypes);
const { nodes, setNodes } = React.useContext(NodeStateContext);

const [log, setLog] = React.useState(null);
const [execution, setExecution] = React.useState(null);
Expand All @@ -19,6 +25,24 @@ export const RunLogProvider = ({ chain_id, children }) => {
}
}, [chain_id]);

// preemptively load any nodes that are missing from local state
// this loads nodes & types for any executions for nested chains
React.useEffect(() => {
if (log) {
const missing_nodes = log.executions
?.map((e) => e.node_id)
?.filter((id) => !nodes[id]);

if (missing_nodes?.length > 0) {
fetch_nodes(`/api/nodes/bulk`, {
data: missing_nodes,
}).catch((err) => {
console.error("Failed to load missing nodes", err);
});
}
}
}, [log]);

React.useEffect(() => {
if (chain_id !== undefined) {
getLog();
Expand Down Expand Up @@ -64,7 +88,7 @@ export const RunLogProvider = ({ chain_id, children }) => {
});
}
return log_by_node;
}, [log]);
}, [log, nodes, types]);

const state = React.useMemo(() => {
const has_errors = log?.executions?.some((e) => e.completed === false);
Expand All @@ -81,9 +105,35 @@ export const RunLogProvider = ({ chain_id, children }) => {
};
}, [log]);

// HAX: force refresh of log object whenever types or nodes changes.
// this is a cheap way of injecting this dependency into the log object
// so downstream components don't need to know that nodes or types have
// changed.
const _log = React.useMemo(() => ({ ...log }), [log, extra_nodes_resp]);

const extra_nodes = React.useMemo(() => {
// return mapping of extra nodes
const mapping = {};
for (let node of extra_nodes_resp?.data.nodes || []) {
mapping[node.id] = node;
}
return mapping;
}, [extra_nodes_resp]);

const extra_types = React.useMemo(() => {
// return mapping of extra types
const mapping = {};
for (let type of extra_nodes_resp?.data.types || []) {
mapping[type.id] = type;
}
return mapping;
}, [extra_nodes_resp]);

const value = React.useMemo(() => {
return {
log,
log: _log,
extra_nodes,
extra_types,
state,
log_by_node,
execution,
Expand Down
4 changes: 2 additions & 2 deletions frontend/chains/editor/run_log/Tree.js
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ const VerticalLine = ({ flexGrow, ...props }) => {
);
};

export const BranchLine = ({ height }) => {
export const BranchLine = ({ height, isLast }) => {
return (
<Box display="flex" flexDirection="column" height={`${height}px`} ml={2}>
<VStack
Expand All @@ -57,7 +57,7 @@ export const BranchLine = ({ height }) => {
ml={"10px"}
flexGrow={0}
/>
<VerticalLine />
{isLast ? <Spacer /> : <VerticalLine />}
</VStack>
</Box>
);
Expand Down
42 changes: 33 additions & 9 deletions frontend/chains/utils.js
Original file line number Diff line number Diff line change
@@ -1,11 +1,35 @@
import React from "react";
import { LLM_NAME_MAP } from "chains/constants";
export const addType = (type, setTypes) => {
setTypes((prev) => {
if (!prev.some((t) => t.id === type.id)) {
return [...prev, type];
}
return prev;
});
};

export function llm_name(classPath) {
if (classPath === undefined) {
return <i>inherited</i>;
}
export const addTypes = (types, setTypes) => {
setTypes((prev) => {
const new_types = types.filter(
(type) => !prev.some((t) => t.id === type.id)
);
return [...prev, ...new_types];
});
};

// lookup LLM name, default to class name if classPath isn't mapped with a label
return LLM_NAME_MAP[classPath] || classPath.split(".").pop();
}
export const addNode = (node, setNodes) => {
setNodes((prevNodes) => {
return { ...prevNodes, [node.id]: node };
});
};

export const addNodes = (newNodes, setNodes) => {
setNodes((prevNodes) => {
const updatedNodes = { ...prevNodes };
newNodes.forEach((node) => {
if (!prevNodes[node.id]) {
updatedNodes[node.id] = node;
}
});
return updatedNodes;
});
};
27 changes: 27 additions & 0 deletions ix/api/editor/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import asyncio
import logging
from typing import List
from uuid import UUID

from django.contrib.auth.models import AbstractUser
from django.db.models import Q
from fastapi import APIRouter, HTTPException, Depends
from pydantic import UUID4

from ix.agents.models import Agent
from ix.api.auth import get_request_user
Expand All @@ -21,6 +23,7 @@
UpdatedRoot,
AddNode,
UpdateRoot,
GraphNodes,
)
from ix.chains.models import Chain, ChainNode, NodeType, ChainEdge

Expand All @@ -32,6 +35,7 @@
from ix.api.chains.types import Node as NodePydantic
from ix.api.chains.types import Edge as EdgePydantic
from ix.chat.models import Chat
from ix.ix_users.models import User, OwnedModel

logger = logging.getLogger(__name__)
router = APIRouter()
Expand Down Expand Up @@ -251,3 +255,26 @@ async def get_chain_chat(
"""Return test chat instance for the chain"""
chat = await _get_test_chat(chain_id, user)
return ChatPydantic.from_orm(chat)


@router.post(
"/nodes/bulk",
operation_id="get_nodes",
response_model=GraphNodes,
tags=["Chain Editor"],
)
async def get_nodes(node_ids: List[UUID4], user: User = Depends(get_request_user)):
"""Return single node"""
"""Return list of nodes and their types for given node IDs"""
filtered_nodes = OwnedModel.filter_owners(
user, ChainNode.objects.all(), prefix="chain__"
)
nodes = filtered_nodes.filter(id__in=node_ids)
node_pydantics = [NodePydantic.model_validate(node) async for node in nodes]

# Fetch types for these nodes
type_ids = {node.node_type_id for node in node_pydantics}
types = NodeType.objects.filter(id__in=type_ids)
type_pydantics = [NodeTypePydantic.model_validate(type_) async for type_ in types]

return GraphNodes(nodes=node_pydantics, types=type_pydantics)
5 changes: 5 additions & 0 deletions ix/api/editor/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,8 @@ class GraphModel(BaseModel):
nodes: List[NodePydantic]
edges: List[EdgePydantic]
types: List[NodeTypePydantic]


class GraphNodes(BaseModel):
nodes: List[NodePydantic]
types: List[NodeTypePydantic]
16 changes: 11 additions & 5 deletions ix/ix_users/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def filtered_owners(cls, user: User, global_restricted=False) -> QuerySet:

@staticmethod
def filter_owners(
user: User, queryset: QuerySet, global_restricted=False
user: User, queryset: QuerySet, global_restricted=False, prefix=""
) -> QuerySet:
"""Filter a queryset to only include objects available to the given user:
Expand All @@ -47,16 +47,22 @@ def filter_owners(
Assumes they inherit from OwnedMixin.
"""

# disable filtering for local deployments
# Disable filtering for local deployments
if not settings.OWNER_FILTERING:
return queryset

if not user:
return queryset.none()

user_owned = Q(user_id=user.id)
group_owned = Q(group__user=user)
global_owned = Q(user_id=None, group_id=None)
# Prepend prefix to field lookups if provided
user_field = f"{prefix}user_id" if prefix else "user_id"
group_field = f"{prefix}group__user" if prefix else "group__user"
global_user_field = f"{prefix}user_id" if prefix else "user_id"
global_group_field = f"{prefix}group_id" if prefix else "group_id"

user_owned = Q(**{user_field: user.id})
group_owned = Q(**{group_field: user})
global_owned = Q(**{global_user_field: None, global_group_field: None})

if global_restricted and not user.is_superuser:
return queryset.exclude(global_owned).filter(user_owned | group_owned)
Expand Down

0 comments on commit 388508f

Please sign in to comment.