Skip to content

Commit

Permalink
[fbsync] Add filter parameters to list_models() (#7718)
Browse files Browse the repository at this point in the history
Reviewed By: matteobettini

Differential Revision: D48642263

fbshipit-source-id: 7dd986c91115b47383dfa69af070626a85b8bf07

Co-authored-by: Mateusz Guzek <matguzek@meta.com>
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
Co-authored-by: Philip Meier <github.pmeier@posteo.de>
  • Loading branch information
4 people authored and facebook-github-bot committed Aug 25, 2023
1 parent de45a8b commit 20d6ed1
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 11 deletions.
74 changes: 67 additions & 7 deletions test/test_extended_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,24 +103,84 @@ def test_weights_deserializable(name):
assert pickle.loads(pickle.dumps(weights)) is weights


def get_models_from_module(module):
return [
v.__name__
for k, v in module.__dict__.items()
if callable(v) and k[0].islower() and k[0] != "_" and k not in models._api.__all__
]


@pytest.mark.parametrize(
"module", [models, models.detection, models.quantization, models.segmentation, models.video, models.optical_flow]
)
def test_list_models(module):
def get_models_from_module(module):
return [
v.__name__
for k, v in module.__dict__.items()
if callable(v) and k[0].islower() and k[0] != "_" and k not in models._api.__all__
]

a = set(get_models_from_module(module))
b = set(x.replace("quantized_", "") for x in models.list_models(module))

assert len(b) > 0
assert a == b


@pytest.mark.parametrize(
"include_filters",
[
None,
[],
(),
"",
"*resnet*",
["*alexnet*"],
"*not-existing-model-for-test?",
["*resnet*", "*alexnet*"],
["*resnet*", "*alexnet*", "*not-existing-model-for-test?"],
("*resnet*", "*alexnet*"),
set(["*resnet*", "*alexnet*"]),
],
)
@pytest.mark.parametrize(
"exclude_filters",
[
None,
[],
(),
"",
"*resnet*",
["*alexnet*"],
["*not-existing-model-for-test?"],
["resnet34", "*not-existing-model-for-test?"],
["resnet34", "*resnet1*"],
("resnet34", "*resnet1*"),
set(["resnet34", "*resnet1*"]),
],
)
def test_list_models_filters(include_filters, exclude_filters):
actual = set(models.list_models(models, include=include_filters, exclude=exclude_filters))
classification_models = set(get_models_from_module(models))

if isinstance(include_filters, str):
include_filters = [include_filters]
if isinstance(exclude_filters, str):
exclude_filters = [exclude_filters]

if include_filters:
expected = set()
for include_f in include_filters:
include_f = include_f.strip("*?")
expected = expected | set(x for x in classification_models if include_f in x)
else:
expected = classification_models

if exclude_filters:
for exclude_f in exclude_filters:
exclude_f = exclude_f.strip("*?")
if exclude_f != "":
a_exclude = set(x for x in classification_models if exclude_f in x)
expected = expected - a_exclude

assert expected == actual


@pytest.mark.parametrize(
"name, weight",
[
Expand Down
33 changes: 29 additions & 4 deletions torchvision/models/_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import fnmatch
import importlib
import inspect
import sys
Expand All @@ -6,7 +7,7 @@
from functools import partial
from inspect import signature
from types import ModuleType
from typing import Any, Callable, Dict, List, Mapping, Optional, Type, TypeVar, Union
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Set, Type, TypeVar, Union

from torch import nn

Expand Down Expand Up @@ -203,19 +204,43 @@ def wrapper(fn: Callable[..., M]) -> Callable[..., M]:
return wrapper


def list_models(module: Optional[ModuleType] = None) -> List[str]:
def list_models(
module: Optional[ModuleType] = None,
include: Union[Iterable[str], str, None] = None,
exclude: Union[Iterable[str], str, None] = None,
) -> List[str]:
"""
Returns a list with the names of registered models.
Args:
module (ModuleType, optional): The module from which we want to extract the available models.
include (str or Iterable[str], optional): Filter(s) for including the models from the set of all models.
Filters are passed to `fnmatch <https://docs.python.org/3/library/fnmatch.html>`__ to match Unix shell-style
wildcards. In case of many filters, the results is the union of individual filters.
exclude (str or Iterable[str], optional): Filter(s) applied after include_filters to remove models.
Filter are passed to `fnmatch <https://docs.python.org/3/library/fnmatch.html>`__ to match Unix shell-style
wildcards. In case of many filters, the results is removal of all the models that match any individual filter.
Returns:
models (list): A list with the names of available models.
"""
models = [
all_models = {
k for k, v in BUILTIN_MODELS.items() if module is None or v.__module__.rsplit(".", 1)[0] == module.__name__
]
}
if include:
models: Set[str] = set()
if isinstance(include, str):
include = [include]
for include_filter in include:
models = models | set(fnmatch.filter(all_models, include_filter))
else:
models = all_models

if exclude:
if isinstance(exclude, str):
exclude = [exclude]
for exclude_filter in exclude:
models = models - set(fnmatch.filter(all_models, exclude_filter))
return sorted(models)


Expand Down

0 comments on commit 20d6ed1

Please sign in to comment.