Skip to content

WIP: Pawel/eng 651 threads management #29

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ __pycache__/

# C extensions
*.so
.vscode/

# Distribution / packaging
.Python
Expand Down Expand Up @@ -47,6 +48,7 @@ htmlcov/
.coverage
.coverage.*
.cache
.ruff_cache
nosetests.xml
coverage.xml
*.cover
Expand Down
4 changes: 4 additions & 0 deletions chainlit_client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import nest_asyncio

from .client import ChainlitClient
from .message import Message
from .my_types import * # noqa
from .step import Step
from .thread import Thread
from .version import __version__

nest_asyncio.apply()

__all__ = [
"ChainlitClient",
"Message",
Expand Down
17 changes: 9 additions & 8 deletions chainlit_client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,14 +486,14 @@ async def upsert_thread(
query = (
"""
mutation UpsertThread(
$threadId: String!,
$id: String!,
$metadata: Json,
$participantId: String,
$environment: String,
$tags: [String!],
) {
upsertThread(
id: $threadId
id: $id
metadata: $metadata
participantId: $participantId
environment: $environment
Expand All @@ -507,7 +507,7 @@ async def upsert_thread(
"""
)
variables = {
"threadId": thread_id,
"id": thread_id,
"metadata": metadata,
"participantId": participant_id,
"environment": environment,
Expand Down Expand Up @@ -588,7 +588,7 @@ async def get_thread(self, id: str) -> Optional[Thread]:

return Thread.from_dict(thread) if thread else None

async def delete_thread(self, id: str) -> str:
async def delete_thread(self, id: str) -> bool:
query = """
mutation DeleteThread($thread_id: String!) {
deleteThread(id: $thread_id) {
Expand All @@ -600,8 +600,8 @@ async def delete_thread(self, id: str) -> str:
variables = {"thread_id": id}

result = await self.make_api_call("delete thread", query, variables)

return result["data"]["deleteThread"]["id"]
deleted = bool(result["data"]["deleteThread"])
return deleted

# User Session API

Expand Down Expand Up @@ -1150,7 +1150,7 @@ async def get_step(self, id: str) -> Optional[Step]:

return Step.from_dict(step) if step else None

async def delete_step(self, id: str) -> str:
async def delete_step(self, id: str) -> bool:
query = """
mutation DeleteStep($id: String!) {
deleteStep(id: $id) {
Expand All @@ -1163,7 +1163,8 @@ async def delete_step(self, id: str) -> str:

result = await self.make_api_call("delete step", query, variables)

return result["data"]["deleteStep"]["id"]
deleted = bool(result["data"]["deleteStep"])
return deleted

async def send_steps(self, steps: List[Union[StepDict, "Step"]]) -> "Dict":
query = query_builder(steps)
Expand Down
6 changes: 3 additions & 3 deletions chainlit_client/callback/langchain_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from importlib.metadata import version
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

from chainlit_client.context import active_steps_var, active_thread_id_var
from chainlit_client.context import active_steps_var, active_thread_var
from chainlit_client.my_types import (
ChatGeneration,
CompletionGeneration,
Expand Down Expand Up @@ -356,7 +356,7 @@ def __init__(
self.ignored_runs = set()

self.step_context = active_steps_var.get()
self.thread_context = active_thread_id_var.get()
self.thread_context = active_thread_var.get()

if self.thread_context is None:
raise Exception(
Expand Down Expand Up @@ -463,7 +463,7 @@ def on_chat_model_start(
def _start_trace(self, run: Run) -> None:
super()._start_trace(run)

active_thread_id_var.set(self.thread_context)
active_thread_var.set(self.thread_context)
active_steps_var.set(self.step_context)

if run.run_type in ["chain", "prompt"]:
Expand Down
16 changes: 10 additions & 6 deletions chainlit_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from chainlit_client.api import API
from chainlit_client.callback.langchain_callback import get_langchain_callback
from chainlit_client.context import active_steps_var, active_thread_id_var
from chainlit_client.context import active_steps_var, active_thread_var
from chainlit_client.event_processor import EventProcessor
from chainlit_client.instrumentation.openai import instrument_openai
from chainlit_client.message import Message
Expand Down Expand Up @@ -55,11 +55,15 @@ def langchain_callback(
**kwargs,
)

def thread(self, original_function=None, *, thread_id: Optional[str] = None):
def thread(
self, original_function=None, *, thread_id: Optional[str] = None, **kwargs
):
if original_function:
return thread_decorator(self, func=original_function, thread_id=thread_id)
return thread_decorator(
self, func=original_function, thread_id=thread_id, **kwargs
)
else:
return ThreadContextManager(self, thread_id=thread_id)
return ThreadContextManager(self, thread_id=thread_id, **kwargs)

def step(
self,
Expand Down Expand Up @@ -141,8 +145,8 @@ def get_current_step(self):
else:
return None

def get_current_thread_id(self):
return active_thread_id_var.get()
def get_current_thread(self):
return active_thread_var.get()

def wait_until_queue_empty(self):
self.event_processor.wait_until_queue_empty()
3 changes: 2 additions & 1 deletion chainlit_client/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

if TYPE_CHECKING:
from chainlit_client.step import Step
from chainlit_client.thread import Thread

active_steps_var = ContextVar[List["Step"]]("active_steps", default=[])
active_thread_id_var = ContextVar[Optional[str]]("active_thread", default=None)
active_thread_var = ContextVar[Optional["Thread"]]("active_thread", default=None)
6 changes: 3 additions & 3 deletions chainlit_client/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
if TYPE_CHECKING:
from chainlit_client.event_processor import EventProcessor

from chainlit_client.context import active_steps_var, active_thread_id_var
from chainlit_client.context import active_steps_var, active_thread_var
from chainlit_client.my_types import Attachment, Feedback
from chainlit_client.step import MessageStepType, StepDict

Expand Down Expand Up @@ -75,8 +75,8 @@ def end(self):
self.thread_id = parent_step.thread_id

if not self.thread_id:
if active_thread := active_thread_id_var.get():
self.thread_id = active_thread
if active_thread := active_thread_var.get():
self.thread_id = active_thread.id

if not self.thread_id:
raise Exception("Message must be initialized with a thread_id.")
Expand Down
39 changes: 16 additions & 23 deletions chainlit_client/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from chainlit_client.client import ChainlitClient
from chainlit_client.event_processor import EventProcessor

from chainlit_client.context import active_steps_var, active_thread_id_var
from chainlit_client.context import active_steps_var, active_thread_var
from chainlit_client.my_types import (
Attachment,
AttachmentDict,
Expand Down Expand Up @@ -110,8 +110,8 @@ def start(self):
self.thread_id = parent_step.thread_id

if not self.thread_id:
if active_thread := active_thread_id_var.get():
self.thread_id = active_thread
if active_thread := active_thread_var.get():
self.thread_id = active_thread.id

if not self.thread_id:
raise Exception("Step must be initialized with a thread_id.")
Expand Down Expand Up @@ -218,9 +218,7 @@ def __call__(self, func):
return step_decorator(
self.client,
func=func,
type=self.step_type,
name=self.step_name,
thread_id=self.thread_id,
ctx_manager=self,
)

async def __aenter__(self):
Expand Down Expand Up @@ -258,23 +256,25 @@ def step_decorator(
id: Optional[str] = None,
parent_id: Optional[str] = None,
thread_id: Optional[str] = None,
ctx_manager: Optional[StepContextManager] = None,
):
if not name:
name = func.__name__

if not ctx_manager:
ctx_manager = StepContextManager(
client=client,
type=type,
name=name,
id=id,
parent_id=parent_id,
thread_id=thread_id,
)
# Handle async decorator
if inspect.iscoroutinefunction(func):

@wraps(func)
async def async_wrapper(*args, **kwargs):
with StepContextManager(
client=client,
type=type,
name=name,
id=id,
parent_id=parent_id,
thread_id=thread_id,
) as step:
with ctx_manager as step:
try:
step.input = json.dumps({"args": args, "kwargs": kwargs})
except Exception:
Expand All @@ -292,14 +292,7 @@ async def async_wrapper(*args, **kwargs):
# Handle sync decorator
@wraps(func)
def sync_wrapper(*args, **kwargs):
with StepContextManager(
client=client,
type=type,
name=name,
id=id,
parent_id=parent_id,
thread_id=thread_id,
) as step:
with ctx_manager as step:
try:
step.input = json.dumps({"args": args, "kwargs": kwargs})
except Exception:
Expand Down
58 changes: 44 additions & 14 deletions chainlit_client/thread.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import asyncio
import inspect
import uuid
from functools import wraps
from typing import TYPE_CHECKING, Callable, Dict, List, Literal, Optional

from pydantic.dataclasses import dataclass

from chainlit_client.context import active_thread_id_var
from chainlit_client.context import active_thread_var
from chainlit_client.my_types import User
from chainlit_client.step import Step

Expand All @@ -20,6 +21,7 @@ class Thread:
steps: Optional[List[Step]]
user: Optional["User"]
created_at: Optional[str] # read-only, set by server
needs_upsert: Optional[bool]

def __init__(
self,
Expand All @@ -34,6 +36,7 @@ def __init__(
self.metadata = metadata
self.tags = tags
self.user = user
self.needs_upsert = bool(metadata or tags or user)

def to_dict(self):
return {
Expand All @@ -42,7 +45,6 @@ def to_dict(self):
"tags": self.tags,
"steps": [step.to_dict() for step in self.steps] if self.steps else [],
"participant": self.user.to_dict() if self.user else None,
"createdAt": self.created_at,
}

@classmethod
Expand All @@ -68,39 +70,67 @@ class ThreadContextManager:
def __init__(
self,
client: "ChainlitClient",
thread_id: Optional[str] = None,
thread_id: "Optional[str]" = None,
**kwargs,
):
self.client = client
if thread_id is None:
thread_id = str(uuid.uuid4())
self.thread_id = thread_id
active_thread_id_var.set(thread_id)
self.thread = Thread(id=thread_id)
self.kwargs = kwargs

async def upsert(self):
thread = active_thread_var.get()
thread_data = thread.to_dict()
thread_data_to_upsert = {
"thread_id": thread_data["id"],
}
if metadata := thread_data.get("metadata"):
thread_data_to_upsert["metadata"] = metadata
if tags := thread_data.get("tags"):
thread_data_to_upsert["tags"] = tags
if user := thread_data.get("user"):
thread_data_to_upsert["participant_id"] = user
await self.client.api.upsert_thread(**thread_data_to_upsert)

def __call__(self, func):
return thread_decorator(self.client, func=func, thread_id=self.thread_id)
return thread_decorator(self.client, func=func, ctx_manager=self)

def __enter__(self) -> Thread:
return self.thread
def __enter__(self) -> "Optional[Thread]":
active_thread_var.set(Thread(id=self.thread_id, **self.kwargs))
return active_thread_var.get()

def __exit__(self, exc_type, exc_val, exc_tb):
active_thread_id_var.set(None)
if (thread := active_thread_var.get()) and thread.needs_upsert:
asyncio.run(self.upsert())
active_thread_var.set(None)

async def __aenter__(self):
return self.thread
active_thread_var.set(Thread(id=self.thread_id, **self.kwargs))
return active_thread_var.get()

async def __aexit__(self, exc_type, exc_val, exc_tb):
active_thread_id_var.set(None)
if (thread := active_thread_var.get()) and thread.needs_upsert:
await self.upsert()
active_thread_var.set(None)


def thread_decorator(
client: "ChainlitClient", func: Callable, thread_id: Optional[str] = None
client: "ChainlitClient",
func: Callable,
thread_id: Optional[str] = None,
ctx_manager: Optional[ThreadContextManager] = None,
**decorator_kwargs,
):
if not ctx_manager:
ctx_manager = ThreadContextManager(
client, thread_id=thread_id, **decorator_kwargs
)
if inspect.iscoroutinefunction(func):

@wraps(func)
async def async_wrapper(*args, **kwargs):
with ThreadContextManager(client, thread_id=thread_id):
with ctx_manager:
result = await func(*args, **kwargs)
return result

Expand All @@ -109,7 +139,7 @@ async def async_wrapper(*args, **kwargs):

@wraps(func)
def sync_wrapper(*args, **kwargs):
with ThreadContextManager(client, thread_id=thread_id):
with ctx_manager:
return func(*args, **kwargs)

return sync_wrapper
Expand Down
Loading