Skip to content

Commit

Permalink
langgraph: validate sync/async clients initialized correctly in Remot…
Browse files Browse the repository at this point in the history
…eGraph (#2214)
  • Loading branch information
vbarda authored Oct 29, 2024
1 parent 90195af commit e7dc43b
Showing 1 changed file with 46 additions and 16 deletions.
62 changes: 46 additions & 16 deletions libs/langgraph/langgraph/pregel/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,28 @@ def __init__(
"""
self.name = name
self.config = config
self.client = client or get_client(url=url, api_key=api_key, headers=headers)
self.sync_client = sync_client or get_sync_client(
url=url, api_key=api_key, headers=headers
)

if client is None and url is not None:
client = get_client(url=url, api_key=api_key, headers=headers)
self.client = client

if sync_client is None and url is not None:
sync_client = get_sync_client(url=url, api_key=api_key, headers=headers)
self.sync_client = sync_client

def _validate_client(self) -> LangGraphClient:
if self.client is None:
raise ValueError(
"Async client is not initialized: please provide `url` or `client` when initializing `RemoteGraph`."
)
return self.client

def _validate_sync_client(self) -> SyncLangGraphClient:
if self.sync_client is None:
raise ValueError(
"Sync client is not initialized: please provide `url` or `sync_client` when initializing `RemoteGraph`."
)
return self.sync_client

def copy(self, update: dict[str, Any]) -> Self:
attrs = {**self.__dict__, **update}
Expand Down Expand Up @@ -103,7 +121,8 @@ def get_graph(
*,
xray: Union[int, bool] = False,
) -> DrawableGraph:
graph = self.sync_client.assistants.get_graph(
sync_client = self._validate_sync_client()
graph = sync_client.assistants.get_graph(
assistant_id=self.name,
xray=xray,
)
Expand All @@ -118,7 +137,8 @@ async def aget_graph(
*,
xray: Union[int, bool] = False,
) -> DrawableGraph:
graph = await self.client.assistants.get_graph(
client = self._validate_client()
graph = await client.assistants.get_graph(
assistant_id=self.name,
xray=xray,
)
Expand Down Expand Up @@ -248,9 +268,10 @@ def _sanitize_obj(obj: Any) -> Any:
def get_state(
self, config: RunnableConfig, *, subgraphs: bool = False
) -> StateSnapshot:
sync_client = self._validate_sync_client()
merged_config = merge_configs(self.config, config)

state = self.sync_client.threads.get_state(
state = sync_client.threads.get_state(
thread_id=merged_config["configurable"]["thread_id"],
checkpoint=self._get_checkpoint(merged_config),
subgraphs=subgraphs,
Expand All @@ -260,9 +281,10 @@ def get_state(
async def aget_state(
self, config: RunnableConfig, *, subgraphs: bool = False
) -> StateSnapshot:
client = self._validate_client()
merged_config = merge_configs(self.config, config)

state = await self.client.threads.get_state(
state = await client.threads.get_state(
thread_id=merged_config["configurable"]["thread_id"],
checkpoint=self._get_checkpoint(merged_config),
subgraphs=subgraphs,
Expand All @@ -277,9 +299,10 @@ def get_state_history(
before: Optional[RunnableConfig] = None,
limit: Optional[int] = None,
) -> Iterator[StateSnapshot]:
sync_client = self._validate_sync_client()
merged_config = merge_configs(self.config, config)

states = self.sync_client.threads.get_history(
states = sync_client.threads.get_history(
thread_id=merged_config["configurable"]["thread_id"],
limit=limit if limit else 10,
before=self._get_checkpoint(before),
Expand All @@ -297,9 +320,10 @@ async def aget_state_history(
before: Optional[RunnableConfig] = None,
limit: Optional[int] = None,
) -> AsyncIterator[StateSnapshot]:
client = self._validate_client()
merged_config = merge_configs(self.config, config)

states = await self.client.threads.get_history(
states = await client.threads.get_history(
thread_id=merged_config["configurable"]["thread_id"],
limit=limit if limit else 10,
before=self._get_checkpoint(before),
Expand All @@ -315,9 +339,10 @@ def update_state(
values: Optional[Union[dict[str, Any], Any]],
as_node: Optional[str] = None,
) -> RunnableConfig:
sync_client = self._validate_sync_client()
merged_config = merge_configs(self.config, config)

response: dict = self.sync_client.threads.update_state( # type: ignore
response: dict = sync_client.threads.update_state( # type: ignore
thread_id=merged_config["configurable"]["thread_id"],
values=values,
as_node=as_node,
Expand All @@ -331,9 +356,10 @@ async def aupdate_state(
values: Optional[Union[dict[str, Any], Any]],
as_node: Optional[str] = None,
) -> RunnableConfig:
client = self._validate_client()
merged_config = merge_configs(self.config, config)

response: dict = await self.client.threads.update_state( # type: ignore
response: dict = await client.threads.update_state( # type: ignore
thread_id=merged_config["configurable"]["thread_id"],
values=values,
as_node=as_node,
Expand Down Expand Up @@ -382,11 +408,12 @@ def stream(
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
subgraphs: bool = False,
) -> Iterator[Union[dict[str, Any], Any]]:
sync_client = self._validate_sync_client()
merged_config = merge_configs(self.config, config)
sanitized_config = self._sanitize_config(merged_config)
stream_modes, req_updates, req_single = self._get_stream_modes(stream_mode)

for chunk in self.sync_client.runs.stream(
for chunk in sync_client.runs.stream(
thread_id=sanitized_config["configurable"].get("thread_id"),
assistant_id=self.name,
input=input,
Expand Down Expand Up @@ -429,11 +456,12 @@ async def astream(
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
subgraphs: bool = False,
) -> AsyncIterator[Union[dict[str, Any], Any]]:
client = self._validate_client()
merged_config = merge_configs(self.config, config)
sanitized_config = self._sanitize_config(merged_config)
stream_modes, req_updates, req_single = self._get_stream_modes(stream_mode)

async for chunk in self.client.runs.stream(
async for chunk in client.runs.stream(
thread_id=sanitized_config["configurable"].get("thread_id"),
assistant_id=self.name,
input=input,
Expand Down Expand Up @@ -490,10 +518,11 @@ def invoke(
interrupt_before: Optional[Union[All, Sequence[str]]] = None,
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
) -> Union[dict[str, Any], Any]:
sync_client = self._validate_sync_client()
merged_config = merge_configs(self.config, config)
sanitized_config = self._sanitize_config(merged_config)

return self.sync_client.runs.wait(
return sync_client.runs.wait(
thread_id=sanitized_config["configurable"].get("thread_id"),
assistant_id=self.name,
input=input,
Expand All @@ -511,10 +540,11 @@ async def ainvoke(
interrupt_before: Optional[Union[All, Sequence[str]]] = None,
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
) -> Union[dict[str, Any], Any]:
client = self._validate_client()
merged_config = merge_configs(self.config, config)
sanitized_config = self._sanitize_config(merged_config)

return await self.client.runs.wait(
return await client.runs.wait(
thread_id=sanitized_config["configurable"].get("thread_id"),
assistant_id=self.name,
input=input,
Expand Down

0 comments on commit e7dc43b

Please sign in to comment.