Skip to content

Commit 52aaa3f

Browse files
sbucaillequbvel
andauthored
[EfficientLoFTR] dynamic image size support (#40329)
* fix: reverted efficientloftr embeddings computation to inference time with lru cache * fix: added dtype and device for torch ones and zeros creation * fix: fixed embed height and width computation with aggregation * fix: make style * fix error message * fix fa2 tests --------- Co-authored-by: qubvel <qubvel@gmail.com>
1 parent ed5dd29 commit 52aaa3f

File tree

3 files changed

+23
-17
lines changed

3 files changed

+23
-17
lines changed

src/transformers/models/efficientloftr/configuration_efficientloftr.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,6 @@ class EfficientLoFTRConfig(PretrainedConfig):
6868
Kernel size used for the fine feature matching
6969
batch_norm_eps (`float`, *optional*, defaults to 1e-05):
7070
The epsilon used by the batch normalization layers.
71-
embedding_size (`List`, *optional*, defaults to [15, 20]):
72-
The size (height, width) of the embedding for the position embeddings.
7371
rope_theta (`float`, *optional*, defaults to 10000.0):
7472
The base period of the RoPE embeddings.
7573
partial_rotary_factor (`float`, *optional*, defaults to 4.0):
@@ -130,7 +128,6 @@ def __init__(
130128
coarse_matching_border_removal: int = 2,
131129
fine_kernel_size: int = 8,
132130
batch_norm_eps: float = 1e-5,
133-
embedding_size: Optional[list[int]] = None,
134131
rope_theta: float = 10000.0,
135132
partial_rotary_factor: float = 4.0,
136133
rope_scaling: Optional[dict] = None,
@@ -163,7 +160,7 @@ def __init__(
163160
self.hidden_size = hidden_size
164161
if self.hidden_size != self.out_features[-1]:
165162
raise ValueError(
166-
f"hidden_size should be equal to the last value in out_features. hidden_size = {self.hidden_size}, out_features = {self.stage_out_channels}"
163+
f"hidden_size should be equal to the last value in out_features. hidden_size = {self.hidden_size}, out_features = {self.out_features[-1]}"
167164
)
168165

169166
self.activation_function = activation_function
@@ -187,7 +184,6 @@ def __init__(
187184
self.fine_matching_regress_temperature = fine_matching_regress_temperature
188185

189186
self.num_key_value_heads = num_attention_heads
190-
self.embedding_size = embedding_size if embedding_size is not None else [15, 20]
191187
self.rope_theta = rope_theta
192188
self.rope_scaling = rope_scaling if rope_scaling is not None else {"rope_type": "default"}
193189

src/transformers/models/efficientloftr/modeling_efficientloftr.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
2424
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
2525
from ...processing_utils import Unpack
26+
from ...pytorch_utils import compile_compatible_method_lru_cache
2627
from ...utils import (
2728
ModelOutput,
2829
TransformersKwargs,
@@ -68,6 +69,20 @@ class KeypointMatchingOutput(ModelOutput):
6869
attentions: Optional[tuple[torch.FloatTensor]] = None
6970

7071

72+
@compile_compatible_method_lru_cache(maxsize=32)
73+
def compute_embeddings(inv_freq: torch.Tensor, embed_height: int, embed_width: int, hidden_size: int) -> torch.Tensor:
74+
i_indices = torch.ones(embed_height, embed_width, dtype=inv_freq.dtype, device=inv_freq.device)
75+
j_indices = torch.ones(embed_height, embed_width, dtype=inv_freq.dtype, device=inv_freq.device)
76+
i_indices = i_indices.cumsum(0).unsqueeze(-1)
77+
j_indices = j_indices.cumsum(1).unsqueeze(-1)
78+
79+
emb = torch.zeros(1, embed_height, embed_width, hidden_size // 2, dtype=inv_freq.dtype, device=inv_freq.device)
80+
emb[:, :, :, 0::2] = i_indices * inv_freq
81+
emb[:, :, :, 1::2] = j_indices * inv_freq
82+
83+
return emb
84+
85+
7186
class EfficientLoFTRRotaryEmbedding(nn.Module):
7287
inv_freq: torch.Tensor # fix linting for `register_buffer`
7388

@@ -80,23 +95,18 @@ def __init__(self, config: EfficientLoFTRConfig, device=None):
8095
inv_freq, _ = self.rope_init_fn(self.config, device)
8196
inv_freq_expanded = inv_freq[None, None, None, :].float().expand(1, 1, 1, -1)
8297

83-
embed_height, embed_width = config.embedding_size
84-
i_indices = torch.ones(embed_height, embed_width).cumsum(0).float().unsqueeze(-1)
85-
j_indices = torch.ones(embed_height, embed_width).cumsum(1).float().unsqueeze(-1)
86-
87-
emb = torch.zeros(1, embed_height, embed_width, self.config.hidden_size // 2)
88-
emb[:, :, :, 0::2] = i_indices * inv_freq_expanded
89-
emb[:, :, :, 1::2] = j_indices * inv_freq_expanded
90-
91-
self.register_buffer("inv_freq", emb, persistent=False)
98+
self.register_buffer("inv_freq", inv_freq_expanded, persistent=False)
9299

93100
@torch.no_grad()
94101
def forward(
95102
self, x: torch.Tensor, position_ids: Optional[tuple[torch.LongTensor, torch.LongTensor]] = None
96103
) -> tuple[torch.Tensor, torch.Tensor]:
104+
feats_height, feats_width = x.shape[-2:]
105+
embed_height = (feats_height - self.config.q_aggregation_kernel_size) // self.config.q_aggregation_stride + 1
106+
embed_width = (feats_width - self.config.q_aggregation_kernel_size) // self.config.q_aggregation_stride + 1
97107
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
98108
with torch.autocast(device_type=device_type, enabled=False): # Force float32
99-
emb = self.inv_freq
109+
emb = compute_embeddings(self.inv_freq, embed_height, embed_width, self.config.hidden_size)
100110
sin = emb.sin()
101111
cos = emb.cos()
102112

tests/models/efficientloftr/test_modeling_efficientloftr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,15 @@ def __init__(
5050
image_width=80,
5151
image_height=60,
5252
stage_num_blocks: list[int] = [1, 1, 1],
53-
out_features: list[int] = [32, 32, 64],
53+
out_features: list[int] = [32, 32, 128],
5454
stage_stride: list[int] = [2, 1, 2],
5555
q_aggregation_kernel_size: int = 1,
5656
kv_aggregation_kernel_size: int = 1,
5757
q_aggregation_stride: int = 1,
5858
kv_aggregation_stride: int = 1,
5959
num_attention_layers: int = 2,
6060
num_attention_heads: int = 8,
61-
hidden_size: int = 64,
61+
hidden_size: int = 128,
6262
coarse_matching_threshold: float = 0.0,
6363
fine_kernel_size: int = 2,
6464
coarse_matching_border_removal: int = 0,

0 commit comments

Comments
 (0)