Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pr@main@dataset qa #523

Merged
merged 6 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions apps/common/handle/base_parse_qa_handle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# coding=utf-8
"""
@project: maxkb
@Author:虎
@file: base_parse_qa_handle.py
@date:2024/5/21 14:56
@desc:
"""
from abc import ABC, abstractmethod


class BaseParseQAHandle(ABC):
@abstractmethod
def support(self, file, get_buffer):
pass

@abstractmethod
def handle(self, file, get_buffer):
pass
56 changes: 56 additions & 0 deletions apps/common/handle/impl/qa/csv_parse_qa_handle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# coding=utf-8
"""
@project: maxkb
@Author:虎
@file: csv_parse_qa_handle.py
@date:2024/5/21 14:59
@desc:
"""
import csv
import io

from charset_normalizer import detect

from common.handle.base_parse_qa_handle import BaseParseQAHandle


def read_csv_standard(file_path):
data = []
with open(file_path, 'r') as file:
reader = csv.reader(file)
for row in reader:
data.append(row)
return data


class CsvParseQAHandle(BaseParseQAHandle):
def support(self, file, get_buffer):
file_name: str = file.name.lower()
if file_name.endswith(".csv"):
return True
return False

def handle(self, file, get_buffer):
buffer = get_buffer(file)
reader = csv.reader(io.TextIOWrapper(io.BytesIO(buffer), encoding=detect(buffer)['encoding']))
try:
title_row_list = reader.__next__()
except Exception as e:
return []
title_row_index_dict = {'title': 0, 'content': 1, 'problem_list': 2}
for index in range(len(title_row_list)):
title_row = title_row_list[index]
if title_row.startswith('分段标题'):
title_row_index_dict['title'] = index
if title_row.startswith('分段内容'):
title_row_index_dict['content'] = index
if title_row.startswith('问题'):
title_row_index_dict['problem_list'] = index
paragraph_list = []
for row in reader:
problem = row[title_row_index_dict.get('problem_list')]
problem_list = [{'content': p[0:255]} for p in problem.split('\n') if len(p.strip()) > 0]
paragraph_list.append({'title': row[title_row_index_dict.get('title')][0:255],
'content': row[title_row_index_dict.get('content')][0:4096],
'problem_list': problem_list})
return [{'name': file.name, 'paragraphs': paragraph_list}]
55 changes: 55 additions & 0 deletions apps/common/handle/impl/qa/xls_parse_qa_handle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# coding=utf-8
"""
@project: maxkb
@Author:虎
@file: xls_parse_qa_handle.py
@date:2024/5/21 14:59
@desc:
"""

import xlrd

from common.handle.base_parse_qa_handle import BaseParseQAHandle


def handle_sheet(file_name, sheet):
rows = iter([sheet.row_values(i) for i in range(sheet.nrows)])
try:
title_row_list = next(rows)
except Exception as e:
return None
title_row_index_dict = {'title': 0, 'content': 1, 'problem_list': 2}
for index in range(len(title_row_list)):
title_row = str(title_row_list[index])
if title_row.startswith('分段标题'):
title_row_index_dict['title'] = index
if title_row.startswith('分段内容'):
title_row_index_dict['content'] = index
if title_row.startswith('问题'):
title_row_index_dict['problem_list'] = index
paragraph_list = []
for row in rows:
problem = str(row[title_row_index_dict.get('problem_list')])
problem_list = [{'content': p[0:255]} for p in problem.split('\n') if len(p.strip()) > 0]
paragraph_list.append({'title': str(row[title_row_index_dict.get('title')])[0:255],
'content': str(row[title_row_index_dict.get('content')])[0:4096],
'problem_list': problem_list})
return {'name': file_name, 'paragraphs': paragraph_list}


class XlsParseQAHandle(BaseParseQAHandle):
def support(self, file, get_buffer):
file_name: str = file.name.lower()
if file_name.endswith(".xls"):
return True
return False

def handle(self, file, get_buffer):
buffer = get_buffer(file)
workbook = xlrd.open_workbook(file_contents=buffer)
worksheets = workbook.sheets()
worksheets_size = len(worksheets)
return [row for row in
[handle_sheet(file.name, sheet) if worksheets_size == 1 and sheet.name == 'Sheet1' else handle_sheet(
sheet.name, sheet) for sheet
in worksheets] if row is not None]
56 changes: 56 additions & 0 deletions apps/common/handle/impl/qa/xlsx_parse_qa_handle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# coding=utf-8
"""
@project: maxkb
@Author:虎
@file: xlsx_parse_qa_handle.py
@date:2024/5/21 14:59
@desc:
"""
import io

import openpyxl

from common.handle.base_parse_qa_handle import BaseParseQAHandle


def handle_sheet(file_name, sheet):
rows = sheet.rows
try:
title_row_list = next(rows)
except Exception as e:
return None
title_row_index_dict = {}
for index in range(len(title_row_list)):
title_row = str(title_row_list[index].value)
if title_row.startswith('分段标题'):
title_row_index_dict['title'] = index
if title_row.startswith('分段内容'):
title_row_index_dict['content'] = index
if title_row.startswith('问题'):
title_row_index_dict['problem_list'] = index
paragraph_list = []
for row in rows:
problem = str(row[title_row_index_dict.get('problem_list')].value)
problem_list = [{'content': p[0:255]} for p in problem.split('\n') if len(p.strip()) > 0]
paragraph_list.append({'title': str(row[title_row_index_dict.get('title')].value)[0:255],
'content': str(row[title_row_index_dict.get('content')].value)[0:4096],
'problem_list': problem_list})
return {'name': file_name, 'paragraphs': paragraph_list}


class XlsxParseQAHandle(BaseParseQAHandle):
def support(self, file, get_buffer):
file_name: str = file.name.lower()
if file_name.endswith(".xlsx"):
return True
return False

def handle(self, file, get_buffer):
buffer = get_buffer(file)
workbook = openpyxl.load_workbook(io.BytesIO(buffer))
worksheets = workbook.worksheets
worksheets_size = len(worksheets)
return [row for row in
[handle_sheet(file.name, sheet) if worksheets_size == 1 and sheet.title == 'Sheet1' else handle_sheet(
sheet.title, sheet) for sheet
in worksheets] if row is not None]
12 changes: 12 additions & 0 deletions apps/common/util/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,18 @@ def get_exec_method(clazz_: str, method_: str):
return getattr(getattr(package_model, clazz_name), method_)


def flat_map(array: List[List]):
"""
将二位数组转为一维数组
:param array: 二维数组
:return: 一维数组
"""
result = []
for e in array:
result += e
return result


def post(post_function):
def inner(func):
def run(*args, **kwargs):
Expand Down
10 changes: 10 additions & 0 deletions apps/common/util/field_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,13 @@ def image(field: str):
'invalid_image': gettext_lazy('【%s】上载有效的图像。您上载的文件不是图像或图像已损坏。' % field),
'max_length': gettext_lazy('请确保此文件名最多包含 {max_length} 个字符(长度为 {length})。')
}

@staticmethod
def file(field: str):
return {
'required': gettext_lazy('【%s】此字段必填。' % field),
'empty': gettext_lazy('【%s】提交的文件为空。' % field),
'invalid': gettext_lazy('【%s】提交的数据不是文件。请检查表单上的编码类型。' % field),
'no_name': gettext_lazy('【%s】无法确定任何文件名。' % field),
'max_length': gettext_lazy('请确保此文件名最多包含 {max_length} 个字符(长度为 {length})。')
}
82 changes: 80 additions & 2 deletions apps/dataset/serializers/dataset_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from django.contrib.postgres.fields import ArrayField
from django.core import validators
from django.db import transaction, models
from django.db.models import QuerySet, Q
from django.db.models import QuerySet
from drf_yasg import openapi
from rest_framework import serializers

Expand All @@ -29,7 +29,7 @@
from common.event import ListenerManagement, SyncWebDatasetArgs
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from common.util.common import post
from common.util.common import post, flat_map
from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from common.util.fork import ChildLink, Fork
Expand Down Expand Up @@ -210,6 +210,75 @@ def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
return True

class CreateQASerializers(serializers.Serializer):
"""
创建web站点序列化对象
"""
name = serializers.CharField(required=True,
error_messages=ErrMessage.char("知识库名称"),
max_length=64,
min_length=1)

desc = serializers.CharField(required=True,
error_messages=ErrMessage.char("知识库描述"),
max_length=256,
min_length=1)

file_list = serializers.ListSerializer(required=True,
error_messages=ErrMessage.list("文件列表"),
child=serializers.FileField(required=True,
error_messages=ErrMessage.file("文件")))

@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='file',
in_=openapi.IN_FORM,
type=openapi.TYPE_ARRAY,
items=openapi.Items(type=openapi.TYPE_FILE),
required=True,
description='上传文件'),
openapi.Parameter(name='name',
in_=openapi.IN_FORM,
required=True,
type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"),
openapi.Parameter(name='desc',
in_=openapi.IN_FORM,
required=True,
type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"),
]

@staticmethod
def get_response_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['id', 'name', 'desc', 'user_id', 'char_length', 'document_count',
'update_time', 'create_time', 'document_list'],
properties={
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
description="id", default="xx"),
'name': openapi.Schema(type=openapi.TYPE_STRING, title="名称",
description="名称", default="测试知识库"),
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="描述",
description="描述", default="测试知识库描述"),
'user_id': openapi.Schema(type=openapi.TYPE_STRING, title="所属用户id",
description="所属用户id", default="user_xxxx"),
'char_length': openapi.Schema(type=openapi.TYPE_STRING, title="字符数",
description="字符数", default=10),
'document_count': openapi.Schema(type=openapi.TYPE_STRING, title="文档数量",
description="文档数量", default=1),
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
description="修改时间",
default="1970-01-01 00:00:00"),
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
description="创建时间",
default="1970-01-01 00:00:00"
),
'document_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="文档列表",
description="文档列表",
items=DocumentSerializers.Operate.get_response_body_api())
}
)

class CreateWebSerializers(serializers.Serializer):
"""
创建web站点序列化对象
Expand Down Expand Up @@ -288,6 +357,15 @@ def post_embedding_dataset(document_list, dataset_id):
ListenerManagement.embedding_by_dataset_signal.send(dataset_id)
return document_list

def save_qa(self, instance: Dict, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
self.CreateQASerializers(data=instance).is_valid()
file_list = instance.get('file_list')
document_list = flat_map([DocumentSerializers.Create.parse_qa_file(file) for file in file_list])
dataset_instance = {'name': instance.get('name'), 'desc': instance.get('desc'), 'documents': document_list}
return self.save(dataset_instance, with_valid=True)

@post(post_function=post_embedding_dataset)
@transaction.atomic
def save(self, instance: Dict, with_valid=True):
Expand Down
Loading
Loading