Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding the huge vision transformer from SWAG #5721

Merged
merged 24 commits into from
Apr 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
c801bf0
Add vit_b_16_swag
YosuaMichael Mar 31, 2022
9e13f79
Better handling idiom for image_size, edit test_extended_model to han…
YosuaMichael Mar 31, 2022
1707171
Update the accuracy to the experiment result on torchvision model
YosuaMichael Mar 31, 2022
bd8b1a8
Fix typo missing underscore
YosuaMichael Mar 31, 2022
6c765a5
raise exception instead of torch._assert, add back publication year (…
YosuaMichael Mar 31, 2022
3326e88
Merge branch 'main' into add-swag-weight
YosuaMichael Apr 1, 2022
e444c5a
Add license information on meta and readme
YosuaMichael Apr 1, 2022
a6ee605
Merge branch 'add-swag-weight' of github.com:pytorch/vision into add-…
YosuaMichael Apr 1, 2022
54aa8cf
Improve wording and fix typo for pretrained model license in readme
YosuaMichael Apr 1, 2022
f9c32eb
Add vit_l_16 weight
YosuaMichael Apr 1, 2022
4cf4eff
Update README.rst
YosuaMichael Apr 1, 2022
9230f40
Update the accuracy meta on vit_l_16_swag model to result from our ex…
YosuaMichael Apr 1, 2022
ce6eb3e
Add vit_h_14_swag model
YosuaMichael Apr 1, 2022
ff76a53
Add accuracy from experiments
YosuaMichael Apr 1, 2022
e874548
Add to vit_h_16 model to hubconf.py
YosuaMichael Apr 1, 2022
2ca4ac4
Add docs and expected pkl file for test
YosuaMichael Apr 1, 2022
c806fb1
Merge branch 'main' into add-vit-swag-huge
YosuaMichael Apr 1, 2022
9ff5a76
Merge branch 'main' into add-vit-swag-huge
datumbox Apr 1, 2022
9f603d6
Remove legacy compatibility for ViT_H_14 model
YosuaMichael Apr 4, 2022
dd21912
Merge branch 'main' into add-vit-swag-huge
YosuaMichael Apr 4, 2022
e4062f4
Merge branch 'main' into add-vit-swag-huge
YosuaMichael Apr 4, 2022
02be296
Test vit_h_14 with smaller image_size to speedup the test
YosuaMichael Apr 4, 2022
87e6c2a
Merge branch 'main' into add-vit-swag-huge
YosuaMichael Apr 5, 2022
696201f
Merge branch 'main' into add-vit-swag-huge
YosuaMichael Apr 5, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ You can construct a model with random weights by calling its constructor:
vit_b_32 = models.vit_b_32()
vit_l_16 = models.vit_l_16()
vit_l_32 = models.vit_l_32()
vit_h_14 = models.vit_h_14()
convnext_tiny = models.convnext_tiny()
convnext_small = models.convnext_small()
convnext_base = models.convnext_base()
Expand Down Expand Up @@ -213,6 +214,7 @@ vit_b_16 81.072 95.318
vit_b_32 75.912 92.466
vit_l_16 79.662 94.638
vit_l_32 76.972 93.070
vit_h_14 88.552 98.694
convnext_tiny 82.520 96.146
convnext_small 83.616 96.650
convnext_base 84.062 96.870
Expand Down Expand Up @@ -434,6 +436,7 @@ VisionTransformer
vit_b_32
vit_l_16
vit_l_32
vit_h_14

ConvNeXt
--------
Expand Down
1 change: 1 addition & 0 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,5 @@
vit_b_32,
vit_l_16,
vit_l_32,
vit_h_14,
)
Binary file added test/expect/ModelTester.test_vit_h_14_expect.pkl
Binary file not shown.
4 changes: 4 additions & 0 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,10 @@ def _check_input_backprop(model, inputs):
"rpn_pre_nms_top_n_test": 1000,
"rpn_post_nms_top_n_test": 1000,
},
"vit_h_14": {
"image_size": 56,
"input_shape": (1, 3, 56, 56),
},
Copy link
Contributor Author

@YosuaMichael YosuaMichael Apr 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@datumbox this according to your suggestions on changing the input image_size for the test on vit_h_14 models to speed up the test. The image_size need to be a multiple of patch_size which is 14, hence we use image_size of 56.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After reducing the image_size, there is a speedup although not a lot. I observed that the speedup is around 1.5s - 2s for each gpu and cpu test of the model.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very interesting. Does this mean that the majority of the time is spent of the model initialization or on the JIT-script parsing?

Copy link
Contributor Author

@YosuaMichael YosuaMichael Apr 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@datumbox I did a bit profiling locally and here are the results :

[2022-04-04 16:48:41.768663] Before building the model
[2022-04-04 16:48:45.000341] After building model
[2022-04-04 16:48:45.002153] After model.eval().to(device=dev)
[2022-04-04 16:48:45.207033] After doing model(x)
[2022-04-04 16:48:45.208375] After assert expected
[2022-04-04 16:48:45.208385] After assert shape num_classes
[2022-04-04 16:48:50.452526] After check_jit_scripttable
[2022-04-04 16:48:50.667378] After check_fx_compatible
[2022-04-04 16:48:51.256744] Finish

Seems like around 35% of the time is on building the model, and another 45% of the time is for check_jit_scriptable.

}
# speeding up slow models:
slow_models = [
Expand Down
46 changes: 46 additions & 0 deletions torchvision/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
"ViT_B_32_Weights",
"ViT_L_16_Weights",
"ViT_L_32_Weights",
"ViT_H_14_Weights",
"vit_b_16",
"vit_b_32",
"vit_l_16",
"vit_l_32",
"vit_h_14",
]


Expand Down Expand Up @@ -435,6 +437,27 @@ class ViT_L_32_Weights(WeightsEnum):
DEFAULT = IMAGENET1K_V1


class ViT_H_14_Weights(WeightsEnum):
IMAGENET1K_SWAG_V1 = Weights(
url="https://download.pytorch.org/models/vit_h_14_swag-80465313.pth",
transforms=partial(
ImageClassification,
crop_size=518,
resize_size=518,
interpolation=InterpolationMode.BICUBIC,
),
meta={
**_COMMON_SWAG_META,
"num_params": 633470440,
"size": (518, 518),
"min_size": (518, 518),
"acc@1": 88.552,
"acc@5": 98.694,
},
)
DEFAULT = IMAGENET1K_SWAG_V1


@handle_legacy_interface(weights=("pretrained", ViT_B_16_Weights.IMAGENET1K_V1))
def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
"""
Expand Down Expand Up @@ -531,6 +554,29 @@ def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = Tru
)


def vit_h_14(*, weights: Optional[ViT_H_14_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
"""
Constructs a vit_h_14 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.

Args:
weights (ViT_H_14_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
"""
weights = ViT_H_14_Weights.verify(weights)

return _vision_transformer(
patch_size=14,
num_layers=32,
num_heads=16,
hidden_dim=1280,
mlp_dim=5120,
weights=weights,
progress=progress,
**kwargs,
)


def interpolate_embeddings(
image_size: int,
patch_size: int,
Expand Down