Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(chat): added streaming #808

Merged
merged 2 commits into from
Jul 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 4 additions & 15 deletions backend/core/llm/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from abc import abstractmethod
from typing import AsyncIterable, List

from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.callbacks.base import AsyncCallbackHandler
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.llms.base import LLM
from logger import get_logger
from models.settings import BrainSettings # Importing settings related to the 'brain'
from pydantic import BaseModel # For data validation and settings management
from utils.constants import streaming_compatible_models

logger = get_logger(__name__)

Expand All @@ -33,7 +31,7 @@ class BaseBrainPicking(BaseModel):

openai_api_key: str = None # pyright: ignore reportPrivateUsage=none
callbacks: List[
AsyncCallbackHandler
AsyncIteratorCallbackHandler
] = None # pyright: ignore reportPrivateUsage=none

def _determine_api_key(self, openai_api_key, user_openai_api_key):
Expand All @@ -45,23 +43,14 @@ def _determine_api_key(self, openai_api_key, user_openai_api_key):

def _determine_streaming(self, model: str, streaming: bool) -> bool:
"""If the model name allows for streaming and streaming is declared, set streaming to True."""
if model in streaming_compatible_models and streaming:
return True
if model not in streaming_compatible_models and streaming:
logger.warning(
f"Streaming is not compatible with {model}. Streaming will be set to False."
)
return False
else:
return False

return streaming
def _determine_callback_array(
self, streaming
) -> List[AsyncIteratorCallbackHandler]: # pyright: ignore reportPrivateUsage=none
"""If streaming is set, set the AsyncIteratorCallbackHandler as the only callback."""
if streaming:
return [
AsyncIteratorCallbackHandler # pyright: ignore reportPrivateUsage=none
AsyncIteratorCallbackHandler() # pyright: ignore reportPrivateUsage=none
]

def __init__(self, **data):
Expand Down
1 change: 1 addition & 0 deletions backend/core/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,6 @@ def _create_llm(self, model, streaming=False, callbacks=None) -> BaseLLM:
temperature=self.temperature,
model=model,
streaming=streaming,
verbose=True,
callbacks=callbacks,
) # pyright: ignore reportPrivateUsage=none
50 changes: 29 additions & 21 deletions backend/core/llm/qa_base.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import asyncio
import json
from abc import abstractmethod, abstractproperty
from typing import AsyncIterable, Awaitable

from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms.base import BaseLLM
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from logger import get_logger
from models.chat import ChatHistory
from repository.chat.format_chat_history import format_chat_history
from repository.chat.get_chat_history import get_chat_history
from repository.chat.update_chat_history import update_chat_history
from repository.chat.update_message_by_id import update_message_by_id
from supabase.client import Client, create_client
from vectorstore.supabase import CustomSupabaseVectorStore
from langchain.chat_models import ChatOpenAI
from repository.chat.update_message_by_id import update_message_by_id
import json

from .base import BaseBrainPicking
from .prompts.CONDENSE_PROMPT import CONDENSE_QUESTION_PROMPT
Expand Down Expand Up @@ -60,31 +61,31 @@ def supabase_client(self) -> Client:

@property
def vector_store(self) -> CustomSupabaseVectorStore:

return CustomSupabaseVectorStore(
self.supabase_client,
self.embeddings,
table_name="vectors",
brain_id=self.brain_id,
)

@property
def question_llm(self):
return self._create_llm(model=self.model, streaming=False)

@property
def doc_llm(self):
return self._create_llm(
model=self.model, streaming=self.streaming, callbacks=self.callbacks
model=self.model, streaming=True, callbacks=self.callbacks
)

@property
def question_generator(self) -> LLMChain:
return LLMChain(llm=self.question_llm, prompt=CONDENSE_QUESTION_PROMPT)
return LLMChain(llm=self.question_llm, prompt=CONDENSE_QUESTION_PROMPT, verbose=True)

@property
def doc_chain(self) -> LLMChain:
return load_qa_chain(
llm=self.doc_llm, chain_type="stuff"
llm=self.doc_llm, chain_type="stuff", verbose=True
) # pyright: ignore reportPrivateUsage=none

@property
Expand Down Expand Up @@ -170,10 +171,20 @@ async def generate_stream(self, question: str) -> AsyncIterable:
:param question: The question
:return: An async iterable which generates the answer.
"""

history = get_chat_history(self.chat_id)
callback = self.callbacks[0]

callback = AsyncIteratorCallbackHandler()
self.callbacks = [callback]
model = ChatOpenAI(
streaming=True,
verbose=True,
callbacks=[callback],
)
llm = ChatOpenAI(temperature=0)
question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT)
doc_chain = load_qa_chain(model, chain_type="stuff")
qa = ConversationalRetrievalChain(
retriever=self.vector_store.as_retriever(), combine_docs_chain=doc_chain, question_generator=question_generator)
transformed_history = []

# Format the chat history into a list of tuples (human, ai)
Expand All @@ -183,23 +194,21 @@ async def generate_stream(self, question: str) -> AsyncIterable:
response_tokens = []

# Wrap an awaitable with a event to signal when it's done or an exception is raised.

async def wrap_done(fn: Awaitable, event: asyncio.Event):
try:
await fn
except Exception as e:
logger.error(f"Caught exception: {e}")
finally:
event.set()

task = asyncio.create_task(
wrap_done(
self.qa._acall_chain( # pyright: ignore reportPrivateUsage=none
self.qa, question, transformed_history
),
callback.done, # pyright: ignore reportPrivateUsage=none
)
)

# Begin a task that runs in the background.

run = asyncio.create_task(wrap_done(
qa.acall({"question": question, "chat_history": transformed_history}),
callback.done,
))

streamed_chat_history = update_chat_history(
chat_id=self.chat_id,
user_message=question,
Expand All @@ -216,8 +225,7 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event):

yield f"data: {json.dumps(streamed_chat_history.to_dict())}"

await task

await run
# Join the tokens to create the assistant's response
assistant = "".join(response_tokens)

Expand Down
25 changes: 8 additions & 17 deletions backend/core/routes/chat_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from http.client import HTTPException
from typing import List
from uuid import UUID
from venv import logger

from auth import AuthBearer, get_current_user
from fastapi import APIRouter, Depends, Query, Request
Expand All @@ -18,9 +19,6 @@
from repository.chat.get_chat_history import get_chat_history
from repository.chat.get_user_chats import get_user_chats
from repository.chat.update_chat import ChatUpdatableProperties, update_chat
from utils.constants import (
streaming_compatible_models,
)

chat_router = APIRouter()

Expand Down Expand Up @@ -228,33 +226,26 @@ async def create_stream_question_handler(
current_user: User = Depends(get_current_user),
) -> StreamingResponse:
# TODO: check if the user has access to the brain
if not brain_id:
brain_id = get_default_user_brain_or_create_new(current_user).id

if chat_question.model not in streaming_compatible_models:
# Forward the request to the none streaming endpoint
return await create_question_handler(
request,
chat_question,
chat_id,
current_user, # pyright: ignore reportPrivateUsage=none
)

try:
user_openai_api_key = request.headers.get("Openai-Api-Key")
streaming = True
logger.info(f"Streaming request for {chat_question.model}")
check_user_limit(current_user)
if not brain_id:
brain_id = get_default_user_brain_or_create_new(current_user).id


gpt_answer_generator = OpenAIBrainPicking(
chat_id=str(chat_id),
model=chat_question.model,
max_tokens=chat_question.max_tokens,
temperature=chat_question.temperature,
brain_id=str(brain_id),
user_openai_api_key=user_openai_api_key, # pyright: ignore reportPrivateUsage=none
streaming=streaming,
user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none
streaming=True,
)

print("streaming")
return StreamingResponse(
gpt_answer_generator.generate_stream( # pyright: ignore reportPrivateUsage=none
chat_question.question
Expand Down
3 changes: 2 additions & 1 deletion backend/core/vectorstore/supabase.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@ def __init__(
def similarity_search(
self,
query: str,
table: str = "match_vectors",
k: int = 6,
table: str = "match_vectors",
threshold: float = 0.5,
**kwargs: Any
) -> List[Document]:

vectors = self._embedding.embed_documents([query])
query_embedding = vectors[0]
res = self._client.rpc(
Expand Down
12 changes: 6 additions & 6 deletions frontend/app/chat/[chatId]/hooks/useChat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ import { useChatContext } from "@/lib/context/ChatProvider/hooks/useChatContext"
import { useToast } from "@/lib/hooks";
import { useEventTracking } from "@/services/analytics/useEventTracking";

import { useQuestion } from "./useQuestion";
import { ChatQuestion } from "../types";
import { useQuestion } from "./useQuestion";



// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types
export const useChat = () => {
Expand Down Expand Up @@ -68,11 +70,9 @@ export const useChat = () => {

void track("QUESTION_ASKED");

if (chatQuestion.model === "gpt-3.5-turbo") {
await addStreamQuestion(currentChatId, chatQuestion);
} else {
await addQuestionToModel(currentChatId, chatQuestion);
}

await addStreamQuestion(currentChatId, chatQuestion);


callback?.();
} catch (error) {
Expand Down
2 changes: 1 addition & 1 deletion frontend/app/chat/[chatId]/hooks/useQuestion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ export const useQuestion = (): UseChatService => {
Accept: "text/event-stream",
};
const body = JSON.stringify(chatQuestion);

console.log("Calling API...");
try {
const response = await fetchInstance.post(
`/chat/${chatId}/question/stream?brain_id=${currentBrain.id}`,
Expand Down
1 change: 1 addition & 0 deletions frontend/lib/context/BrainConfigProvider/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ export type BrainConfigContextType = {
// export const openAiModels = ["gpt-3.5-turbo", "gpt-4"] as const; ## TODO activate GPT4 when not in demo mode

export const openAiModels = [
"gpt-3.5-turbo",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k",
] as const;
Expand Down