-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Separate data loading and retrieval (#286)
* Separate data loading and retrieval
- Loading branch information
Showing
4 changed files
with
145 additions
and
213 deletions.
There are no files selected for viewing
78 changes: 78 additions & 0 deletions
78
pebblo_safeloader/langchain/identity-rag/pebblo_identity_safeload.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# data_loader.py | ||
from dotenv import load_dotenv | ||
|
||
from typing import List | ||
|
||
from langchain.schema import Document | ||
from langchain_community.document_loaders import ( | ||
GoogleDriveLoader, | ||
UnstructuredFileIOLoader, | ||
) | ||
from langchain_community.document_loaders.pebblo import PebbloSafeLoader | ||
from langchain_community.vectorstores.qdrant import Qdrant | ||
from langchain_openai.embeddings import OpenAIEmbeddings | ||
|
||
load_dotenv() | ||
|
||
# Qdrant DB path | ||
QDRANT_PATH = "qdrant.db" | ||
# Qdrant DB collection name | ||
COLLECTION_NAME = "identity-enabled-rag" | ||
|
||
|
||
class QdrantDataLoader: | ||
def __init__(self, folder_id: str, collection_name: str = COLLECTION_NAME): | ||
self.app_name = "acme-corp-rag-1" | ||
self.folder_id = folder_id | ||
self.qdrant_collection_name = collection_name | ||
|
||
def load_documents(self): | ||
print("\nLoading RAG documents ...") | ||
loader = PebbloSafeLoader( | ||
GoogleDriveLoader( | ||
folder_id=self.folder_id, | ||
token_path="./google_token.json", | ||
recursive=True, | ||
file_loader_cls=UnstructuredFileIOLoader, | ||
file_loader_kwargs={"mode": "elements"}, | ||
load_auth=True, | ||
), | ||
name=self.app_name, # App name (Mandatory) | ||
owner="Joe Smith", # Owner (Optional) | ||
description="Identity enabled SafeLoader and SafeRetrival app using Pebblo", # Description (Optional) | ||
) | ||
documents = loader.load() | ||
for doc in documents: | ||
print(f"{doc.metadata}") | ||
|
||
# print(documents[-1].metadata.get("authorized_identities")) | ||
print(f"Loaded {len(documents)} documents ...\n") | ||
return documents | ||
|
||
def add_docs_to_qdrant(self, documents: List[Document]): | ||
""" | ||
Load documents into Qdrant | ||
""" | ||
print("\nAdding documents to Qdrant ...") | ||
embeddings = OpenAIEmbeddings() | ||
vectordb = Qdrant.from_documents( | ||
documents, | ||
embeddings, | ||
path=QDRANT_PATH, | ||
collection_name=self.qdrant_collection_name, | ||
) | ||
print(f"Added {len(documents)} documents to Qdrant ...\n") | ||
return vectordb | ||
|
||
|
||
if __name__ == "__main__": | ||
print("Loading documents to Qdrant ...") | ||
# def_folder_id = "1FQ-LrarHhWBJRGHc8yiH2ZtirpUXERYP" | ||
def_folder_id = "15CyFIWOPJOR5BxDID7G6tUisfHU1szrg" | ||
collection_name = "identity-enabled-rag" | ||
|
||
qloader = QdrantDataLoader(def_folder_id, collection_name) | ||
|
||
documents = qloader.load_documents() | ||
|
||
vectordb = qloader.add_docs_to_qdrant(documents) |
94 changes: 0 additions & 94 deletions
94
pebblo_saferetriever/langchain/identity-rag/pebblo_identity_rag-pinecone.py
This file was deleted.
Oops, something went wrong.
127 changes: 67 additions & 60 deletions
127
pebblo_saferetriever/langchain/identity-rag/pebblo_identity_rag-qdrant.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,91 +1,98 @@ | ||
from typing import List | ||
|
||
# Fill-in OPENAI_API_KEY in .env file | ||
# in this directory before proceeding | ||
from dotenv import load_dotenv | ||
from langchain.chains import PebbloRetrievalQA | ||
from langchain.schema import Document | ||
from langchain_community.document_loaders import ( | ||
GoogleDriveLoader, | ||
UnstructuredFileIOLoader, | ||
) | ||
from langchain_community.document_loaders.pebblo import PebbloSafeLoader | ||
from langchain_community.vectorstores.qdrant import Qdrant | ||
from langchain_openai.embeddings import OpenAIEmbeddings | ||
from langchain_openai import OpenAIEmbeddings | ||
from langchain_openai.llms import OpenAI | ||
from qdrant_client import QdrantClient | ||
|
||
from google_auth import get_authorized_identities | ||
|
||
load_dotenv() | ||
|
||
# Qdrant DB path | ||
QDRANT_PATH = "qdrant.db" | ||
# Qdrant DB collection name | ||
DEFAULT_COLLECTION_NAME = "identity-enabled-rag" | ||
|
||
class PebbloIdentityRAG: | ||
def __init__(self, folder_id: str, collection_name: str): | ||
self.app_name = "pebblo-identity-rag-1" | ||
self.collection_name = collection_name | ||
|
||
# Load documents | ||
print("Loading RAG documents ...") | ||
self.loader = PebbloSafeLoader( | ||
GoogleDriveLoader( | ||
folder_id=folder_id, | ||
token_path="./google_token.json", | ||
recursive=True, | ||
file_loader_cls=UnstructuredFileIOLoader, | ||
file_loader_kwargs={"mode": "elements"}, | ||
load_auth=True, | ||
), | ||
name=self.app_name, # App name (Mandatory) | ||
owner="Joe Smith", # Owner (Optional) | ||
description="Identity enabled SafeLoader and SafeRetrival app using Pebblo", # Description (Optional) | ||
) | ||
self.documents = self.loader.load() | ||
|
||
print(self.documents[-1].metadata.get("authorized_identities")) | ||
print(f"Loaded {len(self.documents)} documents ...\n") | ||
# Load documents into VectorDB | ||
print("Hydrating Vector DB ...") | ||
self.vectordb = self.embeddings(self.documents) | ||
print("Finished hydrating Vector DB ...\n") | ||
|
||
# Prepare LLM | ||
class PebbloIdentityRAG: | ||
def __init__(self, collection_name: str = DEFAULT_COLLECTION_NAME): | ||
self.app_name = "acme-corp-rag-1" | ||
self.qdrant_collection_name = collection_name | ||
self.llm = OpenAI() | ||
self.embeddings = OpenAIEmbeddings() | ||
self.vectordb = self.init_vector_db() | ||
|
||
def embeddings(self, docs: List[Document]): | ||
embeddings = OpenAIEmbeddings() | ||
vectordb = Qdrant.from_documents( | ||
docs, | ||
embeddings, | ||
location=":memory:", | ||
collection_name=self.collection_name, | ||
def init_vector_db(self): | ||
""" | ||
Load Vector DB from file | ||
""" | ||
client = QdrantClient( | ||
path=QDRANT_PATH, | ||
) | ||
vectordb = Qdrant( | ||
client=client, | ||
collection_name=self.qdrant_collection_name, | ||
embeddings=self.embeddings, | ||
) | ||
return vectordb | ||
|
||
def ask(self, question: str, auth: dict): | ||
def ask(self, question: str, auth_context: dict): | ||
# Prepare retriever QA chain | ||
retriever = PebbloRetrievalQA.from_chain_type( | ||
llm=self.llm, | ||
chain_type="stuff", | ||
retriever=self.vectordb.as_retriever(), | ||
verbose=True, | ||
auth_context=auth, | ||
auth_context=auth_context, | ||
) | ||
return retriever.invoke(question) | ||
|
||
|
||
if __name__ == "__main__": | ||
# TODO: pass the actual GoogleDrive folder id | ||
folder_id = "1sRvP0j6L6M_Ll0y_8Qp7cFWUOlpdbfN5" | ||
collection_name = "identity-enabled-rag" | ||
rag_app = PebbloIdentityRAG(folder_id, collection_name) | ||
def_service_acc_path = "credentials/service-account.json" | ||
def_ingestion_user_email_address = "admin@clouddefense.io" | ||
input_collection_name = "identity-enabled-rag" | ||
|
||
rag_app = PebbloIdentityRAG(input_collection_name) | ||
|
||
print("Please enter ingestion user details for loading data...") | ||
ingestion_user_email_address = ( | ||
input(f"email address ({def_ingestion_user_email_address}) : ") | ||
or def_ingestion_user_email_address | ||
) | ||
ingestion_user_service_account_path = ( | ||
input(f"service-account.json path ({def_service_acc_path}) : ") | ||
or def_service_acc_path | ||
) | ||
|
||
def_end_user = "demo-user-hr@daxa.ai" | ||
|
||
while True: | ||
print("Please enter end user details below") | ||
end_user_email_address = ( | ||
input(f"User email address ({def_end_user}): ") or def_end_user | ||
) | ||
prompt = input("Please provide the prompt : ") | ||
print(f"User: {end_user_email_address}.\nQuery:{prompt}\n") | ||
|
||
prompt = "What criteria are used to evaluate employee performance during performance reviews?" | ||
print(f"Query:\n{prompt}") | ||
auth = { | ||
"authorized_identities": get_authorized_identities( | ||
admin_user_email_address=ingestion_user_email_address, | ||
credentials_file_path=ingestion_user_service_account_path, | ||
user_email=end_user_email_address, | ||
) | ||
} | ||
response = rag_app.ask(prompt, auth) | ||
print(f"Response:\n{response}") | ||
try: | ||
continue_or_exist = int(input("\n\nType 1 to continue and 0 to exit : ")) | ||
except ValueError: | ||
print("Please provide valid input") | ||
continue | ||
|
||
user_1 = "user@clouddefense.io" | ||
auth = { | ||
"authorized_identities": get_authorized_identities(user_1), | ||
} | ||
if not continue_or_exist: | ||
exit(0) | ||
|
||
response = rag_app.ask(prompt, auth) | ||
print(f"Response:\n{response}") | ||
print("\n\n") |
59 changes: 0 additions & 59 deletions
59
pebblo_saferetriever/langchain/identity-rag/pinecone_data_loader.py
This file was deleted.
Oops, something went wrong.