Skip to content

a simple Flash Attention v2 implementation with ROCM (RDNA3 GPU, roc wmma), mainly used for stable diffusion(ComfyUI) in Windows ZLUDA environments.

License

Notifications You must be signed in to change notification settings

Repeerc/flash-attention-v2-RDNA3-minimal

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

30 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

flash-attention-v2-RDNA3-minimal

a simple Flash Attention v2 implementation with ROCM (RDNA3 GPU, roc wmma), mainly used for stable diffusion(ComfyUI) in Windows ZLUDA environments.

Build and Test

minimum integration:

──rocwmma_fattn
   │  FlashAttn.py
   │  host.cpp
   │  kernel_bf16.cu
   │  kernel_fp16.cu
   └─ zluda_hijack_torch_hip_ext.py

Linux with rocm:

run test: python bench_with_sdpa.py

Windows with zluda

Need MSVC Compiler, AMD HIP SDK and rocWMMA Library.

Install rocwmma library: https://github.com/ROCm/rocWMMA

clone it and copy library/include/rocwmma to HIP SDK installation path of include folder

In cmd.exe, run vcvars64.bat to active MSVC Environment, then run zluda -- python bench_with_sdpa.py

Pre-build Extension

Tested work with PyTorch 2.2.1 + cu118 windows zluda, gfx1100 GPU

comfyui: https://github.com/Repeerc/ComfyUI-flash-attention-rdna3-win-zluda

webui: https://github.com/Repeerc/sd-webui-flash-attention-zluda-win

To do

  • backward pass
  • causal mask (need more optimization)
  • unaligned 32x seqlen padding optimization
  • Load tile into LDS
  • attention bias
  • matrix multiplication optimization
  • fix poor performance in BF16
  • ...

Benchmark

OS: Windows 11

GPU: 7900xtx (gfx1100)

PyTorch 2.2.1 + CU118 ZLUDA, Python 3.10, HIP 5.7.1

FP16, causal = False

Triton build from: https://github.com/triton-lang/triton

git hash: 47fc046ff29c9ea2ee90e987c39628a540603c8f

test use Triton windows pre-build version: https://github.com/Repeerc/triton-windows-amdgpu

Triton offcial version use 06-fused-attention.py

CK-based(Composable Kernel) flash attention version compiled from: https://github.com/ROCm/flash-attention/tree/howiejay/navi_support

CK-based flash attention windows porting: https://github.com/Repeerc/flash-attn-composable-kernel-gfx110x-windows-port

seqlen with 32x aligened

f487a9936a4cc0f0f76fedef7eae7fc9

412d57e698d5325f87d4a0ca1da589f7

[B N H D] format rearrange and contiguous to [B H N D]

5d3b4ad403dda0d84d430d69435b7c6f

ce9f235ada83a6147db9f5fc36040e19

seqlen without 32x aligened

d470b660f4018cdb3325a5b1f7489537

[B N H D] format rearrange and contiguous to [B H N D]

87475a48c0bb91aa44cb39564dbc2baf

fwd+bwd

84ef4f7d7ec6a1158a0a5c31759aafec

FP16, causal = True

fwd_scan_N

fwd_bwd_scan_N

fwd_scan_D

Performance in Stable Diffusion (ComfyUI)

OS: Windows 11

GPU: 7900xtx (gfx1100)

PyTorch 2.2.1 + CU118 ZLUDA, Python 3.10

Sampler: Euler

SD 1.5 PyTorch SDPA Flash Attn minimal
512x512x1 17.32 it/s 19.20 it/s +10%
VRAM 3.2 GB 2.3 GB
-- -- -- --
512x512x4 4.96 it/s 5.47 it/s +10%
VRAM 5.4 GB 2.5 GB
-- -- -- --
1024x1024x1 2.52it/s 3.53it/s +40%
VRAM 10.7 GB 2.9 GB
SDXL PyTorch SDPA Flash Attn minimal
1536x1024x1 2.03 it/s 2.35 it/s +16%
VRAM 7.4 GB 6.8 GB
-- -- -- --
1024x1024x1 3.30 it/s 3.60 it/s +9%
VRAM 6.5 GB 6.4 GB

SDXL U-Net Lora training

unet_lr = 0.0001
lr_scheduler = "constant"
lr_warmup_steps = 0
optimizer_type = "AdamW"
network_dim = 32
network_alpha = 32
seed = 1337
mixed_precision = "fp16"
full_fp16 = false
full_bf16 = false
fp8_base = true
no_half_vae = false
SDXL PyTorch SDPA Flash Attn minimal
1024x1024x1 1.27 it/s 1.76 it/s +39 %
VRAM 21.5 GB 16.8 GB

About

a simple Flash Attention v2 implementation with ROCM (RDNA3 GPU, roc wmma), mainly used for stable diffusion(ComfyUI) in Windows ZLUDA environments.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published