Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unused cls_token in PatchEmbeddingBlock #3454

Closed
night-gale opened this issue Dec 8, 2021 · 6 comments · Fixed by #3475
Closed

Unused cls_token in PatchEmbeddingBlock #3454

night-gale opened this issue Dec 8, 2021 · 6 comments · Fixed by #3475
Assignees
Labels
enhancement New feature or request

Comments

@night-gale
Copy link

Describe the bug
When I was training the ViT with torch DistributedDataParallel, during backward, torch raises error and reports that

Parameters which did not receive grad for rank 0: vit.patch_embedding.cls_token

which means that the cls_token did not participate in the backward process.

I checked the implementation of ViT and PatchEmbeddingBlock and found the unused cls_token in monai.networks.blocks.patchembedding.py: PatchEmbeddingBlock.
image

To Reproduce
Steps to reproduce the behavior:

  1. set environment variable in shell TORCH_DISTRIBUTED_DEBUG=INFO
  2. train ViT with torch DistributedDataParallel
@Nic-Ma
Copy link
Contributor

Nic-Ma commented Dec 8, 2021

Thanks for raising the issue.
Hi @ahatamiz ,

Could you please help double confirm the issue?
If we really don't need the cls_token, please remove it.

Thanks in advance.

@Nic-Ma Nic-Ma added the enhancement New feature or request label Dec 8, 2021
@ahatamiz
Copy link
Contributor

ahatamiz commented Dec 8, 2021

Hi @night-gale

Thanks for your comments. If you utilize ViT for classification application only, then the classification flag needs to be activated. Doing so will enable the use of cls_token as shown here.

Originally, ViT is used as segmentation backbone for UNETR, hence the application needs to be specificed.

Lastly, cls_token plays an important role in the ViT for classification as it assigns the class type. Hence, removing it will be against the original ViT design. I recommend reading the paper here:
https://arxiv.org/pdf/2010.11929.pdf

Thanks

@night-gale
Copy link
Author

Hi! @ahatamiz
Thanks for your reply!

I understand that the cls_token is an essential component of ViT and can be toggled off by passing classification as False.

However, the redundant cls_token I found is in the PatchEmbeddingBlock. It is not reference in the forward method and cannot be turned off by passing argument.

I currently removed the cls_token in my local copy of Monai and everything now works fine.

It would be great if you could double check the implementation of PatchEmbeddingBlock.

Thanks!

@ahatamiz
Copy link
Contributor

ahatamiz commented Dec 8, 2021

Hi @night-gale

Thanks for pointing out the issue. I see that there is an unused cls_token in here
I will address this in a new PR.

Thanks

@ahatamiz
Copy link
Contributor

Hi @Nic-Ma

Thanks for the efforts. I would be appreciate it if this can be addressed in future PRs.

Thanks.

@Nic-Ma
Copy link
Contributor

Nic-Ma commented Dec 10, 2021

Hi @ahatamiz ,

OK, sure, I will fix it in a PR soon.

Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants