From b47e90c34d2f1b9d62e231f24a127f2e12a68ca8 Mon Sep 17 00:00:00 2001 From: Gabo Date: Tue, 16 Jul 2024 11:49:40 +0200 Subject: [PATCH] Handle concurrency --- src/app.py | 34 +++++++++++-------- src/pdf_layout_analysis/get_xml.py | 14 ++++++++ .../run_pdf_layout_analysis.py | 2 +- 3 files changed, 34 insertions(+), 16 deletions(-) create mode 100644 src/pdf_layout_analysis/get_xml.py diff --git a/src/app.py b/src/app.py index 02f3a70..d0db634 100755 --- a/src/app.py +++ b/src/app.py @@ -1,13 +1,12 @@ -import os import sys -from os.path import join -from pathlib import Path import torch from fastapi import FastAPI, UploadFile, File, Form from fastapi.responses import PlainTextResponse +from starlette.concurrency import run_in_threadpool from catch_exceptions import catch_exceptions -from configuration import service_logger, XMLS_PATH +from configuration import service_logger +from pdf_layout_analysis.get_xml import get_xml from pdf_layout_analysis.run_pdf_layout_analysis import analyze_pdf from pdf_layout_analysis.run_pdf_layout_analysis_fast import analyze_pdf_fast from toc.extract_table_of_contents import extract_table_of_contents @@ -19,31 +18,36 @@ @app.get("/") async def info(): - return sys.version + return sys.version + " Using GPU: " + str(torch.cuda.is_available()) + + +@app.get("/error") +async def error(): + raise FileNotFoundError("This is a test error from the error endpoint") @app.post("/") @catch_exceptions -async def run(file: UploadFile = File(...), fast: bool = Form(False)): - service_logger.info(f"Processing file: {file.filename}") - return analyze_pdf_fast(file.file.read()) if fast else analyze_pdf(file.file.read()) +async def run(file: UploadFile = File(...)): + return await run_in_threadpool(analyze_pdf, file.file.read(), "") @app.post("/save_xml/{xml_file_name}") @catch_exceptions async def analyze_and_save_xml(file: UploadFile = File(...), xml_file_name: str | None = None): - return analyze_pdf(file.file.read(), xml_file_name) + return await run_in_threadpool(analyze_pdf, file.file.read(), xml_file_name) @app.get("/get_xml/{xml_file_name}", response_class=PlainTextResponse) @catch_exceptions -async def get_xml(xml_file_name: str): - xml_file_path = Path(join(XMLS_PATH, xml_file_name)) +async def get_xml_by_name(xml_file_name: str): + return await run_in_threadpool(get_xml, xml_file_name) - with open(xml_file_path, mode="r") as file: - content = file.read() - os.remove(xml_file_path) - return content + +@app.post("/fast") +@catch_exceptions +async def run_fast(file: UploadFile = File(...)): + return await run_in_threadpool(analyze_pdf_fast, file.file.read()) @app.post("/toc") diff --git a/src/pdf_layout_analysis/get_xml.py b/src/pdf_layout_analysis/get_xml.py new file mode 100644 index 0000000..78ed55f --- /dev/null +++ b/src/pdf_layout_analysis/get_xml.py @@ -0,0 +1,14 @@ +import os +from os.path import join +from pathlib import Path + +from configuration import XMLS_PATH + + +def get_xml(xml_file_name: str) -> str: + xml_file_path = Path(join(XMLS_PATH, xml_file_name)) + + with open(xml_file_path, mode="r") as file: + content = file.read() + os.remove(xml_file_path) + return content diff --git a/src/pdf_layout_analysis/run_pdf_layout_analysis.py b/src/pdf_layout_analysis/run_pdf_layout_analysis.py index afbde7b..8195868 100644 --- a/src/pdf_layout_analysis/run_pdf_layout_analysis.py +++ b/src/pdf_layout_analysis/run_pdf_layout_analysis.py @@ -49,7 +49,7 @@ def predict_doclaynet(): VGTTrainer.test(configuration, model) -def analyze_pdf(file: AnyStr, xml_file_name: str = "") -> list[dict]: +def analyze_pdf(file: AnyStr, xml_file_name: str) -> list[dict]: pdf_path = pdf_content_to_pdf_path(file) service_logger.info(f"Creating PDF images") pdf_images_list: list[PdfImages] = [PdfImages.from_pdf_path(pdf_path, "", xml_file_name)]