Skip to content

Commit 3faba75

Browse files
committed
feat: 支持重排模型
1 parent fcbfd8a commit 3faba75

File tree

20 files changed

+983
-6
lines changed

20 files changed

+983
-6
lines changed

apps/application/flow/step_node/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414
from .direct_reply_node import *
1515
from .function_lib_node import *
1616
from .function_node import *
17+
from .reranker_node import *
1718

1819
node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode, BaseConditionNode, BaseReplyNode,
19-
BaseFunctionNodeNode, BaseFunctionLibNodeNode]
20+
BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode]
2021

2122

2223
def get_node(node_type):
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# coding=utf-8
2+
"""
3+
@project: MaxKB
4+
@Author:虎
5+
@file: __init__.py
6+
@date:2024/9/4 11:37
7+
@desc:
8+
"""
9+
from .impl import *
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# coding=utf-8
2+
"""
3+
@project: MaxKB
4+
@Author:虎
5+
@file: i_reranker_node.py
6+
@date:2024/9/4 10:40
7+
@desc:
8+
"""
9+
from typing import Type
10+
11+
from rest_framework import serializers
12+
13+
from application.flow.i_step_node import INode, NodeResult
14+
from common.util.field_message import ErrMessage
15+
16+
17+
class RerankerSettingSerializer(serializers.Serializer):
18+
# 需要查询的条数
19+
top_n = serializers.IntegerField(required=True,
20+
error_messages=ErrMessage.integer("引用分段数"))
21+
# 相似度 0-1之间
22+
similarity = serializers.FloatField(required=True, max_value=2, min_value=0,
23+
error_messages=ErrMessage.float("引用分段数"))
24+
max_paragraph_char_number = serializers.IntegerField(required=True,
25+
error_messages=ErrMessage.float("最大引用分段字数"))
26+
27+
28+
class RerankerStepNodeSerializer(serializers.Serializer):
29+
reranker_setting = RerankerSettingSerializer(required=True)
30+
31+
question_reference_address = serializers.ListField(required=True)
32+
reranker_model_id = serializers.UUIDField(required=True)
33+
reranker_reference_list = serializers.ListField(required=True, child=serializers.ListField(required=True))
34+
35+
def is_valid(self, *, raise_exception=False):
36+
super().is_valid(raise_exception=True)
37+
38+
39+
class IRerankerNode(INode):
40+
type = 'reranker-node'
41+
42+
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
43+
return RerankerStepNodeSerializer
44+
45+
def _run(self):
46+
question = self.workflow_manage.get_reference_field(
47+
self.node_params_serializer.data.get('question_reference_address')[0],
48+
self.node_params_serializer.data.get('question_reference_address')[1:])
49+
reranker_list = [self.workflow_manage.get_reference_field(
50+
reference[0],
51+
reference[1:]) for reference in
52+
self.node_params_serializer.data.get('reranker_reference_list')]
53+
return self.execute(**self.node_params_serializer.data, question=str(question),
54+
55+
reranker_list=reranker_list)
56+
57+
def execute(self, question, reranker_setting, reranker_list, reranker_model_id,
58+
**kwargs) -> NodeResult:
59+
pass
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# coding=utf-8
2+
"""
3+
@project: MaxKB
4+
@Author:虎
5+
@file: __init__.py
6+
@date:2024/9/4 11:39
7+
@desc:
8+
"""
9+
from .base_reranker_node import *
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# coding=utf-8
2+
"""
3+
@project: MaxKB
4+
@Author:虎
5+
@file: base_reranker_node.py
6+
@date:2024/9/4 11:41
7+
@desc:
8+
"""
9+
from typing import List
10+
11+
from langchain_core.documents import Document
12+
13+
from application.flow.i_step_node import NodeResult
14+
from application.flow.step_node.reranker_node.i_reranker_node import IRerankerNode
15+
from setting.models_provider.tools import get_model_instance_by_model_user_id
16+
17+
18+
def merge_reranker_list(reranker_list, result=None):
19+
if result is None:
20+
result = []
21+
for document in reranker_list:
22+
if isinstance(document, list):
23+
merge_reranker_list(document, result)
24+
elif isinstance(document, dict):
25+
content = document.get('title', '') + document.get('content', '')
26+
result.append(str(document) if len(content) == 0 else content)
27+
else:
28+
result.append(str(document))
29+
return result
30+
31+
32+
def filter_result(document_list: List[Document], max_paragraph_char_number, top_n, similarity):
33+
use_len = 0
34+
result = []
35+
for index in range(len(document_list)):
36+
document = document_list[index]
37+
if use_len >= max_paragraph_char_number or index >= top_n or document.metadata.get(
38+
'relevance_score') < similarity:
39+
break
40+
content = document.page_content[0:max_paragraph_char_number - use_len]
41+
use_len = use_len + len(content)
42+
result.append({'page_content': content, 'metadata': document.metadata})
43+
return result
44+
45+
46+
class BaseRerankerNode(IRerankerNode):
47+
def execute(self, question, reranker_setting, reranker_list, reranker_model_id,
48+
**kwargs) -> NodeResult:
49+
documents = merge_reranker_list(reranker_list)
50+
reranker_model = get_model_instance_by_model_user_id(reranker_model_id,
51+
self.flow_params_serializer.data.get('user_id'))
52+
result = reranker_model.compress_documents(
53+
[Document(page_content=document) for document in documents if document is not None and len(document) > 0],
54+
question)
55+
top_n = reranker_setting.get('top_n', 3)
56+
similarity = reranker_setting.get('similarity', 0.6)
57+
max_paragraph_char_number = reranker_setting.get('max_paragraph_char_number', 5000)
58+
r = filter_result(result, max_paragraph_char_number, top_n, similarity)
59+
return NodeResult({'result_list': r, 'result': ''.join([item.get('page_content') for item in r])}, {})
60+
61+
def get_details(self, index: int, **kwargs):
62+
return {
63+
'name': self.node.properties.get('stepName'),
64+
"index": index,
65+
"question": self.node_params_serializer.data.get('question'),
66+
'run_time': self.context.get('run_time'),
67+
'type': self.node.type,
68+
'reranker_setting': self.node_params_serializer.data.get('reranker_setting'),
69+
'result_list': self.context.get('result_list'),
70+
'result': self.context.get('result'),
71+
'status': self.status,
72+
'err_message': self.err_message
73+
}

apps/setting/models_provider/base_model_provider.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ class ModelTypeConst(Enum):
141141
EMBEDDING = {'code': 'EMBEDDING', 'message': '向量模型'}
142142
STT = {'code': 'STT', 'message': '语音识别'}
143143
TTS = {'code': 'TTS', 'message': '语音合成'}
144+
RERANKER = {'code': 'RERANKER', 'message': '重排模型'}
144145

145146

146147
class ModelInfo:
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# coding=utf-8
2+
"""
3+
@project: MaxKB
4+
@Author:虎
5+
@file: reranker.py
6+
@date:2024/9/3 14:33
7+
@desc:
8+
"""
9+
from typing import Dict
10+
11+
from langchain_core.documents import Document
12+
13+
from common import forms
14+
from common.exception.app_exception import AppApiException
15+
from common.forms import BaseForm
16+
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
17+
from setting.models_provider.impl.local_model_provider.model.reranker import LocalBaseReranker
18+
19+
20+
class LocalRerankerCredential(BaseForm, BaseModelCredential):
21+
22+
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
23+
raise_exception=False):
24+
if not model_type == 'RERANKER':
25+
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
26+
for key in ['cache_dir']:
27+
if key not in model_credential:
28+
if raise_exception:
29+
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
30+
else:
31+
return False
32+
try:
33+
model: LocalBaseReranker = provider.get_model(model_type, model_name, model_credential)
34+
model.compress_documents([Document(page_content='你好')], '你好')
35+
except Exception as e:
36+
if isinstance(e, AppApiException):
37+
raise e
38+
if raise_exception:
39+
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
40+
else:
41+
return False
42+
return True
43+
44+
def encryption_dict(self, model: Dict[str, object]):
45+
return model
46+
47+
cache_dir = forms.TextInputField('模型目录', required=True)

apps/setting/models_provider/impl/local_model_provider/local_model_provider.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,20 @@
1616
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
1717
ModelInfoManage
1818
from setting.models_provider.impl.local_model_provider.credential.embedding import LocalEmbeddingCredential
19+
from setting.models_provider.impl.local_model_provider.credential.reranker import LocalRerankerCredential
1920
from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding
21+
from setting.models_provider.impl.local_model_provider.model.reranker import LocalReranker
2022
from smartdoc.conf import PROJECT_DIR
2123

2224
embedding_text2vec_base_chinese = ModelInfo('shibing624/text2vec-base-chinese', '', ModelTypeConst.EMBEDDING,
2325
LocalEmbeddingCredential(), LocalEmbedding)
26+
bge_reranker_v2_m3 = ModelInfo('BAAI/bge-reranker-v2-m3', '', ModelTypeConst.RERANKER,
27+
LocalRerankerCredential(), LocalReranker)
2428

2529
model_info_manage = (ModelInfoManage.builder().append_model_info(embedding_text2vec_base_chinese)
2630
.append_default_model_info(embedding_text2vec_base_chinese)
31+
.append_model_info(bge_reranker_v2_m3)
32+
.append_default_model_info(bge_reranker_v2_m3)
2733
.build())
2834

2935

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# coding=utf-8
2+
"""
3+
@project: MaxKB
4+
@Author:虎
5+
@file: reranker.py.py
6+
@date:2024/9/2 16:42
7+
@desc:
8+
"""
9+
from typing import Sequence, Optional, Dict, Any
10+
11+
import requests
12+
import torch
13+
from langchain_core.callbacks import Callbacks
14+
from langchain_core.documents import BaseDocumentCompressor, Document
15+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
16+
17+
from setting.models_provider.base_model_provider import MaxKBBaseModel
18+
from smartdoc.const import CONFIG
19+
20+
21+
class LocalReranker(MaxKBBaseModel):
22+
def __init__(self, model_name, top_n=3, cache_dir=None):
23+
super().__init__()
24+
self.model_name = model_name
25+
self.cache_dir = cache_dir
26+
self.top_n = top_n
27+
28+
@staticmethod
29+
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
30+
if model_kwargs.get('use_local', True):
31+
return LocalBaseReranker(model_name=model_name, cache_dir=model_credential.get('cache_dir'),
32+
model_kwargs={'device': model_credential.get('device', 'cpu')}
33+
34+
)
35+
return WebLocalBaseReranker(model_name=model_name, cache_dir=model_credential.get('cache_dir'),
36+
model_kwargs={'device': model_credential.get('device')},
37+
**model_kwargs)
38+
39+
40+
class WebLocalBaseReranker(MaxKBBaseModel, BaseDocumentCompressor):
41+
@staticmethod
42+
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
43+
pass
44+
45+
model_id: str = None
46+
47+
def __init__(self, **kwargs):
48+
super().__init__(**kwargs)
49+
self.model_id = kwargs.get('model_id', None)
50+
51+
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
52+
Sequence[Document]:
53+
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
54+
res = requests.post(
55+
f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}/api/model/{self.model_id}/compress_documents',
56+
json={'documents': [{'page_content': document.page_content, 'metadata': document.metadata} for document in
57+
documents], 'query': query}, headers={'Content-Type': 'application/json'})
58+
result = res.json()
59+
if result.get('code', 500) == 200:
60+
return [Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document
61+
in result.get('data')]
62+
raise Exception(result.get('msg'))
63+
64+
65+
class LocalBaseReranker(MaxKBBaseModel, BaseDocumentCompressor):
66+
client: Any = None
67+
tokenizer: Any = None
68+
model: Optional[str] = None
69+
cache_dir: Optional[str] = None
70+
model_kwargs = {}
71+
72+
def __init__(self, model_name, cache_dir=None, **model_kwargs):
73+
super().__init__()
74+
self.model = model_name
75+
self.cache_dir = cache_dir
76+
self.model_kwargs = model_kwargs
77+
self.client = AutoModelForSequenceClassification.from_pretrained(self.model, cache_dir=self.cache_dir)
78+
self.tokenizer = AutoTokenizer.from_pretrained(self.model, cache_dir=self.cache_dir)
79+
self.client = self.client.to(self.model_kwargs.get('device', 'cpu'))
80+
self.client.eval()
81+
82+
@staticmethod
83+
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
84+
return LocalBaseReranker(model_name, cache_dir=model_credential.get('cache_dir'), **model_kwargs)
85+
86+
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
87+
Sequence[Document]:
88+
with torch.no_grad():
89+
inputs = self.tokenizer([[query, document.page_content] for document in documents], padding=True,
90+
truncation=True, return_tensors='pt', max_length=512)
91+
scores = [torch.sigmoid(s).float().item() for s in
92+
self.client(**inputs, return_dict=True).logits.view(-1, ).float()]
93+
result = [Document(page_content=documents[index].page_content, metadata={'relevance_score': scores[index]})
94+
for index
95+
in range(len(documents))]
96+
result.sort(key=lambda row: row.metadata.get('relevance_score'), reverse=True)
97+
return result

apps/setting/serializers/model_apply_serializers.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
@desc:
88
"""
99
from django.db.models import QuerySet
10+
from langchain_core.documents import Document
1011
from rest_framework import serializers
1112

1213
from common.config.embedding_config import ModelManage
@@ -33,6 +34,16 @@ class EmbedQuery(serializers.Serializer):
3334
text = serializers.CharField(required=True, error_messages=ErrMessage.char("向量文本"))
3435

3536

37+
class CompressDocument(serializers.Serializer):
38+
page_content = serializers.CharField(required=True, error_messages=ErrMessage.char("文本"))
39+
metadata = serializers.DictField(required=False, error_messages=ErrMessage.dict("元数据"))
40+
41+
42+
class CompressDocuments(serializers.Serializer):
43+
documents = CompressDocument(required=True, many=True)
44+
query = serializers.CharField(required=True, error_messages=ErrMessage.char("查询query"))
45+
46+
3647
class ModelApplySerializers(serializers.Serializer):
3748
model_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id"))
3849

@@ -51,3 +62,12 @@ def embed_query(self, instance, with_valid=True):
5162

5263
model = get_embedding_model(self.data.get('model_id'))
5364
return model.embed_query(instance.get('text'))
65+
66+
def compress_documents(self, instance, with_valid=True):
67+
if with_valid:
68+
self.is_valid(raise_exception=True)
69+
CompressDocuments(data=instance).is_valid(raise_exception=True)
70+
model = get_embedding_model(self.data.get('model_id'))
71+
return [{'page_content': d.page_content, 'metadata': d.metadata} for d in model.compress_documents(
72+
[Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document in
73+
instance.get('documents')], instance.get('query'))]

apps/setting/urls.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
path('provider/model_form', views.Provide.ModelForm.as_view(),
1818
name="provider/model_form"),
1919
path('model', views.Model.as_view(), name='model'),
20-
path('model/<str:model_id>/model_params_form', views.Model.ModelParamsForm.as_view(), name='model/model_params_form'),
20+
path('model/<str:model_id>/model_params_form', views.Model.ModelParamsForm.as_view(),
21+
name='model/model_params_form'),
2122
path('model/<str:model_id>', views.Model.Operate.as_view(), name='model/operate'),
2223
path('model/<str:model_id>/pause_download', views.Model.PauseDownload.as_view(), name='model/operate'),
2324
path('model/<str:model_id>/meta', views.Model.ModelMeta.as_view(), name='model/operate/meta'),
@@ -31,4 +32,6 @@
3132
name='model/embed_documents'),
3233
path('model/<str:model_id>/embed_query', views.ModelApply.EmbedQuery.as_view(),
3334
name='model/embed_query'),
35+
path('model/<str:model_id>/compress_documents', views.ModelApply.CompressDocuments.as_view(),
36+
name='model/embed_query'),
3437
]

0 commit comments

Comments
 (0)