Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: chat_routes #1512

Merged
merged 9 commits into from
Oct 30, 2023
Empty file added backend/routes/chat/__init_.py
Empty file.
43 changes: 43 additions & 0 deletions backend/routes/chat/brainful_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from llm.qa_base import QABaseBrainPicking
from routes.authorizations.brain_authorization import validate_brain_authorization
from routes.authorizations.types import RoleEnum
from routes.chat.interface import ChatInterface

from repository.brain import get_brain_details


class BrainfulChat(ChatInterface):
def validate_authorization(self, user_id, brain_id):
if brain_id:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

le if ici n'est plus nécessaire car tu fais déjà le if dans la factory

validate_brain_authorization(
brain_id=brain_id,
user_id=user_id,
required_roles=[RoleEnum.Viewer, RoleEnum.Editor, RoleEnum.Owner],
)

def get_openai_api_key(self, brain_id, user_id):
brain_details = get_brain_details(brain_id)
if brain_details:
return brain_details.openai_api_key

def get_answer_generator(
self,
brain_id,
chat_id,
model,
max_tokens,
temperature,
user_openai_api_key,
streaming,
prompt_id,
):
return QABaseBrainPicking(
chat_id=chat_id,
model=model,
max_tokens=max_tokens,
temperature=temperature,
brain_id=brain_id,
user_openai_api_key=user_openai_api_key,
streaming=streaming,
prompt_id=prompt_id,
)
36 changes: 36 additions & 0 deletions backend/routes/chat/brainless_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from llm.qa_headless import HeadlessQA
from routes.chat.interface import ChatInterface

from repository.user_identity import get_user_identity


class BrainlessChat(ChatInterface):
def validate_authorization(self, user_id, brain_id):
pass

def get_openai_api_key(self, brain_id, user_id):
user_identity = get_user_identity(user_id)

if user_identity is not None:
return user_identity.openai_api_key

def get_answer_generator(
self,
brain_id,
chat_id,
model,
max_tokens,
temperature,
user_openai_api_key,
streaming,
prompt_id,
):
return HeadlessQA(
chat_id=chat_id,
model=model,
max_tokens=max_tokens,
temperature=temperature,
user_openai_api_key=user_openai_api_key,
streaming=streaming,
prompt_id=prompt_id,
)
11 changes: 11 additions & 0 deletions backend/routes/chat/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from uuid import UUID

from .brainful_chat import BrainfulChat
from .brainless_chat import BrainlessChat


def get_chat_strategy(brain_id: UUID | None = None):
if brain_id:
return BrainfulChat()
else:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Le else n'est pas nécessaire ici

return BrainlessChat()
25 changes: 25 additions & 0 deletions backend/routes/chat/interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from abc import ABC, abstractmethod


class ChatInterface(ABC):
@abstractmethod
def validate_authorization(self, user_id, required_roles):
pass

@abstractmethod
def get_openai_api_key(self, brain_id, user_id):
pass

@abstractmethod
def get_answer_generator(
self,
brain_id,
chat_id,
model,
max_tokens,
temperature,
user_openai_api_key,
streaming,
prompt_id,
):
pass
57 changes: 57 additions & 0 deletions backend/routes/chat/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import time
from uuid import UUID

from fastapi import HTTPException
from models import UserIdentity, UserUsage
from models.databases.supabase.supabase import SupabaseDB


class NullableUUID(UUID):
@classmethod
def __get_validators__(cls):
yield cls.validate

@classmethod
def validate(cls, v) -> UUID | None:
if v == "":
return None
try:
return UUID(v)
except ValueError:
return None


def delete_chat_from_db(supabase_db: SupabaseDB, chat_id):
try:
supabase_db.delete_chat_history(chat_id)
except Exception as e:
print(e)
pass
try:
supabase_db.delete_chat(chat_id)
except Exception as e:
print(e)
pass


def check_user_requests_limit(
user: UserIdentity,
):
userDailyUsage = UserUsage(
id=user.id, email=user.email, openai_api_key=user.openai_api_key
)

userSettings = userDailyUsage.get_user_settings()

date = time.strftime("%Y%m%d")
userDailyUsage.handle_increment_user_request_count(date)

if user.openai_api_key is None:
daily_chat_credit = userSettings.get("daily_chat_credit", 0)
if int(userDailyUsage.daily_requests_count) >= int(daily_chat_credit):
raise HTTPException(
status_code=429, # pyright: ignore reportPrivateUsage=none
detail="You have reached the maximum number of requests for today.", # pyright: ignore reportPrivateUsage=none
)
else:
pass
Loading
Loading