From acb59495116ff36f1f603af318294ecb563352c0 Mon Sep 17 00:00:00 2001 From: Stan Girard Date: Mon, 31 Jul 2023 18:27:27 +0200 Subject: [PATCH 1/2] feat(tmp): added streaming --- backend/core/llm/base.py | 20 +++------ backend/core/llm/openai.py | 1 + backend/core/llm/qa_base.py | 44 ++++++++++++------- backend/core/routes/chat_routes.py | 25 ++++------- backend/core/vectorstore/supabase.py | 7 +++ frontend/app/chat/[chatId]/hooks/useChat.ts | 12 ++--- .../app/chat/[chatId]/hooks/useQuestion.ts | 2 +- .../lib/context/BrainConfigProvider/types.ts | 1 + 8 files changed, 56 insertions(+), 56 deletions(-) diff --git a/backend/core/llm/base.py b/backend/core/llm/base.py index 0bc0e255d3a0..7b909649bfc5 100644 --- a/backend/core/llm/base.py +++ b/backend/core/llm/base.py @@ -1,14 +1,12 @@ from abc import abstractmethod from typing import AsyncIterable, List -from langchain.callbacks import AsyncIteratorCallbackHandler -from langchain.callbacks.base import AsyncCallbackHandler +from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler from langchain.chains import ConversationalRetrievalChain, LLMChain from langchain.llms.base import LLM from logger import get_logger from models.settings import BrainSettings # Importing settings related to the 'brain' from pydantic import BaseModel # For data validation and settings management -from utils.constants import streaming_compatible_models logger = get_logger(__name__) @@ -33,7 +31,7 @@ class BaseBrainPicking(BaseModel): openai_api_key: str = None # pyright: ignore reportPrivateUsage=none callbacks: List[ - AsyncCallbackHandler + AsyncIteratorCallbackHandler ] = None # pyright: ignore reportPrivateUsage=none def _determine_api_key(self, openai_api_key, user_openai_api_key): @@ -45,23 +43,15 @@ def _determine_api_key(self, openai_api_key, user_openai_api_key): def _determine_streaming(self, model: str, streaming: bool) -> bool: """If the model name allows for streaming and streaming is declared, set streaming to True.""" - if model in streaming_compatible_models and streaming: - return True - if model not in streaming_compatible_models and streaming: - logger.warning( - f"Streaming is not compatible with {model}. Streaming will be set to False." - ) - return False - else: - return False - + return streaming def _determine_callback_array( self, streaming ) -> List[AsyncIteratorCallbackHandler]: # pyright: ignore reportPrivateUsage=none """If streaming is set, set the AsyncIteratorCallbackHandler as the only callback.""" if streaming: + print("Streaming is enabled. Callbacks will be set to AsyncIteratorCallbackHandler.") return [ - AsyncIteratorCallbackHandler # pyright: ignore reportPrivateUsage=none + AsyncIteratorCallbackHandler() # pyright: ignore reportPrivateUsage=none ] def __init__(self, **data): diff --git a/backend/core/llm/openai.py b/backend/core/llm/openai.py index 31bf50301692..f03820e7aa50 100644 --- a/backend/core/llm/openai.py +++ b/backend/core/llm/openai.py @@ -58,5 +58,6 @@ def _create_llm(self, model, streaming=False, callbacks=None) -> BaseLLM: temperature=self.temperature, model=model, streaming=streaming, + verbose=True, callbacks=callbacks, ) # pyright: ignore reportPrivateUsage=none diff --git a/backend/core/llm/qa_base.py b/backend/core/llm/qa_base.py index 9752af44d276..faf4ae967adb 100644 --- a/backend/core/llm/qa_base.py +++ b/backend/core/llm/qa_base.py @@ -1,8 +1,6 @@ import asyncio -import json from abc import abstractmethod, abstractproperty from typing import AsyncIterable, Awaitable - from langchain.chains import ConversationalRetrievalChain, LLMChain from langchain.chains.question_answering import load_qa_chain from langchain.embeddings.openai import OpenAIEmbeddings @@ -12,9 +10,10 @@ from repository.chat.format_chat_history import format_chat_history from repository.chat.get_chat_history import get_chat_history from repository.chat.update_chat_history import update_chat_history -from repository.chat.update_message_by_id import update_message_by_id from supabase.client import Client, create_client from vectorstore.supabase import CustomSupabaseVectorStore +from langchain.callbacks import AsyncIteratorCallbackHandler +from langchain.chat_models import ChatOpenAI from .base import BaseBrainPicking from .prompts.CONDENSE_PROMPT import CONDENSE_QUESTION_PROMPT @@ -60,13 +59,15 @@ def supabase_client(self) -> Client: @property def vector_store(self) -> CustomSupabaseVectorStore: + print("Creating vector store") + print("🧠🧠🧠🧠🧠🧠🧠🧠") + print("Brain id: ", self.brain_id) return CustomSupabaseVectorStore( self.supabase_client, self.embeddings, table_name="vectors", brain_id=self.brain_id, ) - @property def question_llm(self): return self._create_llm(model=self.model, streaming=False) @@ -74,21 +75,22 @@ def question_llm(self): @property def doc_llm(self): return self._create_llm( - model=self.model, streaming=self.streaming, callbacks=self.callbacks + model=self.model, streaming=True, callbacks=self.callbacks ) @property def question_generator(self) -> LLMChain: - return LLMChain(llm=self.question_llm, prompt=CONDENSE_QUESTION_PROMPT) + return LLMChain(llm=self.question_llm, prompt=CONDENSE_QUESTION_PROMPT, verbose=True) @property def doc_chain(self) -> LLMChain: return load_qa_chain( - llm=self.doc_llm, chain_type="stuff" + llm=self.doc_llm, chain_type="stuff", verbose=True ) # pyright: ignore reportPrivateUsage=none @property def qa(self) -> ConversationalRetrievalChain: + print("Creating QA chain") return ConversationalRetrievalChain( retriever=self.vector_store.as_retriever(), question_generator=self.question_generator, @@ -170,10 +172,21 @@ async def generate_stream(self, question: str) -> AsyncIterable: :param question: The question :return: An async iterable which generates the answer. """ - + print("Generating stream") history = get_chat_history(self.chat_id) callback = self.callbacks[0] - + print(self.callbacks) + callback = AsyncIteratorCallbackHandler() + model = ChatOpenAI( + streaming=True, + verbose=True, + callbacks=[callback], + ) + llm = ChatOpenAI(temperature=0) + question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT) + doc_chain = load_qa_chain(model, chain_type="stuff") + qa = ConversationalRetrievalChain( + retriever=self.vector_store.as_retriever(), combine_docs_chain=doc_chain, question_generator=question_generator) transformed_history = [] # Format the chat history into a list of tuples (human, ai) @@ -190,14 +203,11 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event): logger.error(f"Caught exception: {e}") finally: event.set() - - task = asyncio.create_task( - wrap_done( - self.qa._acall_chain( # pyright: ignore reportPrivateUsage=none - self.qa, question, transformed_history - ), - callback.done, # pyright: ignore reportPrivateUsage=none - ) + print("Calling chain") + # Begin a task that runs in the background. + task = asyncio.create_task(wrap_done( + await qa.acall({"question": question, "chat_history": transformed_history}, include_run_info=True), + callback.done), ) streamed_chat_history = update_chat_history( diff --git a/backend/core/routes/chat_routes.py b/backend/core/routes/chat_routes.py index b6b445de0906..0991fddef51b 100644 --- a/backend/core/routes/chat_routes.py +++ b/backend/core/routes/chat_routes.py @@ -3,6 +3,7 @@ from http.client import HTTPException from typing import List from uuid import UUID +from venv import logger from auth import AuthBearer, get_current_user from fastapi import APIRouter, Depends, Query, Request @@ -18,9 +19,6 @@ from repository.chat.get_chat_history import get_chat_history from repository.chat.get_user_chats import get_user_chats from repository.chat.update_chat import ChatUpdatableProperties, update_chat -from utils.constants import ( - streaming_compatible_models, -) chat_router = APIRouter() @@ -228,22 +226,14 @@ async def create_stream_question_handler( current_user: User = Depends(get_current_user), ) -> StreamingResponse: # TODO: check if the user has access to the brain - if not brain_id: - brain_id = get_default_user_brain_or_create_new(current_user).id - - if chat_question.model not in streaming_compatible_models: - # Forward the request to the none streaming endpoint - return await create_question_handler( - request, - chat_question, - chat_id, - current_user, # pyright: ignore reportPrivateUsage=none - ) try: user_openai_api_key = request.headers.get("Openai-Api-Key") - streaming = True + logger.info(f"Streaming request for {chat_question.model}") check_user_limit(current_user) + if not brain_id: + brain_id = get_default_user_brain_or_create_new(current_user).id + gpt_answer_generator = OpenAIBrainPicking( chat_id=str(chat_id), @@ -251,10 +241,11 @@ async def create_stream_question_handler( max_tokens=chat_question.max_tokens, temperature=chat_question.temperature, brain_id=str(brain_id), - user_openai_api_key=user_openai_api_key, # pyright: ignore reportPrivateUsage=none - streaming=streaming, + user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none + streaming=True, ) + print("streaming") return StreamingResponse( gpt_answer_generator.generate_stream( # pyright: ignore reportPrivateUsage=none chat_question.question diff --git a/backend/core/vectorstore/supabase.py b/backend/core/vectorstore/supabase.py index 1018f5a67f49..32ee2de216c0 100644 --- a/backend/core/vectorstore/supabase.py +++ b/backend/core/vectorstore/supabase.py @@ -29,6 +29,13 @@ def similarity_search( threshold: float = 0.5, **kwargs: Any ) -> List[Document]: + print("Everything is fine 🔥🔥🔥🔥🔥🔥🔥🔥🔥🔥") + print("Query: ", query) + print("Table: ", table) + print("K: ", k) + print("Threshold: ", threshold) + print("Kwargs: ", kwargs) + print("Brain ID: ", self.brain_id) vectors = self._embedding.embed_documents([query]) query_embedding = vectors[0] res = self._client.rpc( diff --git a/frontend/app/chat/[chatId]/hooks/useChat.ts b/frontend/app/chat/[chatId]/hooks/useChat.ts index 6b3596c260a7..58ea5f992d12 100644 --- a/frontend/app/chat/[chatId]/hooks/useChat.ts +++ b/frontend/app/chat/[chatId]/hooks/useChat.ts @@ -9,8 +9,10 @@ import { useChatContext } from "@/lib/context/ChatProvider/hooks/useChatContext" import { useToast } from "@/lib/hooks"; import { useEventTracking } from "@/services/analytics/useEventTracking"; -import { useQuestion } from "./useQuestion"; import { ChatQuestion } from "../types"; +import { useQuestion } from "./useQuestion"; + + // eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types export const useChat = () => { @@ -68,11 +70,9 @@ export const useChat = () => { void track("QUESTION_ASKED"); - if (chatQuestion.model === "gpt-3.5-turbo") { - await addStreamQuestion(currentChatId, chatQuestion); - } else { - await addQuestionToModel(currentChatId, chatQuestion); - } + + await addStreamQuestion(currentChatId, chatQuestion); + callback?.(); } catch (error) { diff --git a/frontend/app/chat/[chatId]/hooks/useQuestion.ts b/frontend/app/chat/[chatId]/hooks/useQuestion.ts index 0a23d00456a0..6ec0fc4fdedb 100644 --- a/frontend/app/chat/[chatId]/hooks/useQuestion.ts +++ b/frontend/app/chat/[chatId]/hooks/useQuestion.ts @@ -79,7 +79,7 @@ export const useQuestion = (): UseChatService => { Accept: "text/event-stream", }; const body = JSON.stringify(chatQuestion); - + console.log("Calling API..."); try { const response = await fetchInstance.post( `/chat/${chatId}/question/stream?brain_id=${currentBrain.id}`, diff --git a/frontend/lib/context/BrainConfigProvider/types.ts b/frontend/lib/context/BrainConfigProvider/types.ts index 28d8ee3048e9..8deba8bba921 100644 --- a/frontend/lib/context/BrainConfigProvider/types.ts +++ b/frontend/lib/context/BrainConfigProvider/types.ts @@ -21,6 +21,7 @@ export type BrainConfigContextType = { // export const openAiModels = ["gpt-3.5-turbo", "gpt-4"] as const; ## TODO activate GPT4 when not in demo mode export const openAiModels = [ + "gpt-3.5-turbo", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k", ] as const; From 49b24e73885d9a251bbaf011e9e86016840d34de Mon Sep 17 00:00:00 2001 From: Stan Girard Date: Mon, 31 Jul 2023 21:31:27 +0200 Subject: [PATCH 2/2] feat(streaming): implemented by changing order --- backend/core/llm/base.py | 1 - backend/core/llm/qa_base.py | 28 +++++++++++++--------------- backend/core/vectorstore/supabase.py | 10 ++-------- 3 files changed, 15 insertions(+), 24 deletions(-) diff --git a/backend/core/llm/base.py b/backend/core/llm/base.py index 7b909649bfc5..721ad772d788 100644 --- a/backend/core/llm/base.py +++ b/backend/core/llm/base.py @@ -49,7 +49,6 @@ def _determine_callback_array( ) -> List[AsyncIteratorCallbackHandler]: # pyright: ignore reportPrivateUsage=none """If streaming is set, set the AsyncIteratorCallbackHandler as the only callback.""" if streaming: - print("Streaming is enabled. Callbacks will be set to AsyncIteratorCallbackHandler.") return [ AsyncIteratorCallbackHandler() # pyright: ignore reportPrivateUsage=none ] diff --git a/backend/core/llm/qa_base.py b/backend/core/llm/qa_base.py index faf4ae967adb..e482ee49b579 100644 --- a/backend/core/llm/qa_base.py +++ b/backend/core/llm/qa_base.py @@ -5,6 +5,7 @@ from langchain.chains.question_answering import load_qa_chain from langchain.embeddings.openai import OpenAIEmbeddings from langchain.llms.base import BaseLLM +from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler from logger import get_logger from models.chat import ChatHistory from repository.chat.format_chat_history import format_chat_history @@ -12,8 +13,9 @@ from repository.chat.update_chat_history import update_chat_history from supabase.client import Client, create_client from vectorstore.supabase import CustomSupabaseVectorStore -from langchain.callbacks import AsyncIteratorCallbackHandler from langchain.chat_models import ChatOpenAI +from repository.chat.update_message_by_id import update_message_by_id +import json from .base import BaseBrainPicking from .prompts.CONDENSE_PROMPT import CONDENSE_QUESTION_PROMPT @@ -59,9 +61,7 @@ def supabase_client(self) -> Client: @property def vector_store(self) -> CustomSupabaseVectorStore: - print("Creating vector store") - print("🧠🧠🧠🧠🧠🧠🧠🧠") - print("Brain id: ", self.brain_id) + return CustomSupabaseVectorStore( self.supabase_client, self.embeddings, @@ -90,7 +90,6 @@ def doc_chain(self) -> LLMChain: @property def qa(self) -> ConversationalRetrievalChain: - print("Creating QA chain") return ConversationalRetrievalChain( retriever=self.vector_store.as_retriever(), question_generator=self.question_generator, @@ -172,11 +171,10 @@ async def generate_stream(self, question: str) -> AsyncIterable: :param question: The question :return: An async iterable which generates the answer. """ - print("Generating stream") history = get_chat_history(self.chat_id) callback = self.callbacks[0] - print(self.callbacks) callback = AsyncIteratorCallbackHandler() + self.callbacks = [callback] model = ChatOpenAI( streaming=True, verbose=True, @@ -196,6 +194,7 @@ async def generate_stream(self, question: str) -> AsyncIterable: response_tokens = [] # Wrap an awaitable with a event to signal when it's done or an exception is raised. + async def wrap_done(fn: Awaitable, event: asyncio.Event): try: await fn @@ -203,13 +202,13 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event): logger.error(f"Caught exception: {e}") finally: event.set() - print("Calling chain") # Begin a task that runs in the background. - task = asyncio.create_task(wrap_done( - await qa.acall({"question": question, "chat_history": transformed_history}, include_run_info=True), - callback.done), - ) - + + run = asyncio.create_task(wrap_done( + qa.acall({"question": question, "chat_history": transformed_history}), + callback.done, + )) + streamed_chat_history = update_chat_history( chat_id=self.chat_id, user_message=question, @@ -226,8 +225,7 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event): yield f"data: {json.dumps(streamed_chat_history.to_dict())}" - await task - + await run # Join the tokens to create the assistant's response assistant = "".join(response_tokens) diff --git a/backend/core/vectorstore/supabase.py b/backend/core/vectorstore/supabase.py index 32ee2de216c0..18900e630d36 100644 --- a/backend/core/vectorstore/supabase.py +++ b/backend/core/vectorstore/supabase.py @@ -24,18 +24,12 @@ def __init__( def similarity_search( self, query: str, - table: str = "match_vectors", k: int = 6, + table: str = "match_vectors", threshold: float = 0.5, **kwargs: Any ) -> List[Document]: - print("Everything is fine 🔥🔥🔥🔥🔥🔥🔥🔥🔥🔥") - print("Query: ", query) - print("Table: ", table) - print("K: ", k) - print("Threshold: ", threshold) - print("Kwargs: ", kwargs) - print("Brain ID: ", self.brain_id) + vectors = self._embedding.embed_documents([query]) query_embedding = vectors[0] res = self._client.rpc(