Skip to content

Commit 3fa6b94

Browse files
committed
fix: added dtype and device for torch ones and zeros creation
1 parent 20cf94a commit 3fa6b94

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

src/transformers/models/efficientloftr/modeling_efficientloftr.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,12 @@ class KeypointMatchingOutput(ModelOutput):
7171

7272
@compile_compatible_method_lru_cache(maxsize=32)
7373
def compute_embeddings(inv_freq: torch.Tensor, embed_height: int, embed_width: int, hidden_size: int) -> torch.Tensor:
74-
i_indices = torch.ones(embed_height, embed_width).cumsum(0).float().unsqueeze(-1)
75-
j_indices = torch.ones(embed_height, embed_width).cumsum(1).float().unsqueeze(-1)
74+
i_indices = torch.ones(embed_height, embed_width, dtype=inv_freq.dtype, device=inv_freq.device)
75+
j_indices = torch.ones(embed_height, embed_width, dtype=inv_freq.dtype, device=inv_freq.device)
76+
i_indices = i_indices.cumsum(0).unsqueeze(-1)
77+
j_indices = j_indices.cumsum(1).unsqueeze(-1)
7678

77-
emb = torch.zeros(1, embed_height, embed_width, hidden_size // 2)
79+
emb = torch.zeros(1, embed_height, embed_width, hidden_size // 2, dtype=inv_freq.dtype, device=inv_freq.device)
7880
emb[:, :, :, 0::2] = i_indices * inv_freq
7981
emb[:, :, :, 1::2] = j_indices * inv_freq
8082

0 commit comments

Comments
 (0)