Skip to content

Commit

Permalink
Cast a to v's dtype in neighborhood attention blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
crowsonkb committed Jan 24, 2024
1 parent 8ace775 commit 6ab5146
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion k_diffusion/models/image_transformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def forward(self, x, pos, cond):
raise ModuleNotFoundError("natten is required for neighborhood attention")
flops.op(flops.op_natten, q.shape, k.shape, v.shape, self.kernel_size)
qk = natten.functional.natten2dqk(q, k, self.kernel_size, 1)
a = torch.softmax(qk, dim=-1)
a = torch.softmax(qk, dim=-1).to(v.dtype)
x = natten.functional.natten2dav(a, v, self.kernel_size, 1)
x = rearrange(x, "n nh h w e -> n h w (nh e)")
x = self.dropout(x)
Expand Down

0 comments on commit 6ab5146

Please sign in to comment.