Skip to content

Commit

Permalink
try out opt_einsum. it's 30x slower on MPS. https://github.com/lstein…
Browse files Browse the repository at this point in the history
  • Loading branch information
Birch-san committed Sep 14, 2022
1 parent 18bb5f8 commit b7357a7
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from opt_einsum import contract

from ldm.modules.diffusionmodules.util import checkpoint

Expand Down Expand Up @@ -179,7 +180,7 @@ def forward(self, x, context=None, mask=None):

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
sim = contract('b i d, b j d -> b i j', q, k) * self.scale
del q, k

if exists(mask):
Expand All @@ -193,7 +194,7 @@ def forward(self, x, context=None, mask=None):
attn = sim.softmax(dim=-1)
del sim

out = einsum('b i j, b j d -> b i d', attn, v)
out = contract('b i j, b j d -> b i d', attn, v)
del attn, v
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
del h
Expand Down

0 comments on commit b7357a7

Please sign in to comment.