Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions torchvision/prototype/models/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def verify(cls, obj: Any) -> Any:
if obj is not None:
if type(obj) is str:
obj = cls.from_str(obj)
elif not isinstance(obj, cls) and not isinstance(obj, WeightEntry):
elif not isinstance(obj, cls):
raise TypeError(
f"Invalid Weight class provided; expected {cls.__name__} but received {obj.__class__.__name__}."
)
Expand All @@ -63,7 +63,7 @@ def from_str(cls, value: str) -> "Weights":
return v
raise ValueError(f"Invalid value {value} for enum {cls.__name__}.")

def state_dict(self, progress: bool) -> OrderedDict:
def get_state_dict(self, progress: bool) -> OrderedDict:
return load_state_dict_from_url(self.url, progress=progress)

def __repr__(self):
Expand All @@ -90,7 +90,7 @@ def get_weight(fn: Callable, weight_name: str) -> Weights:
"""
sig = signature(fn)
if "weights" not in sig.parameters:
raise ValueError("The method is missing the 'weights' argument.")
raise ValueError("The method is missing the 'weights' parameter.")

ann = signature(fn).parameters["weights"].annotation
weights_class = None
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/models/alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class AlexNetWeights(Weights):

def alexnet(weights: Optional[AlexNetWeights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = AlexNetWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = AlexNetWeights.verify(weights)
if weights is not None:
Expand All @@ -39,6 +39,6 @@ def alexnet(weights: Optional[AlexNetWeights] = None, progress: bool = True, **k
model = AlexNet(**kwargs)

if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))

return model
22 changes: 11 additions & 11 deletions torchvision/prototype/models/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _load_state_dict(model: nn.Module, weights: Weights, progress: bool) -> None
r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
)

state_dict = weights.state_dict(progress=progress)
state_dict = weights.get_state_dict(progress=progress)
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
Expand Down Expand Up @@ -63,11 +63,11 @@ def _densenet(
return model


_common_meta = {
_COMMON_META = {
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": None, # weights ported from LuaTorch
"recipe": None, # TODO: add here a URL to documentation stating that the weights were ported from LuaTorch
}


Expand All @@ -76,7 +76,7 @@ class DenseNet121Weights(Weights):
url="https://download.pytorch.org/models/densenet121-a639ec97.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"acc@1": 74.434,
"acc@5": 91.972,
},
Expand All @@ -88,7 +88,7 @@ class DenseNet161Weights(Weights):
url="https://download.pytorch.org/models/densenet161-8d451a50.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"acc@1": 77.138,
"acc@5": 93.560,
},
Expand All @@ -100,7 +100,7 @@ class DenseNet169Weights(Weights):
url="https://download.pytorch.org/models/densenet169-b2777c0a.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"acc@1": 75.600,
"acc@5": 92.806,
},
Expand All @@ -112,7 +112,7 @@ class DenseNet201Weights(Weights):
url="https://download.pytorch.org/models/densenet201-c1103571.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"acc@1": 76.896,
"acc@5": 93.370,
},
Expand All @@ -121,7 +121,7 @@ class DenseNet201Weights(Weights):

def densenet121(weights: Optional[DenseNet121Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = DenseNet121Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = DenseNet121Weights.verify(weights)

Expand All @@ -130,7 +130,7 @@ def densenet121(weights: Optional[DenseNet121Weights] = None, progress: bool = T

def densenet161(weights: Optional[DenseNet161Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = DenseNet161Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = DenseNet161Weights.verify(weights)

Expand All @@ -139,7 +139,7 @@ def densenet161(weights: Optional[DenseNet161Weights] = None, progress: bool = T

def densenet169(weights: Optional[DenseNet169Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = DenseNet169Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = DenseNet169Weights.verify(weights)

Expand All @@ -148,7 +148,7 @@ def densenet169(weights: Optional[DenseNet169Weights] = None, progress: bool = T

def densenet201(weights: Optional[DenseNet201Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = DenseNet201Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = DenseNet201Weights.verify(weights)

Expand Down
24 changes: 12 additions & 12 deletions torchvision/prototype/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
]


_common_meta = {
_COMMON_META = {
"categories": _COCO_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
}
Expand All @@ -41,7 +41,7 @@ class FasterRCNNResNet50FPNWeights(Weights):
url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
transforms=CocoEval,
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
"map": 37.0,
},
Expand All @@ -53,7 +53,7 @@ class FasterRCNNMobileNetV3LargeFPNWeights(Weights):
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
transforms=CocoEval,
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
"map": 32.8,
},
Expand All @@ -65,7 +65,7 @@ class FasterRCNNMobileNetV3Large320FPNWeights(Weights):
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
transforms=CocoEval,
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
"map": 22.8,
},
Expand All @@ -81,11 +81,11 @@ def fasterrcnn_resnet50_fpn(
**kwargs: Any,
) -> FasterRCNN:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = FasterRCNNResNet50FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = FasterRCNNResNet50FPNWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = ResNet50Weights.verify(weights_backbone)

Expand All @@ -102,7 +102,7 @@ def fasterrcnn_resnet50_fpn(
model = FasterRCNN(backbone, num_classes=num_classes, **kwargs)

if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == FasterRCNNResNet50FPNWeights.Coco_RefV1:
overwrite_eps(model, 0.0)

Expand Down Expand Up @@ -142,7 +142,7 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
)

if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))

return model

Expand All @@ -156,11 +156,11 @@ def fasterrcnn_mobilenet_v3_large_fpn(
**kwargs: Any,
) -> FasterRCNN:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = FasterRCNNMobileNetV3LargeFPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = FasterRCNNMobileNetV3LargeFPNWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)

Expand Down Expand Up @@ -188,11 +188,11 @@ def fasterrcnn_mobilenet_v3_large_320_fpn(
**kwargs: Any,
) -> FasterRCNN:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = FasterRCNNMobileNetV3Large320FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = FasterRCNNMobileNetV3Large320FPNWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)

Expand Down
12 changes: 6 additions & 6 deletions torchvision/prototype/models/detection/keypoint_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@
]


_common_meta = {"categories": _COCO_PERSON_CATEGORIES, "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES}
_COMMON_META = {"categories": _COCO_PERSON_CATEGORIES, "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES}


class KeypointRCNNResNet50FPNWeights(Weights):
Coco_RefV1_Legacy = WeightEntry(
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth",
transforms=CocoEval,
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/issues/1606",
"box_map": 50.6,
"kp_map": 61.1,
Expand All @@ -40,7 +40,7 @@ class KeypointRCNNResNet50FPNWeights(Weights):
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
transforms=CocoEval,
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#keypoint-r-cnn",
"box_map": 54.6,
"kp_map": 65.0,
Expand All @@ -58,7 +58,7 @@ def keypointrcnn_resnet50_fpn(
**kwargs: Any,
) -> KeypointRCNN:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
pretrained = kwargs.pop("pretrained")
if type(pretrained) == str and pretrained == "legacy":
weights = KeypointRCNNResNet50FPNWeights.Coco_RefV1_Legacy
Expand All @@ -68,7 +68,7 @@ def keypointrcnn_resnet50_fpn(
weights = None
weights = KeypointRCNNResNet50FPNWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = ResNet50Weights.verify(weights_backbone)

Expand All @@ -86,7 +86,7 @@ def keypointrcnn_resnet50_fpn(
model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)

if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == KeypointRCNNResNet50FPNWeights.Coco_RefV1:
overwrite_eps(model, 0.0)

Expand Down
6 changes: 3 additions & 3 deletions torchvision/prototype/models/detection/mask_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ def maskrcnn_resnet50_fpn(
**kwargs: Any,
) -> MaskRCNN:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = MaskRCNNResNet50FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = MaskRCNNResNet50FPNWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = ResNet50Weights.verify(weights_backbone)

Expand All @@ -67,7 +67,7 @@ def maskrcnn_resnet50_fpn(
model = MaskRCNN(backbone, num_classes=num_classes, **kwargs)

if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == MaskRCNNResNet50FPNWeights.Coco_RefV1:
overwrite_eps(model, 0.0)

Expand Down
6 changes: 3 additions & 3 deletions torchvision/prototype/models/detection/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ def retinanet_resnet50_fpn(
**kwargs: Any,
) -> RetinaNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RetinaNetResNet50FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = RetinaNetResNet50FPNWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = ResNet50Weights.verify(weights_backbone)

Expand All @@ -70,7 +70,7 @@ def retinanet_resnet50_fpn(
model = RetinaNet(backbone, num_classes, **kwargs)

if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == RetinaNetResNet50FPNWeights.Coco_RefV1:
overwrite_eps(model, 0.0)

Expand Down
8 changes: 4 additions & 4 deletions torchvision/prototype/models/detection/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,16 @@ def ssd300_vgg16(
**kwargs: Any,
) -> SSD:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = SSD300VGG16Weights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = SSD300VGG16Weights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = VGG16Weights.ImageNet1K_Features if kwargs.pop("pretrained_backbone") else None
weights_backbone = VGG16Weights.verify(weights_backbone)

if "size" in kwargs:
warnings.warn("The size of the model is already fixed; ignoring the argument.")
warnings.warn("The size of the model is already fixed; ignoring the parameter.")

if weights is not None:
weights_backbone = None
Expand Down Expand Up @@ -81,6 +81,6 @@ def ssd300_vgg16(
model = SSD(backbone, anchor_generator, (300, 300), num_classes, **kwargs)

if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))

return model
8 changes: 4 additions & 4 deletions torchvision/prototype/models/detection/ssdlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,16 @@ def ssdlite320_mobilenet_v3_large(
**kwargs: Any,
) -> SSD:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = SSDlite320MobileNetV3LargeFPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = SSDlite320MobileNetV3LargeFPNWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)

if "size" in kwargs:
warnings.warn("The size of the model is already fixed; ignoring the argument.")
warnings.warn("The size of the model is already fixed; ignoring the parameter.")

if weights is not None:
weights_backbone = None
Expand Down Expand Up @@ -114,6 +114,6 @@ def ssdlite320_mobilenet_v3_large(
)

if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))

return model
Loading