Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion csrc/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
const std::string& kv_cache_dtype,
torch::Tensor& scale);

void cp_fused_concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
torch::Tensor& cp_local_token_select_indices,
torch::Tensor& kv_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
torch::Tensor& scale);

// Just for unittest
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const double scale, const std::string& kv_cache_dtype);
Expand All @@ -47,4 +54,12 @@ void gather_and_maybe_dequant_cache(
torch::Tensor const& cu_seq_lens, // [BATCH+1]
int64_t batch_size, const std::string& kv_cache_dtype,
torch::Tensor const& scale,
std::optional<torch::Tensor> seq_starts = std::nullopt);
std::optional<torch::Tensor> seq_starts = std::nullopt);

// TODO(hc): cp_gather_cache need support scaled kvcahe in the future.
void cp_gather_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& cu_seq_lens, // [BATCH+1]
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
247 changes: 247 additions & 0 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAException.h>

#include "cuda_utils.h"
#include "cuda_compat.h"
Expand Down Expand Up @@ -395,6 +396,51 @@ __global__ void concat_and_cache_mla_kernel(
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
}

template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void cp_fused_concat_and_cache_mla_kernel(
const scalar_t* __restrict__ kv_c, // [num_full_tokens, kv_lora_rank]
const scalar_t* __restrict__ k_pe, // [num_full_tokens, pe_dim]
const int64_t* __restrict__ cp_local_token_select_indices, // [num_tokens]
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
// + pe_dim)]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride, //
const int entry_stride, //
const int kv_c_stride, //
const int k_pe_stride, //
const int kv_lora_rank, //
const int pe_dim, //
const int block_size, //
const float* scale //
) {
const int64_t token_idx = cp_local_token_select_indices[blockIdx.x];
const int64_t slot_idx = slot_mapping[blockIdx.x];
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0) {
return;
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;

auto copy = [&](const scalar_t* __restrict__ src, cache_t* __restrict__ dst,
int src_stride, int dst_stride, int size, int offset) {
for (int i = threadIdx.x; i < size; i += blockDim.x) {
const int64_t src_idx = token_idx * src_stride + i;
const int64_t dst_idx =
block_idx * block_stride + block_offset * entry_stride + i + offset;
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
dst[dst_idx] = src[src_idx];
} else {
dst[dst_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(src[src_idx], *scale);
}
}
};

copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0);
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
}

} // namespace vllm

// KV_T is the data type of key and value tensors.
Expand Down Expand Up @@ -508,6 +554,20 @@ void reshape_and_cache_flash(
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
reinterpret_cast<const float*>(scale.data_ptr()));

// KV_T is the data type of key and value tensors.
// CACHE_T is the stored data type of kv-cache.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_CP_FUSED_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
vllm::cp_fused_concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
cp_local_token_select_indices.data_ptr<int64_t>(), \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, entry_stride, \
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
reinterpret_cast<const float*>(scale.data_ptr()));

void concat_and_cache_mla(
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
torch::Tensor& k_pe, // [num_tokens, pe_dim]
Expand Down Expand Up @@ -546,6 +606,50 @@ void concat_and_cache_mla(
CALL_CONCAT_AND_CACHE_MLA);
}

// Note(hc): cp_fused_concat_and_cache_mla fuses the following three kernel
// calls into one:
// k_c_normed.index_select(0, cp_local_token_select_indices) + \
// k_pe.squeeze(1).index_select(0, cp_local_token_select_indices) + \
// concat_and_cache_mla.
void cp_fused_concat_and_cache_mla(
torch::Tensor& kv_c, // [num_total_tokens, kv_lora_rank]
torch::Tensor& k_pe, // [num_total_tokens, pe_dim]
torch::Tensor& cp_local_token_select_indices, // [num_tokens]
torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
// pe_dim)]
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
const std::string& kv_cache_dtype, torch::Tensor& scale) {
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
// slot_mapping.size(0) because of padding for CUDA graphs.
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
// both include padding.
// In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0)
// since key includes padding for CUDA graphs, while slot_mapping does not.
// In this case, slot_mapping.size(0) represents the actual number of tokens
// before padding.
// For compatibility with both cases, we use slot_mapping.size(0) as the
// number of tokens.
int num_tokens = slot_mapping.size(0);
int kv_lora_rank = kv_c.size(1);
int pe_dim = k_pe.size(1);
int block_size = kv_cache.size(1);

TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);

int kv_c_stride = kv_c.stride(0);
int k_pe_stride = k_pe.stride(0);
int block_stride = kv_cache.stride(0);
int entry_stride = kv_cache.stride(1);

dim3 grid(num_tokens);
dim3 block(std::min(kv_lora_rank, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
CALL_CP_FUSED_CONCAT_AND_CACHE_MLA);
}

namespace vllm {

template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
Expand Down Expand Up @@ -779,3 +883,146 @@ void gather_and_maybe_dequant_cache(

DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype, CALL_GATHER_CACHE);
}

namespace vllm {
template <typename scalar_t>
// Note(hc): The cp_gather_cache allows seq_starts to no longer be divisible by
// block_size.
__global__ void cp_gather_cache(
const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
// ENTRY_SIZE]
scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRY_SIZE]
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
const int32_t* __restrict__ cu_seq_lens, // [BATCH+1]
const int32_t block_size, const int32_t entry_size,
const int64_t block_table_stride, const int64_t cache_block_stride,
const int64_t cache_entry_stride, const int64_t dst_entry_stride,
const int32_t* __restrict__ seq_starts // Optional: starting offsets per
// batch
) {
const int64_t bid = blockIdx.x; // Batch ID
const int32_t num_splits = gridDim.y;
const int32_t split = blockIdx.y;
const int32_t seq_start = cu_seq_lens[bid];
const int32_t seq_end = cu_seq_lens[bid + 1];
const int32_t seq_len = seq_end - seq_start;
const int32_t tot_slots = seq_len;
const int32_t split_slots = cuda_utils::ceil_div(tot_slots, num_splits);

const int32_t split_start = split * split_slots;
const int32_t split_end = min((split + 1) * split_slots, tot_slots);

const bool is_active_split = (split_start < tot_slots);
const bool is_last_split = (split_end == tot_slots);

if (!is_active_split) return;

// Adjust the pointer for the block_table for this batch.
// If seq_starts is provided, compute an offset based on it
const int32_t batch_offset = bid * block_table_stride;
int32_t offset = split_start;
if (seq_starts != nullptr) {
offset += seq_starts[bid];
}
int32_t offset_div = offset / block_size;
offset = offset % block_size;
const int32_t* batch_block_table = block_table + batch_offset;

// Adjust dst pointer based on the cumulative sequence lengths.
dst += seq_start * dst_entry_stride;

auto copy_entry = [&](const scalar_t* __restrict__ _src,
scalar_t* __restrict__ _dst) {
for (int i = threadIdx.x; i < entry_size; i += blockDim.x)
_dst[i] = _src[i];
};

for (int pid = split_start; pid < split_end; ++pid) {
auto block_id = batch_block_table[offset_div];
auto block_start_ptr = src_cache + block_id * cache_block_stride;
auto block_dst_ptr = dst + pid * dst_entry_stride;
copy_entry(block_start_ptr + offset * cache_entry_stride, block_dst_ptr);
offset += 1;
// bump to next block
if (offset == block_size) {
offset_div += 1;
offset = 0;
}
}
}
} // namespace vllm

// Macro to dispatch the kernel based on the data type.
#define CALL_CP_GATHER_CACHE(CPY_DTYPE) \
vllm::cp_gather_cache<CPY_DTYPE><<<grid, block, 0, stream>>>( \
reinterpret_cast<CPY_DTYPE*>(src_cache.data_ptr()), \
reinterpret_cast<CPY_DTYPE*>(dst.data_ptr()), \
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
block_size, entry_size, block_table_stride, cache_block_stride, \
cache_entry_stride, dst_entry_stride, seq_starts_ptr);

// Gather sequences from the cache into the destination tensor.
// - cu_seq_lens contains the cumulative sequence lengths for each batch
// - block_table contains the cache block indices for each sequence
// - Optionally, seq_starts (if provided) offsets the starting slot index by
// seq_starts[bid]
void cp_gather_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& cu_seq_lens, // [BATCH+1]
int64_t batch_size,
std::optional<torch::Tensor> seq_starts = std::nullopt) {
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

int32_t block_size = src_cache.size(1);
int32_t entry_size = src_cache.flatten(2, -1).size(2);

TORCH_CHECK(block_table.dtype() == torch::kInt32,
"block_table must be int32");
TORCH_CHECK(cu_seq_lens.dtype() == torch::kInt32,
"cu_seq_lens must be int32");
if (seq_starts.has_value()) {
TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32,
"seq_starts must be int32");
}

TORCH_CHECK(src_cache.device() == dst.device(),
"src_cache and dst must be on the same device");
TORCH_CHECK(src_cache.device() == block_table.device(),
"src_cache and block_table must be on the same device");
TORCH_CHECK(src_cache.device() == cu_seq_lens.device(),
"src_cache and cu_seq_lens must be on the same device");
if (seq_starts.has_value()) {
TORCH_CHECK(src_cache.device() == seq_starts.value().device(),
"src_cache and seq_starts must be on the same device");
}

int64_t block_table_stride = block_table.stride(0);
int64_t cache_block_stride = src_cache.stride(0);
int64_t cache_entry_stride = src_cache.stride(1);
int64_t dst_entry_stride = dst.stride(0);

// Decide on the number of splits based on the batch size.
int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16;
dim3 grid(batch_size, num_splits);
dim3 block(1024);

TORCH_CHECK(src_cache.dtype() == dst.dtype(),
"src_cache and dst must have the same dtype");

const int dtype_bits = src_cache.element_size() * 8;
const int32_t* seq_starts_ptr =
seq_starts.has_value() ? seq_starts.value().data_ptr<int32_t>() : nullptr;

if (dtype_bits == 32) {
CALL_CP_GATHER_CACHE(uint32_t);
} else if (dtype_bits == 16) {
CALL_CP_GATHER_CACHE(uint16_t);
} else if (dtype_bits == 8) {
CALL_CP_GATHER_CACHE(uint8_t);
} else {
TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits);
}
}
15 changes: 15 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,16 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
" Tensor scale) -> ()");
cache_ops.impl("concat_and_cache_mla", torch::kCUDA, &concat_and_cache_mla);

cache_ops.def(
"cp_fused_concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
" Tensor cp_local_token_select_indices,"
" Tensor! kv_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" Tensor scale) -> ()");
cache_ops.impl("cp_fused_concat_and_cache_mla", torch::kCUDA,
&cp_fused_concat_and_cache_mla);

// Convert the key and value cache to fp8 data type.
cache_ops.def(
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
Expand All @@ -702,6 +712,11 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
" Tensor scale, Tensor? seq_starts) -> ()");
cache_ops.impl("gather_and_maybe_dequant_cache", torch::kCUDA,
&gather_and_maybe_dequant_cache);

cache_ops.def(
"cp_gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, "
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache);
}

TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
Expand Down
Loading