diff --git a/docs/source/models.rst b/docs/source/models.rst index 16825d2b8b2..f84d9c7fd1a 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -92,6 +92,7 @@ You can construct a model with random weights by calling its constructor: vit_b_32 = models.vit_b_32() vit_l_16 = models.vit_l_16() vit_l_32 = models.vit_l_32() + vit_h_14 = models.vit_h_14() convnext_tiny = models.convnext_tiny() convnext_small = models.convnext_small() convnext_base = models.convnext_base() @@ -213,6 +214,7 @@ vit_b_16 81.072 95.318 vit_b_32 75.912 92.466 vit_l_16 79.662 94.638 vit_l_32 76.972 93.070 +vit_h_14 88.552 98.694 convnext_tiny 82.520 96.146 convnext_small 83.616 96.650 convnext_base 84.062 96.870 @@ -434,6 +436,7 @@ VisionTransformer vit_b_32 vit_l_16 vit_l_32 + vit_h_14 ConvNeXt -------- diff --git a/hubconf.py b/hubconf.py index c3de4f2da9a..bbd5da52b13 100644 --- a/hubconf.py +++ b/hubconf.py @@ -67,4 +67,5 @@ vit_b_32, vit_l_16, vit_l_32, + vit_h_14, ) diff --git a/test/expect/ModelTester.test_vit_h_14_expect.pkl b/test/expect/ModelTester.test_vit_h_14_expect.pkl new file mode 100644 index 00000000000..1f846beb6a0 Binary files /dev/null and b/test/expect/ModelTester.test_vit_h_14_expect.pkl differ diff --git a/test/test_models.py b/test/test_models.py index 0fbf45b9750..5e0cc742d84 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -280,6 +280,10 @@ def _check_input_backprop(model, inputs): "rpn_pre_nms_top_n_test": 1000, "rpn_post_nms_top_n_test": 1000, }, + "vit_h_14": { + "image_size": 56, + "input_shape": (1, 3, 56, 56), + }, } # speeding up slow models: slow_models = [ diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index 59da51c1bd9..de2e61c440a 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -20,10 +20,12 @@ "ViT_B_32_Weights", "ViT_L_16_Weights", "ViT_L_32_Weights", + "ViT_H_14_Weights", "vit_b_16", "vit_b_32", "vit_l_16", "vit_l_32", + "vit_h_14", ] @@ -435,6 +437,27 @@ class ViT_L_32_Weights(WeightsEnum): DEFAULT = IMAGENET1K_V1 +class ViT_H_14_Weights(WeightsEnum): + IMAGENET1K_SWAG_V1 = Weights( + url="https://download.pytorch.org/models/vit_h_14_swag-80465313.pth", + transforms=partial( + ImageClassification, + crop_size=518, + resize_size=518, + interpolation=InterpolationMode.BICUBIC, + ), + meta={ + **_COMMON_SWAG_META, + "num_params": 633470440, + "size": (518, 518), + "min_size": (518, 518), + "acc@1": 88.552, + "acc@5": 98.694, + }, + ) + DEFAULT = IMAGENET1K_SWAG_V1 + + @handle_legacy_interface(weights=("pretrained", ViT_B_16_Weights.IMAGENET1K_V1)) def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ @@ -531,6 +554,29 @@ def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = Tru ) +def vit_h_14(*, weights: Optional[ViT_H_14_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: + """ + Constructs a vit_h_14 architecture from + `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. + + Args: + weights (ViT_H_14_Weights, optional): The pretrained weights for the model + progress (bool): If True, displays a progress bar of the download to stderr + """ + weights = ViT_H_14_Weights.verify(weights) + + return _vision_transformer( + patch_size=14, + num_layers=32, + num_heads=16, + hidden_dim=1280, + mlp_dim=5120, + weights=weights, + progress=progress, + **kwargs, + ) + + def interpolate_embeddings( image_size: int, patch_size: int,