Skip to content

Commit

Permalink
Add pretrained MAE weights, option to load checkpoints in ViT builder (
Browse files Browse the repository at this point in the history
…#479)

Summary:
For MAE fine-tuning, fine-tuning occurs just on the encoder (ViT). This change allows easy loading of MAE pretrained weights directly into our ViT class.

Pull Request resolved: #479

Test Plan:
```
python -m pytest -v tests/models/*
...
========== 207 passed, 25 warnings in 424.67s (0:07:04) ===========================

python -m pytest -v tests/modules/*
...
======================== 192 passed, 2 skipped, 22 warnings in 10.75s ==========================
```

Test instantiating ViT using MAE pretrained weights for each of the 3 checkpoints:

<img width="1163" alt="Screenshot 2023-10-05 at 6 39 02 PM" src="https://github.com/facebookresearch/multimodal/assets/24319399/c159b2dd-0b04-4572-85b9-d3024eee9a53">

Reviewed By: kartikayk

Differential Revision: D50015711

Pulled By: ebsmothers

fbshipit-source-id: e09fd02560b31574427b9f66373f12e7fd663f06
  • Loading branch information
ebsmothers authored and facebook-github-bot committed Oct 6, 2023
1 parent 0de91e1 commit 6f32ca1
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
32 changes: 32 additions & 0 deletions torchmultimodal/models/masked_auto_encoder/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,25 @@
get_2d_sin_cos_embeddings,
)
from torchmultimodal.models.masked_auto_encoder.swin_decoder import SwinTransformer
from torchmultimodal.modules.encoders.vision_transformer import (
VisionTransformer,
vit_b_16,
vit_l_16,
)
from torchmultimodal.modules.layers.patch_embedding import PatchEmbeddings
from torchmultimodal.modules.layers.transformer import (
TransformerEncoder,
TransformerOutput,
)


MAE_MODEL_MAPPING = {
"vit_b16_image": "https://download.pytorch.org/models/multimodal/mae/mae_pretrained_vit_base.pth",
"vit_l16_image": "https://download.pytorch.org/models/multimodal/mae/mae_pretrained_vit_large.pth",
"vit_b16_audio": "https://download.pytorch.org/models/multimodal/audio_mae/audio_mae_pretrained_vit_base.pth",
}


class MAEOutput(NamedTuple):
encoder_output: Union[TransformerOutput, Tensor]
decoder_pred: Optional[Tensor] = None
Expand Down Expand Up @@ -324,6 +336,16 @@ def vit_l_16_image_mae() -> MaskedAutoEncoder:
)


def vit_b_16_image_mae_encoder(pretrained: bool = False) -> VisionTransformer:
ckpt_path = MAE_MODEL_MAPPING["vit_b16_image"] if pretrained else None
return vit_b_16(final_layer_norm_eps=None, ckpt_path=ckpt_path)


def vit_l_16_image_mae_encoder(pretrained: bool = False) -> VisionTransformer:
ckpt_path = MAE_MODEL_MAPPING["vit_l16_image"] if pretrained else None
return vit_l_16(final_layer_norm_eps=None, ckpt_path=ckpt_path)


def audio_mae(
*,
# patch embedding
Expand Down Expand Up @@ -449,3 +471,13 @@ def vit_l_16_audio_mae() -> MaskedAutoEncoder:
decoder_heads=16,
decoder_dim_feedforward=2048,
)


def vit_b_16_audio_mae_encoder(pretrained: bool = False) -> VisionTransformer:
ckpt_path = MAE_MODEL_MAPPING["vit_b16_audio"] if pretrained else None
return vit_b_16(
final_layer_norm_eps=None,
num_channels=1,
image_size=(1024, 128),
ckpt_path=ckpt_path,
)
4 changes: 4 additions & 0 deletions torchmultimodal/modules/encoders/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
TransformerEncoder,
TransformerOutput,
)
from torchmultimodal.utils.common import load_module_from_url


class VisionTransformer(nn.Module):
Expand Down Expand Up @@ -148,6 +149,7 @@ def vision_transformer(
drop_path_rate: Optional[float] = None,
patch_drop_rate: Optional[Union[float, Tuple[float, float]]] = None,
pooler: Optional[nn.Module] = None,
ckpt_path: str = None,
) -> VisionTransformer:
"""
Args:
Expand Down Expand Up @@ -198,6 +200,8 @@ def vision_transformer(
vit = VisionTransformer(
embeddings=image_embedding, encoder=transformer_encoder, pooler=pooler
)
if ckpt_path:
load_module_from_url(vit, ckpt_path)
return vit


Expand Down

0 comments on commit 6f32ca1

Please sign in to comment.