Skip to content

Commit

Permalink
Add --ignore-embedded-text (#47)
Browse files Browse the repository at this point in the history
This PR adds the `--ignore-embedded-text` flag to the `scan` command so
that it is possible to ignore embedded text in document types that
support it (e.g. PDFs). The motivation for this feature is that some
PDFs have OCR results embedded that are low quality and should be
ignored.

The addition to `load_document` is tested via:

```
pytest -sv tests/test_end_to_end.py::test_run_with_ignore_embedded_text
```

Notably, the answer with Tesseract only is actually incorrect, because
it _seems_ Tesseract is missing entire columns.
  • Loading branch information
rstebbing authored Sep 28, 2022
1 parent 13c127b commit 8fadca8
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 5 deletions.
8 changes: 7 additions & 1 deletion src/docquery/cmd/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ def build_parser(subparsers, parent_parser):
parser.add_argument(
"--ocr", choices=list(OCR_MAPPING.keys()), default=None, help="The OCR engine you would like to use"
)
parser.add_argument(
"--ignore-embedded-text",
dest="use_embedded_text",
action="store_false",
help="Do not try and extract embedded text from document types that might provide it (e.g. PDFs)",
)
parser.add_argument(
"--classify",
default=False,
Expand Down Expand Up @@ -58,7 +64,7 @@ def main(args):
for p in paths:
try:
log.info(f"Loading {p}")
docs.append((p, load_document(str(p), ocr_reader=args.ocr)))
docs.append((p, load_document(str(p), ocr_reader=args.ocr, use_embedded_text=args.use_embedded_text)))
except UnsupportedDocument as e:
log.warning(f"Cannot load {p}: {e}. Skipping...")

Expand Down
9 changes: 5 additions & 4 deletions src/docquery/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,10 @@ def _generate_document_output(


class PDFDocument(Document):
def __init__(self, b, ocr_reader, **kwargs):
def __init__(self, b, ocr_reader, use_embedded_text, **kwargs):
self.b = b
self.ocr_reader = ocr_reader
self.use_embedded_text = use_embedded_text

super().__init__(**kwargs)

Expand All @@ -125,7 +126,7 @@ def context(self) -> Dict[str, List[Tuple["Image.Image", List[Any]]]]:
boxes_by_page = []
dimensions_by_page = []
for i, page in enumerate(pdf.pages):
extracted_words = page.extract_words()
extracted_words = page.extract_words() if self.use_embedded_text else []

if len(extracted_words) == 0:
words, boxes = self.ocr_reader.apply_ocr(images[i])
Expand Down Expand Up @@ -234,7 +235,7 @@ def context(self) -> Dict[str, List[Tuple["Image.Image", List[Any]]]]:


@validate_arguments
def load_document(fpath: str, ocr_reader: Optional[Union[str, OCRReader]] = None):
def load_document(fpath: str, ocr_reader: Optional[Union[str, OCRReader]] = None, use_embedded_text=True):
base_path = os.path.basename(fpath).split("?")[0].strip()
doc_type = mimetypes.guess_type(base_path)[0]
if fpath.startswith("http://") or fpath.startswith("https://"):
Expand All @@ -255,7 +256,7 @@ def load_document(fpath: str, ocr_reader: Optional[Union[str, OCRReader]] = None
raise NoOCRReaderFound(f"{ocr_reader} is not a supported OCRReader class")

if doc_type == "application/pdf":
return PDFDocument(b.read(), ocr_reader=ocr_reader)
return PDFDocument(b.read(), ocr_reader=ocr_reader, use_embedded_text=use_embedded_text)
elif doc_type == "text/html":
return WebDocument(fpath)
else:
Expand Down
14 changes: 14 additions & 0 deletions tests/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ class Example(BaseModel):
"question": "What are net sales for 2020?",
"answers": {
"LayoutLMv1": [{"score": 0.9429, "answer": "$ 3,750\n", "word_ids": [15, 16], "page": 0}],
# (The answer with `use_embedded_text=False` relies entirely on Tesseract, and it is incorrect because it
# misses 3,750 altogether.)
"LayoutLMv1__use_embedded_text=False": [
{"score": 0.3078, "answer": "$ 3,980", "word_ids": [11, 12], "page": 0}
],
"LayoutLMv1-Invoices": [{"score": 0.9956, "answer": "$ 3,750\n", "word_ids": [15, 16], "page": 0}],
"Donut": [{"answer": "$ 3,750"}],
},
Expand Down Expand Up @@ -132,3 +137,12 @@ def test_run_with_choosen_OCR_instance():
for qa in example.qa_pairs:
resp = pipe(question=qa.question, **document.context, top_k=1)
assert nested_simplify(resp, decimals=4) == qa.answers["LayoutLMv1"]


def test_run_with_ignore_embedded_text():
example = EXAMPLES[2]
document = load_document(example.path, use_embedded_text=False)
pipe = pipeline("document-question-answering", model=CHECKPOINTS["LayoutLMv1"])
for qa in example.qa_pairs:
resp = pipe(question=qa.question, **document.context, top_k=1)
assert nested_simplify(resp, decimals=4) == qa.answers["LayoutLMv1__use_embedded_text=False"]

0 comments on commit 8fadca8

Please sign in to comment.