Skip to content

fix: 修复【知识库】知识库上传 有关联问题的会阻塞 #676

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion apps/dataset/serializers/common_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
@desc:
"""
import os
import uuid
from typing import List

from django.db.models import QuerySet
Expand All @@ -20,7 +21,7 @@
from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from common.util.fork import Fork
from dataset.models import Paragraph
from dataset.models import Paragraph, Problem, ProblemParagraphMapping
from smartdoc.conf import PROJECT_DIR


Expand Down Expand Up @@ -79,3 +80,53 @@ def get_request_body_api():
description="主键id列表")
}
)


class ProblemParagraphObject:
def __init__(self, dataset_id: str, document_id: str, paragraph_id: str, problem_content: str):
self.dataset_id = dataset_id
self.document_id = document_id
self.paragraph_id = paragraph_id
self.problem_content = problem_content


def or_get(exists_problem_list, content, dataset_id, document_id, paragraph_id, problem_content_dict):
if content in problem_content_dict:
return problem_content_dict.get(content)[0], document_id, paragraph_id
exists = [row for row in exists_problem_list if row.content == content]
if len(exists) > 0:
problem_content_dict[content] = exists[0], False
return exists[0], document_id, paragraph_id
else:
problem = Problem(id=uuid.uuid1(), content=content, dataset_id=dataset_id)
problem_content_dict[content] = problem, True
return problem, document_id, paragraph_id


class ProblemParagraphManage:
def __init__(self, problemParagraphObjectList: [ProblemParagraphObject], dataset_id):
self.dataset_id = dataset_id
self.problemParagraphObjectList = problemParagraphObjectList

def to_problem_model_list(self):
problem_list = [item.problem_content for item in self.problemParagraphObjectList]
exists_problem_list = []
if len(self.problemParagraphObjectList) > 0:
# 查询到已存在的问题列表
exists_problem_list = QuerySet(Problem).filter(dataset_id=self.dataset_id,
content__in=problem_list).all()
problem_content_dict = {}
problem_model_list = [
or_get(exists_problem_list, problemParagraphObject.problem_content, problemParagraphObject.dataset_id,
problemParagraphObject.document_id, problemParagraphObject.paragraph_id, problem_content_dict) for
problemParagraphObject in self.problemParagraphObjectList]

problem_paragraph_mapping_list = [
ProblemParagraphMapping(id=uuid.uuid1(), document_id=document_id, problem_id=problem_model.id,
paragraph_id=paragraph_id,
dataset_id=self.dataset_id) for
problem_model, document_id, paragraph_id in problem_model_list]

result = [problem_model for problem_model, is_create in problem_content_dict.values() if
is_create], problem_paragraph_mapping_list
return result
17 changes: 8 additions & 9 deletions apps/dataset/serializers/dataset_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from common.util.fork import ChildLink, Fork
from common.util.split_model import get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
from embedding.models import SearchMode
from setting.models import AuthOperate
Expand Down Expand Up @@ -383,21 +383,20 @@ def save(self, instance: Dict, with_valid=True):

document_model_list = []
paragraph_model_list = []
problem_model_list = []
problem_paragraph_mapping_list = []
problem_paragraph_object_list = []
# 插入文档
for document in instance.get('documents') if 'documents' in instance else []:
document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id,
document)
document_model_list.append(document_paragraph_dict_model.get('document'))
for paragraph in document_paragraph_dict_model.get('paragraph_model_list'):
paragraph_model_list.append(paragraph)
for problem in document_paragraph_dict_model.get('problem_model_list'):
problem_model_list.append(problem)
for problem_paragraph_mapping in document_paragraph_dict_model.get('problem_paragraph_mapping_list'):
problem_paragraph_mapping_list.append(problem_paragraph_mapping)
problem_model_list, problem_paragraph_mapping_list = DocumentSerializers.Create.reset_problem_model(
problem_model_list, problem_paragraph_mapping_list)
for problem_paragraph_object in document_paragraph_dict_model.get('problem_paragraph_object_list'):
problem_paragraph_object_list.append(problem_paragraph_object)

problem_model_list, problem_paragraph_mapping_list = (ProblemParagraphManage(problem_paragraph_object_list,
dataset_id)
.to_problem_model_list())
# 插入知识库
dataset.save()
# 插入文档
Expand Down
55 changes: 19 additions & 36 deletions apps/dataset/serializers/document_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from common.util.fork import Fork
from common.util.split_model import get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping, Image
from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer
from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
from smartdoc.conf import PROJECT_DIR

Expand Down Expand Up @@ -380,8 +380,9 @@ def sync(self, with_valid=True, with_embedding=True):
document_paragraph_model = DocumentSerializers.Create.get_paragraph_model(document, paragraphs)

paragraph_model_list = document_paragraph_model.get('paragraph_model_list')
problem_model_list = document_paragraph_model.get('problem_model_list')
problem_paragraph_mapping_list = document_paragraph_model.get('problem_paragraph_mapping_list')
problem_paragraph_object_list = document_paragraph_model.get('problem_paragraph_object_list')
problem_model_list, problem_paragraph_mapping_list = ProblemParagraphManage(
problem_paragraph_object_list, document.dataset_id).to_problem_model_list()
# 批量插入段落
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
# 批量插入问题
Expand Down Expand Up @@ -626,11 +627,13 @@ def save(self, instance: Dict, with_valid=False, **kwargs):
self.is_valid(raise_exception=True)
dataset_id = self.data.get('dataset_id')
document_paragraph_model = self.get_document_paragraph_model(dataset_id, instance)

document_model = document_paragraph_model.get('document')
paragraph_model_list = document_paragraph_model.get('paragraph_model_list')
problem_model_list = document_paragraph_model.get('problem_model_list')
problem_paragraph_mapping_list = document_paragraph_model.get('problem_paragraph_mapping_list')

problem_paragraph_object_list = document_paragraph_model.get('problem_paragraph_object_list')
problem_model_list, problem_paragraph_mapping_list = (ProblemParagraphManage(problem_paragraph_object_list,
dataset_id)
.to_problem_model_list())
# 插入文档
document_model.save()
# 批量插入段落
Expand Down Expand Up @@ -685,35 +688,15 @@ def get_paragraph_model(document_model, paragraph_list: List):
dataset_id, document_model.id, paragraph) for paragraph in paragraph_list]

paragraph_model_list = []
problem_model_list = []
problem_paragraph_mapping_list = []
problem_paragraph_object_list = []
for paragraphs in paragraph_model_dict_list:
paragraph = paragraphs.get('paragraph')
for problem_model in paragraphs.get('problem_model_list'):
problem_model_list.append(problem_model)
for problem_paragraph_mapping in paragraphs.get('problem_paragraph_mapping_list'):
problem_paragraph_mapping_list.append(problem_paragraph_mapping)
for problem_model in paragraphs.get('problem_paragraph_object_list'):
problem_paragraph_object_list.append(problem_model)
paragraph_model_list.append(paragraph)

problem_model_list, problem_paragraph_mapping_list = DocumentSerializers.Create.reset_problem_model(
problem_model_list, problem_paragraph_mapping_list)

return {'document': document_model, 'paragraph_model_list': paragraph_model_list,
'problem_model_list': problem_model_list,
'problem_paragraph_mapping_list': problem_paragraph_mapping_list}

@staticmethod
def reset_problem_model(problem_model_list, problem_paragraph_mapping_list):
new_problem_model_list = [x for i, x in enumerate(problem_model_list) if
len([item for item in problem_model_list[:i] if item.content == x.content]) <= 0]

for new_problem_model in new_problem_model_list:
old_model_list = [problem.id for problem in problem_model_list if
problem.content == new_problem_model.content]
for problem_paragraph_mapping in problem_paragraph_mapping_list:
if old_model_list.__contains__(problem_paragraph_mapping.problem_id):
problem_paragraph_mapping.problem_id = new_problem_model.id
return new_problem_model_list, problem_paragraph_mapping_list
'problem_paragraph_object_list': problem_paragraph_object_list}

@staticmethod
def get_document_paragraph_model(dataset_id, instance: Dict):
Expand Down Expand Up @@ -834,20 +817,20 @@ def batch_save(self, instance_list: List[Dict], with_valid=True):
dataset_id = self.data.get("dataset_id")
document_model_list = []
paragraph_model_list = []
problem_model_list = []
problem_paragraph_mapping_list = []
problem_paragraph_object_list = []
# 插入文档
for document in instance_list:
document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id,
document)
document_model_list.append(document_paragraph_dict_model.get('document'))
for paragraph in document_paragraph_dict_model.get('paragraph_model_list'):
paragraph_model_list.append(paragraph)
for problem in document_paragraph_dict_model.get('problem_model_list'):
problem_model_list.append(problem)
for problem_paragraph_mapping in document_paragraph_dict_model.get('problem_paragraph_mapping_list'):
problem_paragraph_mapping_list.append(problem_paragraph_mapping)
for problem_paragraph_object in document_paragraph_dict_model.get('problem_paragraph_object_list'):
problem_paragraph_object_list.append(problem_paragraph_object)

problem_model_list, problem_paragraph_mapping_list = (ProblemParagraphManage(problem_paragraph_object_list,
dataset_id)
.to_problem_model_list())
# 插入文档
QuerySet(Document).bulk_create(document_model_list) if len(document_model_list) > 0 else None
# 批量插入段落
Expand Down
37 changes: 11 additions & 26 deletions apps/dataset/serializers/paragraph_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from common.util.common import post
from common.util.field_message import ErrMessage
from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping
from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer
from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer, ProblemParagraphObject, \
ProblemParagraphManage
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers
from embedding.models import SourceType

Expand Down Expand Up @@ -567,8 +568,10 @@ def save(self, instance: Dict, with_valid=True, with_embedding=True):
document_id = self.data.get('document_id')
paragraph_problem_model = self.get_paragraph_problem_model(dataset_id, document_id, instance)
paragraph = paragraph_problem_model.get('paragraph')
problem_model_list = paragraph_problem_model.get('problem_model_list')
problem_paragraph_mapping_list = paragraph_problem_model.get('problem_paragraph_mapping_list')
problem_paragraph_object_list = paragraph_problem_model.get('problem_paragraph_object_list')
problem_model_list, problem_paragraph_mapping_list = (ProblemParagraphManage(problem_paragraph_object_list,
dataset_id).
to_problem_model_list())
# 插入段落
paragraph_problem_model.get('paragraph').save()
# 插入問題
Expand All @@ -591,30 +594,12 @@ def get_paragraph_problem_model(dataset_id: str, document_id: str, instance: Dic
content=instance.get("content"),
dataset_id=dataset_id,
title=instance.get("title") if 'title' in instance else '')
problem_list = instance.get('problem_list')
exists_problem_list = []
if 'problem_list' in instance and len(problem_list) > 0:
exists_problem_list = QuerySet(Problem).filter(dataset_id=dataset_id,
content__in=[p.get('content') for p in
problem_list]).all()

problem_model_list = [
ParagraphSerializers.Create.or_get(exists_problem_list, problem.get('content'), dataset_id) for
problem in (
instance.get('problem_list') if 'problem_list' in instance else [])]
# 问题去重
problem_model_list = [x for i, x in enumerate(problem_model_list) if
len([item for item in problem_model_list[:i] if item.content == x.content]) <= 0]

problem_paragraph_mapping_list = [
ProblemParagraphMapping(id=uuid.uuid1(), document_id=document_id, problem_id=problem_model.id,
paragraph_id=paragraph.id,
dataset_id=dataset_id) for
problem_model in problem_model_list]
problem_paragraph_object_list = [
ProblemParagraphObject(dataset_id, document_id, paragraph.id, problem.get('content')) for problem in
(instance.get('problem_list') if 'problem_list' in instance else [])]

return {'paragraph': paragraph,
'problem_model_list': [problem_model for problem_model in problem_model_list if
not list(exists_problem_list).__contains__(problem_model)],
'problem_paragraph_mapping_list': problem_paragraph_mapping_list}
'problem_paragraph_object_list': problem_paragraph_object_list}

@staticmethod
def or_get(exists_problem_list, content, dataset_id):
Expand Down
Loading