Skip to content

Commit

Permalink
Fix resnet_fpn_backbone(pretrained=True) (#7172)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored Feb 6, 2023
1 parent 135a0f9 commit 2cd25c1
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 7 deletions.
7 changes: 6 additions & 1 deletion test/test_extended_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torchvision import models
from torchvision.models import get_model_weights, Weights, WeightsEnum
from torchvision.models._utils import handle_legacy_interface
from torchvision.models.detection.backbone_utils import mobilenet_backbone, resnet_fpn_backbone

run_if_test_with_extended = pytest.mark.skipif(
os.getenv("PYTORCH_TEST_WITH_EXTENDED", "0") != "1",
Expand Down Expand Up @@ -425,7 +426,11 @@ def builder(*, weights=None, flag):
+ TM.list_model_fns(models.quantization)
+ TM.list_model_fns(models.segmentation)
+ TM.list_model_fns(models.video)
+ TM.list_model_fns(models.optical_flow),
+ TM.list_model_fns(models.optical_flow)
+ [
lambda pretrained: resnet_fpn_backbone(backbone_name="resnet50", pretrained=pretrained),
lambda pretrained: mobilenet_backbone(backbone_name="mobilenet_v2", fpn=False, pretrained=pretrained),
],
)
@run_if_test_with_extended
def test_pretrained_deprecation(self, model_fn):
Expand Down
8 changes: 4 additions & 4 deletions torchvision/models/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from functools import partial
from inspect import signature
from types import ModuleType
from typing import Any, Callable, cast, Dict, List, Mapping, Optional, TypeVar, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Type, TypeVar, Union

from torch import nn

Expand Down Expand Up @@ -138,7 +138,7 @@ def get_weight(name: str) -> WeightsEnum:
return weights_enum[value_name]


def get_model_weights(name: Union[Callable, str]) -> WeightsEnum:
def get_model_weights(name: Union[Callable, str]) -> Type[WeightsEnum]:
"""
Returns the weights enum class associated to the given model.
Expand All @@ -152,7 +152,7 @@ def get_model_weights(name: Union[Callable, str]) -> WeightsEnum:
return _get_enum_from_fn(model)


def _get_enum_from_fn(fn: Callable) -> WeightsEnum:
def _get_enum_from_fn(fn: Callable) -> Type[WeightsEnum]:
"""
Internal method that gets the weight enum of a specific model builder method.
Expand Down Expand Up @@ -182,7 +182,7 @@ def _get_enum_from_fn(fn: Callable) -> WeightsEnum:
"The WeightsEnum class for the specific method couldn't be retrieved. Make sure the typing info is correct."
)

return cast(WeightsEnum, weights_enum)
return weights_enum


M = TypeVar("M", bound=nn.Module)
Expand Down
4 changes: 2 additions & 2 deletions torchvision/models/detection/backbone_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]:
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: _get_enum_from_fn(resnet.__dict__[kwargs["backbone_name"]]).from_str("IMAGENET1K_V1"),
lambda kwargs: _get_enum_from_fn(resnet.__dict__[kwargs["backbone_name"]])["IMAGENET1K_V1"],
),
)
def resnet_fpn_backbone(
Expand Down Expand Up @@ -177,7 +177,7 @@ def _validate_trainable_layers(
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: _get_enum_from_fn(mobilenet.__dict__[kwargs["backbone_name"]]).from_str("IMAGENET1K_V1"),
lambda kwargs: _get_enum_from_fn(mobilenet.__dict__[kwargs["backbone_name"]])["IMAGENET1K_V1"],
),
)
def mobilenet_backbone(
Expand Down

0 comments on commit 2cd25c1

Please sign in to comment.