diff --git a/backend/spielberg/core/session.py b/backend/spielberg/core/session.py index 7c656e3..f8f77c2 100644 --- a/backend/spielberg/core/session.py +++ b/backend/spielberg/core/session.py @@ -245,3 +245,6 @@ def get(self): def get_all(self): return self.db.get_sessions() + + def delete(self): + return self.db.delete_session(self.session_id) diff --git a/backend/spielberg/db/sqlite/db.py b/backend/spielberg/db/sqlite/db.py index 4885c4d..97434b6 100644 --- a/backend/spielberg/db/sqlite/db.py +++ b/backend/spielberg/db/sqlite/db.py @@ -157,5 +157,32 @@ def add_or_update_context_msg( ) self.conn.commit() + def delete_conversation(self, session_id: str) -> bool: + self.cursor.execute( + "DELETE FROM conversations WHERE session_id = ?", (session_id,) + ) + self.conn.commit() + return self.cursor.rowcount > 0 + + def delete_context(self, session_id: str) -> bool: + self.cursor.execute( + "DELETE FROM context_messages WHERE session_id = ?", (session_id,) + ) + self.conn.commit() + return self.cursor.rowcount > 0 + + def delete_session(self, session_id: str) -> bool: + failed_components = [] + if not self.delete_conversation(session_id): + failed_components.append("conversation") + if not self.delete_context(session_id): + failed_components.append("context") + self.cursor.execute("DELETE FROM sessions WHERE session_id = ?", (session_id,)) + self.conn.commit() + if not self.cursor.rowcount > 0: + failed_components.append("session") + success = len(failed_components) < 3 + return success, failed_components + def __del__(self): self.conn.close() diff --git a/backend/spielberg/entrypoint/api/routes.py b/backend/spielberg/entrypoint/api/routes.py index a9f4552..8a488fb 100644 --- a/backend/spielberg/entrypoint/api/routes.py +++ b/backend/spielberg/entrypoint/api/routes.py @@ -1,6 +1,6 @@ import os -from flask import Blueprint, current_app as app +from flask import Blueprint, request, current_app as app from spielberg.db import load_db from spielberg.handler import ChatHandler, SessionHandler, VideoDBHandler, ConfigHandler @@ -34,16 +34,31 @@ def get_sessions(): return session_handler.get_sessions() -@session_bp.route("/", methods=["GET"]) +@session_bp.route("/", methods=["GET", "DELETE"]) def get_session(session_id): """ - Get the session details + Get or delete the session details """ + if not session_id: + return {"message": f"Please provide {session_id}."}, 400 + session_handler = SessionHandler( db=load_db(os.getenv("SERVER_DB_TYPE", app.config["DB_TYPE"])) ) session = session_handler.get_session(session_id) - return session + if not session: + return {"message": "Session not found."}, 404 + + if request.method == "GET": + return session + elif request.method == "DELETE": + success, failed_components = session_handler.delete_session(session_id) + if success: + return {"message": "Session deleted successfully."}, 200 + else: + return { + "message": f"Failed to delete the entry for following components: {', '.join(failed_components)}" + }, 500 @videodb_bp.route("/collection", defaults={"collection_id": None}, methods=["GET"]) diff --git a/backend/spielberg/handler.py b/backend/spielberg/handler.py index 93dc95c..98a4c42 100644 --- a/backend/spielberg/handler.py +++ b/backend/spielberg/handler.py @@ -100,13 +100,17 @@ class SessionHandler: def __init__(self, db: BaseDB, **kwargs): self.db = db - def get_session(self, session_id): - session = Session(db=self.db, session_id=session_id) - return session.get() - def get_sessions(self): session = Session(db=self.db) return session.get_all() + + def get_session(self, session_id): + session = Session(db=self.db, session_id=session_id) + return session.get() + + def delete_session(self, session_id): + session = Session(db=self.db, session_id=session_id) + return session.delete() class VideoDBHandler: