Skip to content

Commit

Permalink
[Mochi-1] ensuring to compute the fourier features in FP32 in Mochi e…
Browse files Browse the repository at this point in the history
…ncoder (huggingface#10031)

compute fourier features in FP32.
  • Loading branch information
sayakpaul authored Nov 29, 2024
1 parent 6b288ec commit c96bfa5
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/diffusers/models/autoencoders/autoencoder_kl_mochi.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,8 @@ def __init__(self, start: int = 6, stop: int = 8, step: int = 1):

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
r"""Forward method of the `FourierFeatures` class."""

original_dtype = inputs.dtype
inputs = inputs.to(torch.float32)
num_channels = inputs.shape[1]
num_freqs = (self.stop - self.start) // self.step

Expand All @@ -450,7 +451,7 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
# Scale channels by frequency.
h = w * h

return torch.cat([inputs, torch.sin(h), torch.cos(h)], dim=1)
return torch.cat([inputs, torch.sin(h), torch.cos(h)], dim=1).to(original_dtype)


class MochiEncoder3D(nn.Module):
Expand Down

0 comments on commit c96bfa5

Please sign in to comment.