Skip to content

Commit

Permalink
feat(backend): add custom prompt (#885)
Browse files Browse the repository at this point in the history
  • Loading branch information
mamadoudicko authored Aug 7, 2023
1 parent 1160e16 commit 61cd0a6
Showing 1 changed file with 90 additions and 22 deletions.
112 changes: 90 additions & 22 deletions backend/core/llm/qa_base.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
import asyncio
import json
from abc import abstractmethod, abstractproperty
from typing import AsyncIterable, Awaitable
from uuid import UUID

from langchain import PromptTemplate
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms.base import BaseLLM
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from logger import get_logger
from models.chat import ChatHistory
from repository.brain.get_brain_by_id import get_brain_by_id
from repository.chat.format_chat_history import format_chat_history
from repository.chat.get_chat_history import get_chat_history
from repository.chat.update_chat_history import update_chat_history
from repository.chat.update_message_by_id import update_message_by_id
from repository.prompt.get_prompt_by_id import get_prompt_by_id
from supabase.client import Client, create_client
from vectorstore.supabase import CustomSupabaseVectorStore
from repository.chat.update_message_by_id import update_message_by_id
import json

from .base import BaseBrainPicking
from .prompts.CONDENSE_PROMPT import CONDENSE_QUESTION_PROMPT
Expand Down Expand Up @@ -60,13 +65,13 @@ def supabase_client(self) -> Client:

@property
def vector_store(self) -> CustomSupabaseVectorStore:

return CustomSupabaseVectorStore(
self.supabase_client,
self.embeddings,
table_name="vectors",
brain_id=self.brain_id,
)

@property
def question_llm(self):
return self._create_llm(model=self.model, streaming=False)
Expand All @@ -79,12 +84,26 @@ def doc_llm(self):

@property
def question_generator(self) -> LLMChain:
return LLMChain(llm=self.question_llm, prompt=CONDENSE_QUESTION_PROMPT, verbose=True)
return LLMChain(
llm=self.question_llm, prompt=CONDENSE_QUESTION_PROMPT, verbose=True
)

@property
def doc_chain(self) -> LLMChain:
prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
{context}
Question: {question}
Here is instructions on how to answer the question: {brain_prompt}
Answer:"""
PROMPT = PromptTemplate(
template=prompt_template,
input_variables=["context", "question", "brain_prompt"],
)

return load_qa_chain(
llm=self.doc_llm, chain_type="stuff", verbose=True
llm=self.doc_llm, chain_type="stuff", verbose=True, prompt=PROMPT
) # pyright: ignore reportPrivateUsage=none

@property
Expand All @@ -97,7 +116,9 @@ def qa(self) -> ConversationalRetrievalChain:
)

@abstractmethod
def _create_llm(self, model, streaming=False, callbacks=None) -> BaseLLM:
def _create_llm(
self, model, streaming=False, callbacks=None, temperature=0.0
) -> BaseLLM:
"""
Determine the language model to be used.
:param model: Language model name to be used.
Expand All @@ -106,7 +127,7 @@ def _create_llm(self, model, streaming=False, callbacks=None) -> BaseLLM:
:return: Language model instance
"""

def _call_chain(self, chain, question, history):
def _call_chain(self, chain, question, history, brain_prompt):
"""
Call a chain with a given question and history.
:param chain: The chain eg QA (ConversationalRetrievalChain)
Expand All @@ -118,6 +139,7 @@ def _call_chain(self, chain, question, history):
{
"question": question,
"chat_history": history,
"brain_prompt": brain_prompt,
}
)

Expand All @@ -136,7 +158,12 @@ def generate_answer(self, question: str) -> ChatHistory:
transformed_history = format_chat_history(history)

# Generate the model response using the QA chain
model_response = self._call_chain(self.qa, question, transformed_history)
model_response = self._call_chain(
self.qa,
question,
transformed_history,
brain_prompt=self.get_prompt(),
)

answer = model_response["answer"]

Expand Down Expand Up @@ -164,6 +191,17 @@ async def _acall_chain(self, chain, question, history):
}
)

def get_prompt(self) -> str:
brain = get_brain_by_id(UUID(self.brain_id))
brain_prompt = "Your name is Quivr. You're a helpful assistant."

if brain and brain.prompt_id:
brain_prompt_object = get_prompt_by_id(brain.prompt_id)
if brain_prompt_object:
brain_prompt = brain_prompt_object.content

return brain_prompt

async def generate_stream(self, question: str) -> AsyncIterable:
"""
Generate a streaming answer to a given question by interacting with the language model.
Expand All @@ -175,23 +213,44 @@ async def generate_stream(self, question: str) -> AsyncIterable:
self.callbacks = [callback]

# The Model used to answer the question with the context
answering_llm = self._create_llm(model=self.model, streaming=True, callbacks=self.callbacks,temperature=self.temperature)
answering_llm = self._create_llm(
model=self.model,
streaming=True,
callbacks=self.callbacks,
temperature=self.temperature,
)


# The Model used to create the standalone Question
# Temperature = 0 means no randomness
standalone_question_llm = self._create_llm(model=self.model)

# The Chain that generates the standalone question
standalone_question_generator = LLMChain(llm=standalone_question_llm, prompt=CONDENSE_QUESTION_PROMPT)
standalone_question_generator = LLMChain(
llm=standalone_question_llm, prompt=CONDENSE_QUESTION_PROMPT
)

prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
{context}
Question: {question}
Here is instructions on how to answer the question: {brain_prompt}
Answer:"""
PROMPT = PromptTemplate(
template=prompt_template,
input_variables=["context", "question", "brain_prompt"],
)

# The Chain that generates the answer to the question
doc_chain = load_qa_chain(answering_llm, chain_type="stuff")
doc_chain = load_qa_chain(answering_llm, chain_type="stuff", prompt=PROMPT)

# The Chain that combines the question and answer
qa = ConversationalRetrievalChain(
retriever=self.vector_store.as_retriever(), combine_docs_chain=doc_chain, question_generator=standalone_question_generator)

retriever=self.vector_store.as_retriever(),
combine_docs_chain=doc_chain,
question_generator=standalone_question_generator,
)

transformed_history = []

# Format the chat history into a list of tuples (human, ai)
Expand All @@ -201,21 +260,30 @@ async def generate_stream(self, question: str) -> AsyncIterable:
response_tokens = []

# Wrap an awaitable with a event to signal when it's done or an exception is raised.

async def wrap_done(fn: Awaitable, event: asyncio.Event):
try:
await fn
except Exception as e:
logger.error(f"Caught exception: {e}")
finally:
event.set()

# Begin a task that runs in the background.

run = asyncio.create_task(wrap_done(
qa.acall({"question": question, "chat_history": transformed_history}),
callback.done,
))


run = asyncio.create_task(
wrap_done(
qa.acall(
{
"question": question,
"chat_history": transformed_history,
"brain_prompt": self.get_prompt(),
}
),
callback.done,
)
)

streamed_chat_history = update_chat_history(
chat_id=self.chat_id,
user_message=question,
Expand Down

0 comments on commit 61cd0a6

Please sign in to comment.