Skip to content

feat: add speech_to_text node and text_to_speech node #1827

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion apps/application/flow/step_node/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@
from .image_generate_step_node import *

from .search_dataset_node import *
from .speech_to_text_step_node import BaseSpeechToTextNode
from .start_node import *
from .text_to_speech_step_node.impl.base_text_to_speech_node import BaseTextToSpeechNode

node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode, BaseConditionNode, BaseReplyNode,
BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode,
BaseDocumentExtractNode,
BaseImageUnderstandNode, BaseImageGenerateNode, BaseFormNode]
BaseImageUnderstandNode, BaseFormNode, BaseSpeechToTextNode, BaseTextToSpeechNode,BaseImageGenerateNode]


def get_node(node_type):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class ApplicationNodeSerializer(serializers.Serializer):
user_input_field_list = serializers.ListField(required=False, error_messages=ErrMessage.uuid("用户输入字段"))
image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片"))
document_list = serializers.ListField(required=False, error_messages=ErrMessage.list("文档"))
audio_list = serializers.ListField(required=False, error_messages=ErrMessage.list("音频"))
child_node = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict("子节点"))
node_data = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict("表单数据"))

Expand Down Expand Up @@ -43,19 +44,30 @@ def _run(self):
app_document_list[1:])
for document in app_document_list:
if 'file_id' not in document:
raise ValueError("参数值错误: 上传的文档中缺少file_id")
raise ValueError("参数值错误: 上传的文档中缺少file_id,文档上传失败")
app_image_list = self.node_params_serializer.data.get('image_list', [])
if app_image_list and len(app_image_list) > 0:
app_image_list = self.workflow_manage.get_reference_field(
app_image_list[0],
app_image_list[1:])
for image in app_image_list:
if 'file_id' not in image:
raise ValueError("参数值错误: 上传的图片中缺少file_id")
raise ValueError("参数值错误: 上传的图片中缺少file_id,图片上传失败")

app_audio_list = self.node_params_serializer.data.get('audio_list', [])
if app_audio_list and len(app_audio_list) > 0:
app_audio_list = self.workflow_manage.get_reference_field(
app_audio_list[0],
app_audio_list[1:])
for audio in app_audio_list:
if 'file_id' not in audio:
raise ValueError("参数值错误: 上传的图片中缺少file_id,音频上传失败")
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data,
app_document_list=app_document_list, app_image_list=app_image_list,
app_audio_list=app_audio_list,
message=str(question), **kwargs)

def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type,
app_document_list=None, app_image_list=None, child_node=None, node_data=None, **kwargs) -> NodeResult:
app_document_list=None, app_image_list=None, app_audio_list=None, child_node=None, node_data=None,
**kwargs) -> NodeResult:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def save_context(self, details, workflow_manage):
self.answer_text = details.get('answer')

def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type,
app_document_list=None, app_image_list=None, child_node=None, node_data=None,
app_document_list=None, app_image_list=None, app_audio_list=None, child_node=None, node_data=None,
**kwargs) -> NodeResult:
from application.serializers.chat_message_serializers import ChatMessageSerializer
# 生成嵌入应用的chat_id
Expand All @@ -167,6 +167,8 @@ def execute(self, application_id, message, chat_id, chat_record_id, stream, re_c
app_document_list = []
if app_image_list is None:
app_image_list = []
if app_audio_list is None:
app_audio_list = []
runtime_node_id = None
record_id = None
child_node_value = None
Expand All @@ -186,6 +188,7 @@ def execute(self, application_id, message, chat_id, chat_record_id, stream, re_c
'client_type': client_type,
'document_list': app_document_list,
'image_list': app_image_list,
'audio_list': app_audio_list,
'runtime_node_id': runtime_node_id,
'chat_record_id': record_id,
'child_node': child_node_value,
Expand Down Expand Up @@ -234,5 +237,6 @@ def get_details(self, index: int, **kwargs):
'global_fields': global_fields,
'document_list': self.workflow_manage.document_list,
'image_list': self.workflow_manage.image_list,
'audio_list': self.workflow_manage.audio_list,
'application_node_dict': self.context.get('application_node_dict')
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# coding=utf-8

from .impl import *
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# coding=utf-8

from typing import Type

from rest_framework import serializers

from application.flow.i_step_node import INode, NodeResult
from common.util.field_message import ErrMessage


class SpeechToTextNodeSerializer(serializers.Serializer):
stt_model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id"))

is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))

audio_list = serializers.ListField(required=False, error_messages=ErrMessage.list("音频"))


class ISpeechToTextNode(INode):
type = 'speech-to-text-node'

def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
return SpeechToTextNodeSerializer

def _run(self):
res = self.workflow_manage.get_reference_field(self.node_params_serializer.data.get('audio_list')[0],
self.node_params_serializer.data.get('audio_list')[1:])
for audio in res:
if 'file_id' not in audio:
raise ValueError("参数值错误: 上传的图片中缺少file_id,音频上传失败")

return self.execute(audio=res, **self.node_params_serializer.data, **self.flow_params_serializer.data)

def execute(self, stt_model_id, chat_id,
audio,
**kwargs) -> NodeResult:
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# coding=utf-8

from .base_speech_to_text_node import BaseSpeechToTextNode
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# coding=utf-8
import os
import tempfile
import time
import io
from typing import List, Dict

from django.db.models import QuerySet
from pydub import AudioSegment
from concurrent.futures import ThreadPoolExecutor
from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.speech_to_text_step_node.i_speech_to_text_node import ISpeechToTextNode
from common.util.common import split_and_transcribe
from dataset.models import File
from setting.models_provider.tools import get_model_instance_by_model_user_id


class BaseSpeechToTextNode(ISpeechToTextNode):

def save_context(self, details, workflow_manage):
self.context['answer'] = details.get('answer')
self.answer_text = details.get('answer')

def execute(self, stt_model_id, chat_id, audio, **kwargs) -> NodeResult:
stt_model = get_model_instance_by_model_user_id(stt_model_id, self.flow_params_serializer.data.get('user_id'))
audio_list = audio
self.context['audio_list'] = audio


def process_audio_item(audio_item, model):
file = QuerySet(File).filter(id=audio_item['file_id']).first()
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') as temp_file:
temp_file.write(file.get_byte().tobytes())
temp_file_path = temp_file.name
try:
return split_and_transcribe(temp_file_path, model)
finally:
os.remove(temp_file_path)

def process_audio_items(audio_list, model):
with ThreadPoolExecutor(max_workers=5) as executor:
results = list(executor.map(lambda item: process_audio_item(item, model), audio_list))
return '\n\n'.join(results)

result = process_audio_items(audio_list, stt_model)
return NodeResult({'answer': result, 'result': result}, {})

def get_details(self, index: int, **kwargs):
return {
'name': self.node.properties.get('stepName'),
"index": index,
'run_time': self.context.get('run_time'),
'answer': self.context.get('answer'),
'type': self.node.type,
'status': self.status,
'err_message': self.err_message,
'audio_list': self.context.get('audio_list'),
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def save_context(self, details, workflow_manage):
self.context['run_time'] = details.get('run_time')
self.context['document'] = details.get('document_list')
self.context['image'] = details.get('image_list')
self.context['audio'] = details.get('audio_list')
self.status = details.get('status')
self.err_message = details.get('err_message')
for key, value in workflow_variable.items():
Expand All @@ -57,7 +58,8 @@ def execute(self, question, **kwargs) -> NodeResult:
node_variable = {
'question': question,
'image': self.workflow_manage.image_list,
'document': self.workflow_manage.document_list
'document': self.workflow_manage.document_list,
'audio': self.workflow_manage.audio_list
}
return NodeResult(node_variable, workflow_variable)

Expand All @@ -80,5 +82,6 @@ def get_details(self, index: int, **kwargs):
'err_message': self.err_message,
'image_list': self.context.get('image'),
'document_list': self.context.get('document'),
'audio_list': self.context.get('audio'),
'global_fields': global_fields
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# coding=utf-8

from .impl import *
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# coding=utf-8

from typing import Type

from rest_framework import serializers

from application.flow.i_step_node import INode, NodeResult
from common.util.field_message import ErrMessage


class TextToSpeechNodeSerializer(serializers.Serializer):
tts_model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id"))

is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))

content_list = serializers.ListField(required=False, error_messages=ErrMessage.list("文本内容"))
model_params_setting = serializers.DictField(required=False,
error_messages=ErrMessage.integer("模型参数相关设置"))


class ITextToSpeechNode(INode):
type = 'text-to-speech-node'

def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
return TextToSpeechNodeSerializer

def _run(self):
content = self.workflow_manage.get_reference_field(self.node_params_serializer.data.get('content_list')[0],
self.node_params_serializer.data.get('content_list')[1:])
return self.execute(content=content, **self.node_params_serializer.data, **self.flow_params_serializer.data)

def execute(self, tts_model_id, chat_id,
content, model_params_setting=None,
**kwargs) -> NodeResult:
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# coding=utf-8

from .base_text_to_speech_node import BaseTextToSpeechNode
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# coding=utf-8
import io
import mimetypes

from django.core.files.uploadedfile import InMemoryUploadedFile

from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.image_understand_step_node.i_image_understand_node import IImageUnderstandNode
from application.flow.step_node.text_to_speech_step_node.i_text_to_speech_node import ITextToSpeechNode
from dataset.models import File
from dataset.serializers.file_serializers import FileSerializer
from setting.models_provider.tools import get_model_instance_by_model_user_id


def bytes_to_uploaded_file(file_bytes, file_name="generated_audio.mp3"):
content_type, _ = mimetypes.guess_type(file_name)
if content_type is None:
# 如果未能识别,设置为默认的二进制文件类型
content_type = "application/octet-stream"
# 创建一个内存中的字节流对象
file_stream = io.BytesIO(file_bytes)

# 获取文件大小
file_size = len(file_bytes)

uploaded_file = InMemoryUploadedFile(
file=file_stream,
field_name=None,
name=file_name,
content_type=content_type,
size=file_size,
charset=None,
)
return uploaded_file


class BaseTextToSpeechNode(ITextToSpeechNode):
def save_context(self, details, workflow_manage):
self.context['answer'] = details.get('answer')
self.answer_text = details.get('answer')

def execute(self, tts_model_id, chat_id,
content, model_params_setting=None,
**kwargs) -> NodeResult:
self.context['content'] = content
model = get_model_instance_by_model_user_id(tts_model_id, self.flow_params_serializer.data.get('user_id'),
**model_params_setting)
audio_byte = model.text_to_speech(content)
# 需要把这个音频文件存储到数据库中
file_name = 'generated_audio.mp3'
file = bytes_to_uploaded_file(audio_byte, file_name)
application = self.workflow_manage.work_flow_post_handler.chat_info.application
meta = {
'debug': False if application.id else True,
'chat_id': chat_id,
'application_id': str(application.id) if application.id else None,
}
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
# 拼接一个audio标签的src属性
audio_label = f'<audio src="{file_url}" controls style = "width: 300px; height: 43px" class ="border-r-4"/>'
return NodeResult({'answer': audio_label, 'result': audio_label}, {})

def get_details(self, index: int, **kwargs):
return {
'name': self.node.properties.get('stepName'),
"index": index,
'run_time': self.context.get('run_time'),
'type': self.node.type,
'status': self.status,
'content': self.context.get('content'),
'err_message': self.err_message,
'answer': self.context.get('answer'),
}
6 changes: 5 additions & 1 deletion apps/application/flow/workflow_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, _id: str, _type: str, x: int, y: int, properties: dict, **kwa


end_nodes = ['ai-chat-node', 'reply-node', 'function-node', 'function-lib-node', 'application-node',
'image-understand-node', 'image-generate-node']
'image-understand-node', 'speech-to-text-node', 'text-to-speech-node', 'image-generate-node']


class Flow:
Expand Down Expand Up @@ -244,6 +244,7 @@ class WorkflowManage:
def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler,
base_to_response: BaseToResponse = SystemToResponse(), form_data=None, image_list=None,
document_list=None,
audio_list=None,
start_node_id=None,
start_node_data=None, chat_record=None, child_node=None):
if form_data is None:
Expand All @@ -252,11 +253,14 @@ def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandl
image_list = []
if document_list is None:
document_list = []
if audio_list is None:
audio_list = []
self.start_node_id = start_node_id
self.start_node = None
self.form_data = form_data
self.image_list = image_list
self.document_list = document_list
self.audio_list = audio_list
self.params = params
self.flow = flow
self.lock = threading.Lock()
Expand Down
4 changes: 3 additions & 1 deletion apps/application/serializers/chat_message_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ class ChatMessageSerializer(serializers.Serializer):
form_data = serializers.DictField(required=False, error_messages=ErrMessage.char("全局变量"))
image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片"))
document_list = serializers.ListField(required=False, error_messages=ErrMessage.list("文档"))
audio_list = serializers.ListField(required=False, error_messages=ErrMessage.list("音频"))
child_node = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict("子节点"))

def is_valid_application_workflow(self, *, raise_exception=False):
Expand Down Expand Up @@ -338,6 +339,7 @@ def chat_work_flow(self, chat_info: ChatInfo, base_to_response):
form_data = self.data.get('form_data')
image_list = self.data.get('image_list')
document_list = self.data.get('document_list')
audio_list = self.data.get('audio_list')
user_id = chat_info.application.user_id
chat_record_id = self.data.get('chat_record_id')
chat_record = None
Expand All @@ -354,7 +356,7 @@ def chat_work_flow(self, chat_info: ChatInfo, base_to_response):
'client_id': client_id,
'client_type': client_type,
'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type),
base_to_response, form_data, image_list, document_list,
base_to_response, form_data, image_list, document_list, audio_list,
self.data.get('runtime_node_id'),
self.data.get('node_data'), chat_record, self.data.get('child_node'))
r = work_flow_manage.run()
Expand Down
Loading
Loading