diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..89cd51b --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,13 @@ +# Change Log + +List of changes between versions + +## 0.2.0 + +Restructured the code to make it pluginable. +No change should be noticeable from a user experience point of view, but now it should be much easier to contribute to the code (new functionalities can be introduced by writing a plugin without having to modify this codebase). + +- The models entries in the database now requires an `entrypoint` field to identify which model should be used to load it. +- The functionality related to `easyocr`, `tesseract` and `hugginface` models have been moved to the `ocr_translate/plugins` folder, and are now plugins (kept in the main codebase to leave an example on how a plugin can work). + +## 0.1.4 diff --git a/README.md b/README.md index 698e88f..7ff252d 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,10 @@ Unzip the file and from inside the folder, run the `run_server-XXX.exe` file (XX The server will run with sensible defaults. Most notably the models files and database will be downloaded/created under `%userprofile%/.ocr_translate`. Also the gpu version will attempt to run on GPU by default, and fall-back to CPU if the former is not available. -For customization, you can set the [environment variable](#environment-variables) yourself, either via powershell or by searching for *environment variable* in the settings menu. +For customization, you can set the [environment variable](#environment-variables) yourself: + +- Linux: use `export ENV_VAR_NAME=XXX` before launching the code from terminal. +- Windows: either via powershell or by searching for *environment variable* in the settings menu. ### From Github installation @@ -147,6 +150,32 @@ before installing the python package. - Hugging Face [Seq2Seq](https://huggingface.co/learn/nlp-course/chapter1/7) models +## Writing plugins for the server + +Since version 0.2.0 the server has been made pluginable. +You can write a plugin for a model/web-service that has not yet been implemented, by subclassing the following models + +- `ocr_translate.models.OCRBoxModel`: Must define the following methods (see the base models and the plugins under `ocr_translate.plugins` for example of function signature and expected input outputs): + - `load`: function to load the model into memory. Can be defined to do nothing if not needed (e.g. another library that load the model on import or a plugin for a web-service) + - `unload`: function to unload the model from memory. + - `_box_detection`: Function that takes a PIL image as input and returns a list of bounding boxes. +- `ocr_translate.models.OCRModel`: Must define the following methods (see the base models and the plugins under `ocr_translate.plugins` for example of function signature and expected input outputs): + - `load`: function to load the model into memory. Can be defined to do nothing if not needed (e.g. another library that load the model on import or a plugin for a web-service) + - `unload`: function to unload the model from memory. + - `_ocr`: Function that takes an image as input and returns the OCRed text (the image is the content of the bounding generated by a BOX model run) +- `ocr_translate.models.TSLModel`: Must define the following methods (see the base models and the plugins under `ocr_translate.plugins` for example of function signature and expected input outputs): + - `load`: function to load the model into memory. Can be defined to do nothing if not needed (e.g. another library that load the model on import or a plugin for a web-service) + - `unload`: function to unload the model from memory. + - `_translate`: Function that takes a list of tokens or a list(list(tokens)) as inputs and returns the translated text as output either as a `str` or `list[str]` (this is needed to work efficiently with AI models that can perform multiple translations simultaneously) + +NOTE: +When subclassing the following has to be set inside the class (see [django models doc](https://docs.djangoproject.com/en/4.2/topics/db/models/#proxy-models)) + + class Meta: + proxy = True + +Until there is a registry service, do contact me if you write a plugin for the server so i can add a link to it in this README + ## Endpoints This is not a REST API. As of now the communication between the server and a front-end is stateful and depend on the languages and models currently loaded on the server. diff --git a/build.sh b/build.sh index 78fb196..c24fc6f 100755 --- a/build.sh +++ b/build.sh @@ -8,12 +8,14 @@ pyinstaller \ --icon icon.ico \ --add-data "ocr_translate/ocr_tsl/languages.json:ocr_translate/ocr_tsl" \ --add-data "ocr_translate/ocr_tsl/models.json:ocr_translate/ocr_tsl" \ + --collect-all djang-ocr_translate \ --collect-all torch \ --collect-all torchvision \ --collect-all transformers \ --collect-all unidic_lite \ --collect-all sacremoses \ --collect-all sentencepiece \ + --recursive-copy-metadata djang-ocr_translate \ --recursive-copy-metadata torch \ --recursive-copy-metadata torchvision \ --recursive-copy-metadata transformers \ diff --git a/ocr_translate/__init__.py b/ocr_translate/__init__.py index af7954f..63c2dbf 100644 --- a/ocr_translate/__init__.py +++ b/ocr_translate/__init__.py @@ -18,4 +18,4 @@ ################################################################################### """OCR and translation of images.""" -__version__ = '0.1.4' +__version__ = '0.2.0' diff --git a/ocr_translate/migrations/0006_tslmodel_entrypoint.py b/ocr_translate/migrations/0006_tslmodel_entrypoint.py new file mode 100644 index 0000000..ea9d5a5 --- /dev/null +++ b/ocr_translate/migrations/0006_tslmodel_entrypoint.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.4 on 2023-08-03 13:30 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('ocr_translate', '0002_ocrboxmodel_default_options_ocrmodel_default_options_and_more_squashed_0005_alter_ocrboxmodel_default_options_and_more'), + ] + + operations = [ + migrations.AddField( + model_name='tslmodel', + name='entrypoint', + field=models.CharField(max_length=128, null=True), + ), + ] diff --git a/ocr_translate/migrations/0007_ocrmodel_entrypoint_alter_ocrmodel_default_options_and_more.py b/ocr_translate/migrations/0007_ocrmodel_entrypoint_alter_ocrmodel_default_options_and_more.py new file mode 100644 index 0000000..7c644b0 --- /dev/null +++ b/ocr_translate/migrations/0007_ocrmodel_entrypoint_alter_ocrmodel_default_options_and_more.py @@ -0,0 +1,29 @@ +# Generated by Django 4.2.4 on 2023-08-03 14:48 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('ocr_translate', '0006_tslmodel_entrypoint'), + ] + + operations = [ + migrations.AddField( + model_name='ocrmodel', + name='entrypoint', + field=models.CharField(max_length=128, null=True), + ), + migrations.AlterField( + model_name='ocrmodel', + name='default_options', + field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='used_by_%(class)s', to='ocr_translate.optiondict'), + ), + migrations.AlterField( + model_name='tslmodel', + name='default_options', + field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='used_by_%(class)s', to='ocr_translate.optiondict'), + ), + ] diff --git a/ocr_translate/migrations/0008_ocrboxmodel_entrypoint_and_more.py b/ocr_translate/migrations/0008_ocrboxmodel_entrypoint_and_more.py new file mode 100644 index 0000000..916feef --- /dev/null +++ b/ocr_translate/migrations/0008_ocrboxmodel_entrypoint_and_more.py @@ -0,0 +1,24 @@ +# Generated by Django 4.2.4 on 2023-08-03 16:27 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('ocr_translate', '0007_ocrmodel_entrypoint_alter_ocrmodel_default_options_and_more'), + ] + + operations = [ + migrations.AddField( + model_name='ocrboxmodel', + name='entrypoint', + field=models.CharField(max_length=128, null=True), + ), + migrations.AlterField( + model_name='ocrboxmodel', + name='default_options', + field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='used_by_%(class)s', to='ocr_translate.optiondict'), + ), + ] diff --git a/ocr_translate/models.py b/ocr_translate/models.py index af120c5..ccd8910 100644 --- a/ocr_translate/models.py +++ b/ocr_translate/models.py @@ -17,10 +17,21 @@ # Home: https://github.com/Crivella/ocr_translate # ################################################################################### """Django models for the ocr_translate app.""" +import logging +import re +from typing import Generator, Type, Union + +import pkg_resources from django.db import models +from PIL.Image import Image as PILImage + +from . import queues +from .messaging import Message LANG_LENGTH = 32 +logger = logging.getLogger('ocr.general') + class OptionDict(models.Model): """Dictionary of options for OCR and translation""" options = models.JSONField(unique=True) @@ -47,48 +58,375 @@ class Language(models.Model): def __str__(self): return str(self.iso1) -class OCRModel(models.Model): - """OCR model using hugging space naming convention""" +class BaseModel(models.Model): + """Mixin class for loading entrypoint models""" + class Meta: + abstract = True + + entrypoint_namespace = None + name = models.CharField(max_length=128) - languages = models.ManyToManyField(Language, related_name='ocr_models') + + entrypoint = models.CharField(max_length=128, null=True) default_options = models.ForeignKey( - OptionDict, on_delete=models.SET_NULL, related_name='ocr_default_options', null=True + OptionDict, on_delete=models.SET_NULL, related_name='used_by_%(class)s', null=True ) - language_format = models.CharField(max_length=32, null=True) - def __str__(self): return str(self.name) -class OCRBoxModel(models.Model): - """OCR model for bounding boxes""" - name = models.CharField(max_length=128) - languages = models.ManyToManyField(Language, related_name='box_models') + def __del__(self): + try: + self.unload() + except NotImplementedError: + pass - default_options = models.ForeignKey( - OptionDict, on_delete=models.SET_NULL, related_name='box_default_options', null=True - ) + @classmethod + def from_entrypoint(cls, name: str) -> Type['models.Model']: + """Get the entrypoint specific TSL model class from the entrypoint name""" + if cls.entrypoint_namespace is None: + raise ValueError('Cannot load base model class from entrypoint.') + + obj = cls.objects.get(name=name) + ept = obj.entrypoint + + logger.debug(f'Loading model {name} from entrypoint {cls.entrypoint_namespace}:{ept}') + for entrypoint in pkg_resources.iter_entry_points(cls.entrypoint_namespace, name=ept): + new_cls = entrypoint.load() + break + else: + raise ValueError(f'Missing plugin: Entrypoint "{ept}" not found.') + + return new_cls.objects.get(name=name) + + def load(self) -> None: + """Placeholder method for loading the model. To be implemented via entrypoint""" + raise NotImplementedError('The base model class does not implement this method.') + + def unload(self) -> None: + """Placeholder method for unloading the model. To be implemented via entrypoint""" + raise NotImplementedError('The base model class does not implement this method.') + +class OCRModel(BaseModel): + """OCR model.""" + entrypoint_namespace = 'ocr_translate.ocr_models' + # iso1 code for languages that do not use spaces to separate words + NO_SPACE_LANGUAGES = ['ja', 'zh', 'lo', 'my'] + + languages = models.ManyToManyField(Language, related_name='ocr_models') language_format = models.CharField(max_length=32, null=True) - def __str__(self): - return str(self.name) + def prepare_image( + self, + img: PILImage, bbox: tuple[int, int, int, int] = None + ) -> PILImage: + """Standard operation to be performed on image before OCR. E.G color scale and crop to bbox""" + if not isinstance(img, PILImage): + raise TypeError(f'img should be PIL Image, but got {type(img)}') + img = img.convert('RGB') + + if bbox: + img = img.crop(bbox) + + return img + + def _ocr( + self, + img: PILImage, lang: str = None, options: dict = None + ) -> str: + """Placeholder method for performing OCR. To be implemented via entrypoint""" + raise NotImplementedError('The base model class does not implement this method.') + + def ocr( + self, + bbox_obj: 'BBox', lang: 'Language', image: PILImage = None, options: 'OptionDict' = None, + force: bool = False, block: bool = True, + ) -> Generator[Union[Message, 'Text'], None, None]: + """High level function to perform OCR on an image. + + Args: + bbox_obj (m.BBox): The BBox object from the database. + lang (m.Language): The Language object from the database. + image (Image.Image, optional): The image on which to perform OCR. Needed if no previous OCR run exists, or + force is True. + options (m.OptionDict, optional): The OptionDict object from the database + containing the options for the OCR. + force (bool, optional): Whether to force the OCR to run again even if a previous run exists. + Defaults to False. + block (bool, optional): Whether to block until the task is complete. Defaults to True. + + Raises: + ValueError: ValueError is raised if at any step of the pipeline an image is required but not provided. + + Yields: + Generator[Union[Message, m.Text], None, None]: + If block is True, yields a Message object for the OCR run first and the resulting Text object second. + If block is False, yields the resulting Text object. + """ + options_obj = options + if options_obj is None: + options_obj = OptionDict.objects.get(options={}) + params = { + 'bbox': bbox_obj, + 'model': self, + 'lang_src': lang, + 'options': options_obj, + } + ocr_run_obj = OCRRun.objects.filter(**params).first() + if ocr_run_obj is None or force: + if image is None: + raise ValueError('Image is required for OCR') + logger.info('Running OCR') + + id_ = (bbox_obj.id, self.id, lang.id) + mlang = getattr(lang, self.language_format or 'iso1') + opt_dct = options_obj.options + text = queues.ocr_queue.put( + id_=id_, + handler=self._ocr, + msg={ + 'args': (self.prepare_image(image, bbox_obj.lbrt),), + 'kwargs': { + 'lang': mlang, + 'options': opt_dct + }, + }, + ) + if not block: + yield text + text = text.response() + if lang.iso1 in self.NO_SPACE_LANGUAGES: + text = text.replace(' ', '') + text_obj, _ = Text.objects.get_or_create( + text=text, + ) + params['result'] = text_obj + ocr_run_obj = OCRRun.objects.create(**params) + else: + if not block: + # Both branches should have the same number of yields + yield None + logger.info(f'Reusing OCR <{ocr_run_obj.id}>') + text_obj = ocr_run_obj.result + # text = ocr_run.result.text + + yield text_obj + + +class OCRBoxModel(BaseModel): + """OCR model for bounding boxes""" + #pylint: disable=abstract-method + entrypoint_namespace = 'ocr_translate.box_models' + + languages = models.ManyToManyField(Language, related_name='box_models') -class TSLModel(models.Model): + language_format = models.CharField(max_length=32, null=True) + + def _box_detection( + self, + image: PILImage, options: dict = None + ) -> list[tuple[int, int, int, int]]: + """Placeholder method for performing box detection. To be implemented via entrypoint""" + raise NotImplementedError('The base model class does not implement this method.') + + def box_detection( + self, + img_obj: 'Image', lang: 'Language', image: PILImage = None, + force: bool = False, options: 'OptionDict' = None + ) -> list['BBox']: + """High level function to perform box OCR on an image. Will attempt to reuse a previous run if possible. + + Args: + img_obj (m.Image): An Image object from the database. + lang (m.Language): A Language object from the database (not every model is gonna use this). + image (Image.Image, optional): The Pillow image to be used for OCR if a previous result is not found. + Defaults to None. + force (bool, optional): If true, re-run the OCR even if a previous result is found. Defaults to False. + options (m.OptionDict, optional): An OptionDict object from the database. Defaults to None. + + Raises: + ValueError: ValueError is raised if at any step of the pipeline an image is required but not provided. + + Returns: + list[m.BBox]: A list of BBox objects containing the resulting bounding boxes. + """ + options_obj = options or OptionDict.objects.get(options={}) + params = { + 'image': img_obj, + 'model': self, + 'options': options_obj, + 'lang_src': lang, + } + + bbox_run = OCRBoxRun.objects.filter(**params).first() + if bbox_run is None or force: + if image is None: + raise ValueError('Image is required for BBox OCR') + logger.info('Running BBox OCR') + opt_dct = options_obj.options + id_ = (img_obj.id, self.id, lang.id) + bboxes = queues.box_queue.put( + id_=id_, + handler=self._box_detection, + msg={ + 'args': (image,), + 'kwargs': {'options': opt_dct}, + }, + ) + bboxes = bboxes.response() + # Create it here to avoid having a failed entry in DB + bbox_run = OCRBoxRun.objects.create(**params) + for bbox in bboxes: + l,b,r,t = bbox + BBox.objects.create( + l=l, + b=b, + r=r, + t=t, + image=img_obj, + from_ocr=bbox_run, + ) + else: + logger.info(f'Reusing BBox OCR <{bbox_run.id}>') + logger.info(f'BBox OCR result: {len(bbox_run.result.all())} boxes') + + return list(bbox_run.result.all()) + + +class TSLModel(BaseModel): """Translation models using hugging space naming convention""" - name = models.CharField(max_length=128) + entrypoint_namespace = 'ocr_translate.tsl_models' + src_languages = models.ManyToManyField(Language, related_name='tsl_models_src') dst_languages = models.ManyToManyField(Language, related_name='tsl_models_dst') - default_options = models.ForeignKey( - OptionDict, on_delete=models.SET_NULL, related_name='tsl_default_options', null=True - ) - language_format = models.CharField(max_length=32, null=True) - def __str__(self): - return str(self.name) + @staticmethod + def pre_tokenize( + text: str, + ignore_chars: str = None, break_chars: str = None, break_newlines: bool = False, + restore_dash_newlines: bool = False, + **kwargs + ) -> list[str]: + """Pre-tokenize a text string. + + Args: + text (str): Text to tokenize. + ignore_chars (str, optional): String of characters to ignore. Defaults to None. + break_chars (str, optional): String of characters to break on. Defaults to None. + break_newlines (bool, optional): Whether to break on newlines. Defaults to True. + restore_dash_newlines (bool, optional): Whether to restore dash-newlines (word broken with a -newline). + Defaults to False. + + Returns: + list[str]: List of string tokens. + """ + if restore_dash_newlines: + text = re.sub(r'(? 0: + tokens = re.split(f'[{break_chars}+]', text) + + if isinstance(tokens, str): + tokens = [text] + + res = list(filter(None, tokens)) + return res if len(res) > 0 else [' '] + + + def _translate(self, tokens: list, src_lang: str, dst_lang: str, options: dict = None) -> str | list[str]: + """Placeholder method for translating a text. To be implemented via entrypoint""" + raise NotImplementedError('The base model class does not implement this method.') + + def translate( + self, + text_obj: 'Text', src: 'Language', dst: 'Language', options: 'OptionDict' = None, + force: bool = False, + block: bool = True, + lazy: bool = False + ) -> Generator[Union[Message, 'Text'], None, None]: + """High level translate call generating a TranslationRun entry. + Args: + text_obj (m.Text): Text object from the database to translate. + src (m.Language): Source language object from the database. + dst (m.Language): Destination language object from the database. + options (m.OptionDict, optional): OptionDict object from the database. Defaults to None. + force (bool, optional): Whether to force a new TSL run. Defaults to False. + block (bool, optional): Whether to block until the task is complete. Defaults to True. + lazy (bool, optional): Whether to raise an error if the TSL run is not found. Defaults to False. + + Raises: + ValueError: If lazy and force are both True or if lazy is True and the TSL run is not found. + + Yields: + Generator[Union[Message, m.Text], None, None]: + If block is True, yields a Message object for the TSL run first and the resulting Text object second. + If block is False, yields the resulting Text object. + """ + if lazy and force: + raise ValueError('Cannot force + lazy TSL run') + options_obj = options or OptionDict.objects.get(options={}) + params = { + 'options': options_obj, + 'text': text_obj, + 'model': self, + 'lang_src': src, + 'lang_dst': dst, + } + tsl_run_obj = TranslationRun.objects.filter(**params).first() + if tsl_run_obj is None or force: + if lazy: + raise ValueError('Value not found for lazy TSL run') + logger.info('Running TSL') + # Generate a unique id for a message + id_ = (text_obj.id, self.id, options_obj.id, src.id, dst.id) + batch_id = (self.id, options_obj.id, src.id, dst.id) + lang_dct = getattr(src.default_options, 'options', {}) + model_dct = getattr(self.default_options, 'options', {}) + opt_dct = {**lang_dct, **model_dct, **options_obj.options} + + tokens = self.pre_tokenize(text_obj.text, **opt_dct) + new = queues.tsl_queue.put( + id_=id_, + batch_id=batch_id, + handler=self._translate, + msg={ + 'args': ( + tokens, + getattr(src, self.language_format), + getattr(dst, self.language_format) + ), + 'kwargs': {'options': opt_dct}, + }, + ) + if not block: + yield new + new = new.response() + text_obj, _ = Text.objects.get_or_create( + text = new, + ) + params['result'] = text_obj + tsl_run_obj = TranslationRun.objects.create(**params) + else: + if not block: + # Both branches should have the same number of yields + yield None + logger.info(f'Reusing TSL <{tsl_run_obj.id}>') + + yield tsl_run_obj.result class Image(models.Model): """Image registered as the md5 of the uploaded file""" diff --git a/ocr_translate/ocr_tsl/__init__.py b/ocr_translate/ocr_tsl/__init__.py index 587bfd4..70d1142 100644 --- a/ocr_translate/ocr_tsl/__init__.py +++ b/ocr_translate/ocr_tsl/__init__.py @@ -23,11 +23,11 @@ from .initializers import (auto_create_languages, auto_create_models, init_most_used) -if os.environ.get('LOAD_ON_START', 'false').lower() == 'true': - init_most_used() - if os.environ.get('AUTOCREATE_LANGUAGES', 'false').lower() == 'true': auto_create_languages() if os.environ.get('AUTOCREATE_VALIDATED_MODELS', 'false').lower() == 'true': auto_create_models() + +if os.environ.get('LOAD_ON_START', 'false').lower() == 'true': + init_most_used() diff --git a/ocr_translate/ocr_tsl/box.py b/ocr_translate/ocr_tsl/box.py index edb337b..1eaa6f4 100644 --- a/ocr_translate/ocr_tsl/box.py +++ b/ocr_translate/ocr_tsl/box.py @@ -18,241 +18,45 @@ ################################################################################### """Functions and piplines to perform Box OCR on an image.""" import logging -from typing import Hashable, Iterable - -import easyocr -import numpy as np -import torch -from PIL import Image from .. import models as m -from ..queues import box_queue as q -from .huggingface import dev -READER = None +dev = 'cpu' #pylint: disable=invalid-name logger = logging.getLogger('ocr.general') BOX_MODEL_ID = None -BBOX_MODEL_OBJ = None +BOX_MODEL_OBJ: m.OCRBoxModel = None def unload_box_model(): """Remove the current box model from memory.""" - global BBOX_MODEL_OBJ, READER, BOX_MODEL_ID + global BOX_MODEL_OBJ, BOX_MODEL_ID logger.info(f'Unloading BOX model: {BOX_MODEL_ID}') - if BOX_MODEL_ID == 'easyocr': - pass - READER = None - BBOX_MODEL_OBJ = None + del BOX_MODEL_OBJ + BOX_MODEL_OBJ = None BOX_MODEL_ID = None - if dev == 'cuda': - torch.cuda.empty_cache() - - def load_box_model(model_id: str): """Load a box model into memory.""" - global BBOX_MODEL_OBJ, READER, BOX_MODEL_ID + global BOX_MODEL_OBJ, BOX_MODEL_ID if BOX_MODEL_ID == model_id: return + if BOX_MODEL_OBJ is not None: + BOX_MODEL_OBJ.unload() + logger.info(f'Loading BOX model: {model_id}') - if model_id == 'easyocr': - READER = easyocr.Reader([], gpu=(dev == 'cuda'), recognizer=False) - BBOX_MODEL_OBJ, _ = m.OCRBoxModel.objects.get_or_create(name=model_id) - BOX_MODEL_ID = model_id - else: - raise NotImplementedError + model = m.OCRBoxModel.from_entrypoint(model_id) + model.load() + + BOX_MODEL_OBJ = model + BOX_MODEL_ID = model_id logger.debug(f'OCR model loaded: {model_id}') - logger.debug(f'OCR model object: {BBOX_MODEL_OBJ}') + logger.debug(f'OCR model object: {BOX_MODEL_OBJ}') def get_box_model() -> m.OCRBoxModel: """Get the current box model.""" - return BBOX_MODEL_OBJ - -def intersections(bboxes: Iterable[tuple[int, int, int, int]], margin: int = 5) -> list[set[int]]: - """Determine the intersections between a list of bounding boxes. - - Args: - bboxes (Iterable[tuple[int, int, int, int]]): List of bounding boxes in lrbt format. - margin (int, optional): Number of extra pixels outside of the boxes that define an intersection. Defaults to 5. - - Returns: - list[set[int]]: List of sets of indexes of the boxes that intersect. - """ - res = [] - - for i,(l1,r1,b1,t1) in enumerate(bboxes): - l1 -= margin - r1 += margin - b1 -= margin - t1 += margin - - for j,(l2,r2,b2,t2) in enumerate(bboxes): - if i == j: - continue - - if l1 >= r2 or r1 <= l2 or b1 >= t2 or t1 <= b2: - continue - - for ptr in res: - if i in ptr or j in ptr: - break - else: - ptr = set() - res.append(ptr) - - ptr.add(i) - ptr.add(j) - - return res - -def merge_bboxes(bboxes: Iterable[tuple[int, int, int, int]]) -> list[tuple[int, int, int, int]]: - """Merge a list of intersecting bounding boxes. All intersecting boxes are merged into a single box. - - Args: - bboxes (Iterable[Iterable[int]]): Iterable of bounding boxes in lrbt format. - - Returns: - list[tuple[int]]: List of merged bounding boxes in lrbt format. - """ - res = [] - bboxes = np.array(bboxes) - inters = intersections(bboxes) - - lst = list(range(len(bboxes))) - - torm = set() - for app in inters: - app = list(app) - data = bboxes[app].reshape(-1,4) - l = data[:,0].min() - r = data[:,1].max() - b = data[:,2].min() - t = data[:,3].max() - - res.append([l,b,r,t]) - - torm = torm.union(app) - - for i in lst: - if i in torm: - continue - l,r,b,t = bboxes[i] - res.append([l,b,r,t]) - - return res - -def _box_pipeline(image: Image.Image, options: dict = None) -> list[tuple[int, int, int, int]]: - """Perform box OCR on an image. - - Args: - image (Image.Image): A Pillow image on which to perform OCR. - options (dict, optional): A dictionary of options. - - Raises: - NotImplementedError: The type of model specified is not implemented. - - Returns: - list[tuple[int, int, int, int]]: A list of bounding boxes in lrbt format. - """ - - if options is None: - options = {} - - # reader.recognize(image) - if BOX_MODEL_ID == 'easyocr': - image = image.convert('RGB') - results = READER.detect(np.array(image)) - - # Axis rectangles - bboxes = results[0][0] - - # Free (NOT IMPLEMENTED) - # ... - - bboxes = merge_bboxes(bboxes) - else: - raise NotImplementedError - - return bboxes - -def box_pipeline(*args, id_: Hashable, block: bool = True, **kwargs): - """Queue a box OCR pipeline. - - Args: - id_ (Hashable): A unique identifier for the OCR task. - block (bool, optional): Whether to block until the task is complete. Defaults to True. - - Returns: - Union[str, Message]: The text extracted from the image (block=True) or a Message object (block=False). - """ - msg = q.put( - id_ = id_, - msg = {'args': args, 'kwargs': kwargs}, - handler = _box_pipeline, - ) - - if block: - return msg.response() - return msg - -def box_run( - img_obj: m.Image, lang: m.Language, image: Image.Image = None, - force: bool = False, options: m.OptionDict = None - ) -> list[m.BBox]: - """High level function to perform box OCR on an image. Will attempt to reuse a previous run if possible. - - Args: - img_obj (m.Image): An Image object from the database. - lang (m.Language): A Language object from the database (not every model is gonna use this). - image (Image.Image, optional): The Pillow image to be used for OCR if a previous result is not found. - Defaults to None. - force (bool, optional): If true, re-run the OCR even if a previous result is found. Defaults to False. - options (m.OptionDict, optional): An OptionDict object from the database. Defaults to None. - - Raises: - ValueError: ValueError is raised if at any step of the pipeline an image is required but not provided. - - Returns: - list[m.BBox]: A list of BBox objects containing the resulting bounding boxes. - """ - options_obj = options or m.OptionDict.objects.get(options={}) - params = { - 'image': img_obj, - 'model': BBOX_MODEL_OBJ, - 'options': options_obj, - 'lang_src': lang, - } - - bbox_run = m.OCRBoxRun.objects.filter(**params).first() - if bbox_run is None or force: - if image is None: - raise ValueError('Image is required for BBox OCR') - logger.info('Running BBox OCR') - opt_dct = options_obj.options - bboxes = box_pipeline( - image, - id_=img_obj.md5, - options=opt_dct, - ) - # Create it here to avoid having a failed entry in DB - bbox_run = m.OCRBoxRun.objects.create(**params) - for bbox in bboxes: - l,b,r,t = bbox - m.BBox.objects.create( - l=l, - b=b, - r=r, - t=t, - image=img_obj, - from_ocr=bbox_run, - ) - else: - logger.info(f'Reusing BBox OCR <{bbox_run.id}>') - logger.info(f'BBox OCR result: {len(bbox_run.result.all())} boxes') - - return list(bbox_run.result.all()) + return BOX_MODEL_OBJ diff --git a/ocr_translate/ocr_tsl/full.py b/ocr_translate/ocr_tsl/full.py index 5ffb213..5dd2b00 100644 --- a/ocr_translate/ocr_tsl/full.py +++ b/ocr_translate/ocr_tsl/full.py @@ -22,10 +22,10 @@ from PIL import Image from .. import models as m -from .box import box_run +from .box import get_box_model from .lang import get_lang_dst, get_lang_src -from .ocr import ocr_run -from .tsl import tsl_run +from .ocr import get_ocr_model +from .tsl import get_tsl_model logger = logging.getLogger('ocr.general') @@ -34,6 +34,10 @@ def ocr_tsl_pipeline_lazy(md5: str, options: dict = None) -> list[dict]: Try to lazily generate reponse from md5. Should raise a ValueError if the operation is not possible (fails at any step). """ + box_model = get_box_model() + ocr_model = get_ocr_model() + tsl_model = get_tsl_model() + if options is None: options = {} logger.debug(f'LAZY: START {md5}') @@ -42,11 +46,12 @@ def ocr_tsl_pipeline_lazy(md5: str, options: dict = None) -> list[dict]: img_obj= m.Image.objects.get(md5=md5) except m.Image.DoesNotExist as exc: raise ValueError(f'Image with md5 {md5} does not exist') from exc - bbox_obj_list = box_run(img_obj, get_lang_src()) + bbox_obj_list = box_model.box_detection(img_obj, get_lang_src()) for bbox_obj in bbox_obj_list: - text_obj = ocr_run(bbox_obj, get_lang_src()) + text_obj = ocr_model.ocr(bbox_obj, get_lang_src()) text_obj = next(text_obj) - tsl_obj = tsl_run(text_obj, get_lang_src(), get_lang_dst(), lazy=True) + + tsl_obj = tsl_model.translate(text_obj, get_lang_src(), get_lang_dst(), lazy=True) tsl_obj = next(tsl_obj) text = text_obj.text @@ -69,19 +74,23 @@ def ocr_tsl_pipeline_work(img: Image.Image, md5: str, force: bool = False, optio Generate response from md5 and binary. Will attempt to behave lazily at every step unless force is True. """ + box_model = get_box_model() + ocr_model = get_ocr_model() + tsl_model = get_tsl_model() + if options is None: options = {} logger.debug(f'WORK: START {md5}') res = [] img_obj, _ = m.Image.objects.get_or_create(md5=md5) - bbox_obj_list = box_run(img_obj, get_lang_src() ,image=img) + bbox_obj_list = box_model.box_detection(img_obj, get_lang_src() ,image=img) texts = [] for bbox_obj in bbox_obj_list: logger.debug(str(bbox_obj)) - text_obj = ocr_run(bbox_obj, get_lang_src(), image=img, force=force, block=False) + text_obj = ocr_model.ocr(bbox_obj, get_lang_src(), image=img, force=force, block=False) next(text_obj) texts.append(text_obj) @@ -90,7 +99,7 @@ def ocr_tsl_pipeline_work(img: Image.Image, md5: str, force: bool = False, optio trans = [] for text_obj in texts: - tsl_obj = tsl_run(text_obj, get_lang_src(), get_lang_dst(), force=force, block=False) + tsl_obj = tsl_model.translate(text_obj, get_lang_src(), get_lang_dst(), force=force, block=False) next(tsl_obj) trans.append(tsl_obj) diff --git a/ocr_translate/ocr_tsl/huggingface.py b/ocr_translate/ocr_tsl/huggingface.py deleted file mode 100644 index cdbaee7..0000000 --- a/ocr_translate/ocr_tsl/huggingface.py +++ /dev/null @@ -1,83 +0,0 @@ -################################################################################### -# ocr_translate - a django app to perform OCR and translation of images. # -# Copyright (C) 2023-present Davide Grassano # -# # -# This program is free software: you can redistribute it and/or modify # -# it under the terms of the GNU General Public License as published by # -# the Free Software Foundation, either version 3 of the License. # -# # -# This program is distributed in the hope that it will be useful, # -# but WITHOUT ANY WARRANTY; without even the implied warranty of # -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # -# GNU General Public License for more details. # -# # -# You should have received a copy of the GNU General Public License # -# along with this program. If not, see {http://www.gnu.org/licenses/}. # -# # -# Home: https://github.com/Crivella/ocr_translate # -################################################################################### -"""Base utility functions to load models and project-wise environment variables.""" -import logging -import os -from pathlib import Path - -from transformers import (AutoImageProcessor, AutoModel, AutoModelForSeq2SeqLM, - AutoTokenizer, VisionEncoderDecoderModel) - -logger = logging.getLogger('ocr.general') - -root = Path(os.environ.get('TRANSFORMERS_CACHE', '.')) -logger.debug(f'Cache dir: {root}') -dev = os.environ.get('DEVICE', 'cpu') - -def load(loader, model_id: str): - """Use the specified loader to load a transformers specific Class.""" - try: - mid = root / model_id - logger.debug(f'Attempt loading from store: "{loader}" "{mid}"') - res = loader.from_pretrained(mid) - except Exception: - # Needed to catch some weird exception from transformers - # eg: huggingface_hub.utils._validators.HFValidationError: Repo id must use alphanumeric chars or - # '-', '_', '.', '--' and '..' are forbidden, '-' and '.' cannot start or end the name, max length is 96: ... - logger.debug(f'Attempt loading from cache: "{loader}" "{model_id}" "{root}"') - res = loader.from_pretrained(model_id, cache_dir=root) - return res - -mapping = { - 'tokenizer': AutoTokenizer, - 'ved_model': VisionEncoderDecoderModel, - 'model': AutoModel, - 'image_processor': AutoImageProcessor, - 'seq2seq': AutoModelForSeq2SeqLM -} - -accept_device = ['ved_model', 'seq2seq', 'model'] - -def load_hugginface_model(model_id: str, request: list[str]) -> list: - """Load the requested HuggingFace's Classes for the model into the memory of the globally specified device. - - Args: - model_id (str): The HuggingFace model id to load, or a path to a local model. - request (list[str]): A list of HuggingFace's Classes to load. - - Raises: - ValueError: If the model_id is not found or if the requested Class is not supported. - - Returns: - _type_: A list of the requested Classes. - """ """""" - res = {} - for r in request: - if r not in mapping: - raise ValueError(f'Unknown request: {r}') - cls = load(mapping[r], model_id) - if cls is None: - raise ValueError(f'Could not load model: {model_id}') - - if r in accept_device: - cls = cls.to(dev) - - res[r] = cls - - return res diff --git a/ocr_translate/ocr_tsl/initializers.py b/ocr_translate/ocr_tsl/initializers.py index 1e2d057..30e28ca 100644 --- a/ocr_translate/ocr_tsl/initializers.py +++ b/ocr_translate/ocr_tsl/initializers.py @@ -83,10 +83,12 @@ def auto_create_models(): logger.debug(f'Creating box model: {box}') lang = box.pop('lang') lcode = box.pop('lang_code') + entrypoint = box.pop('entrypoint') def_opt = box.pop('default_options', {}) opt_obj, _ = m.OptionDict.objects.get_or_create(options=def_opt) model, _ = m.OCRBoxModel.objects.get_or_create(**box) model.default_options = opt_obj + model.entrypoint = entrypoint model.language_format = lcode for l in lang: model.languages.add(m.Language.objects.get(iso1=l)) @@ -96,11 +98,13 @@ def auto_create_models(): logger.debug(f'Creating ocr model: {ocr}') lang = ocr.pop('lang') lcode = ocr.pop('lang_code') + entrypoint = ocr.pop('entrypoint') def_opt = ocr.pop('default_options', {}) opt_obj, _ = m.OptionDict.objects.get_or_create(options=def_opt) model, _ = m.OCRModel.objects.get_or_create(**ocr) model.default_options = opt_obj model.language_format = lcode + model.entrypoint = entrypoint for l in lang: model.languages.add(m.Language.objects.get(iso1=l)) model.save() @@ -110,11 +114,13 @@ def auto_create_models(): src = tsl.pop('lang_src') dst = tsl.pop('lang_dst') lcode = tsl.pop('lang_code', None) + entrypoint = tsl.pop('entrypoint') def_opt = tsl.pop('default_options', {}) opt_obj, _ = m.OptionDict.objects.get_or_create(options=def_opt) model, _ = m.TSLModel.objects.get_or_create(**tsl) model.default_options = opt_obj model.language_format = lcode + model.entrypoint = entrypoint for l in src: logger.debug(f'Adding src language: {l}') kwargs = {lcode: l} diff --git a/ocr_translate/ocr_tsl/models.json b/ocr_translate/ocr_tsl/models.json index 8973193..0d53fb2 100644 --- a/ocr_translate/ocr_tsl/models.json +++ b/ocr_translate/ocr_tsl/models.json @@ -3,19 +3,22 @@ { "name": "easyocr", "lang": ["en", "ja", "zh", "ko"], - "lang_code": "easyocr" + "lang_code": "easyocr", + "entrypoint": "easyocr.box" } ], "ocr": [ { "name": "kha-white/manga-ocr-base", "lang": ["ja"], - "lang_code": "iso1" + "lang_code": "iso1", + "entrypoint": "hugginface.ved" }, { "name": "tesseract", "lang": ["en", "ja", "zh", "ko"], - "lang_code": "tesseract" + "lang_code": "tesseract", + "entrypoint": "tesseract.ocr" } ], "tsl": [ @@ -26,7 +29,8 @@ "lang_code": "iso1", "default_options": { "break_newlines": false - } + }, + "entrypoint": "hugginface.seq2seq" }, { "name": "Helsinki-NLP/opus-mt-ja-en", @@ -35,7 +39,8 @@ "lang_code": "iso1", "default_options": { "break_newlines": true - } + }, + "entrypoint": "hugginface.seq2seq" }, { "name": "staka/fugumt-ja-en", @@ -44,7 +49,8 @@ "lang_code": "iso1", "default_options": { "break_newlines": true - } + }, + "entrypoint": "hugginface.seq2seq" }, { "name": "Helsinki-NLP/opus-mt-ko-en", @@ -53,7 +59,8 @@ "lang_code": "iso1", "default_options": { "break_newlines": false - } + }, + "entrypoint": "hugginface.seq2seq" }, { "name": "facebook/m2m100_418M", @@ -62,7 +69,8 @@ "lang_code": "facebookM2M", "default_options": { "break_newlines": false - } + }, + "entrypoint": "hugginface.seq2seq" }, { "name": "facebook/m2m100_1.2B", @@ -71,7 +79,8 @@ "lang_code": "facebookM2M", "default_options": { "break_newlines": false - } + }, + "entrypoint": "hugginface.seq2seq" } ] } \ No newline at end of file diff --git a/ocr_translate/ocr_tsl/ocr.py b/ocr_translate/ocr_tsl/ocr.py index 2317354..7c8b36d 100644 --- a/ocr_translate/ocr_tsl/ocr.py +++ b/ocr_translate/ocr_tsl/ocr.py @@ -18,62 +18,37 @@ ################################################################################### """Functions and piplines to perform OCR on an image.""" import logging -from typing import Generator, Hashable, Union - -import torch -from PIL import Image from .. import models as m -from ..messaging import Message -from ..queues import ocr_queue as q -from .huggingface import dev, load_hugginface_model -from .tesseract import tesseract_pipeline logger = logging.getLogger('ocr.general') OBJ_MODEL_ID: str = None -OCR_MODEL = None -OCR_TOKENIZER = None -OCR_IMAGE_PROCESSOR = None - OCR_MODEL_OBJ: m.OCRModel = None -NO_SPACE_LANGUAGES = [ - 'ja', 'zh', 'lo', 'my' -] - def unload_ocr_model(): """Remove the current OCR model from memory.""" - global OCR_MODEL_OBJ, OCR_MODEL, OCR_TOKENIZER, OCR_IMAGE_PROCESSOR, OBJ_MODEL_ID + global OCR_MODEL_OBJ, OBJ_MODEL_ID logger.info(f'Unloading OCR model: {OBJ_MODEL_ID}') - OCR_MODEL = None - OCR_TOKENIZER = None - OCR_IMAGE_PROCESSOR = None + del OCR_MODEL_OBJ OCR_MODEL_OBJ = None OBJ_MODEL_ID = None - if dev == 'cuda': - torch.cuda.empty_cache() - def load_ocr_model(model_id: str): """Load an OCR model into memory.""" - global OCR_MODEL_OBJ, OCR_MODEL, OCR_TOKENIZER, OCR_IMAGE_PROCESSOR, OBJ_MODEL_ID + global OCR_MODEL_OBJ, OBJ_MODEL_ID if OBJ_MODEL_ID == model_id: return - # mid = root / model_id - logger.info(f'Loading OCR model: {model_id}') - if model_id == 'tesseract': - pass - else: - res = load_hugginface_model(model_id, request=['ved_model', 'tokenizer', 'image_processor']) - OCR_MODEL = res['ved_model'] - OCR_TOKENIZER = res['tokenizer'] - OCR_IMAGE_PROCESSOR = res['image_processor'] + if OCR_MODEL_OBJ is not None: + OCR_MODEL_OBJ.unload() + + model = m.OCRModel.from_entrypoint(model_id) + model.load() - OCR_MODEL_OBJ, _ = m.OCRModel.objects.get_or_create(name=model_id) + OCR_MODEL_OBJ = model OBJ_MODEL_ID = model_id logger.debug(f'OCR model loaded: {model_id}') @@ -82,130 +57,3 @@ def load_ocr_model(model_id: str): def get_ocr_model() -> m.OCRModel: """Return the current OCR model.""" return OCR_MODEL_OBJ - -def _ocr(img: Image.Image, lang: str = None, bbox: tuple[int, int, int, int] = None, options: dict = None) -> str: - """Perform OCR on an image. - - Args: - img (Image.Image): A Pillow image on which to perform OCR. - lang (str, optional): The language to use for OCR. (Not every model will use this) - bbox (tuple[int, int, int, int], optional): The bounding box of the text on the image in lbrt format. - options (dict, optional): A dictionary of options to pass to the OCR model. - - Raises: - TypeError: If img is not a Pillow image. - - Returns: - str: The text extracted from the image. - """ - if options is None: - options = {} - if not isinstance(img, Image.Image): - raise TypeError(f'img should be PIL Image, but got {type(img)}') - img = img.convert('RGB') - - if bbox: - img = img.crop(bbox) - - if OBJ_MODEL_ID == 'tesseract': - generated_text = tesseract_pipeline(img, lang) - else: - pixel_values = OCR_IMAGE_PROCESSOR(img, return_tensors='pt').pixel_values - if dev == 'cuda': - pixel_values = pixel_values.cuda() - generated_ids = OCR_MODEL.generate(pixel_values) - generated_text = OCR_TOKENIZER.batch_decode(generated_ids, skip_special_tokens=True)[0] - - if dev == 'cuda': - torch.cuda.empty_cache() - - return generated_text - -def ocr(*args, id_: Hashable, block: bool = True, **kwargs) -> Union[str, Message]: - """Queue a text OCR pipeline. - - Args: - id_ (Hashable): A unique identifier for the OCR task. - block (bool, optional): Whether to block until the task is complete. Defaults to True. - - Returns: - Union[str, Message]: The text extracted from the image (block=True) or a Message object (block=False). - """ - msg = q.put( - id_ = id_, - msg = {'args': args, 'kwargs': kwargs}, - handler = _ocr, - ) - - if block: - return msg.response() - return msg - -def ocr_run( - bbox_obj: m.BBox, lang: m.Language, image: Image.Image = None, options: m.OptionDict = None, - force: bool = False, block: bool = True, - ) -> Generator[Union[Message, m.Text], None, None]: - """High level function to perform OCR on an image. - - Args: - bbox_obj (m.BBox): The BBox object from the database. - lang (m.Language): The Language object from the database. - image (Image.Image, optional): The image on which to perform OCR. Needed if no previous OCR run exists, or - force is True. - options (m.OptionDict, optional): The OptionDict object from the database containing the options for the OCR. - force (bool, optional): Whether to force the OCR to run again even if a previous run exists. Defaults to False. - block (bool, optional): Whether to block until the task is complete. Defaults to True. - - Raises: - ValueError: ValueError is raised if at any step of the pipeline an image is required but not provided. - - Yields: - Generator[Union[Message, m.Text], None, None]: - If block is True, yields a Message object for the OCR run first and the resulting Text object second. - If block is False, yields the resulting Text object. - """ - options_obj = options - if options_obj is None: - options_obj = m.OptionDict.objects.get(options={}) - params = { - 'bbox': bbox_obj, - 'model': OCR_MODEL_OBJ, - 'lang_src': lang, - 'options': options_obj, - } - ocr_run_obj = m.OCRRun.objects.filter(**params).first() - if ocr_run_obj is None or force: - if image is None: - raise ValueError('Image is required for OCR') - logger.info('Running OCR') - - id_ = (bbox_obj.id, OCR_MODEL_OBJ.id, lang.id) - mlang = getattr(lang, OCR_MODEL_OBJ.language_format or 'iso1') - opt_dct = options_obj.options - text = ocr( - image, - lang=mlang, - bbox=bbox_obj.lbrt, - options=opt_dct, - id_=id_, - block=block, - ) - if not block: - yield text - text = text.response() - if lang.iso1 in NO_SPACE_LANGUAGES: - text = text.replace(' ', '') - text_obj, _ = m.Text.objects.get_or_create( - text=text, - ) - params['result'] = text_obj - ocr_run_obj = m.OCRRun.objects.create(**params) - else: - if not block: - # Both branches should have the same number of yields - yield None - logger.info(f'Reusing OCR <{ocr_run_obj.id}>') - text_obj = ocr_run_obj.result - # text = ocr_run.result.text - - yield text_obj diff --git a/ocr_translate/ocr_tsl/tesseract.py b/ocr_translate/ocr_tsl/tesseract.py deleted file mode 100644 index 98869f1..0000000 --- a/ocr_translate/ocr_tsl/tesseract.py +++ /dev/null @@ -1,132 +0,0 @@ -################################################################################### -# ocr_translate - a django app to perform OCR and translation of images. # -# Copyright (C) 2023-present Davide Grassano # -# # -# This program is free software: you can redistribute it and/or modify # -# it under the terms of the GNU General Public License as published by # -# the Free Software Foundation, either version 3 of the License. # -# # -# This program is distributed in the hope that it will be useful, # -# but WITHOUT ANY WARRANTY; without even the implied warranty of # -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # -# GNU General Public License for more details. # -# # -# You should have received a copy of the GNU General Public License # -# along with this program. If not, see {http://www.gnu.org/licenses/}. # -# # -# Home: https://github.com/Crivella/ocr_translate # -################################################################################### -"""Functions and piplines to perform OCR on images using tesseract.""" -import logging -import os -from pathlib import Path - -import requests -from PIL import Image -from pytesseract import Output, image_to_string - -from .huggingface import root - -logger = logging.getLogger('ocr.general') - -MODEL_URL = 'https://github.com/tesseract-ocr/tessdata_best/raw/main/{}.traineddata' - -DATA_DIR = Path(os.getenv('TESSERACT_PREFIX', root / 'tesseract')) - -VERTICAL_LANGS = ['jpn', 'chi_tra', 'chi_sim', 'kor'] - -DOWNLOAD = os.getenv('TESSERACT_ALLOW_DOWNLOAD', 'false').lower() == 'true' -CONFIG = False - -def download_model(lang: str): - """Download a tesseract model for a given language. - - Args: - lang (str): A language code for tesseract. - - Raises: - ValueError: If the model could not be downloaded. - """ - if not DOWNLOAD: - raise ValueError('TESSERACT_ALLOW_DOWNLOAD is false. Downloading models is not allowed') - create_config() - - logger.info(f'Downloading tesseract model for language {lang}') - dst = DATA_DIR / f'{lang}.traineddata' - if dst.exists(): - return - res = requests.get(MODEL_URL.format(lang), timeout=5) - if res.status_code != 200: - raise ValueError(f'Could not download model for language {lang}') - - with open(DATA_DIR / f'{lang}.traineddata', 'wb') as f: - f.write(res.content) - - if lang in VERTICAL_LANGS: - download_model(lang + '_vert') - -def create_config(): - """Create a tesseract config file. Run only once""" - global CONFIG - if CONFIG: - return - CONFIG = True - - logger.info('Creating tesseract tsv config') - cfg = DATA_DIR / 'configs' - cfg.mkdir(exist_ok=True, parents=True) - - dst = cfg / 'tsv' - if dst.exists(): - return - with dst.open('w') as f: - f.write('tessedit_create_tsv 1') - -# Page segmentation modes: -# 0 Orientation and script detection (OSD) only. -# 1 Automatic page segmentation with OSD. -# 2 Automatic page segmentation, but no OSD, or OCR. -# 3 Fully automatic page segmentation, but no OSD. (Default) -# 4 Assume a single column of text of variable sizes. -# 5 Assume a single uniform block of vertically aligned text. -# 6 Assume a single uniform block of text. -# 7 Treat the image as a single text line. -# 8 Treat the image as a single word. -# 9 Treat the image as a single word in a circle. -# 10 Treat the image as a single character. -# 11 Sparse text. Find as much text as possible in no particular order. -# 12 Sparse text with OSD. -# 13 Raw line. Treat the image as a single text line, -# bypassing hacks that are Tesseract-specific. -def tesseract_pipeline(img: Image.Image, lang: str, favor_vertical: bool = True) -> str: - """Run tesseract on an image. - - Args: - img (Image.Image): An image to run tesseract on. - lang (str): A language code for tesseract. - favor_vertical (bool, optional): Wether to favor vertical or horizontal configuration for languages that - can be written vertically. Defaults to True. - - Returns: - str: The text extracted from the image. - """ """""" - create_config() - if not (DATA_DIR / f'{lang}.traineddata').exists(): - download_model(lang) - logger.info(f'Running tesseract for language {lang}') - - psm = 6 - if lang in VERTICAL_LANGS: - exp = 1 if favor_vertical else -1 - if img.height * 1.5**exp > img.width: - psm = 5 - - # Using image_to_string will atleast preserve spaces - res = image_to_string( - img, - lang=lang, - config=f'--tessdata-dir {DATA_DIR.as_posix()} --psm {psm}', - output_type=Output.DICT - ) - - return res['text'] diff --git a/ocr_translate/ocr_tsl/tsl.py b/ocr_translate/ocr_tsl/tsl.py index ea3633e..ba02f9c 100644 --- a/ocr_translate/ocr_tsl/tsl.py +++ b/ocr_translate/ocr_tsl/tsl.py @@ -18,277 +18,39 @@ ################################################################################### """Functions and piplines to perform translation on text.""" import logging -import re -from typing import Generator, Hashable, Union - -import torch -from transformers import M2M100Tokenizer from .. import models as m -from ..messaging import Message -from ..queues import tsl_queue as q -from .huggingface import dev, load_hugginface_model logger = logging.getLogger('ocr.general') TSL_MODEL_ID = None -TSL_MODEL = None -TSL_TOKENIZER = None -TSL_MODEL_OBJ = None +TSL_MODEL_OBJ: m.TSLModel = None def unload_tsl_model(): """Remove the current TSL model from memory.""" - global TSL_MODEL_OBJ, TSL_MODEL, TSL_TOKENIZER, TSL_MODEL_ID + global TSL_MODEL_OBJ, TSL_MODEL_ID logger.info(f'Unloading TSL model: {TSL_MODEL_ID}') - TSL_MODEL = None - TSL_TOKENIZER = None + del TSL_MODEL_OBJ TSL_MODEL_OBJ = None TSL_MODEL_ID = None - if dev == 'cuda': - torch.cuda.empty_cache() - def load_tsl_model(model_id): """Load a TSL model into memory.""" - global TSL_MODEL_OBJ, TSL_MODEL, TSL_TOKENIZER, TSL_MODEL_ID + global TSL_MODEL_OBJ, TSL_MODEL_ID if TSL_MODEL_ID == model_id: return - logger.info(f'Loading TSL model: {model_id}') - res = load_hugginface_model(model_id, request=['seq2seq', 'tokenizer']) - TSL_MODEL = res['seq2seq'] - TSL_TOKENIZER = res['tokenizer'] + if TSL_MODEL_OBJ is not None: + TSL_MODEL_OBJ.unload() + + model = m.TSLModel.from_entrypoint(model_id) + model.load() - TSL_MODEL_OBJ, _ = m.TSLModel.objects.get_or_create(name=model_id) + TSL_MODEL_OBJ = model TSL_MODEL_ID = model_id def get_tsl_model() -> m.TSLModel: """Get the current TSL model.""" return TSL_MODEL_OBJ - -def pre_tokenize( - text: str, - ignore_chars: str = None, break_chars: str = None, break_newlines: bool = False, - restore_dash_newlines: bool = False - ) -> list[str]: - """Pre-tokenize a text string. - - Args: - text (str): Text to tokenize. - ignore_chars (str, optional): String of characters to ignore. Defaults to None. - break_chars (str, optional): String of characters to break on. Defaults to None. - break_newlines (bool, optional): Whether to break on newlines. Defaults to True. - restore_dash_newlines (bool, optional): Whether to restore dash-newlines (word broken with a -newline). - Defaults to False. - - Returns: - list[str]: List of string tokens. - """ - if restore_dash_newlines: - text = re.sub(r'(? 0: - tokens = re.split(f'[{break_chars}+]', text) - - if isinstance(tokens, str): - tokens = [text] - - res = list(filter(None, tokens)) - return res if len(res) > 0 else [' '] - -def get_mnt(ntok: int, options: dict) -> int: - """Get the maximum number of new tokens to generate.""" - min_max_new_tokens = options.get('min_max_new_tokens', 20) - max_max_new_tokens = options.get('max_max_new_tokens', 512) - max_new_tokens = options.get('max_new_tokens', 20) - max_new_tokens_ratio = options.get('max_new_tokens_ratio', 3) - - if min_max_new_tokens > max_max_new_tokens: - raise ValueError('min_max_new_tokens must be less than max_max_new_tokens') - - mnt = min( - max_max_new_tokens, - max( - min_max_new_tokens, - max_new_tokens, - max_new_tokens_ratio * ntok - ) - ) - return mnt - -def _tsl_pipeline( - text: Union[str,list[str]], - lang_src: str, lang_dst: str, - options: dict = None - ) -> Union[str,list[str]]: - """Translate a text using a TSL model. - - Args: - text (Union[str,list[str]]): Text to translate. Can be batched to a list of strings. - lang_src (str): Source language. - lang_dst (str): Destination language. - options (dict, optional): Options for the translation. Defaults to {}. - - Raises: - TypeError: If text is not a string or a list of strings. - - Returns: - Union[str,list[str]]: Translated text. If text is a list, returns a list of translated strings. - """ - if options is None: - options = {} - TSL_TOKENIZER.src_lang = lang_src - - pre_keys = ['ignore_chars', 'break_chars', 'break_newlines', 'restore_dash_newlines'] - pre_dct = {k: options[k] for k in pre_keys if k in options} - - if isinstance(text, list): - tokens = [pre_tokenize(t, **pre_dct) for t in text] - elif isinstance(text, str): - tokens = pre_tokenize(text, **pre_dct) - else: - raise TypeError(f'Unsupported type for text: {type(text)}') - - logger.debug(f'TSL: {tokens}') - if len(tokens) == 0: - return '' - encoded = TSL_TOKENIZER( - tokens, - return_tensors='pt', - padding=True, - truncation=True, - is_split_into_words=True - ) - ntok = encoded['input_ids'].shape[1] - encoded.to(dev) - - mnt = get_mnt(ntok, options) - - kwargs = { - 'max_new_tokens': mnt, - } - if isinstance(TSL_TOKENIZER, M2M100Tokenizer): - kwargs['forced_bos_token_id'] = TSL_TOKENIZER.get_lang_id(lang_dst) - - logger.debug(f'TSL ENCODED: {encoded}') - logger.debug(f'TSL KWARGS: {kwargs}') - generated_tokens = TSL_MODEL.generate( - **encoded, - **kwargs, - ) - - tsl = TSL_TOKENIZER.batch_decode(generated_tokens, skip_special_tokens=True) - logger.debug(f'TSL: {tsl}') - - if isinstance(text, str): - tsl = tsl[0] - - if dev == 'cuda': - torch.cuda.empty_cache() - - return tsl - -def tsl_pipeline(*args, id_: Hashable, batch_id: Hashable = None, block: bool = True, **kwargs): - """Queue a text translation pipeline. - - Args: - id_ (Hashable): A unique identifier for the OCR task. - block (bool, optional): Whether to block until the task is complete. Defaults to True. - - Returns: - Union[str, Message]: The text extracted from the image (block=True) or a Message object (block=False). - """ - msg = q.put( - id_ = id_, - batch_id = batch_id, - msg = {'args': args, 'kwargs': kwargs}, - handler = _tsl_pipeline, - ) - - if block: - return msg.response() - return msg - -def tsl_run( - text_obj: m.Text, src: m.Language, dst: m.Language, options: m.OptionDict = None, - force: bool = False, - block: bool = True, - lazy: bool = False - ) -> Generator[Union[Message, m.Text], None, None]: - """Run a TSL pipeline on a text object. - - Args: - text_obj (m.Text): Text object from the database to translate. - src (m.Language): Source language object from the database. - dst (m.Language): Destination language object from the database. - options (m.OptionDict, optional): OptionDict object from the database. Defaults to None. - force (bool, optional): Whether to force a new TSL run. Defaults to False. - block (bool, optional): Whether to block until the task is complete. Defaults to True. - lazy (bool, optional): Whether to raise an error if the TSL run is not found. Defaults to False. - - Raises: - ValueError: If lazy and force are both True or if lazy is True and the TSL run is not found. - - Yields: - Generator[Union[Message, m.Text], None, None]: - If block is True, yields a Message object for the TSL run first and the resulting Text object second. - If block is False, yields the resulting Text object. - """ - if lazy and force: - raise ValueError('Cannot force + lazy TSL run') - model_obj = get_tsl_model() - options_obj = options or m.OptionDict.objects.get(options={}) - params = { - 'options': options_obj, - 'text': text_obj, - 'model': model_obj, - 'lang_src': src, - 'lang_dst': dst, - } - tsl_run_obj = m.TranslationRun.objects.filter(**params).first() - if tsl_run_obj is None or force: - if lazy: - raise ValueError('Value not found for lazy TSL run') - logger.info('Running TSL') - id_ = (text_obj.id, model_obj.id, options_obj.id, src.id, dst.id) - batch_id = (model_obj.id, options_obj.id, src.id, dst.id) - lang_dct = getattr(src.default_options, 'options', {}) - model_dct = getattr(model_obj.default_options, 'options', {}) - opt_dct = {**lang_dct, **model_dct, **options_obj.options} - new = tsl_pipeline( - text_obj.text, - getattr(src, model_obj.language_format), - getattr(dst, model_obj.language_format), - options=opt_dct, - id_=id_, - batch_id=batch_id, - block=block, - ) - if not block: - yield new - new = new.response() - text_obj, _ = m.Text.objects.get_or_create( - text = new, - ) - params['result'] = text_obj - tsl_run_obj = m.TranslationRun.objects.create(**params) - else: - if not block: - # Both branches should have the same number of yields - yield None - logger.info(f'Reusing TSL <{tsl_run_obj.id}>') - # new = tsl_run_obj.result.text - - yield tsl_run_obj.result diff --git a/ocr_translate/plugins/__init__.py b/ocr_translate/plugins/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ocr_translate/plugins/easyocr.py b/ocr_translate/plugins/easyocr.py new file mode 100644 index 0000000..25af84c --- /dev/null +++ b/ocr_translate/plugins/easyocr.py @@ -0,0 +1,167 @@ +################################################################################### +# ocr_translate - a django app to perform OCR and translation of images. # +# Copyright (C) 2023-present Davide Grassano # +# # +# This program is free software: you can redistribute it and/or modify # +# it under the terms of the GNU General Public License as published by # +# the Free Software Foundation, either version 3 of the License. # +# # +# This program is distributed in the hope that it will be useful, # +# but WITHOUT ANY WARRANTY; without even the implied warranty of # +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # +# GNU General Public License for more details. # +# # +# You should have received a copy of the GNU General Public License # +# along with this program. If not, see {http://www.gnu.org/licenses/}. # +# # +# Home: https://github.com/Crivella/ocr_translate # +################################################################################### +"""OCRtranslate plugin to allow usage of easyocr.""" +import logging +import os +from typing import Iterable + +import easyocr +import numpy as np +import torch +from PIL.Image import Image as PILImage + +from ocr_translate import models as m + +logger = logging.getLogger('plugin') + +class EasyOCRBoxModel(m.OCRBoxModel): + """OCRtranslate plugin to allow usage of easyocr for box detection.""" + class Meta: + proxy = True + + def __init__(self, *args, **kwargs): + """Initialize the model.""" + super().__init__(*args, **kwargs) + + self.reader = None + self.dev = os.environ.get('DEVICE', 'cpu') + + def load(self): + """Load the model into memory.""" + logger.info(f'Loading BOX model: {self.name}') + self.reader = easyocr.Reader([], gpu=(self.dev == 'cuda'), recognizer=False) + + def unload(self) -> None: + """Unload the model from memory.""" + if self.reader is not None: + del self.reader + self.reader = None + + if self.dev == 'cuda': + torch.cuda.empty_cache() + + @staticmethod + def intersections(bboxes: Iterable[tuple[int, int, int, int]], margin: int = 5) -> list[set[int]]: + """Determine the intersections between a list of bounding boxes. + + Args: + bboxes (Iterable[tuple[int, int, int, int]]): List of bounding boxes in lrbt format. + margin (int, optional): Number of extra pixels outside of the boxes that define an intersection. + Defaults to 5. + + Returns: + list[set[int]]: List of sets of indexes of the boxes that intersect. + """ + res = [] + + for i,(l1,r1,b1,t1) in enumerate(bboxes): + l1 -= margin + r1 += margin + b1 -= margin + t1 += margin + + for j,(l2,r2,b2,t2) in enumerate(bboxes): + if i == j: + continue + + if l1 >= r2 or r1 <= l2 or b1 >= t2 or t1 <= b2: + continue + + for ptr in res: + if i in ptr or j in ptr: + break + else: + ptr = set() + res.append(ptr) + + ptr.add(i) + ptr.add(j) + + return res + + @staticmethod + def merge_bboxes(bboxes: Iterable[tuple[int, int, int, int]]) -> list[tuple[int, int, int, int]]: + """Merge a list of intersecting bounding boxes. All intersecting boxes are merged into a single box. + + Args: + bboxes (Iterable[Iterable[int]]): Iterable of bounding boxes in lrbt format. + + Returns: + list[tuple[int]]: List of merged bounding boxes in lrbt format. + """ + res = [] + bboxes = np.array(bboxes) + inters = EasyOCRBoxModel.intersections(bboxes) + + lst = list(range(len(bboxes))) + + torm = set() + for app in inters: + app = list(app) + data = bboxes[app].reshape(-1,4) + l = data[:,0].min() + r = data[:,1].max() + b = data[:,2].min() + t = data[:,3].max() + + res.append([l,b,r,t]) + + torm = torm.union(app) + + for i in lst: + if i in torm: + continue + l,r,b,t = bboxes[i] + res.append([l,b,r,t]) + + return res + + def _box_detection( + self, + image: PILImage, options: dict = None + ) -> list[tuple[int, int, int, int]]: + """Perform box OCR on an image. + + Args: + image (Image.Image): A Pillow image on which to perform OCR. + options (dict, optional): A dictionary of options. + + Raises: + NotImplementedError: The type of model specified is not implemented. + + Returns: + list[tuple[int, int, int, int]]: A list of bounding boxes in lrbt format. + """ + + if options is None: + options = {} + + # reader.recognize(image) + image = image.convert('RGB') + results = self.reader.detect(np.array(image)) + + # Axis rectangles + bboxes = results[0][0] + + # Free (NOT IMPLEMENTED) + # ... + + bboxes = self.merge_bboxes(bboxes) + + return bboxes diff --git a/ocr_translate/plugins/hugginface.py b/ocr_translate/plugins/hugginface.py new file mode 100644 index 0000000..5dbbf5f --- /dev/null +++ b/ocr_translate/plugins/hugginface.py @@ -0,0 +1,289 @@ +################################################################################### +# ocr_translate - a django app to perform OCR and translation of images. # +# Copyright (C) 2023-present Davide Grassano # +# # +# This program is free software: you can redistribute it and/or modify # +# it under the terms of the GNU General Public License as published by # +# the Free Software Foundation, either version 3 of the License. # +# # +# This program is distributed in the hope that it will be useful, # +# but WITHOUT ANY WARRANTY; without even the implied warranty of # +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # +# GNU General Public License for more details. # +# # +# You should have received a copy of the GNU General Public License # +# along with this program. If not, see {http://www.gnu.org/licenses/}. # +# # +# Home: https://github.com/Crivella/ocr_translate # +################################################################################### +"""OCRtranslate plugin to allow loading of hugginface models.""" +import logging +import os +from pathlib import Path + +import torch +from PIL import Image +from transformers import (AutoImageProcessor, AutoModel, AutoModelForSeq2SeqLM, + AutoTokenizer, M2M100Tokenizer, + VisionEncoderDecoderModel) + +from ocr_translate import models as m + +logger = logging.getLogger('plugin') + +class Loaders(): + """Generic functions to load HuggingFace's Classes.""" + accept_device = ['ved_model', 'seq2seq', 'model'] + + mapping = { + 'tokenizer': AutoTokenizer, + 'ved_model': VisionEncoderDecoderModel, + 'model': AutoModel, + 'image_processor': AutoImageProcessor, + 'seq2seq': AutoModelForSeq2SeqLM + } + + @staticmethod + def _load(loader, model_id: str, root: Path): + """Use the specified loader to load a transformers specific Class.""" + try: + mid = root / model_id + logger.debug(f'Attempt loading from store: "{loader}" "{mid}"') + res = loader.from_pretrained(mid) + except Exception: + # Needed to catch some weird exception from transformers + # eg: huggingface_hub.utils._validators.HFValidationError: Repo id must use alphanumeric chars or + # '-', '_', '.', '--' and '..' are forbidden, '-' and '.' + # cannot start or end the name, max length is 96: ... + logger.debug(f'Attempt loading from cache: "{loader}" "{model_id}" "{root}"') + res = loader.from_pretrained(model_id, cache_dir=root) + return res + + @staticmethod + def load(model_id: str, request: list[str], root: Path, dev: str = 'cpu') -> list: + """Load the requested HuggingFace's Classes for the model into the memory of the globally specified device. + + Args: + model_id (str): The HuggingFace model id to load, or a path to a local model. + request (list[str]): A list of HuggingFace's Classes to load. + root (Path): The root path to use for the cache. + + Raises: + ValueError: If the model_id is not found or if the requested Class is not supported. + + Returns: + _type_: A list of the requested Classes. + """ """""" + res = {} + for r in request: + if r not in Loaders.mapping: + raise ValueError(f'Unknown request: {r}') + cls = Loaders._load(Loaders.mapping[r], model_id, root) + if cls is None: + raise ValueError(f'Could not load model: {model_id}') + + if r in Loaders.accept_device: + cls = cls.to(dev) + + res[r] = cls + + return res + + +def get_mnt(ntok: int, options: dict) -> int: + """Get the maximum number of new tokens to generate.""" + min_max_new_tokens = options.get('min_max_new_tokens', 20) + max_max_new_tokens = options.get('max_max_new_tokens', 512) + max_new_tokens = options.get('max_new_tokens', 20) + max_new_tokens_ratio = options.get('max_new_tokens_ratio', 3) + + if min_max_new_tokens > max_max_new_tokens: + raise ValueError('min_max_new_tokens must be less than max_max_new_tokens') + + mnt = min( + max_max_new_tokens, + max( + min_max_new_tokens, + max_new_tokens, + max_new_tokens_ratio * ntok + ) + ) + return mnt + +class EnvMixin(): + """Mixin to allow usage of environment variables.""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.dev = os.environ.get('DEVICE', 'cpu') + self.root = Path(os.environ.get('TRANSFORMERS_CACHE', '.')) + logger.debug(f'Cache dir: {self.root}') + +class HugginfaceSeq2SeqModel(m.TSLModel, EnvMixin): + """OCRtranslate plugin to allow loading of hugginface seq2seq model as translator.""" + + class Meta: + proxy = True + + def __init__(self, *args, **kwargs): + """Initialize the model.""" + super().__init__(*args, **kwargs) + self.tokenizer = None + self.model = None + + def load(self): + """Load the model into memory.""" + logger.info(f'Loading TSL model: {self.name}') + res = Loaders.load(self.name, request=['seq2seq', 'tokenizer'], root=self.root, dev=self.dev) + self.model = res['seq2seq'] + self.tokenizer = res['tokenizer'] + + def unload(self) -> None: + """Unload the model from memory.""" + if self.model is not None: + del self.model + self.model = None + if self.tokenizer is not None: + del self.tokenizer + self.tokenizer = None + + if self.dev == 'cuda': + torch.cuda.empty_cache() + + + def _translate(self, tokens: list, src_lang: str, dst_lang: str, options: dict = None) -> str | list[str]: + """Translate a text using a the loaded model. + + Args: + tokens (list): list or list[list] of string tokens to be translated. + lang_src (str): Source language. + lang_dst (str): Destination language. + options (dict, optional): Options for the translation. Defaults to {}. + + Raises: + TypeError: If text is not a string or a list of strings. + + Returns: + Union[str,list[str]]: Translated text. If text is a list, returns a list of translated strings. + """ + if self.model is None or self.tokenizer is None: + raise RuntimeError('Model not loaded') + if options is None: + options = {} + + logger.debug(f'TSL: {tokens}') + if len(tokens) == 0: + return '' + + self.tokenizer.src_lang = src_lang + encoded = self.tokenizer( + tokens, + return_tensors='pt', + padding=True, + truncation=True, + is_split_into_words=True + ) + ntok = encoded['input_ids'].shape[1] + encoded.to(self.dev) + + mnt = get_mnt(ntok, options) + + kwargs = { + 'max_new_tokens': mnt, + } + if isinstance(self.tokenizer, M2M100Tokenizer): + kwargs['forced_bos_token_id'] = self.tokenizer.get_lang_id(dst_lang) + + logger.debug(f'TSL ENCODED: {encoded}') + logger.debug(f'TSL KWARGS: {kwargs}') + generated_tokens = self.model.generate( + **encoded, + **kwargs, + ) + + tsl = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) + logger.debug(f'TSL: {tsl}') + + if isinstance(tokens[0], str): + tsl = tsl[0] + + if self.dev == 'cuda': + torch.cuda.empty_cache() + + return tsl + + # def translate_batch(self, texts): + # """Translate a batch of texts.""" + # raise NotImplementedError + +class HugginfaceVEDModel(m.OCRModel, EnvMixin): + """OCRtranslate plugin to allow loading of hugginface VisionEncoderDecoder model as text OCR.""" + class Meta: + proxy = True + + def __init__(self, *args, **kwargs): + """Initialize the model.""" + super().__init__(*args, **kwargs) + self.tokenizer = None + self.model = None + self.image_processor = None + + def load(self): + """Load the model into memory.""" + logger.info(f'Loading OCR VED model: {self.name}') + res = Loaders.load( + self.name, request=['ved_model', 'tokenizer', 'image_processor'], + root=self.root, dev=self.dev + ) + self.model = res['ved_model'] + self.tokenizer = res['tokenizer'] + self.image_processor = res['image_processor'] + + def unload(self) -> None: + """Unload the model from memory.""" + if self.model is not None: + del self.model + self.model = None + if self.tokenizer is not None: + del self.tokenizer + self.tokenizer = None + if self.image_processor is not None: + del self.image_processor + self.image_processor = None + + if self.dev == 'cuda': + torch.cuda.empty_cache() + + def _ocr( + self, + img: Image.Image, lang: str = None, options: dict = None + ) -> str: + """Perform OCR on an image. + + Args: + img (Image.Image): A Pillow image on which to perform OCR. + lang (str, optional): The language to use for OCR. (Not every model will use this) + bbox (tuple[int, int, int, int], optional): The bounding box of the text on the image in lbrt format. + options (dict, optional): A dictionary of options to pass to the OCR model. + + Raises: + TypeError: If img is not a Pillow image. + + Returns: + str: The text extracted from the image. + """ + if self.model is None or self.tokenizer is None or self.image_processor is None: + raise RuntimeError('Model not loaded') + + if options is None: + options = {} + + pixel_values = self.image_processor(img, return_tensors='pt').pixel_values + if self.dev == 'cuda': + pixel_values = pixel_values.cuda() + generated_ids = self.model.generate(pixel_values) + generated_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] + + if self.dev == 'cuda': + torch.cuda.empty_cache() + + return generated_text diff --git a/ocr_translate/plugins/tesseract.py b/ocr_translate/plugins/tesseract.py new file mode 100644 index 0000000..ffde7cd --- /dev/null +++ b/ocr_translate/plugins/tesseract.py @@ -0,0 +1,160 @@ +################################################################################### +# ocr_translate - a django app to perform OCR and translation of images. # +# Copyright (C) 2023-present Davide Grassano # +# # +# This program is free software: you can redistribute it and/or modify # +# it under the terms of the GNU General Public License as published by # +# the Free Software Foundation, either version 3 of the License. # +# # +# This program is distributed in the hope that it will be useful, # +# but WITHOUT ANY WARRANTY; without even the implied warranty of # +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # +# GNU General Public License for more details. # +# # +# You should have received a copy of the GNU General Public License # +# along with this program. If not, see {http://www.gnu.org/licenses/}. # +# # +# Home: https://github.com/Crivella/ocr_translate # +################################################################################### +"""OCRtranslate plugin to allow loading of tesseract models.""" + +import logging +import os +from pathlib import Path + +import requests +from PIL import Image +from pytesseract import Output, image_to_string + +from ocr_translate import models as m + +logger = logging.getLogger('plugin') + +MODEL_URL = 'https://github.com/tesseract-ocr/tessdata_best/raw/main/{}.traineddata' + +# root = Path(os.environ.get('TRANSFORMERS_CACHE', '.')) +# DATA_DIR = Path(os.getenv('TESSERACT_PREFIX', root / 'tesseract')) + +# DOWNLOAD = os.getenv('TESSERACT_ALLOW_DOWNLOAD', 'false').lower() == 'true' + +class TesseractOCRModel(m.OCRModel): + """OCRtranslate plugin to allow usage of Tesseract models.""" + VERTICAL_LANGS = ['jpn', 'chi_tra', 'chi_sim', 'kor'] + config = False + + class Meta: + proxy = True + + def __init__(self, *args, **kwargs): + """Initialize the model.""" + super().__init__(*args, **kwargs) + + root = Path(os.environ.get('TRANSFORMERS_CACHE', '.')) + self.data_dir = Path(os.getenv('TESSERACT_PREFIX', root / 'tesseract')) + self.download = os.getenv('TESSERACT_ALLOW_DOWNLOAD', 'false').lower() == 'true' + + def download_model(self, lang: str): + """Download a tesseract model for a given language. + + Args: + lang (str): A language code for tesseract. + + Raises: + ValueError: If the model could not be downloaded. + """ + if not self.download: + raise ValueError('TESSERACT_ALLOW_DOWNLOAD is false. Downloading models is not allowed') + self.create_config() + + logger.info(f'Downloading tesseract model for language {lang}') + dst = self.data_dir / f'{lang}.traineddata' + if dst.exists(): + return + res = requests.get(MODEL_URL.format(lang), timeout=5) + if res.status_code != 200: + raise ValueError(f'Could not download model for language {lang}') + + with open(self.data_dir / f'{lang}.traineddata', 'wb') as f: + f.write(res.content) + + if lang in self.VERTICAL_LANGS: + self.download_model(lang + '_vert') + + def load(self): + """Mock load, not needed for tesseract. Every call done via CLI.""" + + def unload(self) -> None: + """Mock unload, not needed for tesseract. Every call done via CLI.""" + + def create_config(self): + """Create a tesseract config file. Run only once""" + if self.config: + return + self.config = True + + logger.info('Creating tesseract tsv config') + cfg = self.data_dir / 'configs' + cfg.mkdir(exist_ok=True, parents=True) + + dst = cfg / 'tsv' + if dst.exists(): + return + with dst.open('w') as f: + f.write('tessedit_create_tsv 1') + + # Page segmentation modes: + # 0 Orientation and script detection (OSD) only. + # 1 Automatic page segmentation with OSD. + # 2 Automatic page segmentation, but no OSD, or OCR. + # 3 Fully automatic page segmentation, but no OSD. (Default) + # 4 Assume a single column of text of variable sizes. + # 5 Assume a single uniform block of vertically aligned text. + # 6 Assume a single uniform block of text. + # 7 Treat the image as a single text line. + # 8 Treat the image as a single word. + # 9 Treat the image as a single word in a circle. + # 10 Treat the image as a single character. + # 11 Sparse text. Find as much text as possible in no particular order. + # 12 Sparse text with OSD. + # 13 Raw line. Treat the image as a single text line, + # bypassing hacks that are Tesseract-specific. + def _ocr( + self, + img: Image.Image, lang: str = None, options: dict = None + ) -> str: + """Run tesseract on an image. + + Args: + img (Image.Image): An image to run tesseract on. + lang (str): A language code for tesseract. + favor_vertical (bool, optional): Wether to favor vertical or horizontal configuration for languages that + can be written vertically. Defaults to True. + + Returns: + str: The text extracted from the image. + """ + if options is None: + options = {} + + self.create_config() + if not (self.data_dir / f'{lang}.traineddata').exists(): + self.download_model(lang) + logger.info(f'Running tesseract for language {lang}') + + favor_vertical = options.get('favor_vertical', True) + + psm = 6 + if lang in self.VERTICAL_LANGS: + exp = 1 if favor_vertical else -1 + if img.height * 1.5**exp > img.width: + psm = 5 + + # Using image_to_string will atleast preserve spaces + res = image_to_string( + img, + lang=lang, + config=f'--tessdata-dir {self.data_dir.as_posix()} --psm {psm}', + output_type=Output.DICT + ) + + return res['text'] diff --git a/ocr_translate/views.py b/ocr_translate/views.py index 914e32e..4e6ae58 100644 --- a/ocr_translate/views.py +++ b/ocr_translate/views.py @@ -37,8 +37,7 @@ from .ocr_tsl.lang import (get_lang_dst, get_lang_src, load_lang_dst, load_lang_src) from .ocr_tsl.ocr import get_ocr_model, load_ocr_model, unload_ocr_model -from .ocr_tsl.tsl import (get_tsl_model, load_tsl_model, tsl_run, - unload_tsl_model) +from .ocr_tsl.tsl import get_tsl_model, load_tsl_model, unload_tsl_model from .queues import main_queue as q logger = logging.getLogger('ocr.general') @@ -217,9 +216,11 @@ def run_tsl(request: HttpRequest) -> JsonResponse: if len(data) > 0: return JsonResponse({'error': f'invalid data: {data}'}, status=400) + tsl_model = get_tsl_model() + src_obj, _ = m.Text.objects.get_or_create(text=text) - dst_obj = tsl_run(src_obj, get_lang_src(), get_lang_dst()) + dst_obj = tsl_model.translate(src_obj, get_lang_src(), get_lang_dst()) dst_obj = next(dst_obj) return JsonResponse({ diff --git a/pyproject.toml b/pyproject.toml index c1dc656..bc988cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,6 +76,16 @@ release = [ "pyinstaller" ] +[project.entry-points."ocr_translate.box_models"] +"easyocr.box" = "ocr_translate.plugins.easyocr:EasyOCRBoxModel" + +[project.entry-points."ocr_translate.ocr_models"] +"tesseract.ocr" = "ocr_translate.plugins.tesseract:TesseractOCRModel" +"hugginface.ved" = "ocr_translate.plugins.hugginface:HugginfaceVEDModel" + +[project.entry-points."ocr_translate.tsl_models"] +"hugginface.seq2seq" = "ocr_translate.plugins.hugginface:HugginfaceSeq2SeqModel" + [tool.flit.module] name = "ocr_translate" diff --git a/tests/conftest.py b/tests/conftest.py index 6630dee..b579878 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,6 +27,33 @@ from ocr_translate import ocr_tsl, views from ocr_translate.ocr_tsl import box, lang, ocr, tsl +strings = [ + 'This is a test string.', + 'This is a test string.\nWith a newline.', + 'This is a test string.\nWith a newline.\nAnd another.', + 'This is a test string.? With a special break character.', + 'This is a test string.? With a special break character.\nAnd a newline.', + 'String with a dash-newline brok-\nen word.' +] +ids = [ + 'simple', + 'newline', + 'newlines', + 'breakchar', + 'breakchar_newline', + 'dash_newline' +] + +@pytest.fixture(params=strings, ids=ids) +def string(request): + """String to perform TSL on.""" + return request.param + +@pytest.fixture() +def batch_string(string): + """Batched string to perform TSL on.""" + return [string, string, string] + @pytest.fixture() def language_dict(): @@ -43,24 +70,27 @@ def language_dict(): def box_model_dict(): """Dict defiining an OCRBoxModel""" return { - 'name': 'test_model/id', - 'language_format': 'iso1' + 'name': 'test_box_model/id', + 'language_format': 'iso1', + 'entrypoint': 'test_entrypoint.box' } @pytest.fixture() def ocr_model_dict(): """Dict defiining an OCRModel""" return { - 'name': 'test_model/id', - 'language_format': 'iso1' + 'name': 'test_ocr_model/id', + 'language_format': 'iso1', + 'entrypoint': 'test_entrypoint.ocr' } @pytest.fixture() def tsl_model_dict(): """Dict defiining a TSLModel""" return { - 'name': 'test_model/id', - 'language_format': 'iso1' + 'name': 'test_tsl_model/id', + 'language_format': 'iso1', + 'entrypoint': 'test_entrypoint.tsl' } @pytest.fixture() @@ -147,7 +177,7 @@ def tsl_run(language, text, tsl_model, option_dict): def mock_loaded(monkeypatch, language, box_model, ocr_model, tsl_model): """Mock models being loaded""" monkeypatch.setattr(box, 'BOX_MODEL_ID', box_model.name) - monkeypatch.setattr(box, 'BBOX_MODEL_OBJ', box_model) + monkeypatch.setattr(box, 'BOX_MODEL_OBJ', box_model) monkeypatch.setattr(ocr, 'OBJ_MODEL_ID', ocr_model.name) monkeypatch.setattr(ocr, 'OCR_MODEL_OBJ', ocr_model) monkeypatch.setattr(tsl, 'TSL_MODEL_ID', tsl_model.name) @@ -164,7 +194,7 @@ def mock_load_lang_dst(name): monkeypatch.setattr(lang, 'LANG_DST', m.Language.objects.get(iso1=name)) def mock_load_box_model(name): monkeypatch.setattr(box, 'BOX_MODEL_ID', name) - monkeypatch.setattr(box, 'BBOX_MODEL_OBJ', m.OCRBoxModel.objects.get(name=name)) + monkeypatch.setattr(box, 'BOX_MODEL_OBJ', m.OCRBoxModel.objects.get(name=name)) def mock_load_ocr_model(name): monkeypatch.setattr(ocr, 'OBJ_MODEL_ID', name) monkeypatch.setattr(ocr, 'OCR_MODEL_OBJ', m.OCRModel.objects.get(name=name)) diff --git a/tests/ocr_tsl/conftest.py b/tests/ocr_tsl/conftest.py index 0e8ff8c..31e983a 100644 --- a/tests/ocr_tsl/conftest.py +++ b/tests/ocr_tsl/conftest.py @@ -20,188 +20,18 @@ import pytest -strings = [ - 'This is a test string.', - 'This is a test string.\nWith a newline.', - 'This is a test string.\nWith a newline.\nAnd another.', - 'This is a test string.? With a special break character.', - 'This is a test string.? With a special break character.\nAnd a newline.', - 'String with a dash-newline brok-\nen word.' -] -ids = [ - 'simple', - 'newline', - 'newlines', - 'breakchar', - 'breakchar_newline', - 'dash_newline' -] - -@pytest.fixture(params=strings, ids=ids) -def string(request): - """String to perform TSL on.""" - return request.param - -@pytest.fixture() -def batch_string(string): - """Batched string to perform TSL on.""" - return [string, string, string] - -@pytest.fixture() -def mock_tsl_tokenizer(): - """Mock tokenizer for TSL.""" - import torch # pylint: disable=import-outside-toplevel - class _MockTokenizer(): - def __init__(self, model_id): - self.model_id = model_id - self.other_options = {} - self.tok_to_word = {0: 0} - self.word_to_tok = {0: 0} - self.ntoks = 1 - self.called_get_lang_id = False - - def __call__(self, text, **options): - issplit = options.pop('is_split_into_words', False) - padding = options.pop('padding', False) - truncation = options.pop('truncation', False) # pylint: disable=unused-variable - - self.other_options = options - - if isinstance(text, list): - if isinstance(text[0], str): - text = [text] - if issplit: - app = [] - for line in text: - app2 = [] - for seg in line: - app2.extend(seg.split(' ')) - app.append(app2) - else: - app = [_.split(' ') for _ in text] - - if padding: - app2 = [] - for lst in app: - app3 = [] - for word in lst: - if word not in self.word_to_tok: - self.word_to_tok[word] = self.ntoks - self.tok_to_word[self.ntoks] = word - self.ntoks += 1 - app3.append(self.word_to_tok[word]) - app2.append(app3) - - max_len = max(len(_) for _ in app2) - res = [(_ + [0] * max_len)[:max_len] for _ in app2] - else: - res = app - class Dict(dict): - """Dict class with added .to method""" - def to(self, device): # pylint: disable=unused-argument,invalid-name - """Move the dict to a device.""" - return None - - dct = Dict([('input_ids', torch.Tensor(res))]) - return dct - - raise TypeError(f'Expected list of strings, but got {type(text)}') - - def batch_decode(self, tokens, **options): # pylint: disable=unused-argument - """Decode a batch of tokens.""" - res = [' '.join(filter(None, [self.tok_to_word[int(_)] for _ in lst])) for lst in tokens] - return res - - def get_lang_id(self, lang): # pylint: disable=unused-argument - """Get the language id.""" - self.called_get_lang_id = True - return 0 - - return _MockTokenizer - -@pytest.fixture() -def mock_tsl_model(): - """Mock model for TSL.""" - class _MockModel(): - def __init__(self, model_id): - self.model_id = model_id - self.options = {} - - def generate(self, input_ids=None, **options): - """Mock generate translated tokens.""" - self.options = options - return input_ids - - return _MockModel @pytest.fixture() -def mock_ocr_preprocessor(): - """Mock preprocessor for OCR.""" - class RES(): - """Mock result""" +def mock_base_model(): + """Mock BaseModel class.""" + class MockModel(): # pylint: disable=invalid-name + """Mocked BaseModel class.""" def __init__(self): - class PV(list): - """Mock pixel values""" - def cuda(self): - """Mock cuda""" - self.cuda_called = True # pylint: disable=attribute-defined-outside-init - return self - self.pixel_values = PV([1,2,3,4,5]) - - class _MockPreprocessor(): - def __init__(self, model_id): - self.model_id = model_id - self.options = {} - - def __call__(self, img, **options): - self.options = options - res = RES() - return res - - return _MockPreprocessor - -@pytest.fixture() -def mock_ocr_tokenizer(): - """Mock tokenizer for OCR.""" - class _MockTokenizer(): - def __init__(self, model_id): - self.model_id = model_id - self.options = {} - - def batch_decode(self, tokens, **options): - """Mock batch decode.""" - self.options = options - offset = ord('a') - 1 - return [''.join(chr(int(_)+offset) for _ in tokens)] - - return _MockTokenizer - -@pytest.fixture() -def mock_ocr_model(): - """Mock model for OCR.""" - class _MockModel(): - def __init__(self, model_id): - self.model_id = model_id - self.options = {} - - def generate(self, pixel_values=None, **options): - """Mock generate.""" - self.options = options - return pixel_values - - return _MockModel - -@pytest.fixture() -def mock_box_reader(): - """Mock box reader.""" - class _MockReader(): - def __init__(self, model_id): - self.model_id = model_id - self.options = {} - - def detect(self, img, **options): # pylint: disable=unused-argument - """Mock recognize.""" - self.options = options - return (([(10,10,30,30), (40,40,50,50)],),) - - return _MockReader + self.load_called = False + self.unload_called = False + def load(self): # pylint: disable=missing-function-docstring + self.load_called = True + def unload(self): # pylint: disable=missing-function-docstring + self.unload_called = True + + return MockModel diff --git a/tests/ocr_tsl/test_box.py b/tests/ocr_tsl/test_box.py index b23cc96..a023eb2 100644 --- a/tests/ocr_tsl/test_box.py +++ b/tests/ocr_tsl/test_box.py @@ -21,188 +21,115 @@ import pytest from ocr_translate import models as m -from ocr_translate.messaging import Message +# from ocr_translate.messaging import Message from ocr_translate.ocr_tsl import box -boxes = [ - ((10,10,30,30), (15,15,20,20)), # b2 inside b1 - ((15,15,20,20), (10,10,30,30)), # b1 inside b2 - - ((30,30,50,50), (10,10,20,35)), # l1 > r2 - ((30,30,50,50), (55,10,75,35)), # r1 < l2 - ((30,30,50,50), (10,10,35,20)), # b1 > t2 - ((30,30,50,50), (10,55,35,75)), # t1 < b2 - - ((30,30,50,50), (10,10,35,35)), # b2-tr inside b1 - ((30,30,50,50), (45,10,75,35)), # b2-tl inside b1 - ((30,30,50,50), (40,45,75,75)), # b2-bl inside b1 - ((30,30,50,50), (10,45,35,75)), # b2-br inside b1 - - ((10,50,70,60), (50,10,60,70)), # intersection, but cornder not inside - - ((10,10,30,30), (29,29,51,40), (50,10,60,30)), # 3x intersection - ((10,10,30,30), (29,29,51,40), (60,10,70,30)), # 2x intersection + 1 -] -ids = [ - 'b2_inside_b1', - 'b1_inside_b2', - 'l1_>_r2', - 'r1_<_l2', - 'b1_>_t2', - 't1_<_b2', - 'b2-tr_inside_b1', - 'b2-tl_inside_b1', - 'b2-bl_inside_b1', - 'b2-br_inside_b1', - 'int_nocorners', - '3x_intersection', - '2x_intersection_+_1', -] - -def test_intersection_merge(data_regression): - """Test intersections function.""" - - res = [] - for boxes_lbrt,idx in zip(boxes,ids): - ptr = {} - ptr['idx'] = idx - boxes_lrbt = [] - for l,b,r,t in boxes_lbrt: - boxes_lrbt.append((l,r,b,t)) - ptr['box_lst'] = boxes_lrbt - ptr['intersection'] = box.intersections(boxes_lrbt) - merge = box.merge_bboxes(boxes_lrbt) - merge = [[int(_) for _ in el] for el in merge] - ptr['merge'] = merge - res.append(ptr) - - data_regression.check({'res': res}) - -def test_load_box_model_notimplemented(): - """Test load box model. With not implemented model.""" - model_id = 'notimplemented' - with pytest.raises(NotImplementedError): - box.load_box_model(model_id) - -def test_load_box_model_already_loaded(monkeypatch, mock_called): - """Test load box model. With already loaded model.""" - model_id = 'easyocr' - monkeypatch.setattr(box.easyocr, 'Reader', mock_called) - monkeypatch.setattr(box, 'BOX_MODEL_ID', model_id) - box.load_box_model(model_id) - - assert not hasattr(mock_called, 'called') +# def test_load_box_model_notimplemented(): +# """Test load box model. With not implemented model.""" +# model_id = 'notimplemented' +# with pytest.raises(NotImplementedError): +# box.load_box_model(model_id) @pytest.mark.django_db -def test_load_box_model_easyocr(monkeypatch): - """Test load box model. Success""" - model_id = 'easyocr' - monkeypatch.setattr(box.easyocr, 'Reader', lambda *args, **kwargs: 'mocked') - - # Needed to make sure that changes doen by `load_box_model` are not persisted +def test_load_box_model(monkeypatch, mock_called, box_model: m.OCRBoxModel): + """Test load box model from scratch.""" + model_id = box_model.id + def mock_fromentrypoint(*args, **kwargs): + mock_fromentrypoint.called = True + return box_model + # Required to avoid setting global variables for future `from clean` tests monkeypatch.setattr(box, 'BOX_MODEL_ID', None) - monkeypatch.setattr(box, 'READER', None) - monkeypatch.setattr(box, 'BBOX_MODEL_OBJ', None) - - assert m.OCRBoxModel.objects.count() == 0 + monkeypatch.setattr(box, 'BOX_MODEL_OBJ', None) + monkeypatch.setattr(m.OCRBoxModel, 'from_entrypoint', mock_fromentrypoint) + box_model.load = mock_called box.load_box_model(model_id) - assert m.OCRBoxModel.objects.count() == 1 - assert box.BOX_MODEL_ID == model_id - assert box.READER == 'mocked' # Check that the mocked function was called and READER was set by loader + assert hasattr(mock_fromentrypoint, 'called') + assert hasattr(mock_called, 'called') -def test_unload_box_model(monkeypatch): - """Test unload box model.""" +def test_load_box_model_already_loaded(monkeypatch, mock_called): + """Test load box model. With already loaded model.""" model_id = 'easyocr' + monkeypatch.setattr(m.OCRBoxModel, 'from_entrypoint', mock_called) monkeypatch.setattr(box, 'BOX_MODEL_ID', model_id) - monkeypatch.setattr(box, 'READER', 'mocked') - monkeypatch.setattr(box, 'BBOX_MODEL_OBJ', 'test') - - box.unload_box_model() - - assert box.BBOX_MODEL_OBJ is None - assert box.BOX_MODEL_ID is None - assert box.READER is None - -def test_unload_box_model_cpu(monkeypatch, mock_called): - """Test unload box model with cpu.""" - monkeypatch.setattr(box.torch.cuda, 'empty_cache', mock_called) - monkeypatch.setattr(box, 'dev', 'cpu') + box.load_box_model(model_id) - box.unload_box_model() assert not hasattr(mock_called, 'called') -def test_unload_box_model_cuda(monkeypatch, mock_called): - """Test unload box model with cuda.""" - monkeypatch.setattr(box.torch.cuda, 'empty_cache', mock_called) - monkeypatch.setattr(box, 'dev', 'cuda') - - box.unload_box_model() - assert hasattr(mock_called, 'called') - def test_get_box_model(monkeypatch): """Test get box model function.""" - monkeypatch.setattr(box, 'BBOX_MODEL_OBJ', 'test') + monkeypatch.setattr(box, 'BOX_MODEL_OBJ', 'test') assert box.get_box_model() == 'test' -def test_box_pipeline_notimplemented(monkeypatch): - """Test box pipeline. With not implemented model.""" - model_id = 'notimplemented' - monkeypatch.setattr(box, 'BOX_MODEL_ID', model_id) - - with pytest.raises(NotImplementedError): - box._box_pipeline(None) # pylint: disable=protected-access - -def test_box_pipeline_easyocr(image_pillow, monkeypatch, mock_box_reader): - """Test box pipeline.""" - model_id = 'easyocr' - - monkeypatch.setattr(box, 'BOX_MODEL_ID', model_id) - monkeypatch.setattr(box, 'READER', mock_box_reader(model_id)) - - res = box._box_pipeline(image_pillow) # pylint: disable=protected-access - - assert res == [[10,30,10,30], [40,50,40,50]] - -def test_queue_placer_handler(monkeypatch, mock_called): - """Test queue_placer is setting _box_pipeline as handler, and that it is called.""" - monkeypatch.setattr(box, '_box_pipeline', mock_called) - monkeypatch.setattr(box.q.msg_queue, 'reuse_msg', False) - box.box_pipeline(id_=1, block=True) - assert hasattr(mock_called, 'called') +def test_unload_box_model(monkeypatch): + """Test unload box model function.""" + monkeypatch.setattr(box, 'BOX_MODEL_OBJ', 'test') + monkeypatch.setattr(box, 'BOX_MODEL_ID', 'test') + box.unload_box_model() + assert box.BOX_MODEL_OBJ is None + assert box.BOX_MODEL_ID is None -@pytest.mark.parametrize('mock_called', ['test_return'], indirect=True) -def test_queue_placer_blocking(monkeypatch, mock_called): - """Test queue_placer with blocking""" - monkeypatch.setattr(box, '_box_pipeline', mock_called) - monkeypatch.setattr(box.q.msg_queue, 'reuse_msg', False) - res = box.box_pipeline(id_=1, block=True) - assert hasattr(mock_called, 'called') - assert res == mock_called.expected - -@pytest.mark.parametrize('mock_called', ['test_return'], indirect=True) -def test_queue_placer_nonblocking(monkeypatch, mock_called): - """Test queue_placer with blocking""" - monkeypatch.setattr(box, '_box_pipeline', mock_called) - monkeypatch.setattr(box.q.msg_queue, 'reuse_msg', False) - box.q.stop_workers() - res = box.box_pipeline(id_=1, block=False) - assert isinstance(res, Message) - - assert not hasattr(mock_called, 'called') # Before resolving the message the handler is not called - box.q.start_workers() - assert res.response() == mock_called.expected - assert hasattr(mock_called, 'called') # After resolving the message the handler is called - -def test_box_pipeline_worker(): - """Test tsl pipeline with worker""" - placeholder = 'placeholder' - box.q.stop_workers() - - messages = [box.box_pipeline(placeholder, id_=i, block=False) for i in range(3)] - assert all(isinstance(_, Message) for _ in messages) - def gen(): - while not box.q.msg_queue.empty(): - yield box.q.msg_queue.get() - res = list(gen()) - assert len(res) == len(messages) +def test_unload_box_model_if_loaded(monkeypatch): + """Test unload box model is called if load with an already loaded model.""" + class A(): # pylint: disable=missing-class-docstring,invalid-name + def __init__(self): + self.load_called = False + self.unload_called = False + def load(self): # pylint: disable=missing-function-docstring + self.load_called = True + def unload(self): # pylint: disable=missing-function-docstring + self.unload_called = True + a = A() # pylint: disable=invalid-name + b = A() # pylint: disable=invalid-name + monkeypatch.setattr(box, 'BOX_MODEL_OBJ', a) + monkeypatch.setattr(box, 'BOX_MODEL_ID', 'test') + monkeypatch.setattr(m.OCRBoxModel, 'from_entrypoint', lambda *args, **kwargs: b) + box.load_box_model('test2') + + assert not a.load_called + assert a.unload_called + assert b.load_called + assert not b.unload_called + +# def test_queue_placer_handler(monkeypatch, mock_called): +# """Test queue_placer is setting _box_pipeline as handler, and that it is called.""" +# monkeypatch.setattr(box, '_box_pipeline', mock_called) +# monkeypatch.setattr(box.q.msg_queue, 'reuse_msg', False) +# box.box_pipeline(id_=1, block=True) +# assert hasattr(mock_called, 'called') + +# @pytest.mark.parametrize('mock_called', ['test_return'], indirect=True) +# def test_queue_placer_blocking(monkeypatch, mock_called): +# """Test queue_placer with blocking""" +# monkeypatch.setattr(box, '_box_pipeline', mock_called) +# monkeypatch.setattr(box.q.msg_queue, 'reuse_msg', False) +# res = box.box_pipeline(id_=1, block=True) +# assert hasattr(mock_called, 'called') +# assert res == mock_called.expected + +# @pytest.mark.parametrize('mock_called', ['test_return'], indirect=True) +# def test_queue_placer_nonblocking(monkeypatch, mock_called): +# """Test queue_placer with blocking""" +# monkeypatch.setattr(box, '_box_pipeline', mock_called) +# monkeypatch.setattr(box.q.msg_queue, 'reuse_msg', False) +# box.q.stop_workers() +# res = box.box_pipeline(id_=1, block=False) +# assert isinstance(res, Message) + +# assert not hasattr(mock_called, 'called') # Before resolving the message the handler is not called +# box.q.start_workers() +# assert res.response() == mock_called.expected +# assert hasattr(mock_called, 'called') # After resolving the message the handler is called + +# def test_box_pipeline_worker(): +# """Test tsl pipeline with worker""" +# placeholder = 'placeholder' +# box.q.stop_workers() + +# messages = [box.box_pipeline(placeholder, id_=i, block=False) for i in range(3)] +# assert all(isinstance(_, Message) for _ in messages) +# def gen(): +# while not box.q.msg_queue.empty(): +# yield box.q.msg_queue.get() +# res = list(gen()) +# assert len(res) == len(messages) diff --git a/tests/ocr_tsl/test_huggingface.py b/tests/ocr_tsl/test_huggingface.py deleted file mode 100644 index 012a31c..0000000 --- a/tests/ocr_tsl/test_huggingface.py +++ /dev/null @@ -1,124 +0,0 @@ -################################################################################### -# ocr_translate - a django app to perform OCR and translation of images. # -# Copyright (C) 2023-present Davide Grassano # -# # -# This program is free software: you can redistribute it and/or modify # -# it under the terms of the GNU General Public License as published by # -# the Free Software Foundation, either version 3 of the License. # -# # -# This program is distributed in the hope that it will be useful, # -# but WITHOUT ANY WARRANTY; without even the implied warranty of # -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # -# GNU General Public License for more details. # -# # -# You should have received a copy of the GNU General Public License # -# along with this program. If not, see {http://www.gnu.org/licenses/}. # -# # -# Home: https://github.com/Crivella/ocr_translate # -################################################################################### -"""Test base.py from ocr_tsl.""" -# pylint: disable=redefined-outer-name - -import importlib -from pathlib import Path - -import pytest - -from ocr_translate.ocr_tsl import huggingface - - -@pytest.fixture -def mock_loader(): - """Mock hugging face class with `from_pretrained` method.""" - class Loader(): - """Mocked class.""" - def from_pretrained(self, model_id: Path | str, cache_dir=None): - """Mocked method.""" - if isinstance(model_id, Path): - if not model_id.is_dir(): - raise FileNotFoundError('Not in dir') - elif isinstance(model_id, str): - if cache_dir is None: - cache_dir = huggingface.root - if not (cache_dir / f'models--{model_id.replace("/", "--")}').is_dir(): - raise FileNotFoundError('Not in cache') - - return Loader() - -def test_env_transformers_cache(monkeypatch): - """Test that the TRANSFORMERS_CACHE environment variable is set.""" - monkeypatch.setenv('TRANSFORMERS_CACHE', 'test') - importlib.reload(huggingface) - assert huggingface.root == Path('test') - - - -def test_load_from_storage_dir_fail(monkeypatch, mock_loader, tmpdir): - """Test low-level loading a huggingface model from storage (missing file).""" - monkeypatch.setenv('TRANSFORMERS_CACHE', str(tmpdir)) - importlib.reload(huggingface) - - # Load is supposed to test direcotry first and than fallnack to cache - # Exception should always be from not found in cache first - with pytest.raises(FileNotFoundError, match='Not in cache'): - huggingface.load(mock_loader, 'test/id') - -def test_load_from_storage_dir_success(monkeypatch, mock_loader, tmpdir): - """Test low-level loading a huggingface model from storage (success).""" - monkeypatch.setenv('TRANSFORMERS_CACHE', str(tmpdir)) - importlib.reload(huggingface) - - ptr = tmpdir - for pth in Path('test/id').parts: - ptr = ptr.mkdir(pth) - huggingface.load(mock_loader, 'test/id') - -def test_load_from_storage_cache_success(monkeypatch, mock_loader, tmpdir): - """Test low-level loading a huggingface model from storage (success).""" - monkeypatch.setenv('TRANSFORMERS_CACHE', str(tmpdir)) - importlib.reload(huggingface) - - tmpdir.mkdir('models--test--id') - huggingface.load(mock_loader, 'test/id') - -def test_load_hugginface_model_invalide_type(): - """Test high-level loading a huggingface model. Request unkown entity.""" - with pytest.raises(ValueError, match=r'^Unknown request: .*'): - huggingface.load_hugginface_model('test', ['invalid']) - -def test_load_hugginface_model_return_none(monkeypatch): - """Test high-level loading a huggingface model. Return None from load.""" - def mock_load(*args): - """Mocked load function.""" - return None - monkeypatch.setattr(huggingface, 'load', mock_load) - - with pytest.raises(ValueError, match=r'^Could not load model: .*'): - huggingface.load_hugginface_model('test', ['model']) - - -@pytest.mark.parametrize('model_type', [ - 'tokenizer', - 'ved_model', - 'model', - 'image_processor', - 'seq2seq' -]) -def test_load_hugginface_model_success(monkeypatch, model_type): - """Test high-level loading a huggingface model.""" - def mock_load(loader, *args): - """Mocked load function.""" - assert loader == huggingface.mapping[model_type] - class App(): - """Mocked huggingface class with `to` method.""" - def to(self, x): # pylint: disable=invalid-name,unused-argument - """Mocked method.""" - return None - return App() - monkeypatch.setattr(huggingface, 'load', mock_load) - - loaded = huggingface.load_hugginface_model('test', [model_type]) - - assert isinstance(loaded, dict) - assert len(loaded) == 1 - assert model_type in loaded diff --git a/tests/ocr_tsl/test_ocr.py b/tests/ocr_tsl/test_ocr.py index 1ca3461..3561b3c 100644 --- a/tests/ocr_tsl/test_ocr.py +++ b/tests/ocr_tsl/test_ocr.py @@ -21,58 +21,36 @@ import pytest from ocr_translate import models as m -from ocr_translate.messaging import Message from ocr_translate.ocr_tsl import ocr -ocr_globals = ['OCR_MODEL', 'OCR_TOKENIZER', 'OCR_IMAGE_PROCESSOR', 'OCR_MODEL_OBJ', 'OBJ_MODEL_ID'] +ocr_globals = ['OCR_MODEL_OBJ', 'OBJ_MODEL_ID'] + +@pytest.mark.django_db +def test_load_ocr_model(monkeypatch, mock_called, ocr_model: m.OCRModel): + """Test load ocr model from scratch.""" + model_id = ocr_model.id + def mock_fromentrypoint(*args, **kwargs): + mock_fromentrypoint.called = True + return ocr_model + # Required to avoid setting global variables for future `from clean` tests + monkeypatch.setattr(ocr, 'OCR_MODEL_OBJ', None) + monkeypatch.setattr(ocr, 'OBJ_MODEL_ID', None) + monkeypatch.setattr(m.OCRModel, 'from_entrypoint', mock_fromentrypoint) + ocr_model.load = mock_called + ocr.load_ocr_model(model_id) + + assert hasattr(mock_fromentrypoint, 'called') + assert hasattr(mock_called, 'called') def test_load_ocr_model_already_loaded(monkeypatch, mock_called): """Test load box model. With already loaded model.""" model_id = 'test/id' - monkeypatch.setattr(ocr, 'load_hugginface_model', mock_called) + monkeypatch.setattr(m.OCRModel, 'from_entrypoint', mock_called) monkeypatch.setattr(ocr, 'OBJ_MODEL_ID', model_id) ocr.load_ocr_model(model_id) assert not hasattr(mock_called, 'called') -@pytest.mark.django_db -def test_load_ocr_model_tesseract(monkeypatch, mock_called): - """Test load box model. Success""" - model_id = 'tesseract' - - monkeypatch.setattr(ocr, 'load_hugginface_model', mock_called) - # Needed to make sure that changes doen by `load_ocr_model` are not persisted - for key in ocr_globals: - monkeypatch.setattr(ocr, key, None) - - ocr.load_ocr_model(model_id) - assert not hasattr(mock_called, 'called') - -@pytest.mark.django_db -def test_load_ocr_model_test(monkeypatch): - """Test load box model. Success""" - model_id = 'easyocr' - res = { - 'ved_model': 'mocked_ved', - 'tokenizer': 'mocked_tokenizer', - 'image_processor': 'mocked_image_processor', - } - monkeypatch.setattr(ocr, 'load_hugginface_model', lambda *args, **kwargs: res) - - # Needed to make sure that changes doen by `load_ocr_model` are not persisted - for key in ocr_globals: - monkeypatch.setattr(ocr, key, None) - - assert m.OCRModel.objects.count() == 0 - ocr.load_ocr_model(model_id) - assert m.OCRModel.objects.count() == 1 - - assert ocr.OBJ_MODEL_ID == model_id - # Check that the mocked function was called and that globals were set by loader - assert ocr.OCR_MODEL == 'mocked_ved' - assert ocr.OCR_TOKENIZER == 'mocked_tokenizer' - assert ocr.OCR_IMAGE_PROCESSOR == 'mocked_image_processor' - def test_unload_ocr_model(monkeypatch): """Test unload box model.""" for key in ocr_globals: @@ -83,121 +61,21 @@ def test_unload_ocr_model(monkeypatch): for key in ocr_globals: assert getattr(ocr, key) is None -def test_unload_ocr_model_cpu(monkeypatch, mock_called): - """Test unload box model with cpu.""" - monkeypatch.setattr(ocr.torch.cuda, 'empty_cache', mock_called) - monkeypatch.setattr(ocr, 'dev', 'cpu') - - ocr.unload_ocr_model() - assert not hasattr(mock_called, 'called') - -def test_unload_ocr_model_cuda(monkeypatch, mock_called): - """Test unload box model with cuda.""" - monkeypatch.setattr(ocr.torch.cuda, 'empty_cache', mock_called) - monkeypatch.setattr(ocr, 'dev', 'cuda') - - ocr.unload_ocr_model() - assert hasattr(mock_called, 'called') - def test_get_ocr_model(monkeypatch): """Test get ocr model function.""" monkeypatch.setattr(ocr, 'OCR_MODEL_OBJ', 'test') assert ocr.get_ocr_model() == 'test' -def test_pipeline_invalide_image(): - """Test ocr pipeline with invalid image.""" - with pytest.raises(TypeError, match=r'^img should be PIL Image.*'): - ocr._ocr('invalid_image', 'ja') # pylint: disable=protected-access - -def test_pipeline_with_bbox(monkeypatch, mock_called, image_pillow): - """Test ocr pipeline with bbox. Has to call the crop method of image.""" - model_id = 'tesseract' - bbox = (1,2,8,9) - monkeypatch.setattr(ocr, 'OBJ_MODEL_ID', model_id) - monkeypatch.setattr(ocr, 'tesseract_pipeline', lambda *args, **kwargs: None) - monkeypatch.setattr(ocr.Image.Image, 'crop', mock_called) - - ocr._ocr(image_pillow, '', bbox=bbox) # pylint: disable=protected-access - - assert hasattr(mock_called, 'called') - assert mock_called.args[1] == bbox # 0 is self - -def test_pipeline_tesseract(monkeypatch, mock_called, image_pillow): - """Test ocr pipeline with tesseract model.""" - model_id = 'tesseract' - monkeypatch.setattr(ocr, 'OBJ_MODEL_ID', model_id) - monkeypatch.setattr(ocr, 'tesseract_pipeline', mock_called) - - ocr._ocr(image_pillow, '') # pylint: disable=protected-access - - assert hasattr(mock_called, 'called') - -def test_pipeline_hugginface(image_pillow, mock_ocr_preprocessor, mock_ocr_tokenizer, mock_ocr_model, monkeypatch): - """Test ocr pipeline with hugginface model.""" - model_id = 'test_model' - lang = 'ja' - - monkeypatch.setattr(ocr, 'OCR_IMAGE_PROCESSOR', mock_ocr_preprocessor(model_id)) - monkeypatch.setattr(ocr, 'OCR_TOKENIZER', mock_ocr_tokenizer(model_id)) - monkeypatch.setattr(ocr, 'OCR_MODEL', mock_ocr_model(model_id)) - - res = ocr._ocr(image_pillow, lang) # pylint: disable=protected-access - - assert res == 'abcde' - -def test_pipeline_hugginface_cuda(image_pillow, mock_ocr_preprocessor, mock_ocr_tokenizer, mock_ocr_model, monkeypatch): - """Test ocr pipeline with hugginface model and cuda.""" - model_id = 'test_model' - lang = 'ja' - - monkeypatch.setattr(ocr, 'dev', 'cuda') - monkeypatch.setattr(ocr, 'OCR_IMAGE_PROCESSOR', mock_ocr_preprocessor(model_id)) - monkeypatch.setattr(ocr, 'OCR_TOKENIZER', mock_ocr_tokenizer(model_id)) - monkeypatch.setattr(ocr, 'OCR_MODEL', mock_ocr_model(model_id)) - - res = ocr._ocr(image_pillow, lang) # pylint: disable=protected-access - - assert res == 'abcde' - -def test_queue_placer_handler(monkeypatch, mock_called): - """Test queue_placer is setting _ocr as handler, and that it is called.""" - monkeypatch.setattr(ocr, '_ocr', mock_called) - monkeypatch.setattr(ocr.q.msg_queue, 'reuse_msg', False) - ocr.ocr(id_=1, block=True) - assert hasattr(mock_called, 'called') - -@pytest.mark.parametrize('mock_called', ['test_return'], indirect=True) -def test_queue_placer_blocking(monkeypatch, mock_called): - """Test queue_placer with blocking""" - monkeypatch.setattr(ocr, '_ocr', mock_called) - monkeypatch.setattr(ocr.q.msg_queue, 'reuse_msg', False) - res = ocr.ocr(id_=1, block=True) - assert hasattr(mock_called, 'called') - assert res == mock_called.expected - -@pytest.mark.parametrize('mock_called', ['test_return'], indirect=True) -def test_queue_placer_nonblocking(monkeypatch, mock_called): - """Test queue_placer with blocking""" - monkeypatch.setattr(ocr, '_ocr', mock_called) - monkeypatch.setattr(ocr.q.msg_queue, 'reuse_msg', False) - ocr.q.stop_workers() - res = ocr.ocr(id_=1, block=False) - assert isinstance(res, Message) - - assert not hasattr(mock_called, 'called') # Before resolving the message the handler is not called - ocr.q.start_workers() - assert res.response() == mock_called.expected - assert hasattr(mock_called, 'called') # After resolving the message the handler is called - -def test_pipeline_worker(): - """Test tsl pipeline with worker""" - placeholder = 'placeholder' - ocr.q.stop_workers() - - messages = [ocr.ocr(placeholder, 'ja', 'en', id_=i, block=False) for i in range(3)] - assert all(isinstance(_, Message) for _ in messages) - def gen(): - while not ocr.q.msg_queue.empty(): - yield ocr.q.msg_queue.get() - res = list(gen()) - assert len(res) == len(messages) +def test_unload_ocr_model_if_loaded(monkeypatch, mock_base_model): + """Test unload ocr model is called if load with an already loaded model.""" + base1 = mock_base_model() + base2 = mock_base_model() + monkeypatch.setattr(ocr, 'OCR_MODEL_OBJ', base1) + monkeypatch.setattr(ocr, 'OBJ_MODEL_ID', 'test') + monkeypatch.setattr(m.OCRModel, 'from_entrypoint', lambda *args, **kwargs: base2) + ocr.load_ocr_model('test2') + + assert not base1.load_called + assert base1.unload_called + assert base2.load_called + assert not base2.unload_called diff --git a/tests/ocr_tsl/test_tsl.py b/tests/ocr_tsl/test_tsl.py index 4419582..64e6467 100644 --- a/tests/ocr_tsl/test_tsl.py +++ b/tests/ocr_tsl/test_tsl.py @@ -21,73 +21,38 @@ import pytest from ocr_translate import models as m -from ocr_translate.messaging import Message from ocr_translate.ocr_tsl import tsl -tsl_globals = ['TSL_MODEL', 'TSL_TOKENIZER', 'TSL_MODEL_OBJ'] +tsl_globals = ['TSL_MODEL_OBJ'] -def test_get_mnt_wrong_options(): - """Test get_mnt with wrong options.""" - with pytest.raises(ValueError, match=r'^min_max_new_tokens must be less than max_max_new_tokens$'): - tsl.get_mnt(10, {'min_max_new_tokens': 20, 'max_max_new_tokens': 10}) -def test_pre_tokenize(string, data_regression): - """Test tsl module.""" - options = [ - {}, - {'break_newlines': True}, - {'break_newlines': False}, - {'break_chars': '?.!'}, - {'ignore_chars': '?.!'}, - {'break_newlines': False, 'break_chars': '?.!'}, - {'break_newlines': False, 'ignore_chars': '?.!'}, - {'restore_dash_newlines': True}, - ] - - res = [] - for option in options: - dct = { - 'string': string, - 'options': option, - 'tokens': tsl.pre_tokenize(string, **option) - } - res.append(dct) +@pytest.mark.django_db +def test_load_tsl_model(monkeypatch, mock_called, tsl_model: m.TSLModel): + """Test load tsl model from scratch.""" + model_id = tsl_model.id + def mock_fromentrypoint(*args, **kwargs): + mock_fromentrypoint.called = True + return tsl_model + # Required to avoid setting global variables for future `from clean` tests + monkeypatch.setattr(tsl, 'TSL_MODEL_OBJ', None) + monkeypatch.setattr(tsl, 'TSL_MODEL_ID', None) + monkeypatch.setattr(m.TSLModel, 'from_entrypoint', mock_fromentrypoint) + tsl_model.load = mock_called + tsl.load_tsl_model(model_id) - data_regression.check({'res': res}) + assert hasattr(mock_fromentrypoint, 'called') + assert hasattr(mock_called, 'called') def test_load_tsl_model_already_loaded(monkeypatch, mock_called): """Test load box model. With already loaded model.""" model_id = 'test/id' - monkeypatch.setattr(tsl, 'load_hugginface_model', mock_called) + monkeypatch.setattr(m.TSLModel, 'from_entrypoint', mock_called) monkeypatch.setattr(tsl, 'TSL_MODEL_ID', model_id) tsl.load_tsl_model(model_id) assert not hasattr(mock_called, 'called') -@pytest.mark.django_db -def test_load_tsl_model_test(monkeypatch): - """Test load box model. Success""" - model_id = 'test/id' - res = { - 'seq2seq': 'mocked_seq2seq', - 'tokenizer': 'mocked_tokenizer', - } - monkeypatch.setattr(tsl, 'load_hugginface_model', lambda *args, **kwargs: res) - - # Needed to make sure that changes doen by `load_tsl_model` are not persisted - for key in tsl_globals: - monkeypatch.setattr(tsl, key, None) - - assert m.TSLModel.objects.count() == 0 - tsl.load_tsl_model(model_id) - assert m.TSLModel.objects.count() == 1 - - assert tsl.TSL_MODEL_ID == model_id - # Check that the mocked function was called and that globals were set by loader - assert tsl.TSL_MODEL == 'mocked_seq2seq' - assert tsl.TSL_TOKENIZER == 'mocked_tokenizer' - def test_unload_tsl_model(monkeypatch): """Test unload box model.""" for key in tsl_globals: @@ -98,185 +63,21 @@ def test_unload_tsl_model(monkeypatch): for key in tsl_globals: assert getattr(tsl, key) is None -def test_unload_tsl_model_cpu(monkeypatch, mock_called): - """Test unload box model with cpu.""" - monkeypatch.setattr(tsl.torch.cuda, 'empty_cache', mock_called) - monkeypatch.setattr(tsl, 'dev', 'cpu') - - tsl.unload_tsl_model() - assert not hasattr(mock_called, 'called') - -def test_unload_tsl_model_cuda(monkeypatch, mock_called): - """Test unload box model with cuda.""" - monkeypatch.setattr(tsl.torch.cuda, 'empty_cache', mock_called) - monkeypatch.setattr(tsl, 'dev', 'cuda') - - tsl.unload_tsl_model() - assert hasattr(mock_called, 'called') - def test_get_tsl_model(monkeypatch): - """Test get ocr model function.""" + """Test get tsl model function.""" monkeypatch.setattr(tsl, 'TSL_MODEL_OBJ', 'test') assert tsl.get_tsl_model() == 'test' -def test_pipeline_wrong_type(monkeypatch, mock_tsl_tokenizer): - """Test tsl pipeline with wrong type.""" - monkeypatch.setattr(tsl, 'TSL_TOKENIZER', mock_tsl_tokenizer('test/id')) - with pytest.raises(TypeError, match=r'^Unsupported type for text:.*'): - tsl._tsl_pipeline(1, 'ja', 'en') # pylint: disable=protected-access - -def test_pipeline_no_tokens(monkeypatch, mock_tsl_tokenizer): - """Test tsl pipeline with no tokens generated from pre_tokenize.""" - monkeypatch.setattr(tsl, 'pre_tokenize', lambda *args, **kwargs: []) - monkeypatch.setattr(tsl, 'TSL_TOKENIZER', mock_tsl_tokenizer('test/id')) - - res = tsl._tsl_pipeline('', 'ja', 'en') # pylint: disable=protected-access - - assert res == '' - -def test_pipeline_m2m(monkeypatch, mock_tsl_tokenizer, mock_tsl_model): - """Test tsl pipeline with m2m model.""" - model_id = 'test/id' - monkeypatch.setattr(tsl, 'M2M100Tokenizer', mock_tsl_tokenizer) - monkeypatch.setattr(tsl, 'TSL_MODEL', mock_tsl_model(model_id)) - monkeypatch.setattr(tsl, 'TSL_TOKENIZER', mock_tsl_tokenizer(model_id)) - - tsl._tsl_pipeline('', 'ja', 'en') # pylint: disable=protected-access - - assert tsl.TSL_TOKENIZER.called_get_lang_id is True - - -def test_pipeline(string, monkeypatch, mock_tsl_tokenizer, mock_tsl_model, mock_called): - """Test tsl pipeline (also check that cache is not cleared in CPU mode).""" - model_id = 'test_model' - lang_src = 'ja' - lang_dst = 'en' - - monkeypatch.setattr(tsl, 'TSL_MODEL', mock_tsl_model(model_id)) - monkeypatch.setattr(tsl, 'TSL_TOKENIZER', mock_tsl_tokenizer(model_id)) - monkeypatch.setattr(tsl.torch.cuda, 'empty_cache', mock_called) - monkeypatch.setattr(tsl, 'dev', 'cpu') - - res = tsl._tsl_pipeline(string, lang_src, lang_dst) # pylint: disable=protected-access - - assert res == string.replace('\n', ' ') - assert tsl.TSL_TOKENIZER.model_id == model_id - assert tsl.TSL_TOKENIZER.src_lang == lang_src - - assert not hasattr(mock_called, 'called') - -def test_pipeline_clear_cache(monkeypatch, mock_tsl_tokenizer, mock_tsl_model, mock_called): - """Test tsl pipeline with cuda should clear_cache.""" - model_id = 'test_model' - lang_src = 'ja' - lang_dst = 'en' - - monkeypatch.setattr(tsl, 'TSL_MODEL', mock_tsl_model(model_id)) - monkeypatch.setattr(tsl, 'TSL_TOKENIZER', mock_tsl_tokenizer(model_id)) - monkeypatch.setattr(tsl.torch.cuda, 'empty_cache', mock_called) - monkeypatch.setattr(tsl, 'dev', 'cuda') - - tsl._tsl_pipeline('test', lang_src, lang_dst) # pylint: disable=protected-access - - assert hasattr(mock_called, 'called') - - - -def test_pipeline_batch(batch_string, monkeypatch, mock_tsl_tokenizer, mock_tsl_model): - """Test tsl pipeline with batched string.""" - model_id = 'test_model' - lang_src = 'ja' - lang_dst = 'en' - - monkeypatch.setattr(tsl, 'TSL_MODEL', mock_tsl_model(model_id)) - monkeypatch.setattr(tsl, 'TSL_TOKENIZER', mock_tsl_tokenizer(model_id)) - - res = tsl._tsl_pipeline(batch_string, lang_src, lang_dst) # pylint: disable=protected-access - - assert res == [_.replace('\n', ' ') for _ in batch_string] - assert tsl.TSL_TOKENIZER.model_id == model_id - assert tsl.TSL_TOKENIZER.src_lang == lang_src - -@pytest.mark.parametrize( - 'options', - [ - {}, - {'min_max_new_tokens': 30}, - {'max_max_new_tokens': 22}, - {'max_new_tokens': 15}, - {'max_new_tokens_ratio': 2} - ], - ids=[ - 'default', - 'min_max_new_tokens', - 'max_max_new_tokens', - 'max_new_tokens', - 'max_new_tokens_ratio' - ] -) -def test_pipeline_options(options, string, monkeypatch, mock_tsl_tokenizer, mock_tsl_model): - """Test tsl pipeline with options.""" - model_id = 'test_model' - lang_src = 'ja' - lang_dst = 'en' - - monkeypatch.setattr(tsl, 'TSL_MODEL', mock_tsl_model(model_id)) - monkeypatch.setattr(tsl, 'TSL_TOKENIZER', mock_tsl_tokenizer(model_id)) - - min_max_new_tokens = options.get('min_max_new_tokens', 20) - max_max_new_tokens = options.get('max_max_new_tokens', 512) - ntok = string.replace('\n', ' ').count(' ') + 1 - - if min_max_new_tokens > max_max_new_tokens: - with pytest.raises(ValueError): - tsl._tsl_pipeline(string, lang_src, lang_dst, options=options) # pylint: disable=protected-access - else: - tsl._tsl_pipeline(string, lang_src, lang_dst, options=options) # pylint: disable=protected-access - - mnt = tsl.get_mnt(ntok, options) - - model = tsl.TSL_MODEL - - assert model.options['max_new_tokens'] == mnt - -def test_queue_placer_handler(monkeypatch, mock_called): - """Test queue_placer is setting _tsl_pipeline as handler, and that it is called.""" - monkeypatch.setattr(tsl, '_tsl_pipeline', mock_called) - monkeypatch.setattr(tsl.q.msg_queue, 'reuse_msg', False) - tsl.tsl_pipeline(id_=1, block=True) - assert hasattr(mock_called, 'called') - -@pytest.mark.parametrize('mock_called', ['test_return'], indirect=True) -def test_queue_placer_blocking(monkeypatch, mock_called): - """Test queue_placer with blocking""" - monkeypatch.setattr(tsl, '_tsl_pipeline', mock_called) - monkeypatch.setattr(tsl.q.msg_queue, 'reuse_msg', False) - res = tsl.tsl_pipeline(id_=1, block=True) - assert hasattr(mock_called, 'called') - assert res == mock_called.expected - -@pytest.mark.parametrize('mock_called', ['test_return'], indirect=True) -def test_queue_placer_nonblocking(monkeypatch, mock_called): - """Test queue_placer with blocking""" - monkeypatch.setattr(tsl, '_tsl_pipeline', mock_called) - monkeypatch.setattr(tsl.q.msg_queue, 'reuse_msg', False) - tsl.q.stop_workers() - res = tsl.tsl_pipeline(id_=1, block=False) - assert isinstance(res, Message) - - assert not hasattr(mock_called, 'called') # Before resolving the message the handler is not called - tsl.q.start_workers() - assert res.response() == mock_called.expected - assert hasattr(mock_called, 'called') # After resolving the message the handler is called - - -def test_pipeline_worker(): - """Test tsl pipeline with worker""" - placeholder = 'placeholder' - tsl.q.stop_workers() - - messages = [tsl.tsl_pipeline(placeholder, 'ja', 'en', id_=i, batch_id=0, block=False) for i in range(3)] - assert all(isinstance(_, Message) for _ in messages) - # Makes sure that batching is enabled for tsl queue (retrieve all messages withone `get` call) - res = tsl.q.get() - assert len(res) == len(messages) +def test_unload_tsl_model_if_loaded(monkeypatch, mock_base_model): + """Test unload tsl model is called if load with an already loaded model.""" + base1 = mock_base_model() # pylint: disable=invalid-name + base2 = mock_base_model() # pylint: disable=invalid-name + monkeypatch.setattr(tsl, 'TSL_MODEL_OBJ', base1) + monkeypatch.setattr(tsl, 'TSL_MODEL_ID', 'test') + monkeypatch.setattr(m.TSLModel, 'from_entrypoint', lambda *args, **kwargs: base2) + tsl.load_tsl_model('test2') + + assert not base1.load_called + assert base1.unload_called + assert base2.load_called + assert not base2.unload_called diff --git a/tests/plugin/conftest.py b/tests/plugin/conftest.py new file mode 100644 index 0000000..56d5df1 --- /dev/null +++ b/tests/plugin/conftest.py @@ -0,0 +1,181 @@ +################################################################################### +# ocr_translate - a django app to perform OCR and translation of images. # +# Copyright (C) 2023-present Davide Grassano # +# # +# This program is free software: you can redistribute it and/or modify # +# it under the terms of the GNU General Public License as published by # +# the Free Software Foundation, either version 3 of the License. # +# # +# This program is distributed in the hope that it will be useful, # +# but WITHOUT ANY WARRANTY; without even the implied warranty of # +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # +# GNU General Public License for more details. # +# # +# You should have received a copy of the GNU General Public License # +# along with this program. If not, see {http://www.gnu.org/licenses/}. # +# # +# Home: https://github.com/Crivella/ocr_translate # +################################################################################### +"""Fixtures for ocr_tsl tests""" + +import pytest + + +@pytest.fixture() +def mock_tsl_tokenizer(): + """Mock tokenizer for TSL.""" + import torch # pylint: disable=import-outside-toplevel + class _MockTokenizer(): + def __init__(self, model_id): + self.model_id = model_id + self.other_options = {} + self.tok_to_word = {0: 0} + self.word_to_tok = {0: 0} + self.ntoks = 1 + self.called_get_lang_id = False + + def __call__(self, text, **options): + issplit = options.pop('is_split_into_words', False) + padding = options.pop('padding', False) + truncation = options.pop('truncation', False) # pylint: disable=unused-variable + + self.other_options = options + + if isinstance(text, list): + if isinstance(text[0], str): + text = [text] + if issplit: + app = [] + for line in text: + app2 = [] + for seg in line: + app2.extend(seg.split(' ')) + app.append(app2) + else: + app = [_.split(' ') for _ in text] + + if padding: + app2 = [] + for lst in app: + app3 = [] + for word in lst: + if word not in self.word_to_tok: + self.word_to_tok[word] = self.ntoks + self.tok_to_word[self.ntoks] = word + self.ntoks += 1 + app3.append(self.word_to_tok[word]) + app2.append(app3) + + max_len = max(len(_) for _ in app2) + res = [(_ + [0] * max_len)[:max_len] for _ in app2] + else: + res = app + class Dict(dict): + """Dict class with added .to method""" + def to(self, device): # pylint: disable=unused-argument,invalid-name + """Move the dict to a device.""" + return None + + dct = Dict([('input_ids', torch.Tensor(res))]) + return dct + + raise TypeError(f'Expected list of strings, but got {type(text)}') + + def batch_decode(self, tokens, **options): # pylint: disable=unused-argument + """Decode a batch of tokens.""" + res = [' '.join(filter(None, [self.tok_to_word[int(_)] for _ in lst])) for lst in tokens] + return res + + def get_lang_id(self, lang): # pylint: disable=unused-argument + """Get the language id.""" + self.called_get_lang_id = True + return 0 + + return _MockTokenizer + +@pytest.fixture() +def mock_tsl_model(): + """Mock model for TSL.""" + class _MockModel(): + def __init__(self, model_id): + self.model_id = model_id + self.options = {} + + def generate(self, input_ids=None, **options): + """Mock generate translated tokens.""" + self.options = options + return input_ids + + return _MockModel + +@pytest.fixture() +def mock_ocr_preprocessor(): + """Mock preprocessor for OCR.""" + class RES(): + """Mock result""" + def __init__(self): + class PV(list): + """Mock pixel values""" + def cuda(self): + """Mock cuda""" + self.cuda_called = True # pylint: disable=attribute-defined-outside-init + return self + self.pixel_values = PV([1,2,3,4,5]) + + class _MockPreprocessor(): + def __init__(self, model_id): + self.model_id = model_id + self.options = {} + + def __call__(self, img, **options): + self.options = options + res = RES() + return res + + return _MockPreprocessor + +@pytest.fixture() +def mock_ocr_tokenizer(): + """Mock tokenizer for OCR.""" + class _MockTokenizer(): + def __init__(self, model_id): + self.model_id = model_id + self.options = {} + + def batch_decode(self, tokens, **options): + """Mock batch decode.""" + self.options = options + offset = ord('a') - 1 + return [''.join(chr(int(_)+offset) for _ in tokens)] + + return _MockTokenizer + +@pytest.fixture() +def mock_ocr_model(): + """Mock model for OCR.""" + class _MockModel(): + def __init__(self, model_id): + self.model_id = model_id + self.options = {} + + def generate(self, pixel_values=None, **options): + """Mock generate.""" + self.options = options + return pixel_values + + return _MockModel + +@pytest.fixture() +def mock_box_reader(): + """Mock box reader.""" + class _MockReader(): + def __init__(self, model_id): + self.model_id = model_id + self.options = {} + + def detect(self, img, **options): # pylint: disable=unused-argument + """Mock recognize.""" + self.options = options + return (([(10,10,30,30), (40,40,50,50)],),) + + return _MockReader diff --git a/tests/plugin/test_base.py b/tests/plugin/test_base.py new file mode 100644 index 0000000..a3cf017 --- /dev/null +++ b/tests/plugin/test_base.py @@ -0,0 +1,120 @@ +################################################################################### +# ocr_translate - a django app to perform OCR and translation of images. # +# Copyright (C) 2023-present Davide Grassano # +# # +# This program is free software: you can redistribute it and/or modify # +# it under the terms of the GNU General Public License as published by # +# the Free Software Foundation, either version 3 of the License. # +# # +# This program is distributed in the hope that it will be useful, # +# but WITHOUT ANY WARRANTY; without even the implied warranty of # +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # +# GNU General Public License for more details. # +# # +# You should have received a copy of the GNU General Public License # +# along with this program. If not, see {http://www.gnu.org/licenses/}. # +# # +# Home: https://github.com/Crivella/ocr_translate # +################################################################################### +"""Tests for base plugin facility.""" + +import pytest + +from ocr_translate import models as m + + +def test_base_from_entrypoint(): + """Test from_entrypoint method of BaseModel should `raise ValueError`.""" + with pytest.raises(ValueError): + m.BaseModel.from_entrypoint('test_model_id') + +@pytest.mark.django_db +def test_box_model_from_entrypoint_unknown(box_model: m.OCRBoxModel): + """Test from_entrypoint method of OCRModel should `raise ValueError` if entrypoint is unknown.""" + with pytest.raises(ValueError, match=r'^Missing plugin: Entrypoint "test_entrypoint.box" not found.$'): + m.OCRBoxModel.from_entrypoint(box_model.name) + +@pytest.mark.django_db +def test_ocr_model_from_entrypoint_unknown(ocr_model: m.OCRModel): + """Test from_entrypoint method of OCRModel should `raise ValueError` if entrypoint is unknown.""" + with pytest.raises(ValueError, match=r'^Missing plugin: Entrypoint "test_entrypoint.ocr" not found.$'): + m.OCRModel.from_entrypoint(ocr_model.name) + +@pytest.mark.django_db +def test_tsl_model_from_entrypoint_unknown(tsl_model: m.TSLModel): + """Test from_entrypoint method of OCRModel should `raise ValueError` if entrypoint is unknown.""" + with pytest.raises(ValueError, match=r'^Missing plugin: Entrypoint "test_entrypoint.tsl" not found.$'): + m.TSLModel.from_entrypoint(tsl_model.name) + +@pytest.mark.django_db +def test_valid_entrypoint(monkeypatch, box_model: m.OCRBoxModel): + """Test that valid entrypoint works.""" + import pkg_resources as pr # pylint: disable=import-outside-toplevel + class Obj(): # pylint: disable=missing-class-docstring + def __init__(self): + self.called = False + self.called_name = None + + @property + def objects(self): # pylint: disable=missing-function-docstring + class A(): # pylint: disable=missing-class-docstring,invalid-name + pass + new = A() + new.get = self.get # pylint: disable=attribute-defined-outside-init + return new + + def get(self, name): # pylint: disable=missing-function-docstring + self.called = True + self.called_name = name + return name + + def load(self): # pylint: disable=missing-function-docstring + return self + + o = Obj() # pylint: disable=invalid-name + + monkeypatch.setattr( + pr, 'iter_entry_points', + lambda x,name='': [o,], + ) + + m.OCRBoxModel.from_entrypoint(box_model.name) + + assert o.called + assert o.called_name == box_model.name + +@pytest.mark.django_db +def test_box_load_not_implemented(box_model: m.OCRBoxModel): + """Test that load method raises NotImplementedError.""" + with pytest.raises(NotImplementedError): + box_model.load() + +@pytest.mark.django_db +def test_box_unload_not_implemented(box_model: m.OCRBoxModel): + """Test that unload method raises NotImplementedError.""" + with pytest.raises(NotImplementedError): + box_model.unload() + +@pytest.mark.django_db +def test_ocr_non_pil_image(ocr_model: m.OCRModel): + """Test that ocr method raises TypeError if image is not PIL.Image.""" + with pytest.raises(TypeError, match=r'^img should be PIL Image, but got $'): + ocr_model.prepare_image('test_image') + +@pytest.mark.django_db +def test_box_main_method_notimplemented(box_model: m.OCRBoxModel): + """Test that ocr method raises TypeError if image is not PIL.Image.""" + with pytest.raises(NotImplementedError): + box_model._box_detection('test_image') # pylint: disable=protected-access + +@pytest.mark.django_db +def test_ocr_main_method_notimplemented(ocr_model: m.OCRModel): + """Test that ocr method raises TypeError if image is not PIL.Image.""" + with pytest.raises(NotImplementedError): + ocr_model._ocr('test_image') # pylint: disable=protected-access + +@pytest.mark.django_db +def test_tsl_main_method_notimplemented(tsl_model: m.TSLModel): + """Test that ocr method raises TypeError if image is not PIL.Image.""" + with pytest.raises(NotImplementedError): + tsl_model._translate('src_text', 'src_lang', 'dst_lang') # pylint: disable=protected-access diff --git a/tests/plugin/test_easyocr.py b/tests/plugin/test_easyocr.py new file mode 100644 index 0000000..6cbfef2 --- /dev/null +++ b/tests/plugin/test_easyocr.py @@ -0,0 +1,143 @@ +################################################################################### +# ocr_translate - a django app to perform OCR and translation of images. # +# Copyright (C) 2023-present Davide Grassano # +# # +# This program is free software: you can redistribute it and/or modify # +# it under the terms of the GNU General Public License as published by # +# the Free Software Foundation, either version 3 of the License. # +# # +# This program is distributed in the hope that it will be useful, # +# but WITHOUT ANY WARRANTY; without even the implied warranty of # +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # +# GNU General Public License for more details. # +# # +# You should have received a copy of the GNU General Public License # +# along with this program. If not, see {http://www.gnu.org/licenses/}. # +# # +# Home: https://github.com/Crivella/ocr_translate # +################################################################################### +"""Tests for easyocr plugin.""" + +import easyocr +import pytest + +from ocr_translate import models as m +from ocr_translate.plugins import easyocr + +boxes = [ + ((10,10,30,30), (15,15,20,20)), # b2 inside b1 + ((15,15,20,20), (10,10,30,30)), # b1 inside b2 + + ((30,30,50,50), (10,10,20,35)), # l1 > r2 + ((30,30,50,50), (55,10,75,35)), # r1 < l2 + ((30,30,50,50), (10,10,35,20)), # b1 > t2 + ((30,30,50,50), (10,55,35,75)), # t1 < b2 + + ((30,30,50,50), (10,10,35,35)), # b2-tr inside b1 + ((30,30,50,50), (45,10,75,35)), # b2-tl inside b1 + ((30,30,50,50), (40,45,75,75)), # b2-bl inside b1 + ((30,30,50,50), (10,45,35,75)), # b2-br inside b1 + + ((10,50,70,60), (50,10,60,70)), # intersection, but cornder not inside + + ((10,10,30,30), (29,29,51,40), (50,10,60,30)), # 3x intersection + ((10,10,30,30), (29,29,51,40), (60,10,70,30)), # 2x intersection + 1 +] +ids = [ + 'b2_inside_b1', + 'b1_inside_b2', + 'l1_>_r2', + 'r1_<_l2', + 'b1_>_t2', + 't1_<_b2', + 'b2-tr_inside_b1', + 'b2-tl_inside_b1', + 'b2-bl_inside_b1', + 'b2-br_inside_b1', + 'int_nocorners', + '3x_intersection', + '2x_intersection_+_1', +] + +pytestmark = pytest.mark.django_db + +@pytest.fixture() +def easyocr_model(language): + """OCRBoxModel database object.""" + easyocr_model_dict = { + 'name': 'easyocr', + 'language_format': 'iso1', + 'entrypoint': 'easyocr.box' + } + entrypoint = easyocr_model_dict.pop('entrypoint') + res = m.OCRBoxModel.objects.create(**easyocr_model_dict) + res.entrypoint = entrypoint + res.languages.add(language) + res.save() + + return easyocr.EasyOCRBoxModel.objects.get(name = res.name) + +def test_intersection_merge(data_regression): + """Test intersections function.""" + + res = [] + for boxes_lbrt,idx in zip(boxes,ids): + ptr = {} + ptr['idx'] = idx + boxes_lrbt = [] + for l,b,r,t in boxes_lbrt: + boxes_lrbt.append((l,r,b,t)) + ptr['box_lst'] = boxes_lrbt + ptr['intersection'] = easyocr.EasyOCRBoxModel.intersections(boxes_lrbt) + merge = easyocr.EasyOCRBoxModel.merge_bboxes(boxes_lrbt) + merge = [[int(_) for _ in el] for el in merge] + ptr['merge'] = merge + res.append(ptr) + + data_regression.check({'res': res}) + +def test_load_box_model_easyocr(monkeypatch, easyocr_model: m.OCRBoxModel): + """Test load box model. Success""" + monkeypatch.setattr(easyocr.easyocr, 'Reader', lambda *args, **kwargs: 'mocked') + easyocr_model.load() + assert easyocr_model.reader == 'mocked' + +def test_unload_box_model_easyocr_cpu(monkeypatch, mock_called, easyocr_model: m.OCRBoxModel): + """Test unload box model with cpu.""" + easyocr_model.dev = 'cpu' + monkeypatch.setattr(easyocr.torch.cuda, 'empty_cache', mock_called) + + easyocr_model.unload() + assert not hasattr(mock_called, 'called') + +def test_unload_box_model_easyocr_gpu(monkeypatch, mock_called, easyocr_model: m.OCRBoxModel): + """Test unload box model with cpu.""" + easyocr_model.dev = 'cuda' + monkeypatch.setattr(easyocr.torch.cuda, 'empty_cache', mock_called) + + easyocr_model.unload() + assert hasattr(mock_called, 'called') + +@pytest.mark.django_db +def test_easyocr_box_detextion(monkeypatch, mock_called, image_pillow, easyocr_model): + """Test easyocr box detection.""" + class MockReader(): # pylint: disable=missing-class-docstring + def __init__(self): + self.called = False + self.res = None + def detect(self, *args, **kwargs): # pylint: disable=missing-function-docstring + self.called = True + self.res = (('TARGET', 'other1'), 'other2') + return self.res + + reader = MockReader() + easyocr_model.reader = reader + + monkeypatch.setattr(easyocr_model, 'merge_bboxes', mock_called) + + easyocr_model._box_detection(image_pillow) # pylint: disable=protected-access + + assert reader.called + + assert hasattr(mock_called, 'called') + assert mock_called.args[0] == 'TARGET' diff --git a/tests/ocr_tsl/test_box/test_intersection_merge.yml b/tests/plugin/test_easyocr/test_intersection_merge.yml similarity index 100% rename from tests/ocr_tsl/test_box/test_intersection_merge.yml rename to tests/plugin/test_easyocr/test_intersection_merge.yml diff --git a/tests/plugin/test_huggingface.py b/tests/plugin/test_huggingface.py new file mode 100644 index 0000000..c2ec6fe --- /dev/null +++ b/tests/plugin/test_huggingface.py @@ -0,0 +1,442 @@ +################################################################################### +# ocr_translate - a django app to perform OCR and translation of images. # +# Copyright (C) 2023-present Davide Grassano # +# # +# This program is free software: you can redistribute it and/or modify # +# it under the terms of the GNU General Public License as published by # +# the Free Software Foundation, either version 3 of the License. # +# # +# This program is distributed in the hope that it will be useful, # +# but WITHOUT ANY WARRANTY; without even the implied warranty of # +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # +# GNU General Public License for more details. # +# # +# You should have received a copy of the GNU General Public License # +# along with this program. If not, see {http://www.gnu.org/licenses/}. # +# # +# Home: https://github.com/Crivella/ocr_translate # +################################################################################### +"""Test base.py from ocr_tsl.""" +# pylint: disable=redefined-outer-name + +from pathlib import Path + +import pytest + +from ocr_translate import models as m +from ocr_translate.plugins import hugginface + +pytestmark = pytest.mark.django_db + +@pytest.fixture() +def ved_model(language): + """OCRModel database object.""" + model_dict = { + 'name': 'test_ved', + 'language_format': 'iso1', + 'entrypoint': 'hugginface.ved' + } + + entrypoint = model_dict.pop('entrypoint') + res = m.OCRModel.objects.create(**model_dict) + res.entrypoint = entrypoint + res.languages.add(language) + res.save() + + return hugginface.HugginfaceVEDModel.objects.get(name = res.name) + +@pytest.fixture() +def s2s_model(language): + """OCRModel database object.""" + model_dict = { + 'name': 'test_seq2seq', + 'language_format': 'iso1', + 'entrypoint': 'hugginface.seq2seq' + } + + entrypoint = model_dict.pop('entrypoint') + res = m.TSLModel.objects.create(**model_dict) + res.entrypoint = entrypoint + res.src_languages.add(language) + res.dst_languages.add(language) + res.save() + + return hugginface.HugginfaceSeq2SeqModel.objects.get(name = res.name) + + + +@pytest.fixture +def mock_loader(monkeypatch): + """Mock hugging face class with `from_pretrained` method.""" + class Loader(): + """Mocked class.""" + def from_pretrained(self, model_id: Path | str, cache_dir=None): + """Mocked method.""" + if isinstance(model_id, Path): + if not model_id.is_dir(): + raise FileNotFoundError('Not in dir') + elif isinstance(model_id, str): + if cache_dir is None: + cache_dir = hugginface.EnvMixin().root + if not (cache_dir / f'models--{model_id.replace("/", "--")}').is_dir(): + raise FileNotFoundError('Not in cache') + + class A(): # pylint: disable=invalid-name + """Mocked huggingface class with `to` method.""" + def to(self, dev): # pylint: disable=invalid-name,unused-argument,missing-function-docstring + pass + return A() + + monkeypatch.setattr(hugginface.Loaders, 'mapping', { + 'tokenizer': Loader(), + 'seq2seq': Loader(), + 'model': Loader(), + 'ved_model': Loader(), + 'image_processor': Loader(), + }) + + # return Loader() + +def test_env_transformers_cache(monkeypatch): + """Test that the TRANSFORMERS_CACHE environment variable is set.""" + monkeypatch.setenv('TRANSFORMERS_CACHE', 'test') + mixin = hugginface.EnvMixin() + assert mixin.root == Path('test') + +def test_env_transformers_cpu(monkeypatch): + """Test that the DEVICE environment variable is cpu.""" + monkeypatch.setenv('DEVICE', 'cpu') + mixin = hugginface.EnvMixin() + assert mixin.dev == 'cpu' + +def test_env_transformers_cuda(monkeypatch): + """Test that the DEVICE environment variable is cuda.""" + monkeypatch.setenv('DEVICE', 'cuda') + mixin = hugginface.EnvMixin() + assert mixin.dev == 'cuda' + + +def test_load_hugginface_model_invalide_type(): + """Test high-level loading a huggingface model. Request unkown entity.""" + with pytest.raises(ValueError, match=r'^Unknown request: .*'): + hugginface.Loaders.load('test', ['invalid'], 'root') + +def test_load_hugginface_model_return_none(monkeypatch): + """Test high-level loading a huggingface model. Return None from load.""" + def mock_load(*args): + """Mocked load function.""" + return None + monkeypatch.setattr(hugginface.Loaders, '_load', mock_load) + + with pytest.raises(ValueError, match=r'^Could not load model: .*'): + hugginface.Loaders.load('test', ['model'], 'root') + +@pytest.mark.parametrize('model_type', [ + 'tokenizer', + 'ved_model', + 'model', + 'image_processor', + 'seq2seq' +]) +def test_load_hugginface_model_success(monkeypatch, model_type): + """Test high-level loading a huggingface model.""" + def mock_load(loader, *args): + """Mocked load function.""" + assert loader == hugginface.Loaders.mapping[model_type] + class App(): + """Mocked huggingface class with `to` method.""" + def to(self, x): # pylint: disable=invalid-name,unused-argument + """Mocked method.""" + return None + return App() + monkeypatch.setattr(hugginface.Loaders, '_load', mock_load) + + loaded = hugginface.Loaders.load('test', [model_type], 'root') + + assert isinstance(loaded, dict) + assert len(loaded) == 1 + assert model_type in loaded + +#################################################################################### +def test_load_from_storage_dir_fail(monkeypatch, mock_loader, tmpdir, ved_model): + """Test low-level loading a huggingface model from storage (missing file).""" + monkeypatch.setenv('TRANSFORMERS_CACHE', str(tmpdir)) + # Reload to make ENV effective + ved_model = hugginface.HugginfaceVEDModel.objects.get(name = ved_model.name) + + # Load is supposed to test direcotry first and than fallnack to cache + # Exception should always be from not found in cache first + with pytest.raises(FileNotFoundError, match='Not in cache'): + ved_model.load() + +def test_load_from_storage_dir_success(monkeypatch, mock_loader, tmpdir, ved_model): + """Test low-level loading a huggingface model from storage (success).""" + monkeypatch.setenv('TRANSFORMERS_CACHE', str(tmpdir)) + # Reload to make ENV effective + ved_model = hugginface.HugginfaceVEDModel.objects.get(name = ved_model.name) + + ptr = tmpdir + for pth in Path(ved_model.name).parts: + ptr = ptr.mkdir(pth) + ved_model.load() + +def test_load_from_storage_cache_success(monkeypatch, mock_loader, tmpdir, ved_model): + """Test low-level loading a huggingface model from storage (success).""" + monkeypatch.setenv('TRANSFORMERS_CACHE', str(tmpdir)) + # Reload to make ENV effective + ved_model = hugginface.HugginfaceVEDModel.objects.get(name = ved_model.name) + + tmpdir.mkdir('models--' + ved_model.name.replace('/', '--')) + ved_model.load() + +def test_unload_from_loaded_ved(monkeypatch, tmpdir, ved_model): + """Test unload box model with cpu.""" + monkeypatch.setattr(ved_model, 'model', '1') + monkeypatch.setattr(ved_model, 'tokenizer', '1') + + ved_model.unload() + assert ved_model.model is None + assert ved_model.tokenizer is None + +def test_unload_cpu(monkeypatch, mock_called, ved_model): + """Test unload box model with cpu.""" + monkeypatch.setattr(hugginface.torch.cuda, 'empty_cache', mock_called) + monkeypatch.setattr(ved_model, 'dev', 'cpu') + + ved_model.unload() + assert not hasattr(mock_called, 'called') + +def test_unload_cuda(monkeypatch, mock_called, ved_model): + """Test unload box model with cuda.""" + monkeypatch.setattr(hugginface.torch.cuda, 'empty_cache', mock_called) + monkeypatch.setattr(ved_model, 'dev', 'cuda') + + ved_model.unload() + assert hasattr(mock_called, 'called') + +# def test_pipeline_invalide_image(monkeypatch, hf_ved_model): +# """Test ocr pipeline with invalid image.""" +# monkeypatch.setattr(hf_ved_model, 'model', '1') +# monkeypatch.setattr(hf_ved_model, 'tokenizer', '1') +# monkeypatch.setattr(hf_ved_model, 'image_processor', '1') +# with pytest.raises(TypeError, match=r'^img should be PIL Image.*'): +# hf_ved_model._ocr('invalid_image', 'ja') # pylint: disable=protected-access + +def test_pipeline_notinit_ved(ved_model): + """Test tsl pipeline with not initialized model.""" + with pytest.raises(RuntimeError, match=r'^Model not loaded$'): + ved_model._ocr('image') # pylint: disable=protected-access + +def test_pipeline_hugginface( + image_pillow, mock_ocr_preprocessor, mock_ocr_tokenizer, mock_ocr_model, monkeypatch, ved_model): + """Test ocr pipeline with hugginface model.""" + lang = 'ja' + + monkeypatch.setattr(ved_model, 'image_processor', mock_ocr_preprocessor(ved_model.name)) + monkeypatch.setattr(ved_model, 'tokenizer', mock_ocr_tokenizer(ved_model.name)) + monkeypatch.setattr(ved_model, 'model', mock_ocr_model(ved_model.name)) + + res = ved_model._ocr(image_pillow, lang) # pylint: disable=protected-access + + assert res == 'abcde' + +def test_pipeline_hugginface_cuda( + image_pillow, mock_ocr_preprocessor, mock_ocr_tokenizer, mock_ocr_model, monkeypatch, ved_model): + """Test ocr pipeline with hugginface model and cuda.""" + lang = 'ja' + + monkeypatch.setattr(ved_model, 'dev', 'cuda') + monkeypatch.setattr(ved_model, 'image_processor', mock_ocr_preprocessor(ved_model.name)) + monkeypatch.setattr(ved_model, 'tokenizer', mock_ocr_tokenizer(ved_model.name)) + monkeypatch.setattr(ved_model, 'model', mock_ocr_model(ved_model.name)) + + res = ved_model._ocr(image_pillow, lang) # pylint: disable=protected-access + + assert res == 'abcde' + +#################################################################################### +def test_get_mnt_wrong_options(): + """Test get_mnt with wrong options.""" + with pytest.raises(ValueError, match=r'^min_max_new_tokens must be less than max_max_new_tokens$'): + hugginface.get_mnt(10, {'min_max_new_tokens': 20, 'max_max_new_tokens': 10}) + +def test_load_from_storage_dir_fail_s2s(monkeypatch, mock_loader, tmpdir, s2s_model): + """Test low-level loading a huggingface model from storage (missing file).""" + monkeypatch.setenv('TRANSFORMERS_CACHE', str(tmpdir)) + # Reload to make ENV effective + s2s_model = hugginface.HugginfaceSeq2SeqModel.objects.get(name = s2s_model.name) + + # Load is supposed to test direcotry first and than fallnack to cache + # Exception should always be from not found in cache first + with pytest.raises(FileNotFoundError, match='Not in cache'): + s2s_model.load() + +def test_load_from_storage_dir_success_s2s(monkeypatch, mock_loader, tmpdir, s2s_model): + """Test low-level loading a huggingface model from storage (success).""" + monkeypatch.setenv('TRANSFORMERS_CACHE', str(tmpdir)) + # Reload to make ENV effective + s2s_model = hugginface.HugginfaceSeq2SeqModel.objects.get(name = s2s_model.name) + + ptr = tmpdir + for pth in Path(s2s_model.name).parts: + ptr = ptr.mkdir(pth) + s2s_model.load() + +def test_load_from_storage_cache_success_s2s(monkeypatch, mock_loader, tmpdir, s2s_model): + """Test low-level loading a huggingface model from storage (success).""" + monkeypatch.setenv('TRANSFORMERS_CACHE', str(tmpdir)) + # Reload to make ENV effective + s2s_model = hugginface.HugginfaceSeq2SeqModel.objects.get(name = s2s_model.name) + + tmpdir.mkdir('models--' + s2s_model.name.replace('/', '--')) + s2s_model.load() + +def test_unload_from_loaded_s2s(monkeypatch, tmpdir, s2s_model): + """Test unload box model with cpu.""" + monkeypatch.setattr(s2s_model, 'model', '1') + monkeypatch.setattr(s2s_model, 'tokenizer', '1') + + s2s_model.unload() + assert s2s_model.model is None + assert s2s_model.tokenizer is None + +def test_unload_cpu_s2s(monkeypatch, mock_called, s2s_model): + """Test unload box model with cpu.""" + monkeypatch.setattr(hugginface.torch.cuda, 'empty_cache', mock_called) + monkeypatch.setattr(s2s_model, 'dev', 'cpu') + + s2s_model.unload() + assert not hasattr(mock_called, 'called') + +def test_unload_cuda_s2s(monkeypatch, mock_called, s2s_model): + """Test unload box model with cuda.""" + monkeypatch.setattr(hugginface.torch.cuda, 'empty_cache', mock_called) + monkeypatch.setattr(s2s_model, 'dev', 'cuda') + + s2s_model.unload() + assert hasattr(mock_called, 'called') + +def test_pipeline_notinit_s2s(s2s_model): + """Test tsl pipeline with not initialized model.""" + with pytest.raises(RuntimeError, match=r'^Model not loaded$'): + s2s_model._translate('test', 'ja', 'en') # pylint: disable=protected-access + +# def test_pipeline_wrong_type(monkeypatch, mock_tsl_tokenizer, s2s_model): +# """Test tsl pipeline with wrong type.""" +# monkeypatch.setattr(s2s_model, 'tokenizer', mock_tsl_tokenizer(s2s_model.name)) +# with pytest.raises(TypeError, match=r'^Unsupported type for text:.*'): +# s2s_model._translate(1, 'ja', 'en') # pylint: disable=protected-access + +def test_pipeline_no_tokens(monkeypatch, mock_tsl_tokenizer, s2s_model): + """Test tsl pipeline with no tokens generated from pre_tokenize.""" + # monkeypatch.setattr(s2s_model, 'pre_tokenize', lambda *args, **kwargs: []) + monkeypatch.setattr(s2s_model, 'model', '1') + monkeypatch.setattr(s2s_model, 'tokenizer', mock_tsl_tokenizer('test/id')) + + res = s2s_model._translate('', 'ja', 'en') # pylint: disable=protected-access + + assert res == '' + +def test_pipeline_m2m(monkeypatch, mock_tsl_tokenizer, mock_tsl_model, s2s_model): + """Test tsl pipeline with m2m model.""" + monkeypatch.setattr(hugginface, 'M2M100Tokenizer', mock_tsl_tokenizer) + # Reload to make ENV effective + s2s_model = hugginface.HugginfaceSeq2SeqModel.objects.get(name = s2s_model.name) + monkeypatch.setattr(s2s_model, 'model', mock_tsl_model(s2s_model.name)) + monkeypatch.setattr(s2s_model, 'tokenizer', mock_tsl_tokenizer(s2s_model.name)) + + s2s_model._translate(['1',], 'ja', 'en') # pylint: disable=protected-access + + assert s2s_model.tokenizer.called_get_lang_id is True + + +def test_pipeline(string, monkeypatch, mock_tsl_tokenizer, mock_tsl_model, mock_called, s2s_model): + """Test tsl pipeline (also check that cache is not cleared in CPU mode).""" + lang_src = 'ja' + lang_dst = 'en' + + monkeypatch.setattr(s2s_model, 'model', mock_tsl_model(s2s_model.name)) + monkeypatch.setattr(s2s_model, 'tokenizer', mock_tsl_tokenizer(s2s_model.name)) + monkeypatch.setattr(hugginface.torch.cuda, 'empty_cache', mock_called) + monkeypatch.setattr(s2s_model, 'dev', 'cpu') + + res = s2s_model._translate([string,], lang_src, lang_dst) # pylint: disable=protected-access + + assert res == string + assert s2s_model.tokenizer.model_id == s2s_model.name + assert s2s_model.tokenizer.src_lang == lang_src + + assert not hasattr(mock_called, 'called') + +def test_pipeline_clear_cache(monkeypatch, mock_tsl_tokenizer, mock_tsl_model, mock_called, s2s_model): + """Test tsl pipeline with cuda should clear_cache.""" + lang_src = 'ja' + lang_dst = 'en' + + monkeypatch.setattr(s2s_model, 'model', mock_tsl_model(s2s_model.name)) + monkeypatch.setattr(s2s_model, 'tokenizer', mock_tsl_tokenizer(s2s_model.name)) + monkeypatch.setattr(hugginface.torch.cuda, 'empty_cache', mock_called) + monkeypatch.setattr(s2s_model, 'dev', 'cuda') + + s2s_model._translate(['test',], lang_src, lang_dst) # pylint: disable=protected-access + + assert hasattr(mock_called, 'called') + + + +def test_pipeline_batch(batch_string, monkeypatch, mock_tsl_tokenizer, mock_tsl_model, s2s_model): + """Test tsl pipeline with batched string.""" + lang_src = 'ja' + lang_dst = 'en' + + monkeypatch.setattr(s2s_model, 'model', mock_tsl_model(s2s_model.name)) + monkeypatch.setattr(s2s_model, 'tokenizer', mock_tsl_tokenizer(s2s_model.name)) + + batch_string = [[_] for _ in batch_string] + res = s2s_model._translate(batch_string, lang_src, lang_dst) # pylint: disable=protected-access + + assert res == [_[0] for _ in batch_string] + assert s2s_model.tokenizer.model_id == s2s_model.name + assert s2s_model.tokenizer.src_lang == lang_src + +@pytest.mark.parametrize( + 'options', + [ + {}, + {'min_max_new_tokens': 30}, + {'max_max_new_tokens': 22}, + {'max_new_tokens': 15}, + {'max_new_tokens_ratio': 2} + ], + ids=[ + 'default', + 'min_max_new_tokens', + 'max_max_new_tokens', + 'max_new_tokens', + 'max_new_tokens_ratio' + ] +) +def test_pipeline_options(options, string, monkeypatch, mock_tsl_tokenizer, mock_tsl_model, s2s_model): + """Test tsl pipeline with options.""" + lang_src = 'ja' + lang_dst = 'en' + + monkeypatch.setattr(s2s_model, 'model', mock_tsl_model(s2s_model.name)) + monkeypatch.setattr(s2s_model, 'tokenizer', mock_tsl_tokenizer(s2s_model.name)) + + min_max_new_tokens = options.get('min_max_new_tokens', 20) + max_max_new_tokens = options.get('max_max_new_tokens', 512) + ntok = string.replace('\n', ' ').count(' ') + 1 + + string = m.TSLModel.pre_tokenize(string) + if min_max_new_tokens > max_max_new_tokens: + with pytest.raises(ValueError): + s2s_model._translate(string, lang_src, lang_dst, options=options) # pylint: disable=protected-access + else: + s2s_model._translate(string, lang_src, lang_dst, options=options) # pylint: disable=protected-access + + mnt = hugginface.get_mnt(ntok, options) + + model = s2s_model.model + + assert model.options['max_new_tokens'] == mnt diff --git a/tests/ocr_tsl/test_tesseract.py b/tests/plugin/test_tesseract.py similarity index 61% rename from tests/ocr_tsl/test_tesseract.py rename to tests/plugin/test_tesseract.py index 4ae33df..66822a6 100644 --- a/tests/ocr_tsl/test_tesseract.py +++ b/tests/plugin/test_tesseract.py @@ -19,15 +19,33 @@ """Tests for ocr module.""" # pylint: disable=redefined-outer-name -import importlib from pathlib import Path import pytest import requests from PIL import Image -from ocr_translate.ocr_tsl import tesseract +from ocr_translate import models as m +from ocr_translate.plugins import tesseract +pytestmark = pytest.mark.django_db + +@pytest.fixture() +def tesseract_model(language): + """OCRModel database object.""" + tesseract_model_dict = { + 'name': 'tesseract', + 'language_format': 'iso1', + 'entrypoint': 'tesseract.ocr' + } + + entrypoint = tesseract_model_dict.pop('entrypoint') + res = m.OCRModel.objects.create(**tesseract_model_dict) + res.entrypoint = entrypoint + res.languages.add(language) + res.save() + + return tesseract.TesseractOCRModel.objects.get(name = res.name) @pytest.fixture() def mock_content(): @@ -49,151 +67,152 @@ def mock_get(*args, **kwargs): monkeypatch.setattr(tesseract.requests, 'get', mock_get) - -def test_download_model_env_disabled(monkeypatch): +def test_download_model_env_disabled(monkeypatch, tesseract_model): """Test the download of a model from the environment variable.""" monkeypatch.setenv('TESSERACT_ALLOW_DOWNLOAD', 'false') - importlib.reload(tesseract) + # Reload to make ENV effective + tesseract_model = tesseract.TesseractOCRModel.objects.get(name = 'tesseract') with pytest.raises(ValueError, match=r'^TESSERACT_ALLOW_DOWNLOAD is false\. Downloading models is not allowed$'): - tesseract.download_model('eng') + tesseract_model.download_model('eng') -def test_download_model_env_enabled(monkeypatch, tmpdir, mock_content): +def test_download_model_env_enabled(monkeypatch, tmpdir, mock_content, tesseract_model): """Test the download of a model from the environment variable.""" monkeypatch.setenv('TESSERACT_ALLOW_DOWNLOAD', 'true') monkeypatch.setenv('TESSERACT_PREFIX', str(tmpdir)) - importlib.reload(tesseract) + # Reload to make ENV effective + tesseract_model = tesseract.TesseractOCRModel.objects.get(name = 'tesseract') model = 'test' - tesseract.download_model(model) + tesseract_model.download_model(model) tmpfile = tmpdir / f'{model}.traineddata' assert tmpfile.exists() with open(tmpfile, 'rb') as f: assert f.read() == mock_content -def test_download_already_exists(monkeypatch, tmpdir, mock_called): +def test_download_already_exists(monkeypatch, tmpdir, mock_called, tesseract_model): """Test the download of a model from the environment variable.""" - monkeypatch.setattr(tesseract, 'DOWNLOAD', True) - monkeypatch.setattr(tesseract, 'DATA_DIR', Path(tmpdir)) + monkeypatch.setattr(tesseract_model, 'download', True) + monkeypatch.setattr(tesseract_model, 'data_dir', Path(tmpdir)) monkeypatch.setattr(tesseract.requests, 'get', mock_called) model = 'test' tmpfile = tmpdir / f'{model}.traineddata' with tmpfile.open('w') as f: f.write('test') - tesseract.download_model(model) + tesseract_model.download_model(model) assert not hasattr(mock_called, 'called') @pytest.mark.parametrize('mock_get', [{'status_code': 404}], indirect=True) -def test_download_fail_request(monkeypatch, tmpdir): +def test_download_fail_request(monkeypatch, tmpdir, tesseract_model): """Test the download of a language with a normal+vertical model.""" - monkeypatch.setattr(tesseract, 'DOWNLOAD', True) - monkeypatch.setattr(tesseract, 'DATA_DIR', Path(tmpdir)) + monkeypatch.setattr(tesseract_model, 'download', True) + monkeypatch.setattr(tesseract_model, 'data_dir', Path(tmpdir)) model = 'test' with pytest.raises(ValueError, match=r'^Could not download model for language.*'): - tesseract.download_model(model) + tesseract_model.download_model(model) -def test_download_vertical(monkeypatch, tmpdir): +def test_download_vertical(monkeypatch, tmpdir, tesseract_model): """Test the download of a language with a normal+vertical model.""" - monkeypatch.setattr(tesseract, 'DOWNLOAD', True) - monkeypatch.setattr(tesseract, 'DATA_DIR', Path(tmpdir)) + monkeypatch.setattr(tesseract_model, 'download', True) + monkeypatch.setattr(tesseract_model, 'data_dir', Path(tmpdir)) - model = tesseract.VERTICAL_LANGS[0] - tesseract.download_model(model) + model = tesseract_model.VERTICAL_LANGS[0] + tesseract_model.download_model(model) tmpfile_h = tmpdir / f'{model}.traineddata' tmpfile_v = tmpdir / f'{model}_vert.traineddata' assert tmpfile_h.exists() assert tmpfile_v.exists() -def test_create_config(monkeypatch, tmpdir): +def test_create_config(monkeypatch, tmpdir, tesseract_model): """Test the creation of the tesseract config file.""" - monkeypatch.setattr(tesseract, 'CONFIG', False) - monkeypatch.setattr(tesseract, 'DATA_DIR', Path(tmpdir)) + monkeypatch.setattr(tesseract_model, 'config', False) + monkeypatch.setattr(tesseract_model, 'data_dir', Path(tmpdir)) - tesseract.create_config() + tesseract_model.create_config() - assert tesseract.CONFIG is True + assert tesseract_model.config is True pth = Path(tmpdir) assert (pth / 'configs').is_dir() assert (pth / 'configs' / 'tsv').is_file() with open(pth / 'configs' / 'tsv', encoding='utf-8') as f: assert f.read() == 'tessedit_create_tsv 1' -def test_create_config_many(monkeypatch, mock_called): +def test_create_config_many(monkeypatch, mock_called, tesseract_model): """Test that the creation of the tesseract config file happens only once.""" - monkeypatch.setattr(tesseract, 'CONFIG', False) + monkeypatch.setattr(tesseract_model, 'config', False) monkeypatch.setattr(tesseract.Path, 'exists', lambda *args, **kwargs: True) monkeypatch.setattr(tesseract.Path, 'mkdir', lambda *args, **kwargs: None) - tesseract.create_config() + tesseract_model.create_config() monkeypatch.setattr(tesseract.Path, 'mkdir', mock_called) - tesseract.create_config() + tesseract_model.create_config() assert not hasattr(mock_called, 'called') -def test_tesseract_pipeline_nomodel(monkeypatch, mock_called, tmpdir): +def test_tesseract_pipeline_nomodel(monkeypatch, mock_called, tmpdir, tesseract_model): """Test the tesseract pipeline.""" mock_result = 'mock_ocr_result' def mock_tesseract(*args, **kwargs): return {'text': mock_result} - monkeypatch.setattr(tesseract, 'CONFIG', True) - monkeypatch.setattr(tesseract, 'DOWNLOAD', True) - monkeypatch.setattr(tesseract, 'DATA_DIR', Path(tmpdir)) + monkeypatch.setattr(tesseract_model, 'config', True) + monkeypatch.setattr(tesseract_model, 'download', True) + monkeypatch.setattr(tesseract_model, 'data_dir', Path(tmpdir)) - monkeypatch.setattr(tesseract, 'download_model', mock_called) + monkeypatch.setattr(tesseract_model, 'download_model', mock_called) monkeypatch.setattr(tesseract, 'image_to_string', mock_tesseract) - res = tesseract.tesseract_pipeline('image', 'lang') + res = tesseract_model._ocr('image', 'lang') # pylint: disable=protected-access assert hasattr(mock_called, 'called') assert res == mock_result assert len(tmpdir.listdir()) == 0 # No config should be written and download is mocked -def test_tesseract_pipeline_noconfig(monkeypatch, mock_called, tmpdir): +def test_tesseract_pipeline_noconfig(monkeypatch, mock_called, tmpdir, tesseract_model): """Test the tesseract pipeline.""" mock_result = 'mock_ocr_result' def mock_tesseract(*args, **kwargs): return {'text': mock_result} - monkeypatch.setattr(tesseract, 'CONFIG', False) - monkeypatch.setattr(tesseract, 'DOWNLOAD', True) - monkeypatch.setattr(tesseract, 'DATA_DIR', Path(tmpdir)) + monkeypatch.setattr(tesseract_model, 'config', False) + monkeypatch.setattr(tesseract_model, 'download', True) + monkeypatch.setattr(tesseract_model, 'data_dir', Path(tmpdir)) - monkeypatch.setattr(tesseract, 'create_config', mock_called) - monkeypatch.setattr(tesseract, 'download_model', lambda *args, **kwargs: None) + monkeypatch.setattr(tesseract_model, 'create_config', mock_called) + monkeypatch.setattr(tesseract_model, 'download_model', lambda *args, **kwargs: None) monkeypatch.setattr(tesseract, 'image_to_string', mock_tesseract) - res = tesseract.tesseract_pipeline('image', 'lang') + res = tesseract_model._ocr('image', 'lang') # pylint: disable=protected-access assert hasattr(mock_called, 'called') assert res == mock_result assert len(tmpdir.listdir()) == 0 # No file should be downloaded (lambda mocked) and config is mocked @pytest.mark.parametrize('mock_called', [{'text': 0}], indirect=True) -def test_tesseract_pipeline_psm_horiz(monkeypatch, mock_called): +def test_tesseract_pipeline_psm_horiz(monkeypatch, mock_called, tesseract_model): """Test the tesseract pipeline.""" - monkeypatch.setattr(tesseract, 'create_config', lambda *args, **kwargs: None) - monkeypatch.setattr(tesseract, 'download_model', lambda *args, **kwargs: None) + monkeypatch.setattr(tesseract_model, 'create_config', lambda *args, **kwargs: None) + monkeypatch.setattr(tesseract_model, 'download_model', lambda *args, **kwargs: None) monkeypatch.setattr(tesseract, 'image_to_string', mock_called) - tesseract.tesseract_pipeline('image', 'lang') + tesseract_model._ocr('image', 'lang') # pylint: disable=protected-access assert hasattr(mock_called, 'called') assert '--psm 6' in mock_called.kwargs['config'] @pytest.mark.parametrize('mock_called', [{'text': 0}], indirect=True) -def test_tesseract_pipeline_psm_vert(monkeypatch, mock_called): +def test_tesseract_pipeline_psm_vert(monkeypatch, mock_called, tesseract_model): """Test the tesseract pipeline.""" - monkeypatch.setattr(tesseract, 'create_config', lambda *args, **kwargs: None) - monkeypatch.setattr(tesseract, 'download_model', lambda *args, **kwargs: None) + monkeypatch.setattr(tesseract_model, 'create_config', lambda *args, **kwargs: None) + monkeypatch.setattr(tesseract_model, 'download_model', lambda *args, **kwargs: None) monkeypatch.setattr(tesseract, 'image_to_string', mock_called) image = Image.new('RGB', (100, 100)) - lang = tesseract.VERTICAL_LANGS[0] - tesseract.tesseract_pipeline(image, lang) + lang = tesseract_model.VERTICAL_LANGS[0] + tesseract_model._ocr(image, lang) # pylint: disable=protected-access assert hasattr(mock_called, 'called') assert '--psm 5' in mock_called.kwargs['config'] diff --git a/tests/test_init.py b/tests/test_init.py index e388788..62f4237 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -31,7 +31,7 @@ def test_init_most_used_clean(mock_loaders): """Test init_most_used with empty database.""" ocr_tsl.init_most_used() - assert box.BBOX_MODEL_OBJ is None + assert box.BOX_MODEL_OBJ is None assert ocr.OCR_MODEL_OBJ is None assert tsl.TSL_MODEL_OBJ is None assert lang.LANG_SRC is None @@ -40,7 +40,7 @@ def test_init_most_used_clean(mock_loaders): def test_init_most_used_content(mock_loaders, language, box_model, ocr_model, tsl_model): """Test init_most_used with content in the database.""" ocr_tsl.init_most_used() - assert box.BBOX_MODEL_OBJ == box_model + assert box.BOX_MODEL_OBJ == box_model assert ocr.OCR_MODEL_OBJ == ocr_model assert tsl.TSL_MODEL_OBJ == tsl_model assert lang.LANG_SRC == language @@ -102,7 +102,7 @@ def test_init_most_used_more_content(mock_loaders, language_dict, image, option_ assert lang.LANG_SRC == lang2 assert lang.LANG_DST == lang3 - assert box.BBOX_MODEL_OBJ == ocr_box_model2 + assert box.BOX_MODEL_OBJ == ocr_box_model2 assert ocr.OCR_MODEL_OBJ == ocr_model2 assert tsl.TSL_MODEL_OBJ == tsl_model1 diff --git a/tests/test_models.py b/tests/test_models.py index 3b6016a..45927e2 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -17,9 +17,11 @@ # Home: https://github.com/Crivella/ocr_translate # ################################################################################### """Tests for the database models.""" +#pylint: disable=protected-access import django import pytest +from PIL.Image import Image as PILImage from ocr_translate import models as m from ocr_translate.messaging import Message @@ -27,82 +29,89 @@ pytestmark = pytest.mark.django_db -def test_add_language(language_dict, language): +def test_add_language(language_dict: dict, language: m.Language): """Test adding a language.""" query = m.Language.objects.filter(**language_dict) assert query.exists() assert str(query.first()) == language_dict['iso1'] -def test_add_language_existing(language_dict, language): +def test_add_language_existing(language_dict: dict, language: m.Language): """Test adding a language.""" with pytest.raises(django.db.utils.IntegrityError): m.Language.objects.create(**language_dict) -def test_add_ocr_box_model(box_model_dict, box_model): +def test_add_ocr_box_model(box_model_dict: dict, box_model: m.OCRBoxModel): """Test adding a new OCRBoxModel""" query = m.OCRBoxModel.objects.filter(**box_model_dict) assert query.exists() assert str(query.first()) == box_model_dict['name'] -def test_add_ocr_model(ocr_model_dict, ocr_model): +def test_add_ocr_model(ocr_model_dict: dict, ocr_model: m.OCRModel): """Test adding a new OCRModel""" query = m.OCRModel.objects.filter(**ocr_model_dict) assert query.exists() assert str(query.first()) == ocr_model_dict['name'] -def test_add_tsl_model(tsl_model_dict, tsl_model): +def test_add_tsl_model(tsl_model_dict: dict, tsl_model: m.TSLModel): """Test adding a new TSLModel""" query = m.TSLModel.objects.filter(**tsl_model_dict) assert query.exists() assert str(query.first()) == tsl_model_dict['name'] -def test_add_option_dict(option_dict): +def test_add_option_dict(option_dict: m.OptionDict): """Test adding a new OptionDict""" query = m.OptionDict.objects.filter(options={}) assert query.exists() assert str(query.first()) == str({}) -def test_box_run(image, language, box_model, option_dict, monkeypatch): +def test_box_run( + monkeypatch, image: m.Image, language: m.Language, box_model: m.OCRBoxModel, option_dict: m.OptionDict + ): """Test adding a new BoxRun""" lbrt = (1,2,3,4) def mock_pipeline(*args, **kwargs): return [lbrt] - monkeypatch.setattr(box, 'BBOX_MODEL_OBJ', box_model) - monkeypatch.setattr(box, 'box_pipeline', mock_pipeline) + monkeypatch.setattr(box, 'BOX_MODEL_OBJ', box_model) + box_model._box_detection = mock_pipeline - res = box.box_run(image, language, image=1, options=option_dict) + res = box_model.box_detection(image, language, image=1, options=option_dict) assert isinstance(res, list) assert isinstance(res[0], m.BBox) assert res[0].lbrt == lbrt -def test_box_run_reuse(image, language, box_model, option_dict, monkeypatch): +def test_box_run_reuse( + monkeypatch, image: m.Image, language: m.Language, box_model: m.OCRBoxModel, option_dict: m.OptionDict + ): """Test adding a new BoxRun""" lbrt = (1,2,3,4) def mock_pipeline(*args, **kwargs): return [lbrt] - monkeypatch.setattr(box, 'BBOX_MODEL_OBJ', box_model) - monkeypatch.setattr(box, 'box_pipeline', mock_pipeline) + monkeypatch.setattr(box, 'BOX_MODEL_OBJ', box_model) + box_model._box_detection = mock_pipeline assert m.OCRBoxRun.objects.count() == 0 - box.box_run(image, language, image=1, options=option_dict) + box_model.box_detection(image, language, image=1, options=option_dict) assert m.OCRBoxRun.objects.count() == 1 - box.box_run(image, language, image=1, options=option_dict) + box_model.box_detection(image, language, image=1, options=option_dict) assert m.OCRBoxRun.objects.count() == 1 -def test_ocr_run_nooption(bbox, language, ocr_model, option_dict, monkeypatch): +def test_ocr_run_nooption( + monkeypatch, image_pillow: PILImage, + bbox: m.BBox, language: m.Language, ocr_model: m.OCRModel, option_dict: m.OptionDict + ): """Test performin an ocr_run blocking""" text = 'test_text' def mock_ocr(*args, **kwargs): return text monkeypatch.setattr(ocr, 'OCR_MODEL_OBJ', ocr_model) - monkeypatch.setattr(ocr, 'ocr', mock_ocr) + ocr_model._ocr = mock_ocr - gen = ocr.ocr_run(bbox, language, image=1) + gen = ocr_model.ocr(bbox, language, image=image_pillow) res = next(gen) @@ -110,133 +119,169 @@ def mock_ocr(*args, **kwargs): assert res.text == text assert res.from_ocr.first().options.options == {} -def test_ocr_run_noimage(bbox, language, ocr_model, option_dict, monkeypatch): +def test_ocr_run_noimage( + monkeypatch, + bbox: m.BBox, language: m.Language, ocr_model: m.OCRModel, option_dict: m.OptionDict + ): """Test performin an ocr_run blocking""" text = 'test_text' def mock_ocr(*args, **kwargs): return text monkeypatch.setattr(ocr, 'OCR_MODEL_OBJ', ocr_model) - monkeypatch.setattr(ocr, 'ocr', mock_ocr) + ocr_model._ocr = mock_ocr - gen = ocr.ocr_run(bbox, language, options=option_dict) + gen = ocr_model.ocr(bbox, language, options=option_dict) with pytest.raises(ValueError, match=r'^Image is required for OCR$'): next(gen) -def test_ocr_run(bbox, language, ocr_model, option_dict, monkeypatch, mock_called): +def test_ocr_run( + monkeypatch, mock_called, image_pillow: PILImage, + bbox: m.BBox, language: m.Language, ocr_model: m.OCRModel, option_dict: m.OptionDict + ): """Test performin an ocr_run blocking + same pipeline (has to run lazily by refetching previous result)""" text = 'test_text' def mock_ocr(*args, **kwargs): return text monkeypatch.setattr(ocr, 'OCR_MODEL_OBJ', ocr_model) - monkeypatch.setattr(ocr, 'ocr', mock_ocr) + ocr_model._ocr = mock_ocr - gen = ocr.ocr_run(bbox, language, image=1, options=option_dict) + gen = ocr_model.ocr(bbox, language, image=image_pillow, options=option_dict) res = next(gen) assert isinstance(res, m.Text) assert res.text == text - monkeypatch.setattr(ocr, 'ocr', mock_called) # Should not be called as it should be lazy - gen_lazy = ocr.ocr_run(bbox, language, image=1, options=option_dict) + ocr_model._ocr = mock_called # Should not be called as it should be lazy + gen_lazy = ocr_model.ocr(bbox, language, image=image_pillow, options=option_dict) assert not hasattr(mock_called, 'called') assert next(gen_lazy) == res -def test_ocr_run_nonblock(bbox, language, ocr_model, option_dict, monkeypatch, mock_called): +def test_ocr_run_nonblock( + monkeypatch, mock_called, image_pillow: PILImage, + bbox: m.BBox, language: m.Language, ocr_model: m.OCRModel, option_dict: m.OptionDict + ): """Test performin an ocr_run non-blocking + same pipeline (has to run lazily by refetching previous result)""" text = 'test_text' def mock_ocr(*args, **kwargs): - def _handler(text): - return text - return Message(id_=0, msg={'args':(text,)}, handler=_handler) + return text monkeypatch.setattr(ocr, 'OCR_MODEL_OBJ', ocr_model) - monkeypatch.setattr(ocr, 'ocr', mock_ocr) + ocr_model._ocr = mock_ocr - gen = ocr.ocr_run(bbox, language, image=1, options=option_dict, block=False) + gen = ocr_model.ocr(bbox, language, image=image_pillow, options=option_dict, block=False) msg = next(gen) - msg.resolve() + # msg.resolve() res = next(gen) assert isinstance(msg, Message) assert isinstance(res, m.Text) assert res.text == text - monkeypatch.setattr(ocr, 'ocr', mock_called) # Should not be called as it should be lazy - gen_lazy = ocr.ocr_run(bbox, language, image=1, options=option_dict, block=False) + ocr_model._ocr = mock_called # Should not be called as it should be lazy + gen_lazy = ocr_model.ocr(bbox, language, image=image_pillow, options=option_dict, block=False) assert not hasattr(mock_called, 'called') assert next(gen_lazy) is None assert next(gen_lazy) == res -def test_tsl_run(text, language, tsl_model, option_dict, monkeypatch, mock_called): +def test_tsl_pre_tokenize(data_regression, string: str): + """Test tsl module.""" + options = [ + {}, + {'break_newlines': True}, + {'break_newlines': False}, + {'break_chars': '?.!'}, + {'ignore_chars': '?.!'}, + {'break_newlines': False, 'break_chars': '?.!'}, + {'break_newlines': False, 'ignore_chars': '?.!'}, + {'restore_dash_newlines': True}, + ] + + res = [] + for option in options: + dct = { + 'string': string, + 'options': option, + 'tokens': m.TSLModel.pre_tokenize(string, **option) + } + res.append(dct) + + data_regression.check({'res': res}) + +def test_tsl_run( + monkeypatch, mock_called, + text: m.Text, language: m.Language, tsl_model: m.TSLModel, option_dict: m.OptionDict + ): """Test performin an tsl_run blocking""" def mock_tsl_pipeline(*args, **kwargs): return text.text monkeypatch.setattr(tsl, 'TSL_MODEL_OBJ', tsl_model) - monkeypatch.setattr(tsl, 'tsl_pipeline', mock_tsl_pipeline) + tsl_model._translate = mock_tsl_pipeline - gen = tsl.tsl_run(text, src=language, dst=language, options=option_dict) + gen = tsl_model.translate(text, src=language, dst=language, options=option_dict) res = next(gen) assert isinstance(res, m.Text) assert res.text == text.text - monkeypatch.setattr(tsl, 'tsl_pipeline', mock_called) # Should not be called as it should be lazy - gen_lazy = tsl.tsl_run(text, src=language, dst=language, options=option_dict) + tsl_model._translate = mock_called # Should not be called as it should be lazy + gen_lazy = tsl_model.translate(text, src=language, dst=language, options=option_dict) assert not hasattr(mock_called, 'called') assert next(gen_lazy) == res -def test_tsl_run_nonblock(text, language, tsl_model, option_dict, monkeypatch, mock_called): +def test_tsl_run_nonblock( + monkeypatch, mock_called, + text: m.Text, language: m.Language, tsl_model: m.TSLModel, option_dict: m.OptionDict + ): """Test performin an tsl_run non-blocking""" def mock_tsl_pipeline(*args, **kwargs): - def _handler(text): - return text - return Message(id_=0, msg={'args':(text.text,)}, handler=_handler) + return text.text monkeypatch.setattr(tsl, 'TSL_MODEL_OBJ', tsl_model) - monkeypatch.setattr(tsl, 'tsl_pipeline', mock_tsl_pipeline) + tsl_model._translate = mock_tsl_pipeline - gen = tsl.tsl_run(text, src=language, dst=language, options=option_dict, block=False) + gen = tsl_model.translate(text, src=language, dst=language, options=option_dict, block=False) msg = next(gen) - msg.resolve() + # msg.resolve() res = next(gen) assert isinstance(msg, Message) assert isinstance(res, m.Text) assert res.text == text.text - monkeypatch.setattr(tsl, 'tsl_pipeline', mock_called) # Should not be called as it should be lazy - gen_lazy = tsl.tsl_run(text, src=language, dst=language, options=option_dict, block=False) + tsl_model._translate = mock_called # Should not be called as it should be lazy + gen_lazy = tsl_model.translate(text, src=language, dst=language, options=option_dict, block=False) assert not hasattr(mock_called, 'called') assert next(gen_lazy) is None assert next(gen_lazy) == res -def test_tsl_run_lazy(text, language, option_dict): +def test_tsl_run_lazy(text: m.Text, language: m.Language, tsl_model: m.TSLModel, option_dict: m.OptionDict): """Test tsl pipeline with worker""" + # Force and lazy should not be used together with pytest.raises(ValueError): - gen = tsl.tsl_run(text, language, language, force=True, lazy=True) + gen = tsl_model.translate(text, language, language, force=True, lazy=True) next(gen) # Nothing in the DB, so should rise ValueError (no previous found and lazy=True) with pytest.raises(ValueError): - gen = tsl.tsl_run(text, language, language, lazy=True) + gen = tsl_model.translate(text, language, language, lazy=True) next(gen) def test_ocr_tsl_work_plus_lazy( - image, image_pillow, text, bbox, - language, ocr_model, tsl_model, option_dict, - monkeypatch + monkeypatch, image_pillow: PILImage, + image: m.Image, text: m.Text, bbox: m.BBox, language: m.Language, + box_model: m.OCRBoxModel, ocr_model: m.OCRModel, tsl_model: m.TSLModel, option_dict: m.OptionDict ): """Test performin an ocr_tsl_run non-lazy""" def mock_box_run(*args, **kwargs): @@ -252,9 +297,13 @@ def mock_tsl_run(obj, *args, block=True, **kwargs): res, _ = m.Text.objects.get_or_create(text = obj.text + '_translated') yield res - monkeypatch.setattr(full, 'box_run', mock_box_run) - monkeypatch.setattr(full, 'ocr_run', mock_ocr_run) - monkeypatch.setattr(full, 'tsl_run', mock_tsl_run) + box_model.box_detection = mock_box_run + ocr_model.ocr = mock_ocr_run + tsl_model.translate = mock_tsl_run + + monkeypatch.setattr(box, 'BOX_MODEL_OBJ', box_model) + monkeypatch.setattr(ocr, 'OCR_MODEL_OBJ', ocr_model) + monkeypatch.setattr(tsl, 'TSL_MODEL_OBJ', tsl_model) res = full.ocr_tsl_pipeline_work(image_pillow, image.md5) @@ -274,7 +323,14 @@ def test_ocr_tsl_lazy(): with pytest.raises(ValueError, match=r'^Image with md5 .* does not exist$'): full.ocr_tsl_pipeline_lazy('') -def test_ocr_tsl_lazy_image(image, option_dict): +def test_ocr_tsl_lazy_image( + monkeypatch, image: m.Image, + box_model: m.OCRBoxModel, ocr_model: m.OCRModel, tsl_model: m.TSLModel, option_dict: m.OptionDict + ): """Test performin an ocr_tsl_run lazy (with image but missing ocr-tsl steps)""" + monkeypatch.setattr(box, 'BOX_MODEL_OBJ', box_model) + monkeypatch.setattr(ocr, 'OCR_MODEL_OBJ', ocr_model) + monkeypatch.setattr(tsl, 'TSL_MODEL_OBJ', tsl_model) + with pytest.raises(ValueError): full.ocr_tsl_pipeline_lazy(image.md5) diff --git a/tests/ocr_tsl/test_tsl/test_pre_tokenize_breakchar_.yml b/tests/test_models/test_tsl_pre_tokenize_breakchar_.yml similarity index 100% rename from tests/ocr_tsl/test_tsl/test_pre_tokenize_breakchar_.yml rename to tests/test_models/test_tsl_pre_tokenize_breakchar_.yml diff --git a/tests/ocr_tsl/test_tsl/test_pre_tokenize_breakchar_newline_.yml b/tests/test_models/test_tsl_pre_tokenize_breakchar_newline_.yml similarity index 100% rename from tests/ocr_tsl/test_tsl/test_pre_tokenize_breakchar_newline_.yml rename to tests/test_models/test_tsl_pre_tokenize_breakchar_newline_.yml diff --git a/tests/ocr_tsl/test_tsl/test_pre_tokenize_dash_newline_.yml b/tests/test_models/test_tsl_pre_tokenize_dash_newline_.yml similarity index 100% rename from tests/ocr_tsl/test_tsl/test_pre_tokenize_dash_newline_.yml rename to tests/test_models/test_tsl_pre_tokenize_dash_newline_.yml diff --git a/tests/ocr_tsl/test_tsl/test_pre_tokenize_newline_.yml b/tests/test_models/test_tsl_pre_tokenize_newline_.yml similarity index 100% rename from tests/ocr_tsl/test_tsl/test_pre_tokenize_newline_.yml rename to tests/test_models/test_tsl_pre_tokenize_newline_.yml diff --git a/tests/ocr_tsl/test_tsl/test_pre_tokenize_newlines_.yml b/tests/test_models/test_tsl_pre_tokenize_newlines_.yml similarity index 100% rename from tests/ocr_tsl/test_tsl/test_pre_tokenize_newlines_.yml rename to tests/test_models/test_tsl_pre_tokenize_newlines_.yml diff --git a/tests/ocr_tsl/test_tsl/test_pre_tokenize_simple_.yml b/tests/test_models/test_tsl_pre_tokenize_simple_.yml similarity index 100% rename from tests/ocr_tsl/test_tsl/test_pre_tokenize_simple_.yml rename to tests/test_models/test_tsl_pre_tokenize_simple_.yml diff --git a/tests/views/test_handshake.py b/tests/views/test_handshake.py index e3a4fc3..2665da9 100644 --- a/tests/views/test_handshake.py +++ b/tests/views/test_handshake.py @@ -72,7 +72,7 @@ def test_handshake_clean_initialized(client, monkeypatch, language, box_model, t """Test handshake with content + init.""" monkeypatch.setattr(lang, 'LANG_SRC', language) monkeypatch.setattr(lang, 'LANG_DST', language) - monkeypatch.setattr(box, 'BBOX_MODEL_OBJ', box_model) + monkeypatch.setattr(box, 'BOX_MODEL_OBJ', box_model) monkeypatch.setattr(ocr, 'OCR_MODEL_OBJ', ocr_model) monkeypatch.setattr(tsl, 'TSL_MODEL_OBJ', tsl_model) diff --git a/tests/views/test_run_tsl.py b/tests/views/test_run_tsl.py index cbacca2..8d3a2cd 100644 --- a/tests/views/test_run_tsl.py +++ b/tests/views/test_run_tsl.py @@ -23,7 +23,7 @@ import pytest from django.urls import reverse -from ocr_translate import views +from ocr_translate.ocr_tsl import tsl pytestmark = pytest.mark.django_db @@ -67,12 +67,13 @@ def test_run_tsl_post_invalid_data(client, post_kwargs): assert response.status_code == 400 -def test_run_tsl_post_valid(client, monkeypatch, post_kwargs): +def test_run_tsl_post_valid(client, monkeypatch, post_kwargs, tsl_model): """Test run_tsl with POST request with valid data.""" def mock_tsl_run(text, *args, **kwargs): """Mock translate.""" yield text - monkeypatch.setattr(views, 'tsl_run', mock_tsl_run) + monkeypatch.setattr(tsl, 'TSL_MODEL_OBJ', tsl_model) + tsl_model.translate = mock_tsl_run url = reverse('ocr_translate:run_tsl') response = client.post(url, **post_kwargs) diff --git a/tests/views/test_set_lang.py b/tests/views/test_set_lang.py index 648a2eb..1d614d9 100644 --- a/tests/views/test_set_lang.py +++ b/tests/views/test_set_lang.py @@ -173,7 +173,7 @@ def test_set_lang_post_no_unload_box( """Test set_lang with POST valid request. Switching lang_src has to cause unloading of box model if the model does not support that language.""" monkeypatch.setattr(lang, 'LANG_SRC', language) - monkeypatch.setattr(box, 'BBOX_MODEL_OBJ', box_model) + monkeypatch.setattr(box, 'BOX_MODEL_OBJ', box_model) monkeypatch.setattr(views, 'unload_box_model', mock_called) box_model.languages.add(language2) @@ -193,7 +193,7 @@ def test_set_lang_post_unload_box( """Test set_lang with POST valid request. Switching lang_src has to cause unloading of box model if the model does not support that language.""" monkeypatch.setattr(lang, 'LANG_SRC', language) - monkeypatch.setattr(box, 'BBOX_MODEL_OBJ', box_model) + monkeypatch.setattr(box, 'BOX_MODEL_OBJ', box_model) monkeypatch.setattr(views, 'unload_box_model', mock_called) post_kwargs['data']['lang_src'] = language2.iso1