19
19
def get_rel_pos (q_size : int , k_size : int , rel_pos : torch .Tensor ) -> torch .Tensor :
20
20
"""
21
21
Get relative positional embeddings according to the relative positions of
22
- query and key sizes.
22
+ query and key sizes.
23
+
23
24
Args:
24
25
q_size (int): size of query q.
25
26
k_size (int): size of key k.
@@ -51,10 +52,36 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor
51
52
def add_decomposed_rel_pos (
52
53
attn : torch .Tensor , q : torch .Tensor , rel_pos_lst : nn .ParameterList , q_size : Tuple , k_size : Tuple
53
54
) -> torch .Tensor :
54
- """
55
- Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
56
- https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
55
+ r"""
56
+ Calculate decomposed Relative Positional Embeddings from mvitv2 implementation:
57
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
58
+
57
59
Only 2D and 3D are supported.
60
+
61
+ Encoding the relative position of tokens in the attention matrix: tokens spaced a distance
62
+ `d` apart will have the same embedding value (unlike absolute positional embedding).
63
+
64
+ .. math::
65
+ Attn_{logits}(Q, K) = (QK^{T} + E_{rel})*scale
66
+
67
+ where
68
+
69
+ .. math::
70
+ E_{ij}^{(rel)} = Q_{i}.R_{p(i), p(j)}
71
+
72
+ with :math:`R_{p(i), p(j)} \in R^{dim}` and :math:`p(i), p(j)`,
73
+ respectively spatial positions of element :math:`i` and :math:`j`
74
+
75
+ When using "decomposed" relative positional embedding, positional embedding is defined ("decomposed") as follow:
76
+
77
+ .. math::
78
+ R_{p(i), p(j)} = R^{d1}_{d1(i), d1(j)} + ... + R^{dn}_{dn(i), dn(j)}
79
+
80
+ with :math:`n = 1...dim`
81
+
82
+ Decomposed relative positional embedding reduces the complexity from :math:`\mathcal{O}(d1*...*dn)` to
83
+ :math:`\mathcal{O}(d1+...+dn)` compared with classical relative positional embedding.
84
+
58
85
Args:
59
86
attn (Tensor): attention map.
60
87
q (Tensor): query q in the attention layer with shape (B, s_dim_1 * ... * s_dim_n, C).
@@ -63,7 +90,7 @@ def add_decomposed_rel_pos(
63
90
k_size (Tuple): spatial sequence size of key k with (k_dim_1, ..., k_dim_n).
64
91
65
92
Returns:
66
- attn (Tensor): attention map with added relative positional embeddings.
93
+ attn (Tensor): attention logits with added relative positional embeddings.
67
94
"""
68
95
rh = get_rel_pos (q_size [0 ], k_size [0 ], rel_pos_lst [0 ])
69
96
rw = get_rel_pos (q_size [1 ], k_size [1 ], rel_pos_lst [1 ])
0 commit comments