Skip to content

Commit

Permalink
Adding the huge vision transformer from SWAG (#5721)
Browse files Browse the repository at this point in the history
* Add vit_b_16_swag

* Better handling idiom for image_size, edit test_extended_model to handle case where number of param differ from default due to different image size input

* Update the accuracy to the experiment result on torchvision model

* Fix typo missing underscore

* raise exception instead of torch._assert, add back publication year (accidentally deleted)

* Add license information on meta and readme

* Improve wording and fix typo for pretrained model license in readme

* Add vit_l_16 weight

* Update README.rst

Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>

* Update the accuracy meta on vit_l_16_swag model to result from our experiment

* Add vit_h_14_swag model

* Add accuracy from experiments

* Add to vit_h_16 model to hubconf.py

* Add docs and expected pkl file for test

* Remove legacy compatibility for ViT_H_14 model

Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>

* Test vit_h_14 with smaller image_size to speedup the test

Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
  • Loading branch information
YosuaMichael and datumbox committed Apr 5, 2022
1 parent d0c92dc commit 63576c9
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 0 deletions.
3 changes: 3 additions & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ You can construct a model with random weights by calling its constructor:
vit_b_32 = models.vit_b_32()
vit_l_16 = models.vit_l_16()
vit_l_32 = models.vit_l_32()
vit_h_14 = models.vit_h_14()
convnext_tiny = models.convnext_tiny()
convnext_small = models.convnext_small()
convnext_base = models.convnext_base()
Expand Down Expand Up @@ -213,6 +214,7 @@ vit_b_16 81.072 95.318
vit_b_32 75.912 92.466
vit_l_16 79.662 94.638
vit_l_32 76.972 93.070
vit_h_14 88.552 98.694
convnext_tiny 82.520 96.146
convnext_small 83.616 96.650
convnext_base 84.062 96.870
Expand Down Expand Up @@ -434,6 +436,7 @@ VisionTransformer
vit_b_32
vit_l_16
vit_l_32
vit_h_14

ConvNeXt
--------
Expand Down
1 change: 1 addition & 0 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,5 @@
vit_b_32,
vit_l_16,
vit_l_32,
vit_h_14,
)
Binary file added test/expect/ModelTester.test_vit_h_14_expect.pkl
Binary file not shown.
4 changes: 4 additions & 0 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,10 @@ def _check_input_backprop(model, inputs):
"rpn_pre_nms_top_n_test": 1000,
"rpn_post_nms_top_n_test": 1000,
},
"vit_h_14": {
"image_size": 56,
"input_shape": (1, 3, 56, 56),
},
}
# speeding up slow models:
slow_models = [
Expand Down
46 changes: 46 additions & 0 deletions torchvision/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
"ViT_B_32_Weights",
"ViT_L_16_Weights",
"ViT_L_32_Weights",
"ViT_H_14_Weights",
"vit_b_16",
"vit_b_32",
"vit_l_16",
"vit_l_32",
"vit_h_14",
]


Expand Down Expand Up @@ -435,6 +437,27 @@ class ViT_L_32_Weights(WeightsEnum):
DEFAULT = IMAGENET1K_V1


class ViT_H_14_Weights(WeightsEnum):
IMAGENET1K_SWAG_V1 = Weights(
url="https://download.pytorch.org/models/vit_h_14_swag-80465313.pth",
transforms=partial(
ImageClassification,
crop_size=518,
resize_size=518,
interpolation=InterpolationMode.BICUBIC,
),
meta={
**_COMMON_SWAG_META,
"num_params": 633470440,
"size": (518, 518),
"min_size": (518, 518),
"acc@1": 88.552,
"acc@5": 98.694,
},
)
DEFAULT = IMAGENET1K_SWAG_V1


@handle_legacy_interface(weights=("pretrained", ViT_B_16_Weights.IMAGENET1K_V1))
def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
"""
Expand Down Expand Up @@ -531,6 +554,29 @@ def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = Tru
)


def vit_h_14(*, weights: Optional[ViT_H_14_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
"""
Constructs a vit_h_14 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
Args:
weights (ViT_H_14_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
"""
weights = ViT_H_14_Weights.verify(weights)

return _vision_transformer(
patch_size=14,
num_layers=32,
num_heads=16,
hidden_dim=1280,
mlp_dim=5120,
weights=weights,
progress=progress,
**kwargs,
)


def interpolate_embeddings(
image_size: int,
patch_size: int,
Expand Down

0 comments on commit 63576c9

Please sign in to comment.