-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MaxVit model #6342
MaxVit model #6342
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@TeodorPoncu That was an unexpected surprise, thanks a lot for contributing this architecture.
I know that your PR is draft but I've added a few remarks mostly related to our coding styles and practices. I haven't verified the ML side of things. Feel free to ignore this is you thing my input is premature.
@TeodorPoncu It seems that in a recent commit, you accidentally updated all the expected files for all models. Could you please revert that? |
This reverts commit c5b2839.
@datumbox Sorry about that, everything should be fine now.
|
Related discussion and pointers on generalizing fixed resolution for Swin: #6227 Also, I wonder if more relative-attention related modules can be reused from Swin |
torchvision/models/maxvit.py
Outdated
self.register_buffer("relative_position_index", get_relative_position_index(self.size, self.size)) | ||
|
||
# initialize with truncated normal the bias | ||
self.positional_bias.data.normal_(mean=0, std=0.02) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't https://pytorch.org/docs/stable/nn.init.html?highlight=nn%20init#torch.nn.init.normal_ be used here?
torchvision/models/maxvit.py
Outdated
self.scale_factor = feat_dim**-0.5 | ||
|
||
self.merge = nn.Linear(self.head_dim * self.n_heads, feat_dim) | ||
self.positional_bias = nn.parameter.Parameter( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if it's initizlized to normal just below, shouldn't torch.empty be used here?
torchvision/models/maxvit.py
Outdated
|
||
def get_relative_positional_bias(self) -> torch.Tensor: | ||
bias_index = self.relative_position_index.view(-1) # type: ignore | ||
relative_bias = self.positional_bias[bias_index].view(self.max_seq_len, self.max_seq_len, -1) # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if it's just for flattening some end dimensions .flatten(start_dim = 2) can be used here
self.b = b | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
res = torch.swapaxes(x, self.a, self.b) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
swapaxes is NumPy-compat-dialect alias for torch.transpose (https://pytorch.org/docs/stable/generated/torch.swapaxes.html#torch.swapaxes). If the rest is in Torch-lingo, shouldn't it be (arg names to match https://pytorch.org/docs/stable/generated/torch.transpose.html#torch.transpose):
class SwapAxes(nn.Module):
def __init__(self, dim0: int, dim1: int) -> None:
super().__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.transpose(self.dim0, self.dim1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR @TeodorPoncu. I've added some initial comments on the reference scripts updates. Happy to chat more.
Running the deployed weights with the following command: Yields the following results: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@TeodorPoncu pretty awesome work and top quality code.
I've left a few Nits here and there, just to make sure that the implementation code aligns with the idioms of the rest of TorchVision. The only comment worth considering a bit more is the one concerning the number of parameters you will pass on the constructor. This comment is mostly to align MaxViT
with all other models and make it easier on the future to make changes to all of them. In some cases in particular, you expose parameters that are indeed quite useful (num of input channels), but such changes on the API would be best if we introduce them in all models, not just MaxViT.
Other than the above, I didn't validate much the architecture part of things but focused mainly on idioms and code styles. I know that you've already reproduced the accuracy of the paper, which is great but it's worth doing one scan with @jdsgomes prior merging to confirm that there are no deviations from the original implementation (for example on padding of input whose spatial dimensions are not divisible by p
).
All CI failures seems unrelated. Other than these final validations and nits, I think the implementation is in awesome shape and we should be able to merge it soon.
from torchvision.models.maxvit import SwapAxes, WindowDepartition, WindowPartition | ||
|
||
|
||
class MaxvitTester(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand that here you are testing specific layers from MaxViT. This is not something we did previously, so perhaps it does need to be on a separate file.
@YosuaMichael any thoughts here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for a pretty late response!
Currently I dont have any opinion how we should test specific layer of the model and I think this is okay. (Need more time to think and discuss whether we should do more of this kind of test or not)
Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
Two more requests:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all the changes, few follow ups:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did a final pass to review the ML side and looks good. Happy to approve after the pending changes requested are done. Really nice work
…sion into BATERIES]-add-max-vit
Hey @TeodorPoncu! You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py |
[transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] | ||
if center_crop | ||
else [transforms.CenterCrop(crop_size)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@TeodorPoncu I think this is a bug. I believe you meant to write:
[transforms.CenterCrop(crop_size)]
if center_crop
else [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
Could you please confirm?
Edit: I issued a fix at #6642
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, sorry for that. You've correctly guessed what I wanted to write. Thanks for catching it out. I think the --train-center-crop
flag should be removed from the training command docs as well to reflect the way the weights were trained.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@TeodorPoncu thanks for coming back to me. Does this mean that you didn't actually use the flag during training? Can we remove it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the flag can be removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's not what I see on the training log of the trained model. I see that train_center_crop=True
. Do we have the right model available on the checkpoint area on AWS?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but given the fix, in order to replicate one will have to run with train_center_crop=False
in order to have the same preprocessing behavior during training as the AWS weights had.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK so you suspect that this bug was introduced way early right? You don't happen to know more or less which githash you used to train this? I can have a look for you if you give me a rough estimation or band of githashes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the bug was introduced and used when performing the training as in 1fddecc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. I checked all commits prior f561edf (date before training) and all of them use RandomCrop. I'll remove the flag.
Summary: * Added maxvit architecture and tests * rebased + addresed comments * Revert "rebased + addresed comments" This reverts commit c5b2839. * Re-added model changes after revert * aligned with partial original implementation * removed submitit script fixed lint * mypy fix for too many arguments * updated old tests * removed per batch lr scheduler and seed setting * removed ontap * added docs, validated weights * fixed test expect, moved shape assertions in the begging for torch.fx compatibility * mypy fix * lint fix * added legacy interface * added weight link * updated docs * Update references/classification/train.py * Update torchvision/models/maxvit.py * adressed comments * update ra_maginuted and augmix_severity default values * adressed some comments * remove input_channels parameter Reviewed By: YosuaMichael Differential Revision: D39885422 fbshipit-source-id: c51942974bf17f6473c3b3b08a4d16aad5812dc3 Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com> Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com> Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
This PR is w.r.t. Batteries Phase 3 proposal to add the MaxVit architecture. It is still a work in progress as it has yet to be trained.
One caveat w.r.t. the way we would be exposing this model API to users is that the architecture is bound to the specific input size it was trained one (due to the usage of relative positional encodings)
Running the command:
torchrun --nproc_per_node=1 train.py --test-only --prototype --weights MaxVit_T_Weights.IMAGENET1K_V1 --model maxvit_t -b 1
yields the following results:Test: Acc@1 83.700 Acc@5 96.722