Skip to content

Commit 75e9430

Browse files
[Perf] Mem align KV caches for CUDA devices (MLA perf improvement) (#12676)
Signed-off-by: simon-mo <xmo@berkeley.edu> Signed-off-by: Lucas Wilkinson <lcwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Co-authored-by: simon-mo <xmo@berkeley.edu>
1 parent 233df6f commit 75e9430

File tree

10 files changed

+429
-34
lines changed

10 files changed

+429
-34
lines changed

csrc/cache.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
1515
std::vector<torch::Tensor> const& value_caches,
1616
const torch::Tensor& block_mapping);
1717

18+
void copy_blocks_mla(std::vector<torch::Tensor> const& kv_caches,
19+
const torch::Tensor& block_mapping);
20+
1821
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
1922
torch::Tensor& key_cache, torch::Tensor& value_cache,
2023
torch::Tensor& slot_mapping,

csrc/cache_kernels.cu

Lines changed: 70 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
4646
char* src_ptr = static_cast<char*>(src.data_ptr());
4747
char* dst_ptr = static_cast<char*>(dst.data_ptr());
4848

49-
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
49+
// We use the stride instead of numel in case the cache is padded for memory
50+
// alignment reasons, we assume the blocks data (inclusive of any padding)
51+
// is contiguous in memory
52+
const int64_t block_size_in_bytes = src.element_size() * src.stride(0);
5053
const at::cuda::OptionalCUDAGuard device_guard(
5154
src_device.is_cuda() ? src_device : dst_device);
5255
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@@ -93,6 +96,24 @@ __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
9396
}
9497
}
9598

99+
// Kernel for MLA, which works on a single joint kv_cache
100+
// Grid: (num_layers, num_pairs)
101+
template <typename scalar_t>
102+
__global__ void copy_blocks_mla_kernel(
103+
int64_t* cache_ptrs, const int64_t* __restrict__ block_mapping,
104+
const int mem_footprint_per_block) {
105+
const int layer_idx = blockIdx.x;
106+
const int pair_idx = blockIdx.y;
107+
scalar_t* cache = reinterpret_cast<scalar_t*>(cache_ptrs[layer_idx]);
108+
int64_t src_block = block_mapping[2 * pair_idx];
109+
int64_t dst_block = block_mapping[2 * pair_idx + 1];
110+
int64_t src_offset = src_block * mem_footprint_per_block;
111+
int64_t dst_offset = dst_block * mem_footprint_per_block;
112+
for (int i = threadIdx.x; i < mem_footprint_per_block; i += blockDim.x) {
113+
cache[dst_offset + i] = cache[src_offset + i];
114+
}
115+
}
116+
96117
} // namespace vllm
97118

98119
// Note: the key_caches and value_caches vectors are constant but
@@ -147,6 +168,42 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
147168
}));
148169
}
149170

171+
// copy blocks kernel for MLA (assumes a joint KV-cache)
172+
void copy_blocks_mla(std::vector<torch::Tensor> const& kv_caches,
173+
const torch::Tensor& block_mapping) {
174+
int num_layers = kv_caches.size();
175+
if (num_layers == 0) {
176+
return;
177+
}
178+
torch::Device cache_device = kv_caches[0].device();
179+
TORCH_CHECK(cache_device.is_cuda(), "kv_cache must be on CUDA");
180+
181+
std::vector<int64_t> cache_ptrs(num_layers);
182+
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
183+
cache_ptrs[layer_idx] =
184+
reinterpret_cast<int64_t>(kv_caches[layer_idx].data_ptr());
185+
}
186+
torch::Tensor cache_ptrs_tensor =
187+
torch::from_blob(cache_ptrs.data(), {num_layers}, torch::kInt64)
188+
.to(cache_device);
189+
190+
int num_pairs = block_mapping.size(0);
191+
// We use the stride instead of numel in case the cache is padded for memory
192+
// alignment reasons, we assume the blocks data (inclusive of any padding)
193+
// is contiguous in memory
194+
int mem_footprint_per_block = kv_caches[0].stride(0);
195+
dim3 grid(num_layers, num_pairs);
196+
dim3 block(std::min(1024, mem_footprint_per_block));
197+
const at::cuda::OptionalCUDAGuard device_guard(cache_device);
198+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
199+
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
200+
kv_caches[0].scalar_type(), "copy_blocks_mla_kernel", ([&] {
201+
vllm::copy_blocks_mla_kernel<scalar_t><<<grid, block, 0, stream>>>(
202+
cache_ptrs_tensor.data_ptr<int64_t>(),
203+
block_mapping.data_ptr<int64_t>(), mem_footprint_per_block);
204+
}));
205+
}
206+
150207
namespace vllm {
151208

152209
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
@@ -254,6 +311,7 @@ __global__ void concat_and_cache_mla_kernel(
254311
// + pe_dim)]
255312
const int64_t* __restrict__ slot_mapping, // [num_tokens]
256313
const int block_stride, //
314+
const int entry_stride, //
257315
const int kv_c_stride, //
258316
const int k_pe_stride, //
259317
const int kv_lora_rank, //
@@ -274,9 +332,8 @@ __global__ void concat_and_cache_mla_kernel(
274332
int src_stride, int dst_stride, int size, int offset) {
275333
for (int i = threadIdx.x; i < size; i += blockDim.x) {
276334
const int64_t src_idx = token_idx * src_stride + i;
277-
const int64_t dst_idx = block_idx * block_stride +
278-
block_offset * (kv_lora_rank + pe_dim) + i +
279-
offset;
335+
const int64_t dst_idx =
336+
block_idx * block_stride + block_offset * entry_stride + i + offset;
280337
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
281338
dst[dst_idx] = src[src_idx];
282339
} else {
@@ -391,14 +448,14 @@ void reshape_and_cache_flash(
391448
// KV_T is the stored data type of kv-cache.
392449
// CACHE_T is the data type of key and value tensors.
393450
// KV_DTYPE is the real data type of kv-cache.
394-
#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
395-
vllm::concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
396-
<<<grid, block, 0, stream>>>( \
397-
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
398-
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
399-
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
400-
slot_mapping.data_ptr<int64_t>(), block_stride, kv_c_stride, \
401-
k_pe_stride, kv_lora_rank, pe_dim, block_size, \
451+
#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
452+
vllm::concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
453+
<<<grid, block, 0, stream>>>( \
454+
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
455+
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
456+
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
457+
slot_mapping.data_ptr<int64_t>(), block_stride, entry_stride, \
458+
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
402459
reinterpret_cast<const float*>(scale.data_ptr()));
403460

404461
void concat_and_cache_mla(
@@ -428,6 +485,7 @@ void concat_and_cache_mla(
428485
int kv_c_stride = kv_c.stride(0);
429486
int k_pe_stride = k_pe.stride(0);
430487
int block_stride = kv_cache.stride(0);
488+
int entry_stride = kv_cache.stride(1);
431489

432490
dim3 grid(num_tokens);
433491
dim3 block(std::min(kv_lora_rank, 512));

csrc/torch_bindings.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,10 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
450450
"Tensor block_mapping) -> ()");
451451
cache_ops.impl("copy_blocks", torch::kCUDA, &copy_blocks);
452452

453+
cache_ops.def(
454+
"copy_blocks_mla(Tensor(a!)[] kv_caches, Tensor block_mapping) -> ()");
455+
cache_ops.impl("copy_blocks_mla", torch::kCUDA, &copy_blocks_mla);
456+
453457
// Reshape the key and value tensors and cache them.
454458
cache_ops.def(
455459
"reshape_and_cache(Tensor key, Tensor value,"

0 commit comments

Comments
 (0)