From 20d6ed1168107f12e168369dd86ba3774e1910a4 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 25 Aug 2023 02:48:13 -0700 Subject: [PATCH] [fbsync] Add filter parameters to `list_models()` (#7718) Reviewed By: matteobettini Differential Revision: D48642263 fbshipit-source-id: 7dd986c91115b47383dfa69af070626a85b8bf07 Co-authored-by: Mateusz Guzek Co-authored-by: Nicolas Hug Co-authored-by: Philip Meier --- test/test_extended_models.py | 74 ++++++++++++++++++++++++++++++++---- torchvision/models/_api.py | 33 ++++++++++++++-- 2 files changed, 96 insertions(+), 11 deletions(-) diff --git a/test/test_extended_models.py b/test/test_extended_models.py index 0866cc0f8a3..96a3fc5f8ed 100644 --- a/test/test_extended_models.py +++ b/test/test_extended_models.py @@ -103,17 +103,18 @@ def test_weights_deserializable(name): assert pickle.loads(pickle.dumps(weights)) is weights +def get_models_from_module(module): + return [ + v.__name__ + for k, v in module.__dict__.items() + if callable(v) and k[0].islower() and k[0] != "_" and k not in models._api.__all__ + ] + + @pytest.mark.parametrize( "module", [models, models.detection, models.quantization, models.segmentation, models.video, models.optical_flow] ) def test_list_models(module): - def get_models_from_module(module): - return [ - v.__name__ - for k, v in module.__dict__.items() - if callable(v) and k[0].islower() and k[0] != "_" and k not in models._api.__all__ - ] - a = set(get_models_from_module(module)) b = set(x.replace("quantized_", "") for x in models.list_models(module)) @@ -121,6 +122,65 @@ def get_models_from_module(module): assert a == b +@pytest.mark.parametrize( + "include_filters", + [ + None, + [], + (), + "", + "*resnet*", + ["*alexnet*"], + "*not-existing-model-for-test?", + ["*resnet*", "*alexnet*"], + ["*resnet*", "*alexnet*", "*not-existing-model-for-test?"], + ("*resnet*", "*alexnet*"), + set(["*resnet*", "*alexnet*"]), + ], +) +@pytest.mark.parametrize( + "exclude_filters", + [ + None, + [], + (), + "", + "*resnet*", + ["*alexnet*"], + ["*not-existing-model-for-test?"], + ["resnet34", "*not-existing-model-for-test?"], + ["resnet34", "*resnet1*"], + ("resnet34", "*resnet1*"), + set(["resnet34", "*resnet1*"]), + ], +) +def test_list_models_filters(include_filters, exclude_filters): + actual = set(models.list_models(models, include=include_filters, exclude=exclude_filters)) + classification_models = set(get_models_from_module(models)) + + if isinstance(include_filters, str): + include_filters = [include_filters] + if isinstance(exclude_filters, str): + exclude_filters = [exclude_filters] + + if include_filters: + expected = set() + for include_f in include_filters: + include_f = include_f.strip("*?") + expected = expected | set(x for x in classification_models if include_f in x) + else: + expected = classification_models + + if exclude_filters: + for exclude_f in exclude_filters: + exclude_f = exclude_f.strip("*?") + if exclude_f != "": + a_exclude = set(x for x in classification_models if exclude_f in x) + expected = expected - a_exclude + + assert expected == actual + + @pytest.mark.parametrize( "name, weight", [ diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index e244207a8ed..0999bf7ba6b 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -1,3 +1,4 @@ +import fnmatch import importlib import inspect import sys @@ -6,7 +7,7 @@ from functools import partial from inspect import signature from types import ModuleType -from typing import Any, Callable, Dict, List, Mapping, Optional, Type, TypeVar, Union +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Set, Type, TypeVar, Union from torch import nn @@ -203,19 +204,43 @@ def wrapper(fn: Callable[..., M]) -> Callable[..., M]: return wrapper -def list_models(module: Optional[ModuleType] = None) -> List[str]: +def list_models( + module: Optional[ModuleType] = None, + include: Union[Iterable[str], str, None] = None, + exclude: Union[Iterable[str], str, None] = None, +) -> List[str]: """ Returns a list with the names of registered models. Args: module (ModuleType, optional): The module from which we want to extract the available models. + include (str or Iterable[str], optional): Filter(s) for including the models from the set of all models. + Filters are passed to `fnmatch `__ to match Unix shell-style + wildcards. In case of many filters, the results is the union of individual filters. + exclude (str or Iterable[str], optional): Filter(s) applied after include_filters to remove models. + Filter are passed to `fnmatch `__ to match Unix shell-style + wildcards. In case of many filters, the results is removal of all the models that match any individual filter. Returns: models (list): A list with the names of available models. """ - models = [ + all_models = { k for k, v in BUILTIN_MODELS.items() if module is None or v.__module__.rsplit(".", 1)[0] == module.__name__ - ] + } + if include: + models: Set[str] = set() + if isinstance(include, str): + include = [include] + for include_filter in include: + models = models | set(fnmatch.filter(all_models, include_filter)) + else: + models = all_models + + if exclude: + if isinstance(exclude, str): + exclude = [exclude] + for exclude_filter in exclude: + models = models - set(fnmatch.filter(all_models, exclude_filter)) return sorted(models)