Skip to content

Commit

Permalink
Merge pull request #384 from LlmKira/dev
Browse files Browse the repository at this point in the history
LMDB fallback && Fix bug on tool_call
  • Loading branch information
sudoskys authored Apr 18, 2024
2 parents ee2e629 + 960673a commit 1a58d1f
Show file tree
Hide file tree
Showing 20 changed files with 631 additions and 154 deletions.
12 changes: 12 additions & 0 deletions app/middleware/llm_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ def unique_function(tools: List[Tool]):
return functions


def mock_tool_message(assistant_message: AssistantMessage, mock_content: str):
_tool_message = []
if assistant_message.tool_calls:
for tool_call in assistant_message.tool_calls:
_tool_message.append(
ToolMessage(content=mock_content, tool_call_id=tool_call.id)
)
return _tool_message


async def validate_mock(messages: List[Message]):
"""
所有的具有 tool_calls 的 AssistantMessage 后面必须有对应的 ToolMessage 响应,其他消息类型按照原顺序
Expand Down Expand Up @@ -90,6 +100,8 @@ def pair_check(_messages):
else:
new_list.append(_messages[i])
new_list.append(_messages[-1])
if isinstance(_messages[-1], AssistantMessage) and _messages[-1].tool_calls:
new_list.extend(mock_tool_message(_messages[-1], "[On Queue]"))
return new_list

final_messages = pair_check(paired_messages)
Expand Down
82 changes: 49 additions & 33 deletions app/receiver/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,19 @@
from llmkira.task.schema import EventMessage, Location, Sign, Snapshot
from llmkira.task.snapshot import global_snapshot_storage, SnapData

# 记录上次调用时间的字典
TOOL_CALL_LAST_TIME = {}


def has_been_called_recently(userid, n_seconds):
current_time = time.time()
if userid in TOOL_CALL_LAST_TIME:
last_call_time = TOOL_CALL_LAST_TIME[userid]
if current_time - last_call_time <= n_seconds:
return True
TOOL_CALL_LAST_TIME[userid] = current_time
return False


async def append_snapshot(
snapshot_credential: Optional[str],
Expand Down Expand Up @@ -189,19 +202,19 @@ async def run_pending_task(task: TaskHeader, pending_task: ToolCall):
)

# Resign Chain
# 时序实现,防止过度注册
if len(task.task_sign.tool_calls_pending) == 1:
logger.debug("ToolCall run out, resign a new request to request stop sign.")
# NOTE:因为 ToolCall 破坏了递归的链式调用,所以这里不再继续调用
"""
await create_snapshot(
task=task,
tool_calls_pending_now=pending_task,
memory_able=True,
channel=task.receiver.platform,
)
"""
pass
# 运行函数, 传递模型的信息,以及上一条的结果的openai raw信息
if has_been_called_recently(userid=task.receiver.uid, n_seconds=5):
logger.debug(
"ToolCall run out, resign a new request to request stop sign."
)
await create_snapshot(
task=task,
tool_calls_pending_now=pending_task,
memory_able=True,
channel=task.receiver.platform,
)
# 运行函数, 传递模型的信息,以及上一条的结果的openai raw信息
run_result = await _tool_obj.load(
task=task,
receiver=task.receiver,
Expand All @@ -225,28 +238,31 @@ async def process_function_call(self, message: AbstractIncomingMessage):
if os.getenv("STOP_REPLY"):
logger.warning("🚫 STOP_REPLY is set in env, stop reply message")
return None
task: TaskHeader = TaskHeader.model_validate_json(
json_data=message.body.decode("utf-8")
logger.debug(
f"[552351] Received A Function Call from {message.body.decode('utf-8')}"
)
logger.debug(f"[552351] Received A Function Call from {task.receiver.platform}")
# Get Function Call
pending_task: ToolCall = await task.task_sign.get_pending_tool_call(
tool_calls_pending_now=task.task_sign.snapshot_credential,
return_default_if_empty=True,
)
if not pending_task:
return logger.debug("But No ToolCall")
logger.debug("Received A ToolCall")
try:
await self.run_pending_task(task=task, pending_task=pending_task)
except Exception as e:
await task.task_sign.complete_task(
tool_calls=pending_task, success_or_not=False, run_result=str(e)
task: TaskHeader = TaskHeader.model_validate_json(message.body.decode("utf-8"))
RUN_LIMIT = 6
while task.task_sign.tool_calls_pending and RUN_LIMIT > 0:
RUN_LIMIT -= 1
# Get Function Call
pending_task: ToolCall = await task.task_sign.get_pending_tool_call(
tool_calls_pending_now=task.task_sign.snapshot_credential,
return_default_if_empty=True,
)
logger.error(f"Function Call Error {e}")
raise e
finally:
logger.trace("Function Call Finished")
if not pending_task:
return logger.debug("But No ToolCall")
logger.debug("Received A ToolCall")
try:
await self.run_pending_task(task=task, pending_task=pending_task)
except Exception as e:
await task.task_sign.complete_task(
tool_calls=pending_task, success_or_not=False, run_result=str(e)
)
logger.error(f"Function Call Error {e}")
raise e
finally:
logger.trace("Function Call Finished")

async def on_message(self, message: AbstractIncomingMessage):
"""
Expand All @@ -257,7 +273,7 @@ async def on_message(self, message: AbstractIncomingMessage):
try:
await self.process_function_call(message=message)
except Exception as e:
logger.exception(f"Function Receiver Error {e}")
logger.exception(f"Function Receiver Error:{e}")
await message.reject(requeue=False)
raise e
else:
Expand Down
141 changes: 84 additions & 57 deletions app/receiver/receiver_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#####
# This file is not a top-level schematic file!
#####
import asyncio
import os
import time
from abc import ABCMeta, abstractmethod
Expand All @@ -31,6 +32,22 @@
from llmkira.task.snapshot import global_snapshot_storage


class UserLocks:
def __init__(self):
self.locks = {}
self.locks_lock = asyncio.Lock()

async def get(self, user_id):
async with self.locks_lock:
if user_id not in self.locks:
self.locks[user_id] = asyncio.Lock()
return self.locks[user_id]


# UserLocks 的实例
user_locks = UserLocks()


async def read_user_credential(user_id: str) -> Optional[Credential]:
user = await USER_MANAGER.read(user_id=user_id)
return user.credential
Expand Down Expand Up @@ -289,16 +306,11 @@ async def _flash(
except Exception as e:
raise e

async def deal_message(self, message) -> Tuple:
async def deal_message(self, task_head: TaskHeader) -> Tuple:
"""
:param message: 消息
:param task_head: 任务头
:return: 任务,中间件,路由类型,是否释放函数快照
"""
logger.debug(f"Received MQ Message 📩{message.message_id}")
task_head: TaskHeader = TaskHeader.model_validate_json(
json_data=message.body.decode("utf-8")
)
logger.debug(f"Received Task:{task_head.model_dump_json(indent=2)}")
router = task_head.task_sign.router
# Deliver 直接转发
if router == Router.DELIVER:
Expand All @@ -309,6 +321,11 @@ async def deal_message(self, message) -> Tuple:

tools = await reorganize_tools(task=task_head, error_times_limit=10)
"""函数组建,自动过滤拉黑后的插件和错误过多的插件"""
if task_head.task_sign.layer == 0:
task_head.task_sign.tools_ghost.extend(tools)
else:
tools.extend(task_head.task_sign.tools_ghost)
"""当首条链确定工具组成后,传递给子链使用"""
llm_middleware = OpenaiMiddleware(
task=task_head,
tools=tools,
Expand All @@ -328,7 +345,7 @@ async def deal_message(self, message) -> Tuple:
llm=llm_middleware,
task=task_head,
intercept_function=True,
disable_tool=True,
disable_tool=False,
remember=True,
)
return (
Expand Down Expand Up @@ -367,60 +384,70 @@ async def deal_message(self, message) -> Tuple:
async def on_message(self, message: AbstractIncomingMessage):
if not self.task or not self.sender:
raise ValueError("receiver not set core")
if os.getenv("STOP_REPLY"):
logger.warning("🚫 STOP_REPLY is set in env, stop reply message")
return None
logger.debug(f"Received MQ Message 📩{message.message_id}")
try:
if os.getenv("STOP_REPLY"):
logger.warning("🚫 STOP_REPLY is set in env, stop reply message")
return None
task_head: TaskHeader = TaskHeader.model_validate_json(
json_data=message.body.decode("utf-8")
)
logger.debug(f"Received Task:{task_head.model_dump_json(indent=2)}")
# 处理消息
task_head, llm, router, response_snapshot = await self.deal_message(message)
task_head: TaskHeader
logger.debug(f"Message Success:Router {router}")
# 启动链式函数应答循环
if task_head and response_snapshot:
snap_data = await global_snapshot_storage.read(
user_id=task_head.receiver.uid
async with await user_locks.get(task_head.receiver.uid):
task_head, llm, router, response_snapshot = await self.deal_message(
task_head=task_head
)
if snap_data is not None:
data = snap_data.data
renew_snap_data = []
for task in data:
if task.expire_at < int(time.time()):
logger.info(
f"🧀 Expire snapshot {task.snap_uuid} at {router}"
)
# 跳过过期的任何任务
continue
# 不是认证任务
if not task.snapshot_credential:
# 没有被处理
if not task.processed:
try:
# await asyncio.sleep(10)
logger.debug(
f"🧀 Send snapshot {task.snap_uuid} at {router}"
)
await Task.create_and_send(
queue_name=task.channel, task=task.snapshot_data
)
except Exception as e:
logger.exception(f"Response to snapshot error {e}")
logger.debug(f"Message Success:Router {router}")
# 启动链式函数应答循环
if task_head and response_snapshot:
snap_data = await global_snapshot_storage.read(
user_id=task_head.receiver.uid
)
if snap_data is not None:
data = snap_data.data
renew_snap_data = []
for task in data:
if task.expire_at < int(time.time()):
logger.info(
f"🧀 Expire snapshot {task.snap_uuid} at {router}"
)
# 跳过过期的任何任务
continue
# 不是认证任务
if not task.snapshot_credential:
# 没有被处理
if not task.processed:
try:
# await asyncio.sleep(10)
logger.debug(
f"🧀 Send snapshot {task.snap_uuid} at {router}"
)
await Task.create_and_send(
queue_name=task.channel,
task=task.snapshot_data,
)
except Exception as e:
logger.exception(
f"Response to snapshot error {e}"
)
else:
logger.info(
f"🧀 Response to snapshot {task.snap_uuid} at {router}"
)
finally:
task.processed_at = int(time.time())
# renew_snap_data.append(task)
else:
logger.info(
f"🧀 Response to snapshot {task.snap_uuid} at {router}"
)
finally:
task.processed_at = int(time.time())
# renew_snap_data.append(task)
# 被处理过的任务。不再处理
pass
else:
# 被处理过的任务。不再处理
pass
else:
# 认证任务
renew_snap_data.append(task)
snap_data.data = renew_snap_data
await global_snapshot_storage.write(
user_id=task_head.receiver.uid, snapshot=snap_data
)
# 认证任务
renew_snap_data.append(task)
snap_data.data = renew_snap_data
await global_snapshot_storage.write(
user_id=task_head.receiver.uid, snapshot=snap_data
)
except Exception as e:
logger.exception(e)
await message.reject(requeue=False)
Expand Down
8 changes: 8 additions & 0 deletions app/receiver/telegram/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ async def file_forward(self, receiver: Location, file_list: List[File]):
sticker=file_downloaded,
)
elif file_obj.file_name.endswith(".ogg"):
await self.bot.send_chat_action(
chat_id=receiver.chat_id, action="record_voice"
)
try:
await self.bot.send_voice(
chat_id=receiver.chat_id,
Expand All @@ -78,6 +81,9 @@ async def file_forward(self, receiver: Location, file_list: List[File]):
else:
raise e
else:
await self.bot.send_chat_action(
chat_id=receiver.chat_id, action="upload_document"
)
await self.bot.send_document(
chat_id=receiver.chat_id,
document=file_downloaded,
Expand Down Expand Up @@ -118,6 +124,8 @@ async def reply(
:param messages: OPENAI Format Message
:param reply_to_message: 是否回复消息
"""
if receiver.chat_id is not None:
await self.bot.send_chat_action(chat_id=receiver.chat_id, action="typing")
event_message = [
EventMessage.from_openai_message(message=item, locate=receiver)
for item in messages
Expand Down
1 change: 1 addition & 0 deletions app/sender/telegram/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ async def listen_login_command(message: types.Message):
"You can set it via `https://api.com/v1$key$model` format, "
"or you can log in via URL using `token$https://provider.com`."
),
parse_mode="MarkdownV2",
)
if len(settings) == 2:
try:
Expand Down
4 changes: 4 additions & 0 deletions docs/dev_note/time.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,7 @@ author@sudoskys
=====

不管是哪种方案,都很麻烦,在实时性强的系统里,对性能额外要求。向量匹配只能放在后台处理。

#### 投入 Openai 生态

Openai 有一个很好的方案,向量匹配,检索主题。代价是兼容性(其他模型)。
Loading

0 comments on commit 1a58d1f

Please sign in to comment.