Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot install xformers on linux server #390

Closed
fedshyvana opened this issue Sep 14, 2022 · 14 comments · Fixed by #529
Closed

Cannot install xformers on linux server #390

fedshyvana opened this issue Sep 14, 2022 · 14 comments · Fixed by #529

Comments

@fedshyvana
Copy link

fedshyvana commented Sep 14, 2022

❓ Questions and Help

When I tried either pip install or build from source, I get this issue:

 × python setup.py egg_info did not run successfully.
  │ exit code: 1
  ╰─> [18 lines of output]
      Traceback (most recent call last):
        File "<string>", line 2, in <module>
        File "<pip-setuptools-caller>", line 34, in <module>
        File "/home/username/xformers/setup.py", line 239, in <module>
          ext_modules=get_extensions(),
        File "/home/username/xformers/setup.py", line 187, in get_extensions
          cuda_version = get_cuda_version(CUDA_HOME)
        File "/home/username/xformers/setup.py", line 51, in get_cuda_version
          raw_output = subprocess.check_output([nvcc_bin, "-V"], universal_newlines=True)
        File "/home/username/anaconda3/envs/test_env/lib/python3.9/subprocess.py", line 424, in check_output
          return run(*popenargs, stdout=PIPE, timeout=timeout, check=True,
        File "/home/username/anaconda3/envs/test_env/lib/python3.9/subprocess.py", line 505, in run
          with Popen(*popenargs, **kwargs) as process:
        File "/home/username/anaconda3/envs/test_env/lib/python3.9/subprocess.py", line 951, in __init__
          self._execute_child(args, executable, preexec_fn, close_fds,
        File "/home/username/anaconda3/envs/test_env/lib/python3.9/subprocess.py", line 1821, in _execute_child
          raise child_exception_type(errno_num, err_msg, err_filename)
      FileNotFoundError: [Errno 2] No such file or directory: '/home/username/anaconda3/envs/test_env/bin/nvcc'
      [end of output]

here's the output of nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Tue_Mar__8_18:18:20_PST_2022
Cuda compilation tools, release 11.6, V11.6.124
Build cuda_11.6.r11.6/compiler.31057947_0

and as additional information, I was able to install pytorch the usual way/verify that cuda is available.

@blefaudeux
Copy link
Contributor

you can check with which nvcc where the binary resides, looks like for some reason the setup script looks for it in /home/username/anaconda3/envs/test_env/bin/nvcc and it's not there. Something which works well for me is to let conda handle torch + cuda toolchain, this way the versions are in sync & the paths are usually reliable. I hope that helps

@fmassa
Copy link
Contributor

fmassa commented Sep 16, 2022

Hi,

Thanks for reporting this issue!

We should get this fixed, the following part

xformers/setup.py

Lines 49 to 59 in 51dd119

def get_cuda_version(cuda_dir) -> int:
nvcc_bin = "nvcc" if cuda_dir is None else cuda_dir + "/bin/nvcc"
raw_output = subprocess.check_output([nvcc_bin, "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = int(release[0])
bare_metal_minor = int(release[1][0])
assert bare_metal_minor < 100
return bare_metal_major * 100 + bare_metal_minor

is not robust to all systems

@trufty
Copy link

trufty commented Oct 9, 2022

Just chiming in that I ran into the same error using an Nvidia Cuda docker container, but switching from nvidia/cuda:11.7.x-base-ubuntu22.04 to nvidia/cuda:11.7.x-devel-ubuntu22.04 devel resolved the issue since it has nvcc pre-installed (although much larger image)

Also had to drop back to xformers==0.0.12 because of a cutlass sub module issue. But either way it works now!

@fmassa
Copy link
Contributor

fmassa commented Oct 10, 2022

Hi,

We have recently added a conda package for xformers in https://anaconda.org/xformers/xformers, could you maybe try using it instead?

@cornpo
Copy link

cornpo commented Oct 19, 2022

Ubuntu 22, python3.9, ROCm. I did conda install -c "xformers/label/dev" xformers

Collecting package metadata (current_repodata.json): done
Solving environment: failed with initial frozen solve. Retrying with flexible solve.
Solving environment: failed with repodata from current_repodata.json, will retry with next repodata source.
Collecting package metadata (repodata.json): done
Solving environment: failed with initial frozen solve. Retrying with flexible solve.
Solving environment: | 
Found conflicts! Looking for incompatible packages.
This can take several minutes.  Press CTRL-C to abort.
failed                                                                                                                        

UnsatisfiableError: The following specifications were found to be incompatible with each other:

Output in format: Requested package -> Available versionsThe following specifications were found to be incompatible with your system:

  - feature:/linux-64::__glibc==2.35=0
  - feature:|@/linux-64::__glibc==2.35=0

Your installed version is: 2.35

@0xdevalias
Copy link

0xdevalias commented Nov 15, 2022

We have recently added a conda package for xformers in https://anaconda.org/xformers/xformers, could you maybe try using it instead?

Originally posted by @fmassa in #390 (comment)

Seeing a similar error as above:

⇒ conda install -n "$CONDA_ENV_NAME" xformers -c xformers/label/dev

Collecting package metadata (current_repodata.json): done
Solving environment: failed with initial frozen solve. Retrying with flexible solve.
Solving environment: failed with repodata from current_repodata.json, will retry with next repodata source.
Collecting package metadata (repodata.json): done
Solving environment: failed with initial frozen solve. Retrying with flexible solve.
Solving environment: - 
Found conflicts! Looking for incompatible packages.
This can take several minutes.  Press CTRL-C to abort.
                                                                               failed

UnsatisfiableError: The following specifications were found to be incompatible with each other:

Output in format: Requested package -> Available versionsThe following specifications were found to be incompatible with your system:

  - feature:/linux-64::__glibc==2.27=0
  - feature:|@/linux-64::__glibc==2.27=0

Your installed version is: 2.27
⇒ ldd --version
ldd (Ubuntu GLIBC 2.27-3ubuntu1.6) 2.27
⇒ uname -a
Linux 010e900fca52 5.15.0-25-generic #25-Ubuntu SMP Wed Mar 30 15:54:22 UTC 2022 x86_64 x86_64 x86_64 GNU/Linux
⇒ conda --version
conda 4.13.0
⇒ python --version
Python 3.10.8

A few potentially related issues I found when googling:


@fmassa Any thoughts on what might be going wrong here?

@0xdevalias
Copy link

0xdevalias commented Nov 15, 2022

Digging into this a bit more:

Tried looking at the specific versions available with conda search:

⇒ conda search xformers -c xformers/label/dev

Loading channels: done
# Name                       Version           Build  Channel
..snip..
xformers             0.0.15.dev337+git.fd21b40 py310_cu11.3_pyt1.12.1  xformers/label/dev
xformers             0.0.15.dev337+git.fd21b40 py310_cu11.6_pyt1.12.1  xformers/label/dev
xformers             0.0.15.dev337+git.fd21b40 py310_cu11.6_pyt1.13  xformers/label/dev
xformers             0.0.15.dev337+git.fd21b40 py310_cu11.7_pyt1.13  xformers/label/dev
xformers             0.0.15.dev337+git.fd21b40 py39_cu11.3_pyt1.12.1  xformers/label/dev
xformers             0.0.15.dev337+git.fd21b40 py39_cu11.6_pyt1.12.1  xformers/label/dev
xformers             0.0.15.dev337+git.fd21b40 py39_cu11.6_pyt1.13  xformers/label/dev
xformers             0.0.15.dev337+git.fd21b40 py39_cu11.7_pyt1.13  xformers/label/dev
Full output
⇒ conda search xformers -c xformers/label/dev

Loading channels: done
# Name                       Version           Build  Channel
xformers             0.0.14.dev309+git.2236ed0 py310_cu11.3_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev309+git.2236ed0 py310_cu11.6_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev309+git.2236ed0 py39_cu11.3_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev309+git.2236ed0 py39_cu11.6_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev310+git.e31c571 py310_cu11.3_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev310+git.e31c571 py310_cu11.6_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev310+git.e31c571 py39_cu11.3_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev310+git.e31c571 py39_cu11.6_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev311+git.ba93c50 py310_cu11.3_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev311+git.ba93c50 py310_cu11.6_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev311+git.ba93c50 py39_cu11.3_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev311+git.ba93c50 py39_cu11.6_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev312+git.3633e1a py310_cu11.3_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev312+git.3633e1a py310_cu11.6_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev312+git.3633e1a py39_cu11.3_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev312+git.3633e1a py39_cu11.6_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev313+git.265eb03 py310_cu11.3_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev313+git.265eb03 py310_cu11.6_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev313+git.265eb03 py39_cu11.3_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev313+git.265eb03 py39_cu11.6_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev314+git.faa88b1 py310_cu11.3_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev314+git.faa88b1 py310_cu11.6_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev314+git.faa88b1 py39_cu11.3_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev314+git.faa88b1 py39_cu11.6_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev315+git.e23b369 py310_cu11.3_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev315+git.e23b369 py310_cu11.6_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev315+git.e23b369 py39_cu11.3_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev315+git.e23b369 py39_cu11.6_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev316+git.95ad2fc py310_cu11.3_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev316+git.95ad2fc py310_cu11.6_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev316+git.95ad2fc py39_cu11.3_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev316+git.95ad2fc py39_cu11.6_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev317+git.42e5c27 py310_cu11.3_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev317+git.42e5c27 py310_cu11.6_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev317+git.42e5c27 py39_cu11.3_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev317+git.42e5c27 py39_cu11.6_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev318+git.44c560d py310_cu11.3_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev318+git.44c560d py310_cu11.6_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev318+git.44c560d py39_cu11.3_pyt1.12.1  xformers/label/dev
xformers             0.0.14.dev318+git.44c560d py39_cu11.6_pyt1.12.1  xformers/label/dev
xformers             0.0.15.dev332+git.8367685 py310_cu11.3_pyt1.12.1  xformers/label/dev
xformers             0.0.15.dev332+git.8367685 py310_cu11.6_pyt1.12.1  xformers/label/dev
xformers             0.0.15.dev332+git.8367685 py39_cu11.3_pyt1.12.1  xformers/label/dev
xformers             0.0.15.dev332+git.8367685 py39_cu11.6_pyt1.12.1  xformers/label/dev
xformers             0.0.15.dev336+git.cb79827 py310_cu11.3_pyt1.12.1  xformers/label/dev
xformers             0.0.15.dev336+git.cb79827 py310_cu11.6_pyt1.12.1  xformers/label/dev
xformers             0.0.15.dev336+git.cb79827 py39_cu11.3_pyt1.12.1  xformers/label/dev
xformers             0.0.15.dev336+git.cb79827 py39_cu11.6_pyt1.12.1  xformers/label/dev
xformers             0.0.15.dev337+git.fd21b40 py310_cu11.3_pyt1.12.1  xformers/label/dev
xformers             0.0.15.dev337+git.fd21b40 py310_cu11.6_pyt1.12.1  xformers/label/dev
xformers             0.0.15.dev337+git.fd21b40 py310_cu11.6_pyt1.13  xformers/label/dev
xformers             0.0.15.dev337+git.fd21b40 py310_cu11.7_pyt1.13  xformers/label/dev
xformers             0.0.15.dev337+git.fd21b40 py39_cu11.3_pyt1.12.1  xformers/label/dev
xformers             0.0.15.dev337+git.fd21b40 py39_cu11.6_pyt1.12.1  xformers/label/dev
xformers             0.0.15.dev337+git.fd21b40 py39_cu11.6_pyt1.13  xformers/label/dev
xformers             0.0.15.dev337+git.fd21b40 py39_cu11.7_pyt1.13  xformers/label/dev

And then installing one that should be compatible with my system, though it just gave me a different error:

⇒ conda install -n "$CONDA_ENV_NAME" -c xformers/label/dev xformers=*=py310_cu11.7_pyt1.13

Collecting package metadata (current_repodata.json): done
Solving environment: failed with initial frozen solve. Retrying with flexible solve.
Solving environment: failed with repodata from current_repodata.json, will retry with next repodata source.
Collecting package metadata (repodata.json): done
Solving environment: failed with initial frozen solve. Retrying with flexible solve.
Solving environment: / 
Found conflicts! Looking for incompatible packages.
This can take several minutes.  Press CTRL-C to abort.
                                                                               failed

UnsatisfiableError: The following specifications were found to be incompatible with each other:

Output in format: Requested package -> Available versionsThe following specifications were found to be incompatible with your system:

  - feature:/linux-64::__glibc==2.27=0
  - python=3.10 -> libgcc-ng[version='>=11.2.0'] -> __glibc[version='>=2.17']

Your installed version is: 2.27

Trying a different build:

⇒ conda install -n "$CONDA_ENV_NAME" -c xformers/label/dev xformers=*=py310_cu11.6_pyt1.12.1

Collecting package metadata (current_repodata.json): done
Solving environment: failed with initial frozen solve. Retrying with flexible solve.
Solving environment: failed with repodata from current_repodata.json, will retry with next repodata source.
Collecting package metadata (repodata.json): done
Solving environment: failed with initial frozen solve. Retrying with flexible solve.
Solving environment: - 
Found conflicts! Looking for incompatible packages.
This can take several minutes.  Press CTRL-C to abort.
                                                                               failed

UnsatisfiableError: The following specifications were found to be incompatible with each other:

Output in format: Requested package -> Available versionsThe following specifications were found to be incompatible with your system:

  - feature:/linux-64::__glibc==2.27=0
  - feature:|@/linux-64::__glibc==2.27=0

Your installed version is: 2.27

@0xdevalias
Copy link

0xdevalias commented Nov 15, 2022

This StackOverflow gave me an idea:

Looking at the xformers versions online:

We can see that, for example, linux-64/xformers-0.0.15.dev337+git.fd21b40-py310_cu11.7_pyt1.13 has the following 'depends':

python >=3.10,<3.11.0a0, pytorch 1.13, pytorch-cuda >=11.7,<11.8

Yet the conda env I had activated didn't currently have any versions of pytorch / pytorch-cuda installed into it. I assumed that shouldn't matter, particularly when I was explicitly requesting the version I wanted, as I presumed it would just install it for me.. but maybe not?

Trying a few other things:

⇒ conda install -n "$CONDA_ENV_NAME" -c pytorch pytorch=1.13 pytorch-cuda=11.7 (error)

⇒ conda install -n "$CONDA_ENV_NAME" -c pytorch pytorch=1.13 pytorch-cuda=11.7

Collecting package metadata (current_repodata.json): done
Solving environment: failed with initial frozen solve. Retrying with flexible solve.
Solving environment: failed with repodata from current_repodata.json, will retry with next repodata source.
Collecting package metadata (repodata.json): done
Solving environment: failed with initial frozen solve. Retrying with flexible solve.
Solving environment: | 
Found conflicts! Looking for incompatible packages.
This can take several minutes.  Press CTRL-C to abort.
                                                                               failed

UnsatisfiableError: The following specifications were found to be incompatible with each other:

Output in format: Requested package -> Available versions

Package pytorch-cuda conflicts for:
pytorch-cuda=11.7
pytorch=1.13 -> pytorch-cuda[version='>=11.6,<11.7|>=11.7,<11.8']The following specifications were found to be incompatible with your system:

  - feature:/linux-64::__glibc==2.27=0
  - feature:|@/linux-64::__glibc==2.27=0
  - python=3.10 -> libgcc-ng[version='>=11.2.0'] -> __glibc[version='>=2.17']

Your installed version is: 2.27

⇒ conda install -n "$CONDA_ENV_NAME" -c pytorch pytorch-cuda=11.7 (error)

⇒ conda install -n "$CONDA_ENV_NAME" -c pytorch pytorch-cuda=11.7

Collecting package metadata (current_repodata.json): done
Solving environment: failed with initial frozen solve. Retrying with flexible solve.
Solving environment: failed with repodata from current_repodata.json, will retry with next repodata source.
Collecting package metadata (repodata.json): done
Solving environment: failed with initial frozen solve. Retrying with flexible solve.
Solving environment: | 
Found conflicts! Looking for incompatible packages.
This can take several minutes.  Press CTRL-C to abort.
                                                                               failed

UnsatisfiableError: The following specifications were found to be incompatible with your system:

  - feature:/linux-64::__glibc==2.27=0
  - python=3.10 -> libgcc-ng[version='>=11.2.0'] -> __glibc[version='>=2.17']

Your installed version is: 2.27

⇒ conda install -n "$CONDA_ENV_NAME" -c pytorch pytorch=1.13 (seems like it will work)

⇒ conda install -n "$CONDA_ENV_NAME" -c pytorch pytorch=1.13

Collecting package metadata (current_repodata.json): done
Solving environment: done

## Package Plan ##

  environment location: /opt/conda/envs/xformers

  added / updated specs:
    - pytorch=1.13


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    blas-1.0                   |              mkl           6 KB
    intel-openmp-2022.1.0      |    h9e868ea_3769         4.5 MB
    mkl-2022.1.0               |     hc2b9512_224       129.7 MB
    pytorch-1.13.0             |     py3.10_cpu_0        84.3 MB  pytorch
    pytorch-mutex-1.0          |              cpu           3 KB  pytorch
    typing_extensions-4.3.0    |  py310h06a4308_0          42 KB
    ------------------------------------------------------------
                                           Total:       218.6 MB

The following NEW packages will be INSTALLED:

  blas               pkgs/main/linux-64::blas-1.0-mkl None
  intel-openmp       pkgs/main/linux-64::intel-openmp-2022.1.0-h9e868ea_3769 None
  mkl                pkgs/main/linux-64::mkl-2022.1.0-hc2b9512_224 None
  pytorch            pytorch/linux-64::pytorch-1.13.0-py3.10_cpu_0 None
  pytorch-mutex      pytorch/noarch::pytorch-mutex-1.0-cpu None
  typing_extensions  pkgs/main/linux-64::typing_extensions-4.3.0-py310h06a4308_0 None


Proceed ([y]/n)? 

So narrowing down the issue to seemingly pytorch-cuda, I looked at it's 'depends'

pytorch-cuda-11.7-h67b0de4_0

cuda-nvcc >=11.7,<11.8, cuda-command-line-tools >=11.7,<11.8, cuda-driver-dev >=11.7,<11.8, cuda-cuobjdump >=11.7,<11.8, cuda-cudart >=11.7,<11.8, cuda-toolkit >=11.7,<11.8, cuda-nvrtc >=11.7,<11.8, cuda-cupti >=11.7,<11.8, cuda-runtime >=11.7,<11.8, cuda-cudart-dev >=11.7,<11.8, cuda-nvtx >=11.7,<11.8, cuda-nvprune >=11.7,<11.8, cuda-nvrtc-dev >=11.7,<11.8, cuda-libraries >=11.7,<11.8, cuda-libraries-dev >=11.7,<11.8, cuda-cudaart-dev >=11.7,<11.8, cuda-compiler >=11.7,<11.8, cuda-cccl >=11.7,<11.8, cuda-nvml-dev >=11.7,<11.8, cuda-cuxxfilt >=11.7,<11.8, cuda-tools >=11.7,<11.8

I figured these should all be able to be provided by cuda-toolkit:

So I decided to try installing that, which looked like it would work:

⇒ conda install -n "$CONDA_ENV_NAME" -c "nvidia/label/cuda-11.7.1" cuda-toolkit (seems like it will work)

⇒ conda install -n "$CONDA_ENV_NAME" -c "nvidia/label/cuda-11.7.1" cuda-toolkit

Collecting package metadata (current_repodata.json): done
Solving environment: done

## Package Plan ##

  environment location: /opt/conda/envs/xformers

  added / updated specs:
    - cuda-toolkit


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    cuda-cccl-11.7.91          |                0         1.2 MB  nvidia/label/cuda-11.7.1
    cuda-command-line-tools-11.7.1|                0           1 KB  nvidia/label/cuda-11.7.1
    cuda-compiler-11.7.1       |                0           1 KB  nvidia/label/cuda-11.7.1
    cuda-cudart-11.7.99        |                0         194 KB  nvidia/label/cuda-11.7.1
    cuda-cudart-dev-11.7.99    |                0         1.1 MB  nvidia/label/cuda-11.7.1
    cuda-cuobjdump-11.7.91     |                0         158 KB  nvidia/label/cuda-11.7.1
    cuda-cupti-11.7.101        |                0        22.9 MB  nvidia/label/cuda-11.7.1
    cuda-cuxxfilt-11.7.91      |                0         293 KB  nvidia/label/cuda-11.7.1
    cuda-documentation-11.7.91 |                0          88 KB  nvidia/label/cuda-11.7.1
    cuda-driver-dev-11.7.99    |                0          16 KB  nvidia/label/cuda-11.7.1
    cuda-gdb-11.7.91           |                0         4.8 MB  nvidia/label/cuda-11.7.1
    cuda-libraries-11.7.1      |                0           1 KB  nvidia/label/cuda-11.7.1
    cuda-libraries-dev-11.7.1  |                0           2 KB  nvidia/label/cuda-11.7.1
    cuda-memcheck-11.7.91      |                0         168 KB  nvidia/label/cuda-11.7.1
    cuda-nsight-11.7.91        |                0       113.6 MB  nvidia/label/cuda-11.7.1
    cuda-nsight-compute-11.7.1 |                0           1 KB  nvidia/label/cuda-11.7.1
    cuda-nvcc-11.7.99          |                0        42.7 MB  nvidia/label/cuda-11.7.1
    cuda-nvdisasm-11.7.91      |                0        31.5 MB  nvidia/label/cuda-11.7.1
    cuda-nvml-dev-11.7.91      |                0          80 KB  nvidia/label/cuda-11.7.1
    cuda-nvprof-11.7.101       |                0         4.3 MB  nvidia/label/cuda-11.7.1
    cuda-nvprune-11.7.91       |                0          64 KB  nvidia/label/cuda-11.7.1
    cuda-nvrtc-11.7.99         |                0        17.3 MB  nvidia/label/cuda-11.7.1
    cuda-nvrtc-dev-11.7.99     |                0        16.9 MB  nvidia/label/cuda-11.7.1
    cuda-nvtx-11.7.91          |                0          57 KB  nvidia/label/cuda-11.7.1
    cuda-nvvp-11.7.101         |                0       114.3 MB  nvidia/label/cuda-11.7.1
    cuda-sanitizer-api-11.7.91 |                0        16.8 MB  nvidia/label/cuda-11.7.1
    cuda-toolkit-11.7.1        |                0           1 KB  nvidia/label/cuda-11.7.1
    cuda-tools-11.7.1          |                0           1 KB  nvidia/label/cuda-11.7.1
    cuda-visual-tools-11.7.1   |                0           1 KB  nvidia/label/cuda-11.7.1
    gds-tools-1.3.1.18         |                0           2 KB  nvidia/label/cuda-11.7.1
    libcublas-11.10.3.66       |                0       286.1 MB  nvidia/label/cuda-11.7.1
    libcublas-dev-11.10.3.66   |                0       296.4 MB  nvidia/label/cuda-11.7.1
    libcufft-10.7.2.91         |                0        93.6 MB  nvidia/label/cuda-11.7.1
    libcufft-dev-10.7.2.91     |                0       196.4 MB  nvidia/label/cuda-11.7.1
    libcufile-1.3.1.18         |                0         545 KB  nvidia/label/cuda-11.7.1
    libcufile-dev-1.3.1.18     |                0        12.4 MB  nvidia/label/cuda-11.7.1
    libcurand-10.2.10.91       |                0        50.3 MB  nvidia/label/cuda-11.7.1
    libcurand-dev-10.2.10.91   |                0        50.7 MB  nvidia/label/cuda-11.7.1
    libcusolver-11.4.0.1       |                0        78.7 MB  nvidia/label/cuda-11.7.1
    libcusolver-dev-11.4.0.1   |                0        55.9 MB  nvidia/label/cuda-11.7.1
    libcusparse-11.7.4.91      |                0       151.1 MB  nvidia/label/cuda-11.7.1
    libcusparse-dev-11.7.4.91  |                0       309.5 MB  nvidia/label/cuda-11.7.1
    libnpp-11.7.4.75           |                0       129.3 MB  nvidia/label/cuda-11.7.1
    libnpp-dev-11.7.4.75       |                0       126.6 MB  nvidia/label/cuda-11.7.1
    libnvjpeg-11.8.0.2         |                0         2.2 MB  nvidia/label/cuda-11.7.1
    libnvjpeg-dev-11.8.0.2     |                0         1.9 MB  nvidia/label/cuda-11.7.1
    nsight-compute-2022.2.1.3  |                0       461.4 MB  nvidia/label/cuda-11.7.1
    ------------------------------------------------------------
                                           Total:        2.63 GB

The following NEW packages will be INSTALLED:

  cuda-cccl          nvidia/label/cuda-11.7.1/linux-64::cuda-cccl-11.7.91-0 None
  cuda-command-line~ nvidia/label/cuda-11.7.1/linux-64::cuda-command-line-tools-11.7.1-0 None
  cuda-compiler      nvidia/label/cuda-11.7.1/linux-64::cuda-compiler-11.7.1-0 None
  cuda-cudart        nvidia/label/cuda-11.7.1/linux-64::cuda-cudart-11.7.99-0 None
  cuda-cudart-dev    nvidia/label/cuda-11.7.1/linux-64::cuda-cudart-dev-11.7.99-0 None
  cuda-cuobjdump     nvidia/label/cuda-11.7.1/linux-64::cuda-cuobjdump-11.7.91-0 None
  cuda-cupti         nvidia/label/cuda-11.7.1/linux-64::cuda-cupti-11.7.101-0 None
  cuda-cuxxfilt      nvidia/label/cuda-11.7.1/linux-64::cuda-cuxxfilt-11.7.91-0 None
  cuda-documentation nvidia/label/cuda-11.7.1/linux-64::cuda-documentation-11.7.91-0 None
  cuda-driver-dev    nvidia/label/cuda-11.7.1/linux-64::cuda-driver-dev-11.7.99-0 None
  cuda-gdb           nvidia/label/cuda-11.7.1/linux-64::cuda-gdb-11.7.91-0 None
  cuda-libraries     nvidia/label/cuda-11.7.1/linux-64::cuda-libraries-11.7.1-0 None
  cuda-libraries-dev nvidia/label/cuda-11.7.1/linux-64::cuda-libraries-dev-11.7.1-0 None
  cuda-memcheck      nvidia/label/cuda-11.7.1/linux-64::cuda-memcheck-11.7.91-0 None
  cuda-nsight        nvidia/label/cuda-11.7.1/linux-64::cuda-nsight-11.7.91-0 None
  cuda-nsight-compu~ nvidia/label/cuda-11.7.1/linux-64::cuda-nsight-compute-11.7.1-0 None
  cuda-nvcc          nvidia/label/cuda-11.7.1/linux-64::cuda-nvcc-11.7.99-0 None
  cuda-nvdisasm      nvidia/label/cuda-11.7.1/linux-64::cuda-nvdisasm-11.7.91-0 None
  cuda-nvml-dev      nvidia/label/cuda-11.7.1/linux-64::cuda-nvml-dev-11.7.91-0 None
  cuda-nvprof        nvidia/label/cuda-11.7.1/linux-64::cuda-nvprof-11.7.101-0 None
  cuda-nvprune       nvidia/label/cuda-11.7.1/linux-64::cuda-nvprune-11.7.91-0 None
  cuda-nvrtc         nvidia/label/cuda-11.7.1/linux-64::cuda-nvrtc-11.7.99-0 None
  cuda-nvrtc-dev     nvidia/label/cuda-11.7.1/linux-64::cuda-nvrtc-dev-11.7.99-0 None
  cuda-nvtx          nvidia/label/cuda-11.7.1/linux-64::cuda-nvtx-11.7.91-0 None
  cuda-nvvp          nvidia/label/cuda-11.7.1/linux-64::cuda-nvvp-11.7.101-0 None
  cuda-sanitizer-api nvidia/label/cuda-11.7.1/linux-64::cuda-sanitizer-api-11.7.91-0 None
  cuda-toolkit       nvidia/label/cuda-11.7.1/linux-64::cuda-toolkit-11.7.1-0 None
  cuda-tools         nvidia/label/cuda-11.7.1/linux-64::cuda-tools-11.7.1-0 None
  cuda-visual-tools  nvidia/label/cuda-11.7.1/linux-64::cuda-visual-tools-11.7.1-0 None
  gds-tools          nvidia/label/cuda-11.7.1/linux-64::gds-tools-1.3.1.18-0 None
  libcublas          nvidia/label/cuda-11.7.1/linux-64::libcublas-11.10.3.66-0 None
  libcublas-dev      nvidia/label/cuda-11.7.1/linux-64::libcublas-dev-11.10.3.66-0 None
  libcufft           nvidia/label/cuda-11.7.1/linux-64::libcufft-10.7.2.91-0 None
  libcufft-dev       nvidia/label/cuda-11.7.1/linux-64::libcufft-dev-10.7.2.91-0 None
  libcufile          nvidia/label/cuda-11.7.1/linux-64::libcufile-1.3.1.18-0 None
  libcufile-dev      nvidia/label/cuda-11.7.1/linux-64::libcufile-dev-1.3.1.18-0 None
  libcurand          nvidia/label/cuda-11.7.1/linux-64::libcurand-10.2.10.91-0 None
  libcurand-dev      nvidia/label/cuda-11.7.1/linux-64::libcurand-dev-10.2.10.91-0 None
  libcusolver        nvidia/label/cuda-11.7.1/linux-64::libcusolver-11.4.0.1-0 None
  libcusolver-dev    nvidia/label/cuda-11.7.1/linux-64::libcusolver-dev-11.4.0.1-0 None
  libcusparse        nvidia/label/cuda-11.7.1/linux-64::libcusparse-11.7.4.91-0 None
  libcusparse-dev    nvidia/label/cuda-11.7.1/linux-64::libcusparse-dev-11.7.4.91-0 None
  libnpp             nvidia/label/cuda-11.7.1/linux-64::libnpp-11.7.4.75-0 None
  libnpp-dev         nvidia/label/cuda-11.7.1/linux-64::libnpp-dev-11.7.4.75-0 None
  libnvjpeg          nvidia/label/cuda-11.7.1/linux-64::libnvjpeg-11.8.0.2-0 None
  libnvjpeg-dev      nvidia/label/cuda-11.7.1/linux-64::libnvjpeg-dev-11.8.0.2-0 None
  nsight-compute     nvidia/label/cuda-11.7.1/linux-64::nsight-compute-2022.2.1.3-0 None


Proceed ([y]/n)? 

That then made me think that maybe the earlier conda install commands were actually failing because the -c pytorch and -c "nvidia/label/cuda-11.7.1" wasn't available to add the 'package repo' for it to find/satisfy the cuda dependencies.. so then I tried the following:

⇒ conda install -n "$CONDA_ENV_NAME" -c xformers/label/dev -c pytorch -c nvidia/label/cuda-11.7.1 xformers=*=py310_cu11.7_pyt1.13 (seems like it will work)

⇒ conda install -n "$CONDA_ENV_NAME" -c xformers/label/dev -c pytorch -c nvidia/label/cuda-11.7.1 xformers=*=py310_cu11.7_pyt1.13

Collecting package metadata (current_repodata.json): done
Solving environment: done

## Package Plan ##

  environment location: /opt/conda/envs/xformers

  added / updated specs:
    - xformers[build=py310_cu11.7_pyt1.13]


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    blas-1.0                   |              mkl           6 KB
    cuda-11.7.1                |                0           1 KB  nvidia/label/cuda-11.7.1
    cuda-cccl-11.7.91          |                0         1.2 MB  nvidia/label/cuda-11.7.1
    cuda-command-line-tools-11.7.1|                0           1 KB  nvidia/label/cuda-11.7.1
    cuda-compiler-11.7.1       |                0           1 KB  nvidia/label/cuda-11.7.1
    cuda-cudart-11.7.99        |                0         194 KB  nvidia/label/cuda-11.7.1
    cuda-cudart-dev-11.7.99    |                0         1.1 MB  nvidia/label/cuda-11.7.1
    cuda-cuobjdump-11.7.91     |                0         158 KB  nvidia/label/cuda-11.7.1
    cuda-cupti-11.7.101        |                0        22.9 MB  nvidia/label/cuda-11.7.1
    cuda-cuxxfilt-11.7.91      |                0         293 KB  nvidia/label/cuda-11.7.1
    cuda-demo-suite-11.7.91    |                0         4.9 MB  nvidia/label/cuda-11.7.1
    cuda-documentation-11.7.91 |                0          88 KB  nvidia/label/cuda-11.7.1
    cuda-driver-dev-11.7.99    |                0          16 KB  nvidia/label/cuda-11.7.1
    cuda-gdb-11.7.91           |                0         4.8 MB  nvidia/label/cuda-11.7.1
    cuda-libraries-11.7.1      |                0           1 KB  nvidia/label/cuda-11.7.1
    cuda-libraries-dev-11.7.1  |                0           2 KB  nvidia/label/cuda-11.7.1
    cuda-memcheck-11.7.91      |                0         168 KB  nvidia/label/cuda-11.7.1
    cuda-nsight-11.7.91        |                0       113.6 MB  nvidia/label/cuda-11.7.1
    cuda-nsight-compute-11.7.1 |                0           1 KB  nvidia/label/cuda-11.7.1
    cuda-nvcc-11.7.99          |                0        42.7 MB  nvidia/label/cuda-11.7.1
    cuda-nvdisasm-11.7.91      |                0        31.5 MB  nvidia/label/cuda-11.7.1
    cuda-nvml-dev-11.7.91      |                0          80 KB  nvidia/label/cuda-11.7.1
    cuda-nvprof-11.7.101       |                0         4.3 MB  nvidia/label/cuda-11.7.1
    cuda-nvprune-11.7.91       |                0          64 KB  nvidia/label/cuda-11.7.1
    cuda-nvrtc-11.7.99         |                0        17.3 MB  nvidia/label/cuda-11.7.1
    cuda-nvrtc-dev-11.7.99     |                0        16.9 MB  nvidia/label/cuda-11.7.1
    cuda-nvtx-11.7.91          |                0          57 KB  nvidia/label/cuda-11.7.1
    cuda-nvvp-11.7.101         |                0       114.3 MB  nvidia/label/cuda-11.7.1
    cuda-runtime-11.7.1        |                0           1 KB  nvidia/label/cuda-11.7.1
    cuda-sanitizer-api-11.7.91 |                0        16.8 MB  nvidia/label/cuda-11.7.1
    cuda-toolkit-11.7.1        |                0           1 KB  nvidia/label/cuda-11.7.1
    cuda-tools-11.7.1          |                0           1 KB  nvidia/label/cuda-11.7.1
    cuda-visual-tools-11.7.1   |                0           1 KB  nvidia/label/cuda-11.7.1
    gds-tools-1.3.1.18         |                0           2 KB  nvidia/label/cuda-11.7.1
    intel-openmp-2022.1.0      |    h9e868ea_3769         4.5 MB
    libcublas-11.10.3.66       |                0       286.1 MB  nvidia/label/cuda-11.7.1
    libcublas-dev-11.10.3.66   |                0       296.4 MB  nvidia/label/cuda-11.7.1
    libcufft-10.7.2.91         |                0        93.6 MB  nvidia/label/cuda-11.7.1
    libcufft-dev-10.7.2.91     |                0       196.4 MB  nvidia/label/cuda-11.7.1
    libcufile-1.3.1.18         |                0         545 KB  nvidia/label/cuda-11.7.1
    libcufile-dev-1.3.1.18     |                0        12.4 MB  nvidia/label/cuda-11.7.1
    libcurand-10.2.10.91       |                0        50.3 MB  nvidia/label/cuda-11.7.1
    libcurand-dev-10.2.10.91   |                0        50.7 MB  nvidia/label/cuda-11.7.1
    libcusolver-11.4.0.1       |                0        78.7 MB  nvidia/label/cuda-11.7.1
    libcusolver-dev-11.4.0.1   |                0        55.9 MB  nvidia/label/cuda-11.7.1
    libcusparse-11.7.4.91      |                0       151.1 MB  nvidia/label/cuda-11.7.1
    libcusparse-dev-11.7.4.91  |                0       309.5 MB  nvidia/label/cuda-11.7.1
    libnpp-11.7.4.75           |                0       129.3 MB  nvidia/label/cuda-11.7.1
    libnpp-dev-11.7.4.75       |                0       126.6 MB  nvidia/label/cuda-11.7.1
    libnvjpeg-11.8.0.2         |                0         2.2 MB  nvidia/label/cuda-11.7.1
    libnvjpeg-dev-11.8.0.2     |                0         1.9 MB  nvidia/label/cuda-11.7.1
    mkl-2022.1.0               |     hc2b9512_224       129.7 MB
    nsight-compute-2022.2.1.3  |                0       461.4 MB  nvidia/label/cuda-11.7.1
    pytorch-1.13.0             |py3.10_cuda11.7_cudnn8.5.0_0        1.15 GB  pytorch
    pytorch-cuda-11.7          |       h67b0de4_0           7 KB  pytorch
    pytorch-mutex-1.0          |             cuda           3 KB  pytorch
    typing_extensions-4.3.0    |  py310h06a4308_0          42 KB
    xformers-0.0.15.dev337+git.fd21b40|py310_cu11.7_pyt1.13       105.3 MB  xformers/label/dev
    ------------------------------------------------------------
                                           Total:        4.01 GB

The following NEW packages will be INSTALLED:

  blas               pkgs/main/linux-64::blas-1.0-mkl None
  cuda               nvidia/label/cuda-11.7.1/linux-64::cuda-11.7.1-0 None
  cuda-cccl          nvidia/label/cuda-11.7.1/linux-64::cuda-cccl-11.7.91-0 None
  cuda-command-line~ nvidia/label/cuda-11.7.1/linux-64::cuda-command-line-tools-11.7.1-0 None
  cuda-compiler      nvidia/label/cuda-11.7.1/linux-64::cuda-compiler-11.7.1-0 None
  cuda-cudart        nvidia/label/cuda-11.7.1/linux-64::cuda-cudart-11.7.99-0 None
  cuda-cudart-dev    nvidia/label/cuda-11.7.1/linux-64::cuda-cudart-dev-11.7.99-0 None
  cuda-cuobjdump     nvidia/label/cuda-11.7.1/linux-64::cuda-cuobjdump-11.7.91-0 None
  cuda-cupti         nvidia/label/cuda-11.7.1/linux-64::cuda-cupti-11.7.101-0 None
  cuda-cuxxfilt      nvidia/label/cuda-11.7.1/linux-64::cuda-cuxxfilt-11.7.91-0 None
  cuda-demo-suite    nvidia/label/cuda-11.7.1/linux-64::cuda-demo-suite-11.7.91-0 None
  cuda-documentation nvidia/label/cuda-11.7.1/linux-64::cuda-documentation-11.7.91-0 None
  cuda-driver-dev    nvidia/label/cuda-11.7.1/linux-64::cuda-driver-dev-11.7.99-0 None
  cuda-gdb           nvidia/label/cuda-11.7.1/linux-64::cuda-gdb-11.7.91-0 None
  cuda-libraries     nvidia/label/cuda-11.7.1/linux-64::cuda-libraries-11.7.1-0 None
  cuda-libraries-dev nvidia/label/cuda-11.7.1/linux-64::cuda-libraries-dev-11.7.1-0 None
  cuda-memcheck      nvidia/label/cuda-11.7.1/linux-64::cuda-memcheck-11.7.91-0 None
  cuda-nsight        nvidia/label/cuda-11.7.1/linux-64::cuda-nsight-11.7.91-0 None
  cuda-nsight-compu~ nvidia/label/cuda-11.7.1/linux-64::cuda-nsight-compute-11.7.1-0 None
  cuda-nvcc          nvidia/label/cuda-11.7.1/linux-64::cuda-nvcc-11.7.99-0 None
  cuda-nvdisasm      nvidia/label/cuda-11.7.1/linux-64::cuda-nvdisasm-11.7.91-0 None
  cuda-nvml-dev      nvidia/label/cuda-11.7.1/linux-64::cuda-nvml-dev-11.7.91-0 None
  cuda-nvprof        nvidia/label/cuda-11.7.1/linux-64::cuda-nvprof-11.7.101-0 None
  cuda-nvprune       nvidia/label/cuda-11.7.1/linux-64::cuda-nvprune-11.7.91-0 None
  cuda-nvrtc         nvidia/label/cuda-11.7.1/linux-64::cuda-nvrtc-11.7.99-0 None
  cuda-nvrtc-dev     nvidia/label/cuda-11.7.1/linux-64::cuda-nvrtc-dev-11.7.99-0 None
  cuda-nvtx          nvidia/label/cuda-11.7.1/linux-64::cuda-nvtx-11.7.91-0 None
  cuda-nvvp          nvidia/label/cuda-11.7.1/linux-64::cuda-nvvp-11.7.101-0 None
  cuda-runtime       nvidia/label/cuda-11.7.1/linux-64::cuda-runtime-11.7.1-0 None
  cuda-sanitizer-api nvidia/label/cuda-11.7.1/linux-64::cuda-sanitizer-api-11.7.91-0 None
  cuda-toolkit       nvidia/label/cuda-11.7.1/linux-64::cuda-toolkit-11.7.1-0 None
  cuda-tools         nvidia/label/cuda-11.7.1/linux-64::cuda-tools-11.7.1-0 None
  cuda-visual-tools  nvidia/label/cuda-11.7.1/linux-64::cuda-visual-tools-11.7.1-0 None
  gds-tools          nvidia/label/cuda-11.7.1/linux-64::gds-tools-1.3.1.18-0 None
  intel-openmp       pkgs/main/linux-64::intel-openmp-2022.1.0-h9e868ea_3769 None
  libcublas          nvidia/label/cuda-11.7.1/linux-64::libcublas-11.10.3.66-0 None
  libcublas-dev      nvidia/label/cuda-11.7.1/linux-64::libcublas-dev-11.10.3.66-0 None
  libcufft           nvidia/label/cuda-11.7.1/linux-64::libcufft-10.7.2.91-0 None
  libcufft-dev       nvidia/label/cuda-11.7.1/linux-64::libcufft-dev-10.7.2.91-0 None
  libcufile          nvidia/label/cuda-11.7.1/linux-64::libcufile-1.3.1.18-0 None
  libcufile-dev      nvidia/label/cuda-11.7.1/linux-64::libcufile-dev-1.3.1.18-0 None
  libcurand          nvidia/label/cuda-11.7.1/linux-64::libcurand-10.2.10.91-0 None
  libcurand-dev      nvidia/label/cuda-11.7.1/linux-64::libcurand-dev-10.2.10.91-0 None
  libcusolver        nvidia/label/cuda-11.7.1/linux-64::libcusolver-11.4.0.1-0 None
  libcusolver-dev    nvidia/label/cuda-11.7.1/linux-64::libcusolver-dev-11.4.0.1-0 None
  libcusparse        nvidia/label/cuda-11.7.1/linux-64::libcusparse-11.7.4.91-0 None
  libcusparse-dev    nvidia/label/cuda-11.7.1/linux-64::libcusparse-dev-11.7.4.91-0 None
  libnpp             nvidia/label/cuda-11.7.1/linux-64::libnpp-11.7.4.75-0 None
  libnpp-dev         nvidia/label/cuda-11.7.1/linux-64::libnpp-dev-11.7.4.75-0 None
  libnvjpeg          nvidia/label/cuda-11.7.1/linux-64::libnvjpeg-11.8.0.2-0 None
  libnvjpeg-dev      nvidia/label/cuda-11.7.1/linux-64::libnvjpeg-dev-11.8.0.2-0 None
  mkl                pkgs/main/linux-64::mkl-2022.1.0-hc2b9512_224 None
  nsight-compute     nvidia/label/cuda-11.7.1/linux-64::nsight-compute-2022.2.1.3-0 None
  pytorch            pytorch/linux-64::pytorch-1.13.0-py3.10_cuda11.7_cudnn8.5.0_0 None
  pytorch-cuda       pytorch/noarch::pytorch-cuda-11.7-h67b0de4_0 None
  pytorch-mutex      pytorch/noarch::pytorch-mutex-1.0-cuda None
  typing_extensions  pkgs/main/linux-64::typing_extensions-4.3.0-py310h06a4308_0 None
  xformers           xformers/label/dev/linux-64::xformers-0.0.15.dev337+git.fd21b40-py310_cu11.7_pyt1.13 None


Proceed ([y]/n)? 

And the more simplified:

⇒ conda install -n "$CONDA_ENV_NAME" -c xformers/label/dev -c pytorch -c nvidia/label/cuda-11.7.1 xformers (seems like it will work)

⇒ conda install -n "$CONDA_ENV_NAME" -c xformers/label/dev -c pytorch -c nvidia/label/cuda-11.7.1 xformers

Collecting package metadata (current_repodata.json): done
Solving environment: done

## Package Plan ##

  environment location: /opt/conda/envs/xformers

  added / updated specs:
    - xformers


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    blas-1.0                   |              mkl           6 KB
    cuda-11.7.1                |                0           1 KB  nvidia/label/cuda-11.7.1
    cuda-cccl-11.7.91          |                0         1.2 MB  nvidia/label/cuda-11.7.1
    cuda-command-line-tools-11.7.1|                0           1 KB  nvidia/label/cuda-11.7.1
    cuda-compiler-11.7.1       |                0           1 KB  nvidia/label/cuda-11.7.1
    cuda-cudart-11.7.99        |                0         194 KB  nvidia/label/cuda-11.7.1
    cuda-cudart-dev-11.7.99    |                0         1.1 MB  nvidia/label/cuda-11.7.1
    cuda-cuobjdump-11.7.91     |                0         158 KB  nvidia/label/cuda-11.7.1
    cuda-cupti-11.7.101        |                0        22.9 MB  nvidia/label/cuda-11.7.1
    cuda-cuxxfilt-11.7.91      |                0         293 KB  nvidia/label/cuda-11.7.1
    cuda-demo-suite-11.7.91    |                0         4.9 MB  nvidia/label/cuda-11.7.1
    cuda-documentation-11.7.91 |                0          88 KB  nvidia/label/cuda-11.7.1
    cuda-driver-dev-11.7.99    |                0          16 KB  nvidia/label/cuda-11.7.1
    cuda-gdb-11.7.91           |                0         4.8 MB  nvidia/label/cuda-11.7.1
    cuda-libraries-11.7.1      |                0           1 KB  nvidia/label/cuda-11.7.1
    cuda-libraries-dev-11.7.1  |                0           2 KB  nvidia/label/cuda-11.7.1
    cuda-memcheck-11.7.91      |                0         168 KB  nvidia/label/cuda-11.7.1
    cuda-nsight-11.7.91        |                0       113.6 MB  nvidia/label/cuda-11.7.1
    cuda-nsight-compute-11.7.1 |                0           1 KB  nvidia/label/cuda-11.7.1
    cuda-nvcc-11.7.99          |                0        42.7 MB  nvidia/label/cuda-11.7.1
    cuda-nvdisasm-11.7.91      |                0        31.5 MB  nvidia/label/cuda-11.7.1
    cuda-nvml-dev-11.7.91      |                0          80 KB  nvidia/label/cuda-11.7.1
    cuda-nvprof-11.7.101       |                0         4.3 MB  nvidia/label/cuda-11.7.1
    cuda-nvprune-11.7.91       |                0          64 KB  nvidia/label/cuda-11.7.1
    cuda-nvrtc-11.7.99         |                0        17.3 MB  nvidia/label/cuda-11.7.1
    cuda-nvrtc-dev-11.7.99     |                0        16.9 MB  nvidia/label/cuda-11.7.1
    cuda-nvtx-11.7.91          |                0          57 KB  nvidia/label/cuda-11.7.1
    cuda-nvvp-11.7.101         |                0       114.3 MB  nvidia/label/cuda-11.7.1
    cuda-runtime-11.7.1        |                0           1 KB  nvidia/label/cuda-11.7.1
    cuda-sanitizer-api-11.7.91 |                0        16.8 MB  nvidia/label/cuda-11.7.1
    cuda-toolkit-11.7.1        |                0           1 KB  nvidia/label/cuda-11.7.1
    cuda-tools-11.7.1          |                0           1 KB  nvidia/label/cuda-11.7.1
    cuda-visual-tools-11.7.1   |                0           1 KB  nvidia/label/cuda-11.7.1
    gds-tools-1.3.1.18         |                0           2 KB  nvidia/label/cuda-11.7.1
    intel-openmp-2022.1.0      |    h9e868ea_3769         4.5 MB
    libcublas-11.10.3.66       |                0       286.1 MB  nvidia/label/cuda-11.7.1
    libcublas-dev-11.10.3.66   |                0       296.4 MB  nvidia/label/cuda-11.7.1
    libcufft-10.7.2.91         |                0        93.6 MB  nvidia/label/cuda-11.7.1
    libcufft-dev-10.7.2.91     |                0       196.4 MB  nvidia/label/cuda-11.7.1
    libcufile-1.3.1.18         |                0         545 KB  nvidia/label/cuda-11.7.1
    libcufile-dev-1.3.1.18     |                0        12.4 MB  nvidia/label/cuda-11.7.1
    libcurand-10.2.10.91       |                0        50.3 MB  nvidia/label/cuda-11.7.1
    libcurand-dev-10.2.10.91   |                0        50.7 MB  nvidia/label/cuda-11.7.1
    libcusolver-11.4.0.1       |                0        78.7 MB  nvidia/label/cuda-11.7.1
    libcusolver-dev-11.4.0.1   |                0        55.9 MB  nvidia/label/cuda-11.7.1
    libcusparse-11.7.4.91      |                0       151.1 MB  nvidia/label/cuda-11.7.1
    libcusparse-dev-11.7.4.91  |                0       309.5 MB  nvidia/label/cuda-11.7.1
    libnpp-11.7.4.75           |                0       129.3 MB  nvidia/label/cuda-11.7.1
    libnpp-dev-11.7.4.75       |                0       126.6 MB  nvidia/label/cuda-11.7.1
    libnvjpeg-11.8.0.2         |                0         2.2 MB  nvidia/label/cuda-11.7.1
    libnvjpeg-dev-11.8.0.2     |                0         1.9 MB  nvidia/label/cuda-11.7.1
    mkl-2022.1.0               |     hc2b9512_224       129.7 MB
    nsight-compute-2022.2.1.3  |                0       461.4 MB  nvidia/label/cuda-11.7.1
    pytorch-1.13.0             |py3.10_cuda11.7_cudnn8.5.0_0        1.15 GB  pytorch
    pytorch-cuda-11.7          |       h67b0de4_0           7 KB  pytorch
    pytorch-mutex-1.0          |             cuda           3 KB  pytorch
    typing_extensions-4.3.0    |  py310h06a4308_0          42 KB
    xformers-0.0.15.dev337+git.fd21b40|py310_cu11.7_pyt1.13       105.3 MB  xformers/label/dev
    ------------------------------------------------------------
                                           Total:        4.01 GB

The following NEW packages will be INSTALLED:

  blas               pkgs/main/linux-64::blas-1.0-mkl None
  cuda               nvidia/label/cuda-11.7.1/linux-64::cuda-11.7.1-0 None
  cuda-cccl          nvidia/label/cuda-11.7.1/linux-64::cuda-cccl-11.7.91-0 None
  cuda-command-line~ nvidia/label/cuda-11.7.1/linux-64::cuda-command-line-tools-11.7.1-0 None
  cuda-compiler      nvidia/label/cuda-11.7.1/linux-64::cuda-compiler-11.7.1-0 None
  cuda-cudart        nvidia/label/cuda-11.7.1/linux-64::cuda-cudart-11.7.99-0 None
  cuda-cudart-dev    nvidia/label/cuda-11.7.1/linux-64::cuda-cudart-dev-11.7.99-0 None
  cuda-cuobjdump     nvidia/label/cuda-11.7.1/linux-64::cuda-cuobjdump-11.7.91-0 None
  cuda-cupti         nvidia/label/cuda-11.7.1/linux-64::cuda-cupti-11.7.101-0 None
  cuda-cuxxfilt      nvidia/label/cuda-11.7.1/linux-64::cuda-cuxxfilt-11.7.91-0 None
  cuda-demo-suite    nvidia/label/cuda-11.7.1/linux-64::cuda-demo-suite-11.7.91-0 None
  cuda-documentation nvidia/label/cuda-11.7.1/linux-64::cuda-documentation-11.7.91-0 None
  cuda-driver-dev    nvidia/label/cuda-11.7.1/linux-64::cuda-driver-dev-11.7.99-0 None
  cuda-gdb           nvidia/label/cuda-11.7.1/linux-64::cuda-gdb-11.7.91-0 None
  cuda-libraries     nvidia/label/cuda-11.7.1/linux-64::cuda-libraries-11.7.1-0 None
  cuda-libraries-dev nvidia/label/cuda-11.7.1/linux-64::cuda-libraries-dev-11.7.1-0 None
  cuda-memcheck      nvidia/label/cuda-11.7.1/linux-64::cuda-memcheck-11.7.91-0 None
  cuda-nsight        nvidia/label/cuda-11.7.1/linux-64::cuda-nsight-11.7.91-0 None
  cuda-nsight-compu~ nvidia/label/cuda-11.7.1/linux-64::cuda-nsight-compute-11.7.1-0 None
  cuda-nvcc          nvidia/label/cuda-11.7.1/linux-64::cuda-nvcc-11.7.99-0 None
  cuda-nvdisasm      nvidia/label/cuda-11.7.1/linux-64::cuda-nvdisasm-11.7.91-0 None
  cuda-nvml-dev      nvidia/label/cuda-11.7.1/linux-64::cuda-nvml-dev-11.7.91-0 None
  cuda-nvprof        nvidia/label/cuda-11.7.1/linux-64::cuda-nvprof-11.7.101-0 None
  cuda-nvprune       nvidia/label/cuda-11.7.1/linux-64::cuda-nvprune-11.7.91-0 None
  cuda-nvrtc         nvidia/label/cuda-11.7.1/linux-64::cuda-nvrtc-11.7.99-0 None
  cuda-nvrtc-dev     nvidia/label/cuda-11.7.1/linux-64::cuda-nvrtc-dev-11.7.99-0 None
  cuda-nvtx          nvidia/label/cuda-11.7.1/linux-64::cuda-nvtx-11.7.91-0 None
  cuda-nvvp          nvidia/label/cuda-11.7.1/linux-64::cuda-nvvp-11.7.101-0 None
  cuda-runtime       nvidia/label/cuda-11.7.1/linux-64::cuda-runtime-11.7.1-0 None
  cuda-sanitizer-api nvidia/label/cuda-11.7.1/linux-64::cuda-sanitizer-api-11.7.91-0 None
  cuda-toolkit       nvidia/label/cuda-11.7.1/linux-64::cuda-toolkit-11.7.1-0 None
  cuda-tools         nvidia/label/cuda-11.7.1/linux-64::cuda-tools-11.7.1-0 None
  cuda-visual-tools  nvidia/label/cuda-11.7.1/linux-64::cuda-visual-tools-11.7.1-0 None
  gds-tools          nvidia/label/cuda-11.7.1/linux-64::gds-tools-1.3.1.18-0 None
  intel-openmp       pkgs/main/linux-64::intel-openmp-2022.1.0-h9e868ea_3769 None
  libcublas          nvidia/label/cuda-11.7.1/linux-64::libcublas-11.10.3.66-0 None
  libcublas-dev      nvidia/label/cuda-11.7.1/linux-64::libcublas-dev-11.10.3.66-0 None
  libcufft           nvidia/label/cuda-11.7.1/linux-64::libcufft-10.7.2.91-0 None
  libcufft-dev       nvidia/label/cuda-11.7.1/linux-64::libcufft-dev-10.7.2.91-0 None
  libcufile          nvidia/label/cuda-11.7.1/linux-64::libcufile-1.3.1.18-0 None
  libcufile-dev      nvidia/label/cuda-11.7.1/linux-64::libcufile-dev-1.3.1.18-0 None
  libcurand          nvidia/label/cuda-11.7.1/linux-64::libcurand-10.2.10.91-0 None
  libcurand-dev      nvidia/label/cuda-11.7.1/linux-64::libcurand-dev-10.2.10.91-0 None
  libcusolver        nvidia/label/cuda-11.7.1/linux-64::libcusolver-11.4.0.1-0 None
  libcusolver-dev    nvidia/label/cuda-11.7.1/linux-64::libcusolver-dev-11.4.0.1-0 None
  libcusparse        nvidia/label/cuda-11.7.1/linux-64::libcusparse-11.7.4.91-0 None
  libcusparse-dev    nvidia/label/cuda-11.7.1/linux-64::libcusparse-dev-11.7.4.91-0 None
  libnpp             nvidia/label/cuda-11.7.1/linux-64::libnpp-11.7.4.75-0 None
  libnpp-dev         nvidia/label/cuda-11.7.1/linux-64::libnpp-dev-11.7.4.75-0 None
  libnvjpeg          nvidia/label/cuda-11.7.1/linux-64::libnvjpeg-11.8.0.2-0 None
  libnvjpeg-dev      nvidia/label/cuda-11.7.1/linux-64::libnvjpeg-dev-11.8.0.2-0 None
  mkl                pkgs/main/linux-64::mkl-2022.1.0-hc2b9512_224 None
  nsight-compute     nvidia/label/cuda-11.7.1/linux-64::nsight-compute-2022.2.1.3-0 None
  pytorch            pytorch/linux-64::pytorch-1.13.0-py3.10_cuda11.7_cudnn8.5.0_0 None
  pytorch-cuda       pytorch/noarch::pytorch-cuda-11.7-h67b0de4_0 None
  pytorch-mutex      pytorch/noarch::pytorch-mutex-1.0-cuda None
  typing_extensions  pkgs/main/linux-64::typing_extensions-4.3.0-py310h06a4308_0 None
  xformers           xformers/label/dev/linux-64::xformers-0.0.15.dev337+git.fd21b40-py310_cu11.7_pyt1.13 None


Proceed ([y]/n)? 

Note that i'm explicitly choosing to use -c nvidia/label/cuda-11.7.1 and not just -c nvidia with a version constraint of cuda-toolkit=11.7.1, as if I do the latter, then it appears to try and pull in cuda 11.8.x versions of things (rather than 11.7.1), which will no doubt cause stuff to break:

⇒ conda install -n "$CONDA_ENV_NAME" -c nvidia cuda-toolkit=11.7.1 (will seemingly incorrectly try and install 11.8.x things)

⇒ conda install -n "$CONDA_ENV_NAME" -c nvidia cuda-toolkit=11.7.1

Collecting package metadata (current_repodata.json): done
Solving environment: done

## Package Plan ##

  environment location: /opt/conda/envs/xformers

  added / updated specs:
    - cuda-toolkit=11.7.1


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    cuda-cccl-11.8.89          |                0         1.2 MB  nvidia
    cuda-command-line-tools-11.8.0|                0           1 KB  nvidia
    cuda-compiler-11.8.0       |                0           1 KB  nvidia
    cuda-cudart-11.8.89        |                0         197 KB  nvidia
    cuda-cudart-dev-11.8.89    |                0         1.1 MB  nvidia
    cuda-cuobjdump-11.8.86     |                0         229 KB  nvidia
    cuda-cupti-11.8.87         |                0        25.3 MB  nvidia
    cuda-cuxxfilt-11.8.86      |                0         291 KB  nvidia
    cuda-documentation-11.8.86 |                0          89 KB  nvidia
    cuda-driver-dev-11.8.89    |                0          16 KB  nvidia
    cuda-gdb-11.8.86           |                0         4.8 MB  nvidia
    cuda-libraries-11.8.0      |                0           1 KB  nvidia
    cuda-libraries-dev-11.8.0  |                0           2 KB  nvidia
    cuda-memcheck-11.8.86      |                0         168 KB  nvidia
    cuda-nsight-11.8.86        |                0       113.6 MB  nvidia
    cuda-nsight-compute-11.8.0 |                0           1 KB  nvidia
    cuda-nvcc-11.8.89          |                0        50.8 MB  nvidia
    cuda-nvdisasm-11.8.86      |                0        48.7 MB  nvidia
    cuda-nvml-dev-11.8.86      |                0          83 KB  nvidia
    cuda-nvprof-11.8.87        |                0         4.4 MB  nvidia
    cuda-nvprune-11.8.86       |                0          65 KB  nvidia
    cuda-nvrtc-11.8.89         |                0        19.1 MB  nvidia
    cuda-nvrtc-dev-11.8.89     |                0        17.0 MB  nvidia
    cuda-nvtx-11.8.86          |                0          57 KB  nvidia
    cuda-nvvp-11.8.87          |                0       114.4 MB  nvidia
    cuda-profiler-api-11.8.86  |                0          18 KB  nvidia
    cuda-sanitizer-api-11.8.86 |                0        16.6 MB  nvidia
    cuda-toolkit-11.7.1        |                0           1 KB  nvidia
    cuda-tools-11.8.0          |                0           1 KB  nvidia
    cuda-visual-tools-11.8.0   |                0           1 KB  nvidia
    gds-tools-1.4.0.31         |                0           2 KB  nvidia
    libcublas-11.11.3.6        |                0       364.0 MB  nvidia
    libcublas-dev-11.11.3.6    |                0       394.1 MB  nvidia
    libcufft-10.9.0.58         |                0       142.8 MB  nvidia
    libcufft-dev-10.9.0.58     |                0       275.8 MB  nvidia
    libcufile-1.4.0.31         |                0         548 KB  nvidia
    libcufile-dev-1.4.0.31     |                0         1.6 MB  nvidia
    libcurand-10.3.0.86        |                0        53.2 MB  nvidia
    libcurand-dev-10.3.0.86    |                0        53.7 MB  nvidia
    libcusolver-11.4.1.48      |                0        96.5 MB  nvidia
    libcusolver-dev-11.4.1.48  |                0        66.3 MB  nvidia
    libcusparse-11.7.5.86      |                0       176.3 MB  nvidia
    libcusparse-dev-11.7.5.86  |                0       359.7 MB  nvidia
    libnpp-11.8.0.86           |                0       147.8 MB  nvidia
    libnpp-dev-11.8.0.86       |                0       144.5 MB  nvidia
    libnvjpeg-11.9.0.86        |                0         2.4 MB  nvidia
    libnvjpeg-dev-11.9.0.86    |                0         2.1 MB  nvidia
    nsight-compute-2022.3.0.22 |                0       610.0 MB  nvidia
    ------------------------------------------------------------
                                           Total:        3.23 GB

The following NEW packages will be INSTALLED:

  cuda-cccl          nvidia/linux-64::cuda-cccl-11.8.89-0 None
  cuda-command-line~ nvidia/linux-64::cuda-command-line-tools-11.8.0-0 None
  cuda-compiler      nvidia/linux-64::cuda-compiler-11.8.0-0 None
  cuda-cudart        nvidia/linux-64::cuda-cudart-11.8.89-0 None
  cuda-cudart-dev    nvidia/linux-64::cuda-cudart-dev-11.8.89-0 None
  cuda-cuobjdump     nvidia/linux-64::cuda-cuobjdump-11.8.86-0 None
  cuda-cupti         nvidia/linux-64::cuda-cupti-11.8.87-0 None
  cuda-cuxxfilt      nvidia/linux-64::cuda-cuxxfilt-11.8.86-0 None
  cuda-documentation nvidia/linux-64::cuda-documentation-11.8.86-0 None
  cuda-driver-dev    nvidia/linux-64::cuda-driver-dev-11.8.89-0 None
  cuda-gdb           nvidia/linux-64::cuda-gdb-11.8.86-0 None
  cuda-libraries     nvidia/linux-64::cuda-libraries-11.8.0-0 None
  cuda-libraries-dev nvidia/linux-64::cuda-libraries-dev-11.8.0-0 None
  cuda-memcheck      nvidia/linux-64::cuda-memcheck-11.8.86-0 None
  cuda-nsight        nvidia/linux-64::cuda-nsight-11.8.86-0 None
  cuda-nsight-compu~ nvidia/linux-64::cuda-nsight-compute-11.8.0-0 None
  cuda-nvcc          nvidia/linux-64::cuda-nvcc-11.8.89-0 None
  cuda-nvdisasm      nvidia/linux-64::cuda-nvdisasm-11.8.86-0 None
  cuda-nvml-dev      nvidia/linux-64::cuda-nvml-dev-11.8.86-0 None
  cuda-nvprof        nvidia/linux-64::cuda-nvprof-11.8.87-0 None
  cuda-nvprune       nvidia/linux-64::cuda-nvprune-11.8.86-0 None
  cuda-nvrtc         nvidia/linux-64::cuda-nvrtc-11.8.89-0 None
  cuda-nvrtc-dev     nvidia/linux-64::cuda-nvrtc-dev-11.8.89-0 None
  cuda-nvtx          nvidia/linux-64::cuda-nvtx-11.8.86-0 None
  cuda-nvvp          nvidia/linux-64::cuda-nvvp-11.8.87-0 None
  cuda-profiler-api  nvidia/linux-64::cuda-profiler-api-11.8.86-0 None
  cuda-sanitizer-api nvidia/linux-64::cuda-sanitizer-api-11.8.86-0 None
  cuda-toolkit       nvidia/linux-64::cuda-toolkit-11.7.1-0 None
  cuda-tools         nvidia/linux-64::cuda-tools-11.8.0-0 None
  cuda-visual-tools  nvidia/linux-64::cuda-visual-tools-11.8.0-0 None
  gds-tools          nvidia/linux-64::gds-tools-1.4.0.31-0 None
  libcublas          nvidia/linux-64::libcublas-11.11.3.6-0 None
  libcublas-dev      nvidia/linux-64::libcublas-dev-11.11.3.6-0 None
  libcufft           nvidia/linux-64::libcufft-10.9.0.58-0 None
  libcufft-dev       nvidia/linux-64::libcufft-dev-10.9.0.58-0 None
  libcufile          nvidia/linux-64::libcufile-1.4.0.31-0 None
  libcufile-dev      nvidia/linux-64::libcufile-dev-1.4.0.31-0 None
  libcurand          nvidia/linux-64::libcurand-10.3.0.86-0 None
  libcurand-dev      nvidia/linux-64::libcurand-dev-10.3.0.86-0 None
  libcusolver        nvidia/linux-64::libcusolver-11.4.1.48-0 None
  libcusolver-dev    nvidia/linux-64::libcusolver-dev-11.4.1.48-0 None
  libcusparse        nvidia/linux-64::libcusparse-11.7.5.86-0 None
  libcusparse-dev    nvidia/linux-64::libcusparse-dev-11.7.5.86-0 None
  libnpp             nvidia/linux-64::libnpp-11.8.0.86-0 None
  libnpp-dev         nvidia/linux-64::libnpp-dev-11.8.0.86-0 None
  libnvjpeg          nvidia/linux-64::libnvjpeg-11.9.0.86-0 None
  libnvjpeg-dev      nvidia/linux-64::libnvjpeg-dev-11.9.0.86-0 None
  nsight-compute     nvidia/linux-64::nsight-compute-2022.3.0.22-0 None


Proceed ([y]/n)? 

@0xdevalias
Copy link

So after that giant deepdive, it seems that the TL;DR for resolving the following error:

UnsatisfiableError: The following specifications were found to be incompatible with each other:

Output in format: Requested package -> Available versionsThe following specifications were found to be incompatible with your system:

  - feature:/linux-64::__glibc==2.27=0
  - feature:|@/linux-64::__glibc==2.27=0

Is to ensure that you have all of the required -c conda repositories enabled when doing your conda install. I suspect any of these should work correctly, depending on your needs:

PyTorch 1.13, Cuda 11.7.1, xformers:

conda install -n "$CONDA_ENV_NAME" -c xformers/label/dev -c pytorch -c nvidia/label/cuda-11.7.1 xformers=*=py310_cu11.7_pyt1.13
conda install -n "$CONDA_ENV_NAME" -c xformers/label/dev -c pytorch -c nvidia/label/cuda-11.7.1 xformers
conda install -n "$CONDA_ENV_NAME" -c xformers/label/dev -c pytorch -c nvidia/label/cuda-11.7.1 pytorch=1.13 pytorch-cuda=11.7 xformers

PyTorch 1.12, Cuda 11.7.1, xformers:

conda install -n "$CONDA_ENV_NAME" -c xformers/label/dev -c pytorch -c nvidia/label/cuda-11.7.1 pytorch=1.12 pytorch-cuda=11.7 xformers

PyTorch 1.13, Cuda 11.6.2, xformers:

conda install -n "$CONDA_ENV_NAME" -c xformers/label/dev -c pytorch -c nvidia/label/cuda-11.6.2 pytorch=1.13 pytorch-cuda=11.6 xformers

PyTorch 1.12, Cuda 11.6.2, xformers:

!conda install -n "$CONDA_ENV_NAME" -c xformers/label/dev -c pytorch -c nvidia/label/cuda-11.6.2 pytorch=1.12 pytorch-cuda=11.6 xformers

Etc

@danthe3rd
Copy link
Contributor

Thanks for reporting those!
The correct command to install pytorch with conda should be on the pytorch website and look something like:

conda install pytorch torchvision torchaudio pytorch-cuda=11.6 -c pytorch -c nvidia

After that, you should be able to install xformers with

conda install xformers -c xformers/label/dev

@0xdevalias
Copy link

0xdevalias commented Nov 15, 2022

@danthe3rd Yup, that makes sense after the fact, though conda gives super misleading error messages if it's not already installed.

It would be good to explicitly mention in the installation docs that pytorch needs to be setup properly first, as the current docs don't really make that clear/obvious:

(RECOMMENDED) Using binaries: We provide binaries for Linux and recent PyTorch versions. Install xFormers with conda:

conda install xformers -c xformers/label/dev

And probably also include a note in the troubleshooting installation section about what those misleading conda errors actually mean (eg. "You didn't install pytorch properly") to save people getting stuck/wasting time chasing down deep rabbitholes like a few of us on this issue have gotten.

@danthe3rd
Copy link
Contributor

Great feedback - will update the README in #529

@0xdevalias
Copy link

0xdevalias commented Nov 15, 2022

After that, you should be able to install xformers with

conda install xformers -c xformers/label/dev

I'll need to go through the steps I figured out today in a clean environment again to be sure (won't have a chance to for ~1week or so at the earliest), but it seemed that even using the above command that should have installed the binary versions, in an environment that matched, when I ran the python -m xformers.info test script, it seemed to indicate that things weren't installed correctly; and it wasn't until I had built from source that that script seemed to indicate that it was working.

I did the PyTorch + Cuda + xformers intalls all in a single conda install, rather than one after the other, so i'm not sure if that would make any difference.

Just wanted to note that here while I remembered it.

@lucasjinreal
Copy link

FileNotFoundError: [Errno 2] No such file or directory: '/opt/conda/lib/python3.10/site-packages/xformers/cpp_lib.json'

got this while I installed from source, how?

bertmaher pushed a commit to bertmaher/xformers that referenced this issue Dec 20, 2024
* Add new baseline attention

It's so far 192x slower than the baseline

* Optimize by 5x

ldg on query brought 2% speedup, use of shared memory on key and value brought a 80% speedup and parallelizing over k brought another 2.8x speedup

* lint

* Use vector loads and stores

3.5x speedup

We are still 11x slower than baseline

* Reduce global reads/writes by tiling over key loop

Brings 70% speedup

* clang-format

* First naive implementation of sparsity

It is not 100% correct and might give different results sometimes

* Get inspiration from sputnik kernels

Seems to be 2x faster than sputnik, but needs further testing. Also doesnt have the bound checks

* Cleanup commented code and add back correctness checks

This brings slowdown but I'll fix this later

* Add tests

* Fix bug for K>32

* Load indices in parallel and store in shared memory

Also simplified the boundary condition checks. This makes the code almost as fast as without the bound checks, and without the indices reads

* Cleanups

* clang-format

* Cleanups

* Handle fully-masked rows

* Add scaling coefficient

* Add benchmark scripts

Need to fix utils

* Add checks and improve tests

* Minor cleanups

* Rename some variables for better clarity

* Make code 5% faster by adding stride inside indices

This will only help if dimension K is the same for keys and values

* Initialise m_i with m_prime for 3% speedup

Also do some variable renamings

* Minor cleanups

* Add implementation for backward

grad_q is fast, grad_k and grad_v is slow so far

* Add benchmark for backward

* [WIP] trying csr2csc

* Fix csr2csc

It's reasonably fast now and supports broadcast

* Remove legacy code

* Return permutation in csr_transpose

* Return the inverse permutation

More useful for us

* Optimize backward by up to 15x for larger sparsities

Uses temporary memory, and so far doesn't work for expanded index
tensors yet. Also needs to use vector reads / writes

Previous implementation

====== {'shape': (256, 1024, 32), 'num_threads': 1, 'sparsity': 0.0} ======
optimized: memory used: 261.00439453125 MB
vanilla: memory used: 4384.0 MB
sputnik: memory used: 3441.0166015625 MB
====== {'shape': (256, 1024, 32), 'num_threads': 1, 'sparsity': 0.5} ======
optimized: memory used: 259.00439453125 MB
vanilla: memory used: 6688.0 MB
sputnik: memory used: 1896.5009765625 MB
====== {'shape': (256, 1024, 32), 'num_threads': 1, 'sparsity': 0.7} ======
optimized: memory used: 258.2041015625 MB
vanilla: memory used: 6688.0 MB
sputnik: memory used: 1278.46728515625 MB
====== {'shape': (256, 1024, 32), 'num_threads': 1, 'sparsity': 0.8} ======
optimized: memory used: 257.80517578125 MB
vanilla: memory used: 6688.0 MB
sputnik: memory used: 970.384765625 MB
====== {'shape': (256, 1024, 32), 'num_threads': 1, 'sparsity': 0.9} ======
optimized: memory used: 257.404296875 MB
vanilla: memory used: 6688.0 MB
sputnik: memory used: 660.82958984375 MB
====== {'shape': (256, 1024, 32), 'num_threads': 1, 'sparsity': 0.95} ======
optimized: memory used: 257.205078125 MB
vanilla: memory used: 6688.0 MB
sputnik: memory used: 506.86572265625 MB
====== {'shape': (256, 1024, 32), 'num_threads': 1, 'sparsity': 0.99} ======
optimized: memory used: 257.0439453125 MB
vanilla: memory used: 6688.0 MB
sputnik: memory used: 382.53662109375 MB
defaultdict(<class 'dict'>,
            {'optimized': {'B=256, M=1024, K=32, sparsity=0.0': 261.00439453125,
                           'B=256, M=1024, K=32, sparsity=0.5': 259.00439453125,
                           'B=256, M=1024, K=32, sparsity=0.7': 258.2041015625,
                           'B=256, M=1024, K=32, sparsity=0.8': 257.80517578125,
                           'B=256, M=1024, K=32, sparsity=0.9': 257.404296875,
                           'B=256, M=1024, K=32, sparsity=0.95': 257.205078125,
                           'B=256, M=1024, K=32, sparsity=0.99': 257.0439453125},
             'sputnik': {'B=256, M=1024, K=32, sparsity=0.0': 3441.0166015625,
                         'B=256, M=1024, K=32, sparsity=0.5': 1896.5009765625,
                         'B=256, M=1024, K=32, sparsity=0.7': 1278.46728515625,
                         'B=256, M=1024, K=32, sparsity=0.8': 970.384765625,
                         'B=256, M=1024, K=32, sparsity=0.9': 660.82958984375,
                         'B=256, M=1024, K=32, sparsity=0.95': 506.86572265625,
                         'B=256, M=1024, K=32, sparsity=0.99': 382.53662109375},
             'vanilla': {'B=256, M=1024, K=32, sparsity=0.0': 4384.0,
                         'B=256, M=1024, K=32, sparsity=0.5': 6688.0,
                         'B=256, M=1024, K=32, sparsity=0.7': 6688.0,
                         'B=256, M=1024, K=32, sparsity=0.8': 6688.0,
                         'B=256, M=1024, K=32, sparsity=0.9': 6688.0,
                         'B=256, M=1024, K=32, sparsity=0.95': 6688.0,
                         'B=256, M=1024, K=32, sparsity=0.99': 6688.0}})
[---------------------------- attention backward ----------------------------]
                                          |  optimized  |  vanilla  |  sputnik
1 threads: -------------------------------------------------------------------
      B=256, M=1024, K=32, sparsity=0.0   |    475.9    |    33.3   |    91.6
      B=256, M=1024, K=32, sparsity=0.5   |    255.4    |    41.0   |    50.3
      B=256, M=1024, K=32, sparsity=0.7   |    183.5    |    41.5   |    31.6
      B=256, M=1024, K=32, sparsity=0.8   |    144.6    |    41.4   |    23.0
      B=256, M=1024, K=32, sparsity=0.9   |     97.9    |    41.6   |    12.7
      B=256, M=1024, K=32, sparsity=0.95  |     77.3    |    41.7   |     8.1
      B=256, M=1024, K=32, sparsity=0.99  |     86.5    |    42.1   |     2.3

Times are in milliseconds (ms).

new implem

====== {'shape': (256, 1024, 32), 'num_threads': 1, 'sparsity': 0.0} ======
optimized: memory used: 5379.001953125 MB
vanilla: memory used: 4384.0 MB
sputnik: memory used: 3441.0166015625 MB
====== {'shape': (256, 1024, 32), 'num_threads': 1, 'sparsity': 0.5} ======
optimized: memory used: 2818.845703125 MB
vanilla: memory used: 6688.0 MB
sputnik: memory used: 1896.3583984375 MB
====== {'shape': (256, 1024, 32), 'num_threads': 1, 'sparsity': 0.7} ======
optimized: memory used: 1798.7041015625 MB
vanilla: memory used: 6688.0 MB
sputnik: memory used: 1280.94189453125 MB
====== {'shape': (256, 1024, 32), 'num_threads': 1, 'sparsity': 0.8} ======
optimized: memory used: 1282.830078125 MB
vanilla: memory used: 6688.0 MB
sputnik: memory used: 969.71435546875 MB
====== {'shape': (256, 1024, 32), 'num_threads': 1, 'sparsity': 0.9} ======
optimized: memory used: 768.1767578125 MB
vanilla: memory used: 6688.0 MB
sputnik: memory used: 659.2041015625 MB
====== {'shape': (256, 1024, 32), 'num_threads': 1, 'sparsity': 0.95} ======
optimized: memory used: 513.1962890625 MB
vanilla: memory used: 6688.0 MB
sputnik: memory used: 505.3720703125 MB
====== {'shape': (256, 1024, 32), 'num_threads': 1, 'sparsity': 0.99} ======
optimized: memory used: 310.06640625 MB
vanilla: memory used: 6688.0 MB
sputnik: memory used: 382.81982421875 MB
[------------------------------------ attention backward ----------------------]
                                          |  optimized  |  vanilla  |  sputnik
1 threads: --------------------------------------------------------------------
      B=256, M=1024, K=32, sparsity=0.0   |    315.8    |    33.3   |    91.8
      B=256, M=1024, K=32, sparsity=0.5   |    265.6    |    41.4   |    50.6
      B=256, M=1024, K=32, sparsity=0.7   |    150.7    |    41.9   |    31.7
      B=256, M=1024, K=32, sparsity=0.8   |     98.2    |    41.6   |    22.6
      B=256, M=1024, K=32, sparsity=0.9   |     43.9    |    42.0   |    12.6
      B=256, M=1024, K=32, sparsity=0.95  |     19.3    |    41.8   |     8.0
      B=256, M=1024, K=32, sparsity=0.99  |      5.2    |    42.2   |     2.3

Times are in milliseconds (ms).

* Fix backward for expanded strides

Now uses much less memory, and is also faster

====== {'shape': (256, 1024, 32), 'num_threads': 1, 'sparsity': 0.0} ======
optimized: memory used: 2317.0087890625 MB
vanilla: memory used: 4384.0 MB
sputnik: memory used: 3441.0166015625 MB
====== {'shape': (256, 1024, 32), 'num_threads': 1, 'sparsity': 0.5} ======
optimized: memory used: 1287.4765625 MB
vanilla: memory used: 6688.0 MB
sputnik: memory used: 1897.21240234375 MB
====== {'shape': (256, 1024, 32), 'num_threads': 1, 'sparsity': 0.7} ======
optimized: memory used: 875.3642578125 MB
vanilla: memory used: 6688.0 MB
sputnik: memory used: 1279.24462890625 MB
====== {'shape': (256, 1024, 32), 'num_threads': 1, 'sparsity': 0.8} ======
optimized: memory used: 668.93701171875 MB
vanilla: memory used: 6688.0 MB
sputnik: memory used: 969.70263671875 MB
====== {'shape': (256, 1024, 32), 'num_threads': 1, 'sparsity': 0.9} ======
optimized: memory used: 462.31494140625 MB
vanilla: memory used: 6688.0 MB
sputnik: memory used: 659.8759765625 MB
====== {'shape': (256, 1024, 32), 'num_threads': 1, 'sparsity': 0.95} ======
optimized: memory used: 359.31787109375 MB
vanilla: memory used: 6688.0 MB
sputnik: memory used: 505.427734375 MB
====== {'shape': (256, 1024, 32), 'num_threads': 1, 'sparsity': 0.99} ======
optimized: memory used: 277.9248046875 MB
vanilla: memory used: 6688.0 MB
sputnik: memory used: 384.1904296875 MB
defaultdict(<class 'dict'>,
            {'optimized': {'B=256, M=1024, K=32, sparsity=0.0': 2317.0087890625,
                           'B=256, M=1024, K=32, sparsity=0.5': 1287.4765625,
                           'B=256, M=1024, K=32, sparsity=0.7': 875.3642578125,
                           'B=256, M=1024, K=32, sparsity=0.8': 668.93701171875,
                           'B=256, M=1024, K=32, sparsity=0.9': 462.31494140625,
                           'B=256, M=1024, K=32, sparsity=0.95': 359.31787109375,
                           'B=256, M=1024, K=32, sparsity=0.99': 277.9248046875},
             'sputnik': {'B=256, M=1024, K=32, sparsity=0.0': 3441.0166015625,
                         'B=256, M=1024, K=32, sparsity=0.5': 1897.21240234375,
                         'B=256, M=1024, K=32, sparsity=0.7': 1279.24462890625,
                         'B=256, M=1024, K=32, sparsity=0.8': 969.70263671875,
                         'B=256, M=1024, K=32, sparsity=0.9': 659.8759765625,
                         'B=256, M=1024, K=32, sparsity=0.95': 505.427734375,
                         'B=256, M=1024, K=32, sparsity=0.99': 384.1904296875},
             'vanilla': {'B=256, M=1024, K=32, sparsity=0.0': 4384.0,
                         'B=256, M=1024, K=32, sparsity=0.5': 6688.0,
                         'B=256, M=1024, K=32, sparsity=0.7': 6688.0,
                         'B=256, M=1024, K=32, sparsity=0.8': 6688.0,
                         'B=256, M=1024, K=32, sparsity=0.9': 6688.0,
                         'B=256, M=1024, K=32, sparsity=0.95': 6688.0,
                         'B=256, M=1024, K=32, sparsity=0.99': 6688.0}})
[---------------------------- attention backward ----------------------------]
                                          |  optimized  |  vanilla  |  sputnik
1 threads: -------------------------------------------------------------------
      B=256, M=1024, K=32, sparsity=0.0   |    205.8    |    32.9   |    91.6
      B=256, M=1024, K=32, sparsity=0.5   |    159.6    |    41.4   |    51.0
      B=256, M=1024, K=32, sparsity=0.7   |     90.6    |    41.7   |    31.6
      B=256, M=1024, K=32, sparsity=0.8   |     57.3    |    41.4   |    22.7
      B=256, M=1024, K=32, sparsity=0.9   |     26.2    |    41.6   |    12.5
      B=256, M=1024, K=32, sparsity=0.95  |     12.3    |    42.0   |     8.0
      B=256, M=1024, K=32, sparsity=0.99  |      4.6    |    42.2   |     2.3

Times are in milliseconds (ms).

* Load matrices to shared memory

* Delete unused code

* Avoid local reads from attn and g_attn

Also puts indices in cache. This brings up to 2x speedup, as
before we were stalled reading from local (i.e., slow) memory for attn
and g_attn, which was already available in registers just before.

Before

[------------------------------------ attention backward ------------------------------------]
                                          |  optimized  |  vanilla  |  sputnik  |  permutation
1 threads: -----------------------------------------------------------------------------------
      B=256, M=1024, K=32, sparsity=0.0   |    97312.9  |  16790.6  |  38410.3  |     9096.5
      B=256, M=1024, K=32, sparsity=0.5   |   106913.2  |  22151.2  |  22155.6  |     3447.4
      B=256, M=1024, K=32, sparsity=0.7   |    50923.9  |  22093.2  |  15086.7  |     1917.2
      B=256, M=1024, K=32, sparsity=0.8   |    30387.6  |  22106.9  |  10875.4  |     1166.4
      B=256, M=1024, K=32, sparsity=0.9   |    13409.7  |  22171.4  |   7530.2  |      551.8
      B=256, M=1024, K=32, sparsity=0.95  |     6470.7  |  22122.1  |   5724.2  |      272.8
      B=256, M=1024, K=32, sparsity=0.99  |     2502.6  |  22139.1  |   1568.1  |       54.8

After
[------------------------------------ attention backward ------------------------------------]
                                          |  optimized  |  vanilla  |  sputnik  |  permutation
1 threads: -----------------------------------------------------------------------------------
      B=256, M=1024, K=32, sparsity=0.0   |   43171.3   |  16791.7  |  38693.6  |     9063.6
      B=256, M=1024, K=32, sparsity=0.5   |   48063.5   |  22213.2  |  22419.0  |     3459.6
      B=256, M=1024, K=32, sparsity=0.7   |   20605.6   |  22197.0  |  15003.4  |     1860.7
      B=256, M=1024, K=32, sparsity=0.8   |   14178.5   |  22201.5  |  10859.1  |     1144.8
      B=256, M=1024, K=32, sparsity=0.9   |    7677.4   |  22236.0  |   7555.1  |      551.0
      B=256, M=1024, K=32, sparsity=0.95  |    4239.9   |  22277.9  |   5707.1  |      273.3
      B=256, M=1024, K=32, sparsity=0.99  |    1681.8   |  22210.4  |   1569.8  |       55.0

* Use AllReduce in attn

Provides some extra speedup by avoiding repeated computation

Before:
[------------------------------------ attention backward ------------------------------------]
                                          |  optimized  |  vanilla  |  sputnik  |  permutation
1 threads: -----------------------------------------------------------------------------------
      B=256, M=1024, K=32, sparsity=0.0   |   43171.3   |  16791.7  |  38693.6  |     9063.6
      B=256, M=1024, K=32, sparsity=0.5   |   48063.5   |  22213.2  |  22419.0  |     3459.6
      B=256, M=1024, K=32, sparsity=0.7   |   20605.6   |  22197.0  |  15003.4  |     1860.7
      B=256, M=1024, K=32, sparsity=0.8   |   14178.5   |  22201.5  |  10859.1  |     1144.8
      B=256, M=1024, K=32, sparsity=0.9   |    7677.4   |  22236.0  |   7555.1  |      551.0
      B=256, M=1024, K=32, sparsity=0.95  |    4239.9   |  22277.9  |   5707.1  |      273.3
      B=256, M=1024, K=32, sparsity=0.99  |    1681.8   |  22210.4  |   1569.8  |       55.0

After:
[------------------------------------ attention backward ------------------------------------]
                                          |  optimized  |  vanilla  |  sputnik  |  permutation
1 threads: -----------------------------------------------------------------------------------
      B=256, M=1024, K=32, sparsity=0.0   |   41700.6   |  16818.5  |  38704.0  |     9075.2
      B=256, M=1024, K=32, sparsity=0.5   |   42627.3   |  22190.8  |  22432.3  |     3470.6
      B=256, M=1024, K=32, sparsity=0.7   |   19311.2   |  22238.3  |  15008.4  |     1871.5
      B=256, M=1024, K=32, sparsity=0.8   |   13390.8   |  22210.8  |  10890.7  |     1155.4
      B=256, M=1024, K=32, sparsity=0.9   |    7283.9   |  22167.4  |   7519.6  |      547.1
      B=256, M=1024, K=32, sparsity=0.95  |    4029.3   |  22172.9  |   5737.9  |      276.8
      B=256, M=1024, K=32, sparsity=0.99  |    1571.9   |  22182.8  |   1578.2  |       55.9

* Use AllReduce in gattn as well

Provides further speedup on almost all cases (except for sparsity=50%,
but need to investigate)

Before
[------------------------------------ attention backward ------------------------------------]
                                          |  optimized  |  vanilla  |  sputnik  |  permutation
1 threads: -----------------------------------------------------------------------------------
      B=256, M=1024, K=32, sparsity=0.0   |   41700.6   |  16818.5  |  38704.0  |     9075.2
      B=256, M=1024, K=32, sparsity=0.5   |   42627.3   |  22190.8  |  22432.3  |     3470.6
      B=256, M=1024, K=32, sparsity=0.7   |   19311.2   |  22238.3  |  15008.4  |     1871.5
      B=256, M=1024, K=32, sparsity=0.8   |   13390.8   |  22210.8  |  10890.7  |     1155.4
      B=256, M=1024, K=32, sparsity=0.9   |    7283.9   |  22167.4  |   7519.6  |      547.1
      B=256, M=1024, K=32, sparsity=0.95  |    4029.3   |  22172.9  |   5737.9  |      276.8
      B=256, M=1024, K=32, sparsity=0.99  |    1571.9   |  22182.8  |   1578.2  |       55.9

After
[------------------------------------ attention backward ------------------------------------]
                                          |  optimized  |  vanilla  |  sputnik  |  permutation
1 threads: -----------------------------------------------------------------------------------
      B=256, M=1024, K=32, sparsity=0.0   |   35632.9   |  16806.3  |  38725.3  |     9073.3
      B=256, M=1024, K=32, sparsity=0.5   |   57879.6   |  22237.0  |  22411.2  |     3469.7
      B=256, M=1024, K=32, sparsity=0.7   |   18370.7   |  22203.4  |  14988.8  |     1826.1
      B=256, M=1024, K=32, sparsity=0.8   |   12375.0   |  22272.6  |  10879.8  |     1149.1
      B=256, M=1024, K=32, sparsity=0.9   |    6626.8   |  22218.9  |   7530.6  |      547.6
      B=256, M=1024, K=32, sparsity=0.95  |    3543.7   |  22237.3  |   5697.3  |      273.0
      B=256, M=1024, K=32, sparsity=0.99  |    1301.4   |  22257.6  |   1565.6  |       54.6

* Up to 20% speedup with better hyperparameters

Before
[------------------------------------ attention backward ------------------------------------]
                                          |  optimized  |  vanilla  |  sputnik  |  permutation
1 threads: -----------------------------------------------------------------------------------
      B=256, M=1024, K=32, sparsity=0.0   |   35632.9   |  16806.3  |  38725.3  |     9073.3
      B=256, M=1024, K=32, sparsity=0.5   |   57879.6   |  22237.0  |  22411.2  |     3469.7
      B=256, M=1024, K=32, sparsity=0.7   |   18370.7   |  22203.4  |  14988.8  |     1826.1
      B=256, M=1024, K=32, sparsity=0.8   |   12375.0   |  22272.6  |  10879.8  |     1149.1
      B=256, M=1024, K=32, sparsity=0.9   |    6626.8   |  22218.9  |   7530.6  |      547.6
      B=256, M=1024, K=32, sparsity=0.95  |    3543.7   |  22237.3  |   5697.3  |      273.0
      B=256, M=1024, K=32, sparsity=0.99  |    1301.4   |  22257.6  |   1565.6  |       54.6

After
[------------------------------------ attention backward ------------------------------------]
                                          |  optimized  |  vanilla  |  sputnik  |  permutation
1 threads: -----------------------------------------------------------------------------------
      B=256, M=1024, K=32, sparsity=0.0   |   34498.9   |  16805.3  |  38477.0  |     9110.1
      B=256, M=1024, K=32, sparsity=0.5   |   45049.4   |  22157.0  |  22546.3  |     3484.0
      B=256, M=1024, K=32, sparsity=0.7   |   16594.9   |  22189.6  |  15107.5  |     1868.0
      B=256, M=1024, K=32, sparsity=0.8   |   11207.1   |  22203.4  |  10878.6  |     1147.5
      B=256, M=1024, K=32, sparsity=0.9   |    6229.1   |  22205.4  |   7538.9  |      550.8
      B=256, M=1024, K=32, sparsity=0.95  |    3450.3   |  22225.4  |   5712.1  |      272.7
      B=256, M=1024, K=32, sparsity=0.99  |    1300.5   |  22145.1  |   1575.2  |       55.9

* Optimize csr_transpose with 2d grid

This brings up to 2x+ speedup when the indices are not expanded

Before
[------------------------------------ attention backward ------------------------------------]
                                          |  optimized  |  vanilla  |  sputnik  |  permutation
1 threads: -----------------------------------------------------------------------------------
      B=256, M=1024, K=32, sparsity=0.0   |   120820.7  |  16812.4  |  38074.4  |     9098.3
      B=256, M=1024, K=32, sparsity=0.5   |   111207.0  |  22165.4  |  22288.1  |     3424.7
      B=256, M=1024, K=32, sparsity=0.7   |    56921.9  |  22133.9  |  15027.0  |     1910.0
      B=256, M=1024, K=32, sparsity=0.8   |    37642.5  |  22164.7  |  10869.8  |     1137.2
      B=256, M=1024, K=32, sparsity=0.9   |    18317.1  |  22238.1  |   7524.9  |      547.5
      B=256, M=1024, K=32, sparsity=0.95  |     8361.0  |  22176.8  |   5713.9  |      272.4
      B=256, M=1024, K=32, sparsity=0.99  |     1633.5  |  22187.4  |   1578.0  |       55.7

After
[------------------------------------ attention backward ------------------------------------]
                                          |  optimized  |  vanilla  |  sputnik  |  permutation
1 threads: -----------------------------------------------------------------------------------
      B=256, M=1024, K=32, sparsity=0.0   |   58574.9   |  16826.3  |  37989.3  |     9108.3
      B=256, M=1024, K=32, sparsity=0.5   |   53964.1   |  22133.1  |  22271.3  |     3468.4
      B=256, M=1024, K=32, sparsity=0.7   |   26181.9   |  22164.0  |  15032.5  |     1900.4
      B=256, M=1024, K=32, sparsity=0.8   |   16494.2   |  22201.7  |  10833.9  |     1158.2
      B=256, M=1024, K=32, sparsity=0.9   |    8691.6   |  22156.6  |   7504.8  |      543.3
      B=256, M=1024, K=32, sparsity=0.95  |    4801.8   |  22179.5  |   5730.0  |      273.4
      B=257, M=1024, K=32, sparsity=0.99  |    1563.4   |  22147.2  |   1566.1  |       55.7

* Bugfixes

* Add tests

* Add tests for expanded indices

* Minor cleanups

* Optimize forward by 20% with AllReduce

This better splits the work between threads

On V100

Before
[-------------------------------- attention ---------------------------------]
                                          |  optimized  |  vanilla  |  sputnik
1 threads: -------------------------------------------------------------------
      B=256, M=1024, K=32, sparsity=0.0   |   10490.0   |   7663.6  |  16984.8
      B=256, M=1024, K=32, sparsity=0.5   |    6539.2   |  23591.3  |   9597.2
      B=256, M=1024, K=32, sparsity=0.7   |    4370.7   |  23584.3  |   7026.9
      B=256, M=1024, K=32, sparsity=0.8   |    3166.8   |  23586.1  |   5232.3
      B=256, M=1024, K=32, sparsity=0.9   |    1896.1   |  23606.3  |   4254.4
      B=256, M=1024, K=32, sparsity=0.95  |    1124.2   |  23600.3  |   3768.1
      B=256, M=1024, K=32, sparsity=0.99  |     495.9   |  23588.3  |    823.9

After
[-------------------------------- attention ---------------------------------]
                                          |  optimized  |  vanilla  |  sputnik
1 threads: -------------------------------------------------------------------
      B=256, M=1024, K=32, sparsity=0.0   |    8476.5   |   7735.0  |  16995.9
      B=256, M=1024, K=32, sparsity=0.5   |    5336.1   |  23553.4  |   9547.1
      B=256, M=1024, K=32, sparsity=0.7   |    3541.8   |  23536.2  |   7030.6
      B=256, M=1024, K=32, sparsity=0.8   |    2595.9   |  23626.3  |   5221.0
      B=256, M=1024, K=32, sparsity=0.9   |    1531.5   |  23644.4  |   4248.3
      B=256, M=1024, K=32, sparsity=0.95  |     995.1   |  23585.0  |   3766.8
      B=256, M=1024, K=32, sparsity=0.99  |     413.8   |  23634.7  |    815.0

On P100

Before
[-------------------------------- attention ---------------------------------]
                                          |  optimized  |  vanilla  |  sputnik
1 threads: -------------------------------------------------------------------
      B=256, M=1024, K=32, sparsity=0.0   |     28.9    |    14.0   |    35.5
      B=256, M=1024, K=32, sparsity=0.5   |     16.6    |    38.6   |    20.2
      B=256, M=1024, K=32, sparsity=0.7   |     10.5    |    38.6   |    13.3
      B=256, M=1024, K=32, sparsity=0.8   |      7.3    |    38.6   |     9.9
      B=256, M=1024, K=32, sparsity=0.9   |      4.1    |    38.6   |     6.3
      B=256, M=1024, K=32, sparsity=0.95  |      2.3    |    38.5   |     4.7
      B=256, M=1024, K=32, sparsity=0.99  |      1.1    |    38.3   |     1.2

Times are in milliseconds (ms).

After
[-------------------------------- attention ---------------------------------]
                                          |  optimized  |  vanilla  |  sputnik
1 threads: -------------------------------------------------------------------
      B=256, M=1024, K=32, sparsity=0.0   |   26319.6   |  13992.5  |  35386.8
      B=256, M=1024, K=32, sparsity=0.5   |   14612.6   |  38586.0  |  20171.3
      B=256, M=1024, K=32, sparsity=0.7   |    9413.2   |  38582.5  |  13197.9
      B=256, M=1024, K=32, sparsity=0.8   |    6652.5   |  38583.7  |   9939.2
      B=256, M=1024, K=32, sparsity=0.9   |    3606.4   |  38566.4  |   6272.9
      B=256, M=1024, K=32, sparsity=0.95  |    1952.6   |  38478.3  |   4667.2
      B=256, M=1024, K=32, sparsity=0.99  |     969.6   |  38297.5  |   1163.9

Times are in microseconds (us).

* 10% speedup with better hyperparameters on forward

On V100

Before
[-------------------------------- attention ---------------------------------]
                                          |  optimized  |  vanilla  |  sputnik
1 threads: -------------------------------------------------------------------
      B=256, M=1024, K=32, sparsity=0.0   |    8476.5   |   7735.0  |  16995.9
      B=256, M=1024, K=32, sparsity=0.5   |    5336.1   |  23553.4  |   9547.1
      B=256, M=1024, K=32, sparsity=0.7   |    3541.8   |  23536.2  |   7030.6
      B=256, M=1024, K=32, sparsity=0.8   |    2595.9   |  23626.3  |   5221.0
      B=256, M=1024, K=32, sparsity=0.9   |    1531.5   |  23644.4  |   4248.3
      B=256, M=1024, K=32, sparsity=0.95  |     995.1   |  23585.0  |   3766.8
      B=256, M=1024, K=32, sparsity=0.99  |     413.8   |  23634.7  |    815.0

After
[-------------------------------- attention ---------------------------------]
                                          |  optimized  |  vanilla  |  sputnik
1 threads: -------------------------------------------------------------------
      B=256, M=1024, K=32, sparsity=0.0   |    8347.5   |   7700.6  |  16969.6
      B=256, M=1024, K=32, sparsity=0.5   |    4883.9   |  23566.4  |   9562.1
      B=256, M=1024, K=32, sparsity=0.7   |    3199.1   |  23631.1  |   7021.5
      B=256, M=1024, K=32, sparsity=0.8   |    2342.8   |  23577.2  |   5247.7
      B=256, M=1024, K=32, sparsity=0.9   |    1457.6   |  23527.2  |   4259.0
      B=256, M=1024, K=32, sparsity=0.95  |     918.2   |  23595.1  |   3766.9
      B=256, M=1024, K=32, sparsity=0.99  |     413.0   |  23625.9  |    829.6

* Optimize backward by up to 2x for large sequences

It is cheaper to perform uncoalesced loads than stores.

This is particularly important for very large sequences

On V100

Before:

Contiguous indices
[---------------------- attention backward ----------------------]
                                          |  optimized  |  sputnik
1 threads: -------------------------------------------------------
      B=16, M=16384, K=32, sparsity=0.96  |    206.3    |    89.3

Times are in milliseconds (ms).

Expanded indices
[---------------------- attention backward ----------------------]
                                          |  optimized  |  sputnik
1 threads: -------------------------------------------------------
      B=16, M=16384, K=32, sparsity=0.96  |    187.5    |    89.4

Times are in milliseconds (ms).

After

Contiguous indices
[---------------------- attention backward ----------------------]
                                          |  optimized  |  sputnik
1 threads: -------------------------------------------------------
      B=16, M=16384, K=32, sparsity=0.96  |    121.3    |    89.1

Times are in milliseconds (ms).

Expanded indices
[---------------------- attention backward ----------------------]
                                          |  optimized  |  sputnik
1 threads: -------------------------------------------------------
      B=16, M=16384, K=32, sparsity=0.96  |     70.6    |    89.2

Times are in milliseconds (ms).

Performance for previous benchmarks are still good

Contiguous indices
[---------------------- attention backward ----------------------]
                                          |  optimized  |  sputnik
1 threads: -------------------------------------------------------
      B=256, M=1024, K=32, sparsity=0.0   |     60.1    |    38.1
      B=256, M=1024, K=32, sparsity=0.5   |     38.6    |    22.0
      B=256, M=1024, K=32, sparsity=0.7   |     25.0    |    14.8
      B=256, M=1024, K=32, sparsity=0.8   |     16.3    |    10.8
      B=256, M=1024, K=32, sparsity=0.9   |      8.9    |     7.6
      B=256, M=1024, K=32, sparsity=0.95  |      5.0    |     5.7
      B=256, M=1024, K=32, sparsity=0.99  |      1.6    |     1.6

Times are in milliseconds (ms).

Expanded indices

[---------------------- attention backward ----------------------]
                                          |  optimized  |  sputnik
1 threads: -------------------------------------------------------
      B=256, M=1024, K=32, sparsity=0.0   |     32.0    |    38.3
      B=256, M=1024, K=32, sparsity=0.5   |     23.0    |    22.0
      B=256, M=1024, K=32, sparsity=0.7   |     13.9    |    15.0
      B=256, M=1024, K=32, sparsity=0.8   |      9.7    |    10.8
      B=256, M=1024, K=32, sparsity=0.9   |      5.5    |     7.5
      B=256, M=1024, K=32, sparsity=0.95  |      3.2    |     5.7
      B=256, M=1024, K=32, sparsity=0.99  |      1.2    |     1.6

Times are in milliseconds (ms).

* Optimize backward by 20% more on large cases

Needs cleanup. Saves on uncoalesced writes and reads

V100

Contiguous indices
Before
[---------------------- attention backward ----------------------]
                                          |  optimized  |  sputnik
1 threads: -------------------------------------------------------
      B=16, M=16384, K=32, sparsity=0.96  |    121.3    |    89.1

Times are in milliseconds (ms).

After
[---------------------- attention backward ----------------------]
                                          |  optimized  |  sputnik
1 threads: -------------------------------------------------------
      B=16, M=16384, K=32, sparsity=0.96  |     99.8    |    89.2

Times are in milliseconds (ms).

Expanded indices
Before
[---------------------- attention backward ----------------------]
                                          |  optimized  |  sputnik
1 threads: -------------------------------------------------------
      B=16, M=16384, K=32, sparsity=0.96  |     70.6    |    89.2

Times are in milliseconds (ms).

After
[---------------------- attention backward ----------------------]
                                          |  optimized  |  sputnik
1 threads: -------------------------------------------------------
      B=16, M=16384, K=32, sparsity=0.96  |     68.7    |    89.4

Times are in milliseconds (ms).

* Optimize backward by up to 26% more for large cases

Performs vector reads, saves on uncoalesced reads and writes

V100

Contiguous indices
Before
[---------------------- attention backward ----------------------]
                                          |  optimized  |  sputnik
1 threads: -------------------------------------------------------
      B=16, M=16384, K=32, sparsity=0.96  |     99.8    |    89.2

Times are in milliseconds (ms).

After
[---------------------- attention backward ----------------------]
                                          |  optimized  |  sputnik
1 threads: -------------------------------------------------------
      B=16, M=16384, K=32, sparsity=0.96  |     81.0    |    89.1

Times are in milliseconds (ms).

Expanded indices
Before
[---------------------- attention backward ----------------------]
                                          |  optimized  |  sputnik
1 threads: -------------------------------------------------------
      B=16, M=16384, K=32, sparsity=0.96  |     68.7    |    89.4

Times are in milliseconds (ms).

After
[---------------------- attention backward ----------------------]
                                          |  optimized  |  sputnik
1 threads: -------------------------------------------------------
      B=16, M=16384, K=32, sparsity=0.96  |     51.0    |    89.3

Times are in milliseconds (ms).

* Add attention bias to sparse kernels
So far it assumes that attention bias is a dense matrix of the
shape of the attention, but which can be expanded on all but the last
dimension

Before

[-------------------------- attention --------------------------]
                                         |  optimized  |  sputnik
1 threads: ------------------------------------------------------
      B=16, M=1024, K=32, sparsity=0.0   |    587.7    |   1070.6
      B=16, M=1024, K=32, sparsity=0.5   |    339.8    |    642.4
      B=16, M=1024, K=32, sparsity=0.7   |    231.0    |    497.0
      B=16, M=1024, K=32, sparsity=0.8   |    171.8    |    489.9
      B=16, M=1024, K=32, sparsity=0.9   |    113.5    |    511.4
      B=16, M=1024, K=32, sparsity=0.95  |     71.5    |    504.0
      B=16, M=1024, K=32, sparsity=0.99  |     30.4    |    508.9

[---------------------- attention backward ---------------------]
                                         |  optimized  |  sputnik
1 threads: ------------------------------------------------------
      B=16, M=1024, K=32, sparsity=0.0   |    3957.2   |   2484.1
      B=16, M=1024, K=32, sparsity=0.5   |    2311.9   |   1532.6
      B=16, M=1024, K=32, sparsity=0.7   |    1520.0   |   1055.6
      B=16, M=1024, K=32, sparsity=0.8   |    1069.1   |    763.8
      B=16, M=1024, K=32, sparsity=0.9   |     620.7   |    523.1
      B=16, M=1024, K=32, sparsity=0.95  |     358.0   |    464.9
      B=16, M=1024, K=32, sparsity=0.99  |     132.1   |    414.7

After

attn_bias = None
[-------------------------- attention --------------------------]
                                         |  optimized  |  sputnik
1 threads: ------------------------------------------------------
      B=16, M=1024, K=32, sparsity=0.0   |    608.0    |   1071.6
      B=16, M=1024, K=32, sparsity=0.5   |    341.4    |    640.6
      B=16, M=1024, K=32, sparsity=0.7   |    225.6    |    518.5
      B=16, M=1024, K=32, sparsity=0.8   |    164.4    |    495.7
      B=16, M=1024, K=32, sparsity=0.9   |    105.4    |    512.3
      B=16, M=1024, K=32, sparsity=0.95  |     67.4    |    509.4
      B=16, M=1024, K=32, sparsity=0.99  |     30.2    |    496.4

Times are in microseconds (us).

[---------------------- attention backward ---------------------]
                                         |  optimized  |  sputnik
1 threads: ------------------------------------------------------
      B=16, M=1024, K=32, sparsity=0.0   |    3761.6   |   2505.4
      B=16, M=1024, K=32, sparsity=0.5   |    2256.0   |   1512.8
      B=16, M=1024, K=32, sparsity=0.7   |    1458.7   |   1062.2
      B=16, M=1024, K=32, sparsity=0.8   |    1025.2   |    767.7
      B=16, M=1024, K=32, sparsity=0.9   |     588.5   |    523.8
      B=16, M=1024, K=32, sparsity=0.95  |     342.3   |    445.8
      B=16, M=1024, K=32, sparsity=0.99  |     143.5   |    437.8

Times are in microseconds (us).

attn_bias = Tensor
[-------------------------- attention --------------------------]
                                         |  optimized  |  sputnik
1 threads: ------------------------------------------------------
      B=16, M=1024, K=32, sparsity=0.0   |    691.3    |   1071.4
      B=16, M=1024, K=32, sparsity=0.5   |    382.0    |    644.3
      B=16, M=1024, K=32, sparsity=0.7   |    247.2    |    566.9
      B=16, M=1024, K=32, sparsity=0.8   |    182.4    |    518.3
      B=16, M=1024, K=32, sparsity=0.9   |    118.4    |    516.7
      B=16, M=1024, K=32, sparsity=0.95  |     76.1    |    514.9
      B=16, M=1024, K=32, sparsity=0.99  |     32.2    |    508.9

[---------------------- attention backward ---------------------]
                                         |  optimized  |  sputnik
1 threads: ------------------------------------------------------
      B=16, M=1024, K=32, sparsity=0.0   |    3939.1   |   2507.7
      B=16, M=1024, K=32, sparsity=0.5   |    2305.3   |   1509.3
      B=16, M=1024, K=32, sparsity=0.7   |    1496.1   |   1053.2
      B=16, M=1024, K=32, sparsity=0.8   |    1049.7   |    768.8
      B=16, M=1024, K=32, sparsity=0.9   |     608.2   |    525.7
      B=16, M=1024, K=32, sparsity=0.95  |     355.2   |    424.9
      B=16, M=1024, K=32, sparsity=0.99  |     133.4   |    428.1

* Add CUDA_CHECK calls

* Fix sparse backward when a full row is masked

Only affected cases with masking it seems

* Move to fairinternal

* Expose sparse attention to ops

* Try fix python lint

* Add license

* Remove comments in cuda file

* clang-format

* Fix and cleanup benchmarks following rebase

* Bump tolerance

* Add memory-efficient attention in benchmarks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

8 participants