Skip to content

Improve models structure #53

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions src/layoutparser/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1 @@
from . import catalog as _UNUSED
# A trick learned from
# https://github.com/facebookresearch/detectron2/blob/62cf3a2b6840734d2717abdf96e2dd57ed6612a6/detectron2/checkpoint/__init__.py#L6
from .layoutmodel import Detectron2LayoutModel
from .detectron2.layoutmodel import Detectron2LayoutModel
20 changes: 20 additions & 0 deletions src/layoutparser/models/base_catalog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from iopath.common.file_io import HTTPURLHandler
from iopath.common.file_io import PathManager as PathManagerBase

# A trick learned from https://github.com/facebookresearch/detectron2/blob/65faeb4779e4c142484deeece18dc958c5c9ad18/detectron2/utils/file_io.py#L3


class DropboxHandler(HTTPURLHandler):
"""
Supports download and file check for dropbox links
"""

def _get_supported_prefixes(self):
return ["https://www.dropbox.com"]

def _isfile(self, path):
return path in self.cache_map


PathManager = PathManagerBase()
PathManager.register_handler(DropboxHandler())
50 changes: 50 additions & 0 deletions src/layoutparser/models/base_layoutmodel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from abc import ABC, abstractmethod
import os
import importlib


class BaseLayoutModel(ABC):

@property
@abstractmethod
def DETECTOR_NAME(self):
pass

@abstractmethod
def detect(self):
pass

# Add lazy loading mechanisms for layout models, refer to
# layoutparser.ocr.BaseOCRAgent
# TODO: Build a metaclass for lazy module loader
@property
@abstractmethod
def DEPENDENCIES(self):
"""DEPENDENCIES lists all necessary dependencies for the class."""
pass

@property
@abstractmethod
def MODULES(self):
"""MODULES instructs how to import these necessary libraries."""
pass

@classmethod
def _import_module(cls):
for m in cls.MODULES:
if importlib.util.find_spec(m["module_path"]):
setattr(
cls, m["import_name"], importlib.import_module(m["module_path"])
)
else:
raise ModuleNotFoundError(
f"\n "
f"\nPlease install the following libraries to support the class {cls.__name__}:"
f"\n pip install {' '.join(cls.DEPENDENCIES)}"
f"\n "
)

def __new__(cls, *args, **kwargs):

cls._import_module()
return super().__new__(cls)
4 changes: 4 additions & 0 deletions src/layoutparser/models/detectron2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from . import catalog as _UNUSED
# A trick learned from
# https://github.com/facebookresearch/detectron2/blob/62cf3a2b6840734d2717abdf96e2dd57ed6612a6/detectron2/checkpoint/__init__.py#L6
from .layoutmodel import Detectron2LayoutModel
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from iopath.common.file_io import PathHandler, PathManager, HTTPURLHandler
from iopath.common.file_io import PathManager as PathManagerBase
from iopath.common.file_io import PathHandler

# A trick learned from https://github.com/facebookresearch/detectron2/blob/65faeb4779e4c142484deeece18dc958c5c9ad18/detectron2/utils/file_io.py#L3
from ..base_catalog import PathManager

MODEL_CATALOG = {
"HJDataset": {
Expand Down Expand Up @@ -49,6 +48,7 @@
},
}

# fmt: off
LABEL_MAP_CATALOG = {
"HJDataset": {
1: "Page Frame",
Expand All @@ -59,7 +59,12 @@
6: "Subtitle",
7: "Other",
},
"PubLayNet": {0: "Text", 1: "Title", 2: "List", 3: "Table", 4: "Figure"},
"PubLayNet": {
0: "Text",
1: "Title",
2: "List",
3: "Table",
4: "Figure"},
"PrimaLayout": {
1: "TextRegion",
2: "ImageRegion",
Expand All @@ -77,34 +82,26 @@
5: "Headline",
6: "Advertisement",
},
"TableBank": {0: "Table"},
"TableBank": {
0: "Table"
},
}
# fmt: on


class DropboxHandler(HTTPURLHandler):
"""
Supports download and file check for dropbox links
"""

def _get_supported_prefixes(self):
return ["https://www.dropbox.com"]

def _isfile(self, path):
return path in self.cache_map


class LayoutParserHandler(PathHandler):
class LayoutParserDetectron2ModelHandler(PathHandler):
"""
Resolve anything that's in LayoutParser model zoo.
"""

PREFIX = "lp://"
PREFIX = "lp://detectron2/"

def _get_supported_prefixes(self):
return [self.PREFIX]

def _get_local_path(self, path, **kwargs):
model_name = path[len(self.PREFIX) :]

dataset_name, *model_name, data_type = model_name.split("/")

if data_type == "weight":
Expand All @@ -119,6 +116,4 @@ def _open(self, path, mode="r", **kwargs):
return PathManager.open(self._get_local_path(path), mode, **kwargs)


PathManager = PathManagerBase()
PathManager.register_handler(DropboxHandler())
PathManager.register_handler(LayoutParserHandler())
PathManager.register_handler(LayoutParserDetectron2ModelHandler())
Original file line number Diff line number Diff line change
@@ -1,58 +1,14 @@
from abc import ABC, abstractmethod
import os
import importlib

from PIL import Image
import numpy as np
import torch

from .catalog import PathManager, LABEL_MAP_CATALOG
from ..elements import *
from ..base_layoutmodel import BaseLayoutModel
from ...elements import Rectangle, TextBlock, Layout

__all__ = ["Detectron2LayoutModel"]


class BaseLayoutModel(ABC):
@abstractmethod
def detect(self):
pass

# Add lazy loading mechanisms for layout models, refer to
# layoutparser.ocr.BaseOCRAgent
# TODO: Build a metaclass for lazy module loader
@property
@abstractmethod
def DEPENDENCIES(self):
"""DEPENDENCIES lists all necessary dependencies for the class."""
pass

@property
@abstractmethod
def MODULES(self):
"""MODULES instructs how to import these necessary libraries."""
pass

@classmethod
def _import_module(cls):
for m in cls.MODULES:
if importlib.util.find_spec(m["module_path"]):
setattr(
cls, m["import_name"], importlib.import_module(m["module_path"])
)
else:
raise ModuleNotFoundError(
f"\n "
f"\nPlease install the following libraries to support the class {cls.__name__}:"
f"\n pip install {' '.join(cls.DEPENDENCIES)}"
f"\n "
)

def __new__(cls, *args, **kwargs):

cls._import_module()
return super().__new__(cls)


class Detectron2LayoutModel(BaseLayoutModel):
"""Create a Detectron2-based Layout Detection Model

Expand Down Expand Up @@ -93,6 +49,7 @@ class Detectron2LayoutModel(BaseLayoutModel):
},
{"import_name": "_config", "module_path": "detectron2.config"},
]
DETECTOR_NAME = "detectron2"

def __init__(
self,
Expand All @@ -111,18 +68,47 @@ def __init__(
extra_config.extend(["MODEL.DEVICE", "cpu"])

cfg = self._config.get_cfg()
config_path = self._reconstruct_path_with_detector_name(config_path)
config_path = PathManager.get_local_path(config_path)
cfg.merge_from_file(config_path)
cfg.merge_from_list(extra_config)

if model_path is not None:
model_path = self._reconstruct_path_with_detector_name(model_path)
cfg.MODEL.WEIGHTS = model_path
cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
self.cfg = cfg

self.label_map = label_map
self._create_model()

def _reconstruct_path_with_detector_name(self, path: str) -> str:
"""This function will add the detector name (detectron2) into the
lp model config path to get the "canonical" model name.

For example, for a given config_path `lp://HJDataset/faster_rcnn_R_50_FPN_3x/config`,
it will transform it into `lp://detectron2/HJDataset/faster_rcnn_R_50_FPN_3x/config`.
However, if the config_path already contains the detector name, we won't change it.

This function is a general step to support multiple backends in the layout-parser
library.

Args:
path (str): The given input path that might or might not contain the detector name.

Returns:
str: a modified path that contains the detector name.
"""
if path.startswith("lp://"): # TODO: Move "lp://" to a constant
model_name = path[len("lp://") :]
model_name_segments = model_name.split("/")
if (
len(model_name_segments) == 3
and "detectron2" not in model_name_segments
):
return "lp://" + self.DETECTOR_NAME + "/" + path[len("lp://") :]
return path

def gather_output(self, outputs):

instance_pred = outputs["instances"].to("cpu")
Expand Down