Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

memory_efficient_attention: torch.compile compatibility #920

Open
achalddave opened this issue Nov 9, 2023 · 3 comments
Open

memory_efficient_attention: torch.compile compatibility #920

achalddave opened this issue Nov 9, 2023 · 3 comments
Labels
bug Something isn't working

Comments

@achalddave
Copy link

🐛 Bug

Using xformers.memory_efficient_attention with FSDP and torch.compile fails when using bfloat16, but works when using float32. It's unclear to me if this is an xformers bug, an FSDP bug, or a torch.compile bug. It might be related to pytorch/pytorch#112164, and it came up in our codebase where we use xformers: mlfoundations/open_lm#72

Command

torchrun --nproc_per_node 2 script.py

To Reproduce

Steps to reproduce the behavior:

  1. Save code sample below as script.py
  2. Run torchrun --nproc_per_node 2 script.py
# script.py
import torch
import torch.nn as nn

from torch.distributed.fsdp import MixedPrecision, FullyShardedDataParallel as FSDP
from xformers.ops import memory_efficient_attention
import xformers.ops as xops


class Layer(nn.Module):
    def __init__(self, n_feat):
        super().__init__()
        self.linear_out = nn.Linear(n_feat, n_feat)

    def forward(self, x):
        B, N, C = x.shape
        x = memory_efficient_attention(x, x, x, attn_bias=xops.LowerTriangularMask())
        return self.linear_out(x.reshape([B, N, C]))

###
dtype = torch.bfloat16  # Setting this to torch.float32 makes this code work.
###

torch.distributed.init_process_group(backend="nccl")
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
torch.cuda.set_device(device)
FEAT_SIZE = 128
MAX_LEN = 100
BATCH_SIZE = 8

batch = torch.zeros(BATCH_SIZE, MAX_LEN, FEAT_SIZE).to(device)
mha = Layer(FEAT_SIZE).to(device)

mp_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype)
mha_fsdp = FSDP(mha, use_orig_params=True, device_id=device, mixed_precision=mp_policy)

compile_mha = torch.compile(mha_fsdp).to(device)
output = compile_mha(batch)
output.mean().backward()

Expected behavior

Code runs without error.

Environment

Please copy and paste the output from the
environment collection script from PyTorch
(or fill out the checklist below manually).

You can run the script with:

# For security purposes, please check the contents of collect_env.py before running it.
python -m torch.utils.collect_env
PyTorch version: 2.0.1+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.26.3
Libc version: glibc-2.31

Python version: 3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:36:39) [GCC 10.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-1028-aws-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA A100-SXM4-80GB
GPU 4: NVIDIA A100-SXM4-80GB
GPU 5: NVIDIA A100-SXM4-80GB
GPU 6: NVIDIA A100-SXM4-80GB
GPU 7: NVIDIA A100-SXM4-80GB

Nvidia driver version: 515.65.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] mypy==1.5.1
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.3
[pip3] pytorch-ranger==0.1.1
[pip3] st-moe-pytorch==0.1.1
[pip3] torch==2.0.1+cu118
[pip3] torch-optimizer==0.3.0
[pip3] torchdata==0.6.0
[pip3] torchmetrics==0.11.3
[pip3] torchtext==0.15.1
[pip3] torchvision==0.15.2+cu118
[conda] numpy                     1.25.2                   pypi_0    pypi

Additional context

xformers version: 0.0.22.

@danthe3rd
Copy link
Contributor

Hi,
Thanks for reporting this! A lot of operators in xFormers don't support torch.compile at the moment. This is on our roadmap, but might take ~months to get there (we might also need to fix some bugs in PyTorch as well...)

@achalddave
Copy link
Author

Ah, okay, thanks! Is there an issue that tracks this that we could follow? We'd love to support torch.compile+xformers attention in our repo.

@danthe3rd danthe3rd pinned this issue Nov 15, 2023
@danthe3rd danthe3rd added the bug Something isn't working label Nov 15, 2023
@danthe3rd danthe3rd changed the title xformers.memory_efficient_attention: Compatibility with torch.compile, FSDP, and bfloat16 xformers.memory_efficient_attention: Compatibility with torch.compile Nov 15, 2023
@danthe3rd danthe3rd changed the title xformers.memory_efficient_attention: Compatibility with torch.compile memory_efficient_attention: torch.compile compatibility Nov 15, 2023
@danthe3rd
Copy link
Contributor

We can use this issue to track. However this particular error might be related to FSDP ...
xFormers operator will most likely incur a graph break (which will make performance worse), but shouldn't cause an exception or error.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants