From 7b9cefa7444b1624cdd6ff4160b746fda911ae5e Mon Sep 17 00:00:00 2001 From: San Nguyen Date: Wed, 6 Mar 2024 21:26:00 +0900 Subject: [PATCH 1/6] fix: correct typing for _data_layer Signed-off-by: San Nguyen --- backend/chainlit/data/__init__.py | 50 +++++++++---------------------- 1 file changed, 14 insertions(+), 36 deletions(-) diff --git a/backend/chainlit/data/__init__.py b/backend/chainlit/data/__init__.py index 1618f9cdcd..a8921e5f54 100644 --- a/backend/chainlit/data/__init__.py +++ b/backend/chainlit/data/__init__.py @@ -23,17 +23,12 @@ from chainlit.element import Element, ElementDict from chainlit.step import FeedbackDict, StepDict -_data_layer = None - def queue_until_user_message(): def decorator(method): @functools.wraps(method) async def wrapper(self, *args, **kwargs): - if ( - isinstance(context.session, WebsocketSession) - and not context.session.has_first_interaction - ): + if isinstance(context.session, WebsocketSession) and not context.session.has_first_interaction: # Queue the method invocation waiting for the first user message queues = context.session.thread_queues method_name = method.__name__ @@ -69,9 +64,7 @@ async def upsert_feedback( async def create_element(self, element_dict: "ElementDict"): pass - async def get_element( - self, thread_id: str, element_id: str - ) -> Optional["ElementDict"]: + async def get_element(self, thread_id: str, element_id: str) -> Optional["ElementDict"]: pass @queue_until_user_message() @@ -96,12 +89,8 @@ async def get_thread_author(self, thread_id: str) -> str: async def delete_thread(self, thread_id: str): pass - async def list_threads( - self, pagination: "Pagination", filters: "ThreadFilter" - ) -> "PaginatedResponse[ThreadDict]": - return PaginatedResponse( - data=[], pageInfo=PageInfo(hasNextPage=False, endCursor=None) - ) + async def list_threads(self, pagination: "Pagination", filters: "ThreadFilter") -> "PaginatedResponse[ThreadDict]": + return PaginatedResponse(data=[], pageInfo=PageInfo(hasNextPage=False, endCursor=None)) async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]": return None @@ -120,6 +109,9 @@ async def delete_user_session(self, id: str) -> bool: return True +_data_layer: Optional[BaseDataLayer] = None + + class ChainlitDataLayer: def __init__(self, api_key: str, server: Optional[str]): from literalai import LiteralClient @@ -145,9 +137,7 @@ def attachment_to_element_dict(self, attachment: Attachment) -> "ElementDict": "threadId": attachment.thread_id, } - def feedback_to_feedback_dict( - self, feedback: Optional[ClientFeedback] - ) -> "Optional[FeedbackDict]": + def feedback_to_feedback_dict(self, feedback: Optional[ClientFeedback]) -> "Optional[FeedbackDict]": if not feedback: return None return { @@ -160,9 +150,7 @@ def feedback_to_feedback_dict( def step_to_step_dict(self, step: ClientStep) -> "StepDict": metadata = step.metadata or {} - input = (step.input or {}).get("content") or ( - json.dumps(step.input) if step.input and step.input != {} else "" - ) + input = (step.input or {}).get("content") or (json.dumps(step.input) if step.input and step.input != {} else "") output = (step.output or {}).get("content") or ( json.dumps(step.output) if step.output and step.output != {} else "" ) @@ -202,9 +190,7 @@ async def get_user(self, identifier: str) -> Optional[PersistedUser]: async def create_user(self, user: User) -> Optional[PersistedUser]: _user = await self.client.api.get_user(identifier=user.identifier) if not _user: - _user = await self.client.api.create_user( - identifier=user.identifier, metadata=user.metadata - ) + _user = await self.client.api.create_user(identifier=user.identifier, metadata=user.metadata) elif _user.id: await self.client.api.update_user(id=_user.id, metadata=user.metadata) return PersistedUser( @@ -284,9 +270,7 @@ async def create_element(self, element: "Element"): ] ) - async def get_element( - self, thread_id: str, element_id: str - ) -> Optional["ElementDict"]: + async def get_element(self, thread_id: str, element_id: str) -> Optional["ElementDict"]: attachment = await self.client.api.get_attachment(id=element_id) if not attachment: return None @@ -345,23 +329,17 @@ async def get_thread_author(self, thread_id: str) -> str: async def delete_thread(self, thread_id: str): await self.client.api.delete_thread(id=thread_id) - async def list_threads( - self, pagination: "Pagination", filters: "ThreadFilter" - ) -> "PaginatedResponse[ThreadDict]": + async def list_threads(self, pagination: "Pagination", filters: "ThreadFilter") -> "PaginatedResponse[ThreadDict]": if not filters.userIdentifier: raise ValueError("userIdentifier is required") client_filters = ClientThreadFilter( - participantsIdentifier=StringListFilter( - operator="in", value=[filters.userIdentifier] - ), + participantsIdentifier=StringListFilter(operator="in", value=[filters.userIdentifier]), ) if filters.search: client_filters.search = StringFilter(operator="ilike", value=filters.search) if filters.feedback: - client_filters.feedbacksValue = NumberListFilter( - operator="in", value=[filters.feedback] - ) + client_filters.feedbacksValue = NumberListFilter(operator="in", value=[filters.feedback]) return await self.client.api.list_threads( first=pagination.first, after=pagination.cursor, filters=client_filters ) From f5df914f6ee72ccf1b978d58b900f6384cd8082d Mon Sep 17 00:00:00 2001 From: San Nguyen Date: Thu, 7 Mar 2024 00:59:13 +0900 Subject: [PATCH 2/6] draft implement for mongodb data layer Signed-off-by: San Nguyen --- backend/chainlit/config.py | 58 +++--- backend/chainlit/data/__init__.py | 3 +- backend/chainlit/data/mongodb.py | 281 ++++++++++++++++++++++++++++++ 3 files changed, 312 insertions(+), 30 deletions(-) create mode 100644 backend/chainlit/data/mongodb.py diff --git a/backend/chainlit/config.py b/backend/chainlit/config.py index f56b96ca7a..a642bd8643 100644 --- a/backend/chainlit/config.py +++ b/backend/chainlit/config.py @@ -3,7 +3,8 @@ import sys from importlib import util from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional +from chainlit import data import tomli from chainlit.logger import logger @@ -50,7 +51,7 @@ # Enable third parties caching (e.g LangChain cache) cache = false -# Authorized origins +# Authorized origins allow_origins = ["*"] # Follow symlink for asset mount (see https://github.com/Chainlit/chainlit/issues/317) @@ -75,6 +76,13 @@ # See all languages here https://github.com/JamesBrill/react-speech-recognition/blob/HEAD/docs/API.md#language-string # language = "en-US" +[data_layer] +# Database to use. Currently Literal and MongoDB are supported. +database = "literal" + +# Object storage to use. Currently Literal and S3 are supported. +object_storage = "literal" + [UI] # Name of the app and chatbot. name = "Chatbot" @@ -215,9 +223,7 @@ class CodeSettings: # Bunch of callbacks defined by the developer password_auth_callback: Optional[Callable[[str, str], Optional["User"]]] = None header_auth_callback: Optional[Callable[[Headers], Optional["User"]]] = None - oauth_callback: Optional[ - Callable[[str, str, Dict[str, str], "User"], Optional["User"]] - ] = None + oauth_callback: Optional[Callable[[str, str, Dict[str, str], "User"], Optional["User"]]] = None on_logout: Optional[Callable[["Request", "Response"], Any]] = None on_stop: Optional[Callable[[], Any]] = None on_chat_start: Optional[Callable[[], Any]] = None @@ -226,9 +232,13 @@ class CodeSettings: on_message: Optional[Callable[[str], Any]] = None author_rename: Optional[Callable[[str], str]] = None on_settings_update: Optional[Callable[[Dict[str, Any]], Any]] = None - set_chat_profiles: Optional[Callable[[Optional["User"]], List["ChatProfile"]]] = ( - None - ) + set_chat_profiles: Optional[Callable[[Optional["User"]], List["ChatProfile"]]] = None + + +@dataclass() +class DataLayerSettings: + database: Literal["literal", "mongodb"] + object_storage: Literal["literal", "s3"] @dataclass() @@ -259,25 +269,20 @@ class ChainlitConfig: ui: UISettings project: ProjectSettings code: CodeSettings + data_layer: DataLayerSettings def load_translation(self, language: str): translation = {} default_language = "en-US" - translation_lib_file_path = os.path.join( - config_translation_dir, f"{language}.json" - ) - default_translation_lib_file_path = os.path.join( - config_translation_dir, f"{default_language}.json" - ) + translation_lib_file_path = os.path.join(config_translation_dir, f"{language}.json") + default_translation_lib_file_path = os.path.join(config_translation_dir, f"{default_language}.json") if os.path.exists(translation_lib_file_path): with open(translation_lib_file_path, "r", encoding="utf-8") as f: translation = json.load(f) elif os.path.exists(default_translation_lib_file_path): - logger.warning( - f"Translation file for {language} not found. Using default translation {default_language}." - ) + logger.warning(f"Translation file for {language} not found. Using default translation {default_language}.") with open(default_translation_lib_file_path, "r", encoding="utf-8") as f: translation = json.load(f) @@ -296,9 +301,7 @@ def init_config(log=False): if not os.path.exists(config_translation_dir): os.makedirs(config_translation_dir, exist_ok=True) - logger.info( - f"Created default translation directory at {config_translation_dir}" - ) + logger.info(f"Created default translation directory at {config_translation_dir}") for file in os.listdir(TRANSLATIONS_DIR): if file.endswith(".json"): @@ -324,11 +327,7 @@ def load_module(target: str, force_refresh: bool = False): if force_refresh: # Clear the modules related to the app from sys.modules for module_name, module in list(sys.modules.items()): - if ( - hasattr(module, "__file__") - and module.__file__ - and module.__file__.startswith(target_dir) - ): + if hasattr(module, "__file__") and module.__file__ and module.__file__.startswith(target_dir): sys.modules.pop(module_name, None) spec = util.spec_from_file_location(target, target) @@ -355,11 +354,10 @@ def load_settings(): features_settings = toml_dict.get("features", {}) ui_settings = toml_dict.get("UI", {}) meta = toml_dict.get("meta") + data_layer = toml_dict.get("data_layer") if not meta or meta.get("generated_by") <= "0.3.0": - raise ValueError( - "Your config file is outdated. Please delete it and restart the app to regenerate it." - ) + raise ValueError("Your config file is outdated. Please delete it and restart the app to regenerate it.") lc_cache_path = os.path.join(config_dir, ".langchain.db") @@ -372,11 +370,14 @@ def load_settings(): ui_settings = UISettings(**ui_settings) + data_layer_settings = DataLayerSettings(**data_layer) + return { "features": features_settings, "ui": ui_settings, "project": project_settings, "code": CodeSettings(action_callbacks={}), + "data_layer": data_layer_settings, } @@ -392,6 +393,7 @@ def reload_config(): config.code = settings["code"] config.ui = settings["ui"] config.project = settings["project"] + config.data_layer = settings["data_layer"] def load_config(): diff --git a/backend/chainlit/data/__init__.py b/backend/chainlit/data/__init__.py index a8921e5f54..0753fcb263 100644 --- a/backend/chainlit/data/__init__.py +++ b/backend/chainlit/data/__init__.py @@ -112,7 +112,7 @@ async def delete_user_session(self, id: str) -> bool: _data_layer: Optional[BaseDataLayer] = None -class ChainlitDataLayer: +class ChainlitDataLayer(BaseDataLayer): def __init__(self, api_key: str, server: Optional[str]): from literalai import LiteralClient @@ -173,7 +173,6 @@ def step_to_step_dict(self, step: ClientStep) -> "StepDict": "language": metadata.get("language"), "isError": metadata.get("isError", False), "waitForAnswer": metadata.get("waitForAnswer", False), - "feedback": self.feedback_to_feedback_dict(step.feedback), } async def get_user(self, identifier: str) -> Optional[PersistedUser]: diff --git a/backend/chainlit/data/mongodb.py b/backend/chainlit/data/mongodb.py new file mode 100644 index 0000000000..e8b22480d8 --- /dev/null +++ b/backend/chainlit/data/mongodb.py @@ -0,0 +1,281 @@ +from chainlit.data import BaseDataLayer, queue_until_user_message +from chainlit.types import Feedback, Pagination, ThreadDict, ThreadFilter +from chainlit.user import PersistedUser, User +from literalai import Attachment, Feedback as _ClientFeedback +from literalai import PageInfo, PaginatedResponse +from literalai import Step as _ClientStep +from pymongo import MongoClient, DESCENDING +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional +from datetime import datetime +import json + + +if TYPE_CHECKING: + from chainlit.element import Element, ElementDict + from chainlit.step import FeedbackDict as _FeedbackDict, StepDict + + class ClientFeedback(_ClientFeedback): + """Fix typing for _ClientFeedback""" + + value: Literal[-1, 0, 0] + + class ClientStep(_ClientStep): + """Fix typing for _ClientStep""" + + feedback: Optional[ClientFeedback] + + class FeedbackDict(_FeedbackDict): + """Augment typing for _FeedbackDict""" + + id: str + forId: str + + +class MongoDataLayer(BaseDataLayer): + def __init__(self, db_url: str): + # Connect to the database + self.client = MongoClient(db_url) # type: MongoClient + self.db = self.client.get_database() + + # Get collection references + self.users_collection = self.db.get_collection("users") + self.elements_collection = self.db.get_collection("elements") + self.steps_collection = self.db.get_collection("steps") + self.threads_collection = self.db.get_collection("threads") + + def attachment_to_element_dict(self, attachment: Attachment) -> "ElementDict": + metadata = attachment.metadata or {} + return { + "chainlitKey": None, + "display": metadata.get("display", "side"), + "language": metadata.get("language"), + "page": metadata.get("page"), + "size": metadata.get("size"), + "type": metadata.get("type", "file"), + "forId": attachment.step_id, + "id": attachment.id or "", + "mime": attachment.mime, + "name": attachment.name or "", + "objectKey": attachment.object_key, + "url": attachment.url, + "threadId": attachment.thread_id, + } + + def feedback_to_feedback_dict(self, feedback: Optional[ClientFeedback]) -> "Optional[FeedbackDict]": + if not feedback: + return None + return { + "id": feedback.id or "", + "forId": feedback.step_id or "", + "value": feedback.value or 0, + "comment": feedback.comment, + "strategy": "BINARY", + } + + def step_to_step_dict(self, step: ClientStep) -> "StepDict": + metadata = step.metadata or {} + input = (step.input or {}).get("content") or (json.dumps(step.input) if step.input and step.input != {} else "") + output = (step.output or {}).get("content") or ( + json.dumps(step.output) if step.output and step.output != {} else "" + ) + return { + "createdAt": step.created_at, + "id": step.id or "", + "threadId": step.thread_id or "", + "parentId": step.parent_id, + "feedback": self.feedback_to_feedback_dict(step.feedback), + "start": step.start_time, + "end": step.end_time, + "type": step.type or "undefined", + "name": step.name or "", + "generation": step.generation.to_dict() if step.generation else None, + "input": input, + "output": output, + "showInput": metadata.get("showInput", False), + "disableFeedback": metadata.get("disableFeedback", False), + "indent": metadata.get("indent"), + "language": metadata.get("language"), + "isError": metadata.get("isError", False), + "waitForAnswer": metadata.get("waitForAnswer", False), + } + + async def get_user(self, identifier: str) -> Optional[PersistedUser]: + user_data = self.users_collection.find_one({"identifier": identifier}) + if user_data: + return PersistedUser( + id=user_data["_id"], + identifier=user_data["identifier"], + metadata=user_data.get("metadata", {}), + createdAt=user_data.get("createdAt"), + ) + return None + + async def create_user(self, user: User) -> Optional[PersistedUser]: + user_data: Dict[str, Any] = { + "identifier": user.identifier, + "metadata": user.metadata, + "createdAt": datetime.utcnow(), + } + result = self.users_collection.insert_one(user_data) + user_data["_id"] = result.inserted_id + return PersistedUser(**user_data) + + async def upsert_feedback(self, feedback: Feedback): + feedback_data = { + "step_id": feedback.forId, + "value": feedback.value, + "strategy": feedback.strategy, + "comment": feedback.comment, + } + if feedback.id: + self.steps_collection.update_one({"_id": feedback.id}, {"$set": {"feedback": feedback_data}}) + return feedback.id + else: + result = self.steps_collection.update_one({"_id": feedback.forId}, {"$set": {"feedback": feedback_data}}) + return result.upserted_id or "" + + @queue_until_user_message() + async def create_element(self, element: "Element"): + # TODO: Support file upload from user + pass + + async def get_element(self, thread_id: str, element_id: str) -> Optional["ElementDict"]: + element_data = self.elements_collection.find_one({"id": element_id}) + if not element_data: + return None + return element_data + + @queue_until_user_message() + async def delete_element(self, element_id: str): + self.elements_collection.delete_one({"id": element_id}) + + @queue_until_user_message() + async def create_step(self, step_dict: "StepDict"): + metadata = { + "disableFeedback": step_dict.get("disableFeedback"), + "isError": step_dict.get("isError"), + "waitForAnswer": step_dict.get("waitForAnswer"), + "language": step_dict.get("language"), + "showInput": step_dict.get("showInput"), + } + + step_data = { + "createdAt": step_dict.get("createdAt"), + "id": step_dict.get("id"), + "threadId": step_dict.get("threadId"), + "parentId": step_dict.get("parentId"), + "feedback": step_dict.get("feedback"), + "start": step_dict.get("start"), + "end": step_dict.get("end"), + "type": step_dict.get("type"), + "name": step_dict.get("name"), + "generation": step_dict.get("generation"), + "input": step_dict.get("input"), + "output": step_dict.get("output"), + "metadata": metadata, + } + + self.steps_collection.insert_one(step_data) + + @queue_until_user_message() + async def update_step(self, step_dict: "StepDict"): + self.steps_collection.update_one({"_id": step_dict["id"]}, {"$set": step_dict}) + + @queue_until_user_message() + async def delete_step(self, step_id: str): + self.steps_collection.delete_one({"_id": step_id}) + + async def get_thread_author(self, thread_id: str) -> str: + thread = self.threads_collection.find_one({"_id": thread_id}) + if not thread: + return "" + user = thread.get("user") + if not user: + return "" + return user.get("identifier") or "" + + async def delete_thread(self, thread_id: str): + self.threads_collection.delete_one({"_id": thread_id}) + self.steps_collection.delete_many({"threadId": thread_id}) + self.elements_collection.delete_many({"threadId": thread_id}) + + async def list_threads(self, pagination: Pagination, filters: ThreadFilter) -> PaginatedResponse[ThreadDict]: + if not filters.userIdentifier: + raise ValueError("userIdentifier is required") + + query: Dict[str, Any] = {"participants.identifier": filters.userIdentifier} + if filters.search: + query["$text"] = {"$search": filters.search} + if filters.feedback: + query["feedback.value"] = filters.feedback + + sort = [("createdAt", DESCENDING)] + if pagination.cursor: + query["_id"] = {"$lt": pagination.cursor} + + threads_data = list(self.threads_collection.find(query).sort(sort).limit(pagination.first)) + + threads: List[ThreadDict] = [ + { + "id": thread["_id"], + "createdAt": thread["createdAt"], + "name": thread.get("name"), + "user": thread.get("user"), + "tags": thread.get("tags"), + "metadata": thread.get("metadata"), + "steps": list(self.steps_collection.find({"threadId": thread["_id"]})), + "elements": list(self.elements_collection.find({"threadId": thread["_id"]})), + } + for thread in threads_data + ] + + has_next_page = len(threads) == pagination.first + end_cursor = threads[-1]["id"] if has_next_page else None + + return PaginatedResponse( + data=threads, + pageInfo=PageInfo(hasNextPage=has_next_page, endCursor=end_cursor), + ) + + async def get_thread(self, thread_id: str) -> Optional[ThreadDict]: + thread_data = self.threads_collection.find_one({"_id": thread_id}) + if not thread_data: + return None + + elements = list(self.elements_collection.find({"threadId": thread_data["_id"]})) + steps = list(self.steps_collection.find({"threadId": thread_data["_id"]})) + + return { + "createdAt": thread_data["createdAt"], + "id": thread_data["_id"], + "name": thread_data.get("name"), + "steps": [self.step_to_step_dict(step) for step in steps], + "elements": [self.attachment_to_element_dict(element) for element in elements], + "metadata": thread_data.get("metadata"), + "user": thread_data.get("user"), + "tags": thread_data.get("tags"), + } + + async def update_thread( + self, + thread_id: str, + name: Optional[str] = None, + user_id: Optional[str] = None, + metadata: Optional[Dict] = None, + tags: Optional[List[str]] = None, + ): + update_data: Dict[str, Any] = {} + if name is not None: + update_data["name"] = name + if user_id is not None: + update_data["user"] = {"identifier": user_id} + if metadata is not None: + update_data["metadata"] = metadata + if tags is not None: + update_data["tags"] = tags + + self.threads_collection.update_one({"_id": thread_id}, {"$set": update_data}) + + async def delete_user_session(self, id: str) -> bool: + result = self.threads_collection.delete_many({"metadata.id": id}) + return result.deleted_count > 0 From 3f181a0ff4d117321f986f8a840a945108a47824 Mon Sep 17 00:00:00 2001 From: San Nguyen Date: Sat, 9 Mar 2024 17:51:18 +0900 Subject: [PATCH 3/6] revert config Signed-off-by: San Nguyen --- backend/chainlit/config.py | 58 +++-- backend/chainlit/data/__init__.py | 404 +----------------------------- backend/chainlit/data/base.py | 98 ++++++++ backend/chainlit/data/chainlit.py | 304 ++++++++++++++++++++++ backend/chainlit/data/mongodb.py | 114 ++++++--- 5 files changed, 525 insertions(+), 453 deletions(-) create mode 100644 backend/chainlit/data/base.py create mode 100644 backend/chainlit/data/chainlit.py diff --git a/backend/chainlit/config.py b/backend/chainlit/config.py index a642bd8643..f56b96ca7a 100644 --- a/backend/chainlit/config.py +++ b/backend/chainlit/config.py @@ -3,8 +3,7 @@ import sys from importlib import util from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional -from chainlit import data +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional import tomli from chainlit.logger import logger @@ -51,7 +50,7 @@ # Enable third parties caching (e.g LangChain cache) cache = false -# Authorized origins +# Authorized origins allow_origins = ["*"] # Follow symlink for asset mount (see https://github.com/Chainlit/chainlit/issues/317) @@ -76,13 +75,6 @@ # See all languages here https://github.com/JamesBrill/react-speech-recognition/blob/HEAD/docs/API.md#language-string # language = "en-US" -[data_layer] -# Database to use. Currently Literal and MongoDB are supported. -database = "literal" - -# Object storage to use. Currently Literal and S3 are supported. -object_storage = "literal" - [UI] # Name of the app and chatbot. name = "Chatbot" @@ -223,7 +215,9 @@ class CodeSettings: # Bunch of callbacks defined by the developer password_auth_callback: Optional[Callable[[str, str], Optional["User"]]] = None header_auth_callback: Optional[Callable[[Headers], Optional["User"]]] = None - oauth_callback: Optional[Callable[[str, str, Dict[str, str], "User"], Optional["User"]]] = None + oauth_callback: Optional[ + Callable[[str, str, Dict[str, str], "User"], Optional["User"]] + ] = None on_logout: Optional[Callable[["Request", "Response"], Any]] = None on_stop: Optional[Callable[[], Any]] = None on_chat_start: Optional[Callable[[], Any]] = None @@ -232,13 +226,9 @@ class CodeSettings: on_message: Optional[Callable[[str], Any]] = None author_rename: Optional[Callable[[str], str]] = None on_settings_update: Optional[Callable[[Dict[str, Any]], Any]] = None - set_chat_profiles: Optional[Callable[[Optional["User"]], List["ChatProfile"]]] = None - - -@dataclass() -class DataLayerSettings: - database: Literal["literal", "mongodb"] - object_storage: Literal["literal", "s3"] + set_chat_profiles: Optional[Callable[[Optional["User"]], List["ChatProfile"]]] = ( + None + ) @dataclass() @@ -269,20 +259,25 @@ class ChainlitConfig: ui: UISettings project: ProjectSettings code: CodeSettings - data_layer: DataLayerSettings def load_translation(self, language: str): translation = {} default_language = "en-US" - translation_lib_file_path = os.path.join(config_translation_dir, f"{language}.json") - default_translation_lib_file_path = os.path.join(config_translation_dir, f"{default_language}.json") + translation_lib_file_path = os.path.join( + config_translation_dir, f"{language}.json" + ) + default_translation_lib_file_path = os.path.join( + config_translation_dir, f"{default_language}.json" + ) if os.path.exists(translation_lib_file_path): with open(translation_lib_file_path, "r", encoding="utf-8") as f: translation = json.load(f) elif os.path.exists(default_translation_lib_file_path): - logger.warning(f"Translation file for {language} not found. Using default translation {default_language}.") + logger.warning( + f"Translation file for {language} not found. Using default translation {default_language}." + ) with open(default_translation_lib_file_path, "r", encoding="utf-8") as f: translation = json.load(f) @@ -301,7 +296,9 @@ def init_config(log=False): if not os.path.exists(config_translation_dir): os.makedirs(config_translation_dir, exist_ok=True) - logger.info(f"Created default translation directory at {config_translation_dir}") + logger.info( + f"Created default translation directory at {config_translation_dir}" + ) for file in os.listdir(TRANSLATIONS_DIR): if file.endswith(".json"): @@ -327,7 +324,11 @@ def load_module(target: str, force_refresh: bool = False): if force_refresh: # Clear the modules related to the app from sys.modules for module_name, module in list(sys.modules.items()): - if hasattr(module, "__file__") and module.__file__ and module.__file__.startswith(target_dir): + if ( + hasattr(module, "__file__") + and module.__file__ + and module.__file__.startswith(target_dir) + ): sys.modules.pop(module_name, None) spec = util.spec_from_file_location(target, target) @@ -354,10 +355,11 @@ def load_settings(): features_settings = toml_dict.get("features", {}) ui_settings = toml_dict.get("UI", {}) meta = toml_dict.get("meta") - data_layer = toml_dict.get("data_layer") if not meta or meta.get("generated_by") <= "0.3.0": - raise ValueError("Your config file is outdated. Please delete it and restart the app to regenerate it.") + raise ValueError( + "Your config file is outdated. Please delete it and restart the app to regenerate it." + ) lc_cache_path = os.path.join(config_dir, ".langchain.db") @@ -370,14 +372,11 @@ def load_settings(): ui_settings = UISettings(**ui_settings) - data_layer_settings = DataLayerSettings(**data_layer) - return { "features": features_settings, "ui": ui_settings, "project": project_settings, "code": CodeSettings(action_callbacks={}), - "data_layer": data_layer_settings, } @@ -393,7 +392,6 @@ def reload_config(): config.code = settings["code"] config.ui = settings["ui"] config.project = settings["project"] - config.data_layer = settings["data_layer"] def load_config(): diff --git a/backend/chainlit/data/__init__.py b/backend/chainlit/data/__init__.py index 0753fcb263..eebfaa6ade 100644 --- a/backend/chainlit/data/__init__.py +++ b/backend/chainlit/data/__init__.py @@ -1,404 +1,24 @@ -import functools -import json import os -from collections import deque -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import Optional -import aiofiles from chainlit.config import config -from chainlit.context import context -from chainlit.logger import logger -from chainlit.session import WebsocketSession -from chainlit.types import Feedback, Pagination, ThreadDict, ThreadFilter -from chainlit.user import PersistedUser, User, UserDict -from literalai import Attachment -from literalai import Feedback as ClientFeedback -from literalai import PageInfo, PaginatedResponse -from literalai import Step as ClientStep -from literalai.step import StepDict as ClientStepDict -from literalai.thread import NumberListFilter, StringFilter, StringListFilter -from literalai.thread import ThreadFilter as ClientThreadFilter - -if TYPE_CHECKING: - from chainlit.element import Element, ElementDict - from chainlit.step import FeedbackDict, StepDict - - -def queue_until_user_message(): - def decorator(method): - @functools.wraps(method) - async def wrapper(self, *args, **kwargs): - if isinstance(context.session, WebsocketSession) and not context.session.has_first_interaction: - # Queue the method invocation waiting for the first user message - queues = context.session.thread_queues - method_name = method.__name__ - if method_name not in queues: - queues[method_name] = deque() - queues[method_name].append((method, self, args, kwargs)) - - else: - # Otherwise, Execute the method immediately - return await method(self, *args, **kwargs) - - return wrapper - - return decorator - - -class BaseDataLayer: - """Base class for data persistence.""" - - async def get_user(self, identifier: str) -> Optional["PersistedUser"]: - return None - - async def create_user(self, user: "User") -> Optional["PersistedUser"]: - pass - - async def upsert_feedback( - self, - feedback: Feedback, - ) -> str: - return "" - - @queue_until_user_message() - async def create_element(self, element_dict: "ElementDict"): - pass - - async def get_element(self, thread_id: str, element_id: str) -> Optional["ElementDict"]: - pass - - @queue_until_user_message() - async def delete_element(self, element_id: str): - pass - - @queue_until_user_message() - async def create_step(self, step_dict: "StepDict"): - pass - - @queue_until_user_message() - async def update_step(self, step_dict: "StepDict"): - pass - - @queue_until_user_message() - async def delete_step(self, step_id: str): - pass - - async def get_thread_author(self, thread_id: str) -> str: - return "" - - async def delete_thread(self, thread_id: str): - pass - - async def list_threads(self, pagination: "Pagination", filters: "ThreadFilter") -> "PaginatedResponse[ThreadDict]": - return PaginatedResponse(data=[], pageInfo=PageInfo(hasNextPage=False, endCursor=None)) - - async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]": - return None - - async def update_thread( - self, - thread_id: str, - name: Optional[str] = None, - user_id: Optional[str] = None, - metadata: Optional[Dict] = None, - tags: Optional[List[str]] = None, - ): - pass - - async def delete_user_session(self, id: str) -> bool: - return True +from chainlit.data.base import BaseDataLayer +from chainlit.data.chainlit import ChainlitDataLayer +from chainlit.data.mongodb import MongoDBDataLayer _data_layer: Optional[BaseDataLayer] = None - -class ChainlitDataLayer(BaseDataLayer): - def __init__(self, api_key: str, server: Optional[str]): - from literalai import LiteralClient - - self.client = LiteralClient(api_key=api_key, url=server) - logger.info("Chainlit data layer initialized") - - def attachment_to_element_dict(self, attachment: Attachment) -> "ElementDict": - metadata = attachment.metadata or {} - return { - "chainlitKey": None, - "display": metadata.get("display", "side"), - "language": metadata.get("language"), - "page": metadata.get("page"), - "size": metadata.get("size"), - "type": metadata.get("type", "file"), - "forId": attachment.step_id, - "id": attachment.id or "", - "mime": attachment.mime, - "name": attachment.name or "", - "objectKey": attachment.object_key, - "url": attachment.url, - "threadId": attachment.thread_id, - } - - def feedback_to_feedback_dict(self, feedback: Optional[ClientFeedback]) -> "Optional[FeedbackDict]": - if not feedback: - return None - return { - "id": feedback.id or "", - "forId": feedback.step_id or "", - "value": feedback.value or 0, # type: ignore - "comment": feedback.comment, - "strategy": "BINARY", - } - - def step_to_step_dict(self, step: ClientStep) -> "StepDict": - metadata = step.metadata or {} - input = (step.input or {}).get("content") or (json.dumps(step.input) if step.input and step.input != {} else "") - output = (step.output or {}).get("content") or ( - json.dumps(step.output) if step.output and step.output != {} else "" - ) - return { - "createdAt": step.created_at, - "id": step.id or "", - "threadId": step.thread_id or "", - "parentId": step.parent_id, - "feedback": self.feedback_to_feedback_dict(step.feedback), - "start": step.start_time, - "end": step.end_time, - "type": step.type or "undefined", - "name": step.name or "", - "generation": step.generation.to_dict() if step.generation else None, - "input": input, - "output": output, - "showInput": metadata.get("showInput", False), - "disableFeedback": metadata.get("disableFeedback", False), - "indent": metadata.get("indent"), - "language": metadata.get("language"), - "isError": metadata.get("isError", False), - "waitForAnswer": metadata.get("waitForAnswer", False), - } - - async def get_user(self, identifier: str) -> Optional[PersistedUser]: - user = await self.client.api.get_user(identifier=identifier) - if not user: - return None - return PersistedUser( - id=user.id or "", - identifier=user.identifier or "", - metadata=user.metadata, - createdAt=user.created_at or "", - ) - - async def create_user(self, user: User) -> Optional[PersistedUser]: - _user = await self.client.api.get_user(identifier=user.identifier) - if not _user: - _user = await self.client.api.create_user(identifier=user.identifier, metadata=user.metadata) - elif _user.id: - await self.client.api.update_user(id=_user.id, metadata=user.metadata) - return PersistedUser( - id=_user.id or "", - identifier=_user.identifier or "", - metadata=_user.metadata, - createdAt=_user.created_at or "", - ) - - async def upsert_feedback( - self, - feedback: Feedback, - ): - if feedback.id: - await self.client.api.update_feedback( - id=feedback.id, - update_params={ - "comment": feedback.comment, - "strategy": feedback.strategy, - "value": feedback.value, - }, - ) - return feedback.id - else: - created = await self.client.api.create_feedback( - step_id=feedback.forId, - value=feedback.value, - comment=feedback.comment, - strategy=feedback.strategy, - ) - return created.id or "" - - @queue_until_user_message() - async def create_element(self, element: "Element"): - metadata = { - "size": element.size, - "language": element.language, - "display": element.display, - "type": element.type, - "page": getattr(element, "page", None), - } - - if not element.for_id: - return - - object_key = None - - if not element.url: - if element.path: - async with aiofiles.open(element.path, "rb") as f: - content = await f.read() # type: Union[bytes, str] - elif element.content: - content = element.content - else: - raise ValueError("Either path or content must be provided") - uploaded = await self.client.api.upload_file( - content=content, mime=element.mime, thread_id=element.thread_id - ) - object_key = uploaded["object_key"] - - await self.client.api.send_steps( - [ - { - "id": element.for_id, - "threadId": element.thread_id, - "attachments": [ - { - "id": element.id, - "name": element.name, - "metadata": metadata, - "mime": element.mime, - "url": element.url, - "objectKey": object_key, - } - ], - } - ] - ) - - async def get_element(self, thread_id: str, element_id: str) -> Optional["ElementDict"]: - attachment = await self.client.api.get_attachment(id=element_id) - if not attachment: - return None - return self.attachment_to_element_dict(attachment) - - @queue_until_user_message() - async def delete_element(self, element_id: str): - await self.client.api.delete_attachment(id=element_id) - - @queue_until_user_message() - async def create_step(self, step_dict: "StepDict"): - metadata = { - "disableFeedback": step_dict.get("disableFeedback"), - "isError": step_dict.get("isError"), - "waitForAnswer": step_dict.get("waitForAnswer"), - "language": step_dict.get("language"), - "showInput": step_dict.get("showInput"), - } - - step: ClientStepDict = { - "createdAt": step_dict.get("createdAt"), - "startTime": step_dict.get("start"), - "endTime": step_dict.get("end"), - "generation": step_dict.get("generation"), - "id": step_dict.get("id"), - "parentId": step_dict.get("parentId"), - "name": step_dict.get("name"), - "threadId": step_dict.get("threadId"), - "type": step_dict.get("type"), - "metadata": metadata, - } - if step_dict.get("input"): - step["input"] = {"content": step_dict.get("input")} - if step_dict.get("output"): - step["output"] = {"content": step_dict.get("output")} - - await self.client.api.send_steps([step]) - - @queue_until_user_message() - async def update_step(self, step_dict: "StepDict"): - await self.create_step(step_dict) - - @queue_until_user_message() - async def delete_step(self, step_id: str): - await self.client.api.delete_step(id=step_id) - - async def get_thread_author(self, thread_id: str) -> str: - thread = await self.get_thread(thread_id) - if not thread: - return "" - user = thread.get("user") - if not user: - return "" - return user.get("identifier") or "" - - async def delete_thread(self, thread_id: str): - await self.client.api.delete_thread(id=thread_id) - - async def list_threads(self, pagination: "Pagination", filters: "ThreadFilter") -> "PaginatedResponse[ThreadDict]": - if not filters.userIdentifier: - raise ValueError("userIdentifier is required") - - client_filters = ClientThreadFilter( - participantsIdentifier=StringListFilter(operator="in", value=[filters.userIdentifier]), - ) - if filters.search: - client_filters.search = StringFilter(operator="ilike", value=filters.search) - if filters.feedback: - client_filters.feedbacksValue = NumberListFilter(operator="in", value=[filters.feedback]) - return await self.client.api.list_threads( - first=pagination.first, after=pagination.cursor, filters=client_filters - ) - - async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]": - thread = await self.client.api.get_thread(id=thread_id) - if not thread: - return None - elements = [] # List[ElementDict] - steps = [] # List[StepDict] - if thread.steps: - for step in thread.steps: - if config.ui.hide_cot and step.parent_id: - continue - for attachment in step.attachments: - elements.append(self.attachment_to_element_dict(attachment)) - if not config.features.prompt_playground and step.generation: - step.generation = None - steps.append(self.step_to_step_dict(step)) - - user = None # type: Optional["UserDict"] - - if thread.user: - user = { - "id": thread.user.id or "", - "identifier": thread.user.identifier or "", - "metadata": thread.user.metadata, - } - - return { - "createdAt": thread.created_at or "", - "id": thread.id, - "name": thread.name or None, - "steps": steps, - "elements": elements, - "metadata": thread.metadata, - "user": user, - "tags": thread.tags, - } - - async def update_thread( - self, - thread_id: str, - name: Optional[str] = None, - user_id: Optional[str] = None, - metadata: Optional[Dict] = None, - tags: Optional[List[str]] = None, - ): - await self.client.api.upsert_thread( - thread_id=thread_id, - name=name, - participant_id=user_id, - metadata=metadata, - tags=tags, - ) - - -if api_key := os.environ.get("LITERAL_API_KEY"): +if config.data_layer.database == "chainlit": + api_key = os.environ.get("LITERAL_API_KEY") + assert api_key is not None server = os.environ.get("LITERAL_SERVER") _data_layer = ChainlitDataLayer(api_key=api_key, server=server) +if config.data_layer.database == "mongodb": + if config.data_layer.object_storage != "s3": + raise ValueError("MongoDB data layer requires an S3 object storage") + db_url = os.environ.get("CHAINLIT_MONGODB_URL") + _data_layer = MongoDBDataLayer(db_url) def get_data_layer(): diff --git a/backend/chainlit/data/base.py b/backend/chainlit/data/base.py new file mode 100644 index 0000000000..c671378850 --- /dev/null +++ b/backend/chainlit/data/base.py @@ -0,0 +1,98 @@ +import functools +from collections import deque +from typing import TYPE_CHECKING, Dict, List, Optional + +from chainlit.context import context +from chainlit.session import WebsocketSession +from chainlit.types import Feedback, Pagination, ThreadDict, ThreadFilter +from chainlit.user import PersistedUser, User +from literalai import PageInfo, PaginatedResponse + +if TYPE_CHECKING: + from chainlit.element import ElementDict + from chainlit.step import StepDict + + +def queue_until_user_message(): + def decorator(method): + @functools.wraps(method) + async def wrapper(self, *args, **kwargs): + if isinstance(context.session, WebsocketSession) and not context.session.has_first_interaction: + # Queue the method invocation waiting for the first user message + queues = context.session.thread_queues + method_name = method.__name__ + if method_name not in queues: + queues[method_name] = deque() + queues[method_name].append((method, self, args, kwargs)) + + else: + # Otherwise, Execute the method immediately + return await method(self, *args, **kwargs) + + return wrapper + + return decorator + + +class BaseDataLayer: + """Base class for data persistence.""" + + async def get_user(self, identifier: str) -> Optional["PersistedUser"]: + return None + + async def create_user(self, user: "User") -> Optional["PersistedUser"]: + pass + + async def upsert_feedback( + self, + feedback: Feedback, + ) -> str: + return "" + + @queue_until_user_message() + async def create_element(self, element_dict: "ElementDict"): + pass + + async def get_element(self, thread_id: str, element_id: str) -> Optional["ElementDict"]: + pass + + @queue_until_user_message() + async def delete_element(self, element_id: str): + pass + + @queue_until_user_message() + async def create_step(self, step_dict: "StepDict"): + pass + + @queue_until_user_message() + async def update_step(self, step_dict: "StepDict"): + pass + + @queue_until_user_message() + async def delete_step(self, step_id: str): + pass + + async def get_thread_author(self, thread_id: str) -> str: + return "" + + async def delete_thread(self, thread_id: str): + pass + + async def list_threads(self, pagination: "Pagination", filters: "ThreadFilter") -> "PaginatedResponse[ThreadDict]": + return PaginatedResponse(data=[], pageInfo=PageInfo(hasNextPage=False, endCursor=None)) + + async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]": + return None + + async def update_thread( + self, + thread_id: str, + name: Optional[str] = None, + user_id: Optional[str] = None, + metadata: Optional[Dict] = None, + tags: Optional[List[str]] = None, + ): + pass + + async def delete_user_session(self, id: str) -> bool: + return True diff --git a/backend/chainlit/data/chainlit.py b/backend/chainlit/data/chainlit.py new file mode 100644 index 0000000000..81477aaead --- /dev/null +++ b/backend/chainlit/data/chainlit.py @@ -0,0 +1,304 @@ +import json +from typing import TYPE_CHECKING, Dict, List, Optional, Union + +import aiofiles +from chainlit.config import config +from chainlit.logger import logger +from chainlit.types import Feedback, Pagination, ThreadDict, ThreadFilter +from chainlit.user import PersistedUser, User, UserDict +from chainlit.data.base import BaseDataLayer, queue_until_user_message +from literalai import Attachment +from literalai import Feedback as ClientFeedback +from literalai import PaginatedResponse +from literalai import Step as ClientStep +from literalai.step import StepDict as ClientStepDict +from literalai.thread import NumberListFilter, StringFilter, StringListFilter +from literalai.thread import ThreadFilter as ClientThreadFilter + +if TYPE_CHECKING: + from chainlit.element import Element, ElementDict + from chainlit.step import FeedbackDict, StepDict + + +class ChainlitDataLayer(BaseDataLayer): + def __init__(self, api_key: str, server: Optional[str]): + from literalai import LiteralClient + + self.client = LiteralClient(api_key=api_key, url=server) + logger.info("Chainlit data layer initialized") + + def attachment_to_element_dict(self, attachment: Attachment) -> "ElementDict": + metadata = attachment.metadata or {} + return { + "chainlitKey": None, + "display": metadata.get("display", "side"), + "language": metadata.get("language"), + "page": metadata.get("page"), + "size": metadata.get("size"), + "type": metadata.get("type", "file"), + "forId": attachment.step_id, + "id": attachment.id or "", + "mime": attachment.mime, + "name": attachment.name or "", + "objectKey": attachment.object_key, + "url": attachment.url, + "threadId": attachment.thread_id, + } + + def feedback_to_feedback_dict(self, feedback: Optional[ClientFeedback]) -> "Optional[FeedbackDict]": + if not feedback: + return None + return { + "id": feedback.id or "", + "forId": feedback.step_id or "", + "value": feedback.value or 0, # type: ignore + "comment": feedback.comment, + "strategy": "BINARY", + } + + def step_to_step_dict(self, step: ClientStep) -> "StepDict": + metadata = step.metadata or {} + input = (step.input or {}).get("content") or (json.dumps(step.input) if step.input and step.input != {} else "") + output = (step.output or {}).get("content") or ( + json.dumps(step.output) if step.output and step.output != {} else "" + ) + return { + "createdAt": step.created_at, + "id": step.id or "", + "threadId": step.thread_id or "", + "parentId": step.parent_id, + "feedback": self.feedback_to_feedback_dict(step.feedback), + "start": step.start_time, + "end": step.end_time, + "type": step.type or "undefined", + "name": step.name or "", + "generation": step.generation.to_dict() if step.generation else None, + "input": input, + "output": output, + "showInput": metadata.get("showInput", False), + "disableFeedback": metadata.get("disableFeedback", False), + "indent": metadata.get("indent"), + "language": metadata.get("language"), + "isError": metadata.get("isError", False), + "waitForAnswer": metadata.get("waitForAnswer", False), + } + + async def get_user(self, identifier: str) -> Optional[PersistedUser]: + user = await self.client.api.get_user(identifier=identifier) + if not user: + return None + return PersistedUser( + id=user.id or "", + identifier=user.identifier or "", + metadata=user.metadata, + createdAt=user.created_at or "", + ) + + async def create_user(self, user: User) -> Optional[PersistedUser]: + _user = await self.client.api.get_user(identifier=user.identifier) + if not _user: + _user = await self.client.api.create_user(identifier=user.identifier, metadata=user.metadata) + elif _user.id: + await self.client.api.update_user(id=_user.id, metadata=user.metadata) + return PersistedUser( + id=_user.id or "", + identifier=_user.identifier or "", + metadata=_user.metadata, + createdAt=_user.created_at or "", + ) + + async def upsert_feedback( + self, + feedback: Feedback, + ): + if feedback.id: + await self.client.api.update_feedback( + id=feedback.id, + update_params={ + "comment": feedback.comment, + "strategy": feedback.strategy, + "value": feedback.value, + }, + ) + return feedback.id + else: + created = await self.client.api.create_feedback( + step_id=feedback.forId, + value=feedback.value, + comment=feedback.comment, + strategy=feedback.strategy, + ) + return created.id or "" + + @queue_until_user_message() + async def create_element(self, element: "Element"): + metadata = { + "size": element.size, + "language": element.language, + "display": element.display, + "type": element.type, + "page": getattr(element, "page", None), + } + + if not element.for_id: + return + + object_key = None + + if not element.url: + if element.path: + async with aiofiles.open(element.path, "rb") as f: + content: Union[bytes, str] = await f.read() + elif element.content: + content = element.content + else: + raise ValueError("Either path or content must be provided") + uploaded = await self.client.api.upload_file( + content=content, mime=element.mime, thread_id=element.thread_id + ) + object_key = uploaded["object_key"] + + await self.client.api.send_steps( + [ + { + "id": element.for_id, + "threadId": element.thread_id, + "attachments": [ + { + "id": element.id, + "name": element.name, + "metadata": metadata, + "mime": element.mime, + "url": element.url, + "objectKey": object_key, + } + ], + } + ] + ) + + async def get_element(self, thread_id: str, element_id: str) -> Optional["ElementDict"]: + attachment = await self.client.api.get_attachment(id=element_id) + if not attachment: + return None + return self.attachment_to_element_dict(attachment) + + @queue_until_user_message() + async def delete_element(self, element_id: str): + await self.client.api.delete_attachment(id=element_id) + + @queue_until_user_message() + async def create_step(self, step_dict: "StepDict"): + metadata = { + "disableFeedback": step_dict.get("disableFeedback"), + "isError": step_dict.get("isError"), + "waitForAnswer": step_dict.get("waitForAnswer"), + "language": step_dict.get("language"), + "showInput": step_dict.get("showInput"), + } + + step: ClientStepDict = { + "createdAt": step_dict.get("createdAt"), + "startTime": step_dict.get("start"), + "endTime": step_dict.get("end"), + "generation": step_dict.get("generation"), + "id": step_dict.get("id"), + "parentId": step_dict.get("parentId"), + "name": step_dict.get("name"), + "threadId": step_dict.get("threadId"), + "type": step_dict.get("type"), + "metadata": metadata, + } + if step_dict.get("input"): + step["input"] = {"content": step_dict.get("input")} + if step_dict.get("output"): + step["output"] = {"content": step_dict.get("output")} + + await self.client.api.send_steps([step]) + + @queue_until_user_message() + async def update_step(self, step_dict: "StepDict"): + await self.create_step(step_dict) + + @queue_until_user_message() + async def delete_step(self, step_id: str): + await self.client.api.delete_step(id=step_id) + + async def get_thread_author(self, thread_id: str) -> str: + thread = await self.get_thread(thread_id) + if not thread: + return "" + user = thread.get("user") + if not user: + return "" + return user.get("identifier") or "" + + async def delete_thread(self, thread_id: str): + await self.client.api.delete_thread(id=thread_id) + + async def list_threads(self, pagination: "Pagination", filters: "ThreadFilter") -> "PaginatedResponse[ThreadDict]": + if not filters.userIdentifier: + raise ValueError("userIdentifier is required") + + client_filters = ClientThreadFilter( + participantsIdentifier=StringListFilter(operator="in", value=[filters.userIdentifier]), + ) + if filters.search: + client_filters.search = StringFilter(operator="ilike", value=filters.search) + if filters.feedback: + client_filters.feedbacksValue = NumberListFilter(operator="in", value=[filters.feedback]) + return await self.client.api.list_threads( + first=pagination.first, after=pagination.cursor, filters=client_filters + ) + + async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]": + thread = await self.client.api.get_thread(id=thread_id) + if not thread: + return None + elements = [] # List[ElementDict] + steps = [] # List[StepDict] + if thread.steps: + for step in thread.steps: + if config.ui.hide_cot and step.parent_id: + continue + for attachment in step.attachments: + elements.append(self.attachment_to_element_dict(attachment)) + if not config.features.prompt_playground and step.generation: + step.generation = None + steps.append(self.step_to_step_dict(step)) + + user = None # type: Optional["UserDict"] + + if thread.user: + user = { + "id": thread.user.id or "", + "identifier": thread.user.identifier or "", + "metadata": thread.user.metadata, + } + + return { + "createdAt": thread.created_at or "", + "id": thread.id, + "name": thread.name or None, + "steps": steps, + "elements": elements, + "metadata": thread.metadata, + "user": user, + "tags": thread.tags, + } + + async def update_thread( + self, + thread_id: str, + name: Optional[str] = None, + user_id: Optional[str] = None, + metadata: Optional[Dict] = None, + tags: Optional[List[str]] = None, + ): + await self.client.api.upsert_thread( + thread_id=thread_id, + name=name, + participant_id=user_id, + metadata=metadata, + tags=tags, + ) diff --git a/backend/chainlit/data/mongodb.py b/backend/chainlit/data/mongodb.py index e8b22480d8..4620c183da 100644 --- a/backend/chainlit/data/mongodb.py +++ b/backend/chainlit/data/mongodb.py @@ -1,12 +1,13 @@ from chainlit.data import BaseDataLayer, queue_until_user_message from chainlit.types import Feedback, Pagination, ThreadDict, ThreadFilter from chainlit.user import PersistedUser, User +from chainlit.logger import logger from literalai import Attachment, Feedback as _ClientFeedback from literalai import PageInfo, PaginatedResponse from literalai import Step as _ClientStep -from pymongo import MongoClient, DESCENDING -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union from datetime import datetime +import aiofiles import json @@ -31,8 +32,10 @@ class FeedbackDict(_FeedbackDict): forId: str -class MongoDataLayer(BaseDataLayer): +class MongoDBDataLayer(BaseDataLayer): def __init__(self, db_url: str): + from pymongo import MongoClient + # Connect to the database self.client = MongoClient(db_url) # type: MongoClient self.db = self.client.get_database() @@ -42,6 +45,7 @@ def __init__(self, db_url: str): self.elements_collection = self.db.get_collection("elements") self.steps_collection = self.db.get_collection("steps") self.threads_collection = self.db.get_collection("threads") + logger.info("MongoDB data layer initialized") def attachment_to_element_dict(self, attachment: Attachment) -> "ElementDict": metadata = attachment.metadata or {} @@ -111,33 +115,80 @@ async def get_user(self, identifier: str) -> Optional[PersistedUser]: return None async def create_user(self, user: User) -> Optional[PersistedUser]: - user_data: Dict[str, Any] = { - "identifier": user.identifier, - "metadata": user.metadata, - "createdAt": datetime.utcnow(), - } - result = self.users_collection.insert_one(user_data) - user_data["_id"] = result.inserted_id - return PersistedUser(**user_data) + _user = await self.get_user(user.identifier) + if not _user: + user_data: Dict[str, Any] = { + "identifier": user.identifier, + "metadata": user.metadata, + "createdAt": datetime.utcnow(), + } + result = self.users_collection.insert_one(user_data) + user_data["_id"] = result.inserted_id + return PersistedUser(**user_data) + + return PersistedUser( + id=_user.id or "", + identifier=_user.identifier or "", + metadata=_user.metadata, + createdAt=_user.created_at or "", + ) async def upsert_feedback(self, feedback: Feedback): feedback_data = { - "step_id": feedback.forId, + "id": feedback.id, + "stepId": feedback.forId, "value": feedback.value, "strategy": feedback.strategy, "comment": feedback.comment, } - if feedback.id: - self.steps_collection.update_one({"_id": feedback.id}, {"$set": {"feedback": feedback_data}}) - return feedback.id - else: - result = self.steps_collection.update_one({"_id": feedback.forId}, {"$set": {"feedback": feedback_data}}) - return result.upserted_id or "" + self.steps_collection.update_one({"_id": feedback.id}, {"$set": {"feedback": feedback_data}}) + return feedback.id @queue_until_user_message() async def create_element(self, element: "Element"): - # TODO: Support file upload from user - pass + metadata = { + "size": element.size, + "language": element.language, + "display": element.display, + "type": element.type, + "page": getattr(element, "page", None), + } + + if not element.for_id: + return + + object_key = None + + if not element.url: + if element.path: + async with aiofiles.open(element.path, "rb") as f: + content = await f.read() # type: Union[bytes, str] + elif element.content: + content = element.content + else: + raise ValueError("Either path or content must be provided") + uploaded = await self.client.api.upload_file( + content=content, mime=element.mime, thread_id=element.thread_id + ) + object_key = uploaded["object_key"] + + element_data = { + "id": element.id, + "threadId": element.thread_id, + "stepId": element.for_id, + "type": element.type, + "url": element.url, + "chainlitKey": element.chainlit_key, + "name": element.name, + "display": element.display, + "objectKey": object_key, + "size": element.size, + "page": element.page, + "language": element.language, + "mime": element.mime, + "metadata": metadata, + } + self.elements_collection.insert_one(element_data) async def get_element(self, thread_id: str, element_id: str) -> Optional["ElementDict"]: element_data = self.elements_collection.find_one({"id": element_id}) @@ -160,19 +211,20 @@ async def create_step(self, step_dict: "StepDict"): } step_data = { - "createdAt": step_dict.get("createdAt"), "id": step_dict.get("id"), "threadId": step_dict.get("threadId"), "parentId": step_dict.get("parentId"), - "feedback": step_dict.get("feedback"), - "start": step_dict.get("start"), - "end": step_dict.get("end"), - "type": step_dict.get("type"), "name": step_dict.get("name"), - "generation": step_dict.get("generation"), + "type": step_dict.get("type"), "input": step_dict.get("input"), "output": step_dict.get("output"), + "createdAt": step_dict.get("createdAt"), + "startTime": step_dict.get("start"), + "endTime": step_dict.get("end"), + "generation": step_dict.get("generation"), "metadata": metadata, + "feedback": step_dict.get("feedback"), + "attachments": step_dict.get("attachments"), } self.steps_collection.insert_one(step_data) @@ -203,7 +255,7 @@ async def list_threads(self, pagination: Pagination, filters: ThreadFilter) -> P if not filters.userIdentifier: raise ValueError("userIdentifier is required") - query: Dict[str, Any] = {"participants.identifier": filters.userIdentifier} + query = {"participants.identifier": filters.userIdentifier} if filters.search: query["$text"] = {"$search": filters.search} if filters.feedback: @@ -215,7 +267,7 @@ async def list_threads(self, pagination: Pagination, filters: ThreadFilter) -> P threads_data = list(self.threads_collection.find(query).sort(sort).limit(pagination.first)) - threads: List[ThreadDict] = [ + threads = [ { "id": thread["_id"], "createdAt": thread["createdAt"], @@ -230,7 +282,7 @@ async def list_threads(self, pagination: Pagination, filters: ThreadFilter) -> P ] has_next_page = len(threads) == pagination.first - end_cursor = threads[-1]["id"] if has_next_page else None + end_cursor = threads[-1]["_id"] if has_next_page else None return PaginatedResponse( data=threads, @@ -246,8 +298,8 @@ async def get_thread(self, thread_id: str) -> Optional[ThreadDict]: steps = list(self.steps_collection.find({"threadId": thread_data["_id"]})) return { - "createdAt": thread_data["createdAt"], "id": thread_data["_id"], + "createdAt": thread_data["createdAt"], "name": thread_data.get("name"), "steps": [self.step_to_step_dict(step) for step in steps], "elements": [self.attachment_to_element_dict(element) for element in elements], @@ -264,7 +316,7 @@ async def update_thread( metadata: Optional[Dict] = None, tags: Optional[List[str]] = None, ): - update_data: Dict[str, Any] = {} + update_data = {} if name is not None: update_data["name"] = name if user_id is not None: From 661b8c0ac443839855bbb27f9ab95d789e20f1ff Mon Sep 17 00:00:00 2001 From: San Nguyen Date: Sat, 9 Mar 2024 23:05:32 +0900 Subject: [PATCH 4/6] revert literalai api Signed-off-by: San Nguyen --- backend/chainlit/data/__init__.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/backend/chainlit/data/__init__.py b/backend/chainlit/data/__init__.py index eebfaa6ade..9d3e870083 100644 --- a/backend/chainlit/data/__init__.py +++ b/backend/chainlit/data/__init__.py @@ -1,7 +1,6 @@ import os from typing import Optional -from chainlit.config import config from chainlit.data.base import BaseDataLayer from chainlit.data.chainlit import ChainlitDataLayer from chainlit.data.mongodb import MongoDBDataLayer @@ -9,16 +8,9 @@ _data_layer: Optional[BaseDataLayer] = None -if config.data_layer.database == "chainlit": - api_key = os.environ.get("LITERAL_API_KEY") - assert api_key is not None +if api_key := os.environ.get("LITERAL_API_KEY"): server = os.environ.get("LITERAL_SERVER") _data_layer = ChainlitDataLayer(api_key=api_key, server=server) -if config.data_layer.database == "mongodb": - if config.data_layer.object_storage != "s3": - raise ValueError("MongoDB data layer requires an S3 object storage") - db_url = os.environ.get("CHAINLIT_MONGODB_URL") - _data_layer = MongoDBDataLayer(db_url) def get_data_layer(): From fce8fbc513e12ab6110ef39df63fd1ccc625f562 Mon Sep 17 00:00:00 2001 From: San Nguyen Date: Mon, 11 Mar 2024 01:52:59 +0900 Subject: [PATCH 5/6] update MongoDataLayer Signed-off-by: San Nguyen --- backend/chainlit/data/__init__.py | 7 +- backend/chainlit/data/mongodb.py | 368 ++++++++++++++++++------------ 2 files changed, 222 insertions(+), 153 deletions(-) diff --git a/backend/chainlit/data/__init__.py b/backend/chainlit/data/__init__.py index 9d3e870083..182c33948c 100644 --- a/backend/chainlit/data/__init__.py +++ b/backend/chainlit/data/__init__.py @@ -3,7 +3,7 @@ from chainlit.data.base import BaseDataLayer from chainlit.data.chainlit import ChainlitDataLayer -from chainlit.data.mongodb import MongoDBDataLayer +from chainlit.data.mongodb import MongoDataLayer _data_layer: Optional[BaseDataLayer] = None @@ -12,6 +12,11 @@ server = os.environ.get("LITERAL_SERVER") _data_layer = ChainlitDataLayer(api_key=api_key, server=server) +if mongodb_uri := os.environ.get("CHAINLIT_MONGODB_URI"): + s3_bucket = os.environ.get("CHAINLIT_S3_BUCKET") + assert s3_bucket is not None, "Environment variable CHAINLIT_S3_BUCKET is required" + _data_layer = MongoDataLayer(mongodb_uri=mongodb_uri, s3_bucket=s3_bucket) + def get_data_layer(): return _data_layer diff --git a/backend/chainlit/data/mongodb.py b/backend/chainlit/data/mongodb.py index 4620c183da..3851462a27 100644 --- a/backend/chainlit/data/mongodb.py +++ b/backend/chainlit/data/mongodb.py @@ -1,51 +1,67 @@ -from chainlit.data import BaseDataLayer, queue_until_user_message -from chainlit.types import Feedback, Pagination, ThreadDict, ThreadFilter -from chainlit.user import PersistedUser, User +from dataclasses import asdict +import json +import os +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +import asyncio + +import aiofiles +from chainlit.config import config +from chainlit.context import context from chainlit.logger import logger -from literalai import Attachment, Feedback as _ClientFeedback +from chainlit.session import WebsocketSession +from chainlit.types import Feedback, Pagination, ThreadDict, ThreadFilter +from chainlit.user import PersistedUser, User, UserDict +from chainlit.data import BaseDataLayer, queue_until_user_message +from literalai import Attachment +from literalai import Feedback as ClientFeedback from literalai import PageInfo, PaginatedResponse -from literalai import Step as _ClientStep -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union +from literalai import Step as ClientStep +from literalai.step import StepDict as ClientStepDict +from literalai.thread import NumberListFilter, StringFilter, StringListFilter +from literalai.thread import ThreadFilter as ClientThreadFilter from datetime import datetime -import aiofiles -import json if TYPE_CHECKING: from chainlit.element import Element, ElementDict - from chainlit.step import FeedbackDict as _FeedbackDict, StepDict - - class ClientFeedback(_ClientFeedback): - """Fix typing for _ClientFeedback""" - - value: Literal[-1, 0, 0] - - class ClientStep(_ClientStep): - """Fix typing for _ClientStep""" - - feedback: Optional[ClientFeedback] - - class FeedbackDict(_FeedbackDict): - """Augment typing for _FeedbackDict""" - - id: str - forId: str - - -class MongoDBDataLayer(BaseDataLayer): - def __init__(self, db_url: str): - from pymongo import MongoClient - - # Connect to the database - self.client = MongoClient(db_url) # type: MongoClient - self.db = self.client.get_database() + from chainlit.step import FeedbackDict, StepDict + + +class MongoDataLayer(BaseDataLayer): + def __init__(self, mongodb_uri: str, s3_bucket: str): + import boto3 + from pymongo import MongoClient, ASCENDING, DESCENDING + from pymongo.errors import DuplicateKeyError, PyMongoError + + if TYPE_CHECKING: + from mypy_boto3_s3 import S3Client + + self.mongo_client: MongoClient = MongoClient(mongodb_uri) + self.s3_client: S3Client = boto3.client("s3") + self.s3_bucket = s3_bucket + + self.db = self.mongo_client.get_database() + self.users = self.db.get_collection("users") + self.threads = self.db.get_collection("threads") + self.steps = self.db.get_collection("steps") + self.elements = self.db.get_collection("elements") + self.feedbacks = self.db.get_collection("feedbacks") + + # Create indexes for faster querying + try: + self.threads.create_index( + [("user.identifier", ASCENDING), ("createdAt", DESCENDING)], + background=True, + ) + self.steps.create_index([("threadId", ASCENDING)], background=True) + self.elements.create_index([("threadId", ASCENDING), ("forId", ASCENDING)], background=True) + self.feedbacks.create_index([("forId", ASCENDING)], background=True) + except DuplicateKeyError: + logger.info("Indexes already exist!") + except PyMongoError as e: + logger.warning("Errors creating indexes MongoDB: %r", e) - # Get collection references - self.users_collection = self.db.get_collection("users") - self.elements_collection = self.db.get_collection("elements") - self.steps_collection = self.db.get_collection("steps") - self.threads_collection = self.db.get_collection("threads") - logger.info("MongoDB data layer initialized") + logger.info("Mongo data layer initialized") def attachment_to_element_dict(self, attachment: Attachment) -> "ElementDict": metadata = attachment.metadata or {} @@ -71,7 +87,7 @@ def feedback_to_feedback_dict(self, feedback: Optional[ClientFeedback]) -> "Opti return { "id": feedback.id or "", "forId": feedback.step_id or "", - "value": feedback.value or 0, + "value": feedback.value or 0, # type: ignore "comment": feedback.comment, "strategy": "BINARY", } @@ -104,48 +120,43 @@ def step_to_step_dict(self, step: ClientStep) -> "StepDict": } async def get_user(self, identifier: str) -> Optional[PersistedUser]: - user_data = self.users_collection.find_one({"identifier": identifier}) - if user_data: - return PersistedUser( - id=user_data["_id"], - identifier=user_data["identifier"], - metadata=user_data.get("metadata", {}), - createdAt=user_data.get("createdAt"), - ) - return None + user_dict = self.users.find_one({"identifier": identifier}) + if not user_dict: + return None + return PersistedUser( + id=str(user_dict.pop("_id")), + **user_dict, + ) async def create_user(self, user: User) -> Optional[PersistedUser]: _user = await self.get_user(user.identifier) if not _user: - user_data: Dict[str, Any] = { - "identifier": user.identifier, - "metadata": user.metadata, - "createdAt": datetime.utcnow(), - } - result = self.users_collection.insert_one(user_data) - user_data["_id"] = result.inserted_id - return PersistedUser(**user_data) - - return PersistedUser( - id=_user.id or "", - identifier=_user.identifier or "", - metadata=_user.metadata, - createdAt=_user.created_at or "", - ) + user_dict: Dict[str, Any] = user.to_dict() + user_dict["createdAt"] = datetime.utcnow().isoformat() + user_id = str(self.users.insert_one(user_dict).inserted_id) + return PersistedUser( + id=user_id, + **user_dict, + ) + return _user - async def upsert_feedback(self, feedback: Feedback): - feedback_data = { - "id": feedback.id, - "stepId": feedback.forId, - "value": feedback.value, - "strategy": feedback.strategy, - "comment": feedback.comment, - } - self.steps_collection.update_one({"_id": feedback.id}, {"$set": {"feedback": feedback_data}}) - return feedback.id + async def upsert_feedback( + self, + feedback: Feedback, + ) -> str: + feedback_dict = asdict(feedback) + feedback_id = self.feedbacks.update_one( + {"forId": feedback.forId}, + {"$set": feedback_dict}, + upsert=True, + ).upserted_id + return str(feedback_id) @queue_until_user_message() async def create_element(self, element: "Element"): + if not element.for_id: + return + metadata = { "size": element.size, "language": element.language, @@ -154,54 +165,60 @@ async def create_element(self, element: "Element"): "page": getattr(element, "page", None), } - if not element.for_id: - return - - object_key = None + object_key: Optional[str] = None if not element.url: if element.path: async with aiofiles.open(element.path, "rb") as f: - content = await f.read() # type: Union[bytes, str] + content: Union[bytes, str] = await f.read() elif element.content: content = element.content else: raise ValueError("Either path or content must be provided") - uploaded = await self.client.api.upload_file( - content=content, mime=element.mime, thread_id=element.thread_id + assert ( + context.session and context.session.user and context.session.user.identifier + ), "User is not set in Chainlit context" + object_key = f"{context.session.user.identifier}/{element.id}" + f"/{element.name}" if element.name else "" + self.s3_client.put_object( + Bucket=self.s3_bucket, + Key=object_key, + Body=content, + ContentType=element.mime or "", ) - object_key = uploaded["object_key"] + element.url = f"s3://{self.s3_bucket}/{object_key}" - element_data = { + element_dict = { "id": element.id, "threadId": element.thread_id, "stepId": element.for_id, + "name": element.name, + "metadata": metadata, "type": element.type, + "mime": element.mime, "url": element.url, - "chainlitKey": element.chainlit_key, - "name": element.name, - "display": element.display, "objectKey": object_key, - "size": element.size, - "page": element.page, - "language": element.language, - "mime": element.mime, - "metadata": metadata, } - self.elements_collection.insert_one(element_data) + # Set "_id" to the "element.id" from frontend + self.elements.insert_one({"_id": element.id} | element_dict) async def get_element(self, thread_id: str, element_id: str) -> Optional["ElementDict"]: - element_data = self.elements_collection.find_one({"id": element_id}) - if not element_data: + element_dict = self.elements.find_one({"_id": element_id, "threadId": thread_id}) + if not element_dict: return None - return element_data + return element_dict @queue_until_user_message() async def delete_element(self, element_id: str): - self.elements_collection.delete_one({"id": element_id}) + element_dict = self.elements.find_one_and_delete({"_id": element_id}) + if element_dict and element_dict.get("objectKey"): + self.s3_client.delete_object(Bucket=self.s3_bucket, Key=element_dict["objectKey"]) @queue_until_user_message() async def create_step(self, step_dict: "StepDict"): + await self.update_step(step_dict) + + @queue_until_user_message() + async def update_step(self, step_dict: "StepDict"): metadata = { "disableFeedback": step_dict.get("disableFeedback"), "isError": step_dict.get("isError"), @@ -210,35 +227,46 @@ async def create_step(self, step_dict: "StepDict"): "showInput": step_dict.get("showInput"), } - step_data = { - "id": step_dict.get("id"), - "threadId": step_dict.get("threadId"), - "parentId": step_dict.get("parentId"), - "name": step_dict.get("name"), - "type": step_dict.get("type"), - "input": step_dict.get("input"), - "output": step_dict.get("output"), + step: ClientStepDict = { "createdAt": step_dict.get("createdAt"), "startTime": step_dict.get("start"), "endTime": step_dict.get("end"), "generation": step_dict.get("generation"), + "id": step_dict.get("id"), + "parentId": step_dict.get("parentId"), + "name": step_dict.get("name"), + "threadId": step_dict.get("threadId"), + "type": step_dict.get("type"), "metadata": metadata, - "feedback": step_dict.get("feedback"), - "attachments": step_dict.get("attachments"), } - - self.steps_collection.insert_one(step_data) - - @queue_until_user_message() - async def update_step(self, step_dict: "StepDict"): - self.steps_collection.update_one({"_id": step_dict["id"]}, {"$set": step_dict}) + if step_dict.get("input"): + step["input"] = {"content": step_dict.get("input")} + if step_dict.get("output"): + step["output"] = {"content": step_dict.get("output")} + # Use frontend generated step id for "_id" + step_with_id = {"_id": step["id"]} | step + self.steps.update_one({"_id": step["id"]}, step_with_id, upsert=True) + + async def _delete_steps(self, step_ids: List[str]): + """Delete all elements and steps associated with steps""" + step_ids_to_delete = [] + step_ids # Create new list to avoid modifying the original list + child_steps_cursor = self.steps.find({"parentId": {"$in": step_ids_to_delete}}, {"_id": 1}) + for child_step in child_steps_cursor: + step_ids_to_delete.append(child_step["_id"]) + elements_cursor = self.elements.find({"stepId": {"$in": step_ids_to_delete}}, {"_id": 1, "objectKey": 1}) + await asyncio.gather( + *[self.delete_element(element["_id"]) for element in elements_cursor if element.get("objectKey")] + ) + self.feedbacks.delete_many({"forId": {"$in": step_ids_to_delete}}) + self.steps.delete_many({"_id": {"$in": step_ids_to_delete}}) @queue_until_user_message() async def delete_step(self, step_id: str): - self.steps_collection.delete_one({"_id": step_id}) + """Delete all elements and steps associated with the step""" + await self._delete_steps([step_id]) async def get_thread_author(self, thread_id: str) -> str: - thread = self.threads_collection.find_one({"_id": thread_id}) + thread = await self.get_thread(thread_id) if not thread: return "" user = thread.get("user") @@ -247,27 +275,43 @@ async def get_thread_author(self, thread_id: str) -> str: return user.get("identifier") or "" async def delete_thread(self, thread_id: str): - self.threads_collection.delete_one({"_id": thread_id}) - self.steps_collection.delete_many({"threadId": thread_id}) - self.elements_collection.delete_many({"threadId": thread_id}) + thread_dict = await self.get_thread(thread_id) + if not thread_dict: + return + + # Delete all steps, feedbacks and elements associated with the thread + steps_cursor = self.steps.find({"threadId": thread_id}, {"_id": 1}) # Only return the "_id" of steps + await self._delete_steps([step["_id"] for step in steps_cursor]) + + # Delete the thread itself + self.threads.delete_one({"_id": thread_id}) + + async def list_threads(self, pagination: "Pagination", filters: "ThreadFilter") -> "PaginatedResponse[ThreadDict]": + from pymongo import DESCENDING - async def list_threads(self, pagination: Pagination, filters: ThreadFilter) -> PaginatedResponse[ThreadDict]: if not filters.userIdentifier: raise ValueError("userIdentifier is required") - query = {"participants.identifier": filters.userIdentifier} + client_filters = ClientThreadFilter( + participantsIdentifier=StringListFilter(operator="in", value=[filters.userIdentifier]), + ) + if filters.search: + client_filters.search = StringFilter(operator="ilike", value=filters.search) + if filters.feedback: + client_filters.feedbacksValue = NumberListFilter(operator="in", value=[filters.feedback]) + + # Build MongoDB query based on filters + query: Dict[str, Any] = {"user.identifier": filters.userIdentifier} if filters.search: query["$text"] = {"$search": filters.search} if filters.feedback: query["feedback.value"] = filters.feedback - sort = [("createdAt", DESCENDING)] - if pagination.cursor: - query["_id"] = {"$lt": pagination.cursor} + # Apply pagination + skip = 0 if not pagination.cursor else int(pagination.cursor) + threads_cursor = self.threads.find(query).sort([("createdAt", DESCENDING)]).skip(skip).limit(pagination.first) - threads_data = list(self.threads_collection.find(query).sort(sort).limit(pagination.first)) - - threads = [ + threads: List[ThreadDict] = [ { "id": thread["_id"], "createdAt": thread["createdAt"], @@ -275,37 +319,56 @@ async def list_threads(self, pagination: Pagination, filters: ThreadFilter) -> P "user": thread.get("user"), "tags": thread.get("tags"), "metadata": thread.get("metadata"), - "steps": list(self.steps_collection.find({"threadId": thread["_id"]})), - "elements": list(self.elements_collection.find({"threadId": thread["_id"]})), + "steps": thread.get("steps"), + "elements": thread.get("elements"), } - for thread in threads_data + for thread in threads_cursor ] has_next_page = len(threads) == pagination.first - end_cursor = threads[-1]["_id"] if has_next_page else None + end_cursor = str(threads[-1]["id"]) if has_next_page else None - return PaginatedResponse( + return PaginatedResponse[ThreadDict]( data=threads, pageInfo=PageInfo(hasNextPage=has_next_page, endCursor=end_cursor), ) - async def get_thread(self, thread_id: str) -> Optional[ThreadDict]: - thread_data = self.threads_collection.find_one({"_id": thread_id}) - if not thread_data: + async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]": + thread_dict = self.threads.find_one({"_id": thread_id}) + if not thread_dict: return None - elements = list(self.elements_collection.find({"threadId": thread_data["_id"]})) - steps = list(self.steps_collection.find({"threadId": thread_data["_id"]})) + elements: List[ElementDict] = [] + steps: List[StepDict] = [] + + steps_cursor = self.steps.find({"threadId": thread_id}).sort([("createdAt", 1)]) + for step in steps_cursor: + if config.ui.hide_cot and step.get("parentId"): + continue + for attachment in step.get("attachments", []): + elements.append(self.attachment_to_element_dict(attachment)) + if not config.features.prompt_playground and step.get("generation"): + step.pop("generation", None) + steps.append(self.step_to_step_dict(step)) + + user = None # type: Optional["UserDict"] + + if thread_dict.get("user"): + user = { + "id": str(thread_dict["user"]["_id"]), + "identifier": thread_dict["user"]["identifier"], + "metadata": thread_dict["user"]["metadata"], + } return { - "id": thread_data["_id"], - "createdAt": thread_data["createdAt"], - "name": thread_data.get("name"), - "steps": [self.step_to_step_dict(step) for step in steps], - "elements": [self.attachment_to_element_dict(element) for element in elements], - "metadata": thread_data.get("metadata"), - "user": thread_data.get("user"), - "tags": thread_data.get("tags"), + "createdAt": thread_dict["createdAt"], + "id": str(thread_dict["_id"]), + "name": thread_dict.get("name"), + "steps": steps, + "elements": elements, + "metadata": thread_dict.get("metadata"), + "user": user, + "tags": thread_dict.get("tags"), } async def update_thread( @@ -316,18 +379,19 @@ async def update_thread( metadata: Optional[Dict] = None, tags: Optional[List[str]] = None, ): - update_data = {} + update_dict: Dict[str, Any] = {} if name is not None: - update_data["name"] = name + update_dict["name"] = name if user_id is not None: - update_data["user"] = {"identifier": user_id} + update_dict["user"] = {"_id": user_id} if metadata is not None: - update_data["metadata"] = metadata + update_dict["metadata"] = metadata if tags is not None: - update_data["tags"] = tags + update_dict["tags"] = tags - self.threads_collection.update_one({"_id": thread_id}, {"$set": update_data}) + self.threads.update_one({"_id": thread_id}, {"$set": update_dict}) async def delete_user_session(self, id: str) -> bool: - result = self.threads_collection.delete_many({"metadata.id": id}) - return result.deleted_count > 0 + if not self.threads.delete_one({"metadata.id": id}): + return False + return True From 864da6684bf9554a07500ec8e75052feeaa84ecf Mon Sep 17 00:00:00 2001 From: San Nguyen Date: Fri, 15 Mar 2024 14:52:57 +0900 Subject: [PATCH 6/6] use mongo_api to mimic literalai client instead Signed-off-by: San Nguyen --- backend/chainlit/data/mongodb.py | 399 +---------------- backend/chainlit/data/mongodb_api.py | 632 +++++++++++++++++++++++++++ 2 files changed, 641 insertions(+), 390 deletions(-) create mode 100644 backend/chainlit/data/mongodb_api.py diff --git a/backend/chainlit/data/mongodb.py b/backend/chainlit/data/mongodb.py index 3851462a27..c7fea7cb94 100644 --- a/backend/chainlit/data/mongodb.py +++ b/backend/chainlit/data/mongodb.py @@ -1,397 +1,16 @@ -from dataclasses import asdict -import json -import os -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union -import asyncio +from chainlit import logger +from chainlit.data import ChainlitDataLayer +from literalai.client import LiteralClient -import aiofiles -from chainlit.config import config -from chainlit.context import context -from chainlit.logger import logger -from chainlit.session import WebsocketSession -from chainlit.types import Feedback, Pagination, ThreadDict, ThreadFilter -from chainlit.user import PersistedUser, User, UserDict -from chainlit.data import BaseDataLayer, queue_until_user_message -from literalai import Attachment -from literalai import Feedback as ClientFeedback -from literalai import PageInfo, PaginatedResponse -from literalai import Step as ClientStep -from literalai.step import StepDict as ClientStepDict -from literalai.thread import NumberListFilter, StringFilter, StringListFilter -from literalai.thread import ThreadFilter as ClientThreadFilter -from datetime import datetime +class MongoDataLayer(ChainlitDataLayer): + def __init__(self, mongodb_uri: str): + # Do not call super().__init__() here, because it will create a new LiteralClient -if TYPE_CHECKING: - from chainlit.element import Element, ElementDict - from chainlit.step import FeedbackDict, StepDict + from .mongodb_api import API - -class MongoDataLayer(BaseDataLayer): - def __init__(self, mongodb_uri: str, s3_bucket: str): - import boto3 - from pymongo import MongoClient, ASCENDING, DESCENDING - from pymongo.errors import DuplicateKeyError, PyMongoError - - if TYPE_CHECKING: - from mypy_boto3_s3 import S3Client - - self.mongo_client: MongoClient = MongoClient(mongodb_uri) - self.s3_client: S3Client = boto3.client("s3") - self.s3_bucket = s3_bucket - - self.db = self.mongo_client.get_database() - self.users = self.db.get_collection("users") - self.threads = self.db.get_collection("threads") - self.steps = self.db.get_collection("steps") - self.elements = self.db.get_collection("elements") - self.feedbacks = self.db.get_collection("feedbacks") - - # Create indexes for faster querying - try: - self.threads.create_index( - [("user.identifier", ASCENDING), ("createdAt", DESCENDING)], - background=True, - ) - self.steps.create_index([("threadId", ASCENDING)], background=True) - self.elements.create_index([("threadId", ASCENDING), ("forId", ASCENDING)], background=True) - self.feedbacks.create_index([("forId", ASCENDING)], background=True) - except DuplicateKeyError: - logger.info("Indexes already exist!") - except PyMongoError as e: - logger.warning("Errors creating indexes MongoDB: %r", e) + self.client = LiteralClient(api_key="literalai") # API key is unused + self.client.api = API(mongodb_uri) # type: ignore logger.info("Mongo data layer initialized") - def attachment_to_element_dict(self, attachment: Attachment) -> "ElementDict": - metadata = attachment.metadata or {} - return { - "chainlitKey": None, - "display": metadata.get("display", "side"), - "language": metadata.get("language"), - "page": metadata.get("page"), - "size": metadata.get("size"), - "type": metadata.get("type", "file"), - "forId": attachment.step_id, - "id": attachment.id or "", - "mime": attachment.mime, - "name": attachment.name or "", - "objectKey": attachment.object_key, - "url": attachment.url, - "threadId": attachment.thread_id, - } - - def feedback_to_feedback_dict(self, feedback: Optional[ClientFeedback]) -> "Optional[FeedbackDict]": - if not feedback: - return None - return { - "id": feedback.id or "", - "forId": feedback.step_id or "", - "value": feedback.value or 0, # type: ignore - "comment": feedback.comment, - "strategy": "BINARY", - } - - def step_to_step_dict(self, step: ClientStep) -> "StepDict": - metadata = step.metadata or {} - input = (step.input or {}).get("content") or (json.dumps(step.input) if step.input and step.input != {} else "") - output = (step.output or {}).get("content") or ( - json.dumps(step.output) if step.output and step.output != {} else "" - ) - return { - "createdAt": step.created_at, - "id": step.id or "", - "threadId": step.thread_id or "", - "parentId": step.parent_id, - "feedback": self.feedback_to_feedback_dict(step.feedback), - "start": step.start_time, - "end": step.end_time, - "type": step.type or "undefined", - "name": step.name or "", - "generation": step.generation.to_dict() if step.generation else None, - "input": input, - "output": output, - "showInput": metadata.get("showInput", False), - "disableFeedback": metadata.get("disableFeedback", False), - "indent": metadata.get("indent"), - "language": metadata.get("language"), - "isError": metadata.get("isError", False), - "waitForAnswer": metadata.get("waitForAnswer", False), - } - - async def get_user(self, identifier: str) -> Optional[PersistedUser]: - user_dict = self.users.find_one({"identifier": identifier}) - if not user_dict: - return None - return PersistedUser( - id=str(user_dict.pop("_id")), - **user_dict, - ) - - async def create_user(self, user: User) -> Optional[PersistedUser]: - _user = await self.get_user(user.identifier) - if not _user: - user_dict: Dict[str, Any] = user.to_dict() - user_dict["createdAt"] = datetime.utcnow().isoformat() - user_id = str(self.users.insert_one(user_dict).inserted_id) - return PersistedUser( - id=user_id, - **user_dict, - ) - return _user - - async def upsert_feedback( - self, - feedback: Feedback, - ) -> str: - feedback_dict = asdict(feedback) - feedback_id = self.feedbacks.update_one( - {"forId": feedback.forId}, - {"$set": feedback_dict}, - upsert=True, - ).upserted_id - return str(feedback_id) - - @queue_until_user_message() - async def create_element(self, element: "Element"): - if not element.for_id: - return - - metadata = { - "size": element.size, - "language": element.language, - "display": element.display, - "type": element.type, - "page": getattr(element, "page", None), - } - - object_key: Optional[str] = None - - if not element.url: - if element.path: - async with aiofiles.open(element.path, "rb") as f: - content: Union[bytes, str] = await f.read() - elif element.content: - content = element.content - else: - raise ValueError("Either path or content must be provided") - assert ( - context.session and context.session.user and context.session.user.identifier - ), "User is not set in Chainlit context" - object_key = f"{context.session.user.identifier}/{element.id}" + f"/{element.name}" if element.name else "" - self.s3_client.put_object( - Bucket=self.s3_bucket, - Key=object_key, - Body=content, - ContentType=element.mime or "", - ) - element.url = f"s3://{self.s3_bucket}/{object_key}" - - element_dict = { - "id": element.id, - "threadId": element.thread_id, - "stepId": element.for_id, - "name": element.name, - "metadata": metadata, - "type": element.type, - "mime": element.mime, - "url": element.url, - "objectKey": object_key, - } - # Set "_id" to the "element.id" from frontend - self.elements.insert_one({"_id": element.id} | element_dict) - - async def get_element(self, thread_id: str, element_id: str) -> Optional["ElementDict"]: - element_dict = self.elements.find_one({"_id": element_id, "threadId": thread_id}) - if not element_dict: - return None - return element_dict - - @queue_until_user_message() - async def delete_element(self, element_id: str): - element_dict = self.elements.find_one_and_delete({"_id": element_id}) - if element_dict and element_dict.get("objectKey"): - self.s3_client.delete_object(Bucket=self.s3_bucket, Key=element_dict["objectKey"]) - - @queue_until_user_message() - async def create_step(self, step_dict: "StepDict"): - await self.update_step(step_dict) - - @queue_until_user_message() - async def update_step(self, step_dict: "StepDict"): - metadata = { - "disableFeedback": step_dict.get("disableFeedback"), - "isError": step_dict.get("isError"), - "waitForAnswer": step_dict.get("waitForAnswer"), - "language": step_dict.get("language"), - "showInput": step_dict.get("showInput"), - } - - step: ClientStepDict = { - "createdAt": step_dict.get("createdAt"), - "startTime": step_dict.get("start"), - "endTime": step_dict.get("end"), - "generation": step_dict.get("generation"), - "id": step_dict.get("id"), - "parentId": step_dict.get("parentId"), - "name": step_dict.get("name"), - "threadId": step_dict.get("threadId"), - "type": step_dict.get("type"), - "metadata": metadata, - } - if step_dict.get("input"): - step["input"] = {"content": step_dict.get("input")} - if step_dict.get("output"): - step["output"] = {"content": step_dict.get("output")} - # Use frontend generated step id for "_id" - step_with_id = {"_id": step["id"]} | step - self.steps.update_one({"_id": step["id"]}, step_with_id, upsert=True) - - async def _delete_steps(self, step_ids: List[str]): - """Delete all elements and steps associated with steps""" - step_ids_to_delete = [] + step_ids # Create new list to avoid modifying the original list - child_steps_cursor = self.steps.find({"parentId": {"$in": step_ids_to_delete}}, {"_id": 1}) - for child_step in child_steps_cursor: - step_ids_to_delete.append(child_step["_id"]) - elements_cursor = self.elements.find({"stepId": {"$in": step_ids_to_delete}}, {"_id": 1, "objectKey": 1}) - await asyncio.gather( - *[self.delete_element(element["_id"]) for element in elements_cursor if element.get("objectKey")] - ) - self.feedbacks.delete_many({"forId": {"$in": step_ids_to_delete}}) - self.steps.delete_many({"_id": {"$in": step_ids_to_delete}}) - - @queue_until_user_message() - async def delete_step(self, step_id: str): - """Delete all elements and steps associated with the step""" - await self._delete_steps([step_id]) - - async def get_thread_author(self, thread_id: str) -> str: - thread = await self.get_thread(thread_id) - if not thread: - return "" - user = thread.get("user") - if not user: - return "" - return user.get("identifier") or "" - - async def delete_thread(self, thread_id: str): - thread_dict = await self.get_thread(thread_id) - if not thread_dict: - return - - # Delete all steps, feedbacks and elements associated with the thread - steps_cursor = self.steps.find({"threadId": thread_id}, {"_id": 1}) # Only return the "_id" of steps - await self._delete_steps([step["_id"] for step in steps_cursor]) - - # Delete the thread itself - self.threads.delete_one({"_id": thread_id}) - - async def list_threads(self, pagination: "Pagination", filters: "ThreadFilter") -> "PaginatedResponse[ThreadDict]": - from pymongo import DESCENDING - - if not filters.userIdentifier: - raise ValueError("userIdentifier is required") - - client_filters = ClientThreadFilter( - participantsIdentifier=StringListFilter(operator="in", value=[filters.userIdentifier]), - ) - if filters.search: - client_filters.search = StringFilter(operator="ilike", value=filters.search) - if filters.feedback: - client_filters.feedbacksValue = NumberListFilter(operator="in", value=[filters.feedback]) - - # Build MongoDB query based on filters - query: Dict[str, Any] = {"user.identifier": filters.userIdentifier} - if filters.search: - query["$text"] = {"$search": filters.search} - if filters.feedback: - query["feedback.value"] = filters.feedback - - # Apply pagination - skip = 0 if not pagination.cursor else int(pagination.cursor) - threads_cursor = self.threads.find(query).sort([("createdAt", DESCENDING)]).skip(skip).limit(pagination.first) - - threads: List[ThreadDict] = [ - { - "id": thread["_id"], - "createdAt": thread["createdAt"], - "name": thread.get("name"), - "user": thread.get("user"), - "tags": thread.get("tags"), - "metadata": thread.get("metadata"), - "steps": thread.get("steps"), - "elements": thread.get("elements"), - } - for thread in threads_cursor - ] - - has_next_page = len(threads) == pagination.first - end_cursor = str(threads[-1]["id"]) if has_next_page else None - - return PaginatedResponse[ThreadDict]( - data=threads, - pageInfo=PageInfo(hasNextPage=has_next_page, endCursor=end_cursor), - ) - - async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]": - thread_dict = self.threads.find_one({"_id": thread_id}) - if not thread_dict: - return None - - elements: List[ElementDict] = [] - steps: List[StepDict] = [] - - steps_cursor = self.steps.find({"threadId": thread_id}).sort([("createdAt", 1)]) - for step in steps_cursor: - if config.ui.hide_cot and step.get("parentId"): - continue - for attachment in step.get("attachments", []): - elements.append(self.attachment_to_element_dict(attachment)) - if not config.features.prompt_playground and step.get("generation"): - step.pop("generation", None) - steps.append(self.step_to_step_dict(step)) - - user = None # type: Optional["UserDict"] - - if thread_dict.get("user"): - user = { - "id": str(thread_dict["user"]["_id"]), - "identifier": thread_dict["user"]["identifier"], - "metadata": thread_dict["user"]["metadata"], - } - - return { - "createdAt": thread_dict["createdAt"], - "id": str(thread_dict["_id"]), - "name": thread_dict.get("name"), - "steps": steps, - "elements": elements, - "metadata": thread_dict.get("metadata"), - "user": user, - "tags": thread_dict.get("tags"), - } - - async def update_thread( - self, - thread_id: str, - name: Optional[str] = None, - user_id: Optional[str] = None, - metadata: Optional[Dict] = None, - tags: Optional[List[str]] = None, - ): - update_dict: Dict[str, Any] = {} - if name is not None: - update_dict["name"] = name - if user_id is not None: - update_dict["user"] = {"_id": user_id} - if metadata is not None: - update_dict["metadata"] = metadata - if tags is not None: - update_dict["tags"] = tags - - self.threads.update_one({"_id": thread_id}, {"$set": update_dict}) - - async def delete_user_session(self, id: str) -> bool: - if not self.threads.delete_one({"metadata.id": id}): - return False - return True diff --git a/backend/chainlit/data/mongodb_api.py b/backend/chainlit/data/mongodb_api.py new file mode 100644 index 0000000000..02e67235ed --- /dev/null +++ b/backend/chainlit/data/mongodb_api.py @@ -0,0 +1,632 @@ +import mimetypes +from bson import ObjectId +import os + +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from chainlit import logger +from pymongo import MongoClient, DESCENDING, ASCENDING +from pymongo.errors import DuplicateKeyError, PyMongoError +from literalai.my_types import ( + Attachment, + AttachmentDict, + ChatGeneration, + CompletionGeneration, + Feedback, + FeedbackDict, + FeedbackStrategy, + PageInfo, + PaginatedResponse, + User, + UserDict, +) +from literalai.step import Step, StepDict, StepType +from literalai.thread import Thread +from literalai.helper import utc_now +from literalai.thread import ThreadFilter +from literalai.api import API as LiteralAIAPI + + +class API: + def __init__(self, mongodb_uri: str): + # MongoDB client + self.mongodb_client: MongoClient = MongoClient(mongodb_uri) + db = self.mongodb_client.get_database() + self.users = db["users"] + self.threads = db["threads"] + self.feedback = db["feedback"] + self.attachments = db["attachments"] + self.steps = db["steps"] + self.generations = db["generations"] + + # Create indexes for faster querying + try: + self.users.create_index([("identifier", ASCENDING)], background=True) + self.threads.create_index( + [("user.identifier", ASCENDING), ("createdAt", DESCENDING)], + background=True, + ) + self.steps.create_index([("threadId", ASCENDING)], background=True) + self.attachments.create_index([("threadId", ASCENDING), ("stepId", ASCENDING)], background=True) + self.feedback.create_index([("threadId", ASCENDING), ("stepId", ASCENDING)], background=True) + except DuplicateKeyError: + logger.info("Indexes already exist!") + except PyMongoError as e: + logger.warning("Errors creating indexes MongoDB: %r", e) + + logger.info("Mongo API initialized") + + # User API + + async def create_user(self, identifier: str, metadata: Optional[Dict] = None) -> User: + user_data = { + "identifier": identifier, + "metadata": metadata, + "createdAt": utc_now(), + } + + user_id = self.users.insert_one(user_data).inserted_id + + user_data["id"] = str(user_id) + + logger.info("User created: %r", user_data) + + return User.from_dict(user_data) + + async def update_user(self, id: str, identifier: Optional[str] = None, metadata: Optional[Dict] = None) -> User: + update_data: UserDict = {} + if identifier is not None: + update_data["identifier"] = identifier + if metadata is not None: + update_data["metadata"] = metadata + + self.users.update_one({"_id": ObjectId(id)}, {"$set": update_data}) + + user = self.users.find_one({"_id": ObjectId(id)}) + assert user is not None, "User not found" + user["id"] = str(user["_id"]) + + logger.info("User updated: %r", user) + + return User.from_dict(user) + + async def get_user(self, id: Optional[str] = None, identifier: Optional[str] = None) -> Optional[User]: + query = {} + if id is not None: + query["_id"] = id + elif identifier is not None: + query["identifier"] = identifier + + user = self.users.find_one(query) + + if user: + if user.get("deactivate"): + logger.info("User deactivated: %r", user) + return None + + user["id"] = str(user["_id"]) + + logger.info("User found: %r", user) + + return User.from_dict(user) if user else None + + async def delete_user(self, id: str) -> str: + result = self.users.update_one( + {"_id": ObjectId(id)}, + {"$set": {"deactivate": True, "deactivateAt": utc_now()}}, + ) + + if result.modified_count > 0: + logger.info("User deactivated: %r", result) + else: + logger.warning("User not found: %r", result) + + return id + + # Thread API + + async def list_threads( + self, + first: Optional[int] = None, + after: Optional[str] = None, + filters: Optional[ThreadFilter] = None, + ) -> PaginatedResponse: + # TODO: currently filters is not correct, will need to be fixed so we can get the threads displayed + # Run the mongodb locally and see what the threads collection looks like then update the query below + query: Dict[str, Any] = {} + if filters: + # Implement filter logic using MongoDB query operators + if filters.createdAt: + query["createdAt"] = {"$" + filters.createdAt.operator: filters.createdAt.value} + if filters.afterCreatedAt: + query["createdAt"] = {"$" + filters.afterCreatedAt.operator: filters.afterCreatedAt.value} + if filters.beforeCreatedAt: + query["createdAt"] = {"$" + filters.beforeCreatedAt.operator: filters.beforeCreatedAt.value} + if filters.environment: + query["environment"] = {"$" + filters.environment.operator: filters.environment.value} + if filters.feedbacksValue: + query["feedbacks.value"] = {"$" + filters.feedbacksValue.operator: filters.feedbacksValue.value} + if filters.participantsIdentifier: + query["participant.identifier"] = { + "$" + filters.participantsIdentifier.operator: filters.participantsIdentifier.value + } + if filters.search: + query["$text"] = {"$search": filters.search.value} + + logger.info("Threads query: %r", query) + + threads = self.threads.find(query).sort("createdAt", DESCENDING) + + if first: + threads = threads.limit(first) + + if after: + # Implement pagination logic using MongoDB cursor + threads = threads.skip(int(after)) + + thread_data = [ + Thread.from_dict(thread) for thread in threads + ] # Assuming Thread.from_dict handles conversion from MongoDB document + + # Construct PaginatedResponse with pageInfo and data + has_next_page = len(thread_data) > first if first else False + page_info = PageInfo(hasNextPage=has_next_page, endCursor=str(len(thread_data))) + + logger.info("Threads found: %r", thread_data) + + return PaginatedResponse(pageInfo=page_info, data=thread_data) + + async def create_thread( + self, + name: Optional[str] = None, + metadata: Optional[Dict] = None, + participant_id: Optional[str] = None, + environment: Optional[str] = None, + tags: Optional[List[str]] = None, + ) -> Thread: + thread_data = { + "name": name, + "metadata": metadata, + "participantId": participant_id, + "environment": environment, + "tags": tags, + "createdAt": utc_now(), + } + + thread_id = self.threads.insert_one(thread_data).inserted_id + + thread_data["id"] = str(thread_id) + + logger.info("Thread created: %r", thread_data) + + return Thread.from_dict(thread_data) + + async def upsert_thread( + self, + thread_id: str, + name: Optional[str] = None, + metadata: Optional[Dict] = None, + participant_id: Optional[str] = None, + environment: Optional[str] = None, + tags: Optional[List[str]] = None, + ) -> Thread: + update_data: Dict[str, Any] = {} + if name is not None: + update_data["name"] = name + if metadata is not None: + update_data["metadata"] = metadata + if participant_id is not None: + update_data["participantId"] = participant_id + if environment is not None: + update_data["environment"] = environment + if tags is not None: + update_data["tags"] = tags + + self.threads.update_one({"_id": thread_id}, {"$set": update_data}, upsert=True) + + thread = self.threads.find_one({"_id": thread_id}) + assert thread is not None, "Thread not found" + thread["id"] = str(thread["_id"]) + + logger.info("Thread updated: %r", thread) + + return Thread.from_dict(thread) + + async def update_thread( + self, + id: str, + name: Optional[str] = None, + metadata: Optional[Dict] = None, + participant_id: Optional[str] = None, + environment: Optional[str] = None, + tags: Optional[List[str]] = None, + ) -> Thread: + # Implementation similar to upsert_thread, but without upsert=True + update_data: Dict[str, Any] = {} + if name is not None: + update_data["name"] = name + if metadata is not None: + update_data["metadata"] = metadata + if participant_id is not None: + update_data["participantId"] = participant_id + if environment is not None: + update_data["environment"] = environment + if tags is not None: + update_data["tags"] = tags + + self.threads.update_one({"_id": ObjectId(id)}, {"$set": update_data}) + + thread = self.threads.find_one({"_id": ObjectId(id)}) + assert thread is not None, "Thread not found" + thread["id"] = str(thread["_id"]) + + logger.info("Thread updated: %r", thread) + + return Thread.from_dict(thread) + + async def get_thread(self, id: str) -> Optional[Thread]: + thread = self.threads.find_one({"_id": ObjectId(id)}) + + if thread: + if thread.get("deactivate"): + logger.info("Thread deactivated: %r", thread) + return None + + thread["id"] = str(thread["_id"]) + + logger.info("Thread found: %r", thread) + + return Thread.from_dict(thread) if thread else None + + async def delete_thread(self, id: str) -> bool: + result = self.threads.update_one( + {"_id": ObjectId(id)}, + {"$set": {"deactivate": True, "deactivateAt": utc_now()}}, + ) + + logger.info("Thread deactivated: %r", result) + + return result.modified_count > 0 + + # Feedback API + + async def create_feedback( + self, + step_id: str, + value: int, + comment: Optional[str] = None, + strategy: Optional[FeedbackStrategy] = None, + ) -> Feedback: + assert strategy is not None, "Feedback strategy must be provided" + feedback_data: FeedbackDict = { + "stepId": step_id, + "value": value, + "comment": comment, + "strategy": strategy, + } + + feedback_id = self.feedback.insert_one(feedback_data).inserted_id + + feedback_data["id"] = str(feedback_id) + + logger.info("Feedback created: %r", feedback_data) + + return Feedback.from_dict(feedback_data) + + async def update_feedback( + self, + id: str, + update_params: LiteralAIAPI.FeedbackUpdate, + ) -> "Feedback": + update_data: FeedbackDict = {} + if update_params.get("comment") is not None: + update_data["comment"] = update_params["comment"] + if update_params.get("value") is not None: + update_data["value"] = update_params["value"] + if update_params.get("strategy") is not None: + assert update_params["strategy"] is not None, "Feedback strategy must be provided" + update_data["strategy"] = update_params["strategy"] + + self.feedback.update_one({"_id": ObjectId(id)}, {"$set": update_data}) + + feedback = self.feedback.find_one({"_id": ObjectId(id)}) + assert feedback is not None, "Feedback not found" + feedback["id"] = str(feedback["_id"]) + + logger.info("Feedback updated: %r", feedback) + + return Feedback.from_dict(feedback) + + # Attachment API + + async def create_attachment( + self, + thread_id: str, + step_id: str, + id: Optional[str] = None, + metadata: Optional[Dict] = None, + mime: Optional[str] = None, + name: Optional[str] = None, + object_key: Optional[str] = None, + url: Optional[str] = None, + content: Optional[Union[bytes, str]] = None, + path: Optional[str] = None, + ) -> Attachment: + if not content and not url and not path: + raise Exception("Either content, path or attachment url must be provided") + + if content and path: + raise Exception("Only one of content and path must be provided") + + if (content and url) or (path and url): + raise Exception("Only one of content, path and attachment url must be provided") + + if path: + # TODO: if attachment.mime is text, we could read as text? + with open(path, "rb") as f: + content = f.read() + if not name: + name = path.split("/")[-1] + if not mime: + mime, _ = mimetypes.guess_type(path) + mime = mime or "application/octet-stream" + + if not name: + raise Exception("Attachment name must be provided") + + if content: + uploaded = await self.upload_file(content=content, thread_id=thread_id, mime=mime) + + if uploaded["object_key"] is None or uploaded["url"] is None: + raise Exception("Failed to upload file") + + object_key = uploaded["object_key"] + url = None + if not object_key: + url = uploaded["url"] + + attachment_data: AttachmentDict = { + "threadId": thread_id, + "stepId": step_id, + "metadata": metadata, + "mime": mime, + "name": name, + "objectKey": object_key, + "url": url, + } + + # Use generated id as _id if available + attachment_id = self.attachments.insert_one({"_id": ObjectId(id)} if id else {} | attachment_data).inserted_id + + attachment_data["id"] = str(attachment_id) + + logger.info("Attachment created: %r", attachment_data) + + return Attachment.from_dict(attachment_data) + + async def update_attachment( + self, + id: str, + update_params: LiteralAIAPI.AttachmentUpload, + ) -> Attachment: + update_data: AttachmentDict = {} + if update_params.get("metadata") is not None: + update_data["metadata"] = update_params["metadata"] + if update_params.get("mime") is not None: + update_data["mime"] = update_params["mime"] + if update_params.get("name") is not None: + update_data["name"] = update_params["name"] + if update_params.get("objectKey") is not None: + update_data["objectKey"] = update_params["objectKey"] + if update_params.get("url") is not None: + update_data["url"] = update_params["url"] + + self.attachments.update_one({"_id": ObjectId(id)}, {"$set": update_data}) + + attachment = self.attachments.find_one({"_id": ObjectId(id)}) + assert attachment is not None, "Attachment not found" + attachment["id"] = str(attachment["_id"]) + + logger.info("Attachment updated: %r", attachment) + + return Attachment.from_dict(attachment) + + async def get_attachment(self, id: str) -> Optional[Attachment]: + attachment = self.attachments.find_one({"_id": ObjectId(id)}) + + if attachment: + if attachment.get("deactivate"): + logger.info("Attachment deactivated: %r", attachment) + return None + + attachment["id"] = str(attachment["_id"]) + + logger.info("Attachment found: %r", attachment) + + return Attachment.from_dict(attachment) if attachment else None + + async def delete_attachment(self, id: str) -> str: + attachment = self.attachments.find_one({"_id": ObjectId(id)}) + + if attachment and attachment.get("objectKey"): + await self.delete_file(attachment["objectKey"]) + + self.attachments.update_one( + {"_id": ObjectId(id)}, + {"$set": {"deactivate": True, "deactivateAt": utc_now()}}, + ) + + logger.info("Attachment deactivated: %r", attachment) + + return id + + # Step API + + async def create_step( + self, + thread_id: Optional[str] = None, + type: Optional[StepType] = "undefined", + start_time: Optional[str] = None, + end_time: Optional[str] = None, + input: Optional[Dict] = None, + output: Optional[Dict] = None, + metadata: Optional[Dict] = None, + parent_id: Optional[str] = None, + name: Optional[str] = None, + tags: Optional[List[str]] = None, + ) -> Step: + step_data: StepDict = { + "threadId": thread_id, + "type": type, + "startTime": start_time, + "endTime": end_time, + "input": input, + "output": output, + "metadata": metadata, + "parentId": parent_id, + "name": name, + "tags": tags, + } + + step_id = self.steps.insert_one(step_data).inserted_id + + step_data["id"] = str(step_id) + + logger.info("Step created: %r", step_data) + + return Step.from_dict(step_data) + + async def update_step( + self, + id: str, + type: Optional[StepType] = None, + input: Optional[Dict] = None, + output: Optional[Dict] = None, + metadata: Optional[Dict] = None, + name: Optional[str] = None, + tags: Optional[List[str]] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + parent_id: Optional[str] = None, + ) -> "Step": + update_data: StepDict = {} + if type is not None: + update_data["type"] = type + if input is not None: + update_data["input"] = input + if output is not None: + update_data["output"] = output + if metadata is not None: + update_data["metadata"] = metadata + if name is not None: + update_data["name"] = name + if tags is not None: + update_data["tags"] = tags + if start_time is not None: + update_data["startTime"] = start_time + if end_time is not None: + update_data["endTime"] = end_time + if parent_id is not None: + update_data["parentId"] = parent_id + + self.steps.update_one({"_id": ObjectId(id)}, {"$set": update_data}) + + step = self.steps.find_one({"_id": ObjectId(id)}) + assert step is not None, "Step not found" + step["id"] = str(step["_id"]) + + logger.info("Step updated: %r", step) + + return Step.from_dict(step) + + async def get_step(self, id: str) -> Optional[Step]: + step = self.steps.find_one({"_id": ObjectId(id)}) + + if step: + if step.get("deactivate"): + logger.info("Step deactivated: %r", step) + return None + step["id"] = str(step["_id"]) + + logger.info("Step found: %r", step) + + return Step.from_dict(step) if step else None + + async def delete_step(self, id: str) -> bool: + result = self.steps.update_one( + {"_id": ObjectId(id)}, + {"$set": {"deactivate": True, "deactivateAt": utc_now()}}, + ) + + logger.info("Step deactivated: %r", result) + + return result.modified_count > 0 + + async def send_steps(self, steps: List[Union[StepDict, Step]]) -> Dict: + step_data: List[StepDict] = [step.to_dict() if isinstance(step, Step) else step for step in steps] + + for step in step_data: + if step.get("id") is not None: + self.steps.update_one({"_id": step.pop("id")}, {"$set": step}, upsert=True) + logger.info("Step updated: %r", step) + else: + self.steps.insert_one(step) + logger.info("Step created: %r", step) + + return {"ok": True, "message": "Steps ingested successfully"} + + # Generation API + + async def create_generation(self, generation: Union[ChatGeneration, CompletionGeneration]) -> str: + generation_data = generation.to_dict() + + logger.info("Generation created: %r", generation_data) + + return str(self.generations.insert_one(generation_data).inserted_id) + + # Blob file storage API. Overwrite these methods to use a different blob storage provider. + + async def upload_file( + self, + content: Union[bytes, str], + thread_id: str, + mime: Optional[str] = "application/octet-stream", + ) -> Dict: + id = str(ObjectId()) + s3_object_key = f"attachments/{thread_id}/{id}" + assert mime is not None, "MIME type is required" + + CHAINLIT_S3_BUCKET = os.environ.get("CHAINLIT_S3_BUCKET") + import boto3 + + if TYPE_CHECKING: + from mypy_boto3_s3 import S3Client + + assert CHAINLIT_S3_BUCKET is not None, "CHAINLIT_S3_BUCKET environment variable not set" + + s3_client: S3Client = boto3.client("s3") + s3_client.put_object(Bucket=CHAINLIT_S3_BUCKET, Key=s3_object_key, Body=content, ContentType=mime) + + return { + "object_key": s3_object_key, + "url": f"s3://{CHAINLIT_S3_BUCKET}/{s3_object_key}", + } + + async def delete_file(self, object_key: str): + # Enable bucket versioning to avoid deleting permanently + CHAINLIT_S3_BUCKET = os.environ.get("CHAINLIT_S3_BUCKET") + + assert CHAINLIT_S3_BUCKET is not None, "CHAINLIT_S3_BUCKET environment variable not set" + + import boto3 + + if TYPE_CHECKING: + from mypy_boto3_s3 import S3Client + + s3_client: S3Client = boto3.client("s3") + s3_client.delete_object(Bucket=CHAINLIT_S3_BUCKET, Key=object_key) + + # Dataset API + # TODO: Check if we need Dataset API for custom data layer + + # Prompt API + # TODO: Check if we need Prompt API for custom data layer