Skip to content

Commit

Permalink
Adds new bbox merging for GoogleAzureOCR that mostly preserves the ex…
Browse files Browse the repository at this point in the history
…isting Google OCR order
  • Loading branch information
Paethon committed Jan 29, 2024
1 parent 508800c commit 195ecb0
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 6 deletions.
85 changes: 80 additions & 5 deletions ocr_wrapper/google_azure_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
from ocr_wrapper.ocr_wrapper import rotate_image
from ocr_wrapper.tilt_correction import correct_tilt

from .utils import get_img_hash
from .bbox_order import get_ordered_bboxes_idxs
from .bbox_utils import split_bbox
from .utils import get_img_hash


class GoogleAzureOCR:
Expand Down Expand Up @@ -105,10 +106,16 @@ def ocr(self, img: Image.Image, return_extra: bool = False) -> Union[list[BBox],

# Combine the bboxes from Google and Azure
bbox_overlap_checker = BBoxOverlapChecker(google_bboxes)
combined_bboxes = google_bboxes.copy()
azure_bboxes_to_add = []

for bbox in azure_bboxes:
if len(bbox_overlap_checker.get_overlapping_bboxes(bbox)) == 0:
combined_bboxes.append(bbox)
azure_bboxes_to_add.append(bbox)

document_width, document_height = img.size
combined_bboxes = _merge_bboxes(
google_bboxes, azure_bboxes_to_add, document_width=document_width, document_height=document_height
)

# Build extra information dict
extra = {
Expand Down Expand Up @@ -205,6 +212,76 @@ def get_overlapping_bboxes(self, bbox: BBox, threshold: float = 0.01) -> list[BB
return overlapping_bboxes


def _merge_bboxes(
google_bboxes: list[BBox], azure_bboxes: list[BBox], document_width: int, document_height: int
) -> list[BBox]:
"""
Given the list of google_bboxes as well as azure_bboxes, inserts the azure_bboxes into the google_bboxes list at the correct position.
For this, the order of google_bboxes are used as the reference. The position of the azure bboxes are determined by
merging the two lists and sorting them using the order_bboxes function, which returns indexes of a fully sorted list.
This sorting is not used to sort the bboxes, but to determine the position of the azure bboxes in the google_bboxes list.
"""
google_bboxes_idxs = [i for i in range(len(google_bboxes))]
azure_bboxes_idxs = [i + len(google_bboxes) for i in range(len(azure_bboxes))]
merged_bboxes = google_bboxes + azure_bboxes
sorted_idxs = get_ordered_bboxes_idxs(
merged_bboxes, document_width=document_width, document_height=document_height
)
merged_bbox_idxs = merge_idx_lists(google_bboxes_idxs, azure_bboxes_idxs, sorted_idxs)
merged_bboxes = [merged_bboxes[i] for i in merged_bbox_idxs]

return merged_bboxes


def merge_idx_lists(raw_a, raw_b, sorted_ab):
"""
We merge two lists of indexes, raw_a and raw_b, into a single list. The order of the indexes in raw_a follow the
order given in raw_a, but elements from raw_b can be inserted in between the elements of raw_a. The order of the
elements in raw_b is determined by the order of the elements in sorted_ab.
"""
assert len(raw_a) + len(raw_b) == len(sorted_ab)

if len(sorted_ab) == 0:
return []

result = []
raw_a_set = set(raw_a)
raw_b_set = set(raw_b)
raw_a_left = raw_a.copy()
raw_a_left.reverse()

# Create a map of each element in sorted_ab to the one following it
# e.g. [1, 2, 3, 4] -> {1: 2, 2: 3, 3: 4}
next_sorted_map = {sorted_ab[i]: sorted_ab[i + 1] for i in range(len(sorted_ab) - 1)}

# Select the first element to add
if sorted_ab[0] in raw_b_set: # If the first element in sorted_ab is in raw_b, we start with that
last_added = sorted_ab[0]
raw_b_set.remove(last_added)
else: # Otherwise, we start with the first element in raw_a
last_added = raw_a[0]
raw_a_set.remove(last_added)
raw_a_left.pop()
result.append(last_added)

# Add all the other elements
while len(raw_a_set) != 0 or len(raw_b_set) != 0:
next_in_sorted = next_sorted_map.get(last_added, -1)
if next_in_sorted in raw_b_set: # If the next element in sorted_ab is in raw_b, we follow the sorted order ...
last_added = next_in_sorted
raw_b_set.remove(last_added)
else: # ... otherwise we keep the order given in raw_a
last_added = raw_a_left.pop()
raw_a_set.remove(last_added)

result.append(last_added)

assert len(result) == len(raw_a) + len(raw_b)

return result


def _get_mean_bbox_area(bboxes: list[BBox]) -> float:
"""Returns the mean area of the bboxes in the list
Expand Down Expand Up @@ -254,8 +331,6 @@ def consecutive_elements(lst, n):
def is_match(combination):
concatenated = "".join(c.text for c in combination).replace(" ", "")
_match = re.match(date_range_pattern, concatenated)
if _match:
print(f"Match: {concatenated}")
return _match

# Generate all combinations of strings in the list of different lengths
Expand Down
11 changes: 10 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@
"Programming Language :: Python :: 3",
"Operating System :: OS Independent",
],
install_requires=["Pillow", "Shapely", "pdf2image", "rtree", "opencv-python-headless", "torch", "torchvision"],
install_requires=[
"Pillow",
"Shapely",
"pdf2image",
"rtree",
"opencv-python-headless",
"torch",
"torchvision",
"numpy",
],
zip_safe=False,
)
16 changes: 16 additions & 0 deletions tests/test_google_azure_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
from ocr_wrapper import GoogleAzureOCR
from PIL import Image
from ocr_wrapper.google_azure_ocr import merge_idx_lists

filedir = os.path.dirname(__file__)
DATA_DIR = os.path.join(filedir, "data")
Expand Down Expand Up @@ -120,3 +121,18 @@ def test_azure_date_range_split(ocr):
for expected_date in expected_dates:
assert expected_date in texts
assert "-" in texts


@pytest.mark.parametrize(
"raw_a, raw_b, sorted_ab, expected",
[
([1, 2, 3, 4, 5], [6, 7, 8], [4, 2, 5, 6, 7, 3, 1, 8], [1, 8, 2, 3, 4, 5, 6, 7]),
([], [], [], []),
([], [2, 3, 4], [4, 2, 3], [4, 2, 3]),
([4, 5, 6], [], [5, 6, 4], [4, 5, 6]),
([1], [2, 3, 4, 5], [3, 4, 1, 5, 2], [3, 4, 1, 5, 2]),
],
)
def test_merge_idx_lists(raw_a, raw_b, sorted_ab, expected):
res = merge_idx_lists(raw_a, raw_b, sorted_ab)
assert res == expected

0 comments on commit 195ecb0

Please sign in to comment.