diff --git a/backend/src/delete_document/main.py b/backend/src/delete_document/main.py index 41318e2..471055b 100644 --- a/backend/src/delete_document/main.py +++ b/backend/src/delete_document/main.py @@ -10,6 +10,7 @@ EMBEDDING_MODEL_ID = os.environ["EMBEDDING_MODEL_ID"] REGION = os.environ["REGION"] DATABASE_SECRET_NAME = os.environ["DATABASE_SECRET_NAME"] +ALL_DOCUMENTS = "ALL_DOCUMENTS" ddb = boto3.resource("dynamodb") document_table = ddb.Table(DOCUMENT_TABLE) @@ -28,6 +29,78 @@ def get_db_secret(): secret = json.loads(response) return secret +def get_embeddings(): + bedrock_runtime = boto3.client( + service_name="bedrock-runtime", + region_name=REGION, + ) + embeddings = BedrockEmbeddings( + model_id=EMBEDDING_MODEL_ID, + client=bedrock_runtime, + region_name=REGION, + ) + return embeddings + +def delete_collection(collection_name, connection_str, embeddings): + pgvector = PGVector( + collection_name=collection_name, + connection=connection_str, + embeddings=embeddings, + ) + pgvector.delete_collection() + +def delete_from_user_collection(user_id, document, connection_str, embeddings): + response = document_table.get_item( + Key={"userid": user_id, "documentid": ALL_DOCUMENTS} + ) + if "Item" in response: + documents_all = response["Item"] + # Remove document_split_ids from documents_all split ids + documents_all["document_split_ids"] = [ + id for id in documents_all["document_split_ids"] if id not in document["document_split_ids"] + ] + + collection_name = f"{user_id}_{ALL_DOCUMENTS}" + pgvector = PGVector( + collection_name=collection_name, + connection=connection_str, + embeddings=embeddings, + ) + + if not documents_all["document_split_ids"]: + # Delete documents_all and related conversations since document_split_ids is empty + with memory_table.batch_writer() as batch: + for item in documents_all["conversations"]: + batch.delete_item(Key={"SessionId": item["conversationid"]}) + + document_table.delete_item( + Key={"userid": user_id, "documentid": ALL_DOCUMENTS} + ) + + # Delete user collection since empty + pgvector.delete_collection() + + else: + # Adjust filesize and pages + documents_all["pages"] = str(int(documents_all["pages"]) - int(document["pages"])) + documents_all["filesize"] = str(int(documents_all["filesize"]) - int(document["filesize"])) + + # Update the ALL_DOCUMENTS item in the table + document_table.update_item( + Key={"userid": user_id, "documentid": ALL_DOCUMENTS}, + UpdateExpression="SET document_split_ids = :ids, pages = :pages, filesize = :filesize", + ExpressionAttributeValues={ + ":ids": documents_all["document_split_ids"], + ":pages": documents_all["pages"], + ":filesize": documents_all["filesize"] + }, + ) + + # Remove embeddings from user collection + pgvector.delete(document["document_split_ids"]) + + + @logger.inject_lambda_context(log_event=True) def lambda_handler(event, context): user_id = event["requestContext"]["authorizer"]["claims"]["sub"] @@ -61,27 +134,15 @@ def lambda_handler(event, context): logger.info("Deleting from vector store") - - bedrock_runtime = boto3.client( - service_name="bedrock-runtime", - region_name=REGION, - ) - embeddings = BedrockEmbeddings( - model_id=EMBEDDING_MODEL_ID, - client=bedrock_runtime, - region_name=REGION, - ) - - collection_name = f"{user_id}_{filename}" + embeddings = get_embeddings() db_secret = get_db_secret() connection_str = f"postgresql+psycopg2://{db_secret['username']}:{db_secret['password']}@{db_secret['host']}:5432/{db_secret['dbname']}?sslmode=require" - pgvector = PGVector( - collection_name=collection_name, - connection=connection_str, - embeddings=embeddings, - ) - pgvector.delete_collection() + user_file_collection_name = f"{user_id}_{filename}" + delete_collection(user_file_collection_name, connection_str, embeddings) + + if document_id is not ALL_DOCUMENTS: + delete_from_user_collection(user_id, document, connection_str, embeddings) logger.info("Deletion complete") diff --git a/backend/src/generate_embeddings/main.py b/backend/src/generate_embeddings/main.py index 24e0add..edc6c89 100644 --- a/backend/src/generate_embeddings/main.py +++ b/backend/src/generate_embeddings/main.py @@ -13,6 +13,9 @@ EMBEDDING_MODEL_ID = os.environ["EMBEDDING_MODEL_ID"] REGION = os.environ["REGION"] DATABASE_SECRET_NAME = os.environ["DATABASE_SECRET_NAME"] +ALL_DOCUMENTS = "ALL_DOCUMENTS" +PROCESSING = "PROCESSING" +READY = "READY" s3 = boto3.client("s3") ddb = boto3.resource("dynamodb") @@ -20,11 +23,27 @@ logger = Logger() -def set_doc_status(user_id, document_id, status): +def set_doc_status(user_id, document_id, status, ids=None): + if (ids): + UpdateExpression=""" + SET docstatus = :docstatus, + document_split_ids = list_append(if_not_exists(document_split_ids, :empty_list), :ids) + """ + ExpressionAttributeValues={ + ":docstatus": status, + ":ids": ids, + ":empty_list": [] + } + else: + UpdateExpression="SET docstatus = :docstatus" + ExpressionAttributeValues={ + ":docstatus": status + } + document_table.update_item( Key={"userid": user_id, "documentid": document_id}, - UpdateExpression="SET docstatus = :docstatus", - ExpressionAttributeValues={":docstatus": status}, + UpdateExpression=UpdateExpression, + ExpressionAttributeValues=ExpressionAttributeValues, ) def get_db_secret(): @@ -47,7 +66,8 @@ def lambda_handler(event, context): key = event_body["key"] file_name_full = key.split("/")[-1] - set_doc_status(user_id, document_id, "PROCESSING") + set_doc_status(user_id, document_id, PROCESSING) + set_doc_status(user_id, ALL_DOCUMENTS, PROCESSING) s3.download_file(BUCKET, key, f"/tmp/{file_name_full}") @@ -67,17 +87,20 @@ def lambda_handler(event, context): region_name=REGION, ) - collection_name = f"{user_id}_{file_name_full}" db_secret = get_db_secret() connection_str = f"postgresql+psycopg2://{db_secret['username']}:{db_secret['password']}@{db_secret['host']}:5432/{db_secret['dbname']}?sslmode=require" - vector_store = PGVector( - embeddings=embeddings, - collection_name=collection_name, - connection= connection_str, - use_jsonb=True, - ) - - vector_store.add_documents(split_document) - - set_doc_status(user_id, document_id, "READY") + collection_names = [f"{user_id}_{file_name_full}", f"{user_id}_{ALL_DOCUMENTS}"] + ids = [f"{user_id}_{file_name_full}_{i}" for i in range(len(split_document))] + for collection_name in collection_names: + vector_store = PGVector( + embeddings=embeddings, + collection_name=collection_name, + connection= connection_str, + use_jsonb=True, + ) + + vector_store.add_documents(split_document, ids=ids) + + set_doc_status(user_id, document_id, READY, ids) + set_doc_status(user_id, ALL_DOCUMENTS, READY, ids) diff --git a/backend/src/upload_trigger/main.py b/backend/src/upload_trigger/main.py index 2388814..d254e33 100644 --- a/backend/src/upload_trigger/main.py +++ b/backend/src/upload_trigger/main.py @@ -10,6 +10,8 @@ MEMORY_TABLE = os.environ["MEMORY_TABLE"] QUEUE = os.environ["QUEUE"] BUCKET = os.environ["BUCKET"] +UPLOADED = "UPLOADED" +ALL_DOCUMENTS = "ALL_DOCUMENTS" ddb = boto3.resource("dynamodb") @@ -20,47 +22,68 @@ logger = Logger() -@logger.inject_lambda_context(log_event=True) -def lambda_handler(event, context): - key = urllib.parse.unquote_plus(event["Records"][0]["s3"]["object"]["key"]) - split = key.split("/") - user_id = split[0] - file_name = split[1] +def create_document_and_conversation(user_id, filename, pages, filesize): + timestamp = datetime.utcnow() + timestamp_str = timestamp.strftime("%Y-%m-%dT%H:%M:%S.%fZ") document_id = shortuuid.uuid() - - s3.download_file(BUCKET, key, f"/tmp/{file_name}") - - with open(f"/tmp/{file_name}", "rb") as f: - reader = PyPDF2.PdfReader(f) - pages = str(len(reader.pages)) - conversation_id = shortuuid.uuid() - timestamp = datetime.utcnow() - timestamp_str = timestamp.strftime("%Y-%m-%dT%H:%M:%S.%fZ") - document = { "userid": user_id, - "documentid": document_id, - "filename": file_name, + "documentid": ALL_DOCUMENTS if (filename == ALL_DOCUMENTS) else document_id, + "filename": filename, "created": timestamp_str, "pages": pages, - "filesize": str(event["Records"][0]["s3"]["object"]["size"]), - "docstatus": "UPLOADED", + "filesize": filesize, + "docstatus": UPLOADED, "conversations": [], + "document_split_ids": [], } conversation = {"conversationid": conversation_id, "created": timestamp_str} document["conversations"].append(conversation) - document_table.put_item(Item=document) - conversation = {"SessionId": conversation_id, "History": []} + + return [document, conversation] + + +@logger.inject_lambda_context(log_event=True) +def lambda_handler(event, context): + key = urllib.parse.unquote_plus(event["Records"][0]["s3"]["object"]["key"]) + split = key.split("/") + user_id = split[0] + file_name = split[1] + + s3.download_file(BUCKET, key, f"/tmp/{file_name}") + + with open(f"/tmp/{file_name}", "rb") as f: + reader = PyPDF2.PdfReader(f) + pages = str(len(reader.pages)) + + ### Create new document & conversation history + filesize = str(event["Records"][0]["s3"]["object"]["size"]) + document, conversation = create_document_and_conversation(user_id, file_name, pages, filesize) + + document_table.put_item(Item=document) memory_table.put_item(Item=conversation) + ### Create/Update ALL_DOCUMENTS document + response = document_table.get_item(Key={"userid": user_id, "documentid": ALL_DOCUMENTS}) + if "Item" not in response: + documents_all, conversation_all = create_document_and_conversation(user_id, ALL_DOCUMENTS, pages, filesize) + memory_table.put_item(Item=conversation_all) + else: + documents_all = response["Item"] + documents_all["docstatus"] = UPLOADED + documents_all["pages"] = str(int(documents_all["pages"]) + int(pages)) + documents_all["filesize"] = str(int(documents_all["filesize"]) + int(filesize)) + + document_table.put_item(Item=documents_all) + message = { - "documentid": document_id, + "documentid": document["documentid"], "key": key, "user": user_id, }