Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support AMD ROCm on FlashAttention 2 #1010

Merged
merged 46 commits into from
Jul 23, 2024
Merged

Support AMD ROCm on FlashAttention 2 #1010

merged 46 commits into from
Jul 23, 2024

Conversation

rocking5566
Copy link
Contributor

@rocking5566 rocking5566 commented Jun 26, 2024

  • This PR implement the AMD / ROCm version of c++ flash api
    1. mha_fwd
    2. mha_varlen_fwd
    3. mha_bwd
    4. mha_varlen_bwd
  • The kernel implementation comes from composable kernel
  • The c++ api is same as original version. Hence, python interface can be used in common.

@minzhezhou
Copy link

minzhezhou commented Jul 12, 2024

I got this error during the build fmha_bwd_d128_fp16_batch for gfx1100:
/root/code/flash-attention/csrc/composable_kernel/include/ck_tile/core/arch/generic_memory_space_atomic_hip.hpp:66:19: error: static assertion failed due to requirement '(std::is_same<_Float16, int>::value && (4 == 1)) || (std::is_same<_Float16, unsigned int>::value && (4 == 1)) || (std::is_same<_Float16, float>::value && (4 == 1 || 4 == 2)) || (std::is_same<_Float16, double>::value && (4 == 1 || 4 == 2)) || (std::is_same<_Float16, unsigned short>::value && (4 == 2 || 4 == 4))': wrong! not implemented,

It was due to _Float16 not implemented in the instantiation of:
/root/code/flash-attention/csrc/composable_kernel/include/ck_tile/core/tensor/buffer_view_hip.hpp:413:28: note: in instantiation of function template specialization 'ck_tile::buffer_view<ck_tile::address_space_enum::global, _Float16, long, true, ck_tile::amd_buffer_coherence_enum::coherence_default>::atomic_add<ck_tile::thread_buffer<_Float16, 4>, false>' requested here
413 | this->template atomic_add(i, is_valid_element, x);

The template doesn't accepts T=float16, but accepts T=bf16 with N==2 or 4:
template <typename T, index_t N>
CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)

Is this expected? I saw the allowed_archs = ["native", "gfx90a", "gfx940", "gfx941", "gfx942"]

@hackey
Copy link

hackey commented Jul 14, 2024

I also have an error when compiling for 7900xtx (gfx1100). Write, does flash-attention support this card?

@minzhezhou
Copy link

minzhezhou commented Jul 14, 2024

I also have an error when compiling for 7900xtx (gfx1100). Write, does flash-attention support this card?

They don't support it yet, native should mean for CPU. I digged into it, here are my findings.
The first problem I met was float16 not implemented in the atomic_add in csrc/composable_kernel/include/ck_tile/core/arch/generic_memory_space_atomic.hpp. I still don't understand why MI300 doesn't need it, but we can disable the generate.py from generating kernels for fp16. There are 3 python files in csrc/composable_kernel/example/ck_tile/01_fmha/codegen/ops/, we can modify them to skip for fp16.
The second problem I'm dealing with is a configuration check: static_assert(kKPack % K3 == 0) in csrc/composable_kernel/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
The key was the wrap size of 7900xtxt is 32, but for mi200 and mi300, they have 64.

kBlockSize = NumWarps * warp_size = (1 * 4 * 1) * 32 = 128
kMPerBlock (M0) = 128
kNPerBlock (N0) = 128
kQKHeaddim = 32
BiasDataType = unsigned short (16 bits)
total_pixels = kMPerBlock * kNPerBlock / kBlockSize = 128 * 128 / 128 = 128
Since total_pixels > 32, N1 = 8
N0 = kNPerBlock / N1 = 128 / 8 = 16
total_pixels = kMPerBlock * kNPerBlock / kBlockSize = 128 * 128 / 128 = 128
M3 = total_pixels / N1 = 128 / 8 = 16
kKPack = 16 / sizeof(BiasDataType) = 16 / 2 = 8
8 % 16 != 0

Here is the problem template instanciation I figured out, hope this could help:


ck_tile::BlockFmhaBwdPipelineProblem<
        unsigned short, unsigned short, unsigned short, unsigned short, float,          float,        float,        unsigned short, unsigned char,          unsigned short, unsigned short, unsigned short, unsigned short, unsigned short, unsigned short, 
        # QDataType,    KDataType,      VDataType,      GemmDataType,   LSEDataType,    AccDataType,  DDataType,    BiasDataType,   RandValOutputDataType,  ODataType,      OGradDataType,  QGradDataType,  KGradDataType,  VGradDataType,  BiasGradDataType    
        ck_tile::TileFmhaBwdShape<  # BlockFmhaShape
            ck_tile::sequence<128, 128, 32, 32, 32, 32, 32, 32, 32>,    # BlockTile: [kM0, KN0, KK0, KK1, KK2, KK3, KK4, kQKHeaddim, kVHeaddim]
            ck_tile::sequence<1, 4, 1>,         # Gemm0BlockWarps
            ck_tile::sequence<32, 32, 16>,      # Gemm0WarpTile
            ck_tile::sequence<4, 1, 1>,         # Gemm1BlockWarps
            ck_tile::sequence<32, 32, 16>,      # Gemm1WarpTile
            ck_tile::sequence<1, 4, 1>,         # Gemm2BlockWarps
            ck_tile::sequence<32, 32, 16>,      # Gemm2WarpTile
            ck_tile::sequence<4, 1, 1>,         # Gemm3BlockWarps
            ck_tile::sequence<32, 32, 16>,      # Gemm3WarpTile
            ck_tile::sequence<4, 1, 1>,         # Gemm4BlockWarps
            ck_tile::sequence<32, 32, 16>       # Gemm4WarpTile
        >, # NumWarps == reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{}), kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(), wrap_size = 32 
        false, #kIsGroupMode
        ck_tile::SimplifiedGenericAttentionMask<>,  # FmhaMask
        ck_tile::TileFmhaTraits<false, false, false, false, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, 1>    #Traits
    >

I'll keep digging into it. My plan is to generate a group a kernels fitting with 7900xtx, will post my result if I have any findings.

@ehartford
Copy link

ehartford commented Jul 16, 2024

I tried this on a known good configuration, using TRL

I am able to run it without flash attention, and I am able to run it with the ROCm version of flash attention.

But using this PR - I get an error.

I also repro this same error in axolotl, so it's not a specific problem with TRL.

The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.
Traceback (most recent call last):
  File "/home/erichartford/flash-attention/./thingy.py", line 107, in <module>
    trainer.train()
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/trl/trainer/sft_trainer.py", line 451, in train
    output = super().train(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/transformers/trainer.py", line 1932, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/transformers/trainer.py", line 2268, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/transformers/trainer.py", line 3307, in training_step
    loss = self.compute_loss(model, inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/transformers/trainer.py", line 3338, in compute_loss
    outputs = model(**inputs)
              ^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/accelerate/utils/operations.py", line 819, in forward
    return model_forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/accelerate/utils/operations.py", line 807, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/peft/peft_model.py", line 1430, in forward
    return self.base_model(
           ^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/peft/tuners/tuners_utils.py", line 179, in forward
    return self.model.forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/accelerate/hooks.py", line 169, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1174, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 967, in forward
    layer_outputs = self._gradient_checkpointing_func(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/_compile.py", line 31, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/utils/checkpoint.py", line 481, in checkpoint
    return CheckpointFunction.apply(function, preserve, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/autograd/function.py", line 574, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/utils/checkpoint.py", line 255, in forward
    outputs = run_function(*args)
              ^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/accelerate/hooks.py", line 169, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 718, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
                                                          ^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/accelerate/hooks.py", line 169, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 467, in forward
    attn_output = self._flash_attention_forward(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 532, in _flash_attention_forward
    attn_output = flash_attn_func(
                  ^^^^^^^^^^^^^^^^
  File "/home/erichartford/flash-attention/flash_attn/flash_attn_interface.py", line 882, in flash_attn_func
    return FlashAttnFunc.apply(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/miniconda3/envs/flashattn/lib/python3.11/site-packages/torch/autograd/function.py", line 574, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/flash-attention/flash_attn/flash_attn_interface.py", line 548, in forward
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
                                                                ^^^^^^^^^^^^^^^^^^^^
  File "/home/erichartford/flash-attention/flash_attn/flash_attn_interface.py", line 51, in _flash_attn_forward
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
                                                                ^^^^^^^^^^^^^^^^^^^^
TypeError: fwd(): incompatible function arguments. The following argument types are supported:
    1. (arg0: torch.Tensor, arg1: torch.Tensor, arg2: torch.Tensor, arg3: Optional[torch.Tensor], arg4: Optional[torch.Tensor], arg5: float, arg6: float, arg7: bool, arg8: int, arg9: int, arg10: bool, arg11: Optional[torch.Generator]) -> list[torch.Tensor]

Invoked with: tensor([[[[ 1.3281e-01,  2.2461e-01, -3.9551e-02,  ...,  3.5156e-02,
            5.8838e-02,  1.4648e-01],
          [-1.3733e-02,  3.4332e-03, -4.8633e-01,  ...,  4.6631e-02,
            2.4805e-01,  4.3701e-02],
          ....
          [-2.1362e-02, -4.7607e-02, -2.6611e-02,  ..., -1.8677e-02,
           -3.3447e-02, -7.6660e-02]]]], device='cuda:0', dtype=torch.bfloat16), None, None, 0.0, 0.08838834764831845, True, -1, -1, 0.0, False, None

Also please note that - we were able to test Successfully when using this branch in the ROCm repo.

It is only when we merge that to the main branch of the DAO-AILab repo that this error presents.

We have independently reproduced this error on several machines, and several environments.

@rocking5566 we would be happy to have a call and troubleshoot this with you, if you would like to reach out. eric@tensorwave.com

@poyenc
Copy link
Contributor

poyenc commented Jul 17, 2024

@minzhezhou Thanks for your time. We only support mi200 & mi300 at this time. Thus we put gfx90a/gfx94x in the allowed_archs list. Other targets should be blocked anyway..

@minzhezhou
Copy link

@minzhezhou Thanks for your time. We only support mi200 & mi300 at this time. Thus we put gfx90a/gfx94x in the allowed_archs list. Other targets should be blocked anyway..

Hi @poyenc, thanks for the reminder. Do you mean it is technically impossible to make it work for navi or it is not on the official roadmap yet?
How about gfx908?

@poyenc
Copy link
Contributor

poyenc commented Jul 18, 2024

Hi @poyenc, thanks for the reminder. Do you mean it is technically impossible to make it work for navi or it is not on the official roadmap yet?
How about gfx908?

@minzhezhou thanks for your attention. targets other than gfx90a & gfx94x is not on the roadmap yet. currently we are focusing on mi300 platforms.

@deke997
Copy link

deke997 commented Jul 23, 2024

now that the conflicts are resolved, can we merge this?

thanks!

@tridao tridao merged commit d8f104e into Dao-AILab:main Jul 23, 2024
@linchen111
Copy link

linchen111 commented Jul 24, 2024

I have mi50 and mi100 , looking forward to gfx906 and gfx908 support, they are so much cheaper

@yiakwy-xpu-ml-framework-team
Copy link

yiakwy-xpu-ml-framework-team commented Jul 24, 2024

Hi @poyenc, thanks for the reminder. Do you mean it is technically impossible to make it work for navi or it is not on the official roadmap yet?
How about gfx908?

@minzhezhou thanks for your attention. targets other than gfx90a & gfx94x is not on the roadmap yet. currently we are focusing on mi300 platforms.

You can use this script to query arch and exclude any archs not well being testified (only gfx942 tested)

target_amdarch =$(/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*')

@minzhezhou
Copy link

I have mi50 and mi100 , looking forward to gfx906 and gfx908 support, they are so much cheaper

gfx906 and gfx908 could at least compile, they have warp_size = 64.

@rocking5566
Copy link
Contributor Author

I have mi50 and mi100 , looking forward to gfx906 and gfx908 support, they are so much cheaper

gfx906 and gfx908 could at least compile, they have warp_size = 64.

Current FA could compile successfully in MI100.
However, I found some test cases might fail....
We may fix it in the future.

@ehartford
Copy link

Thats excellent news! I cannot wait to try it on my mi100s

@linchen111
Copy link

I have mi50 and mi100 , looking forward to gfx906 and gfx908 support, they are so much cheaper我有 mi50 和 mi100 ,期待 gfx906 和 gfx908 支持,它们便宜多了

gfx906 and gfx908 could at least compile, they have warp_size = 64.gfx906和gfx908至少可以编译,它们的warp_size = 64。

Current FA could compile successfully in MI100.目前的FA可以在MI100中编译成功。 However, I found some test cases might fail....但是,我发现一些测试用例可能会失败...... We may fix it in the future.我们将来可能会修复它。

mi100 worked , but mi50 failed on all test

@rocking5566
Copy link
Contributor Author

rocking5566 commented Aug 1, 2024

I have mi50 and mi100 , looking forward to gfx906 and gfx908 support, they are so much cheaper我有 mi50 和 mi100 ,期待 gfx906 和 gfx908 支持,它们便宜多了

gfx906 and gfx908 could at least compile, they have warp_size = 64.gfx906和gfx908至少可以编译,它们的warp_size = 64。

Current FA could compile successfully in MI100.目前的FA可以在MI100中编译成功。 However, I found some test cases might fail....但是,我发现一些测试用例可能会失败...... We may fix it in the future.我们将来可能会修复它。

mi100 worked , but mi50 failed on all test

That is why we only claim MI200 & MI300 are officially support.
Other platform might failed on some test cases for some version of ROCm.

@linchen111
Copy link

linchen111 commented Aug 1, 2024

I have mi50 and mi100 , looking forward to gfx906 and gfx908 support, they are so much cheaper我有 mi50 和 mi100 ,期待 gfx906 和 gfx908 支持,它们便宜多了

gfx906 and gfx908 could at least compile, they have warp_size = 64.gfx906和gfx908至少可以编译,它们的warp_size = 64。

Current FA could compile successfully in MI100.目前的FA可以在MI100中编译成功。 However, I found some test cases might fail....但是,我发现一些测试用例可能会失败...... We may fix it in the future.我们将来可能会修复它。

mi100 worked , but mi50 failed on all test

That is why we only claim MI200 & MI300 are officially support. Other platform might failed on some test cases for some version of ROCm.

I installed rocm 6.1x on mi50, all test failed -.-

hoping for mi50 support~

@foreverlms
Copy link

I am wondering why kv cache/ paged attention API fwd_kvcache not supported.

@rocking5566
Copy link
Contributor Author

I am wondering why kv cache/ paged attention API fwd_kvcache not supported.

fwd_kvcache is in our roadmap. This will coming soon.

@foreverlms
Copy link

I am wondering why kv cache/ paged attention API fwd_kvcache not supported.

fwd_kvcache is in our roadmap. This will coming soon.

I am wondering why kv cache/ paged attention API fwd_kvcache not supported.

fwd_kvcache is in our roadmap. This will coming soon.

Thanks for kind reply!
And Seems you are the developer of CK, I am learning how to programming based on CK, but there is not much tutorials I can refer to. Would you please make a recomendation?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.