diff --git a/chainlit_client/api.py b/chainlit_client/api.py index 4b5a70e..aa1ce7e 100644 --- a/chainlit_client/api.py +++ b/chainlit_client/api.py @@ -474,7 +474,7 @@ async def create_thread( async def upsert_thread( self, - thread_id: str, + id: str, metadata: Optional[Dict] = None, participant_id: Optional[str] = None, environment: Optional[str] = None, @@ -483,14 +483,14 @@ async def upsert_thread( query = ( """ mutation UpsertThread( - $threadId: String!, + $id: String!, $metadata: Json, $participantId: String, $environment: String, $tags: [String!], ) { upsertThread( - id: $threadId + id: $id metadata: $metadata participantId: $participantId environment: $environment @@ -504,7 +504,7 @@ async def upsert_thread( """ ) variables = { - "threadId": thread_id, + "id": id, "metadata": metadata, "participantId": participant_id, "environment": environment, diff --git a/chainlit_client/client.py b/chainlit_client/client.py index 46f7c6f..d26b789 100644 --- a/chainlit_client/client.py +++ b/chainlit_client/client.py @@ -55,11 +55,15 @@ def langchain_callback( **kwargs, ) - def thread(self, original_function=None, *, thread_id: Optional[str] = None): + def thread( + self, original_function=None, *, thread_id: Optional[str] = None, **kwargs + ): if original_function: - return thread_decorator(self, func=original_function, thread_id=thread_id) + return thread_decorator( + self, func=original_function, thread_id=thread_id, **kwargs + ) else: - return ThreadContextManager(self, thread_id=thread_id) + return ThreadContextManager(self, thread_id=thread_id, **kwargs) def step( self, diff --git a/chainlit_client/context.py b/chainlit_client/context.py index 5620496..668e7b8 100644 --- a/chainlit_client/context.py +++ b/chainlit_client/context.py @@ -3,6 +3,7 @@ if TYPE_CHECKING: from chainlit_client.step import Step + from chainlit_client.thread import Thread active_steps_var = ContextVar[List["Step"]]("active_steps", default=[]) -active_thread_var = ContextVar[Optional[str]]("active_thread", default=None) +active_thread_var = ContextVar[Optional["Thread"]]("active_thread", default=None) diff --git a/chainlit_client/thread.py b/chainlit_client/thread.py index 208d194..ef904aa 100644 --- a/chainlit_client/thread.py +++ b/chainlit_client/thread.py @@ -1,5 +1,6 @@ import inspect import uuid +import asyncio from functools import wraps from typing import TYPE_CHECKING, Callable, Dict, List, Literal, Optional @@ -20,6 +21,7 @@ class Thread: steps: Optional[List[Step]] user: Optional["User"] created_at: Optional[str] # read-only, set by server + needs_upsert: Optional[bool] def __init__( self, @@ -28,12 +30,14 @@ def __init__( metadata: Optional[Dict] = {}, tags: Optional[List[str]] = [], user: Optional["User"] = None, + needs_upsert: "Optional[bool]" = False, ): self.id = id self.steps = steps self.metadata = metadata self.tags = tags self.user = user + self.needs_upsert = needs_upsert def to_dict(self): return { @@ -42,7 +46,6 @@ def to_dict(self): "tags": self.tags, "steps": [step.to_dict() for step in self.steps] if self.steps else [], "participant": self.user.to_dict() if self.user else None, - "createdAt": self.created_at, } @classmethod @@ -68,38 +71,54 @@ class ThreadContextManager: def __init__( self, client: "ChainlitClient", - thread_id: Optional[str] = None, + thread_id: "Optional[str]" = None, + **kwargs, ): self.client = client if thread_id is None: thread_id = str(uuid.uuid4()) self.thread_id = thread_id - active_thread_var.set(Thread(id=thread_id)) + needs_upsert = kwargs != {} + active_thread_var.set(Thread(id=thread_id, needs_upsert=needs_upsert, **kwargs)) + + async def save(self): + thread = active_thread_var.get() + thread_data = thread.to_dict() + thread_data.pop("steps", None) + thread_data.pop("participant", None) + await self.client.api.upsert_thread(**thread_data) def __call__(self, func): return thread_decorator(self.client, func=func, thread_id=self.thread_id) - def __enter__(self) -> Thread: + def __enter__(self) -> "Optional[Thread]": return active_thread_var.get() def __exit__(self, exc_type, exc_val, exc_tb): + if (thread := active_thread_var.get()) and thread.needs_upsert: + asyncio.run(self.save()) active_thread_var.set(None) async def __aenter__(self): return active_thread_var.get() async def __aexit__(self, exc_type, exc_val, exc_tb): + if (thread := active_thread_var.get()) and thread.needs_upsert: + await self.save() active_thread_var.set(None) def thread_decorator( - client: "ChainlitClient", func: Callable, thread_id: Optional[str] = None + client: "ChainlitClient", + func: Callable, + thread_id: Optional[str] = None, + **decorator_kwargs, ): if inspect.iscoroutinefunction(func): @wraps(func) async def async_wrapper(*args, **kwargs): - with ThreadContextManager(client, thread_id=thread_id): + with ThreadContextManager(client, thread_id=thread_id, **decorator_kwargs): result = await func(*args, **kwargs) return result @@ -108,7 +127,7 @@ async def async_wrapper(*args, **kwargs): @wraps(func) def sync_wrapper(*args, **kwargs): - with ThreadContextManager(client, thread_id=thread_id): + with ThreadContextManager(client, thread_id=thread_id, **decorator_kwargs): return func(*args, **kwargs) return sync_wrapper diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index 4cc52b1..37d1997 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -151,8 +151,8 @@ async def test_attachment(self, client): # @pytest.mark.skip(reason="segmentation fault") async def test_ingestion(self, client): - with client.thread(): - with client.step(name="test_ingestion") as step: + async with client.thread(): + async with client.step(name="test_ingestion") as step: step.metadata = {"foo": "bar"} assert client.event_processor.event_queue._qsize() == 0 stack = chainlit_client.context.active_steps_var.get() @@ -161,3 +161,13 @@ async def test_ingestion(self, client): assert client.event_processor.event_queue._qsize() == 1 stack = chainlit_client.context.active_steps_var.get() assert len(stack) == 0 + + async def create_thread(self, client): + async with client.thread(tags=["foo", "bar"]) as thread: + thread_id = thread.id + return thread_id + + async def test_thread_context(self, client): + thread_id = await self.create_thread(client) + new_thread = await client.api.get_thread(id=thread_id) + assert new_thread.tags == ["foo", "bar"] diff --git a/tests/test_threads_context.py b/tests/test_threads_context.py deleted file mode 100644 index 057ee4d..0000000 --- a/tests/test_threads_context.py +++ /dev/null @@ -1,13 +0,0 @@ -from chainlit_client.client import ChainlitClient - - -def function_where_i_want_the_thread(client): - return client.get_current_thread() - - -def test_thread_context(): - client = ChainlitClient() - with client.thread() as thread_from_parent: - thread_from_child = function_where_i_want_the_thread(client) - - assert thread_from_parent.id == thread_from_child.id