From 1283c18541b437ca092e8d4df070afccf8a1c5c2 Mon Sep 17 00:00:00 2001 From: SahilCarterr <110806554+SahilCarterr@users.noreply.github.com> Date: Wed, 4 Dec 2024 17:33:42 +0530 Subject: [PATCH 1/2] Fix get_1d_rotary_pos_embed in embedding.py --- src/diffusers/models/embeddings.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 91451fa9aac2..412502d8aedc 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -959,7 +959,12 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: freqs_dtype = torch.float32 if is_mps else torch.float64 for i in range(n_axes): cos, sin = get_1d_rotary_pos_embed( - self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype + self.axes_dim[i], + pos[:, i], + self.theta, + repeat_interleave_real=True, + use_real=True, + freqs_dtype=freqs_dtype, ) cos_out.append(cos) sin_out.append(sin) From c52f72c3b331b2408d4f1471d7f9d65ddd5591c3 Mon Sep 17 00:00:00 2001 From: SahilCarterr <110806554+SahilCarterr@users.noreply.github.com> Date: Thu, 5 Dec 2024 07:20:58 +0530 Subject: [PATCH 2/2] Update embeddings.py --- src/diffusers/models/embeddings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 412502d8aedc..8f8f1073da74 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -961,7 +961,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: cos, sin = get_1d_rotary_pos_embed( self.axes_dim[i], pos[:, i], - self.theta, + theta=self.theta, repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype,