Skip to content

Commit 6fb51f5

Browse files
committed
doc
Signed-off-by: vgrau98 <victor.grau93@gmail.com>
1 parent 6d61b22 commit 6fb51f5

File tree

2 files changed

+38
-5
lines changed

2 files changed

+38
-5
lines changed

docs/source/networks.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,12 @@ Blocks
248248
.. autoclass:: monai.apps.reconstruction.networks.blocks.varnetblock.VarNetBlock
249249
:members:
250250

251+
`Attention utilities`
252+
~~~~~~~~~~~~~~~~~~~~~
253+
.. automodule:: monai.networks.blocks.attention_utils
254+
.. autofunction:: monai.networks.blocks.attention_utils.get_rel_pos
255+
.. autofunction:: monai.networks.blocks.attention_utils.add_decomposed_rel_pos
256+
251257
N-Dim Fourier Transform
252258
~~~~~~~~~~~~~~~~~~~~~~~~
253259
.. automodule:: monai.networks.blocks.fft_utils_t

monai/networks/blocks/attention_utils.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
2020
"""
2121
Get relative positional embeddings according to the relative positions of
22-
query and key sizes.
22+
query and key sizes.
23+
2324
Args:
2425
q_size (int): size of query q.
2526
k_size (int): size of key k.
@@ -51,10 +52,36 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor
5152
def add_decomposed_rel_pos(
5253
attn: torch.Tensor, q: torch.Tensor, rel_pos_lst: nn.ParameterList, q_size: Tuple, k_size: Tuple
5354
) -> torch.Tensor:
54-
"""
55-
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
56-
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
55+
r"""
56+
Calculate decomposed Relative Positional Embeddings from mvitv2 implementation:
57+
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
58+
5759
Only 2D and 3D are supported.
60+
61+
Encoding the relative position of tokens in the attention matrix: tokens spaced a distance
62+
`d` apart will have the same embedding value (unlike absolute positional embedding).
63+
64+
.. math::
65+
Attn_{logits}(Q, K) = (QK^{T} + E_{rel})*scale
66+
67+
where
68+
69+
.. math::
70+
E_{ij}^{(rel)} = Q_{i}.R_{p(i), p(j)}
71+
72+
with :math:`R_{p(i), p(j)} \in R^{dim}` and :math:`p(i), p(j)`,
73+
respectively spatial positions of element :math:`i` and :math:`j`
74+
75+
When using "decomposed" relative positional embedding, positional embedding is defined ("decomposed") as follow:
76+
77+
.. math::
78+
R_{p(i), p(j)} = R^{d1}_{d1(i), d1(j)} + ... + R^{dn}_{dn(i), dn(j)}
79+
80+
with :math:`n = 1...dim`
81+
82+
Decomposed relative positional embedding reduces the complexity from :math:`\mathcal{O}(d1*...*dn)` to
83+
:math:`\mathcal{O}(d1+...+dn)` compared with classical relative positional embedding.
84+
5885
Args:
5986
attn (Tensor): attention map.
6087
q (Tensor): query q in the attention layer with shape (B, s_dim_1 * ... * s_dim_n, C).
@@ -63,7 +90,7 @@ def add_decomposed_rel_pos(
6390
k_size (Tuple): spatial sequence size of key k with (k_dim_1, ..., k_dim_n).
6491
6592
Returns:
66-
attn (Tensor): attention map with added relative positional embeddings.
93+
attn (Tensor): attention logits with added relative positional embeddings.
6794
"""
6895
rh = get_rel_pos(q_size[0], k_size[0], rel_pos_lst[0])
6996
rw = get_rel_pos(q_size[1], k_size[1], rel_pos_lst[1])

0 commit comments

Comments
 (0)