Skip to content

Commit 8ea07f3

Browse files
skurovecAny-Winter-4079
authored andcommitted
reintroduce fix for m1 from PR#579 missing after merge
Make results reproducible (so runs with the same seed produce the same result). Implements fix by @wbowling referenced in #397 (comment)
1 parent 79e79b7 commit 8ea07f3

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

ldm/generate.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,24 @@
3434
from ldm.invoke.devices import choose_torch_device, choose_precision
3535
from ldm.invoke.conditioning import get_uc_and_c
3636

37+
def fix_func(orig):
38+
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
39+
def new_func(*args, **kw):
40+
device = kw.get("device", "mps")
41+
kw["device"]="cpu"
42+
return orig(*args, **kw).to(device)
43+
return new_func
44+
return orig
45+
46+
torch.rand = fix_func(torch.rand)
47+
torch.rand_like = fix_func(torch.rand_like)
48+
torch.randn = fix_func(torch.randn)
49+
torch.randn_like = fix_func(torch.randn_like)
50+
torch.randint = fix_func(torch.randint)
51+
torch.randint_like = fix_func(torch.randint_like)
52+
torch.bernoulli = fix_func(torch.bernoulli)
53+
torch.multinomial = fix_func(torch.multinomial)
54+
3755
"""Simplified text to image API for stable diffusion/latent diffusion
3856
3957
Example Usage:

0 commit comments

Comments
 (0)