2323from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
2424from ...modeling_utils import ALL_ATTENTION_FUNCTIONS , PreTrainedModel
2525from ...processing_utils import Unpack
26+ from ...pytorch_utils import compile_compatible_method_lru_cache
2627from ...utils import (
2728 ModelOutput ,
2829 TransformersKwargs ,
@@ -68,6 +69,20 @@ class KeypointMatchingOutput(ModelOutput):
6869 attentions : Optional [tuple [torch .FloatTensor ]] = None
6970
7071
72+ @compile_compatible_method_lru_cache (maxsize = 32 )
73+ def compute_embeddings (inv_freq : torch .Tensor , embed_height : int , embed_width : int , hidden_size : int ) -> torch .Tensor :
74+ i_indices = torch .ones (embed_height , embed_width , dtype = inv_freq .dtype , device = inv_freq .device )
75+ j_indices = torch .ones (embed_height , embed_width , dtype = inv_freq .dtype , device = inv_freq .device )
76+ i_indices = i_indices .cumsum (0 ).unsqueeze (- 1 )
77+ j_indices = j_indices .cumsum (1 ).unsqueeze (- 1 )
78+
79+ emb = torch .zeros (1 , embed_height , embed_width , hidden_size // 2 , dtype = inv_freq .dtype , device = inv_freq .device )
80+ emb [:, :, :, 0 ::2 ] = i_indices * inv_freq
81+ emb [:, :, :, 1 ::2 ] = j_indices * inv_freq
82+
83+ return emb
84+
85+
7186class EfficientLoFTRRotaryEmbedding (nn .Module ):
7287 inv_freq : torch .Tensor # fix linting for `register_buffer`
7388
@@ -80,23 +95,18 @@ def __init__(self, config: EfficientLoFTRConfig, device=None):
8095 inv_freq , _ = self .rope_init_fn (self .config , device )
8196 inv_freq_expanded = inv_freq [None , None , None , :].float ().expand (1 , 1 , 1 , - 1 )
8297
83- embed_height , embed_width = config .embedding_size
84- i_indices = torch .ones (embed_height , embed_width ).cumsum (0 ).float ().unsqueeze (- 1 )
85- j_indices = torch .ones (embed_height , embed_width ).cumsum (1 ).float ().unsqueeze (- 1 )
86-
87- emb = torch .zeros (1 , embed_height , embed_width , self .config .hidden_size // 2 )
88- emb [:, :, :, 0 ::2 ] = i_indices * inv_freq_expanded
89- emb [:, :, :, 1 ::2 ] = j_indices * inv_freq_expanded
90-
91- self .register_buffer ("inv_freq" , emb , persistent = False )
98+ self .register_buffer ("inv_freq" , inv_freq_expanded , persistent = False )
9299
93100 @torch .no_grad ()
94101 def forward (
95102 self , x : torch .Tensor , position_ids : Optional [tuple [torch .LongTensor , torch .LongTensor ]] = None
96103 ) -> tuple [torch .Tensor , torch .Tensor ]:
104+ feats_height , feats_width = x .shape [- 2 :]
105+ embed_height = (feats_height - self .config .q_aggregation_kernel_size ) // self .config .q_aggregation_stride + 1
106+ embed_width = (feats_width - self .config .q_aggregation_kernel_size ) // self .config .q_aggregation_stride + 1
97107 device_type = x .device .type if isinstance (x .device .type , str ) and x .device .type != "mps" else "cpu"
98108 with torch .autocast (device_type = device_type , enabled = False ): # Force float32
99- emb = self .inv_freq
109+ emb = compute_embeddings ( self .inv_freq , embed_height , embed_width , self . config . hidden_size )
100110 sin = emb .sin ()
101111 cos = emb .cos ()
102112
0 commit comments