Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
… autodiscover
  • Loading branch information
choinek committed Jan 12, 2025
1 parent bc9a9d3 commit d1425e5
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 45 deletions.
4 changes: 4 additions & 0 deletions text_extract_api/celery_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from celery import Celery
from dotenv import load_dotenv

from text_extract_api.extract.ocr_strategies.ocr_strategy import OCRStrategy

sys.path.insert(0, str(pathlib.Path(__file__).parent.resolve()))

load_dotenv(".env")
Expand All @@ -22,4 +24,6 @@
"worker_max_memory_per_child": 8200000
})


OCRStrategy.autodiscover_strategies()
app.autodiscover_tasks(["text_extract_api.extract"], 'tasks', True)
71 changes: 63 additions & 8 deletions text_extract_api/extract/ocr_strategies/ocr_strategy.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,40 @@
from __future__ import annotations

import importlib
import pkgutil
from typing import Type, Dict

from text_extract_api.files.file_formats.file_format import FileFormat
from text_extract_api.files.file_formats.pdf_file_format import PdfFileFormat


def discover_strategies() -> Dict[str, Type]:
strategies = {}

for module_info in pkgutil.iter_modules():
if module_info.name.startswith("text_extract_api"):
try:
module = importlib.import_module(module_info.name)
except ImportError:
continue
if hasattr(module, "__path__"):
for submodule_info in pkgutil.walk_packages(module.__path__, module_info.name + "."):
if ".ocr_strategies." in submodule_info.name:
try:
ocr_module = importlib.import_module(submodule_info.name)
except ImportError:
continue
for attr_name in dir(ocr_module):
attr = getattr(ocr_module, attr_name)
if isinstance(attr, type) and issubclass(attr, OCRStrategy) and attr is not OCRStrategy:
strategies[attr.name()] = attr

return strategies


class OCRStrategy:
_strategies: Dict[str, Type] = {}

def __init__(self):
print("a")
self.update_state_callback = None

def set_update_state_callback(self, callback):
Expand All @@ -15,11 +44,37 @@ def update_state(self, state, meta):
if self.update_state_callback:
self.update_state_callback(state, meta)

def extract_text_from_pdf(self, pdf_bytes):
# Leave for backward compatibility
self.extract_text(PdfFileFormat.from_binary(pdf_bytes, None, None))
@classmethod
def name(cls) -> str:
raise NotImplementedError("Strategy subclasses must implement name")

@classmethod
def extract_text(cls, file_format: Type["FileFormat"]):
raise NotImplementedError("Strategy subclasses must implement extract_text method")

@classmethod
def get_strategy(cls, name: str) -> Type["OCRStrategy"]:
"""
Fetches and returns a registered strategy class based on the given name.
Args:
name: The name of the strategy to fetch.
Returns:
The strategy class corresponding to the provided name.
Raises:
ValueError: If the specified strategy name does not exist among the registered strategies.
"""

if name not in cls._strategies:
cls.autodiscover_strategies()
if name not in cls._strategies:
available = ', '.join(cls._strategies.keys())
raise ValueError(f"Unknown strategy '{name}'. Available: {available}")

"""Base OCR Strategy Interface"""
return cls._strategies[name]

def extract_text(self, file_format: FileFormat):
raise NotImplementedError("Subclasses must implement this method")
@classmethod
def autodiscover_strategies(cls):
cls._strategies = discover_strategies()
10 changes: 6 additions & 4 deletions text_extract_api/extract/ocr_strategies/tesseract.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,23 @@
import numpy as np
import pytesseract



from text_extract_api.extract.ocr_strategies.ocr_strategy import OCRStrategy
from text_extract_api.files.file_formats.file_format import FileFormat
from text_extract_api.files.file_formats.image_file_format import ImageFileFormat


class TesseractOCRStrategy(OCRStrategy):
"""Tesseract OCR Strategy"""

@classmethod
def name(cls) -> str:
return "tesseract"

def extract_text(self, file_format: FileFormat):

if file_format.convertible_to(ImageFileFormat):
raise Exception(f"TesseractOCRStrategy does not handle files of mime type: {file_format.mime_type}")

images = list(FileFormat.convert_to(file_format, ImageFileFormat));
images = FileFormat.convert_to(file_format, ImageFileFormat);
extracted_text = ""

for i, image in enumerate(images):
Expand Down
19 changes: 5 additions & 14 deletions text_extract_api/extract/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,10 @@
import redis

from text_extract_api.celery_app import app as celery_app
from text_extract_api.extract.ocr_strategies.llama_vision import LlamaVisionOCRStrategy
from text_extract_api.extract.ocr_strategies.marker import MarkerOCRStrategy
from text_extract_api.extract.ocr_strategies.tesseract import TesseractOCRStrategy
from text_extract_api.extract.ocr_strategies.ocr_strategy import OCRStrategy
from text_extract_api.files.file_formats.file_format import FileFormat
from text_extract_api.files.storage_manager import StorageManager

OCR_STRATEGIES = {
'marker': MarkerOCRStrategy(),
'tesseract': TesseractOCRStrategy(),
'llama_vision': LlamaVisionOCRStrategy()
}

# Connect to Redis
redis_url = os.getenv('REDIS_CACHE_URL', 'redis://redis:6379/1')
redis_client = redis.StrictRedis.from_url(redis_url)
Expand All @@ -25,7 +18,7 @@
@celery_app.task(bind=True)
def ocr_task(
self,
byes: bytes,
binary_content: bytes,
strategy_name: str,
filename: str,
file_hash: str,
Expand All @@ -39,10 +32,8 @@ def ocr_task(
Celery task to perform OCR processing on a PDF/Office/image file.
"""
start_time = time.time()
if strategy_name not in OCR_STRATEGIES:
raise ValueError(f"Unknown strategy '{strategy_name}'. Available: marker, tesseract, llama_vision")

ocr_strategy = OCR_STRATEGIES[strategy_name]
ocr_strategy = OCRStrategy.get_strategy(strategy_name)
ocr_strategy.set_update_state_callback(self.update_state)

self.update_state(state='PROGRESS', status="File uploaded successfully",
Expand All @@ -61,7 +52,7 @@ def ocr_task(
self.update_state(state='PROGRESS',
meta={'progress': 30, 'status': 'Extracting text from PDF', 'start_time': start_time,
'elapsed_time': time.time() - start_time}) # Example progress update
extracted_text = ocr_strategy.extract_text_from_pdf(byes)
extracted_text = ocr_strategy.extract_text(FileFormat.from_binary(binary_content))
else:
print("Using cached result...")

Expand Down
12 changes: 7 additions & 5 deletions text_extract_api/files/file_formats/file_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import base64
from hashlib import md5
from typing import Type, Iterator, Optional, Dict, Callable, Union, TypeVar
from typing import Type, Iterator, Optional, Dict, Callable, Union, TypeVar, List

from text_extract_api.files.utils import filetype

Expand Down Expand Up @@ -108,15 +108,15 @@ def can_convert_to(self, target_format: Type["FileFormat"]) -> bool:
convertible_keys = self.convertible_to().keys()
return any(target_format is key for key in convertible_keys)

def convert_to(self, target_format: Type["FileFormat"]) -> Iterator["FileFormat"]:
def convert_to(self, target_format: Type["FileFormat"]) -> List["FileFormat"]:
if isinstance(self, target_format):
yield self
# @todo check if this compare is ok
return [self];

converters = self.convertible_to()
if target_format not in converters:
raise ValueError(f"Cannot convert to {target_format}. Conversion not supported.")

return converters[target_format](self)
return list(converters[target_format](self))

@staticmethod
def convertible_to() -> Dict[Type["FileFormat"], Callable[[Type["FileFormat"]], Iterator[Type["Converter"]]]]:
Expand Down Expand Up @@ -155,6 +155,8 @@ def unify(self) -> Type["FileFormat"]:

@staticmethod
def _get_file_format_class(mime_type: str) -> Type["FileFormat"]:
import text_extract_api.files.file_formats.pdf_file_format # noqa - its not unused import @todo autodiscover
import text_extract_api.files.file_formats.image_file_format # noqa - its not unused import @todo autodiscover
for subclass in FileFormat.__subclasses__():
if mime_type in subclass.accepted_mime_types():
return subclass
Expand Down
22 changes: 8 additions & 14 deletions text_extract_api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from text_extract_api.files.file_formats.file_format import FileFormat
from text_extract_api.files.storage_manager import StorageManager
from text_extract_api.celery_app import app as celery_app
from text_extract_api.extract.tasks import ocr_task, OCR_STRATEGIES
from text_extract_api.extract.tasks import ocr_task
from text_extract_api.extract.ocr_strategies.ocr_strategy import OCRStrategy

# Define base path as text_extract_api - required for keeping absolute namespaces
sys.path.insert(0, str(pathlib.Path(__file__).parent.resolve()))
Expand Down Expand Up @@ -53,20 +54,15 @@ async def ocr_endpoint(
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))

if file.content_type is not None and file.content_type != 'application/pdf':
raise HTTPException(status_code=400, detail="Invalid file type. Only PDFs are supported.")

pdf_bytes = await file.read()

# Generate a hash of the document content for caching
pdf_hash = md5(pdf_bytes).hexdigest()
file_binary = await file.read()
file_format = FileFormat.from_binary(file_binary)

print(
f"Processing Document {file.filename} with strategy: {strategy}, ocr_cache: {ocr_cache}, model: {model}, storage_profile: {storage_profile}, storage_filename: {storage_filename}")

# Asynchronous processing using Celery
task = ocr_task.apply_async(
args=[pdf_bytes, strategy, file.filename, pdf_hash, ocr_cache, prompt, model, storage_profile,
args=[file_format.binary, strategy, file.filename, file_format.hash, ocr_cache, prompt, model, storage_profile,
storage_filename])
return {"task_id": task.id}

Expand Down Expand Up @@ -109,8 +105,7 @@ class OcrRequest(BaseModel):

@field_validator('strategy')
def validate_strategy(cls, v):
if v not in OCR_STRATEGIES:
raise ValueError(f"Unknown strategy '{v}'. Available: marker, tesseract")
OCRStrategy.get_strategy(v)
return v

@field_validator('file')
Expand Down Expand Up @@ -140,8 +135,7 @@ class OcrFormRequest(BaseModel):

@field_validator('strategy')
def validate_strategy(cls, v):
if v not in OCR_STRATEGIES:
raise ValueError(f"Unknown strategy '{v}'. Available: marker, tesseract")
OCRStrategy.get_strategy(v)
return v

@field_validator('storage_profile')
Expand All @@ -166,7 +160,7 @@ async def ocr_request_endpoint(request: OcrRequest):
raise HTTPException(status_code=400, detail=str(e))

print(
f"Processing PDF with strategy: {request.strategy}, ocr_cache: {request.ocr_cache}, model: {request.model}, storage_profile: {request.storage_profile}, storage_filename: {request.storage_filename}")
f"Processing {file.mime_type} with strategy: {request.strategy}, ocr_cache: {request.ocr_cache}, model: {request.model}, storage_profile: {request.storage_profile}, storage_filename: {request.storage_filename}")

# Asynchronous processing using Celery
task = ocr_task.apply_async(
Expand Down

0 comments on commit d1425e5

Please sign in to comment.