Skip to content

Commit

Permalink
Add multi-image OCR processing and update GoogleAzureOCR max_size
Browse files Browse the repository at this point in the history
• Implement multi_img_ocr method in OCR wrappers to handle multiple images
• Increase max_size parameter in GoogleAzureOCR from 1024 to 4096
• Make interpolate_point function private and update its usage in bbox_utils.py
• Add documentation and type hints to several methods and classes
• Adjust default max_workers in multi_img_ocr and allow customization
• Fix endpoint assignment in GoogleOCR constructor
• Add and update docstrings for clarity and consistency across modules
  • Loading branch information
Paethon committed Jan 29, 2024
1 parent 796a562 commit 5a7a624
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 13 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The version numbers are according to [Semantic Versioning](http://semver.org/).
## Next Release
### Added
- Added new OCR wrapper that combines Google OCR and Azure OCR to compensate shortcomings of Google OCR
- Added new method `multi_img_ocr` to all OCR wrappers to be able to process multiple images at the same time
### Changed

### Fixed
Expand Down
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,6 @@ Credentials etc. for Azure OCR will be obtained from one of the following (in th
- From the `endpoint` and `key` arguments when creating AzureOCR
- From the environment variables `AZURE_OCR_ENDPOINT` and `AZURE_OCR_KEY`
- From the credentials file `~/.config/azure/ocr_credentials.json` that contains the keys `endpoint` and `key`

### GoogleAzureOCR
Credentials for GoogleOCR as well as AzureOCR have to be set
22 changes: 17 additions & 5 deletions ocr_wrapper/bbox_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,32 @@
import math


def interpolate_point(A: tuple[float, float], B: tuple[float, float], ratio: float) -> tuple[float, float]:
def _interpolate_point(A: tuple[float, float], B: tuple[float, float], ratio: float) -> tuple[float, float]:
"""Returns the point along the line AB at the given ratio"""
return (A[0] + ratio * (B[0] - A[0]), A[1] + ratio * (B[1] - A[1]))


def split_bbox(bbox: BBox, ratio: float) -> tuple[BBox, BBox]:
"""
Splits a bounding box along the longer edge at the given ratio.
Args:
bbox: The bounding box.
ratio: The ratio at which to split the bounding box.
Returns:
A tuple containing the two resulting bounding boxes. Text, label, and is_pixels are copied
from the original bounding box.
"""
# Calculate lengths of top and side edges
top_length = math.sqrt((bbox.TRx - bbox.TLx) ** 2 + (bbox.TRy - bbox.TLy) ** 2)
side_length = math.sqrt((bbox.BLx - bbox.TLx) ** 2 + (bbox.BLy - bbox.TLy) ** 2)

# Determine longer edge and split points
if top_length >= side_length:
# Splitting along the top edge
new_top_point = interpolate_point((bbox.TLx, bbox.TLy), (bbox.TRx, bbox.TRy), ratio)
new_bottom_point = interpolate_point((bbox.BLx, bbox.BLy), (bbox.BRx, bbox.BRy), ratio)
new_top_point = _interpolate_point((bbox.TLx, bbox.TLy), (bbox.TRx, bbox.TRy), ratio)
new_bottom_point = _interpolate_point((bbox.BLx, bbox.BLy), (bbox.BRx, bbox.BRy), ratio)
bbox1 = BBox(
bbox.TLx,
bbox.TLy,
Expand All @@ -47,8 +59,8 @@ def split_bbox(bbox: BBox, ratio: float) -> tuple[BBox, BBox]:
)
else:
# Splitting along the side edge
new_left_point = interpolate_point((bbox.TLx, bbox.TLy), (bbox.BLx, bbox.BLy), ratio)
new_right_point = interpolate_point((bbox.TRx, bbox.TRy), (bbox.BRx, bbox.BRy), ratio)
new_left_point = _interpolate_point((bbox.TLx, bbox.TLy), (bbox.BLx, bbox.BLy), ratio)
new_right_point = _interpolate_point((bbox.TRx, bbox.TRy), (bbox.BRx, bbox.BRy), ratio)
bbox1 = BBox(
bbox.TLx,
bbox.TLy,
Expand Down
23 changes: 17 additions & 6 deletions ocr_wrapper/google_azure_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
cache_file: Optional[str] = None,
ocr_samples: Optional[int] = None,
supports_multi_samples: bool = False,
max_size: Optional[int] = 1024,
max_size: Optional[int] = 4096,
auto_rotate: Optional[bool] = None,
correct_tilt: Optional[bool] = None,
verbose: bool = False,
Expand Down Expand Up @@ -128,16 +128,17 @@ def ocr(self, img: Image.Image, return_extra: bool = False) -> Union[list[BBox],
return result

def multi_img_ocr(
self, imgs: list[Image.Image], return_extra: bool = False
self, imgs: list[Image.Image], return_extra: bool = False, max_workers: int = 32
) -> Union[list[list[BBox]], tuple[list[list[BBox]], list[dict]]]:
"""Runs OCR in parallel on multiple images using both Google and Azure OCR, and combines the results.
Args:
img (list[Image.Image]): The pages to run OCR on.
return_extra (bool, optional): Whether to return extra information. Defaults to False.
max_workers (int, optional): The maximum number of threads to use. Defaults to 32.
"""
# Execute self.ocr in parallel on all images
with ThreadPoolExecutor(max_workers=32) as executor:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(self.ocr, img, return_extra) for img in imgs]
results = [future.result() for future in futures]

Expand Down Expand Up @@ -168,6 +169,15 @@ def _put_on_shelf(self, img_hash: str, return_extra: bool, response):


class BBoxOverlapChecker:
"""
Class to check whether a bbox overlaps with any of a list of bboxes.
Uses an RTree to quickly find overlapping bboxes.
Args:
bboxes (list[BBox]): The bboxes that will be checked for overlap against
"""

def __init__(self, bboxes: list[BBox]):
self.bboxes = bboxes
self.rtree = rtree.index.Index()
Expand All @@ -178,7 +188,7 @@ def get_overlapping_bboxes(self, bbox: BBox, threshold: float = 0.01) -> list[BB
"""Returns the bboxes that overlap with the given bbox.
Args:
bbox (BBox): The bbox to check for overlapping bboxes.
bbox (BBox): The bbox to check for overlap.
threshold (float, optional): The minimum overlap that is required for a bbox to be returned (0.0 to 1.0).
Defaults to 0.01. Overlap is checked in both directions.
Expand Down Expand Up @@ -230,7 +240,7 @@ def _filter_date_boxes(bboxes: list[BBox], max_boxes_range: int = 10) -> list[BB
Args:
bboxes (list[BBox]): The bboxes to filter.
max_boxes_range (int, optional): The maximum number of bboxes to consider for a match. Defaults to 15.
max_boxes_range (int, optional): The maximum number of bboxes to consider for a match. Defaults to 10.
"""
max_boxes_range = min(max_boxes_range, len(bboxes))
date_range_pattern = r"^\s*\d{1,2}\s*/\s*\d{1,2}\s*/\s*\d{4}\s*-\s*\d{1,2}\s*/\s*\d{1,2}\s*/\s*\d{4}\s*$"
Expand Down Expand Up @@ -266,6 +276,7 @@ def _filter_unwanted_google_bboxes(bboxes: list[BBox], width_height_ratio: float
Currently does the following filtering:
- Removes bboxes with an area that is bigger than the mean area of all bboxes in the list and that are vertically aligned
- Filters out bounding boxes that, concatenated, match patterns like "dd/mm/yyyy - dd/mm/yyyy".
Args:
bboxes (list[BBox]): The bboxes to filter.
Expand All @@ -285,7 +296,7 @@ def _filter_unwanted_google_bboxes(bboxes: list[BBox], width_height_ratio: float

def _split_azure_date_boxes(bboxes: list[BBox]) -> list[BBox]:
"""
Splits date boxes that contain a date range of the format "dd/mm/yyyy - dd/mm/yyyy" into two separate boxes.
Splits date boxes that contain a date range of the format "dd/mm/yyyy - dd/mm/yyyy" into three separate boxes.
Args:
bboxes (list[BBox]): The bboxes to filter.
Expand Down
2 changes: 1 addition & 1 deletion ocr_wrapper/google_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def __init__(
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.expanduser(credentials_path)
# Create the client with the specified endpoint
self.endpoint = endpoint
self.client = vision.ImageAnnotatorClient(client_options={"api_endpoint": endpoint})
self.client = vision.ImageAnnotatorClient(client_options={"api_endpoint": self.endpoint})

@requires_gcloud
def _get_ocr_response(self, img: Image.Image):
Expand Down
3 changes: 2 additions & 1 deletion ocr_wrapper/ocr_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def ocr(self, img: Image.Image, return_extra: bool = False) -> Union[list[BBox],
return bboxes

def multi_img_ocr(
self, imgs: list[Image.Image], return_extra: bool = False
self, imgs: list[Image.Image], return_extra: bool = False, max_workers: int = 32
) -> Union[list[list[BBox]], tuple[list[list[BBox]], list[dict]]]:
"""Returns OCR result for a list of images instead of a single image.
Expand All @@ -120,6 +120,7 @@ def multi_img_ocr(
Args:
imgs: Images to be processed
return_extra: If True, returns a tuple of (bboxes, extra) where extra is a list of dicts containing extra information
max_workers: Maximum number of threads to use for parallel processing
"""
results = []
for img in imgs:
Expand Down

0 comments on commit 5a7a624

Please sign in to comment.