Skip to content

Commit

Permalink
Adds autoselect functionality for OCR engine based on environment var…
Browse files Browse the repository at this point in the history
…iable
  • Loading branch information
Paethon committed Nov 15, 2023
1 parent 38887b1 commit a2b0c93
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 5 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ The version numbers are according to [Semantic Versioning](http://semver.org/).
### Added
- Added the output of confidence scores to GoogleOCR
- Added multiple OCR passes to improve OCR reliability
- Adds an environment variable `OCR_WRAPPER_CACHE_FILE` to specify an ocr cache file globally

- Added an environment variable `OCR_WRAPPER_CACHE_FILE` to specify an ocr cache file globally
- Added an `autoselect_ocr_engine` function that selects the correct engine depending on the `OCR_PROVIDER` environment variable
### Changed
- Changed GoogleOCR to use WebP instead of PNG to transfer images to the cloud (reduces amount of transferred data by ~ 1/2)
### Fixed
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ Specific OCR engines may add other keys to return additional information about a

To easily visualize bounding boxes, the library also offers the method `draw_bboxes`. See `tryme.ipynb` for a minimal code example.

### Autoselect
The function `autoselect_ocr_engine()` can be used to automatically return the class for the needed OCR engine, using the `OCR_PROVIDER` environment variable. `google`, `azure`, `aws`, `easy`, and `paddle` are valid settings. If no provider is explicitly set, Google OCR is chosen by default.
In case an invalid OCR provider is specified, an `InvalidOcrProviderException` will be raised.

### GoogleOCR
The credentials for Google OCR will be obtained from one of the following sources, in this order:
- The environment variable `GOOGLE_APPLICATION_CREDENTIALS`
Expand Down
10 changes: 7 additions & 3 deletions ocr_wrapper/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from .aws import AwsOCR
from .azure import AzureOCR
from .google_ocr import GoogleOCR
from .paddleocr import PaddleOCR
from .easy_ocr import EasyOCR
from .bbox import BBox, draw_bboxes, get_label2color_dict
from .easy_ocr import EasyOCR
from .google_ocr import GoogleOCR
from .ocr_wrapper import OcrWrapper
from .paddleocr import PaddleOCR

# Important as last import, because it depends on the other modules
from .autoselect import autoselect_ocr_engine # isort:skip

__all__ = [
"AwsOCR",
Expand All @@ -15,5 +18,6 @@
"BBox",
"draw_bboxes",
"get_label2color_dict",
"autoselect_ocr_engine",
"OcrWrapper",
]
42 changes: 42 additions & 0 deletions ocr_wrapper/autoselect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Implements functionality to automatically select the correct OCR engine"""
from __future__ import annotations

import os
from typing import Optional

from ocr_wrapper import AwsOCR, AzureOCR, EasyOCR, GoogleOCR, OcrWrapper, PaddleOCR


class InvalidOcrProviderException(Exception):
"""Raised when an invalid OCR provider is selected"""

pass


name2engine = dict[str, type[OcrWrapper]](
google=GoogleOCR,
azure=AzureOCR,
aws=AwsOCR,
easy=EasyOCR,
paddle=PaddleOCR,
# For backwards compatibility
easyocr=EasyOCR,
paddleocr=PaddleOCR,
)


def autoselect_ocr_engine(name: Optional[str] = None) -> type[OcrWrapper]:
"""Automatically select the correct OCR engine based on the environment variable OCR_PROVIDER
Returns:
The OCR engine class (default if environment variable is not set: GoogleOCR)
"""
if name is not None:
provider = name
else:
provider = os.environ.get("OCR_PROVIDER", "google").lower()
provider_cls = name2engine.get(provider)
if provider_cls is None:
raise InvalidOcrProviderException(f"Invalid OCR provider {provider}. Select one of {name2engine.keys()}")

return provider_cls
32 changes: 32 additions & 0 deletions tests/test_autoselect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pytest
from ocr_wrapper import AwsOCR, AzureOCR, EasyOCR, GoogleOCR, PaddleOCR
from ocr_wrapper.autoselect import InvalidOcrProviderException, autoselect_ocr_engine


def test_default_ocr_engine(monkeypatch):
# Unset the OCR_PROVIDER environment variable if set
monkeypatch.delenv("OCR_PROVIDER", raising=False)

# When OCR_PROVIDER is not set, should default to GoogleOCR
assert autoselect_ocr_engine() is GoogleOCR


@pytest.mark.parametrize(
"provider, ocr_class",
[("google", GoogleOCR), ("azure", AzureOCR), ("aws", AwsOCR), ("easy", EasyOCR), ("paddle", PaddleOCR)],
)
def test_valid_ocr_provider(monkeypatch, provider, ocr_class):
# Set the OCR_PROVIDER environment variable to a valid provider
monkeypatch.setenv("OCR_PROVIDER", provider)

# Check if the correct OCR engine is returned
assert autoselect_ocr_engine() is ocr_class


def test_invalid_ocr_provider(monkeypatch):
# Set the OCR_PROVIDER environment variable to an invalid provider
monkeypatch.setenv("OCR_PROVIDER", "invalid_provider")

# Expect InvalidOcrProviderException to be raised with an unknown provider and check the message
with pytest.raises(InvalidOcrProviderException, match="Invalid OCR provider"):
autoselect_ocr_engine()

0 comments on commit a2b0c93

Please sign in to comment.