Skip to content

Commit

Permalink
add sharepoint retriever app. (#399)
Browse files Browse the repository at this point in the history
* add sharepoint retriever app.

Signed-off-by: Rahul Tripathi <rauhl.psit.ec@gmail.com>

* add requiremens.txt

Signed-off-by: Rahul Tripathi <rauhl.psit.ec@gmail.com>

* add sample for non sdk synchrounos fetching of authorised identities.

Signed-off-by: Rahul Tripathi <rauhl.psit.ec@gmail.com>

* rename samples

Signed-off-by: Rahul Tripathi <rauhl.psit.ec@gmail.com>

* Update pebblo_identity_api_rag.py

---------

Signed-off-by: Rahul Tripathi <rauhl.psit.ec@gmail.com>
Co-authored-by: Rahul Tripathi <rauhl.psit.ec@gmail.com>
  • Loading branch information
rahul-trip and Rahul Tripathi authored Jun 21, 2024
1 parent 5f3a80c commit 910a403
Show file tree
Hide file tree
Showing 5 changed files with 450 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from typing import Optional
import os
import requests

from dotenv import load_dotenv
load_dotenv()

class SharepointADHelper:
def __init__(
self,
client_id: Optional[str] = None,
client_secret: Optional[str] = None,
tenant_id: Optional[str] = None
):
self.client_id = client_id or os.environ.get("O365_CLIENT_ID")
self.client_secret = client_secret or os.environ.get("O365_CLIENT_SECRET")
self.tenant_id = tenant_id or os.environ.get("O365_TENANT_ID")
if not all([self.client_id, self.client_secret, self.tenant_id]):
raise EnvironmentError(
"At-least one of O365_CLIENT_ID, O365_CLIENT_SECRET or O365_TENANT_ID not provided"
)
self.access_token = self.get_access_token()
if not self.access_token:
raise EnvironmentError("o365 client id/secret or tenant id is invalid."
"Please check the environment variables.")
self.headers = { 'Authorization': 'Bearer' + self.access_token }

def get_authorized_identities(self, user_id: str):
"""
Retrieves the authorized identities for a given user.
Args:
user_id (str): The ID of the user.
Returns:
list: A list of authorized identities, including associated group emails and the user ID.
"""
user = self._get_users(user_id)
user_index_id = user.get("id")
if not user_index_id:
print(f"Could not find the user `{user_id}` information in Microsoft Graph API. Not authorized.")
return [user_id]
associated_groups = self._get_associated_groups(user_index_id)
associated_groups_emails = [
group.get("mail") for group in associated_groups["value"] if group.get("mail")
]
return associated_groups_emails + [user_id]

def _get_associated_groups(self, user_index_id: str):
"""
Retrieves the associated groups for a given user.
Args:
user_index_id (str): The index ID of the user.
Returns:
dict: A dictionary containing the associated groups information.
Raises:
Exception: If there is an error while making the API request.
"""
url = f"https://graph.microsoft.com/v1.0/users/{user_index_id}/memberOf"
try:
response = requests.get(url=url, headers=self.headers, timeout=10)
response.raise_for_status()
except requests.exceptions.HTTPError:
print("Error while retrieving associated groups from Microsoft Graph API")
return {}
else:
return response.json()

def _get_users(self, user_id: str):
"""
Retrieves information about a specific user from the Microsoft Graph API.
Args:
user_id (str): The ID of the user to retrieve information for.
Returns:
dict: A dictionary containing the user's information.
Raises:
Exception: If there is an error while making the API request.
"""
url = f"https://graph.microsoft.com/v1.0/users/{user_id}"
try:
response = requests.get(url=url, headers=self.headers, timeout=10)
response.raise_for_status()
except requests.exceptions.HTTPError:
print("Error while retrieving user information from Microsoft Graph API")
return {}
else:
return response.json()


def get_access_token(self):
"""
Retrieves an access token from Microsoft Graph API using client credentials.
Returns:
str: The access token.
Raises:
requests.exceptions.HTTPError: If the request to retrieve the access token fails.
"""
# ToDo: This access token should be cached and refreshed when it expires
# It should also be stored in home directory or in a secure location
url = f"https://login.microsoftonline.com/{self.tenant_id}/oauth2/v2.0/token"
headers = {"Content-Type": "application/x-www-form-urlencoded"}
data = {
"grant_type": "client_credentials",
"client_id": self.client_id,
"client_secret": self.client_secret,
"scope": "https://graph.microsoft.com/.default"
}
try:
response = requests.post(url, headers=headers, data=data, timeout=10)
response.raise_for_status() # Raise exception if the request failed
except requests.exceptions.HTTPError:
print("Error while retrieving access token from Microsoft Graph API")
return ""
else:
return response.json()["access_token"]

if __name__ == "__main__":
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from dotenv import load_dotenv
load_dotenv()

import asyncio
import os
from typing import Optional

from msgraph import GraphServiceClient
from azure.identity import ClientSecretCredential
from kiota_abstractions.api_error import APIError

async def get_authorized_identities(
user_id: str,
client_id: Optional[str] = None,
client_secret: Optional[str] = None,
tenant_id: Optional[str] = None
):
client_id = client_id or os.environ.get("O365_CLIENT_ID")
client_secret = client_secret or os.environ.get("O365_CLIENT_SECRET")
tenant_id = tenant_id or os.environ.get("O365_TENANT_ID")
if not all([client_id, client_secret, tenant_id]):
raise EnvironmentError(
"atleast one of {O365_CLIENT_ID, O365_CLIENT_SECRET or O365_TENANT_ID not provided"
)
credentials = ClientSecretCredential(
tenant_id,
client_id,
client_secret,
)
graph_client = GraphServiceClient(credentials)

# user = graph_client.users.by_user_id(user_id)
try:
groups = await graph_client.users.by_user_id(user_id).member_of.get()
except APIError:
print(f"ms_graph API error: invalid user: {user_id}")
return [user_id]
auth_iden = [
group.__dict__.get("mail")
for group in groups.value
if group.__dict__.get("mail")
] + [user_id]
return auth_iden

if __name__ == "__main__":
print(asyncio.run(get_authorized_identities("arpit@daxaai.onmicrosoft.com")))

Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Fill-in OPENAI_API_KEY in .env file in this directory before proceeding
from dotenv import load_dotenv
load_dotenv()

import asyncio
import os
from msgraph_api_auth import SharepointADHelper
from langchain_community.chains import PebbloRetrievalQA
from langchain_community.chains.pebblo_retrieval.models import (
AuthContext,
ChainInput,
)
from langchain_community.document_loaders import UnstructuredFileIOLoader
from langchain_community.document_loaders.pebblo import PebbloSafeLoader
from langchain_community.vectorstores.qdrant import Qdrant
from langchain_community.document_loaders import SharePointLoader
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain_openai.llms import OpenAI




class PebbloIdentityRAG:
def __init__(self, drive_id: str, folder_path: str, collection_name: str):
self.loader_app_name = "pebblo-identity-loader"
self.retrieval_app_name = "pebblo-identity-retriever"
self.collection_name = collection_name
self.drive_id = drive_id
self.folder_path = folder_path

# Load documents
print("\nLoading RAG documents ...")
self.loader = PebbloSafeLoader(
SharePointLoader(
document_library_id=self.drive_id,
folder_path=self.folder_path or "/",
auth_with_token=True,
load_auth=True,
recursive=True,
load_extended_metadata=True,
),
name=self.loader_app_name, # App name (Mandatory)
owner="Joe Smith", # Owner (Optional)
description="Identity enabled SafeLoader and SafeRetrival app using Pebblo", # Description (Optional)
)
self.documents = self.loader.load()
print(self.documents[-1].metadata.get("authorized_identities"))
print(f"Loaded {len(self.documents)} documents ...\n")

# Load documents into VectorDB

print("Hydrating Vector DB ...")
self.vectordb = self.embeddings()
print("Finished hydrating Vector DB ...\n")

# Prepare LLM
self.llm = OpenAI()
print("Initializing PebbloRetrievalQA ...")
self.retrieval_chain = self.init_retrieval_chain()

def init_retrieval_chain(self):
"""
Initialize PebbloRetrievalQA chain
"""
return PebbloRetrievalQA.from_chain_type(
llm=self.llm,
app_name=self.retrieval_app_name,
owner="Joe Smith",
description="Identity enabled filtering using PebbloSafeLoader, and PebbloRetrievalQA",
chain_type="stuff",
retriever=self.vectordb.as_retriever(),
verbose=True,
)

def embeddings(self):
embeddings = OpenAIEmbeddings()
vectordb = Qdrant.from_documents(
self.documents,
embeddings,
location=":memory:",
collection_name=self.collection_name,
)
return vectordb

def ask(self, question: str, user_email: str, auth_identifiers: list):
auth_context = {
"user_id": user_email,
"user_auth": auth_identifiers,
}
auth_context = AuthContext(**auth_context)
chain_input = ChainInput(query=question, auth_context=auth_context)

return self.retrieval_chain.invoke(chain_input.dict())


if __name__ == "__main__":
input_collection_name = "identity-enabled-rag"

print("Please enter ingestion user details for loading data...")
app_client_id = input("App client id : ") or os.environ.get("O365_CLIENT_ID")
app_client_secret = input("App client secret : ") or os.environ.get("O365_CLIENT_SECRET")
tenant_id = input("Tenant id : ") or os.environ.get("O365_TENANT_ID")

drive_id = input("Drive id : ") or "b!TVvGZhXfGUuSKMdCgOucz08XRKxsDuVCojWCjzBMN-as9t-EstljQKBl332OMJnI"

rag_app = PebbloIdentityRAG(
drive_id = drive_id, folder_path = "/document", collection_name=input_collection_name
)

while True:
print("Please enter end user details below")
end_user_email_address = input("User email address : ") or "arpit@daxaai.onmicrosoft.com"
prompt = input("Please provide the prompt : ") or "tell me about sample.pdf."
print(f"User: {end_user_email_address}.\nQuery:{prompt}\n")
authorized_identities = SharepointADHelper(
client_id = app_client_id,
client_secret = app_client_secret,
tenant_id = tenant_id,
).get_authorized_identities(end_user_email_address)
response = rag_app.ask(prompt, end_user_email_address, authorized_identities)
print(f"Response:\n{response}")
try:
continue_or_exist = int(input("\n\nType 1 to continue and 0 to exit : "))
except ValueError:
print("Please provide valid input")
continue

if not continue_or_exist:
exit(0)

print("\n\n")

Loading

0 comments on commit 910a403

Please sign in to comment.