Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MaxVit model #6342

Merged
merged 33 commits into from
Sep 23, 2022
Merged

MaxVit model #6342

merged 33 commits into from
Sep 23, 2022

Conversation

TeodorPoncu
Copy link
Contributor

@TeodorPoncu TeodorPoncu commented Aug 1, 2022

This PR is w.r.t. Batteries Phase 3 proposal to add the MaxVit architecture. It is still a work in progress as it has yet to be trained.

One caveat w.r.t. the way we would be exposing this model API to users is that the architecture is bound to the specific input size it was trained one (due to the usage of relative positional encodings)

Running the command: torchrun --nproc_per_node=1 train.py --test-only --prototype --weights MaxVit_T_Weights.IMAGENET1K_V1 --model maxvit_t -b 1 yields the following results:

Test: Acc@1 83.700 Acc@5 96.722

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TeodorPoncu That was an unexpected surprise, thanks a lot for contributing this architecture.

I know that your PR is draft but I've added a few remarks mostly related to our coding styles and practices. I haven't verified the ML side of things. Feel free to ignore this is you thing my input is premature.

test/test_models.py Outdated Show resolved Hide resolved
torchvision/models/maxvit.py Outdated Show resolved Hide resolved
torchvision/models/maxvit.py Outdated Show resolved Hide resolved
torchvision/models/maxvit.py Outdated Show resolved Hide resolved
torchvision/models/maxvit.py Outdated Show resolved Hide resolved
torchvision/models/maxvit.py Show resolved Hide resolved
torchvision/models/maxvit.py Outdated Show resolved Hide resolved
torchvision/models/maxvit.py Outdated Show resolved Hide resolved
@datumbox
Copy link
Contributor

datumbox commented Aug 5, 2022

@TeodorPoncu It seems that in a recent commit, you accidentally updated all the expected files for all models. Could you please revert that?

@TeodorPoncu
Copy link
Contributor Author

@datumbox Sorry about that, everything should be fine now.

@TeodorPoncu It seems that in a recent commit, you accidentally updated all the expected files for all models. Could you please revert that?

@vadimkantorov
Copy link

vadimkantorov commented Aug 19, 2022

Related discussion and pointers on generalizing fixed resolution for Swin: #6227

Also, I wonder if more relative-attention related modules can be reused from Swin

self.register_buffer("relative_position_index", get_relative_position_index(self.size, self.size))

# initialize with truncated normal the bias
self.positional_bias.data.normal_(mean=0, std=0.02)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.scale_factor = feat_dim**-0.5

self.merge = nn.Linear(self.head_dim * self.n_heads, feat_dim)
self.positional_bias = nn.parameter.Parameter(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it's initizlized to normal just below, shouldn't torch.empty be used here?


def get_relative_positional_bias(self) -> torch.Tensor:
bias_index = self.relative_position_index.view(-1) # type: ignore
relative_bias = self.positional_bias[bias_index].view(self.max_seq_len, self.max_seq_len, -1) # type: ignore

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it's just for flattening some end dimensions .flatten(start_dim = 2) can be used here

self.b = b

def forward(self, x: torch.Tensor) -> torch.Tensor:
res = torch.swapaxes(x, self.a, self.b)
Copy link

@vadimkantorov vadimkantorov Aug 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

swapaxes is NumPy-compat-dialect alias for torch.transpose (https://pytorch.org/docs/stable/generated/torch.swapaxes.html#torch.swapaxes). If the rest is in Torch-lingo, shouldn't it be (arg names to match https://pytorch.org/docs/stable/generated/torch.transpose.html#torch.transpose):

class SwapAxes(nn.Module):
    def __init__(self, dim0: int, dim1: int) -> None:
        super().__init__()
        self.dim0 = dim0
        self.dim1 = dim1

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x.transpose(self.dim0, self.dim1)

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @TeodorPoncu. I've added some initial comments on the reference scripts updates. Happy to chat more.

references/classification/run_with_submitit.py Outdated Show resolved Hide resolved
references/classification/train.py Outdated Show resolved Hide resolved
references/classification/utils.py Outdated Show resolved Hide resolved
references/classification/train.py Outdated Show resolved Hide resolved
references/classification/train.py Outdated Show resolved Hide resolved
@TeodorPoncu TeodorPoncu changed the title [WIP] MaxVit model MaxVit model Sep 21, 2022
@TeodorPoncu
Copy link
Contributor Author

Running the deployed weights with the following command:
torchrun --nproc_per_node=1 train.py --model maxvit_t --interpolation bicubic --batch-size 1 --test-only --weights MaxVit_T_Weights.IMAGENET1K_V1

Yields the following results:
Test: Acc@1 83.700 Acc@5 96.722

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TeodorPoncu pretty awesome work and top quality code.

I've left a few Nits here and there, just to make sure that the implementation code aligns with the idioms of the rest of TorchVision. The only comment worth considering a bit more is the one concerning the number of parameters you will pass on the constructor. This comment is mostly to align MaxViT with all other models and make it easier on the future to make changes to all of them. In some cases in particular, you expose parameters that are indeed quite useful (num of input channels), but such changes on the API would be best if we introduce them in all models, not just MaxViT.

Other than the above, I didn't validate much the architecture part of things but focused mainly on idioms and code styles. I know that you've already reproduced the accuracy of the paper, which is great but it's worth doing one scan with @jdsgomes prior merging to confirm that there are no deviations from the original implementation (for example on padding of input whose spatial dimensions are not divisible by p).

All CI failures seems unrelated. Other than these final validations and nits, I think the implementation is in awesome shape and we should be able to merge it soon.

docs/source/models.rst Outdated Show resolved Hide resolved
references/classification/presets.py Outdated Show resolved Hide resolved
references/classification/train.py Outdated Show resolved Hide resolved
from torchvision.models.maxvit import SwapAxes, WindowDepartition, WindowPartition


class MaxvitTester(unittest.TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that here you are testing specific layers from MaxViT. This is not something we did previously, so perhaps it does need to be on a separate file.

@YosuaMichael any thoughts here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for a pretty late response!
Currently I dont have any opinion how we should test specific layer of the model and I think this is okay. (Need more time to think and discuss whether we should do more of this kind of test or not)

torchvision/models/maxvit.py Show resolved Hide resolved
torchvision/models/maxvit.py Outdated Show resolved Hide resolved
torchvision/models/maxvit.py Outdated Show resolved Hide resolved
torchvision/models/maxvit.py Outdated Show resolved Hide resolved
torchvision/models/maxvit.py Outdated Show resolved Hide resolved
torchvision/models/maxvit.py Outdated Show resolved Hide resolved
TeodorPoncu and others added 2 commits September 21, 2022 17:26
Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
@datumbox
Copy link
Contributor

Two more requests:

  • Could you please upload the weights on manifold (see internal guide)
  • Could you update the PR description to show-case the output accuracy of the following command?
torchrun --nproc_per_node=1 train.py --test-only --prototype --weights MaxVit_T_Weights.IMAGENET1K_V1 --model maxvit_t -b 1

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the changes, few follow ups:

torchvision/models/maxvit.py Outdated Show resolved Hide resolved
torchvision/models/maxvit.py Outdated Show resolved Hide resolved
torchvision/models/maxvit.py Outdated Show resolved Hide resolved
torchvision/models/maxvit.py Show resolved Hide resolved
torchvision/models/maxvit.py Outdated Show resolved Hide resolved
references/classification/train.py Outdated Show resolved Hide resolved
Copy link
Contributor

@jdsgomes jdsgomes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did a final pass to review the ML side and looks good. Happy to approve after the pending changes requested are done. Really nice work

@TeodorPoncu TeodorPoncu merged commit 6b1646c into main Sep 23, 2022
@github-actions
Copy link

Hey @TeodorPoncu!

You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py

@datumbox datumbox deleted the BATERIES]-add-max-vit branch September 23, 2022 14:56
Comment on lines +22 to +24
[transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
if center_crop
else [transforms.CenterCrop(crop_size)]
Copy link
Contributor

@datumbox datumbox Sep 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TeodorPoncu I think this is a bug. I believe you meant to write:

            [transforms.CenterCrop(crop_size)]
            if center_crop
            else [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]

Could you please confirm?

Edit: I issued a fix at #6642

Copy link
Contributor Author

@TeodorPoncu TeodorPoncu Sep 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, sorry for that. You've correctly guessed what I wanted to write. Thanks for catching it out. I think the --train-center-crop flag should be removed from the training command docs as well to reflect the way the weights were trained.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TeodorPoncu thanks for coming back to me. Does this mean that you didn't actually use the flag during training? Can we remove it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the flag can be removed.

Copy link
Contributor

@datumbox datumbox Sep 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not what I see on the training log of the trained model. I see that train_center_crop=True. Do we have the right model available on the checkpoint area on AWS?

Copy link
Contributor Author

@TeodorPoncu TeodorPoncu Sep 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but given the fix, in order to replicate one will have to run with train_center_crop=False in order to have the same preprocessing behavior during training as the AWS weights had.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK so you suspect that this bug was introduced way early right? You don't happen to know more or less which githash you used to train this? I can have a look for you if you give me a rough estimation or band of githashes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the bug was introduced and used when performing the training as in 1fddecc

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. I checked all commits prior f561edf (date before training) and all of them use RandomCrop. I'll remove the flag.

facebook-github-bot pushed a commit that referenced this pull request Sep 29, 2022
Summary:
* Added maxvit architecture and tests

* rebased + addresed comments

* Revert "rebased + addresed comments"

This reverts commit c5b2839.

* Re-added model changes after revert

* aligned with partial original implementation

* removed submitit script fixed lint

* mypy fix for too many arguments

* updated old tests

* removed per batch lr scheduler and seed setting

* removed ontap

* added docs, validated weights

* fixed test expect, moved shape assertions in the begging for torch.fx compatibility

* mypy fix

* lint fix

* added legacy interface

* added weight link

* updated docs

* Update references/classification/train.py

* Update torchvision/models/maxvit.py

* adressed comments

* update ra_maginuted and augmix_severity default values

* adressed some comments

* remove input_channels parameter

Reviewed By: YosuaMichael

Differential Revision: D39885422

fbshipit-source-id: c51942974bf17f6473c3b3b08a4d16aad5812dc3

Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants