Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add filter parameters to list_models() #7718

Merged
merged 14 commits into from
Jul 5, 2023
74 changes: 67 additions & 7 deletions test/test_extended_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,24 +103,84 @@ def test_weights_deserializable(name):
assert pickle.loads(pickle.dumps(weights)) is weights


def get_models_from_module(module):
return [
v.__name__
for k, v in module.__dict__.items()
if callable(v) and k[0].islower() and k[0] != "_" and k not in models._api.__all__
]


@pytest.mark.parametrize(
"module", [models, models.detection, models.quantization, models.segmentation, models.video, models.optical_flow]
)
def test_list_models(module):
def get_models_from_module(module):
return [
v.__name__
for k, v in module.__dict__.items()
if callable(v) and k[0].islower() and k[0] != "_" and k not in models._api.__all__
]

a = set(get_models_from_module(module))
b = set(x.replace("quantized_", "") for x in models.list_models(module))

assert len(b) > 0
assert a == b


@pytest.mark.parametrize(
"include_filters",
[
None,
[],
(),
"",
"*resnet*",
["*alexnet*"],
"*not-existing-model-for-test?",
["*resnet*", "*alexnet*"],
["*resnet*", "*alexnet*", "*not-existing-model-for-test?"],
("*resnet*", "*alexnet*"),
set(["*resnet*", "*alexnet*"]),
],
)
@pytest.mark.parametrize(
"exclude_filters",
[
None,
[],
(),
"",
"*resnet*",
["*alexnet*"],
["*not-existing-model-for-test?"],
["resnet34", "*not-existing-model-for-test?"],
["resnet34", "*resnet1*"],
("resnet34", "*resnet1*"),
set(["resnet34", "*resnet1*"]),
],
)
def test_list_models_filters(include_filters, exclude_filters):
actual = set(models.list_models(models, include=include_filters, exclude=exclude_filters))
classification_models = set(get_models_from_module(models))

if isinstance(include_filters, str):
include_filters = [include_filters]
if isinstance(exclude_filters, str):
exclude_filters = [exclude_filters]

if include_filters:
expected = set()
for include_f in include_filters:
include_f = include_f.strip("*?")
expected = expected | set(x for x in classification_models if include_f in x)
else:
expected = classification_models

if exclude_filters:
for exclude_f in exclude_filters:
exclude_f = exclude_f.strip("*?")
if exclude_f != "":
a_exclude = set(x for x in classification_models if exclude_f in x)
expected = expected - a_exclude

assert expected == actual


@pytest.mark.parametrize(
"name, weight",
[
Expand Down
33 changes: 29 additions & 4 deletions torchvision/models/_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import fnmatch
import importlib
import inspect
import sys
Expand All @@ -6,7 +7,7 @@
from functools import partial
from inspect import signature
from types import ModuleType
from typing import Any, Callable, Dict, List, Mapping, Optional, Type, TypeVar, Union
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Set, Type, TypeVar, Union

from torch import nn

Expand Down Expand Up @@ -203,19 +204,43 @@ def wrapper(fn: Callable[..., M]) -> Callable[..., M]:
return wrapper


def list_models(module: Optional[ModuleType] = None) -> List[str]:
def list_models(
module: Optional[ModuleType] = None,
include: Union[Iterable[str], str, None] = None,
exclude: Union[Iterable[str], str, None] = None,
) -> List[str]:
"""
Returns a list with the names of registered models.

Args:
module (ModuleType, optional): The module from which we want to extract the available models.
include (str or Iterable[str], optional): Filter(s) for including the models from the set of all models.
Filters are passed to `fnmatch <https://docs.python.org/3/library/fnmatch.html>`__ to match Unix shell-style
wildcards. In case of many filters, the results is the union of individual filters.
exclude (str or Iterable[str], optional): Filter(s) applied after include_filters to remove models.
Filter are passed to `fnmatch <https://docs.python.org/3/library/fnmatch.html>`__ to match Unix shell-style
wildcards. In case of many filters, the results is removal of all the models that match any individual filter.

Returns:
models (list): A list with the names of available models.
"""
models = [
all_models = {
k for k, v in BUILTIN_MODELS.items() if module is None or v.__module__.rsplit(".", 1)[0] == module.__name__
]
}
if include:
models: Set[str] = set()
if isinstance(include, str):
include = [include]
for include_filter in include:
models = models | set(fnmatch.filter(all_models, include_filter))
else:
models = all_models

if exclude:
if isinstance(exclude, str):
exclude = [exclude]
for exclude_filter in exclude:
models = models - set(fnmatch.filter(all_models, exclude_filter))
return sorted(models)


Expand Down