-
Notifications
You must be signed in to change notification settings - Fork 7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Added maxvit architecture and tests * rebased + addresed comments * Revert "rebased + addresed comments" This reverts commit c5b2839. * Re-added model changes after revert * aligned with partial original implementation * removed submitit script fixed lint * mypy fix for too many arguments * updated old tests * removed per batch lr scheduler and seed setting * removed ontap * added docs, validated weights * fixed test expect, moved shape assertions in the begging for torch.fx compatibility * mypy fix * lint fix * added legacy interface * added weight link * updated docs * Update references/classification/train.py Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com> * Update torchvision/models/maxvit.py Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com> * adressed comments * update ra_maginuted and augmix_severity default values * adressed some comments * remove input_channels parameter Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
- Loading branch information
1 parent
d65e286
commit 6b1646c
Showing
9 changed files
with
940 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
MaxVit | ||
=============== | ||
|
||
.. currentmodule:: torchvision.models | ||
|
||
The MaxVit transformer models are based on the `MaxViT: Multi-Axis Vision Transformer <https://arxiv.org/abs/2204.01697>`__ | ||
paper. | ||
|
||
|
||
Model builders | ||
-------------- | ||
|
||
The following model builders can be used to instantiate an MaxVit model with and without pre-trained weights. | ||
All the model builders internally rely on the ``torchvision.models.maxvit.MaxVit`` | ||
base class. Please refer to the `source code | ||
<https://github.com/pytorch/vision/blob/main/torchvision/models/maxvit.py>`_ for | ||
more details about this class. | ||
|
||
.. autosummary:: | ||
:toctree: generated/ | ||
:template: function.rst | ||
|
||
maxvit_t |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import unittest | ||
|
||
import pytest | ||
import torch | ||
|
||
from torchvision.models.maxvit import SwapAxes, WindowDepartition, WindowPartition | ||
|
||
|
||
class MaxvitTester(unittest.TestCase): | ||
def test_maxvit_window_partition(self): | ||
input_shape = (1, 3, 224, 224) | ||
partition_size = 7 | ||
n_partitions = input_shape[3] // partition_size | ||
|
||
x = torch.randn(input_shape) | ||
|
||
partition = WindowPartition() | ||
departition = WindowDepartition() | ||
|
||
x_hat = partition(x, partition_size) | ||
x_hat = departition(x_hat, partition_size, n_partitions, n_partitions) | ||
|
||
assert torch.allclose(x, x_hat) | ||
|
||
def test_maxvit_grid_partition(self): | ||
input_shape = (1, 3, 224, 224) | ||
partition_size = 7 | ||
n_partitions = input_shape[3] // partition_size | ||
|
||
x = torch.randn(input_shape) | ||
pre_swap = SwapAxes(-2, -3) | ||
post_swap = SwapAxes(-2, -3) | ||
|
||
partition = WindowPartition() | ||
departition = WindowDepartition() | ||
|
||
x_hat = partition(x, n_partitions) | ||
x_hat = pre_swap(x_hat) | ||
x_hat = post_swap(x_hat) | ||
x_hat = departition(x_hat, n_partitions, partition_size, partition_size) | ||
|
||
assert torch.allclose(x, x_hat) | ||
|
||
|
||
if __name__ == "__main__": | ||
pytest.main([__file__]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.