Description
🚀 RFC
Background Info
To access pre-trained models in TorchVision, one needs to pass pretrained=True
on the model builders. Example:
from torchvision.models import resnet50
# With weights:
model = resnet50(pretrained=True)
# Without weights:
model = resnet50(pretrained=False)
Unfortunately the above API does not allow us to support multiple pre-trained weights. This feature is necessary when we want to provide improved weights on the same dataset (for example better Acc@1 on ImageNet) or additional weights trained on a different dataset (for example in Object Detection use VOC instead of COCO). With the completion of the Multi-weight support prototype the TorchVision model builders can now support more than 1 set of weights:
from torchvision.prototype.models import resnet50, ResNet50_Weights
# Old weights:
model = resnet50(weights=ResNet50_Weights.ImageNet1K_V1)
# New weights:
model = resnet50(weights=ResNet50_Weights.ImageNet1K_V2)
# No weights:
model = resnet50(weights=None)
The above prototype API is now available on nightly builds where users can test it and gather feedback. Once the feedback is gathered and acted upon, we will consider releasing the new API on the main area.
What should be the behaviour of pretrained=True
?
Upon release, the legacy pretrained=True
parameter will be deprecated and it will be removed on a future version of TorchVision (TBD when). The question of this RFC is what the behaviour of the pretrained=True
should be until its removal. There are currently two obvious candidates:
Option 1: Using the Legacy weights
Using pretrained=True
the new API should return the same legacy weights as the one used by the current API.
This is how the prototype is currently implemented. The following calls are all equivalent:
# Legacy weights with accuracy 76.130%
model = resnet50(weights=ResNet50_Weights.ImageNet1K_V1)
model = resnet50(pretrained=True)
model = resnet50(True)
Why to select this option:
- It is aligned with TorchVision's strong Backwards Compatibility guarantees
- It requires a "manual opt-in" from users to switch to the new weights
- It's the safest option
Option 2: Using the Best available weights
Using pretrained=True
the new API should return the best available weights.
The following calls will be made equivalent:
# New weights with accuracy 80.674%
model = resnet50(weights=ResNet50_Weights.ImageNet1K_V2)
model = resnet50(weights=ResNet50_Weights.default)
model = resnet50(pretrained=True)
model = resnet50(True)
Why to select this option:
- The users will benefit automatically from the major accuracy improvement.
- In practice TorchVision didn't actually offer BC guarantees on the weights. There are several instances where we modified in-place the weights previously [1, 2, 3, 4]. Due to this, one could make the argument that the semantics of
pretrained=True
always meant "give me the current best weights". - The in-place modification of weights is commonplace in other libraries [1, 2].
- It emphasises the fact that ResNet50 and other older architectures achieve very high accuracies if trained with modern approaches and this can have positive influence in research [1].
To address some of the cons of adopting this option we will:
- Raise warnings to inform users that they access new weights. Provide information within the warning on how to switch to the old behaviour.
- Inform downstream libraries and users about the upcoming change via blogposts, social media and even by opening PRs to their projects (especially for Meta-backed projects).
Feedback
We would love to hear your thoughts on the matter. You can vote on this topic on Twitter or explain your position in the comments.