Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update/faiss #13

Merged
merged 3 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions local_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from parameters import load_config
global DATA_PATH
load_config('test')
from parameters import CHROMA_ROOT_PATH, EMBEDDING_MODEL, LLM_MODEL, PROMPT_TEMPLATE, DATA_PATH, REPHRASING_PROMPT, STANDALONE_PROMPT, ROUTER_DECISION_PROMPT
from parameters import DATABASE_ROOT_PATH, EMBEDDING_MODEL, LLM_MODEL, PROMPT_TEMPLATE, DATA_PATH, REPHRASING_PROMPT, STANDALONE_PROMPT, ROUTER_DECISION_PROMPT
from get_llm_function import get_llm_function
from get_rag_chain import get_rag_chain
from ConversationalRagChain import ConversationalRagChain
Expand Down Expand Up @@ -82,7 +82,6 @@ def get_Chat_response(query):
"chat_history": []
}
res = rag_conv._call(inputs)
print(res['metadatas'])
output = jsonify({
'response': res['result'],
'context': res['context'],
Expand Down
10 changes: 5 additions & 5 deletions python_script/config.json
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
{
"default": {
"data_path": "data/documents/default",
"chroma_root_path": "data/chroma/default",
"database_root_path": "data/database/default",
"embedding_model": "voyage-law-2",
"llm_model": "gpt-3.5-turbo"
},
"seus": {
"data_path": "data/documents/ship_data",
"chroma_root_path": "data/chroma/ship_chroma",
"database_root_path": "data/database/ship",
"embedding_model": "openai",
"llm_model": "gpt-3.5-turbo"
},
"arch-en": {
"data_path": "data/documents/arch_data-en",
"chroma_root_path": "data/chroma/arch_data-en_chroma",
"database_root_path": "data/database/arch_data-en",
"embedding_model": "openai",
"llm_model": "gpt-3.5-turbo"
},
"arch-ru": {
"data_path": "data/documents/arch_data-ru",
"chroma_root_path": "data/chroma/arch_data-ru_chroma",
"database_root_path": "data/database/arch_data-ru",
"embedding_model": "openai",
"llm_model": "gpt-3.5-turbo"
},
"test": {
"data_path": "data/documents/test_data",
"chroma_root_path": "data/chroma/test_chroma",
"database_root_path": "data/database/test",
"embedding_model": "openai",
"llm_model": "gpt-3.5-turbo"
}
Expand Down
34 changes: 21 additions & 13 deletions python_script/get_rag_chain.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from parameters import CHROMA_ROOT_PATH, EMBEDDING_MODEL, LLM_MODEL
from parameters import DATABASE_ROOT_PATH, EMBEDDING_MODEL, LLM_MODEL

from get_embedding_function import get_embedding_function
from get_llm_function import get_llm_function
from populate_database import find_chroma_path
from populate_database import find_database_path

from langchain.vectorstores.chroma import Chroma
from langchain.chains import create_history_aware_retriever
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.vectorstores import FAISS

def get_rag_chain(params = None):
"""
Expand All @@ -21,7 +21,7 @@ def get_rag_chain(params = None):

Parameters:
params (dict, optional): A dictionary of configuration parameters.
- chroma_root_path (str): The root path for Chroma data storage.
- database_root_path (str): The root path for data storage.
- embedding_model (str): The model name for the embedding function.
- llm_model (str): The model name for the language model.
- search_type (str): The type of search to perform. Options are:
Expand All @@ -41,7 +41,7 @@ def get_rag_chain(params = None):
"""

default_params = {
"chroma_root_path": CHROMA_ROOT_PATH,
"database_root_path": DATABASE_ROOT_PATH,
"embedding_model": EMBEDDING_MODEL,
"llm_model": LLM_MODEL,
"search_type": "similarity",
Expand All @@ -59,25 +59,34 @@ def get_rag_chain(params = None):
params = {**default_params, **params}

try:
required_keys = ["chroma_root_path", "embedding_model", "llm_model"]
required_keys = ["database_root_path", "embedding_model", "llm_model"]
for key in required_keys:
if key not in params:
raise NameError(f"Required setting '{key}' not defined.")

embedding_model = get_embedding_function(model_name=params["embedding_model"])
llm = get_llm_function(model_name=params["llm_model"])
db = Chroma(persist_directory=find_chroma_path(model_name=params["embedding_model"], base_path=params["chroma_root_path"]), embedding_function=embedding_model)


# Load the FAISS index from disk
vector_store = FAISS.load_local(find_database_path(EMBEDDING_MODEL,DATABASE_ROOT_PATH)
, embedding_model, allow_dangerous_deserialization=True)

search_type = params["search_type"]
if search_type == "similarity":
retriever = db.as_retriever(search_type=search_type, search_kwargs={"k": params["similarity_doc_nb"]})
retriever = vector_store.as_retriever(search_type=search_type,
search_kwargs={"k": params["similarity_doc_nb"]})
elif search_type == "similarity_score_threshold":
retriever = db.as_retriever(search_type=search_type, search_kwargs={"k": params["max_chunk_return"],"score_threshold": params["score_threshold"]})
retriever = vector_store.as_retriever(search_type=search_type,
search_kwargs={"k": params["max_chunk_return"],
"score_threshold": params["score_threshold"]})
elif search_type == "mmr":
retriever = db.as_retriever(search_type=search_type, search_kwargs={"k": params["mmr_doc_nb"], "fetch_k": params["considered_chunk"], "lambda_mult": params["lambda_mult"]})
retriever = vector_store.as_retriever(search_type=search_type,
search_kwargs={"k": params["mmr_doc_nb"],
"fetch_k": params["considered_chunk"],
"lambda_mult": params["lambda_mult"]})
else:
raise ValueError("Invalid 'search_type' setting")

except NameError as e:
variable_name = str(e).split("'")[1]
raise NameError(f"{variable_name} isn't defined")
Expand Down Expand Up @@ -118,5 +127,4 @@ def get_rag_chain(params = None):
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)

rag_chain = create_retrieval_chain(retriever, question_answer_chain)

return rag_chain
15 changes: 7 additions & 8 deletions python_script/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from langchain.prompts import PromptTemplate

DATA_PATH = None
CHROMA_ROOT_PATH = None
DATABASE_ROOT_PATH = None
EMBEDDING_MODEL = None
LLM_MODEL = None
PROMPT_TEMPLATE = None
Expand All @@ -19,7 +19,7 @@ def load_api_keys():
load_dotenv()
os.environ["HF_API_TOKEN"] = os.getenv("HF_API_TOKEN")
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
os.environ["VOYAGE_API_KEY"] = os.getenv("VOYAGE_API_KEY")
#os.environ["VOYAGE_API_KEY"] = os.getenv("VOYAGE_API_KEY")


def load_config(config_name = 'default', show_config = False):
Expand All @@ -31,7 +31,7 @@ def load_config(config_name = 'default', show_config = False):
{
"config_name": {
"data_path": "", # Path to the data folder
"chroma_root_path": "", # Path to the folder where the Chroma database will be stored
"database_root_path": "", # Path to the folder where the database will be stored
"embedding_model": "", # Model to use for embeddings (e.g., 'sentence-transformers/all-mpnet-base-v2', 'openai', 'voyage-law-2')
"llm_model": "", # Model to use for the language model (e.g., 'gpt-3.5-turbo', 'mistralai/Mistral-7B-Instruct-v0.1', 'nvidia/Llama3-ChatQA-1.5-8B')
}
Expand All @@ -48,7 +48,7 @@ def load_config(config_name = 'default', show_config = False):
- "mistralai/Mixtral-8x7B-Instruct-v0.1"
- "nvidia/Llama3-ChatQA-1.5-8B"
"""
global DATA_PATH, CHROMA_ROOT_PATH, EMBEDDING_MODEL, LLM_MODEL
global DATA_PATH, DATABASE_ROOT_PATH, EMBEDDING_MODEL, LLM_MODEL
try:
with open('config.json', 'r') as file:
config = json.load(file)
Expand All @@ -60,11 +60,10 @@ def load_config(config_name = 'default', show_config = False):
raise FileNotFoundError("The configuration file cannot be found in the specified paths.")
except json.JSONDecodeError:
raise ValueError("The configuration file is present but contains a JSON format error.")

selected_config = config[config_name]

DATA_PATH = selected_config['data_path']
CHROMA_ROOT_PATH = selected_config['chroma_root_path']
DATABASE_ROOT_PATH = selected_config['database_root_path']
EMBEDDING_MODEL = selected_config['embedding_model']
LLM_MODEL = selected_config['llm_model']

Expand All @@ -79,11 +78,11 @@ def print_config():
Print the current configuration settings.
This function prints the values of the global configuration parameters.
"""
global DATA_PATH, CHROMA_ROOT_PATH, EMBEDDING_MODEL, LLM_MODEL
global DATA_PATH, DATABASE_ROOT_PATH, EMBEDDING_MODEL, LLM_MODEL

print("\nCurrent Configuration Settings:\n")
print(f"Data Path: {DATA_PATH}")
print(f"Chroma Root Path: {CHROMA_ROOT_PATH}")
print(f"Database Root Path: {DATABASE_ROOT_PATH}")
print(f"Embedding Model: {EMBEDDING_MODEL}")
print(f"Language Model: {LLM_MODEL}\n")

Expand Down
Loading