From 135a0f9ea9841b6324b4fe8974e2543cbb95709a Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 2 Feb 2023 10:46:13 +0100 Subject: [PATCH] Make WeightEnum and Weights public + cleanups (#7100) --- test/test_extended_models.py | 2 +- torchvision/models/__init__.py | 7 ++++++- torchvision/models/_api.py | 31 ++++++++++++++++--------------- 3 files changed, 23 insertions(+), 17 deletions(-) diff --git a/test/test_extended_models.py b/test/test_extended_models.py index 068d3e238f9..ded0ecf63fe 100644 --- a/test/test_extended_models.py +++ b/test/test_extended_models.py @@ -7,7 +7,7 @@ import torch from common_extended_utils import get_file_size_mb, get_ops from torchvision import models -from torchvision.models._api import get_model_weights, Weights, WeightsEnum +from torchvision.models import get_model_weights, Weights, WeightsEnum from torchvision.models._utils import handle_legacy_interface run_if_test_with_extended = pytest.mark.skipif( diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index 93d96112ba1..6ea0a1f7178 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -15,4 +15,9 @@ from .swin_transformer import * from .maxvit import * from . import detection, optical_flow, quantization, segmentation, video -from ._api import get_model, get_model_builder, get_model_weights, get_weight, list_models + +# The Weights and WeightsEnum are developer-facing utils that we make public for +# downstream libs like torchgeo https://github.com/pytorch/vision/issues/7094 +# TODO: we could / should document them publicly, but it's not clear where, as +# they're not intended for end users. +from ._api import get_model, get_model_builder, get_model_weights, get_weight, list_models, Weights, WeightsEnum diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index d888ac262b0..3915547ebba 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -1,7 +1,8 @@ import importlib import inspect import sys -from dataclasses import dataclass, fields +from dataclasses import dataclass +from enum import Enum from functools import partial from inspect import signature from types import ModuleType @@ -9,8 +10,6 @@ from torch import nn -from torchvision._utils import StrEnum - from .._internally_replaced_utils import load_state_dict_from_url @@ -65,7 +64,7 @@ def __eq__(self, other: Any) -> bool: return self.transforms == other.transforms -class WeightsEnum(StrEnum): +class WeightsEnum(Enum): """ This class is the parent class of all model weights. Each model building method receives an optional `weights` parameter with its associated pre-trained weights. It inherits from `Enum` and its values should be of type @@ -75,14 +74,11 @@ class WeightsEnum(StrEnum): value (Weights): The data class entry with the weight information. """ - def __init__(self, value: Weights): - self._value_ = value - @classmethod def verify(cls, obj: Any) -> Any: if obj is not None: if type(obj) is str: - obj = cls.from_str(obj.replace(cls.__name__ + ".", "")) + obj = cls[obj.replace(cls.__name__ + ".", "")] elif not isinstance(obj, cls): raise TypeError( f"Invalid Weight class provided; expected {cls.__name__} but received {obj.__class__.__name__}." @@ -95,12 +91,17 @@ def get_state_dict(self, progress: bool) -> Mapping[str, Any]: def __repr__(self) -> str: return f"{self.__class__.__name__}.{self._name_}" - def __getattr__(self, name): - # Be able to fetch Weights attributes directly - for f in fields(Weights): - if f.name == name: - return object.__getattribute__(self.value, name) - return super().__getattr__(name) + @property + def url(self): + return self.value.url + + @property + def transforms(self): + return self.value.transforms + + @property + def meta(self): + return self.value.meta def get_weight(name: str) -> WeightsEnum: @@ -134,7 +135,7 @@ def get_weight(name: str) -> WeightsEnum: if weights_enum is None: raise ValueError(f"The weight enum '{enum_name}' for the specific method couldn't be retrieved.") - return weights_enum.from_str(value_name) + return weights_enum[value_name] def get_model_weights(name: Union[Callable, str]) -> WeightsEnum: