Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions src/uipath_langchain/_cli/_runtime/_context.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
from typing import Any, Optional, Union

from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
from langgraph.graph import StateGraph
from uipath._cli._runtime._contracts import UiPathRuntimeContext

from .._utils._graph import LangGraphConfig


class LangGraphRuntimeContext(UiPathRuntimeContext):
"""Context information passed throughout the runtime execution."""

langgraph_config: Optional[LangGraphConfig] = None
state_graph: Optional[StateGraph[Any, Any]] = None
output: Optional[Any] = None
state: Optional[Any] = (
None # TypedDict issue, the actual type is: Optional[langgraph.types.StateSnapshot]
Expand Down
143 changes: 143 additions & 0 deletions src/uipath_langchain/_cli/_runtime/_graph_resolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import asyncio
from typing import Any, Awaitable, Callable, Optional

from langgraph.graph.state import CompiledStateGraph, StateGraph
from uipath._cli._runtime._contracts import (
UiPathErrorCategory,
)

from .._utils._graph import GraphConfig, LangGraphConfig
from ._exception import LangGraphRuntimeError


class LangGraphJsonResolver:
def __init__(self, entrypoint: Optional[str] = None) -> None:
self.entrypoint = entrypoint
self.graph_config: Optional[GraphConfig] = None
self._lock = asyncio.Lock()
self._graph_cache: Optional[StateGraph[Any, Any, Any]] = None
self._resolving: bool = False

async def __call__(self) -> StateGraph[Any, Any, Any]:
# Fast path: if already resolved, return immediately without locking
if self._graph_cache is not None:
return self._graph_cache

# Slow path: acquire lock and resolve
async with self._lock:
# Double-check after acquiring lock (another coroutine may have resolved it)
if self._graph_cache is not None:
return self._graph_cache

self._graph_cache = await self._resolve(self.entrypoint)
return self._graph_cache

async def _resolve(self, entrypoint: Optional[str]) -> StateGraph[Any, Any, Any]:
config = LangGraphConfig()
if not config.exists:
raise LangGraphRuntimeError(
"CONFIG_MISSING",
"Invalid configuration",
"Failed to load configuration",
UiPathErrorCategory.DEPLOYMENT,
)

try:
config.load_config()
except Exception as e:
raise LangGraphRuntimeError(
"CONFIG_INVALID",
"Invalid configuration",
f"Failed to load configuration: {str(e)}",
UiPathErrorCategory.DEPLOYMENT,
) from e

# Determine entrypoint if not provided
graphs = config.graphs
if not entrypoint and len(graphs) == 1:
entrypoint = graphs[0].name
elif not entrypoint:
graph_names = ", ".join(g.name for g in graphs)
raise LangGraphRuntimeError(
"ENTRYPOINT_MISSING",
"Entrypoint required",
f"Multiple graphs available. Please specify one of: {graph_names}.",
UiPathErrorCategory.DEPLOYMENT,
)

# Get the specified graph
self.graph_config = config.get_graph(entrypoint)
if not self.graph_config:
raise LangGraphRuntimeError(
"GRAPH_NOT_FOUND",
"Graph not found",
f"Graph '{entrypoint}' not found.",
UiPathErrorCategory.DEPLOYMENT,
)
try:
loaded_graph = await self.graph_config.load_graph()
return (
loaded_graph.builder
if isinstance(loaded_graph, CompiledStateGraph)
else loaded_graph
)
except ImportError as e:
raise LangGraphRuntimeError(
"GRAPH_IMPORT_ERROR",
"Graph import failed",
f"Failed to import graph '{entrypoint}': {str(e)}",
UiPathErrorCategory.USER,
) from e
except TypeError as e:
raise LangGraphRuntimeError(
"GRAPH_TYPE_ERROR",
"Invalid graph type",
f"Graph '{entrypoint}' is not a valid StateGraph or CompiledStateGraph: {str(e)}",
UiPathErrorCategory.USER,
) from e
except ValueError as e:
raise LangGraphRuntimeError(
"GRAPH_VALUE_ERROR",
"Invalid graph value",
f"Invalid value in graph '{entrypoint}': {str(e)}",
UiPathErrorCategory.USER,
) from e
except Exception as e:
raise LangGraphRuntimeError(
"GRAPH_LOAD_ERROR",
"Failed to load graph",
f"Unexpected error loading graph '{entrypoint}': {str(e)}",
UiPathErrorCategory.USER,
) from e

async def cleanup(self):
"""Clean up resources"""
async with self._lock:
if self.graph_config:
await self.graph_config.cleanup()
self.graph_config = None
self._graph_cache = None


AsyncResolver = Callable[[], Awaitable[StateGraph[Any, Any, Any]]]


class LangGraphJsonResolverContext:
"""
Async context manager wrapping LangGraphJsonResolver.
Returns a callable that can be passed directly as AsyncResolver to LangGraphRuntime.
Thread-safe and reuses the same resolved graph across concurrent executions.
"""

def __init__(self, entrypoint: Optional[str] = None) -> None:
self._resolver = LangGraphJsonResolver(entrypoint)

async def __aenter__(self) -> AsyncResolver:
# Return a callable that safely reuses the cached graph
async def resolver_callable() -> StateGraph[Any, Any, Any]:
return await self._resolver()

return resolver_callable

async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
await self._resolver.cleanup()
125 changes: 31 additions & 94 deletions src/uipath_langchain/_cli/_runtime/_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
UiPathRuntimeResult,
)

from .._utils._graph import LangGraphConfig
from ._context import LangGraphRuntimeContext
from ._conversation import map_message
from ._exception import LangGraphRuntimeError
from ._graph_resolver import AsyncResolver, LangGraphJsonResolver
from ._input import LangGraphInputProcessor
from ._output import LangGraphOutputProcessor

Expand All @@ -31,9 +31,10 @@ class LangGraphRuntime(UiPathBaseRuntime):
This allows using the class with 'async with' statements.
"""

def __init__(self, context: LangGraphRuntimeContext):
def __init__(self, context: LangGraphRuntimeContext, graph_resolver: AsyncResolver):
super().__init__(context)
self.context: LangGraphRuntimeContext = context
self.graph_resolver: AsyncResolver = graph_resolver

async def execute(self) -> Optional[UiPathRuntimeResult]:
"""
Expand All @@ -46,7 +47,8 @@ async def execute(self) -> Optional[UiPathRuntimeResult]:
LangGraphRuntimeError: If execution fails
"""

if self.context.state_graph is None:
graph = await self.graph_resolver()
if not graph:
return None

try:
Expand All @@ -56,9 +58,7 @@ async def execute(self) -> Optional[UiPathRuntimeResult]:
self.context.memory = memory

# Compile the graph with the checkpointer
graph = self.context.state_graph.compile(
checkpointer=self.context.memory
)
compiled_graph = graph.compile(checkpointer=self.context.memory)

# Process input, handling resume if needed
input_processor = LangGraphInputProcessor(context=self.context)
Expand Down Expand Up @@ -87,7 +87,7 @@ async def execute(self) -> Optional[UiPathRuntimeResult]:
graph_config["max_concurrency"] = int(max_concurrency)

if self.context.chat_handler:
async for stream_chunk in graph.astream(
async for stream_chunk in compiled_graph.astream(
processed_input,
graph_config,
stream_mode="messages",
Expand All @@ -109,7 +109,7 @@ async def execute(self) -> Optional[UiPathRuntimeResult]:
elif self.is_debug_run():
# Get final chunk while streaming
final_chunk = None
async for stream_chunk in graph.astream(
async for stream_chunk in compiled_graph.astream(
processed_input,
graph_config,
stream_mode="updates",
Expand All @@ -118,16 +118,18 @@ async def execute(self) -> Optional[UiPathRuntimeResult]:
self._pretty_print(stream_chunk)
final_chunk = stream_chunk

self.context.output = self._extract_graph_result(final_chunk, graph)
self.context.output = self._extract_graph_result(
final_chunk, compiled_graph
)
else:
# Execute the graph normally at runtime or eval
self.context.output = await graph.ainvoke(
self.context.output = await compiled_graph.ainvoke(
processed_input, graph_config
)

# Get the state if available
try:
self.context.state = await graph.aget_state(graph_config)
self.context.state = await compiled_graph.aget_state(graph_config)
except Exception:
pass

Expand Down Expand Up @@ -177,91 +179,10 @@ async def execute(self) -> Optional[UiPathRuntimeResult]:
pass

async def validate(self) -> None:
"""Validate runtime inputs."""
"""Load and validate the graph configuration ."""
if self.context.langgraph_config is None:
self.context.langgraph_config = LangGraphConfig()
if not self.context.langgraph_config.exists:
raise LangGraphRuntimeError(
"CONFIG_MISSING",
"Invalid configuration",
"Failed to load configuration",
UiPathErrorCategory.DEPLOYMENT,
)

try:
self.context.langgraph_config.load_config()
except Exception as e:
raise LangGraphRuntimeError(
"CONFIG_INVALID",
"Invalid configuration",
f"Failed to load configuration: {str(e)}",
UiPathErrorCategory.DEPLOYMENT,
) from e

# Determine entrypoint if not provided
graphs = self.context.langgraph_config.graphs
if not self.context.entrypoint and len(graphs) == 1:
self.context.entrypoint = graphs[0].name
elif not self.context.entrypoint:
graph_names = ", ".join(g.name for g in graphs)
raise LangGraphRuntimeError(
"ENTRYPOINT_MISSING",
"Entrypoint required",
f"Multiple graphs available. Please specify one of: {graph_names}.",
UiPathErrorCategory.DEPLOYMENT,
)

# Get the specified graph
self.graph_config = self.context.langgraph_config.get_graph(
self.context.entrypoint
)
if not self.graph_config:
raise LangGraphRuntimeError(
"GRAPH_NOT_FOUND",
"Graph not found",
f"Graph '{self.context.entrypoint}' not found.",
UiPathErrorCategory.DEPLOYMENT,
)
try:
loaded_graph = await self.graph_config.load_graph()
self.context.state_graph = (
loaded_graph.builder
if isinstance(loaded_graph, CompiledStateGraph)
else loaded_graph
)
except ImportError as e:
raise LangGraphRuntimeError(
"GRAPH_IMPORT_ERROR",
"Graph import failed",
f"Failed to import graph '{self.context.entrypoint}': {str(e)}",
UiPathErrorCategory.USER,
) from e
except TypeError as e:
raise LangGraphRuntimeError(
"GRAPH_TYPE_ERROR",
"Invalid graph type",
f"Graph '{self.context.entrypoint}' is not a valid StateGraph or CompiledStateGraph: {str(e)}",
UiPathErrorCategory.USER,
) from e
except ValueError as e:
raise LangGraphRuntimeError(
"GRAPH_VALUE_ERROR",
"Invalid graph value",
f"Invalid value in graph '{self.context.entrypoint}': {str(e)}",
UiPathErrorCategory.USER,
) from e
except Exception as e:
raise LangGraphRuntimeError(
"GRAPH_LOAD_ERROR",
"Failed to load graph",
f"Unexpected error loading graph '{self.context.entrypoint}': {str(e)}",
UiPathErrorCategory.USER,
) from e
pass

async def cleanup(self):
if hasattr(self, "graph_config") and self.graph_config:
await self.graph_config.cleanup()
pass

def _extract_graph_result(
self, final_chunk, graph: CompiledStateGraph[Any, Any, Any]
Expand Down Expand Up @@ -377,3 +298,19 @@ def _pretty_print(self, stream_chunk: Union[Tuple[Any, Any], Dict[str, Any], Any
logger.info("%s", formatted_metadata)
except (TypeError, ValueError):
pass


class LangGraphScriptRuntime(LangGraphRuntime):
"""
Resolves the graph from langgraph.json config file and passes it to the base runtime.
"""

def __init__(
self, context: LangGraphRuntimeContext, entrypoint: Optional[str] = None
):
self.resolver = LangGraphJsonResolver(entrypoint=entrypoint)
super().__init__(context, self.resolver)

async def cleanup(self):
await super().cleanup()
await self.resolver.cleanup()
10 changes: 8 additions & 2 deletions src/uipath_langchain/_cli/cli_dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from .._tracing import _instrument_traceable_attributes
from ._runtime._context import LangGraphRuntimeContext
from ._runtime._runtime import LangGraphRuntime
from ._runtime._runtime import LangGraphScriptRuntime

console = ConsoleLogger()

Expand All @@ -22,8 +22,14 @@ def langgraph_dev_middleware(interface: Optional[str]) -> MiddlewareResult:

try:
if interface == "terminal":

def generate_runtime(
ctx: LangGraphRuntimeContext,
) -> LangGraphScriptRuntime:
return LangGraphScriptRuntime(ctx, ctx.entrypoint)

runtime_factory = UiPathRuntimeFactory(
LangGraphRuntime, LangGraphRuntimeContext
LangGraphScriptRuntime, LangGraphRuntimeContext, generate_runtime
)

_instrument_traceable_attributes()
Expand Down
Loading