diff --git a/models/swin_transformer_v2.py b/models/swin_transformer_v2.py index a429d0a2c..e5697d877 100644 --- a/models/swin_transformer_v2.py +++ b/models/swin_transformer_v2.py @@ -153,7 +153,7 @@ def forward(self, x, mask=None): # cosine attention attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) - logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01))).exp() + logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.get_device())).exp() attn = attn * logit_scale relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)