Skip to content

feat: 【知识库】docx支持图片上传 #69 #267

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apps/common/handle/base_split_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ def support(self, file, get_buffer):
pass

@abstractmethod
def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer):
def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image):
pass
91 changes: 80 additions & 11 deletions apps/common/handle/impl/doc_split_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,17 @@
"""
import io
import re
import traceback
import uuid
from typing import List

from docx import Document
from docx import Document, ImagePart
from docx.table import Table
from docx.text.paragraph import Paragraph

from common.handle.base_split_handle import BaseSplitHandle
from common.util.split_model import SplitModel
from dataset.models import Image

default_pattern_list = [re.compile('(?<=^)# .*|(?<=\\n)# .*'),
re.compile('(?<=\\n)(?<!#)## (?!#).*|(?<=^)(?<!#)## (?!#).*'),
Expand All @@ -25,28 +28,86 @@
re.compile("(?<=\\n)(?<!#)###### (?!#).*|(?<=^)(?<!#)###### (?!#).*")]


def image_to_mode(image, doc: Document, images_list, get_image_id):
for img_id in image.xpath('.//a:blip/@r:embed'): # 获取图片id
part = doc.part.related_parts[img_id] # 根据图片id获取对应的图片
if isinstance(part, ImagePart):
image_uuid = get_image_id(img_id)
if len([i for i in images_list if i.id == image_uuid]) == 0:
image = Image(id=image_uuid, image=part.blob, image_name=part.filename)
images_list.append(image)
return f'![](/api/image/{image_uuid})'


def get_paragraph_element_txt(paragraph_element, doc: Document, images_list, get_image_id):
try:
images = paragraph_element.xpath(".//pic:pic")
if len(images) > 0:
return "".join(
[item for item in [image_to_mode(image, doc, images_list, get_image_id) for image in images] if
item is not None])
elif paragraph_element.text is not None:
return paragraph_element.text
return ""
except Exception as e:
print(e)
return ""


def get_paragraph_txt(paragraph: Paragraph, doc: Document, images_list, get_image_id):
try:
return "".join([get_paragraph_element_txt(e, doc, images_list, get_image_id) for e in paragraph._element])
except Exception as e:
return ""


def get_cell_text(cell, doc: Document, images_list, get_image_id):
try:
return "".join(
[get_paragraph_txt(paragraph, doc, images_list, get_image_id) for paragraph in cell.paragraphs]).replace(
"\n", '</br>')
except Exception as e:
return ""


def get_image_id_func():
image_map = {}

def get_image_id(image_id):
_v = image_map.get(image_id)
if _v is None:
image_map[image_id] = uuid.uuid1()
return image_map.get(image_id)
return _v

return get_image_id


class DocSplitHandle(BaseSplitHandle):
@staticmethod
def paragraph_to_md(paragraph):
def paragraph_to_md(paragraph: Paragraph, doc: Document, images_list, get_image_id):
try:
psn = paragraph.style.name
if psn.startswith('Heading'):
return "".join(["#" for i in range(int(psn.replace("Heading ", '')))]) + " " + paragraph.text
except Exception as e:
return paragraph.text
return paragraph.text
return get_paragraph_txt(paragraph, doc, images_list, get_image_id)

@staticmethod
def table_to_md(table):
def table_to_md(table, doc: Document, images_list, get_image_id):
rows = table.rows

# 创建 Markdown 格式的表格
md_table = '| ' + ' | '.join([cell.text.replace("\n", '</br>') for cell in rows[0].cells]) + ' |\n'
md_table = '| ' + ' | '.join(
[get_cell_text(cell, doc, images_list, get_image_id) for cell in rows[0].cells]) + ' |\n'
md_table += '| ' + ' | '.join(['---' for i in range(len(rows[0].cells))]) + ' |\n'
for row in rows[1:]:
md_table += '| ' + ' | '.join([cell.text.replace("\n", '</br>') for cell in row.cells]) + ' |\n'
md_table += '| ' + ' | '.join(
[get_cell_text(cell, doc, images_list, get_image_id) for cell in row.cells]) + ' |\n'
return md_table

def to_md(self, doc):
def to_md(self, doc, images_list, get_image_id):
elements = []
for element in doc.element.body:
if element.tag.endswith('tbl'):
Expand All @@ -57,21 +118,29 @@ def to_md(self, doc):
# 处理段落
paragraph = Paragraph(element, doc)
elements.append(paragraph)

return "\n".join(
[self.paragraph_to_md(element) if isinstance(element, Paragraph) else self.table_to_md(element) for element
[self.paragraph_to_md(element, doc, images_list, get_image_id) if isinstance(element,
Paragraph) else self.table_to_md(
element,
doc,
images_list, get_image_id)
for element
in elements])

def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer):
def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image):
try:
image_list = []
buffer = get_buffer(file)
doc = Document(io.BytesIO(buffer))
content = self.to_md(doc)
content = self.to_md(doc, image_list, get_image_id_func())
if len(image_list) > 0:
save_image(image_list)
if pattern_list is not None and len(pattern_list) > 0:
split_model = SplitModel(pattern_list, with_filter, limit)
else:
split_model = SplitModel(default_pattern_list, with_filter=with_filter, limit=limit)
except BaseException as e:
traceback.print_exception(e)
return {'name': file.name,
'content': []}
return {'name': file.name,
Expand Down
2 changes: 1 addition & 1 deletion apps/common/handle/impl/pdf_split_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def number_to_text(pdf_document, page_number):


class PdfSplitHandle(BaseSplitHandle):
def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer):
def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer,save_image):
try:
buffer = get_buffer(file)
pdf_document = fitz.open(file.name, buffer)
Expand Down
2 changes: 1 addition & 1 deletion apps/common/handle/impl/text_split_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def support(self, file, get_buffer):
return True
return False

def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer):
def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image):
buffer = get_buffer(file)
if pattern_list is not None and len(pattern_list) > 0:
split_model = SplitModel(pattern_list, with_filter, limit)
Expand Down
10 changes: 7 additions & 3 deletions apps/dataset/serializers/document_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from common.util.file_util import get_file_content
from common.util.fork import Fork
from common.util.split_model import get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping, Image
from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
from smartdoc.conf import PROJECT_DIR
Expand Down Expand Up @@ -627,9 +627,13 @@ def get_buffer(self, file):
split_handles = [DocSplitHandle(), PdfSplitHandle(), default_split_handle]


def save_image(image_list):
QuerySet(Image).bulk_create(image_list)


def file_to_paragraph(file, pattern_list: List, with_filter: bool, limit: int):
get_buffer = FileBufferHandle().get_buffer
for split_handle in split_handles:
if split_handle.support(file, get_buffer):
return split_handle.handle(file, pattern_list, with_filter, limit, get_buffer)
return default_split_handle.handle(file, pattern_list, with_filter, limit, get_buffer)
return split_handle.handle(file, pattern_list, with_filter, limit, get_buffer, save_image)
return default_split_handle.handle(file, pattern_list, with_filter, limit, get_buffer, save_image)