Skip to content

Pr@main@feat is return #921

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions apps/application/flow/i_step_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from abc import abstractmethod
from typing import Type, Dict, List

from django.core import cache
from django.db.models import QuerySet
from rest_framework import serializers

Expand All @@ -18,7 +19,6 @@
from common.constants.authentication_type import AuthenticationType
from common.field.common import InstanceField
from common.util.field_message import ErrMessage
from django.core import cache

chat_cache = cache.caches['chat_cache']

Expand All @@ -27,9 +27,13 @@ def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
if step_variable is not None:
for key in step_variable:
node.context[key] = step_variable[key]
if workflow.is_result() and 'answer' in step_variable:
yield step_variable['answer']
workflow.answer += step_variable['answer']
if global_variable is not None:
for key in global_variable:
workflow.context[key] = global_variable[key]
node.context['run_time'] = time.time() - node.context['start_time']


class WorkFlowPostHandler:
Expand Down Expand Up @@ -70,18 +74,14 @@ def handler(self, chat_id,


class NodeResult:
def __init__(self, node_variable: Dict, workflow_variable: Dict, _to_response=None, _write_context=write_context):
def __init__(self, node_variable: Dict, workflow_variable: Dict,
_write_context=write_context):
self._write_context = _write_context
self.node_variable = node_variable
self.workflow_variable = workflow_variable
self._to_response = _to_response

def write_context(self, node, workflow):
self._write_context(self.node_variable, self.workflow_variable, node, workflow)

def to_response(self, chat_id, chat_record_id, node, workflow, post_handler: WorkFlowPostHandler):
return self._to_response(chat_id, chat_record_id, self.node_variable, self.workflow_variable, node, workflow,
post_handler)
return self._write_context(self.node_variable, self.workflow_variable, node, workflow)

def is_assertion_result(self):
return 'branch_id' in self.node_variable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class ChatNodeSerializer(serializers.Serializer):
# 多轮对话数量
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))

is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))


class IChatNode(INode):
type = 'ai-chat-node'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,25 @@
from langchain.schema import HumanMessage, SystemMessage
from langchain_core.messages import BaseMessage

from application.flow import tools
from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
from setting.models_provider.tools import get_model_instance_by_model_user_id


def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
chat_model = node_variable.get('chat_model')
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
answer_tokens = chat_model.get_num_tokens(answer)
node.context['message_tokens'] = message_tokens
node.context['answer_tokens'] = answer_tokens
node.context['answer'] = answer
node.context['history_message'] = node_variable['history_message']
node.context['question'] = node_variable['question']
node.context['run_time'] = time.time() - node.context['start_time']
if workflow.is_result():
workflow.answer += answer


def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
"""
写入上下文数据 (流式)
Expand All @@ -31,15 +44,8 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
answer = ''
for chunk in response:
answer += chunk.content
chat_model = node_variable.get('chat_model')
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
answer_tokens = chat_model.get_num_tokens(answer)
node.context['message_tokens'] = message_tokens
node.context['answer_tokens'] = answer_tokens
node.context['answer'] = answer
node.context['history_message'] = node_variable['history_message']
node.context['question'] = node_variable['question']
node.context['run_time'] = time.time() - node.context['start_time']
yield answer
_write_context(node_variable, workflow_variable, node, workflow, answer)


def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
Expand All @@ -51,71 +57,8 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
@param workflow: 工作流管理器
"""
response = node_variable.get('result')
chat_model = node_variable.get('chat_model')
answer = response.content
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
answer_tokens = chat_model.get_num_tokens(answer)
node.context['message_tokens'] = message_tokens
node.context['answer_tokens'] = answer_tokens
node.context['answer'] = answer
node.context['history_message'] = node_variable['history_message']
node.context['question'] = node_variable['question']


def get_to_response_write_context(node_variable: Dict, node: INode):
def _write_context(answer, status=200):
chat_model = node_variable.get('chat_model')

if status == 200:
answer_tokens = chat_model.get_num_tokens(answer)
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
else:
answer_tokens = 0
message_tokens = 0
node.err_message = answer
node.status = status
node.context['message_tokens'] = message_tokens
node.context['answer_tokens'] = answer_tokens
node.context['answer'] = answer
node.context['run_time'] = time.time() - node.context['start_time']

return _write_context


def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
post_handler):
"""
将流式数据 转换为 流式响应
@param chat_id: 会话id
@param chat_record_id: 对话记录id
@param node_variable: 节点数据
@param workflow_variable: 工作流数据
@param node: 节点
@param workflow: 工作流管理器
@param post_handler: 后置处理器 输出结果后执行
@return: 流式响应
"""
response = node_variable.get('result')
_write_context = get_to_response_write_context(node_variable, node)
return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)


def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
post_handler):
"""
将结果转换
@param chat_id: 会话id
@param chat_record_id: 对话记录id
@param node_variable: 节点数据
@param workflow_variable: 工作流数据
@param node: 节点
@param workflow: 工作流管理器
@param post_handler: 后置处理器
@return: 响应
"""
response = node_variable.get('result')
_write_context = get_to_response_write_context(node_variable, node)
return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
_write_context(node_variable, workflow_variable, node, workflow, answer)


class BaseChatNode(IChatNode):
Expand All @@ -132,13 +75,12 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record
r = chat_model.stream(message_list)
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
'history_message': history_message, 'question': question.content}, {},
_write_context=write_context_stream,
_to_response=to_stream_response)
_write_context=write_context_stream)
else:
r = chat_model.invoke(message_list)
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
'history_message': history_message, 'question': question.content}, {},
_write_context=write_context, _to_response=to_response)
_write_context=write_context)

@staticmethod
def get_history_message(history_chat_record, dialogue_number):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class ReplyNodeParamsSerializer(serializers.Serializer):
fields = serializers.ListField(required=False, error_messages=ErrMessage.list("引用字段"))
content = serializers.CharField(required=False, allow_blank=True, allow_null=True,
error_messages=ErrMessage.char("直接回答内容"))
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))

def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,69 +6,19 @@
@date:2024/6/11 17:25
@desc:
"""
from typing import List, Dict
from typing import List

from langchain_core.messages import AIMessage, AIMessageChunk

from application.flow import tools
from application.flow.i_step_node import NodeResult, INode
from application.flow.i_step_node import NodeResult
from application.flow.step_node.direct_reply_node.i_reply_node import IReplyNode


def get_to_response_write_context(node_variable: Dict, node: INode):
def _write_context(answer, status=200):
node.context['answer'] = answer

return _write_context


def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
post_handler):
"""
将流式数据 转换为 流式响应
@param chat_id: 会话id
@param chat_record_id: 对话记录id
@param node_variable: 节点数据
@param workflow_variable: 工作流数据
@param node: 节点
@param workflow: 工作流管理器
@param post_handler: 后置处理器 输出结果后执行
@return: 流式响应
"""
response = node_variable.get('result')
_write_context = get_to_response_write_context(node_variable, node)
return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)


def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
post_handler):
"""
将结果转换
@param chat_id: 会话id
@param chat_record_id: 对话记录id
@param node_variable: 节点数据
@param workflow_variable: 工作流数据
@param node: 节点
@param workflow: 工作流管理器
@param post_handler: 后置处理器
@return: 响应
"""
response = node_variable.get('result')
_write_context = get_to_response_write_context(node_variable, node)
return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)


class BaseReplyNode(IReplyNode):
def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult:
if reply_type == 'referencing':
result = self.get_reference_content(fields)
else:
result = self.generate_reply_content(content)
if stream:
return NodeResult({'result': iter([AIMessageChunk(content=result)]), 'answer': result}, {},
_to_response=to_stream_response)
else:
return NodeResult({'result': AIMessage(content=result), 'answer': result}, {}, _to_response=to_response)
return NodeResult({'answer': result}, {})

def generate_reply_content(self, prompt):
return self.workflow_manage.generate_prompt(prompt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class QuestionNodeSerializer(serializers.Serializer):
# 多轮对话数量
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))

is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))


class IQuestionNode(INode):
type = 'question-node'
Expand Down
Loading
Loading