-
Notifications
You must be signed in to change notification settings - Fork 22.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MPS] einsum returns incorrect matmul result on first invocation on nightly builds #85224
Comments
At least locally #85689 fixes this one for me |
Looks like the expectation in that code were that `.clone` will return contiguous tensor, so explicitly specify memory format Fixes #85675 and #85224 Pull Request resolved: #85689 Approved by: https://github.com/kulinseth
@Birch-san can I ask you to test the fix (if you are building PyTorch locally, or do you need to wait until nightlies are available?) |
hey @malfet, thanks for this. tested just now on certainly seems deterministic now, but performance has regressed massively since |
Well, einsum is expected to be much slower than mat-mul, but we can probably add a shortcut for this particular codepath (which should be unrelated to MPS) |
I tried a few nightly releases from September:
|
so, given you merged your fix for einsum determinism on the 27th: I think the perf regression must be something different, introduced on the 25th... |
Hmm, among changes on 6916826 @Birch-san can I ask you to |
we found the opposite: einsum is 36% faster than matmul, at least on 1.13.0.dev20220826. |
Hmm, what are the tensor sizes you are using?
|
hmm does it use opt_einsum then? |
# q.shape
# torch.Size([16, 4096, 40])
# k.shape
# torch.Size([16, 4096, 40])
einsum('b i d, b j d -> b i j', q, k)
# attn.shape
# torch.Size([16, 4096, 4096])
# v.shape
# torch.Size([16, 4096, 40])
einsum('b i j, b j d -> b i d', attn, v) |
which nightly do you want this measured on? latest? |
|
In that case, can you please try uninstalling opt_einsum in your environment and measure performance again |
hmm not as conclusive as my first measurement. do we need to try with a bigger tensor? |
Hey @Birch-san thanks for reply on invoke-ai/InvokeAI#814 (comment)_ |
Looks like the expectation in that code were that `.clone` will return contiguous tensor, so explicitly specify memory format Fixes #85675 and #85224 Pull Request resolved: #85689 Approved by: https://github.com/kulinseth
measured this again in the generating a stable-diffusion image (8 steps, Heun) takes 68.8secs in @kulinseth is it planned that this change will be promoted to stable? I think it will prevent Mac users from upgrading. |
@Birch-san thanks for raising this concern—could you please test wrapping your code with the line below to see if disabling opt_einsum would change the perf? I see your prev attempt was not so conclusive and this would be an easier way to disable using opt_einsum (vs uninstalling)
(EDIT'd --> I forgot the torch.backends prefix last night, but that should def be there, thanks @Birch-san for guinea pigging below) |
thanks for checking in. there's no such method in opt_einsum 3.3.0: and if I uninstall
I get this at the time of Exception has occurred: ModuleNotFoundError
No module named 'opt_einsum' with opt_einsum uninstalled via the above method: image generation took 71.6 seconds (i.e. still slow). |
Ah, thanks @Birch-san--I meant instead of installing/uninstalling opt-einsum, you can keep it installed and then use the backends to trigger it on or off. However, this is definitely an unintended error message, a fix is coming up: #86985 |
@Birch-san your results suggest that it is NOT the path computation cost that is causing the performance regression, which makes sense since that should have been skipped cuz your einsum call only had 2 arguments anyway. When I did the benchmarking for my change, the C++ changes added negligible perf to Linux. This combined with your later inconclusive results makes me think that the regression is not due to my einsum change, though I need to set up an mps machine to verify for sure. @Homemaderobot in the meantime, were you able to repro? |
Thanks for improving the defaults. I think it's looking like opt_einsum isn't the source of the problem? Since even when it's uninstalled and the torch backend says opt_einsum is not enabled: the perf regression reproduces. Any idea where to go from here? |
The perf tests were on randn(100,100). Is that big enough? The original tensors were each (16, 4096, 40) |
Yea, I too suspect the issue lies elsewhere. Would you mind running a benchmark on the bigger tensors then (no need to do anything about opt-einsum)? |
sure; got a benchmark handy? |
side-note: so, time to run 8 steps of Heun (fastest on left): |
I think something similar to your prev code would be helpful to compare the two releases:
But seeing your more recent messages does raise my eyebrow...I will test if this is MPS only. |
okay, a possible reason for previous benchmarking's being inconclusive: we didn't wait for asynchronous operations to synchronize. so we were probably only waiting for them to be enqueued. I think the following methodology is a bit fairer: import torch
import time
from torch import einsum
# MPS backend has no synchronization API comparable to CUDA's:
# https://pytorch.org/docs/master/notes/cuda.html#asynchronous-execution
# so we have to get creative.
# to force Python to wait for computation to complete: introduce a data-dependency on every element in the tensor.
baseline_start = time.perf_counter()
torch.rand(16, 4096, 4096, dtype=torch.float, device="mps").max().item()
print('for reference: our sync operation -- on a tensor of our target size -- is responsible for only %.4f seconds of overhead' % (time.perf_counter()-baseline_start))
repeats = 10
ein0_batch_duration = 0
for ix in range(repeats):
q = torch.rand(16, 4096, 40, dtype=torch.float, device="mps")
k = torch.rand(16, 4096, 40, dtype=torch.float, device="mps")
ein0_start = time.perf_counter()
einsum('b i d, b j d -> b i j', q, k).max().item()
ein0_duration = time.perf_counter()-ein0_start
print('einsum 0 iteration %d took %.4f seconds' % (ix, ein0_duration))
ein0_batch_duration += ein0_duration
print('%d iterations of einsum 0 took %.4f seconds; avg %.4f secs' % (repeats, ein0_batch_duration, ein0_batch_duration/repeats))
ein1_batch_duration = 0
for ix in range(repeats):
attn = torch.rand(16, 4096, 4096, dtype=torch.float, device="mps")
v = torch.rand(16, 4096, 40, dtype=torch.float, device="mps")
ein1_start = time.perf_counter()
einsum('b i j, b j d -> b i d', attn, v).max().item()
ein1_duration = time.perf_counter()-ein1_start
print('einsum 1 iteration %d took %.4f seconds' % (ix, ein1_duration))
ein1_batch_duration += ein1_duration
print('%d iterations of einsum 1 took %.4f seconds; avg %.4f secs' % (repeats, ein1_batch_duration, ein1_batch_duration/repeats)) 1.13.0:
Torch 1.12.1:
einsum 0 is 42x slower in the 1.13.0 release candidate. interesting asymmetry. to put that into context: Heun samples from the model curiously: "which einsum is the more intensive" swaps when we use the newer release. |
the interesting thing about how einsum 0 is formulated, is that it incurs a transpose: |
I tried to transpose the key and move it to contiguous memory before entering the einsum, which allows me to formulate einsum 0 the same way as einsum 1: import torch
import time
from torch import einsum
baseline_start = time.perf_counter()
torch.rand(16, 4096, 4096, dtype=torch.float, device="mps").max().item()
print('for reference: our sync operation -- on a tensor of our target size -- is responsible for only %.4f seconds of overhead' % (time.perf_counter()-baseline_start))
repeats = 10
k = torch.rand(16, 4096, 40, dtype=torch.float, device="mps")
start = time.perf_counter()
k.transpose(1, 2).contiguous().max().item()
print('transposing a key such as ours takes no more than %.4f seconds' % (time.perf_counter()-start))
ein0_var1_batch_duration = 0
for ix in range(repeats):
q = torch.rand(16, 4096, 40, dtype=torch.float, device="mps")
k = torch.rand(16, 4096, 40, dtype=torch.float, device="mps")
ein0_start = time.perf_counter()
einsum('b i d, b d j -> b i j', q, k.transpose(1, 2).contiguous()).max().item()
ein0_duration = time.perf_counter()-ein0_start
print('einsum 0 variation 1 iteration %d took %.4f seconds' % (ix, ein0_duration))
ein0_var1_batch_duration += ein0_duration
print('%d iterations of einsum 0 variation 1 took %.4f seconds; avg %.4f secs' % (repeats, ein0_var1_batch_duration, ein0_var1_batch_duration/repeats)) 1.13.0:
1.12.1:
no improvement; exactly the same perf outcomes as the original formulation of einsum 0. |
@Birch-san , @janeyx99 do you mind filing a separate issue for einsum perf on MPS, and probably close this one, as output is currently correct one, though perf indeed needs to be improved. |
sure thing @malfet. closing this on the basis that the original issue (correctness) is confirmed solved since at least opened new ticket to follow the perf regression: |
Fixes the confusing situation mentioned here #85224 (comment) by - setting better OG defaults - changing warnings to errors now that we have better defaults Test plan: - Ran einsum tests locally + CI - Uninstalled opt-einsum and ran through setting - `enabled` to False (doesn't throw error) - `strategy` to anything that's not None (errors) - `strategy` to None (noops) - Installed opt-einsum and ran through setting - `enabled` to False (doesn't throw error) - `enabled` to True (doesn't throw error, no ops + defaults to 'auto') - `strategy` to random string (errors) - `strategy` to None (noops, still is 'auto') - `strategy` to 'greedy' (is set to 'greedy') Pull Request resolved: #86985 Approved by: https://github.com/soulitzer
Fixes the confusing situation mentioned here #85224 (comment) by - setting better OG defaults - changing warnings to errors now that we have better defaults Test plan: - Ran einsum tests locally + CI - Uninstalled opt-einsum and ran through setting - `enabled` to False (doesn't throw error) - `strategy` to anything that's not None (errors) - `strategy` to None (noops) - Installed opt-einsum and ran through setting - `enabled` to False (doesn't throw error) - `enabled` to True (doesn't throw error, no ops + defaults to 'auto') - `strategy` to random string (errors) - `strategy` to None (noops, still is 'auto') - `strategy` to 'greedy' (is set to 'greedy') Pull Request resolved: #86985 Approved by: https://github.com/soulitzer
Fixes the confusing situation mentioned here #85224 (comment) by - setting better OG defaults - changing warnings to errors now that we have better defaults Test plan: - Ran einsum tests locally + CI - Uninstalled opt-einsum and ran through setting - `enabled` to False (doesn't throw error) - `strategy` to anything that's not None (errors) - `strategy` to None (noops) - Installed opt-einsum and ran through setting - `enabled` to False (doesn't throw error) - `enabled` to True (doesn't throw error, no ops + defaults to 'auto') - `strategy` to random string (errors) - `strategy` to None (noops, still is 'auto') - `strategy` to 'greedy' (is set to 'greedy') Pull Request resolved: #86985 Approved by: https://github.com/soulitzer
Fixes the confusing situation mentioned here #85224 (comment) by - setting better OG defaults - changing warnings to errors now that we have better defaults Test plan: - Ran einsum tests locally + CI - Uninstalled opt-einsum and ran through setting - `enabled` to False (doesn't throw error) - `strategy` to anything that's not None (errors) - `strategy` to None (noops) - Installed opt-einsum and ran through setting - `enabled` to False (doesn't throw error) - `enabled` to True (doesn't throw error, no ops + defaults to 'auto') - `strategy` to random string (errors) - `strategy` to None (noops, still is 'auto') - `strategy` to 'greedy' (is set to 'greedy') Pull Request resolved: #86985 Approved by: https://github.com/soulitzer
🐛 Describe the bug
the context for this is that it's part of cross-attention in stable-diffusion:
https://github.com/CompVis/stable-diffusion/blob/69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc/ldm/modules/attention.py#L180
it means that if we want to produce n identical images in one python run: the first will be wrong, but subsequent images will be correct. this makes it hard to generate transitions (e.g. animations or latent walks), where you want to be always starting from the same image before you make a tweak:
works fine on 1.12.1.
broken on
1.13.0.dev20220917
.I believe it was broken at least as far back as
1.13.0.dev20220826
(from which I upgraded today to see if this was fixed).this also explains why I got different images from
einsum()
than I did viamatmul()
:huggingface/diffusers#452 (comment)
Versions
cc @kulinseth @albanD
The text was updated successfully, but these errors were encountered: