diff --git a/enterprise_knowledge_retriever/requirements.txt b/enterprise_knowledge_retriever/requirements.txt index 8acae439..57c28636 100644 --- a/enterprise_knowledge_retriever/requirements.txt +++ b/enterprise_knowledge_retriever/requirements.txt @@ -22,7 +22,7 @@ langchain==0.3.9 langgraph==0.2.28 mypy==1.11.2 nltk==3.9.1 -openpyxl==3.1.4 +openpyxl==3.1.5 pillow==10.4.0 pillow_heif==0.16.0 pre-commit==4.0.1 @@ -31,6 +31,7 @@ pydantic==2.9.2 pydantic_core==2.23.4 python-dotenv==1.0.0 ruff==0.6.9 +schedule==1.2.2 sentence_transformers==3.3.1 sseclient-py==1.8.0 streamlit-extras==0.4.3 diff --git a/enterprise_knowledge_retriever/streamlit/app.py b/enterprise_knowledge_retriever/streamlit/app.py index 3382d060..000988f8 100644 --- a/enterprise_knowledge_retriever/streamlit/app.py +++ b/enterprise_knowledge_retriever/streamlit/app.py @@ -2,9 +2,12 @@ import os import shutil import sys +import time import uuid +from threading import Thread from typing import Any, List, Optional +import schedule import streamlit as st import yaml from streamlit.runtime.uploaded_file_manager import UploadedFile @@ -25,9 +28,10 @@ CONFIG_PATH = os.path.join(kit_dir, 'config.yaml') APP_DESCRIPTION_PATH = os.path.join(kit_dir, 'streamlit', 'app_description.yaml') PERSIST_DIRECTORY = os.path.join(kit_dir, f'data/my-vector-db') +# Minutes for scheduled cache deletion +EXIT_TIME_DELTA = 30 logging.basicConfig(level=logging.INFO) -logging.info('URL: http://localhost:8501') def load_config() -> Any: @@ -40,19 +44,46 @@ def load_app_description() -> Any: return yaml.safe_load(yaml_file) -def save_files_user(docs: List[UploadedFile]) -> str: +def delete_temp_dir(temp_dir: str) -> None: + """Delete the temporary directory and its contents.""" + + if os.path.exists(temp_dir): + try: + shutil.rmtree(temp_dir) + logging.info(f'Temporary directory {temp_dir} deleted.') + except: + logging.info(f'Could not delete temporary directory {temp_dir}.') + + +def schedule_temp_dir_deletion(temp_dir: str, delay_minutes: int) -> None: + """Schedule the deletion of the temporary directory after a delay.""" + + schedule.every(delay_minutes).minutes.do(delete_temp_dir, temp_dir).tag(temp_dir) + + def run_scheduler() -> None: + while schedule.get_jobs(temp_dir): + schedule.run_pending() + time.sleep(1) + + # Run scheduler in a separate thread to be non-blocking + Thread(target=run_scheduler, daemon=True).start() + + +def save_files_user(docs: List[UploadedFile], schedule_deletion: bool = True) -> str: """ Save all user uploaded files in Streamlit to the tmp dir with their file names Args: - docs (List[UploadFile]): A list of uploaded files in Streamlit + docs (List[UploadFile]): A list of uploaded files in Streamlit. + schedule_deletion (bool): wether or not to schedule the deletion of the uploaded files + temporal folder. default to True. Returns: str: path where the files are saved. """ - # Create the data/tmp folder if it doesn't exist - temp_folder = os.path.join(kit_dir, 'data/tmp') + # Create the temporal folder to this session if it doesn't exist + temp_folder = os.path.join(kit_dir, 'data', 'tmp', st.session_state.session_temp_subfolder) if not os.path.exists(temp_folder): os.makedirs(temp_folder) else: @@ -75,6 +106,13 @@ def save_files_user(docs: List[UploadedFile]) -> str: with open(temp_file, 'wb') as f: f.write(doc.getvalue()) + if schedule_deletion: + schedule_temp_dir_deletion(temp_folder, EXIT_TIME_DELTA) + st.toast( + """your session will be active for the next 30 minutes, after this time files + will be deleted""" + ) + return temp_folder @@ -172,6 +210,8 @@ def main() -> None: st.session_state.document_retrieval = None if 'st_session_id' not in st.session_state: st.session_state.st_session_id = str(uuid.uuid4()) + if 'session_temp_subfolder' not in st.session_state: + st.session_state.session_temp_subfolder = 'upload_' + st.session_state.st_session_id if 'mp_events' not in st.session_state: st.session_state.mp_events = MixpanelEvents( os.getenv('MIXPANEL_TOKEN'), @@ -261,7 +301,7 @@ def main() -> None: with st.spinner('Processing'): try: if docs is not None: - temp_folder = save_files_user(docs) + temp_folder = save_files_user(docs, schedule_deletion=prod_mode) text_chunks = st.session_state.document_retrieval.parse_doc(temp_folder) if len(text_chunks) == 0: st.error( @@ -270,8 +310,11 @@ def main() -> None: ) embeddings = st.session_state.document_retrieval.load_embedding_model() collection_name = default_collection if not prod_mode else None + save_location = temp_folder + '_db' + if prod_mode: + schedule_temp_dir_deletion(save_location, EXIT_TIME_DELTA) vectorstore = st.session_state.document_retrieval.create_vector_store( - text_chunks, embeddings, output_db=None, collection_name=collection_name + text_chunks, embeddings, output_db=save_location, collection_name=collection_name ) st.session_state.vectorstore = vectorstore st.session_state.document_retrieval.init_retriever(vectorstore) @@ -290,7 +333,7 @@ def main() -> None: with st.spinner('Processing'): try: if docs is not None: - temp_folder = save_files_user(docs) + temp_folder = save_files_user(docs, schedule_deletion=prod_mode) text_chunks = st.session_state.document_retrieval.parse_doc(temp_folder) embeddings = st.session_state.document_retrieval.load_embedding_model() vectorstore = st.session_state.document_retrieval.create_vector_store(