Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions aiter/ops/fused_mrope_rms.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,20 @@ def fused_mrope_3d_rms(
is_interleaved: bool,
eps: float,
) -> None: ...


@compile_ops("module_fused_mrope_rms")
def fused_rope_rms(
qkv: Tensor,
qw: Tensor,
kw: Tensor,
cos_sin: Tensor,
positions: Tensor,
num_tokens: int,
num_heads_q: int,
num_heads_k: int,
num_heads_v: int,
head_size: int,
is_neox_style: bool,
eps: float,
) -> None: ...
13 changes: 13 additions & 0 deletions csrc/include/fused_mrope_rms.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,16 @@ void fused_mrope_3d_rms(Tensor& qkv,
std::vector<int64_t> mrope_section_,
bool is_interleaved,
double eps);

void fused_rope_rms(Tensor& qkv,
Tensor& qw,
Tensor& kw,
Tensor& cos_sin,
Tensor& positions,
int64_t num_tokens,
int64_t num_heads_q,
int64_t num_heads_k,
int64_t num_heads_v,
int64_t head_size,
bool is_neox_style,
double eps);
52 changes: 27 additions & 25 deletions csrc/include/rocm_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1315,7 +1315,9 @@ namespace py = pybind11;
m.def("rope_cached_positions_offsets_fwd_impl", &rope_cached_positions_offsets_fwd_impl); \
m.def("rope_cached_positions_offsets_2c_fwd_impl", &rope_cached_positions_offsets_2c_fwd_impl);

#define FUSED_MROPE_RMS_PYBIND m.def("fused_mrope_3d_rms", &fused_mrope_3d_rms);
#define FUSED_MROPE_RMS_PYBIND \
m.def("fused_mrope_3d_rms", &fused_mrope_3d_rms); \
m.def("fused_rope_rms", &fused_rope_rms);

#define SMOOTHQUANT_PYBIND \
m.def("smoothquant_fwd", &smoothquant_fwd); \
Expand Down Expand Up @@ -1435,30 +1437,30 @@ namespace py = pybind11;
py::arg("stride0"), \
py::arg("stride1"));

#define MLA_METADATA_PYBIND \
m.def("get_mla_metadata_v1", \
&get_mla_metadata_v1, \
"get_mla_metadata_v1", \
py::arg("seqlens_qo_indptr"), \
py::arg("seqlens_kv_indptr"), \
py::arg("num_heads_per_head_k"), \
py::arg("num_heads_k"), \
py::arg("is_causal"), \
py::arg("work_metadata_ptrs"), \
py::arg("work_info_set"), \
py::arg("work_indptr"), \
py::arg("reduce_indptr"), \
py::arg("reduce_final_map"), \
py::arg("reduce_partial_map"), \
py::arg("kv_granularity") = 16, \
py::arg("max_seqlen_qo") = -1, \
py::arg("uni_seqlen_qo") = -1, \
py::arg("fast_mode") = true, \
py::arg("topk") = -1, \
py::arg("max_split_per_batch") = -1, \
py::arg("intra_batch_mode") = false, \
py::arg("dtype_q") = std::nullopt, \
py::arg("dtype_kv") = std::nullopt); \
#define MLA_METADATA_PYBIND \
m.def("get_mla_metadata_v1", \
&get_mla_metadata_v1, \
"get_mla_metadata_v1", \
py::arg("seqlens_qo_indptr"), \
py::arg("seqlens_kv_indptr"), \
py::arg("num_heads_per_head_k"), \
py::arg("num_heads_k"), \
py::arg("is_causal"), \
py::arg("work_metadata_ptrs"), \
py::arg("work_info_set"), \
py::arg("work_indptr"), \
py::arg("reduce_indptr"), \
py::arg("reduce_final_map"), \
py::arg("reduce_partial_map"), \
py::arg("kv_granularity") = 16, \
py::arg("max_seqlen_qo") = -1, \
py::arg("uni_seqlen_qo") = -1, \
py::arg("fast_mode") = true, \
py::arg("topk") = -1, \
py::arg("max_split_per_batch") = -1, \
py::arg("intra_batch_mode") = false, \
py::arg("dtype_q") = std::nullopt, \
py::arg("dtype_kv") = std::nullopt); \
m.def("get_mla_metadata_v1_no_redundant", &get_mla_metadata_v1_no_redundant);

#define PA_METADATA_PYBIND \
Expand Down
52 changes: 42 additions & 10 deletions csrc/kernels/fused_mrope_rms.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ __inline__ __device__ T warp_reduce_sum(T val) {

template <typename T>
__inline__ __device__ T warp_shfl_sync(T val, int src_id) {
return __shfl_sync(__activemask(), val, src_id, 32);
return __shfl(val, src_id, 32);
}

} // namespace block_utils
Expand Down Expand Up @@ -106,7 +106,6 @@ __device__ __forceinline__ void warp_rms_norm_(
int warp_t_id = threadIdx.x % 32;
acc = block_utils::warp_reduce_sum<float>(acc);
acc = block_utils::warp_shfl_sync<float>(acc, 0);
__syncwarp();
auto s_val = rsqrtf(acc / rms_dim + rms_eps);
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
Expand Down Expand Up @@ -259,7 +258,7 @@ __global__ void fused_mrope_rms_noneox_kernel(

template <typename T>
void fused_rope_rms(
T *qkv, const T *q_w, const T *k_w, const T *cos_sin, const int64_t *positions,
T *qkv, const T *q_w, const T *k_w, const T *cos_sin, const int64_t *positions, int64_t ps0, int64_t ps1,
int64_t num_tokens, int64_t num_heads_q, int64_t num_heads_k, int64_t num_heads_v, int64_t head_size,
bool is_neox_style, double eps, hipStream_t stream) {
TORCH_CHECK(head_size == 64 || head_size == 128 || head_size == 256);
Expand All @@ -270,13 +269,13 @@ void fused_rope_rms(
dim3 numBlocks((total_warps + num_warps_per_block - 1) / num_warps_per_block);
std::array<int64_t, 1> mrope_section = {0};

#define DISPATCH_NEOX(HEAD_SIZE) \
if (is_neox_style) { \
fused_mrope_rms_neox_kernel<T, HEAD_SIZE, false, false, 1><<<numBlocks, threadsPerBlock, 0, stream>>>( \
qkv, q_w, k_w, cos_sin, positions, num_heads_q, num_heads_k, num_heads_v, eps, mrope_section, num_tokens, total_warps); \
} else { \
fused_mrope_rms_noneox_kernel<T, HEAD_SIZE, false, false, 1><<<numBlocks, threadsPerBlock, 0, stream>>>( \
qkv, q_w, k_w, cos_sin, positions, num_heads_q, num_heads_k, num_heads_v, eps, mrope_section, num_tokens, total_warps); \
#define DISPATCH_NEOX(HEAD_SIZE) \
if (is_neox_style) { \
fused_mrope_rms_neox_kernel<T, HEAD_SIZE, false, false, 1><<<numBlocks, threadsPerBlock, 0, stream>>>( \
qkv, q_w, k_w, cos_sin, positions, ps0, ps1, num_heads_q, num_heads_k, num_heads_v, eps, mrope_section, num_tokens, total_warps); \
} else { \
fused_mrope_rms_noneox_kernel<T, HEAD_SIZE, false, false, 1><<<numBlocks, threadsPerBlock, 0, stream>>>( \
qkv, q_w, k_w, cos_sin, positions, ps0, ps1, num_heads_q, num_heads_k, num_heads_v, eps, mrope_section, num_tokens, total_warps); \
}

switch (head_size) {
Expand Down Expand Up @@ -402,3 +401,36 @@ void fused_mrope_3d_rms(Tensor &qkv, Tensor &qw, Tensor &kw, Tensor &cos_sin, Te
stream);
});
}

void fused_rope_rms(Tensor &qkv, Tensor &qw, Tensor &kw, Tensor &cos_sin, Tensor &positions,
int64_t num_tokens, int64_t num_heads_q, int64_t num_heads_k, int64_t num_heads_v, int64_t head_size,
bool is_neox_style, double eps) {
TORCH_CHECK(qkv.is_contiguous() && qw.is_contiguous() && kw.is_contiguous() && cos_sin.is_contiguous());
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(qkv));
auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
auto pos_strides = positions.strides();
TORCH_CHECK(pos_strides.size() == 1);
AT_DISPATCH_FLOATING_TYPES_AND2(
kBFloat16,
kHalf,
qkv.scalar_type(),
"fused_rope_rms", [&] {
using T = KernelElementType<scalar_t>::type;
rope_rms::fused_rope_rms<T>(
(T*)qkv.data_ptr<scalar_t>(),
(T*)qw.data_ptr<scalar_t>(),
(T*)kw.data_ptr<scalar_t>(),
(T*)cos_sin.data_ptr<scalar_t>(),
positions.data_ptr<int64_t>(),
0,
pos_strides[0],
num_tokens,
num_heads_q,
num_heads_k,
num_heads_v,
head_size,
is_neox_style,
eps,
stream);
});
}
124 changes: 89 additions & 35 deletions op_tests/test_fused_mrope_rms.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def run_torch_mrope_3d_rms(
qw: Tensor, # contiguous (head_size)
kw: Tensor, # contiguous (head_size)
cos_sin: Tensor, # contiguous (max_positions * head_size)
positions: Tensor, # contiguous (3 * num_tokens)
positions: Tensor, # contiguous (3 * num_tokens) or (num_tokens)
num_tokens: int,
num_heads_q: int,
num_heads_k: int,
Expand All @@ -78,6 +78,7 @@ def run_torch_mrope_3d_rms(
mrope_section: List[int],
is_interleaved: bool,
eps: float,
is_mrope: bool,
):
q_size = num_heads_q * head_size
k_size = num_heads_k * head_size
Expand All @@ -94,22 +95,24 @@ def run_torch_mrope_3d_rms(
k = k_by_head.view(k.shape)

cos_sin = cos_sin.view(max_positions, head_size)
positions = positions.view(3, num_tokens)
if is_mrope:
positions = positions.view(3, num_tokens)
cos_sin = cos_sin[positions]
cos, sin = cos_sin.chunk(2, dim=-1)

if is_interleaved:
cos = apply_interleaved_rope(cos, mrope_section)
sin = apply_interleaved_rope(sin, mrope_section)
else:
cos = torch.cat(
[m[i] for i, m in enumerate(cos.split(mrope_section, dim=-1))],
dim=-1,
)
sin = torch.cat(
[m[i] for i, m in enumerate(sin.split(mrope_section, dim=-1))],
dim=-1,
)
if is_mrope:
if is_interleaved:
cos = apply_interleaved_rope(cos, mrope_section)
sin = apply_interleaved_rope(sin, mrope_section)
else:
cos = torch.cat(
[m[i] for i, m in enumerate(cos.split(mrope_section, dim=-1))],
dim=-1,
)
sin = torch.cat(
[m[i] for i, m in enumerate(sin.split(mrope_section, dim=-1))],
dim=-1,
)

q_shape = q.shape
q = q.view(num_tokens, -1, head_size)
Expand Down Expand Up @@ -140,24 +143,42 @@ def run_aiter_mrope_3d_rms(
mrope_section: List[int],
is_interleaved: bool,
eps: float,
is_mrope: bool,
):
qkv = qkv.clone() # inplace op
aiter.fused_mrope_3d_rms(
qkv,
qw,
kw,
cos_sin,
positions,
num_tokens,
num_heads_q,
num_heads_k,
num_heads_v,
head_size,
is_neox_style,
mrope_section,
is_interleaved,
eps,
)

if is_mrope:
aiter.fused_mrope_3d_rms(
qkv,
qw,
kw,
cos_sin,
positions,
num_tokens,
num_heads_q,
num_heads_k,
num_heads_v,
head_size,
is_neox_style,
mrope_section,
is_interleaved,
eps,
)
else:
aiter.fused_rope_rms(
qkv,
qw,
kw,
cos_sin,
positions,
num_tokens,
num_heads_q,
num_heads_k,
num_heads_v,
head_size,
is_neox_style,
eps,
)

q_size = num_heads_q * head_size
k_size = num_heads_k * head_size
Expand All @@ -179,7 +200,8 @@ def test_mrope_3d_rms(
is_neox_style,
mrope_section,
is_interleaved,
eps=1e-6,
eps,
is_mrope,
):
qkv = torch.randn(
(num_tokens, num_heads_q + num_heads_k + num_heads_v, head_size),
Expand All @@ -189,8 +211,12 @@ def test_mrope_3d_rms(
qw = torch.randn(head_size, dtype=dtype, device="cuda")
kw = torch.randn(head_size, dtype=dtype, device="cuda")
cos_sin = torch.randn((max_positions, head_size), dtype=dtype, device="cuda")
if is_mrope:
pos_shape = (3, num_tokens)
else:
pos_shape = (num_tokens,)
positions = torch.randint(
0, max_positions, (3, num_tokens), dtype=torch.int64, device="cuda"
0, max_positions, pos_shape, dtype=torch.int64, device="cuda"
)

(q_ref, k_ref, v_ref), avg_torch = run_torch_mrope_3d_rms(
Expand All @@ -208,6 +234,7 @@ def test_mrope_3d_rms(
mrope_section,
is_interleaved,
eps,
is_mrope,
)
(q, k, v), avg_cu = run_aiter_mrope_3d_rms(
qkv,
Expand All @@ -224,19 +251,45 @@ def test_mrope_3d_rms(
mrope_section,
is_interleaved,
eps,
is_mrope,
)

info = f"dtype:{dtype}, num_tokens:{num_tokens}, num_heads_q:{num_heads_q}, num_heads_k:{num_heads_k}, num_heads_v:{num_heads_v}, head_size:{head_size}, is_neox_style:{is_neox_style}"
info += (
f", mrope_section:{mrope_section}, is_interleaved:{is_interleaved}, eps:{eps}"
)
if is_mrope:
info += f", mrope_section:{mrope_section}, is_interleaved:{is_interleaved}, eps:{eps}"
msg = f"[perf] === {info} === torch avg: {avg_torch:<8.2f} us, cu avg: {avg_cu:<8.2f} us, uplift: {avg_torch/avg_cu-1:<5.1%}"
checkAllclose(q_ref, q, msg="q", rtol=1e-2, atol=0.05)
checkAllclose(k_ref, k, msg="k", rtol=1e-2, atol=0.05)
checkAllclose(v_ref, v, msg=msg, rtol=1e-2, atol=0.05)


if __name__ == "__main__":
# rope
is_neox_styles = [True, False]
num_tokens = [513, 1257, 127, 778, 10024, 3]
num_heads = [32, 64]
head_sizes = [64, 128, 256]
max_positions = 10000
dtype = torch.bfloat16
for is_neox_style in is_neox_styles:
for num_token in num_tokens:
for num_head in num_heads:
for i, head_size in enumerate(head_sizes):
test_mrope_3d_rms(
dtype,
num_token,
num_head,
num_head,
num_head,
head_size,
is_neox_style,
None,
None,
eps=1e-6,
is_mrope=False,
)

# mrope
is_neox_styles = [True, False]
num_tokens = [513, 1257, 127, 778, 10024, 3]
num_heads = [32, 64]
Expand All @@ -262,5 +315,6 @@ def test_mrope_3d_rms(
ms,
is_interleaved,
eps=1e-6,
is_mrope=True,
)
print("done")