From 82c316b3f70ad0a043cb96bbf0af9be30b832f83 Mon Sep 17 00:00:00 2001 From: Andrew Nguonly Date: Tue, 8 Oct 2024 12:04:26 -0700 Subject: [PATCH] sdk-py: Update return type annotation for `Thread.update_state()` methods. (#2050) ### Summary The response body of the endpoint `POST /threads/{thread_id}/state` looks like this: ``` { "checkpoint": { "thread_id": "e2496803-ecd5-4e0c-a779-3226296181c2", "checkpoint_ns": "", "checkpoint_id": "1ef4a9b8-e6fb-67b1-8001-abd5184439d1", "checkpoint_map": {} } } ``` --- libs/sdk-py/langgraph_sdk/client.py | 37 ++++++++++++++++++++++++----- libs/sdk-py/langgraph_sdk/schema.py | 7 ++++++ 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/libs/sdk-py/langgraph_sdk/client.py b/libs/sdk-py/langgraph_sdk/client.py index 8d7fccdc6..409fcd545 100644 --- a/libs/sdk-py/langgraph_sdk/client.py +++ b/libs/sdk-py/langgraph_sdk/client.py @@ -53,6 +53,7 @@ Thread, ThreadState, ThreadStatus, + ThreadUpdateStateResponse, ) logger = logging.getLogger(__name__) @@ -1060,7 +1061,7 @@ async def update_state( as_node: Optional[str] = None, checkpoint: Optional[Checkpoint] = None, checkpoint_id: Optional[str] = None, # deprecated - ) -> None: + ) -> ThreadUpdateStateResponse: """Update the state of a thread. Args: @@ -1070,15 +1071,27 @@ async def update_state( checkpoint: The checkpoint to update the state of. Returns: - None + ThreadUpdateStateResponse: Response after updating a thread's state. Example Usage: - await client.threads.update_state( + response = await client.threads.update_state( thread_id="my_thread_id", values={"messages":[{"role": "user", "content": "hello!"}]}, as_node="my_node", ) + print(response) + + ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- + + { + 'checkpoint': { + 'thread_id': 'e2496803-ecd5-4e0c-a779-3226296181c2', + 'checkpoint_ns': '', + 'checkpoint_id': '1ef4a9b8-e6fb-67b1-8001-abd5184439d1', + 'checkpoint_map': {} + } + } """ # noqa: E501 payload: Dict[str, Any] = { @@ -3109,7 +3122,7 @@ def update_state( as_node: Optional[str] = None, checkpoint: Optional[Checkpoint] = None, checkpoint_id: Optional[str] = None, # deprecated - ) -> None: + ) -> ThreadUpdateStateResponse: """Update the state of a thread. Args: @@ -3119,15 +3132,27 @@ def update_state( checkpoint: The checkpoint to update the state of. Returns: - None + ThreadUpdateStateResponse: Response after updating a thread's state. Example Usage: - await client.threads.update_state( + response = client.threads.update_state( thread_id="my_thread_id", values={"messages":[{"role": "user", "content": "hello!"}]}, as_node="my_node", ) + print(response) + + ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- + + { + 'checkpoint': { + 'thread_id': 'e2496803-ecd5-4e0c-a779-3226296181c2', + 'checkpoint_ns': '', + 'checkpoint_id': '1ef4a9b8-e6fb-67b1-8001-abd5184439d1', + 'checkpoint_map': {} + } + } """ # noqa: E501 payload: Dict[str, Any] = { diff --git a/libs/sdk-py/langgraph_sdk/schema.py b/libs/sdk-py/langgraph_sdk/schema.py index c6101bdce..bc7bf558e 100644 --- a/libs/sdk-py/langgraph_sdk/schema.py +++ b/libs/sdk-py/langgraph_sdk/schema.py @@ -207,6 +207,13 @@ class ThreadState(TypedDict): """Tasks to execute in this step. If already attempted, may contain an error.""" +class ThreadUpdateStateResponse(TypedDict): + """Represents the response from updating a thread's state.""" + + checkpoint: Checkpoint + """Checkpoint of the latest state.""" + + class Run(TypedDict): """Represents a single execution run."""