Skip to content

Commit 20cf94a

Browse files
committed
fix: reverted efficientloftr embeddings computation to inference time with lru cache
1 parent 3b72301 commit 20cf94a

File tree

2 files changed

+16
-14
lines changed

2 files changed

+16
-14
lines changed

src/transformers/models/efficientloftr/configuration_efficientloftr.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,6 @@ class EfficientLoFTRConfig(PretrainedConfig):
6868
Kernel size used for the fine feature matching
6969
batch_norm_eps (`float`, *optional*, defaults to 1e-05):
7070
The epsilon used by the batch normalization layers.
71-
embedding_size (`List`, *optional*, defaults to [15, 20]):
72-
The size (height, width) of the embedding for the position embeddings.
7371
rope_theta (`float`, *optional*, defaults to 10000.0):
7472
The base period of the RoPE embeddings.
7573
partial_rotary_factor (`float`, *optional*, defaults to 4.0):
@@ -130,7 +128,6 @@ def __init__(
130128
coarse_matching_border_removal: int = 2,
131129
fine_kernel_size: int = 8,
132130
batch_norm_eps: float = 1e-5,
133-
embedding_size: Optional[list[int]] = None,
134131
rope_theta: float = 10000.0,
135132
partial_rotary_factor: float = 4.0,
136133
rope_scaling: Optional[dict] = None,
@@ -187,7 +184,6 @@ def __init__(
187184
self.fine_matching_regress_temperature = fine_matching_regress_temperature
188185

189186
self.num_key_value_heads = num_attention_heads
190-
self.embedding_size = embedding_size if embedding_size is not None else [15, 20]
191187
self.rope_theta = rope_theta
192188
self.rope_scaling = rope_scaling if rope_scaling is not None else {"rope_type": "default"}
193189

src/transformers/models/efficientloftr/modeling_efficientloftr.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
2424
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
2525
from ...processing_utils import Unpack
26+
from ...pytorch_utils import compile_compatible_method_lru_cache
2627
from ...utils import (
2728
ModelOutput,
2829
TransformersKwargs,
@@ -68,6 +69,18 @@ class KeypointMatchingOutput(ModelOutput):
6869
attentions: Optional[tuple[torch.FloatTensor]] = None
6970

7071

72+
@compile_compatible_method_lru_cache(maxsize=32)
73+
def compute_embeddings(inv_freq: torch.Tensor, embed_height: int, embed_width: int, hidden_size: int) -> torch.Tensor:
74+
i_indices = torch.ones(embed_height, embed_width).cumsum(0).float().unsqueeze(-1)
75+
j_indices = torch.ones(embed_height, embed_width).cumsum(1).float().unsqueeze(-1)
76+
77+
emb = torch.zeros(1, embed_height, embed_width, hidden_size // 2)
78+
emb[:, :, :, 0::2] = i_indices * inv_freq
79+
emb[:, :, :, 1::2] = j_indices * inv_freq
80+
81+
return emb
82+
83+
7184
class EfficientLoFTRRotaryEmbedding(nn.Module):
7285
inv_freq: torch.Tensor # fix linting for `register_buffer`
7386

@@ -80,23 +93,16 @@ def __init__(self, config: EfficientLoFTRConfig, device=None):
8093
inv_freq, _ = self.rope_init_fn(self.config, device)
8194
inv_freq_expanded = inv_freq[None, None, None, :].float().expand(1, 1, 1, -1)
8295

83-
embed_height, embed_width = config.embedding_size
84-
i_indices = torch.ones(embed_height, embed_width).cumsum(0).float().unsqueeze(-1)
85-
j_indices = torch.ones(embed_height, embed_width).cumsum(1).float().unsqueeze(-1)
86-
87-
emb = torch.zeros(1, embed_height, embed_width, self.config.hidden_size // 2)
88-
emb[:, :, :, 0::2] = i_indices * inv_freq_expanded
89-
emb[:, :, :, 1::2] = j_indices * inv_freq_expanded
90-
91-
self.register_buffer("inv_freq", emb, persistent=False)
96+
self.register_buffer("inv_freq", inv_freq_expanded, persistent=False)
9297

9398
@torch.no_grad()
9499
def forward(
95100
self, x: torch.Tensor, position_ids: Optional[tuple[torch.LongTensor, torch.LongTensor]] = None
96101
) -> tuple[torch.Tensor, torch.Tensor]:
102+
features_height, features_width = x.shape[-2:]
97103
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
98104
with torch.autocast(device_type=device_type, enabled=False): # Force float32
99-
emb = self.inv_freq
105+
emb = compute_embeddings(self.inv_freq, features_height, features_width, self.config.hidden_size)
100106
sin = emb.sin()
101107
cos = emb.cos()
102108

0 commit comments

Comments
 (0)