Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvement: tmp folder per session in EKR #493

Merged
merged 6 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion enterprise_knowledge_retriever/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ langchain==0.3.9
langgraph==0.2.28
mypy==1.11.2
nltk==3.9.1
openpyxl==3.1.4
openpyxl==3.1.5
snova-kwasia marked this conversation as resolved.
Show resolved Hide resolved
pillow==10.4.0
pillow_heif==0.16.0
pre-commit==4.0.1
Expand All @@ -31,6 +31,7 @@ pydantic==2.9.2
pydantic_core==2.23.4
python-dotenv==1.0.0
ruff==0.6.9
schedule==1.2.2
sentence_transformers==3.3.1
sseclient-py==1.8.0
streamlit-extras==0.4.3
Expand Down
59 changes: 51 additions & 8 deletions enterprise_knowledge_retriever/streamlit/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
import os
import shutil
import sys
import time
import uuid
from threading import Thread
from typing import Any, List, Optional

import schedule
import streamlit as st
import yaml
from streamlit.runtime.uploaded_file_manager import UploadedFile
Expand All @@ -25,9 +28,10 @@
CONFIG_PATH = os.path.join(kit_dir, 'config.yaml')
APP_DESCRIPTION_PATH = os.path.join(kit_dir, 'streamlit', 'app_description.yaml')
PERSIST_DIRECTORY = os.path.join(kit_dir, f'data/my-vector-db')
# Minutes for scheduled cache deletion
EXIT_TIME_DELTA = 30

logging.basicConfig(level=logging.INFO)
logging.info('URL: http://localhost:8501')


def load_config() -> Any:
Expand All @@ -40,19 +44,46 @@ def load_app_description() -> Any:
return yaml.safe_load(yaml_file)


def save_files_user(docs: List[UploadedFile]) -> str:
def delete_temp_dir(temp_dir: str) -> None:
"""Delete the temporary directory and its contents."""

if os.path.exists(temp_dir):
try:
shutil.rmtree(temp_dir)
logging.info(f'Temporary directory {temp_dir} deleted.')
except:
logging.info(f'Could not delete temporary directory {temp_dir}.')


def schedule_temp_dir_deletion(temp_dir: str, delay_minutes: int) -> None:
"""Schedule the deletion of the temporary directory after a delay."""

schedule.every(delay_minutes).minutes.do(delete_temp_dir, temp_dir).tag(temp_dir)

def run_scheduler() -> None:
while schedule.get_jobs(temp_dir):
schedule.run_pending()
time.sleep(1)

# Run scheduler in a separate thread to be non-blocking
Thread(target=run_scheduler, daemon=True).start()


def save_files_user(docs: List[UploadedFile], schedule_deletion: bool = True) -> str:
"""
Save all user uploaded files in Streamlit to the tmp dir with their file names

Args:
docs (List[UploadFile]): A list of uploaded files in Streamlit
docs (List[UploadFile]): A list of uploaded files in Streamlit.
schedule_deletion (bool): wether or not to schedule the deletion of the uploaded files
temporal folder. default to True.

Returns:
str: path where the files are saved.
"""

# Create the data/tmp folder if it doesn't exist
temp_folder = os.path.join(kit_dir, 'data/tmp')
# Create the temporal folder to this session if it doesn't exist
temp_folder = os.path.join(kit_dir, 'data', 'tmp', st.session_state.session_temp_subfolder)
if not os.path.exists(temp_folder):
os.makedirs(temp_folder)
else:
Expand All @@ -75,6 +106,13 @@ def save_files_user(docs: List[UploadedFile]) -> str:
with open(temp_file, 'wb') as f:
f.write(doc.getvalue())

if schedule_deletion:
schedule_temp_dir_deletion(temp_folder, EXIT_TIME_DELTA)
st.toast(
"""your session will be active for the next 30 minutes, after this time files
will be deleted"""
)

return temp_folder


Expand Down Expand Up @@ -172,6 +210,8 @@ def main() -> None:
st.session_state.document_retrieval = None
if 'st_session_id' not in st.session_state:
st.session_state.st_session_id = str(uuid.uuid4())
if 'session_temp_subfolder' not in st.session_state:
st.session_state.session_temp_subfolder = 'upload_' + st.session_state.st_session_id
if 'mp_events' not in st.session_state:
st.session_state.mp_events = MixpanelEvents(
os.getenv('MIXPANEL_TOKEN'),
Expand Down Expand Up @@ -261,7 +301,7 @@ def main() -> None:
with st.spinner('Processing'):
try:
if docs is not None:
temp_folder = save_files_user(docs)
temp_folder = save_files_user(docs, schedule_deletion=prod_mode)
text_chunks = st.session_state.document_retrieval.parse_doc(temp_folder)
if len(text_chunks) == 0:
st.error(
Expand All @@ -270,8 +310,11 @@ def main() -> None:
)
embeddings = st.session_state.document_retrieval.load_embedding_model()
collection_name = default_collection if not prod_mode else None
save_location = temp_folder + '_db'
if prod_mode:
schedule_temp_dir_deletion(save_location, EXIT_TIME_DELTA)
vectorstore = st.session_state.document_retrieval.create_vector_store(
text_chunks, embeddings, output_db=None, collection_name=collection_name
text_chunks, embeddings, output_db=save_location, collection_name=collection_name
)
st.session_state.vectorstore = vectorstore
st.session_state.document_retrieval.init_retriever(vectorstore)
Expand All @@ -290,7 +333,7 @@ def main() -> None:
with st.spinner('Processing'):
try:
if docs is not None:
temp_folder = save_files_user(docs)
temp_folder = save_files_user(docs, schedule_deletion=prod_mode)
text_chunks = st.session_state.document_retrieval.parse_doc(temp_folder)
embeddings = st.session_state.document_retrieval.load_embedding_model()
vectorstore = st.session_state.document_retrieval.create_vector_store(
Expand Down
Loading