diff --git a/.gitignore b/.gitignore index 7016c33..78d7a27 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ __pycache__/** insf_venv/** *.pyc +.env/* \ No newline at end of file diff --git a/app.py b/app.py index e9e6a40..0062810 100644 --- a/app.py +++ b/app.py @@ -1,6 +1,8 @@ import os import uuid import datasets +import tempfile + from langchain_huggingface import HuggingFaceEndpointEmbeddings from langchain_openai import ChatOpenAI from langchain_core.output_parsers import StrOutputParser @@ -19,6 +21,8 @@ from urllib3.exceptions import ProtocolError from langchain.retrievers import ContextualCompressionRetriever from transformers import AutoTokenizer +from unstructured.cleaners.core import clean_extra_whitespace, group_broken_paragraphs +from langchain_community.document_loaders import PyPDFLoader from tools import get_tools from tei_rerank import TEIRerank @@ -29,10 +33,10 @@ import yaml from yaml.loader import SafeLoader -from langchain.globals import set_verbose, set_debug +# from langchain.globals import set_verbose, set_debug -set_verbose(True) -set_debug(True) +# set_verbose(True) +# set_debug(True) st.set_page_config(layout="wide", page_title="InSightful") @@ -129,7 +133,8 @@ def setup_huggingface_embeddings(): @st.cache_resource def load_prompt_and_system_ins( - template_file_path="templates/prompt_template.tmpl", template=None + template_file_path: str = "templates/prompt_template.tmpl", + template: str | None = None, ): # prompt = hub.pull("hwchase17/react-chat") prompt = PromptTemplate.from_file(template_file_path) @@ -149,10 +154,11 @@ def load_prompt_and_system_ins( return prompt, system_instructions -class RAG: - def __init__(self, collection_name, db_client): - self.collection_name = collection_name +class RAG(object): + def __init__(self, llm: ChatOpenAI, db_client, embedding_function): + self.llm = llm self.db_client = db_client + self.embedding_function = embedding_function @retry( retry=retry_if_exception_type(ProtocolError), @@ -182,14 +188,14 @@ def chunk_doc(self, pages, chunk_size=512, chunk_overlap=30): print("Document chunked") return chunks - def insert_embeddings(self, chunks, chroma_embedding_function, batch_size=32): + def insert_embeddings(self, chunks, collection_name, batch_size=32): print( "Inserting embeddings into collection: {collection_name}".format( - collection_name=self.collection_name + collection_name=collection_name ) ) collection = self.db_client.get_or_create_collection( - self.collection_name, embedding_function=chroma_embedding_function + collection_name, embedding_function=self.embedding_function ) for i in range(0, len(chunks), batch_size): batch = chunks[i : i + batch_size] @@ -219,44 +225,39 @@ def get_retriever(self, vector_store, use_reranker=False): return retriever def query_docs( - self, model, question, vector_store, prompt, chat_history, use_reranker=False + self, question, vector_store, prompt, chat_history, use_reranker=False ): retriever = self.get_retriever(vector_store, use_reranker) pass_question = lambda input: input["question"] rag_chain = ( RunnablePassthrough.assign(context=pass_question | retriever | format_docs) | prompt - | model + | self.llm | StrOutputParser() ) return rag_chain.stream({"question": question, "chat_history": chat_history}) + def load_pdf(self, doc): + with tempfile.NamedTemporaryFile( + delete=False, suffix=os.path.splitext(doc.name)[1] + ) as tmp: + tmp.write(doc.getvalue()) + tmp_path = tmp.name + loader = PyPDFLoader(tmp_path) + documents = loader.load() + cleaned_pages = [] + for doc in documents: + doc.page_content = clean_extra_whitespace(doc.page_content) + doc.page_content = group_broken_paragraphs(doc.page_content) + cleaned_pages.append(doc) + return cleaned_pages + def format_docs(docs): return "\n\n".join(doc.page_content for doc in docs) -def create_retriever( - name, description, client, chroma_embedding_function, embedding_svc, reranker=False -): - collection_name = "software-slacks" - rag = RAG(collection_name=collection_name, db_client=client) - pages = rag.load_documents("spencer/software_slacks", num_docs=100) - chunks = rag.chunk_doc(pages) - rag.insert_embeddings(chunks, chroma_embedding_function) - vector_store = Chroma( - embedding_function=embedding_svc, - collection_name=collection_name, - client=client, - ) - retriever = rag.get_retriever(vector_store, use_reranker=reranker) - - retriever = vector_store.as_retriever( - search_type="similarity", search_kwargs={"k": 10} - ) - return create_retriever_tool(retriever, name, description) - @st.cache_resource def setup_agent(_model, _prompt, _tools): agent = create_react_agent( @@ -280,17 +281,25 @@ def main(): model = setup_chat_endpoint() embedder = setup_huggingface_embeddings() use_reranker = os.getenv("USE_RERANKER", "False") == "True" - - retriever_tool = create_retriever( - "slack_conversations_retriever", - "Useful for when you need to answer from Slack conversations.", - client, - chroma_embedding_function, - embedder, - reranker=use_reranker, + rag = RAG(llm=model, db_client=client, embedding_function=chroma_embedding_function) + collection_name = "software-slacks" + pages = rag.load_documents("spencer/software_slacks", num_docs=100) + chunks = rag.chunk_doc(pages) + rag.insert_embeddings(chunks, collection_name) + vector_store = Chroma( + embedding_function=embedder, + collection_name=collection_name, + client=client, ) + retriever = rag.get_retriever(vector_store, use_reranker=use_reranker) _tools = get_tools() - _tools.append(retriever_tool) + _tools.append( + create_retriever_tool( + retriever, + "slack_conversations_retriever", + "Useful for when you need to answer from Slack conversations.", + ) + ) agent_executor = setup_agent(model, prompt, _tools) @@ -328,7 +337,4 @@ def main(): if __name__ == "__main__": - # authenticator = authenticate() - # if st.session_state['authentication_status']: - # authenticator.logout() main() diff --git a/multi_tenant_rag.py b/multi_tenant_rag.py index 68dcaa4..a8d3443 100644 --- a/multi_tenant_rag.py +++ b/multi_tenant_rag.py @@ -1,14 +1,13 @@ -import os import logging -import tempfile +import os import yaml from yaml.loader import SafeLoader import streamlit as st import streamlit_authenticator as stauth from streamlit_authenticator.utilities import RegisterError -from langchain_community.document_loaders import PyPDFLoader from langchain_community.vectorstores.chroma import Chroma -from unstructured.cleaners.core import clean_extra_whitespace, group_broken_paragraphs +from langchain_chroma import Chroma + from tools import get_tools from app import ( @@ -21,17 +20,19 @@ setup_agent, ) + SYSTEM = "system" USER = "user" ASSISTANT = "assistant" logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO) +log: logging.Logger = logging.getLogger(__name__) def configure_authenticator(): auth_config = os.getenv("AUTH_CONFIG_FILE_PATH", default=".streamlit/config.yaml") - print(f"auth_config: {auth_config}") + log.info(f"auth_config: {auth_config}") with open(file=auth_config) as file: config = yaml.load(file, Loader=SafeLoader) @@ -67,49 +68,32 @@ def authenticate(op): return authenticator -class MultiTenantRAG(RAG): - def __init__(self, user_id, collection_name, db_client): - self.user_id = user_id - super().__init__(collection_name, db_client) - - def load_documents(self, doc): - with tempfile.NamedTemporaryFile( - delete=False, suffix=os.path.splitext(doc.name)[1] - ) as tmp: - tmp.write(doc.getvalue()) - tmp_path = tmp.name - loader = PyPDFLoader(tmp_path) - documents = loader.load() - cleaned_pages = [] - for doc in documents: - doc.page_content = clean_extra_whitespace(doc.page_content) - doc.page_content = group_broken_paragraphs(doc.page_content) - cleaned_pages.append(doc) - return cleaned_pages - - def main(): + authenticator = authenticate("login") + if st.session_state["authentication_status"]: + st.sidebar.text(f"Welcome {st.session_state['username']}") + authenticator.logout(location="sidebar") + user_id = st.session_state["username"] + if not user_id: + st.error("Please login to continue") + return + use_reranker = st.sidebar.toggle("Use reranker", False) use_tools = st.sidebar.toggle("Use tools", False) uploaded_file = st.sidebar.file_uploader("Upload a document", type=["pdf"]) question = st.chat_input("Chat with your docs or apis") llm = setup_chat_endpoint() - embedding_svc = setup_huggingface_embeddings() - chroma_embeddings = hf_embedding_server() - - user_id = st.session_state["username"] - client = setup_chroma_client() + # Set up prompt template template = """ Based on the retrieved context, respond with an accurate answer. Be concise and always provide accurate, specific, and relevant information. """ - template_file_path = "templates/multi_tenant_rag_prompt_template.tmpl" if use_tools: template_file_path = "templates/multi_tenant_rag_prompt_template_tools.tmpl" @@ -118,6 +102,7 @@ def main(): template_file_path=template_file_path, template=template, ) + log.info(f"prompt: {prompt} system_instructions: {system_instructions}") chat_history = st.session_state.get( "chat_history", [{"role": SYSTEM, "content": system_instructions.content}] @@ -127,26 +112,19 @@ def main(): with st.chat_message(message["role"]): st.markdown(message["content"]) - if not user_id: - st.error("Please login to continue") - return - collection = client.get_or_create_collection( f"user-collection-{user_id}", embedding_function=chroma_embeddings ) - logger = logging.getLogger(__name__) - logger.info( + log.info( f"user_id: {user_id} use_reranker: {use_reranker} use_tools: {use_tools} question: {question}" ) - rag = MultiTenantRAG(user_id, collection.name, client) + rag = RAG(llm=llm, db_client=client, embedding_function=chroma_embeddings) if use_tools: tools = get_tools() agent_executor = setup_agent(llm, prompt, tools) - # prompt = hub.pull("rlm/rag-prompt") - vectorstore = Chroma( embedding_function=embedding_svc, collection_name=collection.name, @@ -154,11 +132,11 @@ def main(): ) if uploaded_file: - document = rag.load_documents(uploaded_file) + document = rag.load_pdf(uploaded_file) chunks = rag.chunk_doc(document) rag.insert_embeddings( chunks=chunks, - chroma_embedding_function=chroma_embeddings, + collection_name=collection.name, batch_size=32, ) @@ -174,10 +152,9 @@ def main(): )["output"] with st.chat_message(ASSISTANT): st.write(answer) - logger.info(f"answer: {answer}") + log.info(f"answer: {answer}") else: answer = rag.query_docs( - model=llm, question=question, vector_store=vectorstore, prompt=prompt, @@ -186,7 +163,7 @@ def main(): ) with st.chat_message(ASSISTANT): answer = st.write_stream(answer) - logger.info(f"answer: {answer}") + log.info(f"answer: {answer}") chat_history.append({"role": USER, "content": question}) chat_history.append({"role": ASSISTANT, "content": answer}) @@ -194,8 +171,4 @@ def main(): if __name__ == "__main__": - authenticator = authenticate("login") - if st.session_state["authentication_status"]: - st.sidebar.text(f"Welcome {st.session_state['username']}") - authenticator.logout(location="sidebar") - main() + main() diff --git a/requirements.txt b/requirements.txt index 8492c8a..81329be 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ chromadb==0.5.3 datasets==2.20.0 langchain==0.2.12 -langchain_chroma==0.1.2 +langchain_chroma==0.1.3 langchain_community==0.2.11 langchain_core==0.2.28 langchain_huggingface==0.0.3