Skip to content

Commit

Permalink
Merge pull request #2 from lizhenyun01/append_attn_headdim
Browse files Browse the repository at this point in the history
support  append_attn c16 for deep-seek-v3
  • Loading branch information
yuanlehome authored Jan 14, 2025
2 parents dbfd416 + 09f94b9 commit f1894f2
Show file tree
Hide file tree
Showing 6 changed files with 617 additions and 86 deletions.
68 changes: 34 additions & 34 deletions csrc/gpu/append_attn/decoder_write_cache_with_rope_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -79,29 +79,29 @@ __global__ void append_decode_cache_T_rope_kernel(

const int bias_idx = hi * head_size + h_bias;
Load<T, VecSize>(&quant_qkv[ori_idx], &src_vec);
if (hi < num_heads + kv_num_heads) {
// q k rope
const uint32_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
}
// if (hi < num_heads + kv_num_heads) {
// // q k rope
// const uint32_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
// Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
// Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
// }
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
// dequant + add_bias + rope
float input_left = static_cast<float>(src_vec[2 * i]);
float input_right = static_cast<float>(src_vec[2 * i + 1]);

if (hi < num_heads + kv_num_heads) {
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
out_vec[2 * i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
out_vec[2 * i + 1] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
} else {
out_vec[2 * i] = src_vec[2 * i];
out_vec[2 * i + 1] = src_vec[2 * i + 1];
}
// if (hi < num_heads + kv_num_heads) {
// const float cos_tmp = cos_emb_vec[i];
// const float sin_tmp = sin_emb_vec[i];
// out_vec[2 * i] =
// static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
// out_vec[2 * i + 1] =
// static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
// } else {
out_vec[2 * i] = src_vec[2 * i];
out_vec[2 * i + 1] = src_vec[2 * i + 1];
// }
}
if (hi < num_heads) {
// write q
Expand Down Expand Up @@ -307,28 +307,28 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
Load<T, VecSize>(&qkv[ori_idx_left], &left_vec);
Load<T, VecSize>(&qkv[ori_idx_right], &right_vec);

if (hi < num_heads + kv_num_heads) {
// q k rope
const uint32_t emb_idx = write_seq_id * head_size + h_bias;
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
}
// if (hi < num_heads + kv_num_heads) {
// // q k rope
// const uint32_t emb_idx = write_seq_id * head_size + h_bias;
// Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
// Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
// }
#pragma unroll
for (int i = 0; i < VecSize; i++) {
// rope
float input_left = static_cast<float>(left_vec[i]);
float input_right = static_cast<float>(right_vec[i]);
if (hi < num_heads + kv_num_heads) {
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
left_bias_vec[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
right_bias_vec[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
} else {
left_bias_vec[i] = static_cast<T>(input_left);
right_bias_vec[i] = static_cast<T>(input_right);
}
// if (hi < num_heads + kv_num_heads) {
// const float cos_tmp = cos_emb_vec[i];
// const float sin_tmp = sin_emb_vec[i];
// left_bias_vec[i] =
// static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
// right_bias_vec[i] =
// static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
// } else {
left_bias_vec[i] = static_cast<T>(input_left);
right_bias_vec[i] = static_cast<T>(input_right);
// }
}
if (hi < num_heads) {
// write q
Expand Down
36 changes: 20 additions & 16 deletions csrc/gpu/append_attn/encoder_write_cache_with_rope_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -736,8 +736,9 @@ __global__ void append_write_cache_kv_c8_qkv(
batch_id * max_seq_len - cum_offsets[batch_id];
const uint32_t kv_batch_stride = (num_heads + 2 * kv_num_heads) * HEAD_DIM;
const uint32_t kv_h_stride = HEAD_DIM;
__shared__ T k_smem_ori[num_rows_per_block * HEAD_DIM];
__shared__ T v_smem_ori[num_rows_per_block * HEAD_DIM];
extern __shared__ uint8_t smem[];
T *k_smem_ori = (T*)smem; // [num_rows_per_block * HEAD_DIM];
T *v_smem_ori = (T*)(smem + num_rows_per_block * HEAD_DIM * sizeof(T)); // [num_rows_per_block * HEAD_DIM];

smem_t k_smem(k_smem_ori);
smem_t v_smem(v_smem_ori);
Expand Down Expand Up @@ -983,12 +984,13 @@ __global__ void append_write_cache_kv_c4_qkv(
batch_id * max_seq_len - cum_offsets[batch_id];
const uint32_t kv_batch_stride = (num_heads + 2 * kv_num_heads) * HEAD_DIM;
const uint32_t kv_h_stride = HEAD_DIM;
__shared__ T k_smem_ori[num_rows_per_block * HEAD_DIM];
__shared__ T v_smem_ori[num_rows_per_block * HEAD_DIM];
__shared__ T k_scale_smem[HEAD_DIM];
__shared__ T v_scale_smem[HEAD_DIM];
__shared__ T k_zero_point_smem[HEAD_DIM];
__shared__ T v_zero_point_smem[HEAD_DIM];
extern __shared__ uint8_t smem[];
T *k_smem_ori = (T*)smem; // [num_rows_per_block * HEAD_DIM];
T *v_smem_ori = (T*)(smem + num_rows_per_block * HEAD_DIM * sizeof(T)); // [num_rows_per_block * HEAD_DIM];
T *k_scale_smem = (T*)(smem + num_rows_per_block * HEAD_DIM * 2 * sizeof(T)); // [HEAD_DIM];
T *v_scale_smem = (T*)(smem + (num_rows_per_block * HEAD_DIM * 2 + HEAD_DIM) * sizeof(T)); // [HEAD_DIM];
T *k_zero_point_smem = (T*)(smem + (num_rows_per_block * HEAD_DIM * 2 + HEAD_DIM * 2) * sizeof(T)); // [HEAD_DIM];
T *v_zero_point_smem = (T*)(smem + (num_rows_per_block * HEAD_DIM * 2 + HEAD_DIM * 3) * sizeof(T)); // [HEAD_DIM];
const T *cache_k_scale_now = cache_k_scales + kv_head_idx * HEAD_DIM;
const T *cache_k_zp_now = cache_k_zero_points + kv_head_idx * HEAD_DIM;
const T *cache_v_scale_now = cache_v_scales + kv_head_idx * HEAD_DIM;
Expand Down Expand Up @@ -1511,7 +1513,6 @@ void CascadeAppendWriteCacheKVC8QKV(
auto num_tokens = meta_data.token_nums;
auto num_heads = meta_data.q_num_heads;
auto kv_num_heads = meta_data.kv_num_heads;
auto head_dim = meta_data.head_dims;

const uint32_t pad_len = BLOCK_SIZE;

Expand All @@ -1530,9 +1531,11 @@ void CascadeAppendWriteCacheKVC8QKV(
HEAD_DIM,
BLOCK_SIZE,
num_warps>;
cudaFuncSetAttribute(
kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
kernel_fn<<<grids, blocks, 0, stream>>>(cache_k_out->data<uint8_t>(),
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(
kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
kernel_fn<<<grids, blocks, smem_size, stream>>>(cache_k_out->data<uint8_t>(),
cache_v_out->data<uint8_t>(),
qkv.data<T>(),
cache_k_scale.data<T>(),
Expand Down Expand Up @@ -1578,7 +1581,6 @@ void CascadeAppendWriteCacheKVC4QKV(
auto num_tokens = meta_data.token_nums;
auto num_heads = meta_data.q_num_heads;
auto kv_num_heads = meta_data.kv_num_heads;
auto head_dim = meta_data.head_dims;

const uint32_t pad_len = BLOCK_SIZE;

Expand All @@ -1598,9 +1600,11 @@ void CascadeAppendWriteCacheKVC4QKV(
HEAD_DIM,
BLOCK_SIZE,
num_warps>;
cudaFuncSetAttribute(
kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
kernel_fn<<<grids, blocks, 0, stream>>>(cache_k_out->data<uint8_t>(),
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(
kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
kernel_fn<<<grids, blocks, smem_size, stream>>>(cache_k_out->data<uint8_t>(),
cache_v_out->data<uint8_t>(),
qkv.data<T>(),
cache_k_scale.data<T>(),
Expand Down
72 changes: 36 additions & 36 deletions csrc/gpu/append_attn/encoder_write_cache_with_rope_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,42 +49,42 @@ void EncoderWriteCacheWithRopeKernel(
auto kv_num_heads = meta_data.kv_num_heads;
auto head_dim = meta_data.head_dims;

if (num_heads == kv_num_heads) {
rotary_qk_variable(
qkv_out->data<T>(),
qkv.data<QKV_TYPE>(),
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
rotary_embs.get().data<float>(),
padding_offsets.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
token_num,
num_heads,
max_seq_len,
rotary_embs.get().dims()[2],
head_dim,
stream,
use_neox_style);
} else {
gqa_rotary_qk_variable(
qkv_out->data<T>(),
qkv.data<QKV_TYPE>(),
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
rotary_embs.get().data<float>(),
padding_offsets.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
token_num,
num_heads,
kv_num_heads,
max_seq_len,
rotary_embs.get().dims()[2],
head_dim,
stream,
use_neox_style);
}
// if (num_heads == kv_num_heads) {
// rotary_qk_variable(
// qkv_out->data<T>(),
// qkv.data<QKV_TYPE>(),
// qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
// qkv_biases ? qkv_biases.get().data<T>() : nullptr,
// rotary_embs.get().data<float>(),
// padding_offsets.data<int>(),
// seq_lens_encoder.data<int>(),
// seq_lens_decoder.data<int>(),
// token_num,
// num_heads,
// max_seq_len,
// rotary_embs.get().dims()[2],
// head_dim,
// stream,
// use_neox_style);
// } else {
// gqa_rotary_qk_variable(
// qkv_out->data<T>(),
// qkv.data<QKV_TYPE>(),
// qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
// qkv_biases ? qkv_biases.get().data<T>() : nullptr,
// rotary_embs.get().data<float>(),
// padding_offsets.data<int>(),
// seq_lens_encoder.data<int>(),
// seq_lens_decoder.data<int>(),
// token_num,
// num_heads,
// kv_num_heads,
// max_seq_len,
// rotary_embs.get().dims()[2],
// head_dim,
// stream,
// use_neox_style);
// }
const uint32_t block_size = meta_data.block_size;
if (cache_quant_type_str == "none") {
CascadeAppendWriteCacheKVQKV<T>(meta_data,
Expand Down
10 changes: 10 additions & 0 deletions csrc/gpu/append_attn/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,16 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
__VA_ARGS__ \
break; \
} \
case 192: { \
constexpr size_t HEAD_DIM = 192; \
__VA_ARGS__ \
break; \
} \
case 256: { \
constexpr size_t HEAD_DIM = 256; \
__VA_ARGS__ \
break; \
} \
default: { \
PD_THROW("not support the head_dim: ", head_dim); \
} \
Expand Down
Loading

0 comments on commit f1894f2

Please sign in to comment.