diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index de9c92691af6..094ca0fb2299 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -275,11 +275,9 @@ def _attention(self, query, key, value, sequence_length, dim): for i in range(hidden_states.shape[0] // slice_size): start_idx = i * slice_size end_idx = (i + 1) * slice_size - attn_slice = ( - torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) * self.scale - ) + attn_slice = torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale attn_slice = attn_slice.softmax(dim=-1) - attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx]) + attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx]) hidden_states[start_idx:end_idx] = attn_slice