Skip to content

Commit

Permalink
Merge pull request #61 from tjipenk/dev_afn
Browse files Browse the repository at this point in the history
sqlite db for register user and chat history
  • Loading branch information
ruecat authored Sep 26, 2024
2 parents 1d114d6 + d3fb767 commit 58ce0a2
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 3 deletions.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -200,4 +200,8 @@ $RECYCLE.BIN/
.nfs*

# OpenSSH Keys
id_*
id_*


# user dn
users.db
34 changes: 33 additions & 1 deletion bot/func/interactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
import os
import aiohttp
import json
import sqlite3
from aiogram import types
from aiohttp import ClientTimeout
from asyncio import Lock
from functools import wraps
# from bot.run import load_allowed_ids_from_db
from dotenv import load_dotenv
load_dotenv()
token = os.getenv("TOKEN")
allowed_ids = list(map(int, os.getenv("USER_IDS", "").split(",")))
#allowed_ids = list(map(int, os.getenv("USER_IDS", "").split(",")))
admin_ids = list(map(int, os.getenv("ADMIN_IDS", "").split(",")))
ollama_base_url = os.getenv("OLLAMA_BASE_URL")
ollama_port = os.getenv("OLLAMA_PORT", "11434")
Expand Down Expand Up @@ -55,6 +57,36 @@ async def generate(payload: dict, modelname: str, prompt: str):
except aiohttp.ClientError as e:
print(f"Error during request: {e}")

def load_allowed_ids_from_db():
conn = sqlite3.connect('users.db')
c = conn.cursor()
c.execute("SELECT id FROM users")
user_ids = [row[0] for row in c.fetchall()]
conn.close()
return user_ids

allowed_ids = load_allowed_ids_from_db()

def get_all_users_from_db():
conn = sqlite3.connect('users.db')
c = conn.cursor()
c.execute("SELECT id, name FROM users")
users = c.fetchall()
conn.close()
return users

def remove_user_from_db(user_id):
conn = sqlite3.connect('users.db')
c = conn.cursor()
c.execute("DELETE FROM users WHERE id = ?", (user_id,))
removed = c.rowcount > 0
conn.commit()
conn.close()
if removed:
global allowed_ids
allowed_ids = [id for id in allowed_ids if id != user_id]
return removed

def perms_allowed(func):
@wraps(func)
async def wrapper(message: types.Message = None, query: types.CallbackQuery = None):
Expand Down
85 changes: 84 additions & 1 deletion bot/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,29 @@
import traceback
import io
import base64
import sqlite3
bot = Bot(token=token)
dp = Dispatcher()
start_kb = InlineKeyboardBuilder()
settings_kb = InlineKeyboardBuilder()


start_kb.row(
types.InlineKeyboardButton(text="ℹ️ About", callback_data="about"),
types.InlineKeyboardButton(text="⚙️ Settings", callback_data="settings"),
types.InlineKeyboardButton(text="📝 Register", callback_data="register"),

)
settings_kb.row(
types.InlineKeyboardButton(text="🔄 Switch LLM", callback_data="switchllm"),
types.InlineKeyboardButton(text="✏️ Edit system prompt", callback_data="editsystemprompt"),
# types.InlineKeyboardButton(text="✏️ Edit system prompt", callback_data="editsystemprompt"),
)

settings_kb.row(
types.InlineKeyboardButton(text="📋 List Users and remove User", callback_data="list_users"),
)


commands = [
types.BotCommand(command="start", description="Start"),
types.BotCommand(command="reset", description="Reset Chat"),
Expand All @@ -33,6 +43,48 @@
CHAT_TYPE_GROUP = "group"
CHAT_TYPE_SUPERGROUP = "supergroup"

def init_db():
conn = sqlite3.connect('users.db')
c = conn.cursor()
c.execute('''CREATE TABLE IF NOT EXISTS users
(id INTEGER PRIMARY KEY, name TEXT)''')
c.execute('''CREATE TABLE IF NOT EXISTS chats
(id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER,
role TEXT,
content TEXT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users(id))''')
conn.commit()
conn.close()

# Load user IDs from the database and update allowed_ids
# db_user_ids = load_allowed_ids_from_db()
# allowed_ids.extend([user_id for user_id in db_user_ids if user_id not in allowed_ids])

def register_user(user_id, user_name):
conn = sqlite3.connect('users.db')
c = conn.cursor()
c.execute("INSERT OR REPLACE INTO users VALUES (?, ?)", (user_id, user_name))
conn.commit()
conn.close()

def save_chat_message(user_id, role, content):
conn = sqlite3.connect('users.db')
c = conn.cursor()
c.execute("INSERT INTO chats (user_id, role, content) VALUES (?, ?, ?)",
(user_id, role, content))
conn.commit()
conn.close()


@dp.callback_query(lambda query: query.data == "register")
async def register_callback_handler(query: types.CallbackQuery):
user_id = query.from_user.id
user_name = query.from_user.full_name
register_user(user_id, user_name)
await query.answer("You have been registered successfully!")


async def get_bot_info():
global mention
Expand Down Expand Up @@ -133,6 +185,33 @@ async def about_callback_handler(query: types.CallbackQuery):
parse_mode=ParseMode.HTML,
disable_web_page_preview=True,
)

@dp.callback_query(lambda query: query.data == "list_users")
@perms_admins
async def list_users_callback_handler(query: types.CallbackQuery):
users = get_all_users_from_db()
user_kb = InlineKeyboardBuilder()
for user_id, user_name in users:
user_kb.row(types.InlineKeyboardButton(text=f"{user_name} ({user_id})", callback_data=f"remove_{user_id}"))
user_kb.row(types.InlineKeyboardButton(text="Cancel", callback_data="cancel_remove"))
await query.message.answer("Select a user to remove:", reply_markup=user_kb.as_markup())

@dp.callback_query(lambda query: query.data.startswith("remove_"))
@perms_admins
async def remove_user_from_list_handler(query: types.CallbackQuery):
user_id = int(query.data.split("_")[1])
if remove_user_from_db(user_id):
await query.answer(f"User {user_id} has been removed.")
await query.message.edit_text(f"User {user_id} has been removed.")
else:
await query.answer(f"User {user_id} not found.")

@dp.callback_query(lambda query: query.data == "cancel_remove")
@perms_admins
async def cancel_remove_handler(query: types.CallbackQuery):
await query.message.edit_text("User removal cancelled.")


@dp.message()
@perms_allowed
async def handle_message(message: types.Message):
Expand Down Expand Up @@ -254,6 +333,8 @@ async def ollama_request(message: types.Message, prompt: str = None):
if prompt is None:
prompt = message.text or message.caption

save_chat_message(message.from_user.id, "user", prompt)

await add_prompt_to_active_chats(message, prompt, image_base64, modelname)
logging.info(
f"[OllamaAPI]: Processing '{prompt}' for {message.from_user.first_name} {message.from_user.last_name}"
Expand All @@ -268,6 +349,7 @@ async def ollama_request(message: types.Message, prompt: str = None):

if any([c in chunk for c in ".\n!?"]) or response_data.get("done"):
if await handle_response(message, response_data, full_response):
save_chat_message(message.from_user.id, "assistant", full_response)
break

except Exception as e:
Expand All @@ -280,6 +362,7 @@ async def ollama_request(message: types.Message, prompt: str = None):


async def main():
init_db()
await bot.set_my_commands(commands)
await dp.start_polling(bot, skip_update=True)

Expand Down

0 comments on commit 58ce0a2

Please sign in to comment.