Skip to content

Commit

Permalink
Merge pull request #74 from ROCm/ck_tile/kvcache
Browse files Browse the repository at this point in the history
Ck tile/kvcache
  • Loading branch information
rocking5566 authored Aug 28, 2024
2 parents d38c59b + 007ae03 commit 7259227
Show file tree
Hide file tree
Showing 9 changed files with 950 additions and 33 deletions.
2 changes: 1 addition & 1 deletion csrc/composable_kernel
Submodule composable_kernel updated 77 files
+5 −6 CMakeLists.txt
+37 −11 Jenkinsfile
+20 −4 client_example/24_grouped_conv_activation/CMakeLists.txt
+834 −0 client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_reduce/common.hpp
+58 −0 ...nt_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_reduce/conv3d_fwd_convscale_amax_fp8.cpp
+58 −0 ...ample/24_grouped_conv_activation/grouped_convnd_fwd_convscale_reduce/conv3d_fwd_convscale_relu_amax_fp8.cpp
+2 −2 codegen/CMakeLists.txt
+16 −12 codegen/test/CMakeLists.txt
+0 −2 codegen/test/rtc/CMakeLists.txt
+1 −1 docs/sphinx/requirements.in
+1 −1 docs/sphinx/requirements.txt
+2 −2 example/01_gemm/gemm_xdl_fp8.cpp
+5 −5 example/01_gemm/run_gemm_example.inc
+1 −0 example/62_convnd_activ/CMakeLists.txt
+14 −0 example/62_convnd_activ/convscale_reduce/CMakeLists.txt
+502 −0 example/62_convnd_activ/convscale_reduce/convnd_fwd_convscale_reduce_common.hpp
+82 −0 example/62_convnd_activ/convscale_reduce/convnd_fwd_xdl_convscale_amax_fp8.cpp
+82 −0 example/62_convnd_activ/convscale_reduce/convnd_fwd_xdl_convscale_relu_amax_fp8.cpp
+98 −0 example/62_convnd_activ/convscale_reduce/run_convnd_fwd_example.inc
+37 −3 example/ck_tile/01_fmha/CMakeLists.txt
+13 −1 example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
+355 −0 example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py
+127 −61 example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
+604 −164 example/ck_tile/01_fmha/fmha_fwd.cpp
+293 −30 example/ck_tile/01_fmha/fmha_fwd.hpp
+16 −11 example/ck_tile/01_fmha/generate.py
+84 −0 example/ck_tile/01_fmha/rotary.hpp
+2 −3 example/ck_tile/01_fmha/script/benchmark_bwd.sh
+2 −3 example/ck_tile/01_fmha/script/benchmark_fwd.sh
+2 −3 example/ck_tile/01_fmha/script/smoke_test_bwd.sh
+94 −41 example/ck_tile/01_fmha/script/smoke_test_fwd.sh
+94 −21 example/ck_tile/01_fmha/utils.hpp
+3 −3 include/ck/ck.hpp
+3 −1 include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp
+10 −3 include/ck_tile/core/numeric/math.hpp
+50 −0 include/ck_tile/core/tensor/tile_window.hpp
+17 −0 include/ck_tile/core/utility/type_traits.hpp
+1 −0 include/ck_tile/host.hpp
+9 −0 include/ck_tile/host/host_tensor.hpp
+5 −5 include/ck_tile/host/kernel_launch.hpp
+73 −0 include/ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp
+6 −2 include/ck_tile/ops/fmha.hpp
+19 −3 include/ck_tile/ops/fmha/block/block_position_encoding.hpp
+108 −0 include/ck_tile/ops/fmha/block/block_rotary_embedding.hpp
+279 −0 include/ck_tile/ops/fmha/block/page_block_navigator.hpp
+679 −0 include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp
+42 −0 include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp
+182 −177 include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
+277 −0 include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp
+288 −0 include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp
+98 −80 include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
+0 −770 include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp
+0 −19 include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp
+83 −31 include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
+13 −3 include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
+37 −21 include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
+37 −0 ...ibrary/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp
+83 −3 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp
+83 −1 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp
+13 −0 library/include/ck/library/tensor_operation_instance/gpu/permute_scale.hpp
+52 −6 library/include/ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp
+18 −9 ...clude/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_amax.hpp
+3 −2 library/include/ck/library/utility/check_err.hpp
+2 −1 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/CMakeLists.txt
+61 −0 ...d_fwd_convscale/xdl/device_grouped_conv3d_fwd_xdl_combconvscale_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instance.cpp
+2 −1 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/CMakeLists.txt
+61 −0 ...scale_relu/xdl/device_grouped_conv3d_fwd_xdl_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instance.cpp
+1 −4 ...3d_fwd_convscale_relu/xdl/device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp
+3 −2 library/src/tensor_operation_instance/gpu/permute_scale/CMakeLists.txt
+28 −0 library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_6d_fp32_fp8_instances.cpp
+18 −9 library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_amax.cpp
+88 −75 profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp
+1 −2 profiler/src/profile_grouped_conv_bwd_weight.cpp
+386 −0 script/convert_miopen_driver_to_profiler.py
+10 −0 test/data_type/CMakeLists.txt
+23 −19 test/data_type/test_bf8.cpp
+23 −19 test/data_type/test_fp8.cpp
41 changes: 32 additions & 9 deletions csrc/flash_attn_ck/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ mha_fwd(at::Tensor &q,
c10::optional<at::Generator> gen_);

std::vector<at::Tensor>
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
c10::optional<const at::Tensor> &leftpad_k_, // batch_size
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
int max_seqlen_q,
const int max_seqlen_k,
const float p_dropout,
Expand Down Expand Up @@ -89,11 +89,34 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state);

std::vector<at::Tensor>
mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
c10::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
c10::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
c10::optional<const at::Tensor> &seqlens_k_, // batch_size
c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
c10::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
c10::optional<const at::Tensor> &leftpad_k_, // batch_size
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
const float softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int num_splits);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.doc() = "FlashAttention";
m.def("fwd", &mha_fwd, "Forward pass");
m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
m.def("bwd", &mha_bwd, "Backward pass");
m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)");
m.def("fwd_kvcache", &mha_fwd_kvcache, "Forward pass, with KV-cache");
}
34 changes: 34 additions & 0 deletions csrc/flash_attn_ck/flash_common.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/

#include "flash_common.hpp"

namespace flash {
int override_num_splits_if_necessary(int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits)
{
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess)
return num_splits;

hipDeviceProp_t props{};
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess)
return num_splits;

// TODO - tile size should match the TileFmhaShape, hardcode for now
const int kM0 = 128;
const int kN1 = hdim_v;

const int num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
const int num_n_blocks = (hdim_v + kN1 - 1) / kN1;

if(num_splits < 1 && p_drop == 0.0f)
return num_splits_heuristic_ck(
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128);

return num_splits;
}

} // namespace flash
40 changes: 39 additions & 1 deletion csrc/flash_attn_ck/flash_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
namespace flash {
// Copy from PyTorch
// https://github.com/pytorch/pytorch/blob/8b61daaf7349e9102117e1aeefaa51666d887547/aten/src/ATen/cuda/detail/UnpackRaw.cuh#L17
static std::tuple<uint64_t, uint64_t> unpack(at::PhiloxCudaState arg) {
inline std::tuple<uint64_t, uint64_t> unpack(at::PhiloxCudaState arg) {
if (arg.captured_) {
// static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long".
// *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel.
Expand All @@ -35,4 +35,42 @@ static std::tuple<uint64_t, uint64_t> unpack(at::PhiloxCudaState arg) {
}
}

inline int num_splits_heuristic_ck(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) {
// If we have enough to almost fill the SMs, then just use 1 split
if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; }
max_splits = std::min({max_splits, num_SMs, num_n_blocks});
float max_efficiency = 0.f;
std::vector<float> efficiency;
efficiency.reserve(max_splits);
auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
// Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
// we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
// (i.e. it's 11 splits anyway).
// So we check if the number of blocks per split is the same as the previous num_splits.
auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
};
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
if (!is_split_eligible(num_splits)) {
efficiency.push_back(0.f);
} else {
float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs;
float eff = n_waves / ceil(n_waves);
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
if (eff > max_efficiency) { max_efficiency = eff; }
efficiency.push_back(eff);
}
}
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
if (!is_split_eligible(num_splits)) { continue; }
if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {
// printf("num_splits chosen = %d\n", num_splits);
return num_splits;
}
}
return 1;
}

int override_num_splits_if_necessary(int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits);

} // namespace flash
10 changes: 0 additions & 10 deletions csrc/flash_attn_ck/mha_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,6 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
v.data_ptr(),
alibi_slopes_ptr, // bias
has_dropout_randval ? dropout_randval.data_ptr() : nullptr,
nullptr, // lse_acc
nullptr, // o_acc
has_lse ? softmax_lse.data_ptr() : nullptr,
out.data_ptr(),
nullptr, // seqstart_q
Expand All @@ -111,7 +109,6 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
d, // hdim_v
h, // nhead
h_k, // nhead_k
1, // num_splits
softmax_scale, // scale_s
1, // scale_p
1, // scale_o
Expand All @@ -120,28 +117,21 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
stride_v,
stride_alibi_slopes,
stride_randval,
0, // stride_o_acc,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
0, // nhead_stride_bias, FA without bias
nhead_stride_randval,
nhead_stride_lse,
0, // nhead_stride_lse_acc
0, // nhead_stride_o_acc
nhead_stride_o,
batch_stride_q,
batch_stride_k,
batch_stride_v,
0, // batch_stride_bias, FA without bias
batch_stride_randval,
batch_stride_lse,
0, // batch_stride_lse_acc
0, // batch_stride_o_acc
batch_stride_o,
0, // split_stride_lse_acc
0, // split_stride_o_acc
mask.left,
mask.right,
static_cast<ck_tile::index_t>(mask.type),
Expand Down
Loading

0 comments on commit 7259227

Please sign in to comment.