Skip to content

Commit

Permalink
Adds autoselect functionality for OCR engine based on environment var…
Browse files Browse the repository at this point in the history
…iable
  • Loading branch information
Paethon committed Nov 9, 2023
1 parent f7132e3 commit 68c123e
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 3 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ The result will be a list of `BBox` instances. Each `BBox` contains the coordina

To easily visualize bounding boxes, the library also offers the method `draw_bboxes`.

### Autoselect
The function `autoselect_ocr_engine()` can be used to automatically return the class for the needed OCR engine, using the `OCR_PROVIDER` environment variable. `google`, `azure`, `aws`, `easy`, and `paddle` are valid settings. If no provider is explicitly set, Google OCR is chosen by default.
In case an invalid OCR provider is specified, an `InvalidOcrProviderException` will be raised.

### GoogleOCR
The credentials for Google OCR will be obtained from one of the following:
- The environment variable `GOOGLE_APPLICATION_CREDENTIALS`
Expand Down
8 changes: 5 additions & 3 deletions ocr_wrapper/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from .autoselect import autoselect_ocr_engine
from .aws import AwsOCR
from .azure import AzureOCR
from .google_ocr import GoogleOCR
from .paddleocr import PaddleOCR
from .easy_ocr import EasyOCR
from .bbox import BBox, draw_bboxes, get_label2color_dict
from .easy_ocr import EasyOCR
from .google_ocr import GoogleOCR
from .ocr_wrapper import OcrWrapper
from .paddleocr import PaddleOCR

__all__ = [
"AwsOCR",
Expand All @@ -15,5 +16,6 @@
"BBox",
"draw_bboxes",
"get_label2color_dict",
"autoselect_ocr_engine",
"OcrWrapper",
]
37 changes: 37 additions & 0 deletions ocr_wrapper/autoselect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Implements functionality to automatically select the correct OCR engine"""

import os

from ocr_wrapper import AwsOCR, AzureOCR, EasyOCR, GoogleOCR, OcrWrapper, PaddleOCR


class InvalidOcrProviderException(Exception):
"""Raised when an invalid OCR provider is selected"""

pass


name2engine = dict[str, type[OcrWrapper]](
google=GoogleOCR,
azure=AzureOCR,
aws=AwsOCR,
easy=EasyOCR,
paddle=PaddleOCR,
# For backwards compatibility
easyocr=EasyOCR,
paddleocr=PaddleOCR,
)


def autoselect_ocr_engine() -> type[OcrWrapper]:
"""Automatically select the correct OCR engine based on the environment variable OCR_PROVIDER
Returns:
The OCR engine class (default if environment variable is not set: GoogleOCR)
"""
provider = os.environ.get("OCR_PROVIDER", "google").lower()
provider_cls = name2engine.get(provider)
if provider_cls is None:
raise InvalidOcrProviderException(f"Invalid OCR provider {provider}. Select one of {name2engine.keys()}")

return provider_cls
32 changes: 32 additions & 0 deletions tests/test_autoselect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pytest
from ocr_wrapper import AwsOCR, AzureOCR, EasyOCR, GoogleOCR, PaddleOCR
from ocr_wrapper.autoselect import InvalidOcrProviderException, autoselect_ocr_engine


def test_default_ocr_engine(monkeypatch):
# Unset the OCR_PROVIDER environment variable if set
monkeypatch.delenv("OCR_PROVIDER", raising=False)

# When OCR_PROVIDER is not set, should default to GoogleOCR
assert autoselect_ocr_engine() is GoogleOCR


@pytest.mark.parametrize(
"provider, ocr_class",
[("google", GoogleOCR), ("azure", AzureOCR), ("aws", AwsOCR), ("easy", EasyOCR), ("paddle", PaddleOCR)],
)
def test_valid_ocr_provider(monkeypatch, provider, ocr_class):
# Set the OCR_PROVIDER environment variable to a valid provider
monkeypatch.setenv("OCR_PROVIDER", provider)

# Check if the correct OCR engine is returned
assert autoselect_ocr_engine() is ocr_class


def test_invalid_ocr_provider(monkeypatch):
# Set the OCR_PROVIDER environment variable to an invalid provider
monkeypatch.setenv("OCR_PROVIDER", "invalid_provider")

# Expect InvalidOcrProviderException to be raised with an unknown provider and check the message
with pytest.raises(InvalidOcrProviderException, match="Invalid OCR provider"):
autoselect_ocr_engine()

0 comments on commit 68c123e

Please sign in to comment.