diff --git a/src/layoutparser/models/__init__.py b/src/layoutparser/models/__init__.py index d7f1d60..6f07193 100644 --- a/src/layoutparser/models/__init__.py +++ b/src/layoutparser/models/__init__.py @@ -1,4 +1 @@ -from . import catalog as _UNUSED -# A trick learned from -# https://github.com/facebookresearch/detectron2/blob/62cf3a2b6840734d2717abdf96e2dd57ed6612a6/detectron2/checkpoint/__init__.py#L6 -from .layoutmodel import Detectron2LayoutModel +from .detectron2.layoutmodel import Detectron2LayoutModel \ No newline at end of file diff --git a/src/layoutparser/models/base_catalog.py b/src/layoutparser/models/base_catalog.py new file mode 100644 index 0000000..3eb6ba5 --- /dev/null +++ b/src/layoutparser/models/base_catalog.py @@ -0,0 +1,20 @@ +from iopath.common.file_io import HTTPURLHandler +from iopath.common.file_io import PathManager as PathManagerBase + +# A trick learned from https://github.com/facebookresearch/detectron2/blob/65faeb4779e4c142484deeece18dc958c5c9ad18/detectron2/utils/file_io.py#L3 + + +class DropboxHandler(HTTPURLHandler): + """ + Supports download and file check for dropbox links + """ + + def _get_supported_prefixes(self): + return ["https://www.dropbox.com"] + + def _isfile(self, path): + return path in self.cache_map + + +PathManager = PathManagerBase() +PathManager.register_handler(DropboxHandler()) \ No newline at end of file diff --git a/src/layoutparser/models/base_layoutmodel.py b/src/layoutparser/models/base_layoutmodel.py new file mode 100644 index 0000000..33c45fa --- /dev/null +++ b/src/layoutparser/models/base_layoutmodel.py @@ -0,0 +1,50 @@ +from abc import ABC, abstractmethod +import os +import importlib + + +class BaseLayoutModel(ABC): + + @property + @abstractmethod + def DETECTOR_NAME(self): + pass + + @abstractmethod + def detect(self): + pass + + # Add lazy loading mechanisms for layout models, refer to + # layoutparser.ocr.BaseOCRAgent + # TODO: Build a metaclass for lazy module loader + @property + @abstractmethod + def DEPENDENCIES(self): + """DEPENDENCIES lists all necessary dependencies for the class.""" + pass + + @property + @abstractmethod + def MODULES(self): + """MODULES instructs how to import these necessary libraries.""" + pass + + @classmethod + def _import_module(cls): + for m in cls.MODULES: + if importlib.util.find_spec(m["module_path"]): + setattr( + cls, m["import_name"], importlib.import_module(m["module_path"]) + ) + else: + raise ModuleNotFoundError( + f"\n " + f"\nPlease install the following libraries to support the class {cls.__name__}:" + f"\n pip install {' '.join(cls.DEPENDENCIES)}" + f"\n " + ) + + def __new__(cls, *args, **kwargs): + + cls._import_module() + return super().__new__(cls) \ No newline at end of file diff --git a/src/layoutparser/models/detectron2/__init__.py b/src/layoutparser/models/detectron2/__init__.py new file mode 100644 index 0000000..d7f1d60 --- /dev/null +++ b/src/layoutparser/models/detectron2/__init__.py @@ -0,0 +1,4 @@ +from . import catalog as _UNUSED +# A trick learned from +# https://github.com/facebookresearch/detectron2/blob/62cf3a2b6840734d2717abdf96e2dd57ed6612a6/detectron2/checkpoint/__init__.py#L6 +from .layoutmodel import Detectron2LayoutModel diff --git a/src/layoutparser/models/catalog.py b/src/layoutparser/models/detectron2/catalog.py similarity index 81% rename from src/layoutparser/models/catalog.py rename to src/layoutparser/models/detectron2/catalog.py index 80dd317..693f831 100644 --- a/src/layoutparser/models/catalog.py +++ b/src/layoutparser/models/detectron2/catalog.py @@ -1,7 +1,6 @@ -from iopath.common.file_io import PathHandler, PathManager, HTTPURLHandler -from iopath.common.file_io import PathManager as PathManagerBase +from iopath.common.file_io import PathHandler -# A trick learned from https://github.com/facebookresearch/detectron2/blob/65faeb4779e4c142484deeece18dc958c5c9ad18/detectron2/utils/file_io.py#L3 +from ..base_catalog import PathManager MODEL_CATALOG = { "HJDataset": { @@ -49,6 +48,7 @@ }, } +# fmt: off LABEL_MAP_CATALOG = { "HJDataset": { 1: "Page Frame", @@ -59,7 +59,12 @@ 6: "Subtitle", 7: "Other", }, - "PubLayNet": {0: "Text", 1: "Title", 2: "List", 3: "Table", 4: "Figure"}, + "PubLayNet": { + 0: "Text", + 1: "Title", + 2: "List", + 3: "Table", + 4: "Figure"}, "PrimaLayout": { 1: "TextRegion", 2: "ImageRegion", @@ -77,34 +82,26 @@ 5: "Headline", 6: "Advertisement", }, - "TableBank": {0: "Table"}, + "TableBank": { + 0: "Table" + }, } +# fmt: on -class DropboxHandler(HTTPURLHandler): - """ - Supports download and file check for dropbox links - """ - - def _get_supported_prefixes(self): - return ["https://www.dropbox.com"] - - def _isfile(self, path): - return path in self.cache_map - - -class LayoutParserHandler(PathHandler): +class LayoutParserDetectron2ModelHandler(PathHandler): """ Resolve anything that's in LayoutParser model zoo. """ - PREFIX = "lp://" + PREFIX = "lp://detectron2/" def _get_supported_prefixes(self): return [self.PREFIX] def _get_local_path(self, path, **kwargs): model_name = path[len(self.PREFIX) :] + dataset_name, *model_name, data_type = model_name.split("/") if data_type == "weight": @@ -119,6 +116,4 @@ def _open(self, path, mode="r", **kwargs): return PathManager.open(self._get_local_path(path), mode, **kwargs) -PathManager = PathManagerBase() -PathManager.register_handler(DropboxHandler()) -PathManager.register_handler(LayoutParserHandler()) +PathManager.register_handler(LayoutParserDetectron2ModelHandler()) diff --git a/src/layoutparser/models/layoutmodel.py b/src/layoutparser/models/detectron2/layoutmodel.py similarity index 72% rename from src/layoutparser/models/layoutmodel.py rename to src/layoutparser/models/detectron2/layoutmodel.py index 5544653..287c1e2 100644 --- a/src/layoutparser/models/layoutmodel.py +++ b/src/layoutparser/models/detectron2/layoutmodel.py @@ -1,58 +1,14 @@ -from abc import ABC, abstractmethod -import os -import importlib - from PIL import Image import numpy as np import torch from .catalog import PathManager, LABEL_MAP_CATALOG -from ..elements import * +from ..base_layoutmodel import BaseLayoutModel +from ...elements import Rectangle, TextBlock, Layout __all__ = ["Detectron2LayoutModel"] -class BaseLayoutModel(ABC): - @abstractmethod - def detect(self): - pass - - # Add lazy loading mechanisms for layout models, refer to - # layoutparser.ocr.BaseOCRAgent - # TODO: Build a metaclass for lazy module loader - @property - @abstractmethod - def DEPENDENCIES(self): - """DEPENDENCIES lists all necessary dependencies for the class.""" - pass - - @property - @abstractmethod - def MODULES(self): - """MODULES instructs how to import these necessary libraries.""" - pass - - @classmethod - def _import_module(cls): - for m in cls.MODULES: - if importlib.util.find_spec(m["module_path"]): - setattr( - cls, m["import_name"], importlib.import_module(m["module_path"]) - ) - else: - raise ModuleNotFoundError( - f"\n " - f"\nPlease install the following libraries to support the class {cls.__name__}:" - f"\n pip install {' '.join(cls.DEPENDENCIES)}" - f"\n " - ) - - def __new__(cls, *args, **kwargs): - - cls._import_module() - return super().__new__(cls) - - class Detectron2LayoutModel(BaseLayoutModel): """Create a Detectron2-based Layout Detection Model @@ -93,6 +49,7 @@ class Detectron2LayoutModel(BaseLayoutModel): }, {"import_name": "_config", "module_path": "detectron2.config"}, ] + DETECTOR_NAME = "detectron2" def __init__( self, @@ -111,11 +68,13 @@ def __init__( extra_config.extend(["MODEL.DEVICE", "cpu"]) cfg = self._config.get_cfg() + config_path = self._reconstruct_path_with_detector_name(config_path) config_path = PathManager.get_local_path(config_path) cfg.merge_from_file(config_path) cfg.merge_from_list(extra_config) if model_path is not None: + model_path = self._reconstruct_path_with_detector_name(model_path) cfg.MODEL.WEIGHTS = model_path cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu" self.cfg = cfg @@ -123,6 +82,33 @@ def __init__( self.label_map = label_map self._create_model() + def _reconstruct_path_with_detector_name(self, path: str) -> str: + """This function will add the detector name (detectron2) into the + lp model config path to get the "canonical" model name. + + For example, for a given config_path `lp://HJDataset/faster_rcnn_R_50_FPN_3x/config`, + it will transform it into `lp://detectron2/HJDataset/faster_rcnn_R_50_FPN_3x/config`. + However, if the config_path already contains the detector name, we won't change it. + + This function is a general step to support multiple backends in the layout-parser + library. + + Args: + path (str): The given input path that might or might not contain the detector name. + + Returns: + str: a modified path that contains the detector name. + """ + if path.startswith("lp://"): # TODO: Move "lp://" to a constant + model_name = path[len("lp://") :] + model_name_segments = model_name.split("/") + if ( + len(model_name_segments) == 3 + and "detectron2" not in model_name_segments + ): + return "lp://" + self.DETECTOR_NAME + "/" + path[len("lp://") :] + return path + def gather_output(self, outputs): instance_pred = outputs["instances"].to("cpu")