Skip to content

Commit

Permalink
feat(streaming): implemented by changing order
Browse files Browse the repository at this point in the history
  • Loading branch information
StanGirard committed Jul 31, 2023
1 parent acb5949 commit 49b24e7
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 24 deletions.
1 change: 0 additions & 1 deletion backend/core/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def _determine_callback_array(
) -> List[AsyncIteratorCallbackHandler]: # pyright: ignore reportPrivateUsage=none
"""If streaming is set, set the AsyncIteratorCallbackHandler as the only callback."""
if streaming:
print("Streaming is enabled. Callbacks will be set to AsyncIteratorCallbackHandler.")
return [
AsyncIteratorCallbackHandler() # pyright: ignore reportPrivateUsage=none
]
Expand Down
28 changes: 13 additions & 15 deletions backend/core/llm/qa_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@
from langchain.chains.question_answering import load_qa_chain
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms.base import BaseLLM
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from logger import get_logger
from models.chat import ChatHistory
from repository.chat.format_chat_history import format_chat_history
from repository.chat.get_chat_history import get_chat_history
from repository.chat.update_chat_history import update_chat_history
from supabase.client import Client, create_client
from vectorstore.supabase import CustomSupabaseVectorStore
from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.chat_models import ChatOpenAI
from repository.chat.update_message_by_id import update_message_by_id
import json

from .base import BaseBrainPicking
from .prompts.CONDENSE_PROMPT import CONDENSE_QUESTION_PROMPT
Expand Down Expand Up @@ -59,9 +61,7 @@ def supabase_client(self) -> Client:

@property
def vector_store(self) -> CustomSupabaseVectorStore:
print("Creating vector store")
print("🧠🧠🧠🧠🧠🧠🧠🧠")
print("Brain id: ", self.brain_id)

return CustomSupabaseVectorStore(
self.supabase_client,
self.embeddings,
Expand Down Expand Up @@ -90,7 +90,6 @@ def doc_chain(self) -> LLMChain:

@property
def qa(self) -> ConversationalRetrievalChain:
print("Creating QA chain")
return ConversationalRetrievalChain(
retriever=self.vector_store.as_retriever(),
question_generator=self.question_generator,
Expand Down Expand Up @@ -172,11 +171,10 @@ async def generate_stream(self, question: str) -> AsyncIterable:
:param question: The question
:return: An async iterable which generates the answer.
"""
print("Generating stream")
history = get_chat_history(self.chat_id)
callback = self.callbacks[0]
print(self.callbacks)
callback = AsyncIteratorCallbackHandler()
self.callbacks = [callback]
model = ChatOpenAI(
streaming=True,
verbose=True,
Expand All @@ -196,20 +194,21 @@ async def generate_stream(self, question: str) -> AsyncIterable:
response_tokens = []

# Wrap an awaitable with a event to signal when it's done or an exception is raised.

async def wrap_done(fn: Awaitable, event: asyncio.Event):
try:
await fn
except Exception as e:
logger.error(f"Caught exception: {e}")
finally:
event.set()
print("Calling chain")
# Begin a task that runs in the background.
task = asyncio.create_task(wrap_done(
await qa.acall({"question": question, "chat_history": transformed_history}, include_run_info=True),
callback.done),
)


run = asyncio.create_task(wrap_done(
qa.acall({"question": question, "chat_history": transformed_history}),
callback.done,
))

streamed_chat_history = update_chat_history(
chat_id=self.chat_id,
user_message=question,
Expand All @@ -226,8 +225,7 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event):

yield f"data: {json.dumps(streamed_chat_history.to_dict())}"

await task

await run
# Join the tokens to create the assistant's response
assistant = "".join(response_tokens)

Expand Down
10 changes: 2 additions & 8 deletions backend/core/vectorstore/supabase.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,12 @@ def __init__(
def similarity_search(
self,
query: str,
table: str = "match_vectors",
k: int = 6,
table: str = "match_vectors",
threshold: float = 0.5,
**kwargs: Any
) -> List[Document]:
print("Everything is fine 🔥🔥🔥🔥🔥🔥🔥🔥🔥🔥")
print("Query: ", query)
print("Table: ", table)
print("K: ", k)
print("Threshold: ", threshold)
print("Kwargs: ", kwargs)
print("Brain ID: ", self.brain_id)

vectors = self._embedding.embed_documents([query])
query_embedding = vectors[0]
res = self._client.rpc(
Expand Down

0 comments on commit 49b24e7

Please sign in to comment.