diff --git a/backend/core/llm/qa_base.py b/backend/core/llm/qa_base.py index 33295fe83ce4..17486efb89d7 100644 --- a/backend/core/llm/qa_base.py +++ b/backend/core/llm/qa_base.py @@ -1,20 +1,25 @@ import asyncio +import json from abc import abstractmethod, abstractproperty from typing import AsyncIterable, Awaitable +from uuid import UUID + +from langchain import PromptTemplate +from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler from langchain.chains import ConversationalRetrievalChain, LLMChain from langchain.chains.question_answering import load_qa_chain from langchain.embeddings.openai import OpenAIEmbeddings from langchain.llms.base import BaseLLM -from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler from logger import get_logger from models.chat import ChatHistory +from repository.brain.get_brain_by_id import get_brain_by_id from repository.chat.format_chat_history import format_chat_history from repository.chat.get_chat_history import get_chat_history from repository.chat.update_chat_history import update_chat_history +from repository.chat.update_message_by_id import update_message_by_id +from repository.prompt.get_prompt_by_id import get_prompt_by_id from supabase.client import Client, create_client from vectorstore.supabase import CustomSupabaseVectorStore -from repository.chat.update_message_by_id import update_message_by_id -import json from .base import BaseBrainPicking from .prompts.CONDENSE_PROMPT import CONDENSE_QUESTION_PROMPT @@ -60,13 +65,13 @@ def supabase_client(self) -> Client: @property def vector_store(self) -> CustomSupabaseVectorStore: - return CustomSupabaseVectorStore( self.supabase_client, self.embeddings, table_name="vectors", brain_id=self.brain_id, ) + @property def question_llm(self): return self._create_llm(model=self.model, streaming=False) @@ -79,12 +84,26 @@ def doc_llm(self): @property def question_generator(self) -> LLMChain: - return LLMChain(llm=self.question_llm, prompt=CONDENSE_QUESTION_PROMPT, verbose=True) + return LLMChain( + llm=self.question_llm, prompt=CONDENSE_QUESTION_PROMPT, verbose=True + ) @property def doc_chain(self) -> LLMChain: + prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. + + {context} + + Question: {question} + Here is instructions on how to answer the question: {brain_prompt} + Answer:""" + PROMPT = PromptTemplate( + template=prompt_template, + input_variables=["context", "question", "brain_prompt"], + ) + return load_qa_chain( - llm=self.doc_llm, chain_type="stuff", verbose=True + llm=self.doc_llm, chain_type="stuff", verbose=True, prompt=PROMPT ) # pyright: ignore reportPrivateUsage=none @property @@ -97,7 +116,9 @@ def qa(self) -> ConversationalRetrievalChain: ) @abstractmethod - def _create_llm(self, model, streaming=False, callbacks=None) -> BaseLLM: + def _create_llm( + self, model, streaming=False, callbacks=None, temperature=0.0 + ) -> BaseLLM: """ Determine the language model to be used. :param model: Language model name to be used. @@ -106,7 +127,7 @@ def _create_llm(self, model, streaming=False, callbacks=None) -> BaseLLM: :return: Language model instance """ - def _call_chain(self, chain, question, history): + def _call_chain(self, chain, question, history, brain_prompt): """ Call a chain with a given question and history. :param chain: The chain eg QA (ConversationalRetrievalChain) @@ -118,6 +139,7 @@ def _call_chain(self, chain, question, history): { "question": question, "chat_history": history, + "brain_prompt": brain_prompt, } ) @@ -136,7 +158,12 @@ def generate_answer(self, question: str) -> ChatHistory: transformed_history = format_chat_history(history) # Generate the model response using the QA chain - model_response = self._call_chain(self.qa, question, transformed_history) + model_response = self._call_chain( + self.qa, + question, + transformed_history, + brain_prompt=self.get_prompt(), + ) answer = model_response["answer"] @@ -164,6 +191,17 @@ async def _acall_chain(self, chain, question, history): } ) + def get_prompt(self) -> str: + brain = get_brain_by_id(UUID(self.brain_id)) + brain_prompt = "Your name is Quivr. You're a helpful assistant." + + if brain and brain.prompt_id: + brain_prompt_object = get_prompt_by_id(brain.prompt_id) + if brain_prompt_object: + brain_prompt = brain_prompt_object.content + + return brain_prompt + async def generate_stream(self, question: str) -> AsyncIterable: """ Generate a streaming answer to a given question by interacting with the language model. @@ -175,23 +213,44 @@ async def generate_stream(self, question: str) -> AsyncIterable: self.callbacks = [callback] # The Model used to answer the question with the context - answering_llm = self._create_llm(model=self.model, streaming=True, callbacks=self.callbacks,temperature=self.temperature) + answering_llm = self._create_llm( + model=self.model, + streaming=True, + callbacks=self.callbacks, + temperature=self.temperature, + ) - # The Model used to create the standalone Question # Temperature = 0 means no randomness standalone_question_llm = self._create_llm(model=self.model) # The Chain that generates the standalone question - standalone_question_generator = LLMChain(llm=standalone_question_llm, prompt=CONDENSE_QUESTION_PROMPT) + standalone_question_generator = LLMChain( + llm=standalone_question_llm, prompt=CONDENSE_QUESTION_PROMPT + ) + + prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. + + {context} + + Question: {question} + Here is instructions on how to answer the question: {brain_prompt} + Answer:""" + PROMPT = PromptTemplate( + template=prompt_template, + input_variables=["context", "question", "brain_prompt"], + ) # The Chain that generates the answer to the question - doc_chain = load_qa_chain(answering_llm, chain_type="stuff") + doc_chain = load_qa_chain(answering_llm, chain_type="stuff", prompt=PROMPT) # The Chain that combines the question and answer qa = ConversationalRetrievalChain( - retriever=self.vector_store.as_retriever(), combine_docs_chain=doc_chain, question_generator=standalone_question_generator) - + retriever=self.vector_store.as_retriever(), + combine_docs_chain=doc_chain, + question_generator=standalone_question_generator, + ) + transformed_history = [] # Format the chat history into a list of tuples (human, ai) @@ -201,7 +260,7 @@ async def generate_stream(self, question: str) -> AsyncIterable: response_tokens = [] # Wrap an awaitable with a event to signal when it's done or an exception is raised. - + async def wrap_done(fn: Awaitable, event: asyncio.Event): try: await fn @@ -209,13 +268,22 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event): logger.error(f"Caught exception: {e}") finally: event.set() + # Begin a task that runs in the background. - - run = asyncio.create_task(wrap_done( - qa.acall({"question": question, "chat_history": transformed_history}), - callback.done, - )) - + + run = asyncio.create_task( + wrap_done( + qa.acall( + { + "question": question, + "chat_history": transformed_history, + "brain_prompt": self.get_prompt(), + } + ), + callback.done, + ) + ) + streamed_chat_history = update_chat_history( chat_id=self.chat_id, user_message=question,