Skip to content

Commit

Permalink
MaxVit model (#6342)
Browse files Browse the repository at this point in the history
* Added maxvit architecture and tests

* rebased + addresed comments

* Revert "rebased + addresed comments"

This reverts commit c5b2839.

* Re-added model changes after revert

* aligned with partial original implementation

* removed submitit script fixed lint

* mypy fix for too many arguments

* updated old tests

* removed per batch lr scheduler and seed setting

* removed ontap

* added docs, validated weights

* fixed test expect, moved shape assertions in the begging for torch.fx compatibility

* mypy fix

* lint fix

* added legacy interface

* added weight link

* updated docs

* Update references/classification/train.py

Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>

* Update torchvision/models/maxvit.py

Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>

* adressed comments

* update ra_maginuted and augmix_severity default values

* adressed some comments

* remove input_channels parameter

Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
  • Loading branch information
TeodorPoncu and datumbox authored Sep 23, 2022
1 parent d65e286 commit 6b1646c
Show file tree
Hide file tree
Showing 9 changed files with 940 additions and 6 deletions.
1 change: 1 addition & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ weights:
models/efficientnetv2
models/googlenet
models/inception
models/maxvit
models/mnasnet
models/mobilenetv2
models/mobilenetv3
Expand Down
23 changes: 23 additions & 0 deletions docs/source/models/maxvit.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
MaxVit
===============

.. currentmodule:: torchvision.models

The MaxVit transformer models are based on the `MaxViT: Multi-Axis Vision Transformer <https://arxiv.org/abs/2204.01697>`__
paper.


Model builders
--------------

The following model builders can be used to instantiate an MaxVit model with and without pre-trained weights.
All the model builders internally rely on the ``torchvision.models.maxvit.MaxVit``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/maxvit.py>`_ for
more details about this class.

.. autosummary::
:toctree: generated/
:template: function.rst

maxvit_t
8 changes: 8 additions & 0 deletions references/classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,14 @@ Here `$MODEL` is one of `swin_v2_t`, `swin_v2_s` or `swin_v2_b`.
Note that `--val-resize-size` was optimized in a post-training step, see their `Weights` entry for the exact value.


### MaxViT
```
torchrun --nproc_per_node=8 --n_nodes=4 train.py\
--model $MODEL --epochs 400 --batch-size 128 --opt adamw --lr 3e-3 --weight-decay 0.05 --lr-scheduler cosineannealinglr --lr-min 1e-5 --lr-warmup-method linear --lr-warmup-epochs 32 --label-smoothing 0.1 --mixup-alpha 0.8 --clip-grad-norm 1.0 --interpolation bicubic --auto-augment ta_wide --policy-magnitude 15 --train-center-crop --model-ema --val-resize-size 224
--val-crop-size 224 --train-crop-size 224 --amp --model-ema-steps 32 --transformer-embedding-decay 0 --sync-bn
```
Here `$MODEL` is `maxvit_t`.
Note that `--val-resize-size` was not optimized in a post-training step.


### ShuffleNet V2
Expand Down
13 changes: 10 additions & 3 deletions references/classification/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,25 @@ def __init__(
interpolation=InterpolationMode.BILINEAR,
hflip_prob=0.5,
auto_augment_policy=None,
ra_magnitude=9,
augmix_severity=3,
random_erase_prob=0.0,
center_crop=False,
):
trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
trans = (
[transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
if center_crop
else [transforms.CenterCrop(crop_size)]
)
if hflip_prob > 0:
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
if auto_augment_policy is not None:
if auto_augment_policy == "ra":
trans.append(autoaugment.RandAugment(interpolation=interpolation))
trans.append(autoaugment.RandAugment(interpolation=interpolation, magnitude=ra_magnitude))
elif auto_augment_policy == "ta_wide":
trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation))
elif auto_augment_policy == "augmix":
trans.append(autoaugment.AugMix(interpolation=interpolation))
trans.append(autoaugment.AugMix(interpolation=interpolation, severity=augmix_severity))
else:
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation))
Expand Down
25 changes: 22 additions & 3 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,12 @@ def _get_cache_path(filepath):
def load_data(traindir, valdir, args):
# Data loading code
print("Loading data")
val_resize_size, val_crop_size, train_crop_size = args.val_resize_size, args.val_crop_size, args.train_crop_size
val_resize_size, val_crop_size, train_crop_size, center_crop = (
args.val_resize_size,
args.val_crop_size,
args.train_crop_size,
args.train_center_crop,
)
interpolation = InterpolationMode(args.interpolation)

print("Loading training data")
Expand All @@ -126,13 +131,18 @@ def load_data(traindir, valdir, args):
else:
auto_augment_policy = getattr(args, "auto_augment", None)
random_erase_prob = getattr(args, "random_erase", 0.0)
ra_magnitude = args.ra_magnitude
augmix_severity = args.augmix_severity
dataset = torchvision.datasets.ImageFolder(
traindir,
presets.ClassificationPresetTrain(
center_crop=center_crop,
crop_size=train_crop_size,
interpolation=interpolation,
auto_augment_policy=auto_augment_policy,
random_erase_prob=random_erase_prob,
ra_magnitude=ra_magnitude,
augmix_severity=augmix_severity,
),
)
if args.cache_dataset:
Expand Down Expand Up @@ -207,7 +217,10 @@ def main(args):
mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
if mixup_transforms:
mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
collate_fn = lambda batch: mixupcutmix(*default_collate(batch)) # noqa: E731

def collate_fn(batch):
return mixupcutmix(*default_collate(batch))

data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
Expand Down Expand Up @@ -448,6 +461,8 @@ def get_args_parser(add_help=True):
action="store_true",
)
parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)")
parser.add_argument("--ra-magnitude", default=9, type=int, help="magnitude of auto augment policy")
parser.add_argument("--augmix-severity", default=3, type=int, help="severity of augmix policy")
parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)")

# Mixed precision training parameters
Expand Down Expand Up @@ -486,13 +501,17 @@ def get_args_parser(add_help=True):
parser.add_argument(
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
)
parser.add_argument(
"--train-center-crop",
action="store_true",
help="use center crop instead of random crop for training (default: False)",
)
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training")
parser.add_argument(
"--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)"
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

return parser


Expand Down
Binary file added test/expect/ModelTester.test_maxvit_t_expect.pkl
Binary file not shown.
46 changes: 46 additions & 0 deletions test/test_architecture_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import unittest

import pytest
import torch

from torchvision.models.maxvit import SwapAxes, WindowDepartition, WindowPartition


class MaxvitTester(unittest.TestCase):
def test_maxvit_window_partition(self):
input_shape = (1, 3, 224, 224)
partition_size = 7
n_partitions = input_shape[3] // partition_size

x = torch.randn(input_shape)

partition = WindowPartition()
departition = WindowDepartition()

x_hat = partition(x, partition_size)
x_hat = departition(x_hat, partition_size, n_partitions, n_partitions)

assert torch.allclose(x, x_hat)

def test_maxvit_grid_partition(self):
input_shape = (1, 3, 224, 224)
partition_size = 7
n_partitions = input_shape[3] // partition_size

x = torch.randn(input_shape)
pre_swap = SwapAxes(-2, -3)
post_swap = SwapAxes(-2, -3)

partition = WindowPartition()
departition = WindowDepartition()

x_hat = partition(x, n_partitions)
x_hat = pre_swap(x_hat)
x_hat = post_swap(x_hat)
x_hat = departition(x_hat, n_partitions, partition_size, partition_size)

assert torch.allclose(x, x_hat)


if __name__ == "__main__":
pytest.main([__file__])
1 change: 1 addition & 0 deletions torchvision/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@
from .vgg import *
from .vision_transformer import *
from .swin_transformer import *
from .maxvit import *
from . import detection, optical_flow, quantization, segmentation, video
from ._api import get_model, get_model_builder, get_model_weights, get_weight, list_models
Loading

0 comments on commit 6b1646c

Please sign in to comment.