Skip to content

Commit

Permalink
don't know who's more retarded, me or claude
Browse files Browse the repository at this point in the history
  • Loading branch information
rmusser01 committed Oct 31, 2024
1 parent 78850ad commit 57a43b1
Show file tree
Hide file tree
Showing 7 changed files with 246 additions and 177 deletions.
5 changes: 2 additions & 3 deletions App_Function_Libraries/Chat/Chat_Functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
# External Imports
#
# Local Imports
from App_Function_Libraries.DB.DB_Manager import get_conversation_name, save_chat_history_to_database, \
start_new_conversation, update_conversation_title, delete_messages_in_conversation, save_message
from App_Function_Libraries.DB.RAG_QA_Chat_DB import get_db_connection
from App_Function_Libraries.DB.DB_Manager import start_new_conversation, delete_messages_in_conversation, save_message
from App_Function_Libraries.DB.RAG_QA_Chat_DB import get_db_connection, get_conversation_name
from App_Function_Libraries.LLM_API_Calls import chat_with_openai, chat_with_anthropic, chat_with_cohere, \
chat_with_groq, chat_with_openrouter, chat_with_deepseek, chat_with_mistral, chat_with_huggingface
from App_Function_Libraries.LLM_API_Calls_Local import chat_with_aphrodite, chat_with_local_llm, chat_with_ollama, \
Expand Down
2 changes: 1 addition & 1 deletion App_Function_Libraries/DB/Character_Chat_DB.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,7 @@ def fetch_character_ids_by_keywords(keywords: List[str]) -> List[int]:

def view_char_keywords():
try:
with sqlite3.connect('character_chat.db') as conn:
with sqlite3.connect(chat_DB_PATH) as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT DISTINCT keyword
Expand Down
5 changes: 3 additions & 2 deletions App_Function_Libraries/DB/DB_Backups.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#
# Local Imports:
from App_Function_Libraries.DB.Character_Chat_DB import chat_DB_PATH
from App_Function_Libraries.DB.RAG_QA_Chat_DB import rag_qa_db_path
from App_Function_Libraries.DB.RAG_QA_Chat_DB import get_rag_qa_db_path
from App_Function_Libraries.Utils.Utils import get_project_relative_path
#
# End of Imports
Expand Down Expand Up @@ -109,8 +109,9 @@ def setup_backup_config():
backup_base_dir = get_project_relative_path('tldw_DB_Backups')

# RAG Chat DB configuration
rag_db_path = get_rag_qa_db_path()
rag_db_config = {
'db_path': rag_qa_db_path,
'db_path': rag_db_path,
'backup_dir': init_backup_directory(backup_base_dir, 'rag_qa'),
'db_name': 'rag_qa'
}
Expand Down
170 changes: 111 additions & 59 deletions App_Function_Libraries/DB/RAG_QA_Chat_DB.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Imports
import configparser
import logging
import os
import re
import sqlite3
import uuid
Expand All @@ -15,26 +16,24 @@
# (No external imports)
#
# Local Imports
from App_Function_Libraries.Utils.Utils import get_project_relative_path, get_database_path
from App_Function_Libraries.Utils.Utils import get_project_relative_path, get_project_root

#
########################################################################################################################
#
# Functions:

# Construct the path to the config file
config_path = get_project_relative_path('Config_Files/config.txt')

# Read the config file
config = configparser.ConfigParser()
config.read(config_path)

# Get the SQLite path from the config, or use the default if not specified
if config.has_section('Database') and config.has_option('Database', 'rag_qa_db_path'):
rag_qa_db_path = config.get('Database', 'rag_qa_db_path')
else:
rag_qa_db_path = get_database_path('RAG_QA_Chat.db')

print(f"RAG QA Chat Database path: {rag_qa_db_path}")
def get_rag_qa_db_path():
config_path = os.path.join(get_project_root(), 'Config_Files', 'config.txt')
config = configparser.ConfigParser()
config.read(config_path)
if config.has_section('Database') and config.has_option('Database', 'rag_qa_db_path'):
rag_qa_db_path = config.get('Database', 'rag_qa_db_path')
if not os.path.isabs(rag_qa_db_path):
rag_qa_db_path = get_project_relative_path(rag_qa_db_path)
return rag_qa_db_path
else:
raise ValueError("Database path not found in config file")

# Set up logging
logging.basicConfig(level=logging.INFO)
Expand All @@ -57,7 +56,8 @@
created_at DATETIME NOT NULL,
last_updated DATETIME NOT NULL,
title TEXT NOT NULL,
media_id INTEGER
media_id INTEGER,
rating INTEGER CHECK(rating BETWEEN 1 AND 3)
);
-- Table for storing keywords
Expand Down Expand Up @@ -163,7 +163,8 @@
# Database connection management
@contextmanager
def get_db_connection():
conn = sqlite3.connect(rag_qa_db_path)
db_path = get_rag_qa_db_path()
conn = sqlite3.connect(db_path)
try:
yield conn
finally:
Expand Down Expand Up @@ -197,37 +198,16 @@ def execute_query(query, params=None, conn=None):
conn.commit()
return cursor.fetchall()


# DIRTY HACK
# FIXME - DELETE AND REMOVE
def update_schema(conn):
cursor = conn.cursor()
# Check if 'media_id' column exists in 'conversation_metadata' table
cursor.execute("PRAGMA table_info(conversation_metadata);")
columns = [info[1] for info in cursor.fetchall()]
if 'media_id' not in columns:
# Add the 'media_id' column
cursor.execute("ALTER TABLE conversation_metadata ADD COLUMN media_id INTEGER;")
conn.commit()
logger.info("'media_id' column added to 'conversation_metadata' table.")
else:
logger.info("'media_id' column already exists in 'conversation_metadata' table.")


def create_tables():
with get_db_connection() as conn:
cursor = conn.cursor()
# Execute the SCHEMA_SQL to create tables if they don't exist
cursor.executescript(SCHEMA_SQL)
logger.info("All RAG QA Chat tables created successfully")

# update the schema to ensure it includes any new columns
update_schema(conn)

# Initialize the database
create_tables()


#
# End of Setup
############################################################
Expand Down Expand Up @@ -337,7 +317,8 @@ def add_keywords_to_conversation(conversation_id, keywords):

def view_rag_keywords():
try:
with sqlite3.connect('RAG_QA_Chat.db') as conn:
rag_db_path = get_rag_qa_db_path()
with sqlite3.connect(rag_db_path) as conn:
cursor = conn.cursor()
cursor.execute("SELECT keyword FROM rag_qa_keywords ORDER BY keyword")
keywords = cursor.fetchall()
Expand Down Expand Up @@ -590,11 +571,13 @@ def start_new_conversation(title="Untitled Conversation", media_id=None):
try:
conversation_id = str(uuid.uuid4())
query = """
INSERT INTO conversation_metadata (conversation_id, created_at, last_updated, title, media_id)
VALUES (?, ?, ?, ?, ?)
INSERT INTO conversation_metadata (
conversation_id, created_at, last_updated, title, media_id, rating
) VALUES (?, ?, ?, ?, ?, ?)
"""
now = datetime.now().isoformat()
execute_query(query, (conversation_id, now, now, title, media_id))
# Set initial rating to NULL
execute_query(query, (conversation_id, now, now, title, media_id, None))
logger.info(f"New conversation '{conversation_id}' started with title '{title}' and media_id '{media_id}'")
return conversation_id
except Exception as e:
Expand All @@ -605,15 +588,15 @@ def start_new_conversation(title="Untitled Conversation", media_id=None):
def get_all_conversations(page=1, page_size=20):
try:
query = """
SELECT conversation_id, title, media_id
FROM conversation_metadata
ORDER BY last_updated DESC
LIMIT ? OFFSET ?
SELECT conversation_id, title, media_id, rating
FROM conversation_metadata
ORDER BY last_updated DESC
LIMIT ? OFFSET ?
"""

count_query = "SELECT COUNT(*) FROM conversation_metadata"

with sqlite3.connect(rag_qa_db_path) as conn:
db_path = get_rag_qa_db_path()
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()

# Get total count
Expand All @@ -629,10 +612,10 @@ def get_all_conversations(page=1, page_size=20):
conversations = [{
'conversation_id': row[0],
'title': row[1],
'media_id': row[2]
'media_id': row[2],
'rating': row[3] # Include rating
} for row in results]

return conversations, total_pages, total_count
return conversations, total_pages, total_count
except Exception as e:
logging.error(f"Error getting conversations: {e}")
raise
Expand All @@ -650,8 +633,8 @@ def get_all_notes(page=1, page_size=20):
"""

count_query = "SELECT COUNT(*) FROM rag_qa_notes"

with sqlite3.connect(rag_qa_db_path) as conn:
db_path = get_rag_qa_db_path()
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()

# Get total count
Expand Down Expand Up @@ -842,7 +825,8 @@ def get_conversation_text(conversation_id):

messages = []
# Use the connection as a context manager
with sqlite3.connect(rag_qa_db_path) as conn:
db_path = get_rag_qa_db_path()
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
cursor.execute(query, (conversation_id,))
messages = cursor.fetchall()
Expand All @@ -854,12 +838,12 @@ def get_conversation_text(conversation_id):


def get_conversation_details(conversation_id):
query = "SELECT title, media_id FROM conversation_metadata WHERE conversation_id = ?"
query = "SELECT title, media_id, rating FROM conversation_metadata WHERE conversation_id = ?"
result = execute_query(query, (conversation_id,))
if result:
return {'title': result[0][0], 'media_id': result[0][1]}
return {'title': result[0][0], 'media_id': result[0][1], 'rating': result[0][2]}
else:
return {'title': "Untitled Conversation", 'media_id': None}
return {'title': "Untitled Conversation", 'media_id': None, 'rating': None}


def delete_conversation(conversation_id):
Expand All @@ -881,6 +865,72 @@ def delete_conversation(conversation_id):
logger.error(f"Error deleting conversation '{conversation_id}': {e}")
raise

def set_conversation_rating(conversation_id, rating):
"""Set the rating for a conversation."""
# Validate rating
if rating not in [1, 2, 3]:
raise ValueError('Rating must be an integer between 1 and 3.')
try:
query = "UPDATE conversation_metadata SET rating = ? WHERE conversation_id = ?"
execute_query(query, (rating, conversation_id))
logger.info(f"Rating for conversation '{conversation_id}' set to {rating}")
except Exception as e:
logger.error(f"Error setting rating for conversation '{conversation_id}': {e}")
raise

def get_conversation_rating(conversation_id):
"""Get the rating of a conversation."""
try:
query = "SELECT rating FROM conversation_metadata WHERE conversation_id = ?"
result = execute_query(query, (conversation_id,))
if result:
rating = result[0][0]
logger.info(f"Rating for conversation '{conversation_id}' is {rating}")
return rating
else:
logger.warning(f"Conversation '{conversation_id}' not found.")
return None
except Exception as e:
logger.error(f"Error getting rating for conversation '{conversation_id}': {e}")
raise


def get_conversation_name(conversation_id: str) -> str:
"""
Retrieves the title/name of a conversation from the conversation_metadata table.
Args:
conversation_id (str): The unique identifier of the conversation
Returns:
str: The title of the conversation if found, "Untitled Conversation" if not found
Raises:
sqlite3.Error: If there's a database error
"""
try:
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT title FROM conversation_metadata WHERE conversation_id = ?",
(conversation_id,)
)
result = cursor.fetchone()

if result:
return result[0]
else:
logging.warning(f"No conversation found with ID: {conversation_id}")
return "Untitled Conversation"

except sqlite3.Error as e:
logging.error(f"Database error retrieving conversation name for ID {conversation_id}: {e}")
raise
except Exception as e:
logging.error(f"Unexpected error retrieving conversation name for ID {conversation_id}: {e}")
raise


def search_rag_chat(query: str, fts_top_k: int = 10, relevant_media_ids: List[str] = None) -> List[Dict[str, Any]]:
"""
Perform a full-text search on the RAG Chat database.
Expand All @@ -897,7 +947,8 @@ def search_rag_chat(query: str, fts_top_k: int = 10, relevant_media_ids: List[st
return []

try:
with sqlite3.connect(rag_qa_db_path) as conn:
db_path = get_rag_qa_db_path()
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
# Perform the full-text search using the FTS virtual table
cursor.execute("""
Expand Down Expand Up @@ -954,7 +1005,8 @@ def search_rag_notes(query: str, fts_top_k: int = 10, relevant_media_ids: List[s
return []

try:
with sqlite3.connect(rag_qa_db_path) as conn:
db_path = get_rag_qa_db_path()
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
# Perform the full-text search using the FTS virtual table
cursor.execute("""
Expand Down
Loading

0 comments on commit 57a43b1

Please sign in to comment.