From b0083f8d50170e5d8c5bd7892d8e57da93d2c05a Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 9 Sep 2022 16:56:07 +0530 Subject: [PATCH 1/2] use torch.matmul instead of einsum --- src/diffusers/models/attention.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index de9c92691af6..ef93963fea75 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -275,11 +275,9 @@ def _attention(self, query, key, value, sequence_length, dim): for i in range(hidden_states.shape[0] // slice_size): start_idx = i * slice_size end_idx = (i + 1) * slice_size - attn_slice = ( - torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) * self.scale - ) - attn_slice = attn_slice.softmax(dim=-1) - attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx]) + attn_slice = torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale + attn_slice = attn_slice.softmax() + attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx]) hidden_states[start_idx:end_idx] = attn_slice From 38addfb6792044d00e85f1c146ef1d33724e3a31 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 9 Sep 2022 17:00:31 +0530 Subject: [PATCH 2/2] fix softmax --- src/diffusers/models/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index ef93963fea75..094ca0fb2299 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -276,7 +276,7 @@ def _attention(self, query, key, value, sequence_length, dim): start_idx = i * slice_size end_idx = (i + 1) * slice_size attn_slice = torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale - attn_slice = attn_slice.softmax() + attn_slice = attn_slice.softmax(dim=-1) attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx]) hidden_states[start_idx:end_idx] = attn_slice