Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/rag function #36

Merged
merged 12 commits into from
Jan 10, 2024
15 changes: 11 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,17 @@ pip install -r dev-requirements
```python
from wihsperplus.pipelines.chatbot import ChatWithVideo

# Run the query
query = "what is mistral?"
result = ChatWithVideo.run_query(query)
print("result : ", result)
input_file = "trascript.text"
llm_model_name = "TheBloke/Mistral-7B-v0.1-GGUF"
llm_model_file = "mistral-7b-v0.1.Q4_K_M.gguf"
llm_model_type = "mistral"
embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2"
chat = ChatWithVideo(
input_file, llm_model_name, llm_model_file, llm_model_type, embedding_model_name
)
query = "what is this video about ?"
response = chat.run_query(query)
print(response)
```

### Contributing
Expand Down
Empty file added whisperplus/info_rag.txt
kadirnar marked this conversation as resolved.
Show resolved Hide resolved
Empty file.
Empty file.
94 changes: 48 additions & 46 deletions whisperplus/pipelines/chatbot.py
Original file line number Diff line number Diff line change
@@ -1,106 +1,106 @@
import logging

import lancedb
from langchain.chains import RetrievalQA
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import TextLoader
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import CTransformers
from langchain.prompts import PromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import LanceDB

# Configuration and Constants
MODEL_NAME = 'sentence-transformers/all-MiniLM-L6-v2'
LLM_MODEL_NAME = 'TheBloke/Mistral-7B-v0.1-GGUF'
LLM_MODEL_FILE = 'mistral-7b-v0.1.Q4_K_M.gguf'
LLM_MODEL_TYPE = "mistral"
TEXT_FILE_PATH = "transcript.text"
DATABASE_PATH = '/tmp/lancedb'
# Configure basic logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


class ChatWithVideo:

@staticmethod
def load_llm_model():
def __init__(self, input_file, llm_model_name, llm_model_file, llm_model_type, embedding_model_name):
self.input_file = input_file
self.llm_model_name = llm_model_name
self.llm_model_file = llm_model_file
self.llm_model_type = llm_model_type
self.embedding_model_name = embedding_model_name

def load_llm_model(self):
try:
print("Starting to download the Mistral model...")
logger.info(f"Starting to download the {self.llm_model_name} model...")
llm_model = CTransformers(
model=LLM_MODEL_NAME, model_file=LLM_MODEL_FILE, model_type=LLM_MODEL_TYPE)
print("Mistral model successfully loaded.")
model=self.llm_model_name, model_file=self.llm_model_file, model_type=self.llm_model_type)
logger.info(f"{self.llm_model_name} model successfully loaded.")
return llm_model
except Exception as e:
print(f"Error loading the Mistral model: {e}")
logger.error(f"Error loading the {self.llm_model_name} model: {e}")
return None

@staticmethod
def load_text_file(file_path):
def load_text_file(self):
try:
print(f"Loading transcript file from {file_path}...")
loader = TextLoader(file_path)
logger.info(f"Loading transcript file from {self.input_file}...")
loader = TextLoader(self.input_file)
docs = loader.load()
print("Transcript file successfully loaded.")
logger.info("Transcript file successfully loaded.")
return docs
except Exception as e:
print(f"Error loading text file: {e}")
logger.error(f"Error loading text file: {e}")
return None

@staticmethod
def setup_database():
try:
print("Setting up the database...")
db = lancedb.connect(DATABASE_PATH)
print("Database setup complete.")
logger.info("Setting up the database...")
db = lancedb.connect('/tmp/lancedb')
logger.info("Database setup complete.")
return db
except Exception as e:
print(f"Error setting up the database: {e}")
logger.error(f"Error setting up the database: {e}")
return None

# embedding model
@staticmethod
def prepare_embeddings(model_name):
try:
print(f"Preparing embeddings with model: {model_name}...")
logger.info(f"Preparing embeddings with model: {model_name}...")
embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs={'device': 'cpu'})
print("Embeddings prepared successfully.")
logger.info("Embeddings prepared successfully.")
return embeddings
except Exception as e:
print(f"Error preparing embeddings: {e}")
logger.error(f"Error preparing embeddings: {e}")
return None

@staticmethod
def prepare_documents(docs):
if not docs:
print("No documents provided for preparation.")
logger.info("No documents provided for preparation.")
return None
try:
print("Preparing documents...")
logger.info("Preparing documents...")
text_splitter = RecursiveCharacterTextSplitter(chunk_size=200, chunk_overlap=50)
documents = text_splitter.split_documents(docs)
print("Documents prepared successfully.")
logger.info("Documents prepared successfully.")
return documents
except Exception as e:
print(f"Error preparing documents: {e}")
logger.error(f"Error preparing documents: {e}")
return None

@staticmethod
def run_query(query):
def run_query(self, query):
if not query:
print("No query provided.")
logger.info("No query provided.")
return "No query provided."

print(f"Running query: {query}")
docs = ChatWithVideo.load_text_file(TEXT_FILE_PATH)
logger.info(f"Running query: {query}")
docs = self.load_text_file()
if not docs:
return "Failed to load documents."

documents = ChatWithVideo.prepare_documents(docs)
documents = self.prepare_documents(docs)
if not documents:
return "Failed to prepare documents."

embeddings = ChatWithVideo.prepare_embeddings(MODEL_NAME)
embeddings = self.prepare_embeddings(self.embedding_model_name)
if not embeddings:
return "Failed to prepare embeddings."

db = ChatWithVideo.setup_database()
db = self.setup_database()
if not db:
return "Failed to setup database."

Expand All @@ -115,8 +115,7 @@ def run_query(query):
mode="overwrite")
docsearch = LanceDB.from_documents(documents, embeddings, connection=table)

llm = ChatWithVideo.load_llm_model()

llm = self.load_llm_model()
if not llm:
return "Failed to load LLM model."

Expand All @@ -129,14 +128,17 @@ def run_query(query):
Helpful Answer:"""

QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context", "question"], template=template)
print("prompt loaded")
logger.info("Prompt loaded")
qa = RetrievalQA.from_chain_type(
llm,
chain_type='stuff',
retriever=docsearch.as_retriever(),
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT})
print("Query processed successfully.")
return qa.run(query)
logger.info("Query processed successfully.")

result = qa.run(query)
logger.info(f"Result of the query: {result}")
return result
except Exception as e:
print(f"Error running query: {e}")
logger.error(f"Error running query: {e}")
return f"Error: {e}"