Skip to content

Commit

Permalink
Merge pull request #447 from kreneskyp/graph_backend_deps
Browse files Browse the repository at this point in the history
Misc updates to support LangGraph state machines.
  • Loading branch information
kreneskyp authored Feb 15, 2024
2 parents 67b8ab6 + 6ba5024 commit 1eca74f
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 106 deletions.
2 changes: 1 addition & 1 deletion ix/api/chains/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class Edge(BaseModel):
source_key: Optional[str] = None
target_key: Optional[str] = None
chain_id: UUID
relation: Literal["LINK", "PROP"]
relation: Literal["LINK", "PROP", "GRAPH"]
input_map: Optional[dict] = None

class Config:
Expand Down
93 changes: 72 additions & 21 deletions ix/chains/loaders/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from ix.api.components.types import NodeType as NodeTypePydantic
from ix.api.chains.types import Node as NodePydantic, InputConfig
from ix.chains.components.lcel import init_sequence, init_branch
from ix.chains.fixture_src.flow import ROOT_CLASS_PATH
from ix.chains.loaders.context import IxContext

from ix.chains.loaders.prompts import load_prompt
Expand All @@ -30,7 +29,6 @@
from ix.secrets.models import Secret
from ix.utils.config import format_config
from ix.utils.importlib import import_class
from ix.utils.pydantic import create_args_model_v1
from jsonschema_pydantic import jsonschema_to_pydantic

import_node_class = import_class
Expand Down Expand Up @@ -445,7 +443,7 @@ class SequencePlaceholder:

@property
def id(self):
return self.steps[0][0].id
return self.steps[0].id

def __eq__(self, other):
if isinstance(other, SequencePlaceholder):
Expand All @@ -466,6 +464,51 @@ def __eq__(self, other):
)


def find_roots(
node: ChainNode | List[ChainNode] | MapPlaceholder | BranchPlaceholder,
) -> List[ChainNode]:
"""Finds the first node(s) for a node or node group.
Used to find the node that should receive the incoming edge when connecting
the node group to a sequence
"""
if isinstance(node, list):
return [node[0]]
elif isinstance(node, MapPlaceholder):
nodes = []
for mapped_node in node.map.values():
nodes.extend(find_roots(mapped_node))
return nodes
elif isinstance(node, BranchPlaceholder):
return [node.node]
return [node]


def find_leaves(
node: ChainNode | List[ChainNode] | MapPlaceholder | BranchPlaceholder,
) -> List[ChainNode]:
"""Finds the last node(s) for a node or node group.
Used to find the node that should recieve the outgoing edge when connecting
the node group to a sequence
"""
if isinstance(node, list):
return [node[-1]]
elif isinstance(node, SequencePlaceholder):
return [node.steps[-1]]
elif isinstance(node, MapPlaceholder):
nodes = []
for mapped_node in node.map.values():
nodes.extend(find_leaves(mapped_node))
return nodes
elif isinstance(node, BranchPlaceholder):
nodes = [find_leaves(node.default)]
for key, branch_node in node.branches:
nodes.extend(find_leaves(branch_node))
return nodes
return [node]


def init_chain_flow(
chain: Chain, context: IxContext, variables: Dict[str, Any] = None
) -> Runnable:
Expand Down Expand Up @@ -520,21 +563,14 @@ async def ainit_flow(

def load_chain_flow(chain: Chain) -> Tuple[Type[BaseModel], FlowPlaceholder]:
try:
root = chain.nodes.get(root=True, class_path=ROOT_CLASS_PATH)
nodes = chain.nodes.filter(incoming_edges__source=root)
input_type = create_args_model_v1(
root.config.get("outputs", []), name="ChainInput"
)
nodes = chain.nodes.filter(incoming_edges__source=chain.chat_root)
except ChainNode.DoesNotExist:
# fallback to old style roots:
# TODO: remove this fallback after all chains have been migrated
nodes = chain.nodes.filter(root=True)
logger.debug(f"Loading chain flow with roots: {nodes}")
input_type = create_args_model_v1(
["user_input", "artifact_ids"], name="ChainInput"
)

return input_type, load_flow_node(nodes)
return chain.types.INPUT, load_flow_node(nodes)


async def aload_chain_flow(chain: Chain) -> Tuple[Type[BaseModel], FlowPlaceholder]:
Expand Down Expand Up @@ -608,17 +644,12 @@ def load_flow_map(
return new_nodes


def load_flow_branch(
def build_flow_branch(
node: ChainNode,
outgoing_links: List[ChainEdge],
seen: Dict[UUID, FlowPlaceholder],
branch_depth: Tuple[str] = None,
) -> BranchPlaceholder:
# gather branches
outgoing_links = (
node.outgoing_edges.select_related("target")
.filter(relation="LINK")
.order_by("source_key")
)
):
branches = {}
for key, group in itertools.groupby(outgoing_links, lambda edge: edge.source_key):
_branch_depth = branch_depth + (key,) if branch_depth else (key,)
Expand All @@ -628,7 +659,9 @@ def load_flow_branch(
nodes = load_flow_map(targets, seen=seen, branch_depth=_branch_depth)
else:
nodes = load_flow_sequence(
group_as_list[0].target, seen=seen, branch_depth=_branch_depth
group_as_list[0].target,
seen=seen,
branch_depth=_branch_depth,
)
branches[key] = nodes

Expand All @@ -639,6 +672,24 @@ def load_flow_branch(
(key, branches[branch_uuid])
for key, branch_uuid in zip(branch_keys, branch_uuids)
]
return branches, branch_tuples


def load_flow_branch(
node: ChainNode,
seen: Dict[UUID, FlowPlaceholder],
branch_depth: Tuple[str] = None,
) -> BranchPlaceholder:
# gather branches
outgoing_links = (
node.outgoing_edges.select_related("target")
.filter(relation="LINK")
.order_by("source_key")
)

branches, branch_tuples = build_flow_branch(
node, outgoing_links, seen, branch_depth
)

if "default" not in branches:
raise ValueError("Branch node must have a default branch")
Expand Down
22 changes: 22 additions & 0 deletions ix/chains/migrations/0016_alter_chainedge_relation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Generated by Django 4.2.7 on 2024-02-10 17:53

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("chains", "0015_langchain_0_1_0_deprecations"),
]

operations = [
migrations.AlterField(
model_name="chainedge",
name="relation",
field=models.CharField(
choices=[("PROP", "prop"), ("LINK", "link")],
default="LINK",
max_length=5,
null=True,
),
),
]
42 changes: 38 additions & 4 deletions ix/chains/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@

from django.db import models
from langchain.schema.runnable import Runnable
from pydantic.v1 import BaseModel

from ix.chains.fixture_src.flow import ROOT_CLASS_PATH
from ix.ix_users.models import OwnedModel
from ix.pg_vector.tests.models import PGVectorMixin
from ix.pg_vector.utils import get_embedding

from ix.utils.pydantic import create_args_model_v1

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -277,7 +279,7 @@ class ChainEdge(models.Model):
)
input_map = models.JSONField(null=True)
relation = models.CharField(
max_length=4, null=True, choices=RELATION_CHOICES, default="LINK"
max_length=5, null=True, choices=RELATION_CHOICES, default="LINK"
)

DoesNotExist: Type[models.ObjectDoesNotExist]
Expand Down Expand Up @@ -311,12 +313,12 @@ def root(self) -> ChainNode:
def __str__(self):
return f"{self.name} ({self.id})"

def load_chain(self, context) -> Runnable:
def load_chain(self, context: "IxContext") -> Runnable: # noqa: F821
from ix.chains.loaders.core import init_chain_flow

return init_chain_flow(self, context=context)

async def aload_chain(self, context) -> Runnable:
async def aload_chain(self, context: "IxContext") -> Runnable: # noqa: F821
from ix.chains.loaders.core import init_chain_flow

return await sync_to_async(init_chain_flow)(self, context=context)
Expand All @@ -325,3 +327,35 @@ def clear_chain(self):
"""removes the chain nodes associated with this chain"""
# clear old chain
ChainNode.objects.filter(chain_id=self.id).delete()

@cached_property
def chat_root(self):
return self.nodes.get(root=True, class_path=ROOT_CLASS_PATH)

@cached_property
def types(self) -> Type[BaseModel]:
"""Build pydantic model for chain input."""
try:
root = self.chat_root
input_type = create_args_model_v1(
root.config.get("outputs", []), name="ChainInput"
)
config_type = create_args_model_v1(
root.config.get("config", []), name="ChainConfig"
)
except ChainNode.DoesNotExist:
# fallback to old style roots:
# TODO: remove this fallback after all chains have been migrated
input_type = create_args_model_v1(
["user_input", "artifact_ids"], name="ChainInput"
)
config_type = create_args_model_v1([], name="ChainConfig")

class ChainConfig(BaseModel):
input: input_type
config: config_type = {}

INPUT = input_type
CONFIG = config_type

return ChainConfig
Loading

0 comments on commit 1eca74f

Please sign in to comment.