From 1f825a3f1b124fca95f0f652a9f57f6758f55233 Mon Sep 17 00:00:00 2001 From: Daniel Bolya Date: Thu, 20 Jul 2023 21:17:05 -0700 Subject: [PATCH] [v0.1.2] Added the full model zoo and video MAE models --- CHANGELOG.md | 4 ++ README.md | 50 ++++++++++++++-------- hiera/__init__.py | 5 +++ hiera/hiera.py | 23 +++++++++-- hiera/hiera_mae.py | 98 +++++++++++++++++++++++++++++++++++++++++--- hiera/hiera_utils.py | 4 ++ setup.py | 2 +- 7 files changed, 159 insertions(+), 27 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 59e2ed2..803c256 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +### **[2023.07.20]** v0.1.2 + - Released the full model zoo. + - Added MAE functionality to the video models. + ### **[2023.06.12]** v0.1.1 - Added the ability to specify multiple pretrained checkpoints per architecture (specify with `checkpoint=`). - Added the ability to pass `strict=False` to a pretrained model so that you can use a different number of classes. **Note:** when changing the number of classes, the head layer will be reset. diff --git a/README.md b/README.md index 559f278..51b40b0 100644 --- a/README.md +++ b/README.md @@ -63,35 +63,37 @@ python setup.py build develop ## Model Zoo -Here we provide model checkpoints for Hiera. Each model listed is accessible on [torch hub](https://pytorch.org/docs/stable/hub.html), e.g.: +Here we provide model checkpoints for Hiera. Each model listed is accessible on [torch hub](https://pytorch.org/docs/stable/hub.html) even without the `hiera-transformer` package installed, e.g. the following initializes a base model pretrained and finetuned on ImageNet-1k: ```py model = torch.hub.load("facebookresearch/hiera", model="hiera_base_224", pretrained=True, checkpoint="mae_in1k_ft_in1k") ``` -For model names and corresponding checkpoint names see below. -**Note:** the speeds listed here were benchmarked _without_ PyTorch's optimized [scaled dot product attention](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html). If using PyTorch 2.0 or above, your inference speed will probably be faster than what's listed here. +If you want a model with MAE pretraining only, you can replace the checkpoint with `"mae_in1k"`. Additionally, if you'd like to load the MAE decoder as well (e.g., to continue pretraining), add `mae_` the the start of the model name, e.g.: +```py +model = torch.hub.load("facebookresearch/hiera", model="mae_hiera_base_224", pretrained=True, checkpoint="mae_in1k") +``` +**Note:** Our MAE models were trained with a _normalized pixel loss_. That means that the patches were normalized before the network had to predict them. If you want to visualize the predictions, you'll have to unnormalize them using the visible patches (which might work but wouldn't be perfect) or unnormalize them using the ground truth. For model more names and corresponding checkpoint names see below. -#### Coming Soon -As of now, base finetuned models are available. The rest are coming soon. +**Note:** the speeds listed here were benchmarked _without_ PyTorch's optimized [scaled dot product attention](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html). If using PyTorch 2.0 or above, your inference speed will probably be faster than what's listed here. ### Image Models | Model | Model Name | Pretrained Models
(IN-1K MAE) | Finetuned Models
(IN-1K Supervised) | IN-1K
Top-1 (%) | A100 fp16
Speed (im/s) | |----------|-----------------------|----------------------------------|----------------------------------------|:------------------:|:-------------------------:| -| Hiera-T | `hiera_tiny_224` | Coming Soon | [mae_in1k_ft_in1k](https://dl.fbaipublicfiles.com/hiera/hiera_tiny_224.pth) | 82.8 | 2758 | -| Hiera-S | `hiera_small_224` | Coming Soon | [mae_in1k_ft_in1k](https://dl.fbaipublicfiles.com/hiera/hiera_small_224.pth) | 83.8 | 2211 | -| Hiera-B | `hiera_base_224` | Coming Soon | [mae_in1k_ft_in1k](https://dl.fbaipublicfiles.com/hiera/hiera_base_224.pth) | 84.5 | 1556 | -| Hiera-B+ | `hiera_base_plus_224` | Coming Soon | [mae_in1k_ft_in1k](https://dl.fbaipublicfiles.com/hiera/hiera_base_plus_224.pth) | 85.2 | 1247 | -| Hiera-L | `hiera_large_224` | Coming Soon | [mae_in1k_ft_in1k](https://dl.fbaipublicfiles.com/hiera/hiera_large_224.pth) | 86.1 | 531 | -| Hiera-H | `hiera_huge_224` | Coming Soon | [mae_in1k_ft_in1k](https://dl.fbaipublicfiles.com/hiera/hiera_huge_224.pth) | 86.9 | 274 | +| Hiera-T | `hiera_tiny_224` | [mae_in1k](https://dl.fbaipublicfiles.com/hiera/mae_hiera_tiny_224.pth) | [mae_in1k_ft_in1k](https://dl.fbaipublicfiles.com/hiera/hiera_tiny_224.pth) | 82.8 | 2758 | +| Hiera-S | `hiera_small_224` | [mae_in1k](https://dl.fbaipublicfiles.com/hiera/mae_hiera_small_224.pth) | [mae_in1k_ft_in1k](https://dl.fbaipublicfiles.com/hiera/hiera_small_224.pth) | 83.8 | 2211 | +| Hiera-B | `hiera_base_224` | [mae_in1k](https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_224.pth) | [mae_in1k_ft_in1k](https://dl.fbaipublicfiles.com/hiera/hiera_base_224.pth) | 84.5 | 1556 | +| Hiera-B+ | `hiera_base_plus_224` | [mae_in1k](https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_plus_224.pth) | [mae_in1k_ft_in1k](https://dl.fbaipublicfiles.com/hiera/hiera_base_plus_224.pth) | 85.2 | 1247 | +| Hiera-L | `hiera_large_224` | [mae_in1k](https://dl.fbaipublicfiles.com/hiera/mae_hiera_large_224.pth) | [mae_in1k_ft_in1k](https://dl.fbaipublicfiles.com/hiera/hiera_large_224.pth) | 86.1 | 531 | +| Hiera-H | `hiera_huge_224` | [mae_in1k](https://dl.fbaipublicfiles.com/hiera/mae_hiera_huge_224.pth) | [mae_in1k_ft_in1k](https://dl.fbaipublicfiles.com/hiera/hiera_huge_224.pth) | 86.9 | 274 | Each model inputs a 224x224 image. ### Video Models | Model | Model Name | Pretrained Models
(K400 MAE) | Finetuned Models
(K400) | K400 (3x5 views)
Top-1 (%) | A100 fp16
Speed (clip/s) | |----------|--------------------------|---------------------------------|----------------------------|:-----------------------------:|:---------------------------:| -| Hiera-B | `hiera_base_16x224` | Coming Soon | [mae_k400_ft_k400](https://dl.fbaipublicfiles.com/hiera/hiera_base_16x224.pth) | 84.0 | 133.6 | -| Hiera-B+ | `hiera_base_plus_16x224` | Coming Soon | Coming Soon | 85.0 | 84.1 | -| Hiera-L | `hiera_large_16x224` | Coming Soon | Coming Soon | 87.3 | 40.8 | -| Hiera-H | `hiera_huge_16x224` | Coming Soon | Coming Soon | 87.8 | 20.9 | +| Hiera-B | `hiera_base_16x224` | [mae_k400](https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_16x224.pth) | [mae_k400_ft_k400](https://dl.fbaipublicfiles.com/hiera/hiera_base_16x224.pth) | 84.0 | 133.6 | +| Hiera-B+ | `hiera_base_plus_16x224` | [mae_k400](https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_plus_16x224.pth) | [mae_k400_ft_k400](https://dl.fbaipublicfiles.com/hiera/hiera_base_plus_16x224.pth) | 85.0 | 84.1 | +| Hiera-L | `hiera_large_16x224` | [mae_k400](https://dl.fbaipublicfiles.com/hiera/mae_hiera_large_16x224.pth) | [mae_k400_ft_k400](https://dl.fbaipublicfiles.com/hiera/hiera_large_16x224.pth) | 87.3 | 40.8 | +| Hiera-H | `hiera_huge_16x224` | [mae_k400](https://dl.fbaipublicfiles.com/hiera/mae_hiera_huge_16x224.pth) | [mae_k400_ft_k400](https://dl.fbaipublicfiles.com/hiera/hiera_huge_16x224.pth) | 87.8 | 20.9 | Each model inputs 16 224x224 frames with a temporal stride of 4. @@ -103,9 +105,9 @@ This repo implements the code to run Hiera models for inference. This repository - [x] Image Inference - [x] MAE implementation - [x] Video Inference - - [ ] MAE implementation + - [x] MAE implementation + - [x] Full Model Zoo - [ ] Training scripts - - [ ] Full Model Zoo See [examples](https://github.com/facebookresearch/hiera/tree/main/examples) for examples of how to use Hiera. @@ -130,6 +132,20 @@ Video inference works the same way, just use a `16x224` model instead. output, intermediates = model(x, return_intermediates=True) ``` +#### MAE Inference +By default, the models do not include the MAE decoder. If you would like to use the decoder or compute MAE loss, you can instantiate an mae version by running: +```py +import hiera +model = hiera.mae_hiera_base_224(pretrained=True, checkpoint="mae_in1k") +``` +Then when you run inference on the model, it will return a 4-tuple of `(loss, predictions, labels, mask)` where predictions and labels are for the _deleted tokens_ only. The returned mask will be `True` if the token is visible and `False` if it's deleted. You can change the masking ratio by passing it during inference: +```py +loss, preds, labels, mask = model(x, mask_ratio=0.6) +``` +The default mask ratio is `0.6` for images, but you should pass in `0.9` for video. See the paper for details. + +**Note:** We use _normalized pixel targets_ for MAE pretraining, meaning the patches are each individually normalized before the model model has to predict them. Thus, you have to unnormalize them using the ground truth before visualizing them. See `get_pixel_label_2d` in `hiera_mae.py` for details. + ### Benchmarking We provide a script for easy benchmarking. See [examples/benchmark](https://github.com/facebookresearch/hiera/blob/main/examples/benchmark.ipynb) to see how to use it. diff --git a/hiera/__init__.py b/hiera/__init__.py index eecefed..76e454c 100644 --- a/hiera/__init__.py +++ b/hiera/__init__.py @@ -34,5 +34,10 @@ mae_hiera_large_224, mae_hiera_huge_224, + mae_hiera_base_16x224, + mae_hiera_base_plus_16x224, + mae_hiera_large_16x224, + mae_hiera_huge_16x224, + MaskedAutoencoderHiera, ) \ No newline at end of file diff --git a/hiera/hiera.py b/hiera/hiera.py index 523c8df..35e8c93 100644 --- a/hiera/hiera.py +++ b/hiera/hiera.py @@ -233,6 +233,7 @@ def __init__( super().__init__() depth = sum(stages) + self.patch_stride = patch_stride self.tokens_spatial_shape = [i // s for i, s in zip(input_size, patch_stride)] num_tokens = math.prod(self.tokens_spatial_shape) flat_mu_size = math.prod(mask_unit_size) @@ -438,6 +439,7 @@ def forward( @pretrained_model({ "mae_in1k_ft_in1k": "https://dl.fbaipublicfiles.com/hiera/hiera_tiny_224.pth", + "mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_tiny_224.pth", }, default="mae_in1k_ft_in1k") def hiera_tiny_224(**kwdargs): return Hiera(embed_dim=96, num_heads=1, stages=(1, 2, 7, 2), **kwdargs) @@ -445,6 +447,7 @@ def hiera_tiny_224(**kwdargs): @pretrained_model({ "mae_in1k_ft_in1k": "https://dl.fbaipublicfiles.com/hiera/hiera_small_224.pth", + "mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_small_224.pth", }, default="mae_in1k_ft_in1k") def hiera_small_224(**kwdargs): return Hiera(embed_dim=96, num_heads=1, stages=(1, 2, 11, 2), **kwdargs) @@ -452,6 +455,7 @@ def hiera_small_224(**kwdargs): @pretrained_model({ "mae_in1k_ft_in1k": "https://dl.fbaipublicfiles.com/hiera/hiera_base_224.pth", + "mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_224.pth", }, default="mae_in1k_ft_in1k") def hiera_base_224(**kwdargs): return Hiera(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3), **kwdargs) @@ -459,6 +463,7 @@ def hiera_base_224(**kwdargs): @pretrained_model({ "mae_in1k_ft_in1k": "https://dl.fbaipublicfiles.com/hiera/hiera_base_plus_224.pth", + "mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_plus_224.pth", }, default="mae_in1k_ft_in1k") def hiera_base_plus_224(**kwdargs): return Hiera(embed_dim=112, num_heads=2, stages=(2, 3, 16, 3), **kwdargs) @@ -466,6 +471,7 @@ def hiera_base_plus_224(**kwdargs): @pretrained_model({ "mae_in1k_ft_in1k": "https://dl.fbaipublicfiles.com/hiera/hiera_large_224.pth", + "mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_large_224.pth", }, default="mae_in1k_ft_in1k") def hiera_large_224(**kwdargs): return Hiera(embed_dim=144, num_heads=2, stages=(2, 6, 36, 4), **kwdargs) @@ -473,6 +479,7 @@ def hiera_large_224(**kwdargs): @pretrained_model({ "mae_in1k_ft_in1k": "https://dl.fbaipublicfiles.com/hiera/hiera_huge_224.pth", + "mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_huge_224.pth", }, default="mae_in1k_ft_in1k") def hiera_huge_224(**kwdargs): return Hiera(embed_dim=256, num_heads=4, stages=(2, 6, 36, 4), **kwdargs) @@ -482,6 +489,7 @@ def hiera_huge_224(**kwdargs): @pretrained_model({ "mae_k400_ft_k400": "https://dl.fbaipublicfiles.com/hiera/hiera_base_16x224.pth", + "mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_16x224.pth", }, default="mae_k400_ft_k400") def hiera_base_16x224(num_classes: int = 400, **kwdargs): return Hiera( @@ -497,21 +505,30 @@ def hiera_base_16x224(num_classes: int = 400, **kwdargs): ) -@pretrained_model(None) +@pretrained_model({ + "mae_k400_ft_k400": "https://dl.fbaipublicfiles.com/hiera/hiera_base_plus_16x224.pth", + "mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_plus_16x224.pth", +}, default="mae_k400_ft_k400") def hiera_base_plus_16x224(**kwdargs): return hiera_base_16x224( embed_dim=112, num_heads=2, stages=(2, 3, 16, 3), **kwdargs ) -@pretrained_model(None) +@pretrained_model({ + "mae_k400_ft_k400": "https://dl.fbaipublicfiles.com/hiera/hiera_large_16x224.pth", + "mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_large_16x224.pth", +}, default="mae_k400_ft_k400") def hiera_large_16x224(**kwdargs): return hiera_base_16x224( embed_dim=144, num_heads=2, stages=(2, 6, 36, 4), **kwdargs ) -@pretrained_model(None) +@pretrained_model({ + "mae_k400_ft_k400": "https://dl.fbaipublicfiles.com/hiera/hiera_huge_16x224.pth", + "mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_huge_16x224.pth", +}, default="mae_k400_ft_k400") def hiera_huge_16x224(**kwdargs): return hiera_base_16x224( embed_dim=256, num_heads=4, stages=(2, 6, 36, 4), **kwdargs diff --git a/hiera/hiera_mae.py b/hiera/hiera_mae.py index 650c4ae..64c69cc 100644 --- a/hiera/hiera_mae.py +++ b/hiera/hiera_mae.py @@ -161,6 +161,28 @@ def get_pixel_label_2d( return label + def get_pixel_label_3d( + self, input_vid: torch.Tensor, mask: torch.Tensor, norm: bool = True + ) -> torch.Tensor: + # mask (boolean tensor): True must correspond to *masked* + + # We use time strided loss, only take the first frame from each token + input_vid = input_vid[:, :, ::self.patch_stride[0], :, :] + + size = self.pred_stride + label = input_vid.unfold(3, size, size).unfold(4, size, size) + label = label.permute(0, 2, 3, 4, 5, 6, 1) # Different from 2d, mistake during training lol + label = label.flatten(1, 3).flatten(2) + label = label[mask] + + if norm: + mean = label.mean(dim=-1, keepdim=True) + var = label.var(dim=-1, keepdim=True) + label = (label - mean) / (var + 1.0e-6) ** 0.5 + + return label + + def forward_encoder( self, x: torch.Tensor, mask_ratio: float, mask: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -242,6 +264,8 @@ def forward_loss( """ if len(self.q_stride) == 2: label = self.get_pixel_label_2d(x, mask) + elif len(self.q_stride) == 3: + label = self.get_pixel_label_3d(x, mask) else: raise NotImplementedError @@ -270,43 +294,105 @@ def forward( # Image Models -@pretrained_model(None) +@pretrained_model({ + "mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_tiny_224.pth", +}, default="mae_in1k") def mae_hiera_tiny_224(**kwargs): return MaskedAutoencoderHiera( embed_dim=96, num_heads=1, stages=(1, 2, 7, 2), q_pool=2, **kwargs, ) -@pretrained_model(None) +@pretrained_model({ + "mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_small_224.pth", +}, default="mae_in1k") def mae_hiera_small_224(**kwargs): return MaskedAutoencoderHiera( embed_dim=96, num_heads=1, stages=(1, 2, 11, 2), q_pool=2, **kwargs, ) -@pretrained_model(None) +@pretrained_model({ + "mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_224.pth", +}, default="mae_in1k") def mae_hiera_base_224(**kwargs): return MaskedAutoencoderHiera( embed_dim=96, num_heads=1, stages=(2, 3, 16, 3), q_pool=2, **kwargs, ) -@pretrained_model(None) +@pretrained_model({ + "mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_plus_224.pth", +}, default="mae_in1k") def mae_hiera_base_plus_224(**kwargs): return MaskedAutoencoderHiera( embed_dim=112, num_heads=2, stages=(2, 3, 16, 3), q_pool=2, **kwargs, ) -@pretrained_model(None) +@pretrained_model({ + "mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_large_224.pth", +}, default="mae_in1k") def mae_hiera_large_224(**kwargs): return MaskedAutoencoderHiera( embed_dim=144, num_heads=2, stages=(2, 6, 36, 4), q_pool=2, **kwargs, ) -@pretrained_model(None) +@pretrained_model({ + "mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_huge_224.pth", +}, default="mae_in1k") def mae_hiera_huge_224(**kwargs): return MaskedAutoencoderHiera( embed_dim=256, num_heads=4, stages=(2, 6, 36, 4), q_pool=2, **kwargs, ) + + + +# Video Models + +@pretrained_model({ + "mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_16x224.pth", +}, default="mae_k400") +def mae_hiera_base_16x224(num_classes: int = 400, **kwdargs): + return MaskedAutoencoderHiera( + num_classes=num_classes, # K400 has 400 classes + input_size=(16, 224, 224), + q_stride=(1, 2, 2), + mask_unit_size=(1, 8, 8), + patch_kernel=(3, 7, 7), + patch_stride=(2, 4, 4), + patch_padding=(1, 3, 3), + sep_pos_embed=True, + q_pool=2, + **kwdargs + ) + + +@pretrained_model({ + "mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_plus_16x224.pth", +}, default="mae_k400") +@pretrained_model(None) +def mae_hiera_base_plus_16x224(**kwdargs): + return mae_hiera_base_16x224( + embed_dim=112, num_heads=2, stages=(2, 3, 16, 3), **kwdargs + ) + + +@pretrained_model({ + "mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_large_16x224.pth", +}, default="mae_k400") +@pretrained_model(None) +def mae_hiera_large_16x224(**kwdargs): + return mae_hiera_base_16x224( + embed_dim=144, num_heads=2, stages=(2, 6, 36, 4), **kwdargs + ) + + +@pretrained_model({ + "mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_huge_16x224.pth", +}, default="mae_k400") +def mae_hiera_huge_16x224(**kwdargs): + return mae_hiera_base_16x224( + embed_dim=256, num_heads=4, stages=(2, 6, 36, 4), **kwdargs + ) diff --git a/hiera/hiera_utils.py b/hiera/hiera_utils.py index 49b1821..62d087b 100644 --- a/hiera/hiera_utils.py +++ b/hiera/hiera_utils.py @@ -52,6 +52,10 @@ def model_def(pretrained: bool = False, checkpoint: str = default, strict: bool model = model_func(**kwdargs) if pretrained: + # Disable being strict when trying to load a encoder-decoder model into an encoder-only model + if "decoder_pos_embed" in state_dict["model_state"] and not hasattr(model, "decoder_pos_embed"): + strict = False + model.load_state_dict(state_dict["model_state"], strict=strict) return model diff --git a/setup.py b/setup.py index 83866c2..f51318b 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ setup( name="hiera-transformer", - version="0.1.1", + version="0.1.2", author="Chaitanya Ryali, Daniel Bolya", url="https://github.com/facebookresearch/hiera", description="A fast, powerful, and simple hierarchical vision transformer",