Skip to content

Commit

Permalink
Add endpoint to save xml
Browse files Browse the repository at this point in the history
  • Loading branch information
gabriel-piles committed Jul 2, 2024
1 parent 0ac9430 commit fea5c60
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 9 deletions.
13 changes: 10 additions & 3 deletions src/PdfImages.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import shutil

import cv2
Expand All @@ -9,7 +10,7 @@
from pdf2image import convert_from_path
from pdf_features.PdfFeatures import PdfFeatures

from src.configuration import IMAGES_ROOT_PATH
from src.configuration import IMAGES_ROOT_PATH, XMLS_PATH


class PdfImages:
Expand All @@ -36,8 +37,14 @@ def remove_images():
shutil.rmtree(IMAGES_ROOT_PATH)

@staticmethod
def from_pdf_path(pdf_path: str, pdf_name: str = ""):
pdf_features: PdfFeatures = PdfFeatures.from_pdf_path(pdf_path)
def from_pdf_path(pdf_path: str | Path, pdf_name: str = "", xml_name: str = ""):
xml_path = Path(join(XMLS_PATH, xml_name)) if xml_name else None

if xml_path and not xml_path.parent.exists():
os.makedirs(xml_path.parent, exist_ok=True)

pdf_features: PdfFeatures = PdfFeatures.from_pdf_path(pdf_path, str(xml_path) if xml_path else None)

if pdf_name:
pdf_features.file_name = pdf_name
else:
Expand Down
4 changes: 2 additions & 2 deletions src/analyze_pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ def predict_doclaynet():
VGTTrainer.test(configuration, model)


def analyze_pdf(file: AnyStr):
def analyze_pdf(file: AnyStr, xml_name: str = "") -> list[dict]:
pdf_path = pdf_content_to_pdf_path(file)
service_logger.info(f"Creating PDF images")
pdf_images_list: list[PdfImages] = [PdfImages.from_pdf_path(pdf_path)]
pdf_images_list: list[PdfImages] = [PdfImages.from_pdf_path(pdf_path, "", xml_name)]
create_word_grid([pdf_images.pdf_features for pdf_images in pdf_images_list])
get_annotations(pdf_images_list)
predict_doclaynet()
Expand Down
33 changes: 32 additions & 1 deletion src/app.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import os
import sys
from os.path import join
from pathlib import Path

import torch
from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.responses import PlainTextResponse

from analyze_pdf_fast import analyze_pdf_fast
from configuration import service_logger
from configuration import service_logger, XMLS_PATH
from src.analyze_pdf import analyze_pdf

service_logger.info(f"Is PyTorch using GPU: {torch.cuda.is_available()}")
Expand All @@ -27,6 +31,33 @@ async def run(file: UploadFile = File(...)):
raise HTTPException(status_code=422, detail="Error extracting paragraphs")


@app.post("/save_xml/{xml_file_name}")
async def analyze_and_save_xml(file: UploadFile = File(...), xml_file_name: str = None):
try:
service_logger.info(f"Processing file: {file.filename}")
service_logger.info(f"Saving xml: {xml_file_name}")
return analyze_pdf(file.file.read(), xml_file_name)
except Exception:
service_logger.error("Error", exc_info=1)
raise HTTPException(status_code=422, detail="Error extracting paragraphs")


@app.get("/get_xml/{xml_file_name}", response_class=PlainTextResponse)
async def get_xml(xml_file_name: str):
try:
xml_file_path = Path(join(XMLS_PATH, xml_file_name))

with open(xml_file_path, mode="r") as file:
content = file.read()
os.remove(xml_file_path)
return content
except FileNotFoundError:
raise HTTPException(status_code=404, detail="No xml file")
except Exception:
service_logger.error("Error", exc_info=1)
raise HTTPException(status_code=422, detail="An error has occurred. Check graylog for more info")


@app.post("/fast")
async def run_fast(file: UploadFile = File(...)):
try:
Expand Down
1 change: 1 addition & 0 deletions src/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
JSONS_ROOT_PATH = Path(join(ROOT_PATH, "jsons"))
JSON_TEST_FILE_PATH = Path(join(JSONS_ROOT_PATH, "test.json"))
MODELS_PATH = Path(join(ROOT_PATH, "models"))
XMLS_PATH = Path(join(ROOT_PATH, "xmls"))

DOCLAYNET_TYPE_BY_ID = {
1: "Caption",
Expand Down
7 changes: 4 additions & 3 deletions src/data_model/SegmentBox.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from paragraph_extraction_trainer.PdfSegment import PdfSegment
from pdf_features.Rectangle import Rectangle
from pdf_token_type_labels.TokenType import TokenType
from pydantic import BaseModel

Expand All @@ -12,7 +13,7 @@ class SegmentBox(BaseModel):
height: float
page_number: int
text: str = ""
type: int = 0
type: TokenType = TokenType.TEXT

def to_dict(self):
return {
Expand All @@ -22,7 +23,7 @@ def to_dict(self):
"height": self.height,
"page_number": self.page_number,
"text": self.text,
"type": TokenType.from_index(self.type).name,
"type": self.type.name,
}

@staticmethod
Expand All @@ -34,5 +35,5 @@ def from_pdf_segment(pdf_segment: PdfSegment):
height=pdf_segment.bounding_box.height,
page_number=pdf_segment.page_number,
text=pdf_segment.text_content,
type=pdf_segment.segment_type if type(pdf_segment.segment_type) is int else pdf_segment.segment_type.get_index(),
type=pdf_segment.segment_type,
)

0 comments on commit fea5c60

Please sign in to comment.