-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MaxVit model #6342
MaxVit model #6342
Changes from 27 commits
f15fd92
c5b2839
5e8a222
aa95139
1fddecc
b7f0e97
872f40f
f561edf
314b82a
a4863e9
c4406e4
2111680
cc51c2b
d2dfe71
328f9b6
b334b7f
ebb8c16
e281371
20422bc
9ad86fe
775990c
a24e549
bb42548
ed21d3d
09e4ced
521d6d5
79cb004
97cbcd8
9fc6a5b
6b00ca8
45d3966
2aca920
cab35c1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
MaxVit | ||
=============== | ||
|
||
.. currentmodule:: torchvision.models | ||
|
||
The MaxVit transformer models are based on the `MaxViT: Multi-Axis Vision Transformer <https://arxiv.org/abs/2204.01697>`__ | ||
paper. | ||
|
||
|
||
Model builders | ||
-------------- | ||
|
||
The following model builders can be used to instantiate an MaxVit model with and without pre-trained weights. | ||
All the model builders internally rely on the ``torchvision.models.maxvit.MaxVit`` | ||
base class. Please refer to the `source code | ||
<https://github.com/pytorch/vision/blob/main/torchvision/models/maxvit.py>`_ for | ||
more details about this class. | ||
|
||
.. autosummary:: | ||
:toctree: generated/ | ||
:template: function.rst | ||
|
||
maxvit_t |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import unittest | ||
|
||
import pytest | ||
import torch | ||
|
||
from torchvision.models.maxvit import SwapAxes, WindowDepartition, WindowPartition | ||
|
||
|
||
class MaxvitTester(unittest.TestCase): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand that here you are testing specific layers from MaxViT. This is not something we did previously, so perhaps it does need to be on a separate file. @YosuaMichael any thoughts here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry for a pretty late response! |
||
def test_maxvit_window_partition(self): | ||
input_shape = (1, 3, 224, 224) | ||
partition_size = 7 | ||
n_partitions = input_shape[3] // partition_size | ||
|
||
x = torch.randn(input_shape) | ||
|
||
partition = WindowPartition() | ||
departition = WindowDepartition() | ||
|
||
x_hat = partition(x, partition_size) | ||
x_hat = departition(x_hat, partition_size, n_partitions, n_partitions) | ||
|
||
assert torch.allclose(x, x_hat) | ||
|
||
def test_maxvit_grid_partition(self): | ||
input_shape = (1, 3, 224, 224) | ||
partition_size = 7 | ||
n_partitions = input_shape[3] // partition_size | ||
|
||
x = torch.randn(input_shape) | ||
pre_swap = SwapAxes(-2, -3) | ||
post_swap = SwapAxes(-2, -3) | ||
|
||
partition = WindowPartition() | ||
departition = WindowDepartition() | ||
|
||
x_hat = partition(x, n_partitions) | ||
x_hat = pre_swap(x_hat) | ||
x_hat = post_swap(x_hat) | ||
x_hat = departition(x_hat, n_partitions, partition_size, partition_size) | ||
|
||
assert torch.allclose(x, x_hat) | ||
|
||
|
||
if __name__ == "__main__": | ||
pytest.main([__file__]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@TeodorPoncu I think this is a bug. I believe you meant to write:
Could you please confirm?
Edit: I issued a fix at #6642
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, sorry for that. You've correctly guessed what I wanted to write. Thanks for catching it out. I think the
--train-center-crop
flag should be removed from the training command docs as well to reflect the way the weights were trained.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@TeodorPoncu thanks for coming back to me. Does this mean that you didn't actually use the flag during training? Can we remove it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the flag can be removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's not what I see on the training log of the trained model. I see that
train_center_crop=True
. Do we have the right model available on the checkpoint area on AWS?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but given the fix, in order to replicate one will have to run with
train_center_crop=False
in order to have the same preprocessing behavior during training as the AWS weights had.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK so you suspect that this bug was introduced way early right? You don't happen to know more or less which githash you used to train this? I can have a look for you if you give me a rough estimation or band of githashes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the bug was introduced and used when performing the training as in 1fddecc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. I checked all commits prior f561edf (date before training) and all of them use RandomCrop. I'll remove the flag.