Skip to content

Commit

Permalink
enable querying all documents
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmwita committed Aug 23, 2024
1 parent a1c16ae commit 167a2ce
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 56 deletions.
97 changes: 79 additions & 18 deletions backend/src/delete_document/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
EMBEDDING_MODEL_ID = os.environ["EMBEDDING_MODEL_ID"]
REGION = os.environ["REGION"]
DATABASE_SECRET_NAME = os.environ["DATABASE_SECRET_NAME"]
ALL_DOCUMENTS = "ALL_DOCUMENTS"

ddb = boto3.resource("dynamodb")
document_table = ddb.Table(DOCUMENT_TABLE)
Expand All @@ -28,6 +29,78 @@ def get_db_secret():
secret = json.loads(response)
return secret

def get_embeddings():
bedrock_runtime = boto3.client(
service_name="bedrock-runtime",
region_name=REGION,
)
embeddings = BedrockEmbeddings(
model_id=EMBEDDING_MODEL_ID,
client=bedrock_runtime,
region_name=REGION,
)
return embeddings

def delete_collection(collection_name, connection_str, embeddings):
pgvector = PGVector(
collection_name=collection_name,
connection=connection_str,
embeddings=embeddings,
)
pgvector.delete_collection()

def delete_from_user_collection(user_id, document, connection_str, embeddings):
response = document_table.get_item(
Key={"userid": user_id, "documentid": ALL_DOCUMENTS}
)
if "Item" in response:
documents_all = response["Item"]
# Remove document_split_ids from documents_all split ids
documents_all["document_split_ids"] = [
id for id in documents_all["document_split_ids"] if id not in document["document_split_ids"]
]

collection_name = f"{user_id}_{ALL_DOCUMENTS}"
pgvector = PGVector(
collection_name=collection_name,
connection=connection_str,
embeddings=embeddings,
)

if not documents_all["document_split_ids"]:
# Delete documents_all and related conversations since document_split_ids is empty
with memory_table.batch_writer() as batch:
for item in documents_all["conversations"]:
batch.delete_item(Key={"SessionId": item["conversationid"]})

document_table.delete_item(
Key={"userid": user_id, "documentid": ALL_DOCUMENTS}
)

# Delete user collection since empty
pgvector.delete_collection()

else:
# Adjust filesize and pages
documents_all["pages"] = str(int(documents_all["pages"]) - int(document["pages"]))
documents_all["filesize"] = str(int(documents_all["filesize"]) - int(document["filesize"]))

# Update the ALL_DOCUMENTS item in the table
document_table.update_item(
Key={"userid": user_id, "documentid": ALL_DOCUMENTS},
UpdateExpression="SET document_split_ids = :ids, pages = :pages, filesize = :filesize",
ExpressionAttributeValues={
":ids": documents_all["document_split_ids"],
":pages": documents_all["pages"],
":filesize": documents_all["filesize"]
},
)

# Remove embeddings from user collection
pgvector.delete(document["document_split_ids"])



@logger.inject_lambda_context(log_event=True)
def lambda_handler(event, context):
user_id = event["requestContext"]["authorizer"]["claims"]["sub"]
Expand Down Expand Up @@ -61,27 +134,15 @@ def lambda_handler(event, context):


logger.info("Deleting from vector store")

bedrock_runtime = boto3.client(
service_name="bedrock-runtime",
region_name=REGION,
)
embeddings = BedrockEmbeddings(
model_id=EMBEDDING_MODEL_ID,
client=bedrock_runtime,
region_name=REGION,
)

collection_name = f"{user_id}_{filename}"
embeddings = get_embeddings()
db_secret = get_db_secret()
connection_str = f"postgresql+psycopg2://{db_secret['username']}:{db_secret['password']}@{db_secret['host']}:5432/{db_secret['dbname']}?sslmode=require"

pgvector = PGVector(
collection_name=collection_name,
connection=connection_str,
embeddings=embeddings,
)
pgvector.delete_collection()
user_file_collection_name = f"{user_id}_{filename}"
delete_collection(user_file_collection_name, connection_str, embeddings)

if document_id is not ALL_DOCUMENTS:
delete_from_user_collection(user_id, document, connection_str, embeddings)

logger.info("Deletion complete")

Expand Down
53 changes: 38 additions & 15 deletions backend/src/generate_embeddings/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,37 @@
EMBEDDING_MODEL_ID = os.environ["EMBEDDING_MODEL_ID"]
REGION = os.environ["REGION"]
DATABASE_SECRET_NAME = os.environ["DATABASE_SECRET_NAME"]
ALL_DOCUMENTS = "ALL_DOCUMENTS"
PROCESSING = "PROCESSING"
READY = "READY"

s3 = boto3.client("s3")
ddb = boto3.resource("dynamodb")
document_table = ddb.Table(DOCUMENT_TABLE)
logger = Logger()


def set_doc_status(user_id, document_id, status):
def set_doc_status(user_id, document_id, status, ids=None):
if (ids):
UpdateExpression="""
SET docstatus = :docstatus,
document_split_ids = list_append(if_not_exists(document_split_ids, :empty_list), :ids)
"""
ExpressionAttributeValues={
":docstatus": status,
":ids": ids,
":empty_list": []
}
else:
UpdateExpression="SET docstatus = :docstatus"
ExpressionAttributeValues={
":docstatus": status
}

document_table.update_item(
Key={"userid": user_id, "documentid": document_id},
UpdateExpression="SET docstatus = :docstatus",
ExpressionAttributeValues={":docstatus": status},
UpdateExpression=UpdateExpression,
ExpressionAttributeValues=ExpressionAttributeValues,
)

def get_db_secret():
Expand All @@ -47,7 +66,8 @@ def lambda_handler(event, context):
key = event_body["key"]
file_name_full = key.split("/")[-1]

set_doc_status(user_id, document_id, "PROCESSING")
set_doc_status(user_id, document_id, PROCESSING)
set_doc_status(user_id, ALL_DOCUMENTS, PROCESSING)

s3.download_file(BUCKET, key, f"/tmp/{file_name_full}")

Expand All @@ -67,17 +87,20 @@ def lambda_handler(event, context):
region_name=REGION,
)

collection_name = f"{user_id}_{file_name_full}"
db_secret = get_db_secret()
connection_str = f"postgresql+psycopg2://{db_secret['username']}:{db_secret['password']}@{db_secret['host']}:5432/{db_secret['dbname']}?sslmode=require"

vector_store = PGVector(
embeddings=embeddings,
collection_name=collection_name,
connection= connection_str,
use_jsonb=True,
)

vector_store.add_documents(split_document)

set_doc_status(user_id, document_id, "READY")
collection_names = [f"{user_id}_{file_name_full}", f"{user_id}_{ALL_DOCUMENTS}"]
ids = [f"{user_id}_{file_name_full}_{i}" for i in range(len(split_document))]
for collection_name in collection_names:
vector_store = PGVector(
embeddings=embeddings,
collection_name=collection_name,
connection= connection_str,
use_jsonb=True,
)

vector_store.add_documents(split_document, ids=ids)

set_doc_status(user_id, document_id, READY, ids)
set_doc_status(user_id, ALL_DOCUMENTS, READY, ids)
69 changes: 46 additions & 23 deletions backend/src/upload_trigger/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
MEMORY_TABLE = os.environ["MEMORY_TABLE"]
QUEUE = os.environ["QUEUE"]
BUCKET = os.environ["BUCKET"]
UPLOADED = "UPLOADED"
ALL_DOCUMENTS = "ALL_DOCUMENTS"


ddb = boto3.resource("dynamodb")
Expand All @@ -20,47 +22,68 @@
logger = Logger()


@logger.inject_lambda_context(log_event=True)
def lambda_handler(event, context):
key = urllib.parse.unquote_plus(event["Records"][0]["s3"]["object"]["key"])
split = key.split("/")
user_id = split[0]
file_name = split[1]
def create_document_and_conversation(user_id, filename, pages, filesize):
timestamp = datetime.utcnow()
timestamp_str = timestamp.strftime("%Y-%m-%dT%H:%M:%S.%fZ")

document_id = shortuuid.uuid()

s3.download_file(BUCKET, key, f"/tmp/{file_name}")

with open(f"/tmp/{file_name}", "rb") as f:
reader = PyPDF2.PdfReader(f)
pages = str(len(reader.pages))

conversation_id = shortuuid.uuid()

timestamp = datetime.utcnow()
timestamp_str = timestamp.strftime("%Y-%m-%dT%H:%M:%S.%fZ")

document = {
"userid": user_id,
"documentid": document_id,
"filename": file_name,
"documentid": ALL_DOCUMENTS if (filename == ALL_DOCUMENTS) else document_id,
"filename": filename,
"created": timestamp_str,
"pages": pages,
"filesize": str(event["Records"][0]["s3"]["object"]["size"]),
"docstatus": "UPLOADED",
"filesize": filesize,
"docstatus": UPLOADED,
"conversations": [],
"document_split_ids": [],
}

conversation = {"conversationid": conversation_id, "created": timestamp_str}
document["conversations"].append(conversation)

document_table.put_item(Item=document)

conversation = {"SessionId": conversation_id, "History": []}

return [document, conversation]


@logger.inject_lambda_context(log_event=True)
def lambda_handler(event, context):
key = urllib.parse.unquote_plus(event["Records"][0]["s3"]["object"]["key"])
split = key.split("/")
user_id = split[0]
file_name = split[1]

s3.download_file(BUCKET, key, f"/tmp/{file_name}")

with open(f"/tmp/{file_name}", "rb") as f:
reader = PyPDF2.PdfReader(f)
pages = str(len(reader.pages))

### Create new document & conversation history
filesize = str(event["Records"][0]["s3"]["object"]["size"])
document, conversation = create_document_and_conversation(user_id, file_name, pages, filesize)

document_table.put_item(Item=document)
memory_table.put_item(Item=conversation)

### Create/Update ALL_DOCUMENTS document
response = document_table.get_item(Key={"userid": user_id, "documentid": ALL_DOCUMENTS})
if "Item" not in response:
documents_all, conversation_all = create_document_and_conversation(user_id, ALL_DOCUMENTS, pages, filesize)
memory_table.put_item(Item=conversation_all)
else:
documents_all = response["Item"]
documents_all["docstatus"] = UPLOADED
documents_all["pages"] = str(int(documents_all["pages"]) + int(pages))
documents_all["filesize"] = str(int(documents_all["filesize"]) + int(filesize))

document_table.put_item(Item=documents_all)

message = {
"documentid": document_id,
"documentid": document["documentid"],
"key": key,
"user": user_id,
}
Expand Down

0 comments on commit 167a2ce

Please sign in to comment.