diff --git a/CHANGELOG.md b/CHANGELOG.md index 14b9865..8322bea 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,7 +13,7 @@ The version numbers are according to [Semantic Versioning](http://semver.org/). ### Fixed - Adds forced conversion to RGB in pillow before sending data to OpenCV to fix a possible bug in Studio - Fixes a rare bug where self-intersecting bounding boxes caused the OCR system to crash when using multi-pass OCR - +- Fixed a problem whene trying to use multi-pass with OCR engines that don't support it yet. Now the system will return a warning message and use the single-pass option instead. (Currently only GoogleOCR is supported for multi-pass) ### Removed diff --git a/ocr_wrapper/google_ocr.py b/ocr_wrapper/google_ocr.py index 17d7dbb..1deedc4 100644 --- a/ocr_wrapper/google_ocr.py +++ b/ocr_wrapper/google_ocr.py @@ -104,7 +104,7 @@ def get_word_and_language_codes(response): return word_and_language_codes -def _get_words_bboxes_confidences(response): +def _get_words_bboxes_confidences(response: vision.AnnotateImageResponse): """Given an ocr response, returns a list of tuples of word bounding boxes and confidences""" words, bboxes, confidences = [], [], [] @@ -165,7 +165,12 @@ def __init__( verbose: bool = False, ): super().__init__( - cache_file=cache_file, max_size=max_size, auto_rotate=auto_rotate, ocr_samples=ocr_samples, verbose=verbose + cache_file=cache_file, + max_size=max_size, + auto_rotate=auto_rotate, + ocr_samples=ocr_samples, + supports_multi_samples=True, + verbose=verbose, ) # Get credentials from environment variable of the offered default locations if not os.environ.get("GOOGLE_APPLICATION_CREDENTIALS"): @@ -178,7 +183,7 @@ def __init__( self.client = vision.ImageAnnotatorClient(client_options={"api_endpoint": endpoint}) @requires_gcloud - def _get_ocr_response(self, img: Image.Image): + def _get_ocr_response(self, img: Image.Image) -> vision.AnnotateImageResponse: """Gets the OCR response from the Google cloud. Uses cached response if a cache file has been specified and the document has been OCRed already""" # Pack image in correct format @@ -203,7 +208,9 @@ def _get_ocr_response(self, img: Image.Image): return response @requires_gcloud - def _convert_ocr_response(self, img, response) -> list[dict[str, Union[BBox, str, float]]]: + def _convert_ocr_response( + self, img: Image.Image, response: vision.AnnotateImageResponse + ) -> list[dict[str, Union[BBox, str, float]]]: """Converts the response given by Google OCR to a list of BBox""" # Iterate over all responses except the first. The first is for the whole document -> ignore words, bboxes, confidences = _get_words_bboxes_confidences(response) diff --git a/ocr_wrapper/ocr_wrapper.py b/ocr_wrapper/ocr_wrapper.py index e22347d..b1348f5 100644 --- a/ocr_wrapper/ocr_wrapper.py +++ b/ocr_wrapper/ocr_wrapper.py @@ -44,6 +44,7 @@ def __init__( max_size: Optional[int] = 1024, auto_rotate: bool = False, ocr_samples: int = 2, + supports_multi_samples: bool = False, verbose: bool = False, ): if cache_file is None: @@ -52,6 +53,7 @@ def __init__( self.max_size = max_size self.auto_rotate = auto_rotate self.ocr_samples = ocr_samples + self.supports_multi_samples = supports_multi_samples self.verbose = verbose self.extra = {} # Extra information to be returned by ocr() @@ -78,7 +80,13 @@ def ocr( if self.max_size is not None: img = self._resize_image(img, self.max_size) # Get response from an OCR engine - result = self._get_multi_response(img) + if self.ocr_samples == 1 or not self.supports_multi_samples: + if self.ocr_samples > 1 and self.verbose: + print("Warning: This OCR engine does not support multiple samples. Using only one sample.") + ocr = self._get_ocr_response(img) + result = self._convert_ocr_response(img, ocr) + else: + result = self._get_multi_response(img) if self.auto_rotate and "document_rotation" in self.extra: angle = self.extra["document_rotation"] diff --git a/tests/test_googleocr.py b/tests/test_googleocr.py index 120b84a..055e441 100644 --- a/tests/test_googleocr.py +++ b/tests/test_googleocr.py @@ -36,6 +36,13 @@ def ocr_with_auto_rotate(): return GoogleOCR(auto_rotate=True, ocr_samples=2) +@pytest.fixture +def ocr_forced_single_response(): + ocr = GoogleOCR(auto_rotate=True, ocr_samples=2) + ocr.supports_multi_samples = False + return ocr + + # Fixture for unrotated bboxes @pytest.fixture def unrotated_bboxes(ocr): @@ -51,6 +58,21 @@ def test_google_ocr(ocr): assert all([r["bbox"].original_size == img.size for r in res]) +def test_google_ocr_forced_single_response(ocr_forced_single_response, mocker): + single_response_spy = mocker.spy(ocr_forced_single_response, "_get_ocr_response") + multi_response_spy = mocker.spy(ocr_forced_single_response, "_get_multi_response") + + img_path = os.path.join(DATA_DIR, "ocr_test_big.png") + with Image.open(img_path) as img: + res = ocr_forced_single_response.ocr(img) + text = " ".join([r["text"] for r in res]) + assert text == "This is a test ." + assert all([r["bbox"].original_size == img.size for r in res]) + + single_response_spy.assert_called_once() + multi_response_spy.assert_not_called() + + def test_google_orc_single_sample(): img = Image.open(os.path.join(DATA_DIR, "ocr_test_big.png")) ocr = GoogleOCR(auto_rotate=True, ocr_samples=1)