diff --git a/src/uipath_langchain/_cli/_runtime/_context.py b/src/uipath_langchain/_cli/_runtime/_context.py index 14f83ccf..cae98f3d 100644 --- a/src/uipath_langchain/_cli/_runtime/_context.py +++ b/src/uipath_langchain/_cli/_runtime/_context.py @@ -1,17 +1,12 @@ from typing import Any, Optional, Union from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver -from langgraph.graph import StateGraph from uipath._cli._runtime._contracts import UiPathRuntimeContext -from .._utils._graph import LangGraphConfig - class LangGraphRuntimeContext(UiPathRuntimeContext): """Context information passed throughout the runtime execution.""" - langgraph_config: Optional[LangGraphConfig] = None - state_graph: Optional[StateGraph[Any, Any]] = None output: Optional[Any] = None state: Optional[Any] = ( None # TypedDict issue, the actual type is: Optional[langgraph.types.StateSnapshot] diff --git a/src/uipath_langchain/_cli/_runtime/_graph_resolver.py b/src/uipath_langchain/_cli/_runtime/_graph_resolver.py new file mode 100644 index 00000000..c1f77110 --- /dev/null +++ b/src/uipath_langchain/_cli/_runtime/_graph_resolver.py @@ -0,0 +1,143 @@ +import asyncio +from typing import Any, Awaitable, Callable, Optional + +from langgraph.graph.state import CompiledStateGraph, StateGraph +from uipath._cli._runtime._contracts import ( + UiPathErrorCategory, +) + +from .._utils._graph import GraphConfig, LangGraphConfig +from ._exception import LangGraphRuntimeError + + +class LangGraphJsonResolver: + def __init__(self, entrypoint: Optional[str] = None) -> None: + self.entrypoint = entrypoint + self.graph_config: Optional[GraphConfig] = None + self._lock = asyncio.Lock() + self._graph_cache: Optional[StateGraph[Any, Any, Any]] = None + self._resolving: bool = False + + async def __call__(self) -> StateGraph[Any, Any, Any]: + # Fast path: if already resolved, return immediately without locking + if self._graph_cache is not None: + return self._graph_cache + + # Slow path: acquire lock and resolve + async with self._lock: + # Double-check after acquiring lock (another coroutine may have resolved it) + if self._graph_cache is not None: + return self._graph_cache + + self._graph_cache = await self._resolve(self.entrypoint) + return self._graph_cache + + async def _resolve(self, entrypoint: Optional[str]) -> StateGraph[Any, Any, Any]: + config = LangGraphConfig() + if not config.exists: + raise LangGraphRuntimeError( + "CONFIG_MISSING", + "Invalid configuration", + "Failed to load configuration", + UiPathErrorCategory.DEPLOYMENT, + ) + + try: + config.load_config() + except Exception as e: + raise LangGraphRuntimeError( + "CONFIG_INVALID", + "Invalid configuration", + f"Failed to load configuration: {str(e)}", + UiPathErrorCategory.DEPLOYMENT, + ) from e + + # Determine entrypoint if not provided + graphs = config.graphs + if not entrypoint and len(graphs) == 1: + entrypoint = graphs[0].name + elif not entrypoint: + graph_names = ", ".join(g.name for g in graphs) + raise LangGraphRuntimeError( + "ENTRYPOINT_MISSING", + "Entrypoint required", + f"Multiple graphs available. Please specify one of: {graph_names}.", + UiPathErrorCategory.DEPLOYMENT, + ) + + # Get the specified graph + self.graph_config = config.get_graph(entrypoint) + if not self.graph_config: + raise LangGraphRuntimeError( + "GRAPH_NOT_FOUND", + "Graph not found", + f"Graph '{entrypoint}' not found.", + UiPathErrorCategory.DEPLOYMENT, + ) + try: + loaded_graph = await self.graph_config.load_graph() + return ( + loaded_graph.builder + if isinstance(loaded_graph, CompiledStateGraph) + else loaded_graph + ) + except ImportError as e: + raise LangGraphRuntimeError( + "GRAPH_IMPORT_ERROR", + "Graph import failed", + f"Failed to import graph '{entrypoint}': {str(e)}", + UiPathErrorCategory.USER, + ) from e + except TypeError as e: + raise LangGraphRuntimeError( + "GRAPH_TYPE_ERROR", + "Invalid graph type", + f"Graph '{entrypoint}' is not a valid StateGraph or CompiledStateGraph: {str(e)}", + UiPathErrorCategory.USER, + ) from e + except ValueError as e: + raise LangGraphRuntimeError( + "GRAPH_VALUE_ERROR", + "Invalid graph value", + f"Invalid value in graph '{entrypoint}': {str(e)}", + UiPathErrorCategory.USER, + ) from e + except Exception as e: + raise LangGraphRuntimeError( + "GRAPH_LOAD_ERROR", + "Failed to load graph", + f"Unexpected error loading graph '{entrypoint}': {str(e)}", + UiPathErrorCategory.USER, + ) from e + + async def cleanup(self): + """Clean up resources""" + async with self._lock: + if self.graph_config: + await self.graph_config.cleanup() + self.graph_config = None + self._graph_cache = None + + +AsyncResolver = Callable[[], Awaitable[StateGraph[Any, Any, Any]]] + + +class LangGraphJsonResolverContext: + """ + Async context manager wrapping LangGraphJsonResolver. + Returns a callable that can be passed directly as AsyncResolver to LangGraphRuntime. + Thread-safe and reuses the same resolved graph across concurrent executions. + """ + + def __init__(self, entrypoint: Optional[str] = None) -> None: + self._resolver = LangGraphJsonResolver(entrypoint) + + async def __aenter__(self) -> AsyncResolver: + # Return a callable that safely reuses the cached graph + async def resolver_callable() -> StateGraph[Any, Any, Any]: + return await self._resolver() + + return resolver_callable + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + await self._resolver.cleanup() diff --git a/src/uipath_langchain/_cli/_runtime/_runtime.py b/src/uipath_langchain/_cli/_runtime/_runtime.py index 22eb1ff9..cebcef84 100644 --- a/src/uipath_langchain/_cli/_runtime/_runtime.py +++ b/src/uipath_langchain/_cli/_runtime/_runtime.py @@ -15,10 +15,10 @@ UiPathRuntimeResult, ) -from .._utils._graph import LangGraphConfig from ._context import LangGraphRuntimeContext from ._conversation import map_message from ._exception import LangGraphRuntimeError +from ._graph_resolver import AsyncResolver, LangGraphJsonResolver from ._input import LangGraphInputProcessor from ._output import LangGraphOutputProcessor @@ -31,9 +31,10 @@ class LangGraphRuntime(UiPathBaseRuntime): This allows using the class with 'async with' statements. """ - def __init__(self, context: LangGraphRuntimeContext): + def __init__(self, context: LangGraphRuntimeContext, graph_resolver: AsyncResolver): super().__init__(context) self.context: LangGraphRuntimeContext = context + self.graph_resolver: AsyncResolver = graph_resolver async def execute(self) -> Optional[UiPathRuntimeResult]: """ @@ -46,7 +47,8 @@ async def execute(self) -> Optional[UiPathRuntimeResult]: LangGraphRuntimeError: If execution fails """ - if self.context.state_graph is None: + graph = await self.graph_resolver() + if not graph: return None try: @@ -56,9 +58,7 @@ async def execute(self) -> Optional[UiPathRuntimeResult]: self.context.memory = memory # Compile the graph with the checkpointer - graph = self.context.state_graph.compile( - checkpointer=self.context.memory - ) + compiled_graph = graph.compile(checkpointer=self.context.memory) # Process input, handling resume if needed input_processor = LangGraphInputProcessor(context=self.context) @@ -87,7 +87,7 @@ async def execute(self) -> Optional[UiPathRuntimeResult]: graph_config["max_concurrency"] = int(max_concurrency) if self.context.chat_handler: - async for stream_chunk in graph.astream( + async for stream_chunk in compiled_graph.astream( processed_input, graph_config, stream_mode="messages", @@ -109,7 +109,7 @@ async def execute(self) -> Optional[UiPathRuntimeResult]: elif self.is_debug_run(): # Get final chunk while streaming final_chunk = None - async for stream_chunk in graph.astream( + async for stream_chunk in compiled_graph.astream( processed_input, graph_config, stream_mode="updates", @@ -118,16 +118,18 @@ async def execute(self) -> Optional[UiPathRuntimeResult]: self._pretty_print(stream_chunk) final_chunk = stream_chunk - self.context.output = self._extract_graph_result(final_chunk, graph) + self.context.output = self._extract_graph_result( + final_chunk, compiled_graph + ) else: # Execute the graph normally at runtime or eval - self.context.output = await graph.ainvoke( + self.context.output = await compiled_graph.ainvoke( processed_input, graph_config ) # Get the state if available try: - self.context.state = await graph.aget_state(graph_config) + self.context.state = await compiled_graph.aget_state(graph_config) except Exception: pass @@ -177,91 +179,10 @@ async def execute(self) -> Optional[UiPathRuntimeResult]: pass async def validate(self) -> None: - """Validate runtime inputs.""" - """Load and validate the graph configuration .""" - if self.context.langgraph_config is None: - self.context.langgraph_config = LangGraphConfig() - if not self.context.langgraph_config.exists: - raise LangGraphRuntimeError( - "CONFIG_MISSING", - "Invalid configuration", - "Failed to load configuration", - UiPathErrorCategory.DEPLOYMENT, - ) - - try: - self.context.langgraph_config.load_config() - except Exception as e: - raise LangGraphRuntimeError( - "CONFIG_INVALID", - "Invalid configuration", - f"Failed to load configuration: {str(e)}", - UiPathErrorCategory.DEPLOYMENT, - ) from e - - # Determine entrypoint if not provided - graphs = self.context.langgraph_config.graphs - if not self.context.entrypoint and len(graphs) == 1: - self.context.entrypoint = graphs[0].name - elif not self.context.entrypoint: - graph_names = ", ".join(g.name for g in graphs) - raise LangGraphRuntimeError( - "ENTRYPOINT_MISSING", - "Entrypoint required", - f"Multiple graphs available. Please specify one of: {graph_names}.", - UiPathErrorCategory.DEPLOYMENT, - ) - - # Get the specified graph - self.graph_config = self.context.langgraph_config.get_graph( - self.context.entrypoint - ) - if not self.graph_config: - raise LangGraphRuntimeError( - "GRAPH_NOT_FOUND", - "Graph not found", - f"Graph '{self.context.entrypoint}' not found.", - UiPathErrorCategory.DEPLOYMENT, - ) - try: - loaded_graph = await self.graph_config.load_graph() - self.context.state_graph = ( - loaded_graph.builder - if isinstance(loaded_graph, CompiledStateGraph) - else loaded_graph - ) - except ImportError as e: - raise LangGraphRuntimeError( - "GRAPH_IMPORT_ERROR", - "Graph import failed", - f"Failed to import graph '{self.context.entrypoint}': {str(e)}", - UiPathErrorCategory.USER, - ) from e - except TypeError as e: - raise LangGraphRuntimeError( - "GRAPH_TYPE_ERROR", - "Invalid graph type", - f"Graph '{self.context.entrypoint}' is not a valid StateGraph or CompiledStateGraph: {str(e)}", - UiPathErrorCategory.USER, - ) from e - except ValueError as e: - raise LangGraphRuntimeError( - "GRAPH_VALUE_ERROR", - "Invalid graph value", - f"Invalid value in graph '{self.context.entrypoint}': {str(e)}", - UiPathErrorCategory.USER, - ) from e - except Exception as e: - raise LangGraphRuntimeError( - "GRAPH_LOAD_ERROR", - "Failed to load graph", - f"Unexpected error loading graph '{self.context.entrypoint}': {str(e)}", - UiPathErrorCategory.USER, - ) from e + pass async def cleanup(self): - if hasattr(self, "graph_config") and self.graph_config: - await self.graph_config.cleanup() + pass def _extract_graph_result( self, final_chunk, graph: CompiledStateGraph[Any, Any, Any] @@ -377,3 +298,19 @@ def _pretty_print(self, stream_chunk: Union[Tuple[Any, Any], Dict[str, Any], Any logger.info("%s", formatted_metadata) except (TypeError, ValueError): pass + + +class LangGraphScriptRuntime(LangGraphRuntime): + """ + Resolves the graph from langgraph.json config file and passes it to the base runtime. + """ + + def __init__( + self, context: LangGraphRuntimeContext, entrypoint: Optional[str] = None + ): + self.resolver = LangGraphJsonResolver(entrypoint=entrypoint) + super().__init__(context, self.resolver) + + async def cleanup(self): + await super().cleanup() + await self.resolver.cleanup() diff --git a/src/uipath_langchain/_cli/cli_dev.py b/src/uipath_langchain/_cli/cli_dev.py index d0f3d585..3b9ad95e 100644 --- a/src/uipath_langchain/_cli/cli_dev.py +++ b/src/uipath_langchain/_cli/cli_dev.py @@ -12,7 +12,7 @@ from .._tracing import _instrument_traceable_attributes from ._runtime._context import LangGraphRuntimeContext -from ._runtime._runtime import LangGraphRuntime +from ._runtime._runtime import LangGraphScriptRuntime console = ConsoleLogger() @@ -22,8 +22,14 @@ def langgraph_dev_middleware(interface: Optional[str]) -> MiddlewareResult: try: if interface == "terminal": + + def generate_runtime( + ctx: LangGraphRuntimeContext, + ) -> LangGraphScriptRuntime: + return LangGraphScriptRuntime(ctx, ctx.entrypoint) + runtime_factory = UiPathRuntimeFactory( - LangGraphRuntime, LangGraphRuntimeContext + LangGraphScriptRuntime, LangGraphRuntimeContext, generate_runtime ) _instrument_traceable_attributes() diff --git a/src/uipath_langchain/_cli/cli_eval.py b/src/uipath_langchain/_cli/cli_eval.py index 5e52ef5d..51d00eec 100644 --- a/src/uipath_langchain/_cli/cli_eval.py +++ b/src/uipath_langchain/_cli/cli_eval.py @@ -17,7 +17,7 @@ from uipath.eval._helpers import auto_discover_entrypoint from uipath_langchain._cli._runtime._context import LangGraphRuntimeContext -from uipath_langchain._cli._runtime._runtime import LangGraphRuntime +from uipath_langchain._cli._runtime._runtime import LangGraphScriptRuntime from uipath_langchain._cli._utils._graph import LangGraphConfig from uipath_langchain._tracing import ( LangChainExporter, @@ -48,10 +48,9 @@ def langgraph_eval_middleware( asyncio.run(console_reporter.subscribe_to_eval_runtime_events(event_bus)) def generate_runtime_context( - context_entrypoint: str, langgraph_config: LangGraphConfig, **context_kwargs + context_entrypoint: str, **context_kwargs ) -> LangGraphRuntimeContext: context = LangGraphRuntimeContext.with_defaults(**context_kwargs) - context.langgraph_config = langgraph_config context.entrypoint = context_entrypoint return context @@ -63,14 +62,17 @@ def generate_runtime_context( eval_context.eval_set = eval_set or EvalHelpers.auto_discover_eval_set() eval_context.eval_ids = eval_ids + def generate_runtime(ctx: LangGraphRuntimeContext) -> LangGraphScriptRuntime: + return LangGraphScriptRuntime(ctx, ctx.entrypoint) + runtime_factory = UiPathRuntimeFactory( - LangGraphRuntime, + LangGraphScriptRuntime, LangGraphRuntimeContext, context_generator=lambda **context_kwargs: generate_runtime_context( context_entrypoint=runtime_entrypoint, - langgraph_config=config, **context_kwargs, ), + runtime_generator=generate_runtime, ) if eval_context.job_id: diff --git a/src/uipath_langchain/_cli/cli_run.py b/src/uipath_langchain/_cli/cli_run.py index 61d0cc4c..2c154c00 100644 --- a/src/uipath_langchain/_cli/cli_run.py +++ b/src/uipath_langchain/_cli/cli_run.py @@ -14,8 +14,8 @@ from .._tracing import LangChainExporter, _instrument_traceable_attributes from ._runtime._exception import LangGraphRuntimeError from ._runtime._runtime import ( # type: ignore[attr-defined] - LangGraphRuntime, LangGraphRuntimeContext, + LangGraphScriptRuntime, ) from ._utils._graph import LangGraphConfig @@ -32,15 +32,14 @@ def langgraph_run_middleware( try: context = LangGraphRuntimeContext.with_defaults(**kwargs) - context.langgraph_config = config context.entrypoint = entrypoint context.input = input context.resume = resume _instrument_traceable_attributes() - def generate_runtime(ctx: LangGraphRuntimeContext) -> LangGraphRuntime: - runtime = LangGraphRuntime(ctx) + def generate_runtime(ctx: LangGraphRuntimeContext) -> LangGraphScriptRuntime: + runtime = LangGraphScriptRuntime(ctx, ctx.entrypoint) # If not resuming and no job id, delete the previous state file if not ctx.resume and ctx.job_id is None: if os.path.exists(runtime.state_file_path): @@ -49,7 +48,7 @@ def generate_runtime(ctx: LangGraphRuntimeContext) -> LangGraphRuntime: async def execute(): runtime_factory = UiPathRuntimeFactory( - LangGraphRuntime, + LangGraphScriptRuntime, LangGraphRuntimeContext, runtime_generator=generate_runtime, )