Skip to content

Pr@main@document migrate #268

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions apps/common/event/listener_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ def __init__(self, problem_id: str, problem_content: str):
self.problem_content = problem_content


class UpdateEmbeddingDatasetIdArgs:
def __init__(self, source_id_list: List[str], target_dataset_id: str):
self.source_id_list = source_id_list
self.target_dataset_id = target_dataset_id


class ListenerManagement:
embedding_by_problem_signal = signal("embedding_by_problem")
embedding_by_paragraph_signal = signal("embedding_by_paragraph")
Expand Down Expand Up @@ -205,6 +211,11 @@ def update_problem(args: UpdateProblemArgs):
VectorStore.get_embedding_vector().update_by_source_ids([v.id for v in problem_paragraph_mapping_list],
{'embedding': embed_value})

@staticmethod
def update_embedding_dataset_id(args: UpdateEmbeddingDatasetIdArgs):
VectorStore.get_embedding_vector().update_by_source_ids(args.source_id_list,
{'dataset_id': args.target_dataset_id})

@staticmethod
def delete_embedding_by_source_ids(source_ids: List[str]):
VectorStore.get_embedding_vector().delete_by_source_ids(source_ids, SourceType.PROBLEM)
Expand Down
102 changes: 101 additions & 1 deletion apps/dataset/serializers/document_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from common.db.search import native_search, native_page_search
from common.event.common import work_thread_pool
from common.event.listener_manage import ListenerManagement, SyncWebDocumentArgs
from common.event.listener_manage import ListenerManagement, SyncWebDocumentArgs, UpdateEmbeddingDatasetIdArgs
from common.exception.app_exception import AppApiException
from common.handle.impl.doc_split_handle import DocSplitHandle
from common.handle.impl.pdf_split_handle import PdfSplitHandle
Expand Down Expand Up @@ -114,6 +114,106 @@ def get_request_body_api():


class DocumentSerializers(ApiMixin, serializers.Serializer):
class Migrate(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True,
error_messages=ErrMessage.char(
"知识库id"))
target_dataset_id = serializers.UUIDField(required=True,
error_messages=ErrMessage.char(
"目标知识库id"))
document_id_list = serializers.ListField(required=True, error_messages=ErrMessage.char("文档列表"),
child=serializers.UUIDField(required=True,
error_messages=ErrMessage.uuid("文档id")))

@transaction.atomic
def migrate(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
dataset_id = self.data.get('dataset_id')
target_dataset_id = self.data.get('target_dataset_id')
dataset = QuerySet(DataSet).filter(id=dataset_id).first()
target_dataset = QuerySet(DataSet).filter(id=target_dataset_id).first()
document_id_list = self.data.get('document_id_list')
document_list = QuerySet(Document).filter(dataset_id=dataset_id, id__in=document_id_list)
paragraph_list = QuerySet(Paragraph).filter(dataset_id=dataset_id, document_id__in=document_id_list)

problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter(paragraph__in=paragraph_list)
problem_list = QuerySet(Problem).filter(
id__in=[problem_paragraph_mapping.problem_id for problem_paragraph_mapping in
problem_paragraph_mapping_list])
target_problem_list = list(
QuerySet(Problem).filter(content__in=[problem.content for problem in problem_list],
dataset_id=target_dataset_id))
target_handle_problem_list = [
self.get_target_dataset_problem(target_dataset_id, problem_paragraph_mapping,
problem_list, target_problem_list) for
problem_paragraph_mapping
in
problem_paragraph_mapping_list]

create_problem_list = [problem for problem, is_create in target_handle_problem_list if
is_create is not None and is_create]
# 插入问题
QuerySet(Problem).bulk_create(create_problem_list)
# 修改mapping
QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list, ['problem_id', 'dataset_id'])
# 修改文档
if dataset.type == Type.base.value and target_dataset.type == Type.web.value:
document_list.update(dataset_id=target_dataset_id, type=Type.web,
meta={'source_url': '', 'selector': ''})
elif target_dataset.type == Type.base.value and dataset.type == Type.web.value:
document_list.update(dataset_id=target_dataset_id, type=Type.base,
meta={})
paragraph_list.update(dataset_id=target_dataset_id)
ListenerManagement.update_embedding_dataset_id(UpdateEmbeddingDatasetIdArgs(
[problem_paragraph_mapping.id for problem_paragraph_mapping in problem_paragraph_mapping_list],
target_dataset_id))

@staticmethod
def get_target_dataset_problem(target_dataset_id: str,
problem_paragraph_mapping,
source_problem_list,
target_problem_list):
source_problem_list = [source_problem for source_problem in source_problem_list if
source_problem.id == problem_paragraph_mapping.problem_id]
problem_paragraph_mapping.dataset_id = target_dataset_id
if len(source_problem_list) > 0:
problem_content = source_problem_list[-1].content
problem_list = [problem for problem in target_problem_list if problem.content == problem_content]
if len(problem_list) > 0:
problem = problem_list[-1]
problem_paragraph_mapping.problem_id = problem.id
return problem, False
else:
problem = Problem(id=uuid.uuid1(), dataset_id=target_dataset_id, content=problem_content)
target_problem_list.append(problem)
problem_paragraph_mapping.problem_id = problem.id
return problem, True
return None

@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='知识库id'),
openapi.Parameter(name='target_dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='目标知识库id')
]

@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_ARRAY,
items=openapi.Schema(type=openapi.TYPE_STRING),
title='文档id列表',
description="文档id列表"
)

class Query(ApiMixin, serializers.Serializer):
# 知识库id
dataset_id = serializers.UUIDField(required=True,
Expand Down
1 change: 1 addition & 0 deletions apps/dataset/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
name="document_operate"),
path('dataset/document/split_pattern', views.Document.SplitPattern.as_view(),
name="document_operate"),
path('dataset/<str:dataset_id>/document/migrate/<str:target_dataset_id>', views.Document.Migrate.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/refresh', views.Document.Refresh.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph', views.Paragraph.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<int:current_page>/<int:page_size>',
Expand Down
28 changes: 27 additions & 1 deletion apps/dataset/views/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from rest_framework.views import Request

from common.auth import TokenAuth, has_permissions
from common.constants.permission_constants import Permission, Group, Operate
from common.constants.permission_constants import Permission, Group, Operate, CompareConstants
from common.response import result
from common.util.common import query_params_to_single_dict
from dataset.serializers.common_serializers import BatchSerializer
Expand Down Expand Up @@ -135,6 +135,32 @@ def put(self, request: Request, dataset_id: str, document_id: str):
DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id}).refresh(
))

class Migrate(APIView):
authentication_classes = [TokenAuth]

@action(methods=['PUT'], detail=False)
@swagger_auto_schema(operation_summary="批量迁移文档",
operation_id="批量迁移文档",
manual_parameters=DocumentSerializers.Migrate.get_request_params_api(),
request_body=DocumentSerializers.Migrate.get_request_body_api(),
responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api()),
tags=["知识库/文档"]
)
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')),
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('target_dataset_id')),
compare=CompareConstants.AND
)
def put(self, request: Request, dataset_id: str, target_dataset_id: str):
return result.success(
DocumentSerializers.Migrate(
data={'dataset_id': dataset_id, 'target_dataset_id': target_dataset_id,
'document_id_list': request.data}).migrate(

))

class Operate(APIView):
authentication_classes = [TokenAuth]

Expand Down
21 changes: 20 additions & 1 deletion ui/src/api/document.ts
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,24 @@ const postWebDocument: (
return post(`${prefix}/${dataset_id}/document/web`, data, undefined, loading)
}

/**
* 批量迁移文档
* @param 参数 dataset_id,target_dataset_id,
*/
const putMigrateMulDocument: (
dataset_id: string,
target_dataset_id: string,
data: any,
loading?: Ref<boolean>
) => Promise<Result<boolean>> = (dataset_id, target_dataset_id, data, loading) => {
return put(
`${prefix}/${dataset_id}/document/migrate/${target_dataset_id}`,
data,
undefined,
loading
)
}

export default {
postSplitDocument,
getDocument,
Expand All @@ -200,5 +218,6 @@ export default {
listSplitPattern,
putDocumentRefresh,
delMulSyncDocument,
postWebDocument
postWebDocument,
putMigrateMulDocument
}
25 changes: 25 additions & 0 deletions ui/src/components/icons/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -774,5 +774,30 @@ export const iconMap: any = {
)
])
}
},
'app-migrate': {
iconReader: () => {
return h('i', [
h(
'svg',
{
style: { height: '100%', width: '100%' },
viewBox: '0 0 1024 1024',
version: '1.1',
xmlns: 'http://www.w3.org/2000/svg'
},
[
h('path', {
d: 'M537.6 665.6c-12.8 12.8-12.8 32 0 44.8 6.4 6.4 12.8 6.4 25.6 6.4 6.4 0 19.2 0 25.6-6.4l128-134.4s6.4-6.4 6.4-12.8v-19.2-6.4c0-6.4-6.4-12.8-6.4-12.8l-134.4-128c-12.8-12.8-32-12.8-44.8 0-12.8 12.8-12.8 38.4 0 51.2l76.8 76.8H96c-19.2 0-32 12.8-32 32s12.8 32 32 32h524.8l-83.2 76.8z',
fill: 'currentColor'
}),
h('path', {
d: 'M960 384c0-6.4-6.4-12.8-6.4-19.2L704 128c-6.4-6.4-6.4-6.4-12.8-6.4h-6.4-371.2c-76.8 0-140.8 64-140.8 140.8v172.8c0 19.2 12.8 32 32 32s25.6-19.2 25.6-38.4V262.4c0-44.8 38.4-76.8 76.8-76.8h339.2v211.2c0 19.2 12.8 32 32 32H896V768c0 44.8-38.4 76.8-76.8 76.8H313.6c-44.8 0-76.8-38.4-76.8-76.8v-89.6c0-19.2-12.8-32-32-32s-32 12.8-32 32V768c0 76.8 64 140.8 140.8 140.8h505.6c76.8 0 140.8-64 140.8-140.8V384c0 6.4 0 6.4 0 0z m-243.2-25.6V224l134.4 134.4h-134.4z',
fill: 'currentColor'
})
]
)
])
}
}
}
125 changes: 125 additions & 0 deletions ui/src/views/document/component/SelectDatasetDialog.vue
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
<template>
<el-dialog title="选择知识库" v-model="dialogVisible" width="600" class="select-dataset-dialog">
<template #header="{ titleId, titleClass }">
<div class="my-header flex">
<h4 :id="titleId" :class="titleClass">选择知识库</h4>
<el-button link class="ml-16" @click="refresh">
<el-icon class="mr-4"><Refresh /></el-icon>刷新
</el-button>
</div>
</template>
<div class="content-height">
<el-radio-group v-model="selectDataset" class="card__radio">
<el-scrollbar height="500">
<div class="p-16">
<el-row :gutter="12" v-loading="loading">
<el-col :span="12" v-for="(item, index) in datasetList" :key="index" class="mb-16">
<el-card
shadow="never"
class="mb-8"
:class="item.id === selectDataset ? 'active' : ''"
>
<el-radio :value="item.id" size="large">
<div class="flex align-center">
<AppAvatar v-if="item?.type === '0'" class="mr-8" shape="square" :size="32">
<img src="@/assets/icon_document.svg" style="width: 58%" alt="" />
</AppAvatar>
<AppAvatar
v-if="item?.type === '1'"
class="mr-8 avatar-purple"
shape="square"
:size="32"
>
<img src="@/assets/icon_web.svg" style="width: 58%" alt="" />
</AppAvatar>
<span class="ellipsis">
{{ item.name }}
</span>
</div>
</el-radio>
</el-card>
</el-col>
</el-row>
</div>
</el-scrollbar>
</el-radio-group>
</div>
<template #footer>
<span class="dialog-footer">
<el-button @click.prevent="dialogVisible = false"> 取消 </el-button>
<el-button type="primary" @click="submitHandle" :disabled="!selectDataset || loading">
确认
</el-button>
</span>
</template>
</el-dialog>
</template>
<script setup lang="ts">
import { ref, watch } from 'vue'
import { useRoute } from 'vue-router'
import documentApi from '@/api/document'

import useStore from '@/stores'
const { dataset } = useStore()
const route = useRoute()
const {
params: { id } // id为datasetID
} = route as any

const emit = defineEmits(['refresh'])

const loading = ref<boolean>(false)

const dialogVisible = ref<boolean>(false)
const selectDataset = ref('')
const datasetList = ref<any>([])
const documentList = ref<any>([])

watch(dialogVisible, (bool) => {
if (!bool) {
selectDataset.value = ''
datasetList.value = []
documentList.value = []
}
})

const open = (list: any) => {
documentList.value = list
getDataset()
dialogVisible.value = true
}
const submitHandle = () => {
documentApi
.putMigrateMulDocument(id, selectDataset.value, documentList.value, loading)
.then((res) => {
emit('refresh')
dialogVisible.value = false
})
}

function getDataset() {
dataset.asyncGetAllDataset(loading).then((res: any) => {
datasetList.value = res.data
})
}

const refresh = () => {
getDataset()
}

defineExpose({ open })
</script>
<style lang="scss" scope>
.select-dataset-dialog {
padding: 0;
.el-dialog__header {
padding: 24px 24px 0 24px;
}
.el-dialog__body {
padding: 8px !important;
}
.el-dialog__footer {
padding: 0 24px 24px;
}
}
</style>
Loading