Skip to content

Commit

Permalink
Merge branch 'chatpire:dev' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
maxduke authored Feb 3, 2024
2 parents c52654b + e97a5a1 commit 5b3051e
Show file tree
Hide file tree
Showing 32 changed files with 1,956 additions and 576 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Add source_id for BaseConversation
Revision ID: 333722b0921e
Revises: 7d94b5503088
Create Date: 2024-02-03 12:29:17.095124
"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = '333722b0921e'
down_revision = '7d94b5503088'
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('conversation', sa.Column('source_id', sa.String(length=256), nullable=True, comment='对话来源id'))
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('conversation', 'source_id')
# ### end Alembic commands ###
2 changes: 2 additions & 0 deletions backend/api/conf/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ class AuthSetting(BaseModel):
class OpenaiWebChatGPTSetting(BaseModel):
enabled: bool = True
is_plus_account: bool = True
enable_team_subscription: bool = False
team_account_id: Optional[str] = None
chatgpt_base_url: Optional[str] = None
proxy: Optional[str] = None
common_timeout: int = Field(20, ge=1,
Expand Down
1 change: 1 addition & 0 deletions backend/api/models/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class BaseConversation(Base):

id: Mapped[int] = mapped_column(Integer, primary_key=True)
source: Mapped[ChatSourceTypes] = mapped_column(Enum(ChatSourceTypes), comment="对话类型")
source_id: Mapped[Optional[str]] = mapped_column(String(256), comment="对话来源id")
conversation_id: Mapped[uuid.UUID] = mapped_column(GUID, index=True, unique=True, comment="uuid")
current_model: Mapped[Optional[str]] = mapped_column(default=None, use_existing_column=True)
title: Mapped[Optional[str]] = mapped_column(comment="对话标题")
Expand Down
3 changes: 3 additions & 0 deletions backend/api/models/doc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,9 @@ class OpenaiWebConversationHistoryMeta(BaseModel):
source: Literal["openai_web"]
moderation_results: Optional[list[Any]] = None
plugin_ids: Optional[list[str]] = None
gizmo_id: Optional[str] = None
is_archived: Optional[bool] = None
conversation_template_id: Optional[str] = None


class OpenaiApiConversationHistoryMeta(BaseModel):
Expand Down
108 changes: 74 additions & 34 deletions backend/api/routers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from fastapi.encoders import jsonable_encoder
from fastapi_cache.decorator import cache
from httpx import HTTPError
from pydantic import ValidationError
from pydantic import ValidationError, BaseModel
from sqlalchemy import select, func, and_
from starlette.websockets import WebSocket, WebSocketState
from websockets.exceptions import ConnectionClosed
Expand Down Expand Up @@ -39,25 +39,39 @@
config = Config()

INSTALLED_PLUGINS_CACHE_FILE_PATH = os.path.join(config.data.data_dir, "installed_plugin_manifests.json")
INSTALLED_PLUGINS_CACHE_EXPIRE = 3600 * 24
INSTALLED_PLUGINS_TEAM_CACHE_FILE_PATH = os.path.join(config.data.data_dir, "installed_plugin_manifests_team.json")
CACHE_EXPIRE_DURATION = 3600 * 24

_installed_plugins: OpenaiChatPluginListResponse | None = None
_installed_plugins_map: dict[str, OpenaiChatPlugin] | None = None
_installed_plugins_last_update_time = None

# TODO: 优化插件缓存处理,隔离不同来源的插件

def _load_installed_plugins_from_cache():
global _installed_plugins, _installed_plugins_map, _installed_plugins_last_update_time
if os.path.exists(INSTALLED_PLUGINS_CACHE_FILE_PATH):
with open(INSTALLED_PLUGINS_CACHE_FILE_PATH, "r") as f:
data = json.load(f)
_installed_plugins = OpenaiChatPluginListResponse.model_validate(data["installed_plugins"])
_installed_plugins_map = {plugin.id: plugin for plugin in _installed_plugins.items}
_installed_plugins_last_update_time = data["installed_plugins_last_update_time"]
class PluginsCache(BaseModel):
response: Optional[OpenaiChatPluginListResponse] = None
map: Optional[dict[str, OpenaiChatPlugin]] = None
last_update_time: Optional[float] = None


_cache_by_use_team = {
False: PluginsCache(),
True: PluginsCache()
}


def _save_installed_plugins_to_cache(installed_plugins, installed_plugins_last_update_time):
with open(INSTALLED_PLUGINS_CACHE_FILE_PATH, "w") as f:
def _load_installed_plugins_from_cache():
global _cache_by_use_team
for use_team in [False, True]:
_cache = _cache_by_use_team[use_team]
path = INSTALLED_PLUGINS_CACHE_FILE_PATH if not use_team else INSTALLED_PLUGINS_TEAM_CACHE_FILE_PATH
if os.path.exists(path):
with open(path, "r") as f:
data = json.load(f)
_cache.response = OpenaiChatPluginListResponse.model_validate(data["installed_plugins"])
_cache.map = {plugin.id: plugin for plugin in _cache.response.items}
_cache.last_update_time = data["installed_plugins_last_update_time"]


def _save_installed_plugins_to_cache(installed_plugins, installed_plugins_last_update_time, dest_path: str):
with open(dest_path, "w") as f:
json.dump(jsonable_encoder({
"installed_plugins": installed_plugins,
"installed_plugins_last_update_time": installed_plugins_last_update_time
Expand All @@ -67,34 +81,41 @@ def _save_installed_plugins_to_cache(installed_plugins, installed_plugins_last_u
_load_installed_plugins_from_cache()


async def _refresh_installed_plugins():
global _installed_plugins, _installed_plugins_map, _installed_plugins_last_update_time
if _installed_plugins is None or time.time() - _installed_plugins_last_update_time > INSTALLED_PLUGINS_CACHE_EXPIRE:
_installed_plugins = await openai_web_manager.get_installed_plugin_manifests()
_installed_plugins_map = {plugin.id: plugin for plugin in _installed_plugins.items}
_installed_plugins_last_update_time = time.time()
_save_installed_plugins_to_cache(_installed_plugins, _installed_plugins_last_update_time)
return _installed_plugins
async def _refresh_installed_plugins(use_team: bool = False):
global _cache_by_use_team

_cache = _cache_by_use_team[use_team]
if _cache.response is None or time.time() - _cache.last_update_time > CACHE_EXPIRE_DURATION:
_cache.response = await openai_web_manager.get_installed_plugin_manifests(use_team=use_team)
_cache.map = {plugin.id: plugin for plugin in _cache.response.items}
_cache.last_update_time = time.time()
_save_installed_plugins_to_cache(_cache.response, _cache.last_update_time,
INSTALLED_PLUGINS_TEAM_CACHE_FILE_PATH if use_team else INSTALLED_PLUGINS_CACHE_FILE_PATH)

return _cache.response


@router.get("/chat/openai-plugins", tags=["chat"], response_model=OpenaiChatPluginListResponse)
@cache(expire=60 * 60 * 24)
async def get_openai_web_chat_plugins(offset: int = 0, limit: int = 0, category: str = "", search: str = "",
_user: User = Depends(current_active_user)):
plugins = await openai_web_manager.get_plugin_manifests(offset, limit, category, search)
user: User = Depends(current_active_user)):
plugins = await openai_web_manager.get_plugin_manifests(offset, limit, category, search,
user.setting.openai_web.use_team)
return plugins


@router.get("/chat/openai-plugins/installed", tags=["chat"], response_model=OpenaiChatPluginListResponse)
async def get_installed_openai_web_chat_plugins(_user: User = Depends(current_active_user)):
async def get_installed_openai_web_chat_plugins(user: User = Depends(current_active_user)):
plugins = await _refresh_installed_plugins()
return plugins


@router.get("/chat/openai-plugins/installed/{plugin_id}", tags=["chat"], response_model=OpenaiChatPlugin)
async def get_installed_openai_web_plugin(plugin_id: str, _user: User = Depends(current_active_user)):
await _refresh_installed_plugins()
global _installed_plugins_map
async def get_installed_openai_web_plugin(plugin_id: str, user: User = Depends(current_active_user)):
use_team = user.setting.openai_web.use_team
await _refresh_installed_plugins(use_team)
global _cache_by_use_team
_installed_plugins_map = _cache_by_use_team[use_team].map
if plugin_id in _installed_plugins_map:
return _installed_plugins_map[plugin_id]
else:
Expand All @@ -103,12 +124,13 @@ async def get_installed_openai_web_plugin(plugin_id: str, _user: User = Depends(

@router.patch("/chat/openai-plugins/{plugin_id}/user-settings", tags=["chat"], response_model=OpenaiChatPlugin)
async def update_chat_plugin_user_settings(plugin_id: str, settings: OpenaiChatPluginUserSettings,
use_team: Optional[bool] = config.openai_web.enable_team_subscription,
_user: User = Depends(current_super_user)):
if settings.is_authenticated is not None:
raise InvalidParamsException("can not set is_authenticated")
result = await openai_web_manager.change_plugin_user_settings(plugin_id, settings)
result = await openai_web_manager.change_plugin_user_settings(plugin_id, settings, use_team)
assert isinstance(result, OpenaiChatPlugin)
await _refresh_installed_plugins()
await _refresh_installed_plugins(use_team)
return result


Expand Down Expand Up @@ -245,6 +267,8 @@ async def reply(response: AskResponse):

params = await websocket.receive_json()

use_team = user.setting.openai_web.use_team and config.openai_web.enable_team_subscription

try:
ask_request = AskRequest.model_validate(params)
except ValidationError as e:
Expand All @@ -269,6 +293,13 @@ async def reply(response: AskResponse):
conversation_id = ask_request.conversation_id
conversation = await _get_conversation_by_id(ask_request.conversation_id, user_db)

# 是否可用 team 对话
if conversation is not None and conversation.source_id is not None and use_team == False:
e = WebsocketException(1008, "errors.teamConversationNotAllowed")
await reply(AskResponse(type=AskResponseType.error, tip=e.tip, error_detail=e.error_detail))
await websocket.close(e.code, e.tip)
return

request_start_time = datetime.now()

websocket_code = 1001
Expand All @@ -282,6 +313,7 @@ async def reply(response: AskResponse):
queueing_end_time = None

# 排队
# TODO 可选的排队
if ask_request.source == ChatSourceTypes.openai_web:
if openai_web_manager.is_busy():
await reply(AskResponse(
Expand Down Expand Up @@ -323,10 +355,11 @@ async def reply(response: AskResponse):
model = OpenaiApiChatModels(ask_request.model)

# stream 传输
async for data in manager.complete(text_content=ask_request.text_content,
async for data in manager.complete(model=model,
text_content=ask_request.text_content,
use_team=use_team,
conversation_id=ask_request.conversation_id,
parent_message_id=ask_request.parent,
model=model,
plugin_ids=ask_request.openai_web_plugin_ids if ask_request.new_conversation else None,
attachments=ask_request.openai_web_attachments,
multimodal_image_parts=ask_request.openai_web_multimodal_image_parts,
Expand Down Expand Up @@ -370,6 +403,7 @@ async def reply(response: AskResponse):
except OpenaiException as e:
logger.error(str(e))
error_detail_map = {
400: "errors.openai.400",
401: "errors.openai.401",
403: "errors.openai.403",
404: "errors.openai.404",
Expand Down Expand Up @@ -513,7 +547,11 @@ async def reply(response: AskResponse):
if ask_request.source == ChatSourceTypes.openai_web and ask_request.new_title is not None and \
ask_request.new_title.strip() != "":
try:
await openai_web_manager.set_conversation_title(str(conversation_id), ask_request.new_title)
source_id = None
if use_team:
source_id = config.openai_web.team_account_id
await openai_web_manager.set_conversation_title(str(conversation_id), ask_request.new_title,
source_id=source_id)
except Exception as e:
logger.warning(f"set_conversation_title error {e.__class__.__name__}: {str(e)}")

Expand All @@ -528,6 +566,8 @@ async def reply(response: AskResponse):
create_time=current_time,
update_time=current_time
)
if use_team:
new_conv.source_id = config.openai_web.team_account_id
conversation = BaseConversation(**new_conv.model_dump(exclude_unset=True))
session.add(conversation)

Expand Down
39 changes: 26 additions & 13 deletions backend/api/routers/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,12 @@ async def get_all_conversations(_user: User = Depends(current_super_user), valid

@router.get("/conv/{conversation_id}", tags=["conversation"],
response_model=OpenaiApiConversationHistoryDocument | OpenaiWebConversationHistoryDocument | BaseConversationHistory)
async def get_conversation_history(conversation: BaseConversation = Depends(_get_conversation_by_id)):
async def get_conversation_history(conversation: BaseConversation = Depends(_get_conversation_by_id),
user: User = Depends(current_active_user)):
if conversation.source == ChatSourceTypes.openai_web:
try:
result = await openai_web_manager.get_conversation_history(conversation.conversation_id)
result = await openai_web_manager.get_conversation_history(conversation.conversation_id,
conversation.source_id)
if result.current_model != conversation.current_model or not conversation.is_valid:
async with get_async_session_context() as session:
conversation = await session.get(BaseConversation, conversation.id)
Expand Down Expand Up @@ -115,15 +117,16 @@ async def get_conversation_history_from_cache(conversation_id, user: User = Depe


@router.delete("/conv/{conversation_id}", tags=["conversation"])
async def delete_conversation(conversation: BaseConversation = Depends(_get_conversation_by_id)):
async def delete_conversation(conversation: BaseConversation = Depends(_get_conversation_by_id),
user: User = Depends(current_active_user)):
"""
软删除:标记为 invalid 并且从 chatgpt 账号中删除会话,但不会删除 mongodb 中的历史记录
"""
if not conversation.is_valid:
raise InvalidParamsException("errors.conversationAlreadyDeleted")
if conversation.source == ChatSourceTypes.openai_web:
try:
await openai_web_manager.delete_conversation(conversation.conversation_id)
await openai_web_manager.delete_conversation(conversation.conversation_id, conversation.source_id)
except OpenaiWebException as e:
logger.warning(f"delete conversation {conversation.conversation_id} failed: {e.code} {e.message}")
except httpx.HTTPStatusError as e:
Expand All @@ -143,7 +146,7 @@ async def vanish_conversation(conversation: BaseConversation = Depends(_get_conv
硬删除:删除数据库和账号中的对话和历史记录
"""
if conversation.is_valid:
await delete_conversation(conversation)
await delete_conversation(conversation, conversation.source_id)
if conversation.source == ChatSourceTypes.openai_web:
doc = await OpenaiWebConversationHistoryDocument.get(conversation.conversation_id)
else: # api
Expand All @@ -158,10 +161,11 @@ async def vanish_conversation(conversation: BaseConversation = Depends(_get_conv


@router.patch("/conv/{conversation_id}", tags=["conversation"], response_model=BaseConversationSchema)
async def update_conversation_title(title: str, conversation: BaseConversation = Depends(_get_conversation_by_id)):
async def update_conversation_title(title: str, conversation: BaseConversation = Depends(_get_conversation_by_id),
user: User = Depends(current_active_user)):
if conversation.source == ChatSourceTypes.openai_web:
await openai_web_manager.set_conversation_title(conversation.conversation_id,
title)
title, conversation.source_id)
else: # api
doc = await OpenaiApiConversationHistoryDocument.get(conversation.conversation_id)
if doc is None:
Expand Down Expand Up @@ -194,6 +198,7 @@ async def assign_conversation(username: str, conversation: BaseConversation = De
@router.delete("/conv", tags=["conversation"])
async def delete_all_conversation(_user: User = Depends(current_super_user)):
await openai_web_manager.clear_conversations()
await openai_web_manager.clear_conversations(use_team=True)
async with get_async_session_context() as session:
await session.execute(delete(OpenaiWebConversation))
await session.commit()
Expand All @@ -202,9 +207,11 @@ async def delete_all_conversation(_user: User = Depends(current_super_user)):

@router.patch("/conv/{conversation_id}/gen_title", tags=["conversation"], response_model=str)
async def generate_conversation_title(message_id: str,
conversation: OpenaiWebConversation = Depends(_get_conversation_by_id)):
conversation: OpenaiWebConversation = Depends(_get_conversation_by_id),
_user: User = Depends(current_active_user)):
async with get_async_session_context() as session:
title = await openai_web_manager.generate_conversation_title(conversation.conversation_id, message_id)
title = await openai_web_manager.generate_conversation_title(conversation.conversation_id, message_id,
conversation.source_id)
if title:
conversation.title = title
session.add(conversation)
Expand All @@ -216,14 +223,20 @@ async def generate_conversation_title(message_id: str,


@router.get("/conv/{conversation_id}/interpreter", tags=["conversation"], response_model=OpenaiChatInterpreterInfo)
async def get_conversation_interpreter_info(conversation_id: str):
url = await openai_web_manager.get_interpreter_info(conversation_id)
async def get_conversation_interpreter_info(conversation: OpenaiWebConversation = Depends(_get_conversation_by_id),
_user: User = Depends(current_active_user)):
url = await openai_web_manager.get_interpreter_info(conversation.conversation_id, conversation.source_id)
return response(200, result=url)


@router.get("/conv/{conversation_id}/interpreter/download-url", tags=["conversation"])
async def get_conversation_interpreter_download_url(conversation_id: str, message_id: str, sandbox_path: str):
async def get_conversation_interpreter_download_url(message_id: str, sandbox_path: str,
conversation: OpenaiWebConversation = Depends(
_get_conversation_by_id),
_user: User = Depends(current_active_user)):
if message_id is None or sandbox_path is None:
raise InvalidParamsException("message_id and sandbox_path are required")
url = await openai_web_manager.get_interpreter_file_download_url(conversation_id, message_id, sandbox_path)
url = await openai_web_manager.get_interpreter_file_download_url(conversation.conversation_id, message_id,
sandbox_path,
conversation.source_id)
return response(200, result=url)
Loading

0 comments on commit 5b3051e

Please sign in to comment.