Skip to content

Commit

Permalink
simplification of the sincos positional encoding in patchembedding.py (
Browse files Browse the repository at this point in the history
…#7605)

Fixes #7564 .

### Description

As discussed, a small simplification for the creation of sincos
positional encoding where we don't need to use the `torch.no_grad()`
context or copy the tensor with `copy_` from torch which doesn't
preserve the `requires_grad` attribute here.

The changes are simple and are linked to the corresponding comment
#7564, the output is already in float32 so it doesn't seem particularly
necessary to apply the conversion previously done.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

Signed-off-by: Lucas Robinet <robinet.lucas@iuct-oncopole.fr>
Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
  • Loading branch information
Lucas-rbnt and KumoLiu authored Apr 4, 2024
1 parent 763347d commit 195d7dd
Showing 1 changed file with 1 addition and 4 deletions.
5 changes: 1 addition & 4 deletions monai/networks/blocks/patchembedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,7 @@ def __init__(
for in_size, pa_size in zip(img_size, patch_size):
grid_size.append(in_size // pa_size)

with torch.no_grad():
pos_embeddings = build_sincos_position_embedding(grid_size, hidden_size, spatial_dims)
self.position_embeddings.data.copy_(pos_embeddings.float())
self.position_embeddings.requires_grad = False
self.position_embeddings = build_sincos_position_embedding(grid_size, hidden_size, spatial_dims)
else:
raise ValueError(f"pos_embed_type {self.pos_embed_type} not supported.")

Expand Down

0 comments on commit 195d7dd

Please sign in to comment.