diff --git a/aiter/ops/fused_mrope_rms.py b/aiter/ops/fused_mrope_rms.py index a337d032ed..9289ce2ec6 100644 --- a/aiter/ops/fused_mrope_rms.py +++ b/aiter/ops/fused_mrope_rms.py @@ -23,3 +23,20 @@ def fused_mrope_3d_rms( is_interleaved: bool, eps: float, ) -> None: ... + + +@compile_ops("module_fused_mrope_rms") +def fused_rope_rms( + qkv: Tensor, + qw: Tensor, + kw: Tensor, + cos_sin: Tensor, + positions: Tensor, + num_tokens: int, + num_heads_q: int, + num_heads_k: int, + num_heads_v: int, + head_size: int, + is_neox_style: bool, + eps: float, +) -> None: ... diff --git a/csrc/include/fused_mrope_rms.h b/csrc/include/fused_mrope_rms.h index cca257efc7..ca430930c2 100644 --- a/csrc/include/fused_mrope_rms.h +++ b/csrc/include/fused_mrope_rms.h @@ -21,3 +21,16 @@ void fused_mrope_3d_rms(Tensor& qkv, std::vector mrope_section_, bool is_interleaved, double eps); + +void fused_rope_rms(Tensor& qkv, + Tensor& qw, + Tensor& kw, + Tensor& cos_sin, + Tensor& positions, + int64_t num_tokens, + int64_t num_heads_q, + int64_t num_heads_k, + int64_t num_heads_v, + int64_t head_size, + bool is_neox_style, + double eps); diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index 17ed8ff958..2357d741c6 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -1315,7 +1315,9 @@ namespace py = pybind11; m.def("rope_cached_positions_offsets_fwd_impl", &rope_cached_positions_offsets_fwd_impl); \ m.def("rope_cached_positions_offsets_2c_fwd_impl", &rope_cached_positions_offsets_2c_fwd_impl); -#define FUSED_MROPE_RMS_PYBIND m.def("fused_mrope_3d_rms", &fused_mrope_3d_rms); +#define FUSED_MROPE_RMS_PYBIND \ + m.def("fused_mrope_3d_rms", &fused_mrope_3d_rms); \ + m.def("fused_rope_rms", &fused_rope_rms); #define SMOOTHQUANT_PYBIND \ m.def("smoothquant_fwd", &smoothquant_fwd); \ @@ -1435,30 +1437,30 @@ namespace py = pybind11; py::arg("stride0"), \ py::arg("stride1")); -#define MLA_METADATA_PYBIND \ - m.def("get_mla_metadata_v1", \ - &get_mla_metadata_v1, \ - "get_mla_metadata_v1", \ - py::arg("seqlens_qo_indptr"), \ - py::arg("seqlens_kv_indptr"), \ - py::arg("num_heads_per_head_k"), \ - py::arg("num_heads_k"), \ - py::arg("is_causal"), \ - py::arg("work_metadata_ptrs"), \ - py::arg("work_info_set"), \ - py::arg("work_indptr"), \ - py::arg("reduce_indptr"), \ - py::arg("reduce_final_map"), \ - py::arg("reduce_partial_map"), \ - py::arg("kv_granularity") = 16, \ - py::arg("max_seqlen_qo") = -1, \ - py::arg("uni_seqlen_qo") = -1, \ - py::arg("fast_mode") = true, \ - py::arg("topk") = -1, \ - py::arg("max_split_per_batch") = -1, \ - py::arg("intra_batch_mode") = false, \ - py::arg("dtype_q") = std::nullopt, \ - py::arg("dtype_kv") = std::nullopt); \ +#define MLA_METADATA_PYBIND \ + m.def("get_mla_metadata_v1", \ + &get_mla_metadata_v1, \ + "get_mla_metadata_v1", \ + py::arg("seqlens_qo_indptr"), \ + py::arg("seqlens_kv_indptr"), \ + py::arg("num_heads_per_head_k"), \ + py::arg("num_heads_k"), \ + py::arg("is_causal"), \ + py::arg("work_metadata_ptrs"), \ + py::arg("work_info_set"), \ + py::arg("work_indptr"), \ + py::arg("reduce_indptr"), \ + py::arg("reduce_final_map"), \ + py::arg("reduce_partial_map"), \ + py::arg("kv_granularity") = 16, \ + py::arg("max_seqlen_qo") = -1, \ + py::arg("uni_seqlen_qo") = -1, \ + py::arg("fast_mode") = true, \ + py::arg("topk") = -1, \ + py::arg("max_split_per_batch") = -1, \ + py::arg("intra_batch_mode") = false, \ + py::arg("dtype_q") = std::nullopt, \ + py::arg("dtype_kv") = std::nullopt); \ m.def("get_mla_metadata_v1_no_redundant", &get_mla_metadata_v1_no_redundant); #define PA_METADATA_PYBIND \ diff --git a/csrc/kernels/fused_mrope_rms.cu b/csrc/kernels/fused_mrope_rms.cu index 080d91a689..ab3703a276 100644 --- a/csrc/kernels/fused_mrope_rms.cu +++ b/csrc/kernels/fused_mrope_rms.cu @@ -23,7 +23,7 @@ __inline__ __device__ T warp_reduce_sum(T val) { template __inline__ __device__ T warp_shfl_sync(T val, int src_id) { - return __shfl_sync(__activemask(), val, src_id, 32); + return __shfl(val, src_id, 32); } } // namespace block_utils @@ -106,7 +106,6 @@ __device__ __forceinline__ void warp_rms_norm_( int warp_t_id = threadIdx.x % 32; acc = block_utils::warp_reduce_sum(acc); acc = block_utils::warp_shfl_sync(acc, 0); - __syncwarp(); auto s_val = rsqrtf(acc / rms_dim + rms_eps); #pragma unroll for (int i = 0; i < VEC_SIZE; ++i) { @@ -259,7 +258,7 @@ __global__ void fused_mrope_rms_noneox_kernel( template void fused_rope_rms( - T *qkv, const T *q_w, const T *k_w, const T *cos_sin, const int64_t *positions, + T *qkv, const T *q_w, const T *k_w, const T *cos_sin, const int64_t *positions, int64_t ps0, int64_t ps1, int64_t num_tokens, int64_t num_heads_q, int64_t num_heads_k, int64_t num_heads_v, int64_t head_size, bool is_neox_style, double eps, hipStream_t stream) { TORCH_CHECK(head_size == 64 || head_size == 128 || head_size == 256); @@ -270,13 +269,13 @@ void fused_rope_rms( dim3 numBlocks((total_warps + num_warps_per_block - 1) / num_warps_per_block); std::array mrope_section = {0}; -#define DISPATCH_NEOX(HEAD_SIZE) \ - if (is_neox_style) { \ - fused_mrope_rms_neox_kernel<<>>( \ - qkv, q_w, k_w, cos_sin, positions, num_heads_q, num_heads_k, num_heads_v, eps, mrope_section, num_tokens, total_warps); \ - } else { \ - fused_mrope_rms_noneox_kernel<<>>( \ - qkv, q_w, k_w, cos_sin, positions, num_heads_q, num_heads_k, num_heads_v, eps, mrope_section, num_tokens, total_warps); \ +#define DISPATCH_NEOX(HEAD_SIZE) \ + if (is_neox_style) { \ + fused_mrope_rms_neox_kernel<<>>( \ + qkv, q_w, k_w, cos_sin, positions, ps0, ps1, num_heads_q, num_heads_k, num_heads_v, eps, mrope_section, num_tokens, total_warps); \ + } else { \ + fused_mrope_rms_noneox_kernel<<>>( \ + qkv, q_w, k_w, cos_sin, positions, ps0, ps1, num_heads_q, num_heads_k, num_heads_v, eps, mrope_section, num_tokens, total_warps); \ } switch (head_size) { @@ -402,3 +401,36 @@ void fused_mrope_3d_rms(Tensor &qkv, Tensor &qw, Tensor &kw, Tensor &cos_sin, Te stream); }); } + +void fused_rope_rms(Tensor &qkv, Tensor &qw, Tensor &kw, Tensor &cos_sin, Tensor &positions, + int64_t num_tokens, int64_t num_heads_q, int64_t num_heads_k, int64_t num_heads_v, int64_t head_size, + bool is_neox_style, double eps) { + TORCH_CHECK(qkv.is_contiguous() && qw.is_contiguous() && kw.is_contiguous() && cos_sin.is_contiguous()); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(qkv)); + auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); + auto pos_strides = positions.strides(); + TORCH_CHECK(pos_strides.size() == 1); + AT_DISPATCH_FLOATING_TYPES_AND2( + kBFloat16, + kHalf, + qkv.scalar_type(), + "fused_rope_rms", [&] { + using T = KernelElementType::type; + rope_rms::fused_rope_rms( + (T*)qkv.data_ptr(), + (T*)qw.data_ptr(), + (T*)kw.data_ptr(), + (T*)cos_sin.data_ptr(), + positions.data_ptr(), + 0, + pos_strides[0], + num_tokens, + num_heads_q, + num_heads_k, + num_heads_v, + head_size, + is_neox_style, + eps, + stream); + }); +} diff --git a/op_tests/test_fused_mrope_rms.py b/op_tests/test_fused_mrope_rms.py index 4e212b864d..7dcc975b4e 100644 --- a/op_tests/test_fused_mrope_rms.py +++ b/op_tests/test_fused_mrope_rms.py @@ -68,7 +68,7 @@ def run_torch_mrope_3d_rms( qw: Tensor, # contiguous (head_size) kw: Tensor, # contiguous (head_size) cos_sin: Tensor, # contiguous (max_positions * head_size) - positions: Tensor, # contiguous (3 * num_tokens) + positions: Tensor, # contiguous (3 * num_tokens) or (num_tokens) num_tokens: int, num_heads_q: int, num_heads_k: int, @@ -78,6 +78,7 @@ def run_torch_mrope_3d_rms( mrope_section: List[int], is_interleaved: bool, eps: float, + is_mrope: bool, ): q_size = num_heads_q * head_size k_size = num_heads_k * head_size @@ -94,22 +95,24 @@ def run_torch_mrope_3d_rms( k = k_by_head.view(k.shape) cos_sin = cos_sin.view(max_positions, head_size) - positions = positions.view(3, num_tokens) + if is_mrope: + positions = positions.view(3, num_tokens) cos_sin = cos_sin[positions] cos, sin = cos_sin.chunk(2, dim=-1) - if is_interleaved: - cos = apply_interleaved_rope(cos, mrope_section) - sin = apply_interleaved_rope(sin, mrope_section) - else: - cos = torch.cat( - [m[i] for i, m in enumerate(cos.split(mrope_section, dim=-1))], - dim=-1, - ) - sin = torch.cat( - [m[i] for i, m in enumerate(sin.split(mrope_section, dim=-1))], - dim=-1, - ) + if is_mrope: + if is_interleaved: + cos = apply_interleaved_rope(cos, mrope_section) + sin = apply_interleaved_rope(sin, mrope_section) + else: + cos = torch.cat( + [m[i] for i, m in enumerate(cos.split(mrope_section, dim=-1))], + dim=-1, + ) + sin = torch.cat( + [m[i] for i, m in enumerate(sin.split(mrope_section, dim=-1))], + dim=-1, + ) q_shape = q.shape q = q.view(num_tokens, -1, head_size) @@ -140,24 +143,42 @@ def run_aiter_mrope_3d_rms( mrope_section: List[int], is_interleaved: bool, eps: float, + is_mrope: bool, ): qkv = qkv.clone() # inplace op - aiter.fused_mrope_3d_rms( - qkv, - qw, - kw, - cos_sin, - positions, - num_tokens, - num_heads_q, - num_heads_k, - num_heads_v, - head_size, - is_neox_style, - mrope_section, - is_interleaved, - eps, - ) + + if is_mrope: + aiter.fused_mrope_3d_rms( + qkv, + qw, + kw, + cos_sin, + positions, + num_tokens, + num_heads_q, + num_heads_k, + num_heads_v, + head_size, + is_neox_style, + mrope_section, + is_interleaved, + eps, + ) + else: + aiter.fused_rope_rms( + qkv, + qw, + kw, + cos_sin, + positions, + num_tokens, + num_heads_q, + num_heads_k, + num_heads_v, + head_size, + is_neox_style, + eps, + ) q_size = num_heads_q * head_size k_size = num_heads_k * head_size @@ -179,7 +200,8 @@ def test_mrope_3d_rms( is_neox_style, mrope_section, is_interleaved, - eps=1e-6, + eps, + is_mrope, ): qkv = torch.randn( (num_tokens, num_heads_q + num_heads_k + num_heads_v, head_size), @@ -189,8 +211,12 @@ def test_mrope_3d_rms( qw = torch.randn(head_size, dtype=dtype, device="cuda") kw = torch.randn(head_size, dtype=dtype, device="cuda") cos_sin = torch.randn((max_positions, head_size), dtype=dtype, device="cuda") + if is_mrope: + pos_shape = (3, num_tokens) + else: + pos_shape = (num_tokens,) positions = torch.randint( - 0, max_positions, (3, num_tokens), dtype=torch.int64, device="cuda" + 0, max_positions, pos_shape, dtype=torch.int64, device="cuda" ) (q_ref, k_ref, v_ref), avg_torch = run_torch_mrope_3d_rms( @@ -208,6 +234,7 @@ def test_mrope_3d_rms( mrope_section, is_interleaved, eps, + is_mrope, ) (q, k, v), avg_cu = run_aiter_mrope_3d_rms( qkv, @@ -224,12 +251,12 @@ def test_mrope_3d_rms( mrope_section, is_interleaved, eps, + is_mrope, ) info = f"dtype:{dtype}, num_tokens:{num_tokens}, num_heads_q:{num_heads_q}, num_heads_k:{num_heads_k}, num_heads_v:{num_heads_v}, head_size:{head_size}, is_neox_style:{is_neox_style}" - info += ( - f", mrope_section:{mrope_section}, is_interleaved:{is_interleaved}, eps:{eps}" - ) + if is_mrope: + info += f", mrope_section:{mrope_section}, is_interleaved:{is_interleaved}, eps:{eps}" msg = f"[perf] === {info} === torch avg: {avg_torch:<8.2f} us, cu avg: {avg_cu:<8.2f} us, uplift: {avg_torch/avg_cu-1:<5.1%}" checkAllclose(q_ref, q, msg="q", rtol=1e-2, atol=0.05) checkAllclose(k_ref, k, msg="k", rtol=1e-2, atol=0.05) @@ -237,6 +264,32 @@ def test_mrope_3d_rms( if __name__ == "__main__": + # rope + is_neox_styles = [True, False] + num_tokens = [513, 1257, 127, 778, 10024, 3] + num_heads = [32, 64] + head_sizes = [64, 128, 256] + max_positions = 10000 + dtype = torch.bfloat16 + for is_neox_style in is_neox_styles: + for num_token in num_tokens: + for num_head in num_heads: + for i, head_size in enumerate(head_sizes): + test_mrope_3d_rms( + dtype, + num_token, + num_head, + num_head, + num_head, + head_size, + is_neox_style, + None, + None, + eps=1e-6, + is_mrope=False, + ) + + # mrope is_neox_styles = [True, False] num_tokens = [513, 1257, 127, 778, 10024, 3] num_heads = [32, 64] @@ -262,5 +315,6 @@ def test_mrope_3d_rms( ms, is_interleaved, eps=1e-6, + is_mrope=True, ) print("done")