-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add sharepoint retriever app. (#399)
* add sharepoint retriever app. Signed-off-by: Rahul Tripathi <rauhl.psit.ec@gmail.com> * add requiremens.txt Signed-off-by: Rahul Tripathi <rauhl.psit.ec@gmail.com> * add sample for non sdk synchrounos fetching of authorised identities. Signed-off-by: Rahul Tripathi <rauhl.psit.ec@gmail.com> * rename samples Signed-off-by: Rahul Tripathi <rauhl.psit.ec@gmail.com> * Update pebblo_identity_api_rag.py --------- Signed-off-by: Rahul Tripathi <rauhl.psit.ec@gmail.com> Co-authored-by: Rahul Tripathi <rauhl.psit.ec@gmail.com>
- Loading branch information
1 parent
5f3a80c
commit 910a403
Showing
5 changed files
with
450 additions
and
0 deletions.
There are no files selected for viewing
124 changes: 124 additions & 0 deletions
124
pebblo_saferetriever/langchain/identity-rag/sharepoint/msgraph_api_auth.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
from typing import Optional | ||
import os | ||
import requests | ||
|
||
from dotenv import load_dotenv | ||
load_dotenv() | ||
|
||
class SharepointADHelper: | ||
def __init__( | ||
self, | ||
client_id: Optional[str] = None, | ||
client_secret: Optional[str] = None, | ||
tenant_id: Optional[str] = None | ||
): | ||
self.client_id = client_id or os.environ.get("O365_CLIENT_ID") | ||
self.client_secret = client_secret or os.environ.get("O365_CLIENT_SECRET") | ||
self.tenant_id = tenant_id or os.environ.get("O365_TENANT_ID") | ||
if not all([self.client_id, self.client_secret, self.tenant_id]): | ||
raise EnvironmentError( | ||
"At-least one of O365_CLIENT_ID, O365_CLIENT_SECRET or O365_TENANT_ID not provided" | ||
) | ||
self.access_token = self.get_access_token() | ||
if not self.access_token: | ||
raise EnvironmentError("o365 client id/secret or tenant id is invalid." | ||
"Please check the environment variables.") | ||
self.headers = { 'Authorization': 'Bearer' + self.access_token } | ||
|
||
def get_authorized_identities(self, user_id: str): | ||
""" | ||
Retrieves the authorized identities for a given user. | ||
Args: | ||
user_id (str): The ID of the user. | ||
Returns: | ||
list: A list of authorized identities, including associated group emails and the user ID. | ||
""" | ||
user = self._get_users(user_id) | ||
user_index_id = user.get("id") | ||
if not user_index_id: | ||
print(f"Could not find the user `{user_id}` information in Microsoft Graph API. Not authorized.") | ||
return [user_id] | ||
associated_groups = self._get_associated_groups(user_index_id) | ||
associated_groups_emails = [ | ||
group.get("mail") for group in associated_groups["value"] if group.get("mail") | ||
] | ||
return associated_groups_emails + [user_id] | ||
|
||
def _get_associated_groups(self, user_index_id: str): | ||
""" | ||
Retrieves the associated groups for a given user. | ||
Args: | ||
user_index_id (str): The index ID of the user. | ||
Returns: | ||
dict: A dictionary containing the associated groups information. | ||
Raises: | ||
Exception: If there is an error while making the API request. | ||
""" | ||
url = f"https://graph.microsoft.com/v1.0/users/{user_index_id}/memberOf" | ||
try: | ||
response = requests.get(url=url, headers=self.headers, timeout=10) | ||
response.raise_for_status() | ||
except requests.exceptions.HTTPError: | ||
print("Error while retrieving associated groups from Microsoft Graph API") | ||
return {} | ||
else: | ||
return response.json() | ||
|
||
def _get_users(self, user_id: str): | ||
""" | ||
Retrieves information about a specific user from the Microsoft Graph API. | ||
Args: | ||
user_id (str): The ID of the user to retrieve information for. | ||
Returns: | ||
dict: A dictionary containing the user's information. | ||
Raises: | ||
Exception: If there is an error while making the API request. | ||
""" | ||
url = f"https://graph.microsoft.com/v1.0/users/{user_id}" | ||
try: | ||
response = requests.get(url=url, headers=self.headers, timeout=10) | ||
response.raise_for_status() | ||
except requests.exceptions.HTTPError: | ||
print("Error while retrieving user information from Microsoft Graph API") | ||
return {} | ||
else: | ||
return response.json() | ||
|
||
|
||
def get_access_token(self): | ||
""" | ||
Retrieves an access token from Microsoft Graph API using client credentials. | ||
Returns: | ||
str: The access token. | ||
Raises: | ||
requests.exceptions.HTTPError: If the request to retrieve the access token fails. | ||
""" | ||
# ToDo: This access token should be cached and refreshed when it expires | ||
# It should also be stored in home directory or in a secure location | ||
url = f"https://login.microsoftonline.com/{self.tenant_id}/oauth2/v2.0/token" | ||
headers = {"Content-Type": "application/x-www-form-urlencoded"} | ||
data = { | ||
"grant_type": "client_credentials", | ||
"client_id": self.client_id, | ||
"client_secret": self.client_secret, | ||
"scope": "https://graph.microsoft.com/.default" | ||
} | ||
try: | ||
response = requests.post(url, headers=headers, data=data, timeout=10) | ||
response.raise_for_status() # Raise exception if the request failed | ||
except requests.exceptions.HTTPError: | ||
print("Error while retrieving access token from Microsoft Graph API") | ||
return "" | ||
else: | ||
return response.json()["access_token"] | ||
|
||
if __name__ == "__main__": | ||
pass |
47 changes: 47 additions & 0 deletions
47
pebblo_saferetriever/langchain/identity-rag/sharepoint/msgraph_sdk_auth.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from dotenv import load_dotenv | ||
load_dotenv() | ||
|
||
import asyncio | ||
import os | ||
from typing import Optional | ||
|
||
from msgraph import GraphServiceClient | ||
from azure.identity import ClientSecretCredential | ||
from kiota_abstractions.api_error import APIError | ||
|
||
async def get_authorized_identities( | ||
user_id: str, | ||
client_id: Optional[str] = None, | ||
client_secret: Optional[str] = None, | ||
tenant_id: Optional[str] = None | ||
): | ||
client_id = client_id or os.environ.get("O365_CLIENT_ID") | ||
client_secret = client_secret or os.environ.get("O365_CLIENT_SECRET") | ||
tenant_id = tenant_id or os.environ.get("O365_TENANT_ID") | ||
if not all([client_id, client_secret, tenant_id]): | ||
raise EnvironmentError( | ||
"atleast one of {O365_CLIENT_ID, O365_CLIENT_SECRET or O365_TENANT_ID not provided" | ||
) | ||
credentials = ClientSecretCredential( | ||
tenant_id, | ||
client_id, | ||
client_secret, | ||
) | ||
graph_client = GraphServiceClient(credentials) | ||
|
||
# user = graph_client.users.by_user_id(user_id) | ||
try: | ||
groups = await graph_client.users.by_user_id(user_id).member_of.get() | ||
except APIError: | ||
print(f"ms_graph API error: invalid user: {user_id}") | ||
return [user_id] | ||
auth_iden = [ | ||
group.__dict__.get("mail") | ||
for group in groups.value | ||
if group.__dict__.get("mail") | ||
] + [user_id] | ||
return auth_iden | ||
|
||
if __name__ == "__main__": | ||
print(asyncio.run(get_authorized_identities("arpit@daxaai.onmicrosoft.com"))) | ||
|
132 changes: 132 additions & 0 deletions
132
pebblo_saferetriever/langchain/identity-rag/sharepoint/pebblo_identity_api_rag.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
# Fill-in OPENAI_API_KEY in .env file in this directory before proceeding | ||
from dotenv import load_dotenv | ||
load_dotenv() | ||
|
||
import asyncio | ||
import os | ||
from msgraph_api_auth import SharepointADHelper | ||
from langchain_community.chains import PebbloRetrievalQA | ||
from langchain_community.chains.pebblo_retrieval.models import ( | ||
AuthContext, | ||
ChainInput, | ||
) | ||
from langchain_community.document_loaders import UnstructuredFileIOLoader | ||
from langchain_community.document_loaders.pebblo import PebbloSafeLoader | ||
from langchain_community.vectorstores.qdrant import Qdrant | ||
from langchain_community.document_loaders import SharePointLoader | ||
from langchain_openai.embeddings import OpenAIEmbeddings | ||
from langchain_openai.llms import OpenAI | ||
|
||
|
||
|
||
|
||
class PebbloIdentityRAG: | ||
def __init__(self, drive_id: str, folder_path: str, collection_name: str): | ||
self.loader_app_name = "pebblo-identity-loader" | ||
self.retrieval_app_name = "pebblo-identity-retriever" | ||
self.collection_name = collection_name | ||
self.drive_id = drive_id | ||
self.folder_path = folder_path | ||
|
||
# Load documents | ||
print("\nLoading RAG documents ...") | ||
self.loader = PebbloSafeLoader( | ||
SharePointLoader( | ||
document_library_id=self.drive_id, | ||
folder_path=self.folder_path or "/", | ||
auth_with_token=True, | ||
load_auth=True, | ||
recursive=True, | ||
load_extended_metadata=True, | ||
), | ||
name=self.loader_app_name, # App name (Mandatory) | ||
owner="Joe Smith", # Owner (Optional) | ||
description="Identity enabled SafeLoader and SafeRetrival app using Pebblo", # Description (Optional) | ||
) | ||
self.documents = self.loader.load() | ||
print(self.documents[-1].metadata.get("authorized_identities")) | ||
print(f"Loaded {len(self.documents)} documents ...\n") | ||
|
||
# Load documents into VectorDB | ||
|
||
print("Hydrating Vector DB ...") | ||
self.vectordb = self.embeddings() | ||
print("Finished hydrating Vector DB ...\n") | ||
|
||
# Prepare LLM | ||
self.llm = OpenAI() | ||
print("Initializing PebbloRetrievalQA ...") | ||
self.retrieval_chain = self.init_retrieval_chain() | ||
|
||
def init_retrieval_chain(self): | ||
""" | ||
Initialize PebbloRetrievalQA chain | ||
""" | ||
return PebbloRetrievalQA.from_chain_type( | ||
llm=self.llm, | ||
app_name=self.retrieval_app_name, | ||
owner="Joe Smith", | ||
description="Identity enabled filtering using PebbloSafeLoader, and PebbloRetrievalQA", | ||
chain_type="stuff", | ||
retriever=self.vectordb.as_retriever(), | ||
verbose=True, | ||
) | ||
|
||
def embeddings(self): | ||
embeddings = OpenAIEmbeddings() | ||
vectordb = Qdrant.from_documents( | ||
self.documents, | ||
embeddings, | ||
location=":memory:", | ||
collection_name=self.collection_name, | ||
) | ||
return vectordb | ||
|
||
def ask(self, question: str, user_email: str, auth_identifiers: list): | ||
auth_context = { | ||
"user_id": user_email, | ||
"user_auth": auth_identifiers, | ||
} | ||
auth_context = AuthContext(**auth_context) | ||
chain_input = ChainInput(query=question, auth_context=auth_context) | ||
|
||
return self.retrieval_chain.invoke(chain_input.dict()) | ||
|
||
|
||
if __name__ == "__main__": | ||
input_collection_name = "identity-enabled-rag" | ||
|
||
print("Please enter ingestion user details for loading data...") | ||
app_client_id = input("App client id : ") or os.environ.get("O365_CLIENT_ID") | ||
app_client_secret = input("App client secret : ") or os.environ.get("O365_CLIENT_SECRET") | ||
tenant_id = input("Tenant id : ") or os.environ.get("O365_TENANT_ID") | ||
|
||
drive_id = input("Drive id : ") or "b!TVvGZhXfGUuSKMdCgOucz08XRKxsDuVCojWCjzBMN-as9t-EstljQKBl332OMJnI" | ||
|
||
rag_app = PebbloIdentityRAG( | ||
drive_id = drive_id, folder_path = "/document", collection_name=input_collection_name | ||
) | ||
|
||
while True: | ||
print("Please enter end user details below") | ||
end_user_email_address = input("User email address : ") or "arpit@daxaai.onmicrosoft.com" | ||
prompt = input("Please provide the prompt : ") or "tell me about sample.pdf." | ||
print(f"User: {end_user_email_address}.\nQuery:{prompt}\n") | ||
authorized_identities = SharepointADHelper( | ||
client_id = app_client_id, | ||
client_secret = app_client_secret, | ||
tenant_id = tenant_id, | ||
).get_authorized_identities(end_user_email_address) | ||
response = rag_app.ask(prompt, end_user_email_address, authorized_identities) | ||
print(f"Response:\n{response}") | ||
try: | ||
continue_or_exist = int(input("\n\nType 1 to continue and 0 to exit : ")) | ||
except ValueError: | ||
print("Please provide valid input") | ||
continue | ||
|
||
if not continue_or_exist: | ||
exit(0) | ||
|
||
print("\n\n") | ||
|
Oops, something went wrong.