Skip to content

Commit

Permalink
Enable package import from torchbenchmark and torchbenchmark.models. (#…
Browse files Browse the repository at this point in the history
…2200)

Summary:
More intuitive module import when installed as a package.
This PR adds LazyImport to all torchbenchmark models so that they will be loaded only when being used.

Pull Request resolved: #2200

Test Plan:
Run the following commands:
```
python install.py
pip install -e .
$ python -c "import torchbenchmark.models; print(torchbenchmark.models.densenet121.Model)"
<class 'torchbenchmark.models.densenet121.Model'>
$ python -c "import torchbenchmark; print(torchbenchmark.models.densenet121.Model)"
<class 'torchbenchmark.models.densenet121.Model'>
```

Fixes #2193

Reviewed By: drisspg

Differential Revision: D55024068

Pulled By: xuzhao9

fbshipit-source-id: c9572308d202afaf8e528adb6585b5db80936cbd
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Mar 18, 2024
1 parent 3b76cd4 commit 3bb81e5
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions torchbenchmark/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,30 @@
import os
import sys
import importlib
import importlib.util
from pathlib import Path

def _list_models_without_import():
def _is_non_empty(dirpath):
init_file_path = dirpath.joinpath("__init__.py")
return init_file_path.exists() and init_file_path.stat().st_size > 0
current_dir = Path(__file__).parent
subdirs = [entry for entry in current_dir.iterdir() if entry.is_dir()]
non_empty_subdirs = list(map(lambda x: x.name, filter(_is_non_empty, subdirs)))
return non_empty_subdirs


class LazyImport:
def __init__(self, module_name: str):
self.module_name = module_name
self._module = None

def __getattr__(self, attr: str):
if self._module is None:
self._module = importlib.import_module(self.module_name, package=__name__)
return getattr(self._module, attr)


for _model_name in _list_models_without_import():
globals()[_model_name] = LazyImport(f".{_model_name}")

0 comments on commit 3bb81e5

Please sign in to comment.