Skip to content

feat: 生成关联问题 #1339

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions apps/dataset/models/data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class Status(models.TextChoices):
success = 1, '已完成'
error = 2, '导入失败'
queue_up = 3, '排队中'
generating = 4, '生成问题中'


class Type(models.TextChoices):
Expand Down
33 changes: 32 additions & 1 deletion apps/dataset/serializers/document_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage, \
get_embedding_model_id_by_dataset_id
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
from dataset.task import sync_web_document
from dataset.task import sync_web_document, generate_related_by_document_id
from embedding.task.embedding import embedding_by_document, delete_embedding_by_document_list, \
delete_embedding_by_document, update_embedding_dataset_id, delete_embedding_by_paragraph_ids, \
embedding_by_document_list
Expand Down Expand Up @@ -960,6 +960,37 @@ def batch_refresh(self, instance: Dict, with_valid=True):
except AlreadyQueued as e:
raise AppApiException(500, "任务正在执行中,请勿重复下发")

class GenerateRelated(ApiMixin, serializers.Serializer):
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))

def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
document_id = self.data.get('document_id')
if not QuerySet(Document).filter(id=document_id).exists():
raise AppApiException(500, "文档id不存在")

def generate_related(self, model_id, prompt, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
document_id = self.data.get('document_id')
QuerySet(Document).filter(id=document_id).update(status=Status.queue_up)
generate_related_by_document_id.delay(document_id, model_id, prompt)



class BatchGenerateRelated(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))

@transaction.atomic
def batch_generate_related(self, instance: Dict, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
document_id_list = instance.get("document_id_list")
model_id = instance.get("model_id")
prompt = instance.get("prompt")
for document_id in document_id_list:
DocumentSerializers.GenerateRelated(data={'document_id': document_id}).generate_related(model_id, prompt)


class FileBufferHandle:
buffer = None
Expand Down
18 changes: 18 additions & 0 deletions apps/dataset/serializers/paragraph_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from embedding.task.embedding import embedding_by_problem as embedding_by_problem_task, embedding_by_problem, \
delete_embedding_by_source, enable_embedding_by_paragraph, disable_embedding_by_paragraph, embedding_by_paragraph, \
delete_embedding_by_paragraph, delete_embedding_by_paragraph_ids, update_embedding_document_id
from dataset.task import generate_related_by_paragraph_id_list


class ParagraphSerializer(serializers.ModelSerializer):
Expand Down Expand Up @@ -719,3 +720,20 @@ def get_response_body_api():
)
}
)


class BatchGenerateRelated(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))

@transaction.atomic
def batch_generate_related(self, instance: Dict, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
paragraph_id_list = instance.get("paragraph_id_list")
model_id = instance.get("model_id")
prompt = instance.get("prompt")
generate_related_by_paragraph_id_list.delay(paragraph_id_list, model_id, prompt)



1 change: 1 addition & 0 deletions apps/dataset/task/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
@desc:
"""
from .sync import *
from .generate import *
64 changes: 64 additions & 0 deletions apps/dataset/task/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import logging
from math import ceil

from celery_once import QueueOnce
from django.db.models import QuerySet
from langchain_core.messages import HumanMessage

from common.config.embedding_config import ModelManage
from dataset.models import Paragraph, Document, Status
from dataset.task.tools import save_problem
from ops import celery_app
from setting.models import Model
from setting.models_provider import get_model

max_kb_error = logging.getLogger("max_kb_error")
max_kb = logging.getLogger("max_kb")


def get_llm_model(model_id):
model = QuerySet(Model).filter(id=model_id).first()
return ModelManage.get_model(model_id, lambda _id: get_model(model))


@celery_app.task(base=QueueOnce, once={'keys': ['document_id']},
name='celery:generate_related_by_document')
def generate_related_by_document_id(document_id, model_id, prompt):
llm_model = get_llm_model(model_id)
offset = 0
page_size = 10
QuerySet(Document).filter(id=document_id).update(status=Status.generating)

count = QuerySet(Paragraph).filter(document_id=document_id).count()
for i in range(0, ceil(count / page_size)):
paragraph_list = QuerySet(Paragraph).filter(document_id=document_id).all()[offset:offset + page_size]
offset += page_size
for paragraph in paragraph_list:
res = llm_model.invoke([HumanMessage(content=prompt.replace('{data}', paragraph.content))])
if (res.content is None) or (len(res.content) == 0):
continue
problems = res.content.split('\n')
for problem in problems:
save_problem(paragraph.dataset_id, paragraph.document_id, paragraph.id, problem)

QuerySet(Document).filter(id=document_id).update(status=Status.success)



@celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id_list']},
name='celery:generate_related_by_paragraph_list')
def generate_related_by_paragraph_id_list(paragraph_id_list, model_id, prompt):
llm_model = get_llm_model(model_id)
offset = 0
page_size = 10
count = QuerySet(Paragraph).filter(id__in=paragraph_id_list).count()
for i in range(0, ceil(count / page_size)):
paragraph_list = QuerySet(Paragraph).filter(id__in=paragraph_id_list).all()[offset:offset + page_size]
offset += page_size
for paragraph in paragraph_list:
res = llm_model.invoke([HumanMessage(content=prompt.replace('{data}', paragraph.content))])
if (res.content is None) or (len(res.content) == 0):
continue
problems = res.content.split('\n')
for problem in problems:
save_problem(paragraph.dataset_id, paragraph.document_id, paragraph.id, problem)
21 changes: 21 additions & 0 deletions apps/dataset/task/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

import logging
import re
import traceback

from common.util.fork import ChildLink, Fork
Expand Down Expand Up @@ -60,3 +61,23 @@ def handler(source_url: str, selector, response: Fork.Response):
status=Status.error).save()

return handler


def save_problem(dataset_id, document_id, paragraph_id, problem):
from dataset.serializers.paragraph_serializers import ParagraphSerializers
# print(f"dataset_id: {dataset_id}")
# print(f"document_id: {document_id}")
# print(f"paragraph_id: {paragraph_id}")
# print(f"problem: {problem}")
problem = re.sub(r"^\d+\.\s*", "", problem)
pattern = r"<question>(.*?)</question>"
match = re.search(pattern, problem)
problem = match.group(1) if match else None
if problem is None or len(problem) == 0:
return
try:
ParagraphSerializers.Problem(
data={"dataset_id": dataset_id, 'document_id': document_id,
'paragraph_id': paragraph_id}).save(instance={"content": problem}, with_valid=True)
except Exception as e:
max_kb_error.error(f'关联问题失败: {e}')
4 changes: 4 additions & 0 deletions apps/dataset/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
path('dataset/<str:dataset_id>/sync_web', views.Dataset.SyncWeb.as_view()),
path('dataset/<str:dataset_id>/hit_test', views.Dataset.HitTest.as_view()),
path('dataset/<str:dataset_id>/document', views.Document.as_view(), name='document'),
path('dataset/<str:dataset_id>/model', views.Dataset.Model.as_view()),
path('dataset/document/template/export', views.Template.as_view()),
path('dataset/document/table_template/export', views.TableTemplate.as_view()),
path('dataset/<str:dataset_id>/document/web', views.WebDocument.as_view()),
Expand All @@ -24,6 +25,7 @@
path('dataset/<str:dataset_id>/document/batch_hit_handling', views.Document.BatchEditHitHandling.as_view()),
path('dataset/<str:dataset_id>/document/<int:current_page>/<int:page_size>', views.Document.Page.as_view()),
path('dataset/<str:dataset_id>/document/batch_refresh', views.Document.BatchRefresh.as_view()),
path('dataset/<str:dataset_id>/document/batch_generate_related', views.Document.BatchGenerateRelated.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>', views.Document.Operate.as_view(),
name="document_operate"),
path('dataset/document/split', views.Document.Split.as_view(),
Expand All @@ -36,12 +38,14 @@
path('dataset/<str:dataset_id>/document/<str:document_id>/sync', views.Document.SyncWeb.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/refresh', views.Document.Refresh.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph', views.Paragraph.as_view()),
path('dataset/<str:dataset_id>/document/batch_generate_related', views.Document.BatchGenerateRelated.as_view()),
path(
'dataset/<str:dataset_id>/document/<str:document_id>/paragraph/migrate/dataset/<str:target_dataset_id>/document/<str:target_document_id>',
views.Paragraph.BatchMigrate.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/_batch', views.Paragraph.Batch.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<int:current_page>/<int:page_size>',
views.Paragraph.Page.as_view(), name='paragraph_page'),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/batch_generate_related', views.Paragraph.BatchGenerateRelated.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<paragraph_id>',
views.Paragraph.Operate.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<str:paragraph_id>/problem',
Expand Down
18 changes: 18 additions & 0 deletions apps/dataset/views/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from common.response.result import get_page_request_params, get_page_api_response, get_api_response
from common.swagger_api.common_api import CommonApi
from dataset.serializers.dataset_serializers import DataSetSerializers
from setting.serializers.provider_serializers import ModelSerializer


class Dataset(APIView):
Expand Down Expand Up @@ -223,3 +224,20 @@ def get(self, request: Request, current_page, page_size):
'user_id': str(request.user.id)})
d.is_valid()
return result.success(d.page(current_page, page_size))

class Model(APIView):
authentication_classes = [TokenAuth]

@action(methods=["GET"], detail=False)
@has_permissions(ViewPermission(
[RoleConstants.ADMIN, RoleConstants.USER],
[lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=keywords.get('dataset_id'))],
compare=CompareConstants.AND))
def get(self, request: Request, dataset_id: str):
print(dataset_id)
return result.success(
ModelSerializer.Query(
data={'user_id': request.user.id, 'model_type': 'LLM'}).list(
with_valid=True)
)
11 changes: 11 additions & 0 deletions apps/dataset/views/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,3 +393,14 @@ def get(self, request: Request, dataset_id: str, current_page, page_size):
data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id})
d.is_valid(raise_exception=True)
return result.success(d.page(current_page, page_size))

class BatchGenerateRelated(APIView):
authentication_classes = [TokenAuth]

@action(methods=['PUT'], detail=False)
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def put(self, request: Request, dataset_id: str):
return result.success(DocumentSerializers.BatchGenerateRelated(data={'dataset_id': dataset_id})
.batch_generate_related(request.data))
12 changes: 12 additions & 0 deletions apps/dataset/views/paragraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,15 @@ def get(self, request: Request, dataset_id: str, document_id: str, current_page,
'document_id': document_id})
d.is_valid(raise_exception=True)
return result.success(d.page(current_page, page_size))

class BatchGenerateRelated(APIView):
authentication_classes = [TokenAuth]

@action(methods=['PUT'], detail=False)
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def put(self, request: Request, dataset_id: str, document_id: str):
return result.success(
ParagraphSerializers.BatchGenerateRelated(data={'dataset_id': dataset_id, 'document_id': document_id})
.batch_generate_related(request.data))
19 changes: 18 additions & 1 deletion ui/src/api/dataset.ts
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,22 @@ const exportDataset: (
return exportExcel(dataset_name + '.xlsx', `dataset/${dataset_id}/export`, undefined, loading)
}


/**
* 获取当前用户可使用的模型列表
* @param application_id
* @param loading
* @query { query_text: string, top_number: number, similarity: number }
* @returns
*/
const getDatasetModel: (
dataset_id: string,
loading?: Ref<boolean>
) => Promise<Result<Array<any>>> = (dataset_id, loading) => {
return get(`${prefix}/${dataset_id}/model`, loading)
}


export default {
getDataset,
getAllDataset,
Expand All @@ -215,5 +231,6 @@ export default {
putSyncWebDataset,
putReEmbeddingDataset,
postQADataset,
exportDataset
exportDataset,
getDatasetModel
}
16 changes: 15 additions & 1 deletion ui/src/api/document.ts
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,19 @@ const exportDocument: (
)
}

const batchGenerateRelated: (
dataset_id: string,
data: any,
loading?: Ref<boolean>
) => Promise<Result<boolean>> = (dataset_id, data, loading) => {
return put(
`${prefix}/${dataset_id}/document/batch_generate_related`,
data,
undefined,
loading
)
}

export default {
postSplitDocument,
getDocument,
Expand All @@ -338,5 +351,6 @@ export default {
postQADocument,
postTableDocument,
exportDocument,
batchRefresh
batchRefresh,
batchGenerateRelated
}
18 changes: 17 additions & 1 deletion ui/src/api/paragraph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,21 @@ const disassociationProblem: (
)
}

const batchGenerateRelated: (
dataset_id: string,
document_id: string,
data: any,
loading?: Ref<boolean>
) => Promise<Result<boolean>> = (dataset_id, document_id, data, loading) => {
return put(
`${prefix}/${dataset_id}/document/${document_id}/paragraph/batch_generate_related`,
data,
undefined,
loading
)
}


export default {
getParagraph,
delParagraph,
Expand All @@ -236,5 +251,6 @@ export default {
disassociationProblem,
associationProblem,
delMulParagraph,
putMigrateMulParagraph
putMigrateMulParagraph,
batchGenerateRelated
}
Loading
Loading