Description
🐛 Describe the bug
the context for this is that it's part of cross-attention in stable-diffusion:
https://github.com/CompVis/stable-diffusion/blob/69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc/ldm/modules/attention.py#L180
it means that if we want to produce n identical images in one python run: the first will be wrong, but subsequent images will be correct. this makes it hard to generate transitions (e.g. animations or latent walks), where you want to be always starting from the same image before you make a tweak:
from torch import einsum, tensor, matmul
t = tensor([[[0., 1.],
[2., 3.]]], device='mps')
# result from CPU is correct:
einsum('b i d, b j d -> b i j', t.cpu(), t.cpu())
# tensor([[[ 1., 3.],
# [ 3., 13.]]])
# first result from MPS is wrong:
einsum('b i d, b j d -> b i j', t, t)
# tensor([[[ 2., 3.],
# [ 6., 11.]]], device='mps:0')
# subsequent results from MPS are correct:
einsum('b i d, b j d -> b i j', t, t)
# tensor([[[ 1., 3.],
# [ 3., 13.]]], device='mps:0')
einsum('b i d, b j d -> b i j', t, t)
# tensor([[[ 1., 3.],
# [ 3., 13.]]], device='mps:0')
# btw this einsum is equivalent to the following matmul:
matmul(t, t.transpose(1, 2))
# tensor([[[ 1., 3.],
# [ 3., 13.]]], device='mps:0')
# in other words a matmul over these:
# tensor([[[0., 1.],
# [2., 3.]]]) *
# tensor([[[0., 2.],
# [1., 3.]]]) =
# tensor([[[0*0+1*1, 2*0+3*1],
# [2*0+3*1, 2*2+3*3]]])
works fine on 1.12.1.
broken on 1.13.0.dev20220917
.
I believe it was broken at least as far back as 1.13.0.dev20220826
(from which I upgraded today to see if this was fixed).
this also explains why I got different images from einsum()
than I did via matmul()
:
huggingface/diffusers#452 (comment)
Versions
PyTorch version: 1.13.0.dev20220917
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 12.5 (arm64)
GCC version: Could not collect
Clang version: 13.0.0 (clang-1300.0.29.30)
CMake version: version 3.22.1
Libc version: N/A
Python version: 3.10.4 (main, Mar 31 2022, 03:37:37) [Clang 12.0.0 ] (64-bit runtime)
Python platform: macOS-12.5-arm64-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] numpy==1.22.4
[pip3] pytorch-lightning==1.4.2
[pip3] torch==1.13.0.dev20220917
[pip3] torch-fidelity==0.3.0
[pip3] torchdiffeq==0.2.3
[pip3] torchmetrics==0.6.0
[pip3] torchtyping==0.1.4
[pip3] torchvision==0.14.0.dev20220917
[conda] numpy 1.22.4 pypi_0 pypi
[conda] pytorch-lightning 1.4.2 pypi_0 pypi
[conda] torch 1.13.0.dev20220917 pypi_0 pypi
[conda] torch-fidelity 0.3.0 pypi_0 pypi
[conda] torchdiffeq 0.2.3 pypi_0 pypi
[conda] torchmetrics 0.6.0 pypi_0 pypi
[conda] torchtyping 0.1.4 pypi_0 pypi
[conda] torchvision 0.14.0.dev20220917 pypi_0 pypi