From 49b24e73885d9a251bbaf011e9e86016840d34de Mon Sep 17 00:00:00 2001 From: Stan Girard Date: Mon, 31 Jul 2023 21:31:27 +0200 Subject: [PATCH] feat(streaming): implemented by changing order --- backend/core/llm/base.py | 1 - backend/core/llm/qa_base.py | 28 +++++++++++++--------------- backend/core/vectorstore/supabase.py | 10 ++-------- 3 files changed, 15 insertions(+), 24 deletions(-) diff --git a/backend/core/llm/base.py b/backend/core/llm/base.py index 7b909649bfc5..721ad772d788 100644 --- a/backend/core/llm/base.py +++ b/backend/core/llm/base.py @@ -49,7 +49,6 @@ def _determine_callback_array( ) -> List[AsyncIteratorCallbackHandler]: # pyright: ignore reportPrivateUsage=none """If streaming is set, set the AsyncIteratorCallbackHandler as the only callback.""" if streaming: - print("Streaming is enabled. Callbacks will be set to AsyncIteratorCallbackHandler.") return [ AsyncIteratorCallbackHandler() # pyright: ignore reportPrivateUsage=none ] diff --git a/backend/core/llm/qa_base.py b/backend/core/llm/qa_base.py index faf4ae967adb..e482ee49b579 100644 --- a/backend/core/llm/qa_base.py +++ b/backend/core/llm/qa_base.py @@ -5,6 +5,7 @@ from langchain.chains.question_answering import load_qa_chain from langchain.embeddings.openai import OpenAIEmbeddings from langchain.llms.base import BaseLLM +from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler from logger import get_logger from models.chat import ChatHistory from repository.chat.format_chat_history import format_chat_history @@ -12,8 +13,9 @@ from repository.chat.update_chat_history import update_chat_history from supabase.client import Client, create_client from vectorstore.supabase import CustomSupabaseVectorStore -from langchain.callbacks import AsyncIteratorCallbackHandler from langchain.chat_models import ChatOpenAI +from repository.chat.update_message_by_id import update_message_by_id +import json from .base import BaseBrainPicking from .prompts.CONDENSE_PROMPT import CONDENSE_QUESTION_PROMPT @@ -59,9 +61,7 @@ def supabase_client(self) -> Client: @property def vector_store(self) -> CustomSupabaseVectorStore: - print("Creating vector store") - print("🧠🧠🧠🧠🧠🧠🧠🧠") - print("Brain id: ", self.brain_id) + return CustomSupabaseVectorStore( self.supabase_client, self.embeddings, @@ -90,7 +90,6 @@ def doc_chain(self) -> LLMChain: @property def qa(self) -> ConversationalRetrievalChain: - print("Creating QA chain") return ConversationalRetrievalChain( retriever=self.vector_store.as_retriever(), question_generator=self.question_generator, @@ -172,11 +171,10 @@ async def generate_stream(self, question: str) -> AsyncIterable: :param question: The question :return: An async iterable which generates the answer. """ - print("Generating stream") history = get_chat_history(self.chat_id) callback = self.callbacks[0] - print(self.callbacks) callback = AsyncIteratorCallbackHandler() + self.callbacks = [callback] model = ChatOpenAI( streaming=True, verbose=True, @@ -196,6 +194,7 @@ async def generate_stream(self, question: str) -> AsyncIterable: response_tokens = [] # Wrap an awaitable with a event to signal when it's done or an exception is raised. + async def wrap_done(fn: Awaitable, event: asyncio.Event): try: await fn @@ -203,13 +202,13 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event): logger.error(f"Caught exception: {e}") finally: event.set() - print("Calling chain") # Begin a task that runs in the background. - task = asyncio.create_task(wrap_done( - await qa.acall({"question": question, "chat_history": transformed_history}, include_run_info=True), - callback.done), - ) - + + run = asyncio.create_task(wrap_done( + qa.acall({"question": question, "chat_history": transformed_history}), + callback.done, + )) + streamed_chat_history = update_chat_history( chat_id=self.chat_id, user_message=question, @@ -226,8 +225,7 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event): yield f"data: {json.dumps(streamed_chat_history.to_dict())}" - await task - + await run # Join the tokens to create the assistant's response assistant = "".join(response_tokens) diff --git a/backend/core/vectorstore/supabase.py b/backend/core/vectorstore/supabase.py index 32ee2de216c0..18900e630d36 100644 --- a/backend/core/vectorstore/supabase.py +++ b/backend/core/vectorstore/supabase.py @@ -24,18 +24,12 @@ def __init__( def similarity_search( self, query: str, - table: str = "match_vectors", k: int = 6, + table: str = "match_vectors", threshold: float = 0.5, **kwargs: Any ) -> List[Document]: - print("Everything is fine 🔥🔥🔥🔥🔥🔥🔥🔥🔥🔥") - print("Query: ", query) - print("Table: ", table) - print("K: ", k) - print("Threshold: ", threshold) - print("Kwargs: ", kwargs) - print("Brain ID: ", self.brain_id) + vectors = self._embedding.embed_documents([query]) query_embedding = vectors[0] res = self._client.rpc(