55#include " cuda_utils.h"
66#include " cuda_compat.h"
77#include " dispatch_utils.h"
8+ #include " quantization/vectorization_utils.cuh"
89
910#ifdef USE_ROCM
1011 #include " quantization/fp8/amd/quant_utils.cuh"
@@ -261,14 +262,26 @@ __global__ void reshape_and_cache_kernel(
261262 }
262263}
263264
265+ // Used by vectorization_utils to copy/convert one element
266+ template <typename OutT, typename InT, Fp8KVCacheDataType kv_dt>
267+ struct CopyWithScaleOp {
268+ float scale;
269+
270+ __device__ __forceinline__ void operator ()(OutT& dst, const InT src) const {
271+ if constexpr (kv_dt == Fp8KVCacheDataType::kAuto ) {
272+ dst = static_cast <OutT>(src);
273+ } else {
274+ dst = fp8::scaled_convert<OutT, InT, kv_dt>(src, scale);
275+ }
276+ }
277+ };
278+
264279template <typename scalar_t , typename cache_t , Fp8KVCacheDataType kv_dt>
265280__global__ void reshape_and_cache_flash_kernel (
266281 const scalar_t * __restrict__ key, // [num_tokens, num_heads, head_size]
267282 const scalar_t * __restrict__ value, // [num_tokens, num_heads, head_size]
268- cache_t * __restrict__ key_cache, // [num_blocks, block_size, num_heads,
269- // head_size]
270- cache_t * __restrict__ value_cache, // [num_blocks, block_size, num_heads,
271- // head_size]
283+ cache_t * __restrict__ key_cache, // NHD or HND, shape see comments below
284+ cache_t * __restrict__ value_cache, // same above
272285 const int64_t * __restrict__ slot_mapping, // [num_tokens]
273286 const int64_t block_stride, const int64_t page_stride,
274287 const int64_t head_stride, const int64_t key_stride,
@@ -282,25 +295,58 @@ __global__ void reshape_and_cache_flash_kernel(
282295 }
283296 const int64_t block_idx = slot_idx / block_size;
284297 const int64_t block_offset = slot_idx % block_size;
285- const int n = num_heads * head_size;
286- for (int i = threadIdx .x ; i < n; i += blockDim .x ) {
287- const int64_t src_key_idx = token_idx * key_stride + i;
288- const int64_t src_value_idx = token_idx * value_stride + i;
289- const int head_idx = i / head_size;
290- const int head_offset = i % head_size;
291- const int64_t tgt_key_value_idx = block_idx * block_stride +
292- block_offset * page_stride +
293- head_idx * head_stride + head_offset;
294- scalar_t tgt_key = key[src_key_idx];
295- scalar_t tgt_value = value[src_value_idx];
296- if constexpr (kv_dt == Fp8KVCacheDataType::kAuto ) {
297- key_cache[tgt_key_value_idx] = tgt_key;
298- value_cache[tgt_key_value_idx] = tgt_value;
299- } else {
300- key_cache[tgt_key_value_idx] =
301- fp8::scaled_convert<cache_t , scalar_t , kv_dt>(tgt_key, *k_scale);
302- value_cache[tgt_key_value_idx] =
303- fp8::scaled_convert<cache_t , scalar_t , kv_dt>(tgt_value, *v_scale);
298+ const int n_elems = num_heads * head_size;
299+
300+ // pointers to the beginning of the source row for this token.
301+ const scalar_t * __restrict__ key_src = key + token_idx * key_stride;
302+ const scalar_t * __restrict__ value_src = value + token_idx * value_stride;
303+
304+ // find the start position inside the kv-cache for this token.
305+ cache_t * __restrict__ key_dst =
306+ key_cache + block_idx * block_stride + block_offset * page_stride;
307+ cache_t * __restrict__ value_dst =
308+ value_cache + block_idx * block_stride + block_offset * page_stride;
309+
310+ // this is true for the NHD layout where `head_stride == head_size`
311+ const bool is_contiguous_heads = (head_stride == head_size);
312+
313+ float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto ) ? 0 .f : *k_scale;
314+ float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto ) ? 0 .f : *v_scale;
315+ constexpr int VEC_SIZE = (sizeof (scalar_t ) == 2 ) ? 8 : 4 ;
316+ CopyWithScaleOp<cache_t , scalar_t , kv_dt> k_op{k_scale_val};
317+ CopyWithScaleOp<cache_t , scalar_t , kv_dt> v_op{v_scale_val};
318+ if (is_contiguous_heads) {
319+ // NHD layout
320+ // kv cache: [num_blocks, block_size, num_heads, head_size]
321+ vectorize_with_alignment<VEC_SIZE>(key_src, key_dst, n_elems, threadIdx .x ,
322+ blockDim .x , k_op);
323+
324+ vectorize_with_alignment<VEC_SIZE>(value_src, value_dst, n_elems,
325+ threadIdx .x , blockDim .x , v_op);
326+
327+ } else {
328+ // HND layout: heads are strided, but each head_size segment is contiguous
329+ // kv cache: [num_blocks, num_heads, block_size, head_size]
330+ const int lane = threadIdx .x & 31 ; // 0..31 within warp
331+ const int warp_id = threadIdx .x >> 5 ; // warp index within block
332+ const int warps_per_block = blockDim .x >> 5 ;
333+
334+ for (int head = warp_id; head < num_heads; head += warps_per_block) {
335+ const scalar_t * __restrict__ k_src_h = key_src + head * head_size;
336+ const scalar_t * __restrict__ v_src_h = value_src + head * head_size;
337+
338+ cache_t * __restrict__ k_dst_h =
339+ key_dst + static_cast <int64_t >(head) * head_stride;
340+ cache_t * __restrict__ v_dst_h =
341+ value_dst + static_cast <int64_t >(head) * head_stride;
342+
343+ // within each head, let the 32 threads of the warp perform the vector
344+ // copy
345+ vectorize_with_alignment<VEC_SIZE>(k_src_h, k_dst_h, head_size, lane, 32 ,
346+ k_op);
347+
348+ vectorize_with_alignment<VEC_SIZE>(v_src_h, v_dst_h, head_size, lane, 32 ,
349+ v_op);
304350 }
305351 }
306352}
0 commit comments