Skip to content

Commit 83cd69e

Browse files
authored
feat: Generate problem support for generating unfinished paragraphs #2174 (#2299)
1 parent f45855c commit 83cd69e

File tree

5 files changed

+88
-16
lines changed

5 files changed

+88
-16
lines changed

apps/dataset/serializers/document_serializers.py

+18-8
Original file line numberDiff line numberDiff line change
@@ -243,13 +243,16 @@ def export(self, with_valid=True):
243243
self.is_valid(raise_exception=True)
244244
language = get_language()
245245
if self.data.get('type') == 'csv':
246-
file = open(os.path.join(PROJECT_DIR, "apps", "dataset", 'template', f'csv_template_{to_locale(language)}.csv'), "rb")
246+
file = open(
247+
os.path.join(PROJECT_DIR, "apps", "dataset", 'template', f'csv_template_{to_locale(language)}.csv'),
248+
"rb")
247249
content = file.read()
248250
file.close()
249251
return HttpResponse(content, status=200, headers={'Content-Type': 'text/cxv',
250252
'Content-Disposition': 'attachment; filename="csv_template.csv"'})
251253
elif self.data.get('type') == 'excel':
252-
file = open(os.path.join(PROJECT_DIR, "apps", "dataset", 'template', f'excel_template_{to_locale(language)}.xlsx'), "rb")
254+
file = open(os.path.join(PROJECT_DIR, "apps", "dataset", 'template',
255+
f'excel_template_{to_locale(language)}.xlsx'), "rb")
253256
content = file.read()
254257
file.close()
255258
return HttpResponse(content, status=200, headers={'Content-Type': 'application/vnd.ms-excel',
@@ -261,7 +264,8 @@ def table_export(self, with_valid=True):
261264
language = get_language()
262265
if self.data.get('type') == 'csv':
263266
file = open(
264-
os.path.join(PROJECT_DIR, "apps", "dataset", 'template', f'table_template_{to_locale(language)}.csv'),
267+
os.path.join(PROJECT_DIR, "apps", "dataset", 'template',
268+
f'table_template_{to_locale(language)}.csv'),
265269
"rb")
266270
content = file.read()
267271
file.close()
@@ -1180,7 +1184,7 @@ def is_valid(self, *, raise_exception=False):
11801184
if not QuerySet(Document).filter(id=document_id).exists():
11811185
raise AppApiException(500, _('document id not exist'))
11821186

1183-
def generate_related(self, model_id, prompt, with_valid=True):
1187+
def generate_related(self, model_id, prompt, state_list=None, with_valid=True):
11841188
if with_valid:
11851189
self.is_valid(raise_exception=True)
11861190
document_id = self.data.get('document_id')
@@ -1192,7 +1196,7 @@ def generate_related(self, model_id, prompt, with_valid=True):
11921196
State.PENDING)
11931197
ListenerManagement.get_aggregation_document_status(document_id)()
11941198
try:
1195-
generate_related_by_document_id.delay(document_id, model_id, prompt)
1199+
generate_related_by_document_id.delay(document_id, model_id, prompt, state_list)
11961200
except AlreadyQueued as e:
11971201
raise AppApiException(500, _('The task is being executed, please do not send it again.'))
11981202

@@ -1205,17 +1209,23 @@ def batch_generate_related(self, instance: Dict, with_valid=True):
12051209
document_id_list = instance.get("document_id_list")
12061210
model_id = instance.get("model_id")
12071211
prompt = instance.get("prompt")
1212+
state_list = instance.get('state_list')
12081213
ListenerManagement.update_status(QuerySet(Document).filter(id__in=document_id_list),
12091214
TaskType.GENERATE_PROBLEM,
12101215
State.PENDING)
1211-
ListenerManagement.update_status(QuerySet(Paragraph).filter(document_id__in=document_id_list),
1212-
TaskType.GENERATE_PROBLEM,
1216+
ListenerManagement.update_status(QuerySet(Paragraph).annotate(
1217+
reversed_status=Reverse('status'),
1218+
task_type_status=Substr('reversed_status', TaskType.GENERATE_PROBLEM.value,
1219+
1),
1220+
).filter(task_type_status__in=state_list, document_id__in=document_id_list)
1221+
.values('id'),
1222+
TaskType.EMBEDDING,
12131223
State.PENDING)
12141224
ListenerManagement.get_aggregation_document_status_by_query_set(
12151225
QuerySet(Document).filter(id__in=document_id_list))()
12161226
try:
12171227
for document_id in document_id_list:
1218-
generate_related_by_document_id.delay(document_id, model_id, prompt)
1228+
generate_related_by_document_id.delay(document_id, model_id, prompt, state_list)
12191229
except AlreadyQueued as e:
12201230
pass
12211231

apps/dataset/task/generate.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33

44
from celery_once import QueueOnce
55
from django.db.models import QuerySet
6+
from django.db.models.functions import Reverse, Substr
67
from langchain_core.messages import HumanMessage
78

89
from common.config.embedding_config import ModelManage
910
from common.event import ListenerManagement
10-
from common.util.page_utils import page
11+
from common.util.page_utils import page, page_desc
1112
from dataset.models import Paragraph, Document, Status, TaskType, State
1213
from dataset.task.tools import save_problem
1314
from ops import celery_app
@@ -64,7 +65,11 @@ def is_the_task_interrupted():
6465

6566
@celery_app.task(base=QueueOnce, once={'keys': ['document_id']},
6667
name='celery:generate_related_by_document')
67-
def generate_related_by_document_id(document_id, model_id, prompt):
68+
def generate_related_by_document_id(document_id, model_id, prompt, state_list=None):
69+
if state_list is None:
70+
state_list = [State.PENDING.value, State.STARTED.value, State.SUCCESS.value, State.FAILURE.value,
71+
State.REVOKE.value,
72+
State.REVOKED.value, State.IGNORED.value]
6873
try:
6974
is_the_task_interrupted = get_is_the_task_interrupted(document_id)
7075
if is_the_task_interrupted():
@@ -78,7 +83,12 @@ def generate_related_by_document_id(document_id, model_id, prompt):
7883
generate_problem = get_generate_problem(llm_model, prompt,
7984
ListenerManagement.get_aggregation_document_status(
8085
document_id), is_the_task_interrupted)
81-
page(QuerySet(Paragraph).filter(document_id=document_id), 10, generate_problem, is_the_task_interrupted)
86+
query_set = QuerySet(Paragraph).annotate(
87+
reversed_status=Reverse('status'),
88+
task_type_status=Substr('reversed_status', TaskType.GENERATE_PROBLEM.value,
89+
1),
90+
).filter(task_type_status__in=state_list, document_id=document_id)
91+
page_desc(query_set, 10, generate_problem, is_the_task_interrupted)
8292
except Exception as e:
8393
max_kb_error.error(f'根据文档生成问题:{document_id}出现错误{str(e)}{traceback.format_exc()}')
8494
max_kb_error.error(_('Generate issue based on document: {document_id} error {error}{traceback}').format(

apps/setting/models_provider/impl/xinference_model_provider/model/image.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1-
from typing import Dict
1+
from typing import Dict, List
22

3+
from langchain_core.messages import BaseMessage, get_buffer_string
4+
5+
from common.config.tokenizer_manage_config import TokenizerManage
36
from setting.models_provider.base_model_provider import MaxKBBaseModel
47
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
58

@@ -18,3 +21,15 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
1821
stream_usage=True,
1922
**optional_params,
2023
)
24+
25+
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
26+
if self.usage_metadata is None or self.usage_metadata == {}:
27+
tokenizer = TokenizerManage.get_tokenizer()
28+
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
29+
return self.usage_metadata.get('input_tokens', 0)
30+
31+
def get_num_tokens(self, text: str) -> int:
32+
if self.usage_metadata is None or self.usage_metadata == {}:
33+
tokenizer = TokenizerManage.get_tokenizer()
34+
return len(tokenizer.encode(text))
35+
return self.get_last_generation_info().get('output_tokens', 0)

apps/setting/models_provider/impl/xinference_model_provider/model/llm.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
# coding=utf-8
22

3-
from typing import Dict
3+
from typing import Dict, List
44
from urllib.parse import urlparse, ParseResult
55

6+
from langchain_core.messages import BaseMessage, get_buffer_string
7+
8+
from common.config.tokenizer_manage_config import TokenizerManage
69
from setting.models_provider.base_model_provider import MaxKBBaseModel
710
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
811

@@ -33,3 +36,15 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
3336
openai_api_key=model_credential.get('api_key'),
3437
**optional_params
3538
)
39+
40+
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
41+
if self.usage_metadata is None or self.usage_metadata == {}:
42+
tokenizer = TokenizerManage.get_tokenizer()
43+
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
44+
return self.usage_metadata.get('input_tokens', 0)
45+
46+
def get_num_tokens(self, text: str) -> int:
47+
if self.usage_metadata is None or self.usage_metadata == {}:
48+
tokenizer = TokenizerManage.get_tokenizer()
49+
return len(tokenizer.encode(text))
50+
return self.get_last_generation_info().get('output_tokens', 0)

ui/src/components/generate-related-dialog/index.vue

+25-3
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,16 @@
4848
type="textarea"
4949
/>
5050
</el-form-item>
51+
<el-form-item :label="$t('views.problem.relateParagraph.selectParagraph')" prop="state">
52+
<el-radio-group v-model="state" class="radio-block">
53+
<el-radio value="error" size="large" class="mb-16">{{
54+
$t('views.document.form.selectVectorization.error')
55+
}}</el-radio>
56+
<el-radio value="all" size="large">{{
57+
$t('views.document.form.selectVectorization.all')
58+
}}</el-radio>
59+
</el-radio-group>
60+
</el-form-item>
5161
</el-form>
5262
</div>
5363
<template #footer>
@@ -87,7 +97,11 @@ const dialogVisible = ref<boolean>(false)
8797
const modelOptions = ref<any>(null)
8898
const idList = ref<string[]>([])
8999
const apiType = ref('') // 文档document或段落paragraph
90-
100+
const state = ref<'all' | 'error'>('error')
101+
const stateMap = {
102+
all: ['0', '1', '2', '3', '4', '5', 'n'],
103+
error: ['0', '1', '3', '4', '5', 'n']
104+
}
91105
const FormRef = ref()
92106
const userId = user.userInfo?.id as string
93107
const form = ref(prompt.get(userId))
@@ -131,14 +145,22 @@ const submitHandle = async (formEl: FormInstance) => {
131145
// 保存提示词
132146
prompt.save(user.userInfo?.id as string, form.value)
133147
if (apiType.value === 'paragraph') {
134-
const data = { ...form.value, paragraph_id_list: idList.value }
148+
const data = {
149+
...form.value,
150+
paragraph_id_list: idList.value,
151+
state_list: stateMap[state.value]
152+
}
135153
paragraphApi.batchGenerateRelated(id, documentId, data, loading).then(() => {
136154
MsgSuccess(t('views.document.generateQuestion.successMessage'))
137155
emit('refresh')
138156
dialogVisible.value = false
139157
})
140158
} else if (apiType.value === 'document') {
141-
const data = { ...form.value, document_id_list: idList.value }
159+
const data = {
160+
...form.value,
161+
document_id_list: idList.value,
162+
state_list: stateMap[state.value]
163+
}
142164
documentApi.batchGenerateRelated(id, data, loading).then(() => {
143165
MsgSuccess(t('views.document.generateQuestion.successMessage'))
144166
emit('refresh')

0 commit comments

Comments
 (0)