Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MPS] einsum returns incorrect matmul result on first invocation on nightly builds #85224

Closed
Birch-san opened this issue Sep 17, 2022 · 36 comments
Assignees
Labels
module: correctness (silent) issue that returns an incorrect result silently module: mps Related to Apple Metal Performance Shaders framework module: regression It used to work, and now it doesn't triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@Birch-san
Copy link

Birch-san commented Sep 17, 2022

🐛 Describe the bug

the context for this is that it's part of cross-attention in stable-diffusion:
https://github.com/CompVis/stable-diffusion/blob/69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc/ldm/modules/attention.py#L180
it means that if we want to produce n identical images in one python run: the first will be wrong, but subsequent images will be correct. this makes it hard to generate transitions (e.g. animations or latent walks), where you want to be always starting from the same image before you make a tweak:

image

from torch import einsum, tensor, matmul
t = tensor([[[0., 1.],
             [2., 3.]]], device='mps')

# result from CPU is correct:
einsum('b i d, b j d -> b i j', t.cpu(), t.cpu())
# tensor([[[ 1.,  3.],
#          [ 3., 13.]]])

# first result from MPS is wrong:
einsum('b i d, b j d -> b i j', t, t)
# tensor([[[ 2.,  3.],
#          [ 6., 11.]]], device='mps:0')

# subsequent results from MPS are correct:
einsum('b i d, b j d -> b i j', t, t)
# tensor([[[ 1.,  3.],
#          [ 3., 13.]]], device='mps:0')

einsum('b i d, b j d -> b i j', t, t)
# tensor([[[ 1.,  3.],
#          [ 3., 13.]]], device='mps:0')

# btw this einsum is equivalent to the following matmul:
matmul(t, t.transpose(1, 2))
# tensor([[[ 1.,  3.],
#          [ 3., 13.]]], device='mps:0')
# in other words a matmul over these:
# tensor([[[0., 1.],
#          [2., 3.]]]) *
# tensor([[[0., 2.],
#          [1., 3.]]]) =
# tensor([[[0*0+1*1, 2*0+3*1],
#          [2*0+3*1, 2*2+3*3]]])

works fine on 1.12.1.
broken on 1.13.0.dev20220917.
I believe it was broken at least as far back as 1.13.0.dev20220826 (from which I upgraded today to see if this was fixed).

this also explains why I got different images from einsum() than I did via matmul():
huggingface/diffusers#452 (comment)

Versions

PyTorch version: 1.13.0.dev20220917
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 12.5 (arm64)
GCC version: Could not collect
Clang version: 13.0.0 (clang-1300.0.29.30)
CMake version: version 3.22.1
Libc version: N/A

Python version: 3.10.4 (main, Mar 31 2022, 03:37:37) [Clang 12.0.0 ] (64-bit runtime)
Python platform: macOS-12.5-arm64-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.22.4
[pip3] pytorch-lightning==1.4.2
[pip3] torch==1.13.0.dev20220917
[pip3] torch-fidelity==0.3.0
[pip3] torchdiffeq==0.2.3
[pip3] torchmetrics==0.6.0
[pip3] torchtyping==0.1.4
[pip3] torchvision==0.14.0.dev20220917
[conda] numpy                     1.22.4                   pypi_0    pypi
[conda] pytorch-lightning         1.4.2                    pypi_0    pypi
[conda] torch                     1.13.0.dev20220917          pypi_0    pypi
[conda] torch-fidelity            0.3.0                    pypi_0    pypi
[conda] torchdiffeq               0.2.3                    pypi_0    pypi
[conda] torchmetrics              0.6.0                    pypi_0    pypi
[conda] torchtyping               0.1.4                    pypi_0    pypi
[conda] torchvision               0.14.0.dev20220917          pypi_0    pypi

cc @kulinseth @albanD

@tmm1
Copy link

tmm1 commented Sep 17, 2022

cc huggingface/diffusers#372

@malfet malfet added module: correctness (silent) issue that returns an incorrect result silently module: mps Related to Apple Metal Performance Shaders framework module: regression It used to work, and now it doesn't labels Sep 20, 2022
@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Sep 20, 2022
@abhudev abhudev self-assigned this Sep 21, 2022
@malfet
Copy link
Contributor

malfet commented Sep 27, 2022

At least locally #85689 fixes this one for me

pytorchmergebot pushed a commit that referenced this issue Sep 27, 2022
Looks like the expectation in that code were that `.clone` will return contiguous tensor, so explicitly specify memory format

Fixes #85675 and #85224

Pull Request resolved: #85689
Approved by: https://github.com/kulinseth
@malfet
Copy link
Contributor

malfet commented Sep 27, 2022

@Birch-san can I ask you to test the fix (if you are building PyTorch locally, or do you need to wait until nightlies are available?)

@Birch-san
Copy link
Author

hey @malfet, thanks for this. tested just now on 1.13.0.dev20220928.

image

certainly seems deterministic now, but performance has regressed massively since 1.13.0.dev20220917 (the last nightly I tried out).
an 8-step image generation (which used to take 10.4 secs) now takes 67 secs.
for reference: pytorch stable 1.12.1 takes 9.8 secs.

@malfet
Copy link
Contributor

malfet commented Sep 28, 2022

Well, einsum is expected to be much slower than mat-mul, but we can probably add a shortcut for this particular codepath (which should be unrelated to MPS)

@Birch-san
Copy link
Author

Birch-san commented Sep 28, 2022

I tried a few nightly releases from September:

28 slow
27 slow
26 slow
25 slow
24 fast
22 fast
18 fast
17 fast

@Birch-san
Copy link
Author

so, given you merged your fix for einsum determinism on the 27th: I think the perf regression must be something different, introduced on the 25th...

@malfet
Copy link
Contributor

malfet commented Sep 28, 2022

Hmm, among changes on 6916826
there were no MPS specific changes, but there were one for einsum, see #84890 (cc: @janeyx99 )

@Birch-san can I ask you to timeit einsum('b i d, b j d -> b i j') on say 100x100 matrices?

@Birch-san
Copy link
Author

einsum is expected to be much slower than mat-mul

we found the opposite:
huggingface/diffusers#452 (comment)

einsum is 36% faster than matmul, at least on 1.13.0.dev20220826.

@malfet
Copy link
Contributor

malfet commented Sep 28, 2022

Hmm, what are the tensor sizes you are using?

% python -c "import torch; import timeit; x=torch.rand(100, 100, device='mps');y=x.unsqueeze(0);print(timeit.timeit(lambda: torch.einsum('b i d, b j d -> b i j', y, y), number=100), timeit.timeit(lambda: torch.mm(x, x), number=100))"
0.06760179199999994 0.006913500000000017

@Birch-san
Copy link
Author

add python opt_einsum path passthrough

hmm does it use opt_einsum then?
invoke-ai/InvokeAI#517 (comment)
I found opt_einsum to be 30x slower on MPS than regular

@Birch-san
Copy link
Author

# q.shape
# torch.Size([16, 4096, 40])
# k.shape
# torch.Size([16, 4096, 40])
einsum('b i d, b j d -> b i j', q, k)

# attn.shape
# torch.Size([16, 4096, 4096])
# v.shape
# torch.Size([16, 4096, 40])
einsum('b i j, b j d -> b i d', attn, v)

@Birch-san
Copy link
Author

@Birch-san can I ask you to timeit einsum('b i d, b j d -> b i j') on say 100x100 matrices?

which nightly do you want this measured on? latest?

@Birch-san
Copy link
Author

Birch-san commented Sep 28, 2022

# 1.13.0.dev20220928
python -c "import torch; import timeit; x=torch.rand(100, 100, device='mps');y=x.unsqueeze(0);print(timeit.timeit(lambda: torch.einsum('b i d, b j d -> b i j', y, y), number=100), timeit.timeit(lambda: torch.mm(x, x), number=100))"
0.3051431248895824 0.008158416952937841

# 1.13.0.dev20220925
python -c "import torch; import timeit; x=torch.rand(100, 100, device='mps');y=x.unsqueeze(0);print(timeit.timeit(lambda: torch.einsum('b i d, b j d -> b i j', y, y), number=100), timeit.timeit(lambda: torch.mm(x, x), number=100))"
0.09183137491345406 0.007564791943877935

# 1.13.0.dev20220924
python -c "import torch; import timeit; x=torch.rand(100, 100, device='mps');y=x.unsqueeze(0);print(timeit.timeit(lambda: torch.einsum('b i d, b j d -> b i j', y, y), number=100), timeit.timeit(lambda: torch.mm(x, x), number=100))"
0.053649208042770624 0.006840792018920183

# 1.12.1
python -c "import torch; import timeit; x=torch.rand(100, 100, device='mps');y=x.unsqueeze(0);print(timeit.timeit(lambda: torch.einsum('b i d, b j d -> b i j', y, y), number=100), timeit.timeit(lambda: torch.mm(x, x), number=100))"
0.047988582868129015 0.010437542106956244

@malfet
Copy link
Contributor

malfet commented Sep 28, 2022

I found opt_einsum to be 30x slower on MPS than regular

In that case, can you please try uninstalling opt_einsum in your environment and measure performance again

@Birch-san
Copy link
Author

Birch-san commented Sep 28, 2022

# 1.13.0.dev20220928 (new conda env; opt_einsum not installed; confirmed by attempting and failing to import it)
python -c "import torch; import timeit; x=torch.rand(100, 100, device='mps');y=x.unsqueeze(0);print(timeit.timeit(lambda: torch.einsum('b i d, b j d -> b i j', y, y), number=100), timeit.timeit(lambda: torch.mm(x, x), number=100))"
0.09281750023365021 0.00685783289372921

pip install opt_einsum

# 1.13.0.dev20220928 (opt_einsum installed)
python -c "import torch; import timeit; x=torch.rand(100, 100, device='mps');y=x.unsqueeze(0);print(timeit.timeit(lambda: torch.einsum('b i d, b j d -> b i j', y, y), number=100), timeit.timeit(lambda: torch.mm(x, x), number=100))"
0.10102700022980571 0.007209625095129013

hmm not as conclusive as my first measurement. do we need to try with a bigger tensor?

@Homemaderobot
Copy link

Hey @Birch-san thanks for reply on invoke-ai/InvokeAI#814 (comment)_
re - 6 times slower with PyTorch nightly - is there a way to install a previous version that's at full speed?

mehtanirav pushed a commit that referenced this issue Oct 4, 2022
Looks like the expectation in that code were that `.clone` will return contiguous tensor, so explicitly specify memory format

Fixes #85675 and #85224

Pull Request resolved: #85689
Approved by: https://github.com/kulinseth
@Birch-san
Copy link
Author

measured this again in the 1.13.0 release candidate.

generating a stable-diffusion image (8 steps, Heun) takes 68.8secs in 1.13.0, a 6-fold perf regression from 1.12.1 stable's 10.4secs.

@kulinseth is it planned that this change will be promoted to stable? I think it will prevent Mac users from upgrading.

@janeyx99
Copy link
Contributor

janeyx99 commented Oct 14, 2022

@Birch-san thanks for raising this concern—could you please test wrapping your code with the line below to see if disabling opt_einsum would change the perf? I see your prev attempt was not so conclusive and this would be an easier way to disable using opt_einsum (vs uninstalling)

with torch.backends.opt_einsum.flags(enabled=False):
        # your code

(EDIT'd --> I forgot the torch.backends prefix last night, but that should def be there, thanks @Birch-san for guinea pigging below)

@Birch-san
Copy link
Author

Birch-san commented Oct 14, 2022

thanks for checking in. there's no such method in opt_einsum 3.3.0:

image

and if I uninstall opt_einsum:

(ldmwaifu) ➜  stable-diffusion git:(birch-mps-waifu) ✗ pip uninstall opt_einsum
Found existing installation: opt-einsum 3.3.0
Uninstalling opt-einsum-3.3.0:
  Would remove:
    /Users/birch/anaconda3/envs/ldmwaifu/lib/python3.10/site-packages/opt_einsum-3.3.0.dist-info/*
    /Users/birch/anaconda3/envs/ldmwaifu/lib/python3.10/site-packages/opt_einsum/*
Proceed (Y/n)? Y
  Successfully uninstalled opt-einsum-3.3.0

I get this at the time of import opt_einsum:

Exception has occurred: ModuleNotFoundError
No module named 'opt_einsum'

with opt_einsum uninstalled via the above method: image generation took 71.6 seconds (i.e. still slow).

@Birch-san
Copy link
Author

Birch-san commented Oct 14, 2022

ah, you mean from torch.backends import opt_einsum?
the method exists, but cannot set the flag because opt_einsum is not installed (see enabled: False):

image

@janeyx99
Copy link
Contributor

Ah, thanks @Birch-san--I meant instead of installing/uninstalling opt-einsum, you can keep it installed and then use the backends to trigger it on or off. However, this is definitely an unintended error message, a fix is coming up: #86985

@janeyx99
Copy link
Contributor

@Birch-san your results suggest that it is NOT the path computation cost that is causing the performance regression, which makes sense since that should have been skipped cuz your einsum call only had 2 arguments anyway.

When I did the benchmarking for my change, the C++ changes added negligible perf to Linux. This combined with your later inconclusive results makes me think that the regression is not due to my einsum change, though I need to set up an mps machine to verify for sure. @Homemaderobot in the meantime, were you able to repro?

@Birch-san
Copy link
Author

Birch-san commented Oct 14, 2022

Thanks for improving the defaults.

I think it's looking like opt_einsum isn't the source of the problem? Since even when it's uninstalled and the torch backend says opt_einsum is not enabled: the perf regression reproduces.

Any idea where to go from here?

@Birch-san
Copy link
Author

The perf tests were on randn(100,100). Is that big enough? The original tensors were each (16, 4096, 40)

@janeyx99
Copy link
Contributor

Yea, I too suspect the issue lies elsewhere. Would you mind running a benchmark on the bigger tensors then (no need to do anything about opt-einsum)?

@Birch-san
Copy link
Author

sure; got a benchmark handy?

@Birch-san
Copy link
Author

Birch-san commented Oct 14, 2022

side-note:
if I reformulate the einsums as matmuls like so: image generation takes 14.6 secs on 1.13.0 release candidate. comparable to what I get on 1.12.1.

so, time to run 8 steps of Heun (fastest on left):
1.12.1 einsum (10.7s) > 1.12.1, 1.13.0 matmul (14.6s) > 1.13.0 einsum (71.6s)
this is measuring the whole forward pass of latent diffusion of course, but the choice of einsum vs matmul during cross-attention seems to dominate.

@janeyx99
Copy link
Contributor

I think something similar to your prev code would be helpful to compare the two releases:

import torch
q = torch.rand(16, 4096, 40, dtype=torch.float, device="mps")
k = torch.rand(16, 4096, 40, dtype=torch.float, device="mps")
einsum('b i d, b j d -> b i j', q, k)
attn = torch.rand(16, 4096, 4096, dtype=torch.float, device="mps")
v = torch.rand(16, 4096, 40, dtype=torch.float, device="mps")
einsum('b i j, b j d -> b i d', attn, v)

But seeing your more recent messages does raise my eyebrow...I will test if this is MPS only.

@Birch-san
Copy link
Author

Birch-san commented Oct 14, 2022

okay, a possible reason for previous benchmarking's being inconclusive: we didn't wait for asynchronous operations to synchronize. so we were probably only waiting for them to be enqueued. I think the following methodology is a bit fairer:

import torch
import time
from torch import einsum

# MPS backend has no synchronization API comparable to CUDA's:
# https://pytorch.org/docs/master/notes/cuda.html#asynchronous-execution
# so we have to get creative.
# to force Python to wait for computation to complete: introduce a data-dependency on every element in the tensor.
baseline_start = time.perf_counter()
torch.rand(16, 4096, 4096, dtype=torch.float, device="mps").max().item()
print('for reference: our sync operation -- on a tensor of our target size -- is responsible for only %.4f seconds of overhead' % (time.perf_counter()-baseline_start))

repeats = 10

ein0_batch_duration = 0
for ix in range(repeats):
  q = torch.rand(16, 4096, 40, dtype=torch.float, device="mps")
  k = torch.rand(16, 4096, 40, dtype=torch.float, device="mps")
  ein0_start = time.perf_counter()
  einsum('b i d, b j d -> b i j', q, k).max().item()
  ein0_duration = time.perf_counter()-ein0_start
  print('einsum 0 iteration %d took %.4f seconds' % (ix, ein0_duration))
  ein0_batch_duration += ein0_duration
print('%d iterations of einsum 0 took %.4f seconds; avg %.4f secs' % (repeats, ein0_batch_duration, ein0_batch_duration/repeats))

ein1_batch_duration = 0
for ix in range(repeats):
  attn = torch.rand(16, 4096, 4096, dtype=torch.float, device="mps")
  v = torch.rand(16, 4096, 40, dtype=torch.float, device="mps")
  ein1_start = time.perf_counter()
  einsum('b i j, b j d -> b i d', attn, v).max().item()
  ein1_duration = time.perf_counter()-ein1_start
  print('einsum 1 iteration %d took %.4f seconds' % (ix, ein1_duration))
  ein1_batch_duration += ein1_duration
print('%d iterations of einsum 1 took %.4f seconds; avg %.4f secs' % (repeats, ein1_batch_duration, ein1_batch_duration/repeats))

1.13.0:

for reference: our sync operation -- on a tensor of our target size -- is responsible for only 0.1172 seconds of overhead
einsum 0 iteration 0 took 1.0375 seconds
einsum 0 iteration 1 took 0.6801 seconds
einsum 0 iteration 2 took 0.6819 seconds
einsum 0 iteration 3 took 0.6785 seconds
einsum 0 iteration 4 took 0.6775 seconds
einsum 0 iteration 5 took 0.6762 seconds
einsum 0 iteration 6 took 0.6772 seconds
einsum 0 iteration 7 took 0.6744 seconds
einsum 0 iteration 8 took 0.6788 seconds
einsum 0 iteration 9 took 0.6774 seconds
10 iterations of einsum 0 took 7.1397 seconds; avg 0.7140 secs
einsum 1 iteration 0 took 0.0504 seconds
einsum 1 iteration 1 took 0.0499 seconds
einsum 1 iteration 2 took 0.0498 seconds
einsum 1 iteration 3 took 0.0491 seconds
einsum 1 iteration 4 took 0.0498 seconds
einsum 1 iteration 5 took 0.0500 seconds
einsum 1 iteration 6 took 0.0500 seconds
einsum 1 iteration 7 took 0.0486 seconds
einsum 1 iteration 8 took 0.0619 seconds
einsum 1 iteration 9 took 0.0488 seconds
10 iterations of einsum 1 took 0.5081 seconds; avg 0.0508 secs

Torch 1.12.1:

for reference: our sync operation -- on a tensor of our target size -- is responsible for only 0.1467 seconds of overhead
einsum 0 iteration 0 took 0.0856 seconds
einsum 0 iteration 1 took 0.0095 seconds
einsum 0 iteration 2 took 0.0096 seconds
einsum 0 iteration 3 took 0.0092 seconds
einsum 0 iteration 4 took 0.0093 seconds
einsum 0 iteration 5 took 0.0093 seconds
einsum 0 iteration 6 took 0.0093 seconds
einsum 0 iteration 7 took 0.0094 seconds
einsum 0 iteration 8 took 0.0092 seconds
einsum 0 iteration 9 took 0.0093 seconds
10 iterations of einsum 0 took 0.1696 seconds; avg 0.0170 secs
einsum 1 iteration 0 took 0.0397 seconds
einsum 1 iteration 1 took 0.0892 seconds
einsum 1 iteration 2 took 0.0648 seconds
einsum 1 iteration 3 took 0.0409 seconds
einsum 1 iteration 4 took 0.0387 seconds
einsum 1 iteration 5 took 0.0403 seconds
einsum 1 iteration 6 took 0.0405 seconds
einsum 1 iteration 7 took 0.0403 seconds
einsum 1 iteration 8 took 0.0398 seconds
einsum 1 iteration 9 took 0.0409 seconds
10 iterations of einsum 1 took 0.4750 seconds; avg 0.0475 secs

einsum 0 is 42x slower in the 1.13.0 release candidate.
einsum 1 is 1.07x slower in the 1.13.0 release candidate.

interesting asymmetry.

to put that into context: Heun samples from the model 2n-1 times, so we would be exposed to this half-second bottleneck 15 times for a draft image, and about triple that for a properly-sampled image.

curiously: "which einsum is the more intensive" swaps when we use the newer release.

@Birch-san
Copy link
Author

the interesting thing about how einsum 0 is formulated, is that it incurs a transpose:
Birch-san/stable-diffusion@d2d533d

@Birch-san
Copy link
Author

I tried to transpose the key and move it to contiguous memory before entering the einsum, which allows me to formulate einsum 0 the same way as einsum 1:

import torch
import time
from torch import einsum

baseline_start = time.perf_counter()
torch.rand(16, 4096, 4096, dtype=torch.float, device="mps").max().item()
print('for reference: our sync operation -- on a tensor of our target size -- is responsible for only %.4f seconds of overhead' % (time.perf_counter()-baseline_start))

repeats = 10

k = torch.rand(16, 4096, 40, dtype=torch.float, device="mps")
start = time.perf_counter()
k.transpose(1, 2).contiguous().max().item()
print('transposing a key such as ours takes no more than %.4f seconds' % (time.perf_counter()-start))

ein0_var1_batch_duration = 0
for ix in range(repeats):
  q = torch.rand(16, 4096, 40, dtype=torch.float, device="mps")
  k = torch.rand(16, 4096, 40, dtype=torch.float, device="mps")
  ein0_start = time.perf_counter()
  einsum('b i d, b d j -> b i j', q, k.transpose(1, 2).contiguous()).max().item()
  ein0_duration = time.perf_counter()-ein0_start
  print('einsum 0 variation 1 iteration %d took %.4f seconds' % (ix, ein0_duration))
  ein0_var1_batch_duration += ein0_duration
print('%d iterations of einsum 0 variation 1 took %.4f seconds; avg %.4f secs' % (repeats, ein0_var1_batch_duration, ein0_var1_batch_duration/repeats))

1.13.0:

for reference: our sync operation -- on a tensor of our target size -- is responsible for only 0.1012 seconds of overhead
transposing a key such as ours takes no more than 0.0335 seconds
einsum 0 variation 1 iteration 0 took 1.0355 seconds
einsum 0 variation 1 iteration 1 took 0.6850 seconds
einsum 0 variation 1 iteration 2 took 0.6746 seconds
einsum 0 variation 1 iteration 3 took 0.6740 seconds
einsum 0 variation 1 iteration 4 took 0.6757 seconds
einsum 0 variation 1 iteration 5 took 0.6737 seconds
einsum 0 variation 1 iteration 6 took 0.6745 seconds
einsum 0 variation 1 iteration 7 took 0.6810 seconds
einsum 0 variation 1 iteration 8 took 0.6759 seconds
einsum 0 variation 1 iteration 9 took 0.6745 seconds
10 iterations of einsum 0 variation 1 took 7.1243 seconds; avg 0.7124 secs

1.12.1:

for reference: our sync operation -- on a tensor of our target size -- is responsible for only 0.1409 seconds of overhead
transposing a key such as ours takes no more than 0.0278 seconds
einsum 0 variation 1 iteration 0 took 0.0677 seconds
einsum 0 variation 1 iteration 1 took 0.0093 seconds
einsum 0 variation 1 iteration 2 took 0.0095 seconds
einsum 0 variation 1 iteration 3 took 0.0094 seconds
einsum 0 variation 1 iteration 4 took 0.0092 seconds
einsum 0 variation 1 iteration 5 took 0.0094 seconds
einsum 0 variation 1 iteration 6 took 0.0094 seconds
einsum 0 variation 1 iteration 7 took 0.0093 seconds
einsum 0 variation 1 iteration 8 took 0.0096 seconds
einsum 0 variation 1 iteration 9 took 0.0095 seconds
10 iterations of einsum 0 variation 1 took 0.1523 seconds; avg 0.0152 secs

no improvement; exactly the same perf outcomes as the original formulation of einsum 0.
so, although einsum 0 is notable for doing a transpose... perhaps that's not why it's slower. then, maybe the fact that its input tensors are different sizes than those submitted to einsum 1?

@malfet
Copy link
Contributor

malfet commented Oct 14, 2022

@Birch-san , @janeyx99 do you mind filing a separate issue for einsum perf on MPS, and probably close this one, as output is currently correct one, though perf indeed needs to be improved.

@Birch-san
Copy link
Author

sure thing @malfet. closing this on the basis that the original issue (correctness) is confirmed solved since at least 1.13.0.dev20220928 (#85224 (comment)).

opened new ticket to follow the perf regression:
#87010

pytorchmergebot pushed a commit that referenced this issue Oct 15, 2022
Fixes the confusing situation mentioned here #85224 (comment) by

- setting better OG defaults
- changing warnings to errors now that we have better defaults

Test plan:
- Ran einsum tests locally + CI
- Uninstalled opt-einsum and ran through setting
     - `enabled` to False (doesn't throw error)
     - `strategy` to anything that's not None (errors)
     - `strategy` to None (noops)
- Installed opt-einsum and ran through setting
     - `enabled` to False (doesn't throw error)
     - `enabled` to True (doesn't throw error, no ops + defaults to 'auto')
     - `strategy` to random string (errors)
     - `strategy` to None (noops, still is 'auto')
     - `strategy` to 'greedy' (is set to 'greedy')
Pull Request resolved: #86985
Approved by: https://github.com/soulitzer
janeyx99 added a commit that referenced this issue Oct 17, 2022
Fixes the confusing situation mentioned here #85224 (comment) by

- setting better OG defaults
- changing warnings to errors now that we have better defaults

Test plan:
- Ran einsum tests locally + CI
- Uninstalled opt-einsum and ran through setting
     - `enabled` to False (doesn't throw error)
     - `strategy` to anything that's not None (errors)
     - `strategy` to None (noops)
- Installed opt-einsum and ran through setting
     - `enabled` to False (doesn't throw error)
     - `enabled` to True (doesn't throw error, no ops + defaults to 'auto')
     - `strategy` to random string (errors)
     - `strategy` to None (noops, still is 'auto')
     - `strategy` to 'greedy' (is set to 'greedy')
Pull Request resolved: #86985
Approved by: https://github.com/soulitzer
pytorchmergebot pushed a commit that referenced this issue Oct 18, 2022
Fixes the confusing situation mentioned here #85224 (comment) by

- setting better OG defaults
- changing warnings to errors now that we have better defaults

Test plan:
- Ran einsum tests locally + CI
- Uninstalled opt-einsum and ran through setting
     - `enabled` to False (doesn't throw error)
     - `strategy` to anything that's not None (errors)
     - `strategy` to None (noops)
- Installed opt-einsum and ran through setting
     - `enabled` to False (doesn't throw error)
     - `enabled` to True (doesn't throw error, no ops + defaults to 'auto')
     - `strategy` to random string (errors)
     - `strategy` to None (noops, still is 'auto')
     - `strategy` to 'greedy' (is set to 'greedy')
Pull Request resolved: #86985
Approved by: https://github.com/soulitzer
atalman pushed a commit that referenced this issue Oct 19, 2022
Fixes the confusing situation mentioned here #85224 (comment) by

- setting better OG defaults
- changing warnings to errors now that we have better defaults

Test plan:
- Ran einsum tests locally + CI
- Uninstalled opt-einsum and ran through setting
     - `enabled` to False (doesn't throw error)
     - `strategy` to anything that's not None (errors)
     - `strategy` to None (noops)
- Installed opt-einsum and ran through setting
     - `enabled` to False (doesn't throw error)
     - `enabled` to True (doesn't throw error, no ops + defaults to 'auto')
     - `strategy` to random string (errors)
     - `strategy` to None (noops, still is 'auto')
     - `strategy` to 'greedy' (is set to 'greedy')
Pull Request resolved: #86985
Approved by: https://github.com/soulitzer
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: correctness (silent) issue that returns an incorrect result silently module: mps Related to Apple Metal Performance Shaders framework module: regression It used to work, and now it doesn't triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

7 participants