From 68c123ee718b6b68d10a8e7cc0e7a7bfd53dead3 Mon Sep 17 00:00:00 2001 From: Sebastian Stabinger Date: Thu, 9 Nov 2023 15:41:35 +0100 Subject: [PATCH] Adds autoselect functionality for OCR engine based on environment variable --- README.md | 4 ++++ ocr_wrapper/__init__.py | 8 +++++--- ocr_wrapper/autoselect.py | 37 +++++++++++++++++++++++++++++++++++++ tests/test_autoselect.py | 32 ++++++++++++++++++++++++++++++++ 4 files changed, 78 insertions(+), 3 deletions(-) create mode 100644 ocr_wrapper/autoselect.py create mode 100644 tests/test_autoselect.py diff --git a/README.md b/README.md index cc71914..908fff1 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,10 @@ The result will be a list of `BBox` instances. Each `BBox` contains the coordina To easily visualize bounding boxes, the library also offers the method `draw_bboxes`. +### Autoselect +The function `autoselect_ocr_engine()` can be used to automatically return the class for the needed OCR engine, using the `OCR_PROVIDER` environment variable. `google`, `azure`, `aws`, `easy`, and `paddle` are valid settings. If no provider is explicitly set, Google OCR is chosen by default. +In case an invalid OCR provider is specified, an `InvalidOcrProviderException` will be raised. + ### GoogleOCR The credentials for Google OCR will be obtained from one of the following: - The environment variable `GOOGLE_APPLICATION_CREDENTIALS` diff --git a/ocr_wrapper/__init__.py b/ocr_wrapper/__init__.py index c94b372..4d3af70 100644 --- a/ocr_wrapper/__init__.py +++ b/ocr_wrapper/__init__.py @@ -1,10 +1,11 @@ +from .autoselect import autoselect_ocr_engine from .aws import AwsOCR from .azure import AzureOCR -from .google_ocr import GoogleOCR -from .paddleocr import PaddleOCR -from .easy_ocr import EasyOCR from .bbox import BBox, draw_bboxes, get_label2color_dict +from .easy_ocr import EasyOCR +from .google_ocr import GoogleOCR from .ocr_wrapper import OcrWrapper +from .paddleocr import PaddleOCR __all__ = [ "AwsOCR", @@ -15,5 +16,6 @@ "BBox", "draw_bboxes", "get_label2color_dict", + "autoselect_ocr_engine", "OcrWrapper", ] diff --git a/ocr_wrapper/autoselect.py b/ocr_wrapper/autoselect.py new file mode 100644 index 0000000..b0520a5 --- /dev/null +++ b/ocr_wrapper/autoselect.py @@ -0,0 +1,37 @@ +"""Implements functionality to automatically select the correct OCR engine""" + +import os + +from ocr_wrapper import AwsOCR, AzureOCR, EasyOCR, GoogleOCR, OcrWrapper, PaddleOCR + + +class InvalidOcrProviderException(Exception): + """Raised when an invalid OCR provider is selected""" + + pass + + +name2engine = dict[str, type[OcrWrapper]]( + google=GoogleOCR, + azure=AzureOCR, + aws=AwsOCR, + easy=EasyOCR, + paddle=PaddleOCR, + # For backwards compatibility + easyocr=EasyOCR, + paddleocr=PaddleOCR, +) + + +def autoselect_ocr_engine() -> type[OcrWrapper]: + """Automatically select the correct OCR engine based on the environment variable OCR_PROVIDER + + Returns: + The OCR engine class (default if environment variable is not set: GoogleOCR) + """ + provider = os.environ.get("OCR_PROVIDER", "google").lower() + provider_cls = name2engine.get(provider) + if provider_cls is None: + raise InvalidOcrProviderException(f"Invalid OCR provider {provider}. Select one of {name2engine.keys()}") + + return provider_cls diff --git a/tests/test_autoselect.py b/tests/test_autoselect.py new file mode 100644 index 0000000..2693df7 --- /dev/null +++ b/tests/test_autoselect.py @@ -0,0 +1,32 @@ +import pytest +from ocr_wrapper import AwsOCR, AzureOCR, EasyOCR, GoogleOCR, PaddleOCR +from ocr_wrapper.autoselect import InvalidOcrProviderException, autoselect_ocr_engine + + +def test_default_ocr_engine(monkeypatch): + # Unset the OCR_PROVIDER environment variable if set + monkeypatch.delenv("OCR_PROVIDER", raising=False) + + # When OCR_PROVIDER is not set, should default to GoogleOCR + assert autoselect_ocr_engine() is GoogleOCR + + +@pytest.mark.parametrize( + "provider, ocr_class", + [("google", GoogleOCR), ("azure", AzureOCR), ("aws", AwsOCR), ("easy", EasyOCR), ("paddle", PaddleOCR)], +) +def test_valid_ocr_provider(monkeypatch, provider, ocr_class): + # Set the OCR_PROVIDER environment variable to a valid provider + monkeypatch.setenv("OCR_PROVIDER", provider) + + # Check if the correct OCR engine is returned + assert autoselect_ocr_engine() is ocr_class + + +def test_invalid_ocr_provider(monkeypatch): + # Set the OCR_PROVIDER environment variable to an invalid provider + monkeypatch.setenv("OCR_PROVIDER", "invalid_provider") + + # Expect InvalidOcrProviderException to be raised with an unknown provider and check the message + with pytest.raises(InvalidOcrProviderException, match="Invalid OCR provider"): + autoselect_ocr_engine()