@@ -46,7 +46,10 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
4646 char * src_ptr = static_cast <char *>(src.data_ptr ());
4747 char * dst_ptr = static_cast <char *>(dst.data_ptr ());
4848
49- const int64_t block_size_in_bytes = src.element_size () * src[0 ].numel ();
49+ // We use the stride instead of numel in case the cache is padded for memory
50+ // alignment reasons, we assume the blocks data (inclusive of any padding)
51+ // is contiguous in memory
52+ const int64_t block_size_in_bytes = src.element_size () * src.stride (0 );
5053 const at::cuda::OptionalCUDAGuard device_guard (
5154 src_device.is_cuda () ? src_device : dst_device);
5255 const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
@@ -93,6 +96,24 @@ __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
9396 }
9497}
9598
99+ // Kernel for MLA, which works on a single joint kv_cache
100+ // Grid: (num_layers, num_pairs)
101+ template <typename scalar_t >
102+ __global__ void copy_blocks_mla_kernel (
103+ int64_t * cache_ptrs, const int64_t * __restrict__ block_mapping,
104+ const int mem_footprint_per_block) {
105+ const int layer_idx = blockIdx .x ;
106+ const int pair_idx = blockIdx .y ;
107+ scalar_t * cache = reinterpret_cast <scalar_t *>(cache_ptrs[layer_idx]);
108+ int64_t src_block = block_mapping[2 * pair_idx];
109+ int64_t dst_block = block_mapping[2 * pair_idx + 1 ];
110+ int64_t src_offset = src_block * mem_footprint_per_block;
111+ int64_t dst_offset = dst_block * mem_footprint_per_block;
112+ for (int i = threadIdx .x ; i < mem_footprint_per_block; i += blockDim .x ) {
113+ cache[dst_offset + i] = cache[src_offset + i];
114+ }
115+ }
116+
96117} // namespace vllm
97118
98119// Note: the key_caches and value_caches vectors are constant but
@@ -147,6 +168,42 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
147168 }));
148169}
149170
171+ // copy blocks kernel for MLA (assumes a joint KV-cache)
172+ void copy_blocks_mla (std::vector<torch::Tensor> const & kv_caches,
173+ const torch::Tensor& block_mapping) {
174+ int num_layers = kv_caches.size ();
175+ if (num_layers == 0 ) {
176+ return ;
177+ }
178+ torch::Device cache_device = kv_caches[0 ].device ();
179+ TORCH_CHECK (cache_device.is_cuda (), " kv_cache must be on CUDA" );
180+
181+ std::vector<int64_t > cache_ptrs (num_layers);
182+ for (int layer_idx = 0 ; layer_idx < num_layers; ++layer_idx) {
183+ cache_ptrs[layer_idx] =
184+ reinterpret_cast <int64_t >(kv_caches[layer_idx].data_ptr ());
185+ }
186+ torch::Tensor cache_ptrs_tensor =
187+ torch::from_blob (cache_ptrs.data (), {num_layers}, torch::kInt64 )
188+ .to (cache_device);
189+
190+ int num_pairs = block_mapping.size (0 );
191+ // We use the stride instead of numel in case the cache is padded for memory
192+ // alignment reasons, we assume the blocks data (inclusive of any padding)
193+ // is contiguous in memory
194+ int mem_footprint_per_block = kv_caches[0 ].stride (0 );
195+ dim3 grid (num_layers, num_pairs);
196+ dim3 block (std::min (1024 , mem_footprint_per_block));
197+ const at::cuda::OptionalCUDAGuard device_guard (cache_device);
198+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
199+ VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES (
200+ kv_caches[0 ].scalar_type (), " copy_blocks_mla_kernel" , ([&] {
201+ vllm::copy_blocks_mla_kernel<scalar_t ><<<grid, block, 0 , stream>>> (
202+ cache_ptrs_tensor.data_ptr <int64_t >(),
203+ block_mapping.data_ptr <int64_t >(), mem_footprint_per_block);
204+ }));
205+ }
206+
150207namespace vllm {
151208
152209template <typename scalar_t , typename cache_t , Fp8KVCacheDataType kv_dt>
@@ -254,6 +311,7 @@ __global__ void concat_and_cache_mla_kernel(
254311 // + pe_dim)]
255312 const int64_t * __restrict__ slot_mapping, // [num_tokens]
256313 const int block_stride, //
314+ const int entry_stride, //
257315 const int kv_c_stride, //
258316 const int k_pe_stride, //
259317 const int kv_lora_rank, //
@@ -274,9 +332,8 @@ __global__ void concat_and_cache_mla_kernel(
274332 int src_stride, int dst_stride, int size, int offset) {
275333 for (int i = threadIdx .x ; i < size; i += blockDim .x ) {
276334 const int64_t src_idx = token_idx * src_stride + i;
277- const int64_t dst_idx = block_idx * block_stride +
278- block_offset * (kv_lora_rank + pe_dim) + i +
279- offset;
335+ const int64_t dst_idx =
336+ block_idx * block_stride + block_offset * entry_stride + i + offset;
280337 if constexpr (kv_dt == Fp8KVCacheDataType::kAuto ) {
281338 dst[dst_idx] = src[src_idx];
282339 } else {
@@ -391,14 +448,14 @@ void reshape_and_cache_flash(
391448// KV_T is the stored data type of kv-cache.
392449// CACHE_T is the data type of key and value tensors.
393450// KV_DTYPE is the real data type of kv-cache.
394- #define CALL_CONCAT_AND_CACHE_MLA (KV_T, CACHE_T, KV_DTYPE ) \
395- vllm::concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
396- <<<grid, block, 0 , stream>>> ( \
397- reinterpret_cast <KV_T*>(kv_c.data_ptr()), \
398- reinterpret_cast <KV_T*>(k_pe.data_ptr()), \
399- reinterpret_cast <CACHE_T*>(kv_cache.data_ptr()), \
400- slot_mapping.data_ptr<int64_t >(), block_stride, kv_c_stride , \
401- k_pe_stride, kv_lora_rank, pe_dim, block_size, \
451+ #define CALL_CONCAT_AND_CACHE_MLA (KV_T, CACHE_T, KV_DTYPE ) \
452+ vllm::concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
453+ <<<grid, block, 0 , stream>>> ( \
454+ reinterpret_cast <KV_T*>(kv_c.data_ptr()), \
455+ reinterpret_cast <KV_T*>(k_pe.data_ptr()), \
456+ reinterpret_cast <CACHE_T*>(kv_cache.data_ptr()), \
457+ slot_mapping.data_ptr<int64_t >(), block_stride, entry_stride , \
458+ kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
402459 reinterpret_cast <const float *>(scale.data_ptr()));
403460
404461void concat_and_cache_mla (
@@ -428,6 +485,7 @@ void concat_and_cache_mla(
428485 int kv_c_stride = kv_c.stride (0 );
429486 int k_pe_stride = k_pe.stride (0 );
430487 int block_stride = kv_cache.stride (0 );
488+ int entry_stride = kv_cache.stride (1 );
431489
432490 dim3 grid (num_tokens);
433491 dim3 block (std::min (kv_lora_rank, 512 ));
0 commit comments