diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 91451fa9aac2..8f8f1073da74 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -959,7 +959,12 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: freqs_dtype = torch.float32 if is_mps else torch.float64 for i in range(n_axes): cos, sin = get_1d_rotary_pos_embed( - self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype + self.axes_dim[i], + pos[:, i], + theta=self.theta, + repeat_interleave_real=True, + use_real=True, + freqs_dtype=freqs_dtype, ) cos_out.append(cos) sin_out.append(sin)