Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add sharepoint retriever app. #399

Merged
merged 5 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from typing import Optional
import os
import requests

from dotenv import load_dotenv
load_dotenv()

class SharepointADHelper:
def __init__(
self,
client_id: Optional[str] = None,
client_secret: Optional[str] = None,
tenant_id: Optional[str] = None
):
self.client_id = client_id or os.environ.get("O365_CLIENT_ID")
self.client_secret = client_secret or os.environ.get("O365_CLIENT_SECRET")
self.tenant_id = tenant_id or os.environ.get("O365_TENANT_ID")
if not all([self.client_id, self.client_secret, self.tenant_id]):
raise EnvironmentError(
"At-least one of O365_CLIENT_ID, O365_CLIENT_SECRET or O365_TENANT_ID not provided"
)
self.access_token = self.get_access_token()
if not self.access_token:
raise EnvironmentError("o365 client id/secret or tenant id is invalid."
"Please check the environment variables.")
self.headers = { 'Authorization': 'Bearer' + self.access_token }

def get_authorized_identities(self, user_id: str):
"""
Retrieves the authorized identities for a given user.

Args:
user_id (str): The ID of the user.

Returns:
list: A list of authorized identities, including associated group emails and the user ID.
"""
user = self._get_users(user_id)
user_index_id = user.get("id")
if not user_index_id:
print(f"Could not find the user `{user_id}` information in Microsoft Graph API. Not authorized.")
return [user_id]
associated_groups = self._get_associated_groups(user_index_id)
associated_groups_emails = [
group.get("mail") for group in associated_groups["value"] if group.get("mail")
]
return associated_groups_emails + [user_id]

def _get_associated_groups(self, user_index_id: str):
"""
Retrieves the associated groups for a given user.

Args:
user_index_id (str): The index ID of the user.

Returns:
dict: A dictionary containing the associated groups information.

Raises:
Exception: If there is an error while making the API request.
"""
url = f"https://graph.microsoft.com/v1.0/users/{user_index_id}/memberOf"
try:
response = requests.get(url=url, headers=self.headers, timeout=10)
response.raise_for_status()
except requests.exceptions.HTTPError:
print("Error while retrieving associated groups from Microsoft Graph API")
return {}
else:
return response.json()

def _get_users(self, user_id: str):
"""
Retrieves information about a specific user from the Microsoft Graph API.

Args:
user_id (str): The ID of the user to retrieve information for.

Returns:
dict: A dictionary containing the user's information.

Raises:
Exception: If there is an error while making the API request.
"""
url = f"https://graph.microsoft.com/v1.0/users/{user_id}"
try:
response = requests.get(url=url, headers=self.headers, timeout=10)
response.raise_for_status()
except requests.exceptions.HTTPError:
print("Error while retrieving user information from Microsoft Graph API")
return {}
else:
return response.json()


def get_access_token(self):
"""
Retrieves an access token from Microsoft Graph API using client credentials.
Returns:
str: The access token.
Raises:
requests.exceptions.HTTPError: If the request to retrieve the access token fails.
"""
# ToDo: This access token should be cached and refreshed when it expires
# It should also be stored in home directory or in a secure location
url = f"https://login.microsoftonline.com/{self.tenant_id}/oauth2/v2.0/token"
headers = {"Content-Type": "application/x-www-form-urlencoded"}
data = {
"grant_type": "client_credentials",
"client_id": self.client_id,
"client_secret": self.client_secret,
"scope": "https://graph.microsoft.com/.default"
}
try:
response = requests.post(url, headers=headers, data=data, timeout=10)
response.raise_for_status() # Raise exception if the request failed
except requests.exceptions.HTTPError:
print("Error while retrieving access token from Microsoft Graph API")
return ""
else:
return response.json()["access_token"]

if __name__ == "__main__":
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from dotenv import load_dotenv
load_dotenv()

import asyncio
import os
from typing import Optional

from msgraph import GraphServiceClient
from azure.identity import ClientSecretCredential
from kiota_abstractions.api_error import APIError

async def get_authorized_identities(
user_id: str,
client_id: Optional[str] = None,
client_secret: Optional[str] = None,
tenant_id: Optional[str] = None
):
client_id = client_id or os.environ.get("O365_CLIENT_ID")
client_secret = client_secret or os.environ.get("O365_CLIENT_SECRET")
tenant_id = tenant_id or os.environ.get("O365_TENANT_ID")
if not all([client_id, client_secret, tenant_id]):
raise EnvironmentError(
"atleast one of {O365_CLIENT_ID, O365_CLIENT_SECRET or O365_TENANT_ID not provided"
)
credentials = ClientSecretCredential(
tenant_id,
client_id,
client_secret,
)
graph_client = GraphServiceClient(credentials)

# user = graph_client.users.by_user_id(user_id)
try:
groups = await graph_client.users.by_user_id(user_id).member_of.get()
except APIError:
print(f"ms_graph API error: invalid user: {user_id}")
return [user_id]
auth_iden = [
group.__dict__.get("mail")
for group in groups.value
if group.__dict__.get("mail")
] + [user_id]
return auth_iden

if __name__ == "__main__":
print(asyncio.run(get_authorized_identities("arpit@daxaai.onmicrosoft.com")))

Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Fill-in OPENAI_API_KEY in .env file in this directory before proceeding
from dotenv import load_dotenv
load_dotenv()

import asyncio
import os
from msgraph_api_auth import SharepointADHelper
from langchain_community.chains import PebbloRetrievalQA
from langchain_community.chains.pebblo_retrieval.models import (
AuthContext,
ChainInput,
)
from langchain_community.document_loaders import UnstructuredFileIOLoader
from langchain_community.document_loaders.pebblo import PebbloSafeLoader
from langchain_community.vectorstores.qdrant import Qdrant
from langchain_community.document_loaders import SharePointLoader
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain_openai.llms import OpenAI




class PebbloIdentityRAG:
def __init__(self, drive_id: str, folder_path: str, collection_name: str):
self.loader_app_name = "pebblo-identity-loader"
self.retrieval_app_name = "pebblo-identity-retriever"
self.collection_name = collection_name
self.drive_id = drive_id
self.folder_path = folder_path

# Load documents
print("\nLoading RAG documents ...")
self.loader = PebbloSafeLoader(
SharePointLoader(
document_library_id=self.drive_id,
folder_path=self.folder_path or "/",
auth_with_token=True,
load_auth=True,
recursive=True,
load_extended_metadata=True,
),
name=self.loader_app_name, # App name (Mandatory)
owner="Joe Smith", # Owner (Optional)
description="Identity enabled SafeLoader and SafeRetrival app using Pebblo", # Description (Optional)
)
self.documents = self.loader.load()
print(self.documents[-1].metadata.get("authorized_identities"))
print(f"Loaded {len(self.documents)} documents ...\n")

# Load documents into VectorDB

print("Hydrating Vector DB ...")
self.vectordb = self.embeddings()
print("Finished hydrating Vector DB ...\n")

# Prepare LLM
self.llm = OpenAI()
print("Initializing PebbloRetrievalQA ...")
self.retrieval_chain = self.init_retrieval_chain()

def init_retrieval_chain(self):
"""
Initialize PebbloRetrievalQA chain
"""
return PebbloRetrievalQA.from_chain_type(
llm=self.llm,
app_name=self.retrieval_app_name,
owner="Joe Smith",
description="Identity enabled filtering using PebbloSafeLoader, and PebbloRetrievalQA",
chain_type="stuff",
retriever=self.vectordb.as_retriever(),
verbose=True,
)

def embeddings(self):
embeddings = OpenAIEmbeddings()
vectordb = Qdrant.from_documents(
self.documents,
embeddings,
location=":memory:",
collection_name=self.collection_name,
)
return vectordb

def ask(self, question: str, user_email: str, auth_identifiers: list):
auth_context = {
"user_id": user_email,
"user_auth": auth_identifiers,
}
auth_context = AuthContext(**auth_context)
chain_input = ChainInput(query=question, auth_context=auth_context)

return self.retrieval_chain.invoke(chain_input.dict())


if __name__ == "__main__":
input_collection_name = "identity-enabled-rag"

print("Please enter ingestion user details for loading data...")
app_client_id = input("App client id : ") or os.environ.get("O365_CLIENT_ID")
app_client_secret = input("App client secret : ") or os.environ.get("O365_CLIENT_SECRET")
tenant_id = input("Tenant id : ") or os.environ.get("O365_TENANT_ID")

drive_id = input("Drive id : ") or "b!TVvGZhXfGUuSKMdCgOucz08XRKxsDuVCojWCjzBMN-as9t-EstljQKBl332OMJnI"

rag_app = PebbloIdentityRAG(
drive_id = drive_id, folder_path = "/document", collection_name=input_collection_name
)

while True:
print("Please enter end user details below")
end_user_email_address = input("User email address : ") or "arpit@daxaai.onmicrosoft.com"
prompt = input("Please provide the prompt : ") or "tell me about sample.pdf."
print(f"User: {end_user_email_address}.\nQuery:{prompt}\n")
authorized_identities = SharepointADHelper(
client_id = app_client_id,
client_secret = app_client_secret,
tenant_id = tenant_id,
).get_authorized_identities(end_user_email_address)
response = rag_app.ask(prompt, end_user_email_address, authorized_identities)
print(f"Response:\n{response}")
try:
continue_or_exist = int(input("\n\nType 1 to continue and 0 to exit : "))
except ValueError:
print("Please provide valid input")
continue

if not continue_or_exist:
exit(0)

print("\n\n")

Loading