Skip to content

Commit

Permalink
[FIX] Bug in FluxPosEmbed (huggingface#10115)
Browse files Browse the repository at this point in the history
* Fix get_1d_rotary_pos_embed in embedding.py

* Update embeddings.py

---------

Co-authored-by: hlky <hlky@hlky.ac>
  • Loading branch information
SahilCarterr and hlky authored Dec 5, 2024
1 parent 65ab105 commit 3335e22
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,7 +959,12 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
freqs_dtype = torch.float32 if is_mps else torch.float64
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype
self.axes_dim[i],
pos[:, i],
theta=self.theta,
repeat_interleave_real=True,
use_real=True,
freqs_dtype=freqs_dtype,
)
cos_out.append(cos)
sin_out.append(sin)
Expand Down

0 comments on commit 3335e22

Please sign in to comment.