Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,6 @@ class EfficientLoFTRConfig(PretrainedConfig):
Kernel size used for the fine feature matching
batch_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the batch normalization layers.
embedding_size (`List`, *optional*, defaults to [15, 20]):
The size (height, width) of the embedding for the position embeddings.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
partial_rotary_factor (`float`, *optional*, defaults to 4.0):
Expand Down Expand Up @@ -130,7 +128,6 @@ def __init__(
coarse_matching_border_removal: int = 2,
fine_kernel_size: int = 8,
batch_norm_eps: float = 1e-5,
embedding_size: Optional[list[int]] = None,
rope_theta: float = 10000.0,
partial_rotary_factor: float = 4.0,
rope_scaling: Optional[dict] = None,
Expand Down Expand Up @@ -163,7 +160,7 @@ def __init__(
self.hidden_size = hidden_size
if self.hidden_size != self.out_features[-1]:
raise ValueError(
f"hidden_size should be equal to the last value in out_features. hidden_size = {self.hidden_size}, out_features = {self.stage_out_channels}"
f"hidden_size should be equal to the last value in out_features. hidden_size = {self.hidden_size}, out_features = {self.out_features[-1]}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix attribute name

)

self.activation_function = activation_function
Expand All @@ -187,7 +184,6 @@ def __init__(
self.fine_matching_regress_temperature = fine_matching_regress_temperature

self.num_key_value_heads = num_attention_heads
self.embedding_size = embedding_size if embedding_size is not None else [15, 20]
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling if rope_scaling is not None else {"rope_type": "default"}

Expand Down
30 changes: 20 additions & 10 deletions src/transformers/models/efficientloftr/modeling_efficientloftr.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...pytorch_utils import compile_compatible_method_lru_cache
from ...utils import (
ModelOutput,
TransformersKwargs,
Expand Down Expand Up @@ -68,6 +69,20 @@ class KeypointMatchingOutput(ModelOutput):
attentions: Optional[tuple[torch.FloatTensor]] = None


@compile_compatible_method_lru_cache(maxsize=32)
def compute_embeddings(inv_freq: torch.Tensor, embed_height: int, embed_width: int, hidden_size: int) -> torch.Tensor:
i_indices = torch.ones(embed_height, embed_width, dtype=inv_freq.dtype, device=inv_freq.device)
j_indices = torch.ones(embed_height, embed_width, dtype=inv_freq.dtype, device=inv_freq.device)
i_indices = i_indices.cumsum(0).unsqueeze(-1)
j_indices = j_indices.cumsum(1).unsqueeze(-1)

emb = torch.zeros(1, embed_height, embed_width, hidden_size // 2, dtype=inv_freq.dtype, device=inv_freq.device)
emb[:, :, :, 0::2] = i_indices * inv_freq
emb[:, :, :, 1::2] = j_indices * inv_freq

return emb


class EfficientLoFTRRotaryEmbedding(nn.Module):
inv_freq: torch.Tensor # fix linting for `register_buffer`

Expand All @@ -80,23 +95,18 @@ def __init__(self, config: EfficientLoFTRConfig, device=None):
inv_freq, _ = self.rope_init_fn(self.config, device)
inv_freq_expanded = inv_freq[None, None, None, :].float().expand(1, 1, 1, -1)

embed_height, embed_width = config.embedding_size
i_indices = torch.ones(embed_height, embed_width).cumsum(0).float().unsqueeze(-1)
j_indices = torch.ones(embed_height, embed_width).cumsum(1).float().unsqueeze(-1)

emb = torch.zeros(1, embed_height, embed_width, self.config.hidden_size // 2)
emb[:, :, :, 0::2] = i_indices * inv_freq_expanded
emb[:, :, :, 1::2] = j_indices * inv_freq_expanded

self.register_buffer("inv_freq", emb, persistent=False)
self.register_buffer("inv_freq", inv_freq_expanded, persistent=False)

@torch.no_grad()
def forward(
self, x: torch.Tensor, position_ids: Optional[tuple[torch.LongTensor, torch.LongTensor]] = None
) -> tuple[torch.Tensor, torch.Tensor]:
feats_height, feats_width = x.shape[-2:]
embed_height = (feats_height - self.config.q_aggregation_kernel_size) // self.config.q_aggregation_stride + 1
embed_width = (feats_width - self.config.q_aggregation_kernel_size) // self.config.q_aggregation_stride + 1
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
emb = self.inv_freq
emb = compute_embeddings(self.inv_freq, embed_height, embed_width, self.config.hidden_size)
sin = emb.sin()
cos = emb.cos()

Expand Down
4 changes: 2 additions & 2 deletions tests/models/efficientloftr/test_modeling_efficientloftr.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ def __init__(
image_width=80,
image_height=60,
stage_num_blocks: list[int] = [1, 1, 1],
out_features: list[int] = [32, 32, 64],
out_features: list[int] = [32, 32, 128],
stage_stride: list[int] = [2, 1, 2],
q_aggregation_kernel_size: int = 1,
kv_aggregation_kernel_size: int = 1,
q_aggregation_stride: int = 1,
kv_aggregation_stride: int = 1,
num_attention_layers: int = 2,
num_attention_heads: int = 8,
hidden_size: int = 64,
hidden_size: int = 128,
Comment on lines +53 to +61
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's for FA2 tests to pass

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How didn't I catch this one before ? 🤔 But thanks for taking care of it !

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's skipped locally in case you don't have FA2 installed 😄

coarse_matching_threshold: float = 0.0,
fine_kernel_size: int = 2,
coarse_matching_border_removal: int = 0,
Expand Down