-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds autoselect functionality for OCR engine based on environment var…
…iable
- Loading branch information
Showing
5 changed files
with
87 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
"""Implements functionality to automatically select the correct OCR engine""" | ||
from __future__ import annotations | ||
|
||
import os | ||
from typing import Optional | ||
|
||
from ocr_wrapper import AwsOCR, AzureOCR, EasyOCR, GoogleOCR, OcrWrapper, PaddleOCR | ||
|
||
|
||
class InvalidOcrProviderException(Exception): | ||
"""Raised when an invalid OCR provider is selected""" | ||
|
||
pass | ||
|
||
|
||
name2engine = dict[str, type[OcrWrapper]]( | ||
google=GoogleOCR, | ||
azure=AzureOCR, | ||
aws=AwsOCR, | ||
easy=EasyOCR, | ||
paddle=PaddleOCR, | ||
# For backwards compatibility | ||
easyocr=EasyOCR, | ||
paddleocr=PaddleOCR, | ||
) | ||
|
||
|
||
def autoselect_ocr_engine(name: Optional[str] = None) -> type[OcrWrapper]: | ||
"""Automatically select the correct OCR engine based on the environment variable OCR_PROVIDER | ||
Returns: | ||
The OCR engine class (default if environment variable is not set: GoogleOCR) | ||
""" | ||
if name is not None: | ||
provider = name | ||
else: | ||
provider = os.environ.get("OCR_PROVIDER", "google").lower() | ||
provider_cls = name2engine.get(provider) | ||
if provider_cls is None: | ||
raise InvalidOcrProviderException(f"Invalid OCR provider {provider}. Select one of {name2engine.keys()}") | ||
|
||
return provider_cls |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import pytest | ||
from ocr_wrapper import AwsOCR, AzureOCR, EasyOCR, GoogleOCR, PaddleOCR | ||
from ocr_wrapper.autoselect import InvalidOcrProviderException, autoselect_ocr_engine | ||
|
||
|
||
def test_default_ocr_engine(monkeypatch): | ||
# Unset the OCR_PROVIDER environment variable if set | ||
monkeypatch.delenv("OCR_PROVIDER", raising=False) | ||
|
||
# When OCR_PROVIDER is not set, should default to GoogleOCR | ||
assert autoselect_ocr_engine() is GoogleOCR | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"provider, ocr_class", | ||
[("google", GoogleOCR), ("azure", AzureOCR), ("aws", AwsOCR), ("easy", EasyOCR), ("paddle", PaddleOCR)], | ||
) | ||
def test_valid_ocr_provider(monkeypatch, provider, ocr_class): | ||
# Set the OCR_PROVIDER environment variable to a valid provider | ||
monkeypatch.setenv("OCR_PROVIDER", provider) | ||
|
||
# Check if the correct OCR engine is returned | ||
assert autoselect_ocr_engine() is ocr_class | ||
|
||
|
||
def test_invalid_ocr_provider(monkeypatch): | ||
# Set the OCR_PROVIDER environment variable to an invalid provider | ||
monkeypatch.setenv("OCR_PROVIDER", "invalid_provider") | ||
|
||
# Expect InvalidOcrProviderException to be raised with an unknown provider and check the message | ||
with pytest.raises(InvalidOcrProviderException, match="Invalid OCR provider"): | ||
autoselect_ocr_engine() |