diff --git a/src/app.py b/src/app.py index 2920a3e..3e807ed 100755 --- a/src/app.py +++ b/src/app.py @@ -9,6 +9,7 @@ from pdf_layout_analysis.get_xml import get_xml from pdf_layout_analysis.run_pdf_layout_analysis import analyze_pdf from pdf_layout_analysis.run_pdf_layout_analysis_fast import analyze_pdf_fast +from text_extraction.get_text_extraction import get_text_extraction from toc.get_toc import get_toc service_logger.info(f"Is PyTorch using GPU: {torch.cuda.is_available()}") @@ -51,3 +52,9 @@ async def get_xml_by_name(xml_file_name: str): @catch_exceptions async def get_toc_endpoint(file: UploadFile = File(...), fast: bool = Form(False)): return await run_in_threadpool(get_toc, file, fast) + + +@app.post("/text") +@catch_exceptions +async def get_text_endpoint(file: UploadFile = File(...), fast: bool = Form(False), types: str = Form("all")): + return await run_in_threadpool(get_text_extraction, file, fast, types) diff --git a/src/test_end_to_end.py b/src/test_end_to_end.py index f2e61c1..eb3e24c 100644 --- a/src/test_end_to_end.py +++ b/src/test_end_to_end.py @@ -178,3 +178,28 @@ def test_toc_fast(self): self.assertEqual(response_json[0]["indentation"], 0) self.assertEqual(response_json[-1]["label"], "C. TITLE LONGER") self.assertEqual(response_json[-1]["indentation"], 2) + + def test_text_extraction(self): + with open(f"{ROOT_PATH}/test_pdfs/test.pdf", "rb") as stream: + files = {"file": stream} + + response = requests.post(f"{self.service_url}/text", files=files) + + response_json = response.json() + self.assertEqual(response.status_code, 200) + self.assertEqual(response_json.split()[0], "Document") + self.assertEqual(response_json.split()[1], "Big") + self.assertEqual(response_json.split()[-1], "TEXT") + + def test_text_extraction_fast(self): + with open(f"{ROOT_PATH}/test_pdfs/test.pdf", "rb") as stream: + files = {"file": stream} + data = {"fast": "True"} + + response = requests.post(f"{self.service_url}/text", files=files, data=data) + + response_json = response.json() + self.assertEqual(response.status_code, 200) + self.assertEqual(response_json.split()[0], "Document") + self.assertEqual(response_json.split()[1], "Big") + self.assertEqual(response_json.split()[-1], "TEXT") diff --git a/src/text_extraction/extract_text.py b/src/text_extraction/extract_text.py new file mode 100644 index 0000000..28c9fba --- /dev/null +++ b/src/text_extraction/extract_text.py @@ -0,0 +1,14 @@ +from configuration import service_logger +from pdf_token_type_labels.TokenType import TokenType + + +def extract_text(segment_boxes: list[dict], types: list[TokenType]): + service_logger.info(f"Extracted types: {[t.name for t in types]}") + text = "\n".join( + [ + segment_box["text"] + for segment_box in segment_boxes + if TokenType.from_text(segment_box["type"].replace(" ", "_")) in types + ] + ) + return text diff --git a/src/text_extraction/get_text_extraction.py b/src/text_extraction/get_text_extraction.py new file mode 100644 index 0000000..c9dab23 --- /dev/null +++ b/src/text_extraction/get_text_extraction.py @@ -0,0 +1,16 @@ +from fastapi import UploadFile +from pdf_token_type_labels.TokenType import TokenType +from pdf_layout_analysis.run_pdf_layout_analysis import analyze_pdf +from pdf_layout_analysis.run_pdf_layout_analysis_fast import analyze_pdf_fast +from text_extraction.extract_text import extract_text + + +def get_text_extraction(file: UploadFile, fast: bool, types: str): + file_content = file.file.read() + if types == "all": + token_types: list[TokenType] = [t for t in TokenType] + else: + token_types = list(set([TokenType.from_text(t.strip().replace(" ", "_")) for t in types.split(",")])) + if fast: + return extract_text(analyze_pdf_fast(file_content), token_types) + return extract_text(analyze_pdf(file_content, ""), token_types)