Skip to content

Commit

Permalink
Merge pull request #44 from huridocs/text_extraction
Browse files Browse the repository at this point in the history
Add text extraction
  • Loading branch information
ali6parmak authored Jul 18, 2024
2 parents 53b7e17 + 81b1264 commit 3cb0648
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pdf_layout_analysis.get_xml import get_xml
from pdf_layout_analysis.run_pdf_layout_analysis import analyze_pdf
from pdf_layout_analysis.run_pdf_layout_analysis_fast import analyze_pdf_fast
from text_extraction.get_text_extraction import get_text_extraction
from toc.get_toc import get_toc

service_logger.info(f"Is PyTorch using GPU: {torch.cuda.is_available()}")
Expand Down Expand Up @@ -51,3 +52,9 @@ async def get_xml_by_name(xml_file_name: str):
@catch_exceptions
async def get_toc_endpoint(file: UploadFile = File(...), fast: bool = Form(False)):
return await run_in_threadpool(get_toc, file, fast)


@app.post("/text")
@catch_exceptions
async def get_text_endpoint(file: UploadFile = File(...), fast: bool = Form(False), types: str = Form("all")):
return await run_in_threadpool(get_text_extraction, file, fast, types)
25 changes: 25 additions & 0 deletions src/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,28 @@ def test_toc_fast(self):
self.assertEqual(response_json[0]["indentation"], 0)
self.assertEqual(response_json[-1]["label"], "C. TITLE LONGER")
self.assertEqual(response_json[-1]["indentation"], 2)

def test_text_extraction(self):
with open(f"{ROOT_PATH}/test_pdfs/test.pdf", "rb") as stream:
files = {"file": stream}

response = requests.post(f"{self.service_url}/text", files=files)

response_json = response.json()
self.assertEqual(response.status_code, 200)
self.assertEqual(response_json.split()[0], "Document")
self.assertEqual(response_json.split()[1], "Big")
self.assertEqual(response_json.split()[-1], "TEXT")

def test_text_extraction_fast(self):
with open(f"{ROOT_PATH}/test_pdfs/test.pdf", "rb") as stream:
files = {"file": stream}
data = {"fast": "True"}

response = requests.post(f"{self.service_url}/text", files=files, data=data)

response_json = response.json()
self.assertEqual(response.status_code, 200)
self.assertEqual(response_json.split()[0], "Document")
self.assertEqual(response_json.split()[1], "Big")
self.assertEqual(response_json.split()[-1], "TEXT")
14 changes: 14 additions & 0 deletions src/text_extraction/extract_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from configuration import service_logger
from pdf_token_type_labels.TokenType import TokenType


def extract_text(segment_boxes: list[dict], types: list[TokenType]):
service_logger.info(f"Extracted types: {[t.name for t in types]}")
text = "\n".join(
[
segment_box["text"]
for segment_box in segment_boxes
if TokenType.from_text(segment_box["type"].replace(" ", "_")) in types
]
)
return text
16 changes: 16 additions & 0 deletions src/text_extraction/get_text_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from fastapi import UploadFile
from pdf_token_type_labels.TokenType import TokenType
from pdf_layout_analysis.run_pdf_layout_analysis import analyze_pdf
from pdf_layout_analysis.run_pdf_layout_analysis_fast import analyze_pdf_fast
from text_extraction.extract_text import extract_text


def get_text_extraction(file: UploadFile, fast: bool, types: str):
file_content = file.file.read()
if types == "all":
token_types: list[TokenType] = [t for t in TokenType]
else:
token_types = list(set([TokenType.from_text(t.strip().replace(" ", "_")) for t in types.split(",")]))
if fast:
return extract_text(analyze_pdf_fast(file_content), token_types)
return extract_text(analyze_pdf(file_content, ""), token_types)

0 comments on commit 3cb0648

Please sign in to comment.