a simple Flash Attention v2 implementation with ROCM (RDNA3 GPU, roc wmma), mainly used for stable diffusion(ComfyUI) in Windows ZLUDA environments.
minimum integration:
──rocwmma_fattn
│ FlashAttn.py
│ host.cpp
│ kernel_bf16.cu
│ kernel_fp16.cu
└─ zluda_hijack_torch_hip_ext.py
run test: python bench_with_sdpa.py
Need MSVC Compiler, AMD HIP SDK and rocWMMA Library.
Install rocwmma library: https://github.com/ROCm/rocWMMA
clone it and copy library/include/rocwmma
to HIP SDK installation path of include
folder
In cmd.exe, run vcvars64.bat
to active MSVC Environment, then run zluda -- python bench_with_sdpa.py
Tested work with PyTorch 2.2.1 + cu118 windows zluda, gfx1100 GPU
comfyui: https://github.com/Repeerc/ComfyUI-flash-attention-rdna3-win-zluda
webui: https://github.com/Repeerc/sd-webui-flash-attention-zluda-win
- backward pass
- causal mask (need more optimization)
- unaligned 32x seqlen padding optimization
- Load tile into LDS
- attention bias
- matrix multiplication optimization
- fix poor performance in BF16
- ...
OS: Windows 11
GPU: 7900xtx (gfx1100)
PyTorch 2.2.1 + CU118 ZLUDA, Python 3.10, HIP 5.7.1
Triton build from: https://github.com/triton-lang/triton
git hash: 47fc046ff29c9ea2ee90e987c39628a540603c8f
test use Triton windows pre-build version: https://github.com/Repeerc/triton-windows-amdgpu
Triton offcial version use 06-fused-attention.py
CK-based(Composable Kernel) flash attention version compiled from: https://github.com/ROCm/flash-attention/tree/howiejay/navi_support
CK-based flash attention windows porting: https://github.com/Repeerc/flash-attn-composable-kernel-gfx110x-windows-port
OS: Windows 11
GPU: 7900xtx (gfx1100)
PyTorch 2.2.1 + CU118 ZLUDA, Python 3.10
Sampler: Euler
SD 1.5 | PyTorch SDPA | Flash Attn minimal | |
---|---|---|---|
512x512x1 | 17.32 it/s | 19.20 it/s | +10% |
VRAM | 3.2 GB | 2.3 GB | |
-- | -- | -- | -- |
512x512x4 | 4.96 it/s | 5.47 it/s | +10% |
VRAM | 5.4 GB | 2.5 GB | |
-- | -- | -- | -- |
1024x1024x1 | 2.52it/s | 3.53it/s | +40% |
VRAM | 10.7 GB | 2.9 GB |
SDXL | PyTorch SDPA | Flash Attn minimal | |
---|---|---|---|
1536x1024x1 | 2.03 it/s | 2.35 it/s | +16% |
VRAM | 7.4 GB | 6.8 GB | |
-- | -- | -- | -- |
1024x1024x1 | 3.30 it/s | 3.60 it/s | +9% |
VRAM | 6.5 GB | 6.4 GB |
unet_lr = 0.0001
lr_scheduler = "constant"
lr_warmup_steps = 0
optimizer_type = "AdamW"
network_dim = 32
network_alpha = 32
seed = 1337
mixed_precision = "fp16"
full_fp16 = false
full_bf16 = false
fp8_base = true
no_half_vae = false
SDXL | PyTorch SDPA | Flash Attn minimal | |
---|---|---|---|
1024x1024x1 | 1.27 it/s | 1.76 it/s | +39 % |
VRAM | 21.5 GB | 16.8 GB |