Skip to content

Commit

Permalink
Handle concurrency
Browse files Browse the repository at this point in the history
  • Loading branch information
gabriel-piles committed Jul 16, 2024
1 parent 4624511 commit b47e90c
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 16 deletions.
34 changes: 19 additions & 15 deletions src/app.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import os
import sys
from os.path import join
from pathlib import Path
import torch
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.responses import PlainTextResponse
from starlette.concurrency import run_in_threadpool

from catch_exceptions import catch_exceptions
from configuration import service_logger, XMLS_PATH
from configuration import service_logger
from pdf_layout_analysis.get_xml import get_xml
from pdf_layout_analysis.run_pdf_layout_analysis import analyze_pdf
from pdf_layout_analysis.run_pdf_layout_analysis_fast import analyze_pdf_fast
from toc.extract_table_of_contents import extract_table_of_contents
Expand All @@ -19,31 +18,36 @@

@app.get("/")
async def info():
return sys.version
return sys.version + " Using GPU: " + str(torch.cuda.is_available())


@app.get("/error")
async def error():
raise FileNotFoundError("This is a test error from the error endpoint")


@app.post("/")
@catch_exceptions
async def run(file: UploadFile = File(...), fast: bool = Form(False)):
service_logger.info(f"Processing file: {file.filename}")
return analyze_pdf_fast(file.file.read()) if fast else analyze_pdf(file.file.read())
async def run(file: UploadFile = File(...)):
return await run_in_threadpool(analyze_pdf, file.file.read(), "")


@app.post("/save_xml/{xml_file_name}")
@catch_exceptions
async def analyze_and_save_xml(file: UploadFile = File(...), xml_file_name: str | None = None):
return analyze_pdf(file.file.read(), xml_file_name)
return await run_in_threadpool(analyze_pdf, file.file.read(), xml_file_name)


@app.get("/get_xml/{xml_file_name}", response_class=PlainTextResponse)
@catch_exceptions
async def get_xml(xml_file_name: str):
xml_file_path = Path(join(XMLS_PATH, xml_file_name))
async def get_xml_by_name(xml_file_name: str):
return await run_in_threadpool(get_xml, xml_file_name)

with open(xml_file_path, mode="r") as file:
content = file.read()
os.remove(xml_file_path)
return content

@app.post("/fast")
@catch_exceptions
async def run_fast(file: UploadFile = File(...)):
return await run_in_threadpool(analyze_pdf_fast, file.file.read())


@app.post("/toc")
Expand Down
14 changes: 14 additions & 0 deletions src/pdf_layout_analysis/get_xml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import os
from os.path import join
from pathlib import Path

from configuration import XMLS_PATH


def get_xml(xml_file_name: str) -> str:
xml_file_path = Path(join(XMLS_PATH, xml_file_name))

with open(xml_file_path, mode="r") as file:
content = file.read()
os.remove(xml_file_path)
return content
2 changes: 1 addition & 1 deletion src/pdf_layout_analysis/run_pdf_layout_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def predict_doclaynet():
VGTTrainer.test(configuration, model)


def analyze_pdf(file: AnyStr, xml_file_name: str = "") -> list[dict]:
def analyze_pdf(file: AnyStr, xml_file_name: str) -> list[dict]:
pdf_path = pdf_content_to_pdf_path(file)
service_logger.info(f"Creating PDF images")
pdf_images_list: list[PdfImages] = [PdfImages.from_pdf_path(pdf_path, "", xml_file_name)]
Expand Down

0 comments on commit b47e90c

Please sign in to comment.