Skip to content

Commit

Permalink
feat: upsert thread when other data is passed to thread context manager
Browse files Browse the repository at this point in the history
  • Loading branch information
Pawel Morawian committed Dec 28, 2023
1 parent fe994b0 commit 5a18e5f
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 30 deletions.
8 changes: 4 additions & 4 deletions chainlit_client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ async def create_thread(

async def upsert_thread(
self,
thread_id: str,
id: str,
metadata: Optional[Dict] = None,
participant_id: Optional[str] = None,
environment: Optional[str] = None,
Expand All @@ -483,14 +483,14 @@ async def upsert_thread(
query = (
"""
mutation UpsertThread(
$threadId: String!,
$id: String!,
$metadata: Json,
$participantId: String,
$environment: String,
$tags: [String!],
) {
upsertThread(
id: $threadId
id: $id
metadata: $metadata
participantId: $participantId
environment: $environment
Expand All @@ -504,7 +504,7 @@ async def upsert_thread(
"""
)
variables = {
"threadId": thread_id,
"id": id,
"metadata": metadata,
"participantId": participant_id,
"environment": environment,
Expand Down
10 changes: 7 additions & 3 deletions chainlit_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,15 @@ def langchain_callback(
**kwargs,
)

def thread(self, original_function=None, *, thread_id: Optional[str] = None):
def thread(
self, original_function=None, *, thread_id: Optional[str] = None, **kwargs
):
if original_function:
return thread_decorator(self, func=original_function, thread_id=thread_id)
return thread_decorator(
self, func=original_function, thread_id=thread_id, **kwargs
)
else:
return ThreadContextManager(self, thread_id=thread_id)
return ThreadContextManager(self, thread_id=thread_id, **kwargs)

def step(
self,
Expand Down
3 changes: 2 additions & 1 deletion chainlit_client/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

if TYPE_CHECKING:
from chainlit_client.step import Step
from chainlit_client.thread import Thread

active_steps_var = ContextVar[List["Step"]]("active_steps", default=[])
active_thread_var = ContextVar[Optional[str]]("active_thread", default=None)
active_thread_var = ContextVar[Optional["Thread"]]("active_thread", default=None)
33 changes: 26 additions & 7 deletions chainlit_client/thread.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect
import uuid
import asyncio
from functools import wraps
from typing import TYPE_CHECKING, Callable, Dict, List, Literal, Optional

Expand All @@ -20,6 +21,7 @@ class Thread:
steps: Optional[List[Step]]
user: Optional["User"]
created_at: Optional[str] # read-only, set by server
needs_upsert: Optional[bool]

def __init__(
self,
Expand All @@ -28,12 +30,14 @@ def __init__(
metadata: Optional[Dict] = {},
tags: Optional[List[str]] = [],
user: Optional["User"] = None,
needs_upsert: "Optional[bool]" = False,
):
self.id = id
self.steps = steps
self.metadata = metadata
self.tags = tags
self.user = user
self.needs_upsert = needs_upsert

def to_dict(self):
return {
Expand All @@ -42,7 +46,6 @@ def to_dict(self):
"tags": self.tags,
"steps": [step.to_dict() for step in self.steps] if self.steps else [],
"participant": self.user.to_dict() if self.user else None,
"createdAt": self.created_at,
}

@classmethod
Expand All @@ -68,38 +71,54 @@ class ThreadContextManager:
def __init__(
self,
client: "ChainlitClient",
thread_id: Optional[str] = None,
thread_id: "Optional[str]" = None,
**kwargs,
):
self.client = client
if thread_id is None:
thread_id = str(uuid.uuid4())
self.thread_id = thread_id
active_thread_var.set(Thread(id=thread_id))
needs_upsert = kwargs != {}
active_thread_var.set(Thread(id=thread_id, needs_upsert=needs_upsert, **kwargs))

async def save(self):
thread = active_thread_var.get()
thread_data = thread.to_dict()
thread_data.pop("steps", None)
thread_data.pop("participant", None)
await self.client.api.upsert_thread(**thread_data)

def __call__(self, func):
return thread_decorator(self.client, func=func, thread_id=self.thread_id)

def __enter__(self) -> Thread:
def __enter__(self) -> "Optional[Thread]":
return active_thread_var.get()

def __exit__(self, exc_type, exc_val, exc_tb):
if (thread := active_thread_var.get()) and thread.needs_upsert:
asyncio.run(self.save())
active_thread_var.set(None)

async def __aenter__(self):
return active_thread_var.get()

async def __aexit__(self, exc_type, exc_val, exc_tb):
if (thread := active_thread_var.get()) and thread.needs_upsert:
await self.save()
active_thread_var.set(None)


def thread_decorator(
client: "ChainlitClient", func: Callable, thread_id: Optional[str] = None
client: "ChainlitClient",
func: Callable,
thread_id: Optional[str] = None,
**decorator_kwargs,
):
if inspect.iscoroutinefunction(func):

@wraps(func)
async def async_wrapper(*args, **kwargs):
with ThreadContextManager(client, thread_id=thread_id):
with ThreadContextManager(client, thread_id=thread_id, **decorator_kwargs):
result = await func(*args, **kwargs)
return result

Expand All @@ -108,7 +127,7 @@ async def async_wrapper(*args, **kwargs):

@wraps(func)
def sync_wrapper(*args, **kwargs):
with ThreadContextManager(client, thread_id=thread_id):
with ThreadContextManager(client, thread_id=thread_id, **decorator_kwargs):
return func(*args, **kwargs)

return sync_wrapper
Expand Down
14 changes: 12 additions & 2 deletions tests/e2e/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ async def test_attachment(self, client):

# @pytest.mark.skip(reason="segmentation fault")
async def test_ingestion(self, client):
with client.thread():
with client.step(name="test_ingestion") as step:
async with client.thread():
async with client.step(name="test_ingestion") as step:
step.metadata = {"foo": "bar"}
assert client.event_processor.event_queue._qsize() == 0
stack = chainlit_client.context.active_steps_var.get()
Expand All @@ -161,3 +161,13 @@ async def test_ingestion(self, client):
assert client.event_processor.event_queue._qsize() == 1
stack = chainlit_client.context.active_steps_var.get()
assert len(stack) == 0

async def create_thread(self, client):
async with client.thread(tags=["foo", "bar"]) as thread:
thread_id = thread.id
return thread_id

async def test_thread_context(self, client):
thread_id = await self.create_thread(client)
new_thread = await client.api.get_thread(id=thread_id)
assert new_thread.tags == ["foo", "bar"]
13 changes: 0 additions & 13 deletions tests/test_threads_context.py

This file was deleted.

0 comments on commit 5a18e5f

Please sign in to comment.