Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aiter/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ def mla_decode_fwd(
q = q.view(total_s, nhead, -1)
o = o.view(total_s, nhead, -1)
io_transformed = True
max_seqlen_q = 1
else:
assert False, f"{nhead=} and {max_seqlen_q=} not supported"

Expand Down
155 changes: 87 additions & 68 deletions csrc/kernels/mla/metadata/v1_2_device.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (C) 2025-2026, Advanced Micro Devices, Inc. All rights reserved.

#include "v1_comm.cuh"

Expand Down Expand Up @@ -153,37 +153,42 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__
// If current cu part is able to handle this batch of seqences
if(remain_payload >= (remain_kv_blocks + Traits::kFixedOverheadNumBlocks))
{
const int32_t num_splits = curr_n_split_idx + 1;

auto fill_work_info = [&](const int32_t split_idx) {
const int32_t global_qo_tile_idx = tot_qo_tiles;

MlaWorkInfo work_info{};
work_info.batch_idx = curr_batch;
work_info.qo_start =
qo_state.get_begin(curr_batch) + curr_qo_tile_idx * qo_tile_size;
work_info.qo_end = ck_tile::min(work_info.qo_start + qo_tile_size,
qo_state.get_end(curr_batch));
work_info.kv_start = curr_kv_begin + (curr_kv_block * params.kv_granularity);
int32_t batch_tail = (num_qo_tiles - 1 - curr_qo_tile_idx);
if constexpr(!Traits::kIsSparse)
int32_t num_splits = curr_n_split_idx + 1;

const int32_t global_qo_tile_idx = tot_qo_tiles;

MlaWorkInfo work_info{};
work_info.batch_idx = curr_batch;
work_info.qo_start =
qo_state.get_begin(curr_batch) + curr_qo_tile_idx * qo_tile_size;
work_info.qo_end = ck_tile::min(work_info.qo_start + qo_tile_size,
qo_state.get_end(curr_batch));
work_info.kv_start = curr_kv_begin + (curr_kv_block * params.kv_granularity);
int32_t batch_tail = (num_qo_tiles - 1 - curr_qo_tile_idx);
if constexpr(!Traits::kIsSparse)
{
if (params.qk_batch_ratio != 1)
{
if (params.qk_batch_ratio != 1)
{
batch_tail = num_qo_tiles - (work_info.qo_start / params.qk_batch_ratio) % ori_seqlen_qo - 1;
}
batch_tail = num_qo_tiles - (work_info.qo_start / params.qk_batch_ratio) % ori_seqlen_qo - 1;
}
work_info.kv_end = ck_tile::min(
work_info.kv_start + (remain_kv_blocks * params.kv_granularity),
curr_kv_end - batch_tail);
work_info.kv_offset = curr_kv_end - work_info.kv_end;
}
work_info.kv_end = ck_tile::min(
work_info.kv_start + (remain_kv_blocks * params.kv_granularity),
curr_kv_end - batch_tail);
work_info.kv_offset = curr_kv_end - work_info.kv_end;

// fix non-natively case mtp acc for case: eg:
// kv_start=5, kv_end=4
if(work_info.kv_start >= work_info.kv_end)
{
--curr_n_split_idx;
--num_splits;
}

auto fill_work_info = [&](const int32_t split_idx) {
// split related info
if(curr_n_split_idx > 0)
{
// set work info
work_info.partial_qo_loc = partial_idx;

// set reduce info
params.p_reduce_indptr[global_qo_tile_idx + 1] =
last_reduce_indptr + num_splits;
Expand All @@ -194,31 +199,36 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__
}
else
{
work_info.partial_qo_loc = -1;
params.p_reduce_indptr[global_qo_tile_idx + 1] = last_reduce_indptr;
}

p_work_info_set[num_works] = work_info;
};

// record a work in work_info_set
if(curr_n_split_idx > 0)
{
work_info.partial_qo_loc = partial_idx;
for(int32_t idx = lane_idx; idx < num_splits; idx += ck_tile::get_warp_size())
{
fill_work_info(idx);
}

partial_idx += qo_tile_size;
last_reduce_indptr += num_splits;

}
else
{
work_info.partial_qo_loc = -1;
fill_work_info(0);
}
if(work_info.kv_start < work_info.kv_end)
{
p_work_info_set[num_works] = work_info;
num_works += 1;
}

tot_qo_tiles += 1;
num_works += 1;

remain_payload -= (remain_kv_blocks + Traits::kFixedOverheadNumBlocks);

Expand Down Expand Up @@ -272,40 +282,44 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__
{
const int32_t consuming_blks = remain_payload - Traits::kFixedOverheadNumBlocks;

auto fill_work_info = [&]() {
MlaWorkInfo work_info{};
work_info.batch_idx = curr_batch;
work_info.qo_start =
qo_state.get_begin(curr_batch) + curr_qo_tile_idx * qo_tile_size;
work_info.qo_end = ck_tile::min(work_info.qo_start + qo_tile_size,
qo_state.get_end(curr_batch));
work_info.kv_start =
curr_kv_begin + (curr_kv_block * params.kv_granularity);
int32_t batch_tail = (num_qo_tiles - 1 - curr_qo_tile_idx);
if constexpr(!Traits::kIsSparse)
MlaWorkInfo work_info{};
work_info.batch_idx = curr_batch;
work_info.qo_start =
qo_state.get_begin(curr_batch) + curr_qo_tile_idx * qo_tile_size;
work_info.qo_end = ck_tile::min(work_info.qo_start + qo_tile_size,
qo_state.get_end(curr_batch));
work_info.kv_start =
curr_kv_begin + (curr_kv_block * params.kv_granularity);
int32_t batch_tail = (num_qo_tiles - 1 - curr_qo_tile_idx);
if constexpr(!Traits::kIsSparse)
{
if (params.qk_batch_ratio != 1)
{
if (params.qk_batch_ratio != 1)
{
batch_tail = num_qo_tiles - (work_info.qo_start / params.qk_batch_ratio) % ori_seqlen_qo - 1;
}
batch_tail = num_qo_tiles - (work_info.qo_start / params.qk_batch_ratio) % ori_seqlen_qo - 1;
}
work_info.kv_end = ck_tile::min(
work_info.kv_start + (consuming_blks * params.kv_granularity),
curr_kv_end - batch_tail);
work_info.kv_offset = curr_kv_end - work_info.kv_end;
work_info.partial_qo_loc = partial_idx;
p_work_info_set[num_works] = work_info;
};
}
work_info.kv_end = ck_tile::min(
work_info.kv_start + (consuming_blks * params.kv_granularity),
curr_kv_end - batch_tail);
work_info.kv_offset = curr_kv_end - work_info.kv_end;

// record a work in work_info_set
fill_work_info();
work_info.partial_qo_loc = (curr_n_split_idx == 0 && batch_tail == work_info.kv_offset) ?
-1 : partial_idx;

partial_idx += qo_tile_size;
num_works += 1;
// fix non-natively case mtp acc for case
if(work_info.kv_start < work_info.kv_end)
{
p_work_info_set[num_works] = work_info;
num_works += 1;
++curr_n_split_idx;
}
if (batch_tail != work_info.kv_offset)
{
partial_idx += qo_tile_size;
}

// update state
curr_kv_block += consuming_blks;
++curr_n_split_idx;
}
break;
}
Expand Down Expand Up @@ -371,7 +385,6 @@ void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [ba
torch::Tensor& reduce_final_map,
torch::Tensor& reduce_partial_map)
{
constexpr int32_t kPackedQoLenPerWg = 128;
const hipStream_t stream = at::hip::getCurrentHIPStream();

hipDevice_t dev;
Expand All @@ -394,6 +407,9 @@ void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [ba
const bool kv_is_fp8 =
(kv_dtype == at::ScalarType::Float8_e4m3fnuz || kv_dtype == at::ScalarType::Float8_e4m3fn);

const bool q_is_bf16 = q_dtype == at::ScalarType::BFloat16;
const bool kv_is_bf16 = kv_dtype == at::ScalarType::BFloat16;

const bool natively_supported = (num_heads == 16) ||
((num_heads == 128) && q_is_fp8 && kv_is_fp8);

Expand Down Expand Up @@ -436,15 +452,18 @@ void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [ba
params.qk_batch_ratio = qk_batch_ratio;

// launch kernel
MLA_METADATA_DISPATCHER(
max_seqlen_qo * num_heads_per_head_k,
kPackedQoLenPerWg,
params.uni_seqlen_qo,
topk,
dispatch_mla_metadata_v1_2_device<kPackedQoLenPerWg, kQoSplits, kUniSeqlenQo, kIsSparse>(
params,
stream,
max_seqlen_qo,
dev_prop.warpSize,
dev_prop.maxSharedMemoryPerMultiProcessor));
MLA_NUM_HEADS_DISPATCHER(
num_heads_per_head_k,
MLA_METADATA_DISPATCHER(
max_seqlen_qo * num_heads_per_head_k,
kPackedQoLenPerWg,
params.uni_seqlen_qo,
topk,
dispatch_mla_metadata_v1_2_device<kPackedQoLenPerWg, kQoSplits, kUniSeqlenQo, kIsSparse>(
params,
stream,
max_seqlen_qo,
dev_prop.warpSize,
dev_prop.maxSharedMemoryPerMultiProcessor)));

}
21 changes: 21 additions & 0 deletions csrc/kernels/mla/metadata/v1_comm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -384,3 +384,24 @@ private:
MLA_UNI_SEQLEN_DISPATCHER((UNI_SEQLEN_QO), __VA_ARGS__); \
} \
}

#define MLA_NUM_HEADS_CASE(C_NUM_HEADS, ...) \
case C_NUM_HEADS: \
{ \
constexpr int32_t kPackedQoLenPerWg = C_NUM_HEADS; \
__VA_ARGS__; \
break; \
}

#define MLA_NUM_HEADS_DISPATCHER(NUM_HEADS, ...) \
switch (NUM_HEADS) \
{ \
MLA_NUM_HEADS_CASE(32, __VA_ARGS__); \
MLA_NUM_HEADS_CASE(64, __VA_ARGS__); \
default: \
{ \
constexpr int32_t kPackedQoLenPerWg = 128; \
__VA_ARGS__; \
break; \
} \
}