Skip to content

Commit 3bb81e5

Browse files
xuzhao9facebook-github-bot
authored andcommitted
Enable package import from torchbenchmark and torchbenchmark.models. (#2200)
Summary: More intuitive module import when installed as a package. This PR adds LazyImport to all torchbenchmark models so that they will be loaded only when being used. Pull Request resolved: #2200 Test Plan: Run the following commands: ``` python install.py pip install -e . $ python -c "import torchbenchmark.models; print(torchbenchmark.models.densenet121.Model)" <class 'torchbenchmark.models.densenet121.Model'> $ python -c "import torchbenchmark; print(torchbenchmark.models.densenet121.Model)" <class 'torchbenchmark.models.densenet121.Model'> ``` Fixes #2193 Reviewed By: drisspg Differential Revision: D55024068 Pulled By: xuzhao9 fbshipit-source-id: c9572308d202afaf8e528adb6585b5db80936cbd
1 parent 3b76cd4 commit 3bb81e5

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

torchbenchmark/models/__init__.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,30 @@
1+
import os
2+
import sys
3+
import importlib
4+
import importlib.util
5+
from pathlib import Path
6+
7+
def _list_models_without_import():
8+
def _is_non_empty(dirpath):
9+
init_file_path = dirpath.joinpath("__init__.py")
10+
return init_file_path.exists() and init_file_path.stat().st_size > 0
11+
current_dir = Path(__file__).parent
12+
subdirs = [entry for entry in current_dir.iterdir() if entry.is_dir()]
13+
non_empty_subdirs = list(map(lambda x: x.name, filter(_is_non_empty, subdirs)))
14+
return non_empty_subdirs
15+
16+
17+
class LazyImport:
18+
def __init__(self, module_name: str):
19+
self.module_name = module_name
20+
self._module = None
21+
22+
def __getattr__(self, attr: str):
23+
if self._module is None:
24+
self._module = importlib.import_module(self.module_name, package=__name__)
25+
return getattr(self._module, attr)
26+
27+
28+
for _model_name in _list_models_without_import():
29+
globals()[_model_name] = LazyImport(f".{_model_name}")
130

0 commit comments

Comments
 (0)