Skip to content

Commit

Permalink
Merge branch 'main' into improvement/remove_torch
Browse files Browse the repository at this point in the history
  • Loading branch information
Paethon committed Oct 29, 2024
2 parents 2224367 + 9693f49 commit 58de9a9
Show file tree
Hide file tree
Showing 10 changed files with 85 additions and 155 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ The version numbers are according to [Semantic Versioning](http://semver.org/).

### Removed


## Release v0.2.3 (2024-10-29)
### Changed
- OCR provider specific dependencies are now being loaded lazily to reduce the import time

## Release v0.2.2 (2024-10-03)
### Added
- Added OpenTelemetry to `GoogleAzureOCR`, `GoogleOCR`, `AzureOCR`, and `OcrWrapper` to enable tracing of the OCR process
Expand Down
32 changes: 11 additions & 21 deletions ocr_wrapper/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,8 @@
from .bbox import BBox
from .ocr_wrapper import OcrCacheDisabled, OcrWrapper

try:
import boto3
except ImportError:
_has_boto3 = False
else:
_has_boto3 = True


def requires_boto(fn):
@functools.wraps(fn)
def wrapper_decocator(*args, **kwargs):
if not _has_boto3:
raise ImportError('AWS Textract requires missing "boto3" package.')
return fn(*args, **kwargs)

return wrapper_decocator


class AwsOCR(OcrWrapper):
@requires_boto
def __init__(
self,
*,
Expand All @@ -34,10 +16,19 @@ def __init__(
add_checkboxes: bool = False,
verbose: bool = False
):
super().__init__(cache_file=cache_file, max_size=max_size, add_checkboxes=add_checkboxes, verbose=verbose)
try:
import boto3
except ImportError:
raise ImportError('AwsOCR requires missing "boto3" package.')

super().__init__(
cache_file=cache_file,
max_size=max_size,
add_checkboxes=add_checkboxes,
verbose=verbose,
)
self.client = boto3.client("textract", region_name="eu-central-1")

@requires_boto
def _get_ocr_response(self, img: Image.Image):
"""Gets the OCR response from AWS. Uses cached response if a cache file has been specified and the
document has been OCRed already"""
Expand All @@ -52,7 +43,6 @@ def _get_ocr_response(self, img: Image.Image):
self._put_on_shelf(img, response)
return response

@requires_boto
def _convert_ocr_response(self, response) -> tuple[List[BBox], dict[str, Any]]:
"""Converts the response given by Google OCR to a list of BBox"""
bboxes = []
Expand Down
41 changes: 14 additions & 27 deletions ocr_wrapper/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,9 @@
from .ocr_wrapper import OcrCacheDisabled, OcrWrapper, Union
from .utils import set_image_attributes

try:
from msrest.authentication import CognitiveServicesCredentials
from msrest.exceptions import ClientRequestError

from azure.cognitiveservices.vision.computervision import ComputerVisionClient
from azure.cognitiveservices.vision.computervision.models import (
ComputerVisionOcrErrorException,
OperationStatusCodes,
)
except ImportError:
_has_azure = False
else:
_has_azure = True

tracer = trace.get_tracer(__name__)


def requires_azure(fn):
@functools.wraps(fn)
def wrapper_decocator(*args, **kwargs):
if not _has_azure:
raise ImportError('Azure Read requires missing "azure-cognitiveservices-vision-computervision" package.')
return fn(*args, **kwargs)

return wrapper_decocator


def _discretize_angle_to_90_deg(rotation: float) -> int:
"""Discretize an angle to the nearest 90 degrees"""
return int(((rotation + 45) // 90 * 90) % 360)
Expand Down Expand Up @@ -70,7 +46,6 @@ def _determine_endpoint_and_key(endpoint: Optional[str], key: Optional[str]) ->


class AzureOCR(OcrWrapper):
@requires_azure
def __init__(
self,
*,
Expand All @@ -85,6 +60,11 @@ def __init__(
add_qr_barcodes: bool = False,
verbose: bool = False,
):
try:
from msrest.authentication import CognitiveServicesCredentials
from azure.cognitiveservices.vision.computervision import ComputerVisionClient
except ImportError:
raise ImportError('AzureOCR requires missing "azure-cognitiveservices-vision-computervision" package.')
super().__init__(
cache_file=cache_file,
max_size=max_size,
Expand All @@ -99,11 +79,19 @@ def __init__(
endpoint, key = _determine_endpoint_and_key(endpoint, key)
self.client = ComputerVisionClient(endpoint, CognitiveServicesCredentials(key))

@requires_azure
@tracer.start_as_current_span(name="AzureOCR.get_ocr_response")
def _get_ocr_response(self, img: Image.Image):
"""Gets the OCR response from the Azure. Uses cached response if a cache file has been specified and the
document has been OCRed already"""
try:
from msrest.exceptions import ClientRequestError
from azure.cognitiveservices.vision.computervision.models import (
ComputerVisionOcrErrorException,
OperationStatusCodes,
)
except ImportError:
raise ImportError('AzureOCR requires missing "azure-cognitiveservices-vision-computervision" package.')

span = trace.get_current_span()
set_image_attributes(span, img)

Expand Down Expand Up @@ -158,7 +146,6 @@ def _get_ocr_response(self, img: Image.Image):
self._put_on_shelf(img, read_result)
return read_result

@requires_azure
@tracer.start_as_current_span(name="AzureOCR._convert_ocr_response")
def _convert_ocr_response(self, response) -> tuple[List[BBox], dict[str, Any]]:
"""Converts the response given by Azure Read to a list of BBox"""
Expand Down
25 changes: 5 additions & 20 deletions ocr_wrapper/easy_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,8 @@
from .bbox import BBox
from .ocr_wrapper import OcrWrapper, OcrCacheDisabled

try:
import easyocr
except ImportError:
_has_easyocr = False
else:
_has_easyocr = True


def requires_easyocr(fn):
@functools.wraps(fn)
def wrapper_decocator(*args, **kwargs):
if not _has_easyocr:
raise ImportError('Easy OCR requires missing "easyocr" package.')
return fn(*args, **kwargs)

return wrapper_decocator


class EasyOCR(OcrWrapper):
@requires_easyocr
def __init__(
self,
*,
Expand All @@ -46,6 +28,11 @@ def __init__(
languages: A string or a list of languages to use for OCR from the list here: https://www.jaided.ai/easyocr/
width_thr: Distance where bounding boxes are still getting merged into one
"""
try:
import easyocr
except ImportError:
raise ImportError('EasyOCR requires missing "easyocr" package.')

super().__init__(
cache_file=cache_file,
max_size=max_size,
Expand All @@ -58,7 +45,6 @@ def __init__(

self.client = easyocr.Reader(self.languages, **kwargs)

@requires_easyocr
def _get_ocr_response(self, img: Image.Image):
"""Gets the OCR response from EasyOCR. Uses a cached response if a cache file has been specified and the
document has been OCRed already"""
Expand All @@ -69,7 +55,6 @@ def _get_ocr_response(self, img: Image.Image):
self._put_on_shelf(img, response)
return response

@requires_easyocr
def _convert_ocr_response(self, response) -> tuple[List[BBox], dict[str, Any]]:
"""Converts the response given by EasyOCR to a list of BBox"""
bboxes, confidences = [], []
Expand Down
2 changes: 0 additions & 2 deletions ocr_wrapper/google_azure_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,6 @@ def __init__(self, bboxes: list[BBox]):
for i, bbox in enumerate(bboxes):
self.rtree.insert(i, bbox.get_shapely_polygon().bounds)

@tracer.start_as_current_span(name="BBoxOverlapChecker.get_overlapping_bboxes")
def get_overlapping_bboxes(self, bbox: BBox, threshold: float = 0.01) -> list[BBox]:
"""Returns the bboxes that overlap with the given bbox.
Expand All @@ -318,7 +317,6 @@ def get_overlapping_bboxes(self, bbox: BBox, threshold: float = 0.01) -> list[BB
):
overlapping_bboxes.append(self.bboxes[i])

span.set_attribute("nb_overlapping_bboxes", len(overlapping_bboxes))
return overlapping_bboxes


Expand Down
36 changes: 14 additions & 22 deletions ocr_wrapper/google_document_ocr_checkbox_detector.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from __future__ import annotations

import functools
import os
from typing import Optional, Union
from typing import TYPE_CHECKING, Optional, Union

from PIL import Image

Expand All @@ -11,25 +10,9 @@

from .utils import resize_image

try:
from google.api_core.client_options import ClientOptions
# Load the Google Cloud Document AI client library globally only for type checking (needed for argument types)
if TYPE_CHECKING:
from google.cloud import documentai
except ImportError:
_has_gcloud = False
else:
_has_gcloud = True


def requires_gcloud(fn):
@functools.wraps(fn)
def wrapper_decocator(*args, **kwargs):
if not _has_gcloud:
raise ImportError(
"GoogleDocumentOcrCheckboxDetector OCR requires missing 'google-cloud-documentai' package."
)
return fn(*args, **kwargs)

return wrapper_decocator


def _val_or_env(val: Optional[str], env: str, default: Optional[str] = None) -> Optional[str]:
Expand Down Expand Up @@ -62,7 +45,6 @@ def _visual_element_to_bbox(visual_element) -> tuple[BBox, float]:


class GoogleDocumentOcrCheckboxDetector:
@requires_gcloud
def __init__(
self,
project_id: Optional[str] = None,
Expand All @@ -71,6 +53,12 @@ def __init__(
processor_version: Optional[str] = None,
max_size: Optional[int] = 4096,
):
try:
from google.api_core.client_options import ClientOptions
from google.cloud import documentai
except ImportError:
raise ImportError("GoogleDocumentOcrCheckboxDetector requires missing 'google-cloud-documentai' package.")

self.project_id = _val_or_env(project_id, "GOOGLE_DOC_OCR_PROJECT_ID")
self.location = _val_or_env(location, "GOOGLE_DOC_OCR_LOCATION", default="eu")
self.processor_id = _val_or_env(processor_id, "GOOGLE_DOC_OCR_PROCESSOR_ID")
Expand Down Expand Up @@ -105,8 +93,12 @@ def __init__(
self.project_id, self.location, self.processor_id, self.processor_version
)

@requires_gcloud
def detect_checkboxes(self, page: Union[Image.Image, documentai.RawDocument]) -> tuple[list[BBox], list[float]]:
try:
from google.cloud import documentai
except ImportError:
raise ImportError("GoogleDocumentOcrCheckboxDetector requires missing 'google-cloud-documentai' package.")

if isinstance(page, Image.Image):
if self.max_size is not None:
page = resize_image(img=page, max_size=self.max_size)
Expand Down
32 changes: 10 additions & 22 deletions ocr_wrapper/google_ocr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import functools
import os
import time
from time import sleep
Expand All @@ -15,24 +14,6 @@

tracer = trace.get_tracer(__name__)

try:
from google.cloud import vision
except ImportError:
_has_gcloud = False
else:
_has_gcloud = True


def requires_gcloud(fn):
@functools.wraps(fn)
def wrapper_decocator(*args, **kwargs):
if not _has_gcloud:
raise ImportError('Google OCR requires missing "google-cloud-vision" package.')
return fn(*args, **kwargs)

return wrapper_decocator


# Define a list of languages which are written from right to left. This is needed to determine the rotation of the document
rtl_languages = [
"ar",
Expand Down Expand Up @@ -183,7 +164,6 @@ class GoogleOCR(OcrWrapper):
Google Cloud Vision API.
"""

@requires_gcloud
def __init__(
self,
*,
Expand All @@ -197,6 +177,11 @@ def __init__(
add_qr_barcodes: bool = False,
verbose: bool = False,
):
try:
from google.cloud import vision
except ImportError:
raise ImportError('GoogleOCR requires missing "google-cloud-vision" package.')

super().__init__(
cache_file=cache_file,
max_size=max_size,
Expand All @@ -219,11 +204,15 @@ def __init__(
self.endpoint = endpoint
self.client = vision.ImageAnnotatorClient(client_options={"api_endpoint": self.endpoint})

@requires_gcloud
@tracer.start_as_current_span(name="GoogleOCR._get_ocr")
def _get_ocr_response(self, img: Image.Image):
"""Gets the OCR response from the Google cloud. Uses cached response if a cache file has been specified and the
document has been OCRed already"""
try:
from google.cloud import vision
except ImportError:
raise ImportError('GoogleOCR requires missing "google-cloud-vision" package.')

span = trace.get_current_span()
set_image_attributes(span, img)

Expand Down Expand Up @@ -257,7 +246,6 @@ def _get_ocr_response(self, img: Image.Image):
self._put_on_shelf(img, response)
return response

@requires_gcloud
@tracer.start_as_current_span(name="GoogleOCR._convert_ocr_response")
def _convert_ocr_response(self, response) -> tuple[List[BBox], dict[str, Any]]:
"""Converts the response given by Google OCR to a list of BBox"""
Expand Down
Loading

0 comments on commit 58de9a9

Please sign in to comment.