Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Access self-attention matrix of vision transformer #6032

Closed
sophmrtn opened this issue Feb 21, 2023 · 5 comments · Fixed by #6271 or #6308
Closed

Access self-attention matrix of vision transformer #6032

sophmrtn opened this issue Feb 21, 2023 · 5 comments · Fixed by #6271 or #6308

Comments

@sophmrtn
Copy link

sophmrtn commented Feb 21, 2023

Is your feature request related to a problem? Please describe.
I am training a vision transformer using monai and would like to carry out some interpretability analysis. However, the current model code does not save the self-attention matrix during training, and it is not straightforward to pass it from the self-attention block to the model output.

def forward(self, x):
output = self.input_rearrange(self.qkv(x))
q, k, v = output[0], output[1], output[2]
att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1)
att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
x = self.out_rearrange(x)
x = self.out_proj(x)
x = self.drop_output(x)
return x

Describe the solution you'd like
An option to output the attn_mat from the self-attention block in the model forward pass (before matrix multiplication with the input) or access it after training as a class attribute.

@a-parida12
Copy link
Contributor

@wyli I would like to help with this issue.

I was thinking, the easiest way to achieve this without changing the API could be by making att_mat before dropout as a class attribute so could be accessed by SABlock().attn_mat. also, a parameter like store_attn:bool in the __init__ to prevent memory overhead when the feature is not required.

Let me know your thoughts.

@wyli
Copy link
Contributor

wyli commented Apr 2, 2023

sounds good, please make sure it's backward compatible, e.g. previously saved checkpoints can still be loaded by default.

@wyli wyli closed this as completed in #6271 Apr 3, 2023
wyli added a commit that referenced this issue Apr 3, 2023
Fixes #6032  .

### Description

A few sentences describing the changes proposed in this pull request.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [x] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Ben Murray <ben.murray@gmail.com>
Signed-off-by: a-parida12 <abhijeet.parida@tum.de>
Signed-off-by: YanxuanLiu <yanxuanl@nvidia.com>
Signed-off-by: monai-bot <monai.miccai2019@gmail.com>
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
Co-authored-by: Ben Murray <ben.murray@gmail.com>
Co-authored-by: YanxuanLiu <104543031+YanxuanLiu@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Wenqi Li <831580+wyli@users.noreply.github.com>
Co-authored-by: monai-bot <monai.miccai2019@gmail.com>
Co-authored-by: Wenqi Li <wenqil@nvidia.com>
@wyli wyli reopened this Apr 4, 2023
@a-parida12
Copy link
Contributor

@wyli so #6271 is merged? So do i put a new merge request for a potential solution?

wyli pushed a commit that referenced this issue Apr 11, 2023
Fixes #6032 .

### Description

specified the data type of the `att_matrix` to be compliant with
torch.jit.compile requirements for conditionals
### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: a-parida12 <abhijeet.parida@tum.de>
@a-parida12
Copy link
Contributor

@wyli I just realized that ViT backbone is used by by UNetr and ViTAutoEnc ideally they should have the option to allow access to the the attn_mat else it will always be set as the default value and no way it can be changed by the user of UNetr and ViTAutoEnc. What do you think?

@wyli
Copy link
Contributor

wyli commented Apr 18, 2023

Sure, please help create another feature request to follow up..

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
3 participants