Skip to content

Commit

Permalink
Fixed a bug when trying to use multi-pass for OCR engines that don't …
Browse files Browse the repository at this point in the history
…support it

The system now checks if the OCR engine supports multiple samples before attempting to use multi-pass. If the engine doesn't support it, a warning message is displayed and the system falls back to single-pass. This change was necessary to prevent errors when trying to use multi-pass with unsupported OCR engines.
  • Loading branch information
Paethon committed Nov 13, 2023
1 parent dfdb037 commit f7cec03
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 6 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ The version numbers are according to [Semantic Versioning](http://semver.org/).
### Fixed
- Adds forced conversion to RGB in pillow before sending data to OpenCV to fix a possible bug in Studio
- Fixes a rare bug where self-intersecting bounding boxes caused the OCR system to crash when using multi-pass OCR

- Fixed a problem whene trying to use multi-pass with OCR engines that don't support it yet. Now the system will return a warning message and use the single-pass option instead. (Currently only GoogleOCR is supported for multi-pass)
### Removed


Expand Down
15 changes: 11 additions & 4 deletions ocr_wrapper/google_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def get_word_and_language_codes(response):
return word_and_language_codes


def _get_words_bboxes_confidences(response):
def _get_words_bboxes_confidences(response: vision.AnnotateImageResponse):
"""Given an ocr response, returns a list of tuples of word bounding boxes and confidences"""
words, bboxes, confidences = [], [], []

Expand Down Expand Up @@ -165,7 +165,12 @@ def __init__(
verbose: bool = False,
):
super().__init__(
cache_file=cache_file, max_size=max_size, auto_rotate=auto_rotate, ocr_samples=ocr_samples, verbose=verbose
cache_file=cache_file,
max_size=max_size,
auto_rotate=auto_rotate,
ocr_samples=ocr_samples,
supports_multi_samples=True,
verbose=verbose,
)
# Get credentials from environment variable of the offered default locations
if not os.environ.get("GOOGLE_APPLICATION_CREDENTIALS"):
Expand All @@ -178,7 +183,7 @@ def __init__(
self.client = vision.ImageAnnotatorClient(client_options={"api_endpoint": endpoint})

@requires_gcloud
def _get_ocr_response(self, img: Image.Image):
def _get_ocr_response(self, img: Image.Image) -> vision.AnnotateImageResponse:
"""Gets the OCR response from the Google cloud. Uses cached response if a cache file has been specified and the
document has been OCRed already"""
# Pack image in correct format
Expand All @@ -203,7 +208,9 @@ def _get_ocr_response(self, img: Image.Image):
return response

@requires_gcloud
def _convert_ocr_response(self, img, response) -> list[dict[str, Union[BBox, str, float]]]:
def _convert_ocr_response(
self, img: Image.Image, response: vision.AnnotateImageResponse
) -> list[dict[str, Union[BBox, str, float]]]:
"""Converts the response given by Google OCR to a list of BBox"""
# Iterate over all responses except the first. The first is for the whole document -> ignore
words, bboxes, confidences = _get_words_bboxes_confidences(response)
Expand Down
10 changes: 9 additions & 1 deletion ocr_wrapper/ocr_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
max_size: Optional[int] = 1024,
auto_rotate: bool = False,
ocr_samples: int = 2,
supports_multi_samples: bool = False,
verbose: bool = False,
):
if cache_file is None:
Expand All @@ -52,6 +53,7 @@ def __init__(
self.max_size = max_size
self.auto_rotate = auto_rotate
self.ocr_samples = ocr_samples
self.supports_multi_samples = supports_multi_samples
self.verbose = verbose
self.extra = {} # Extra information to be returned by ocr()

Expand All @@ -78,7 +80,13 @@ def ocr(
if self.max_size is not None:
img = self._resize_image(img, self.max_size)
# Get response from an OCR engine
result = self._get_multi_response(img)
if self.ocr_samples == 1 or not self.supports_multi_samples:
if self.ocr_samples > 1 and self.verbose:
print("Warning: This OCR engine does not support multiple samples. Using only one sample.")
ocr = self._get_ocr_response(img)
result = self._convert_ocr_response(img, ocr)
else:
result = self._get_multi_response(img)

if self.auto_rotate and "document_rotation" in self.extra:
angle = self.extra["document_rotation"]
Expand Down
22 changes: 22 additions & 0 deletions tests/test_googleocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ def ocr_with_auto_rotate():
return GoogleOCR(auto_rotate=True, ocr_samples=2)


@pytest.fixture
def ocr_forced_single_response():
ocr = GoogleOCR(auto_rotate=True, ocr_samples=2)
ocr.supports_multi_samples = False
return ocr


# Fixture for unrotated bboxes
@pytest.fixture
def unrotated_bboxes(ocr):
Expand All @@ -51,6 +58,21 @@ def test_google_ocr(ocr):
assert all([r["bbox"].original_size == img.size for r in res])


def test_google_ocr_forced_single_response(ocr_forced_single_response, mocker):
single_response_spy = mocker.spy(ocr_forced_single_response, "_get_ocr_response")
multi_response_spy = mocker.spy(ocr_forced_single_response, "_get_multi_response")

img_path = os.path.join(DATA_DIR, "ocr_test_big.png")
with Image.open(img_path) as img:
res = ocr_forced_single_response.ocr(img)
text = " ".join([r["text"] for r in res])
assert text == "This is a test ."
assert all([r["bbox"].original_size == img.size for r in res])

single_response_spy.assert_called_once()
multi_response_spy.assert_not_called()


def test_google_orc_single_sample():
img = Image.open(os.path.join(DATA_DIR, "ocr_test_big.png"))
ocr = GoogleOCR(auto_rotate=True, ocr_samples=1)
Expand Down

0 comments on commit f7cec03

Please sign in to comment.