2323from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
2424from ...modeling_utils import ALL_ATTENTION_FUNCTIONS , PreTrainedModel
2525from ...processing_utils import Unpack
26+ from ...pytorch_utils import compile_compatible_method_lru_cache
2627from ...utils import (
2728 ModelOutput ,
2829 TransformersKwargs ,
@@ -68,6 +69,18 @@ class KeypointMatchingOutput(ModelOutput):
6869 attentions : Optional [tuple [torch .FloatTensor ]] = None
6970
7071
72+ @compile_compatible_method_lru_cache (maxsize = 32 )
73+ def compute_embeddings (inv_freq : torch .Tensor , embed_height : int , embed_width : int , hidden_size : int ) -> torch .Tensor :
74+ i_indices = torch .ones (embed_height , embed_width ).cumsum (0 ).float ().unsqueeze (- 1 )
75+ j_indices = torch .ones (embed_height , embed_width ).cumsum (1 ).float ().unsqueeze (- 1 )
76+
77+ emb = torch .zeros (1 , embed_height , embed_width , hidden_size // 2 )
78+ emb [:, :, :, 0 ::2 ] = i_indices * inv_freq
79+ emb [:, :, :, 1 ::2 ] = j_indices * inv_freq
80+
81+ return emb
82+
83+
7184class EfficientLoFTRRotaryEmbedding (nn .Module ):
7285 inv_freq : torch .Tensor # fix linting for `register_buffer`
7386
@@ -80,23 +93,16 @@ def __init__(self, config: EfficientLoFTRConfig, device=None):
8093 inv_freq , _ = self .rope_init_fn (self .config , device )
8194 inv_freq_expanded = inv_freq [None , None , None , :].float ().expand (1 , 1 , 1 , - 1 )
8295
83- embed_height , embed_width = config .embedding_size
84- i_indices = torch .ones (embed_height , embed_width ).cumsum (0 ).float ().unsqueeze (- 1 )
85- j_indices = torch .ones (embed_height , embed_width ).cumsum (1 ).float ().unsqueeze (- 1 )
86-
87- emb = torch .zeros (1 , embed_height , embed_width , self .config .hidden_size // 2 )
88- emb [:, :, :, 0 ::2 ] = i_indices * inv_freq_expanded
89- emb [:, :, :, 1 ::2 ] = j_indices * inv_freq_expanded
90-
91- self .register_buffer ("inv_freq" , emb , persistent = False )
96+ self .register_buffer ("inv_freq" , inv_freq_expanded , persistent = False )
9297
9398 @torch .no_grad ()
9499 def forward (
95100 self , x : torch .Tensor , position_ids : Optional [tuple [torch .LongTensor , torch .LongTensor ]] = None
96101 ) -> tuple [torch .Tensor , torch .Tensor ]:
102+ features_height , features_width = x .shape [- 2 :]
97103 device_type = x .device .type if isinstance (x .device .type , str ) and x .device .type != "mps" else "cpu"
98104 with torch .autocast (device_type = device_type , enabled = False ): # Force float32
99- emb = self .inv_freq
105+ emb = compute_embeddings ( self .inv_freq , features_height , features_width , self . config . hidden_size )
100106 sin = emb .sin ()
101107 cos = emb .cos ()
102108
0 commit comments