Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def run_vllm(
download_dir: Optional[str] = None,
load_format: str = EngineArgs.load_format,
disable_async_output_proc: bool = False,
swap_space: int = 4,
external_swapper: str = "",
external_swapper_space: int = 0,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(
Expand All @@ -112,6 +115,9 @@ def run_vllm(
num_scheduler_steps=num_scheduler_steps,
use_v2_block_manager=use_v2_block_manager,
disable_async_output_proc=disable_async_output_proc,
swap_space=swap_space,
external_swapper=external_swapper,
external_swapper_space=external_swapper_space,
)

# Add the requests to the engine.
Expand Down Expand Up @@ -240,7 +246,8 @@ def main(args: argparse.Namespace):
args.max_num_batched_tokens, args.distributed_executor_backend,
args.gpu_memory_utilization, args.num_scheduler_steps,
args.use_v2_block_manager, args.download_dir, args.load_format,
args.disable_async_output_proc)
args.disable_async_output_proc, args.swap_space,
args.external_swapper, args.external_swapper_space)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
Expand Down Expand Up @@ -426,6 +433,21 @@ def main(args: argparse.Namespace):
action='store_true',
default=False,
help="Disable async output processor for vLLM backend.")
parser.add_argument('--swap-space',
type=float,
default=EngineArgs.swap_space,
help='CPU swap space size (GiB) per GPU.')

parser.add_argument(
'--external-swapper',
type=str,
default="",
help="External storage kv cache medium, supports local file currently."
)
parser.add_argument('--external-swapper-space',
type=int,
default=0,
help="External swapper space size (GiB) per GPU.")
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
Expand Down
96 changes: 96 additions & 0 deletions benchmarks/kernels/benchmark_swap_blocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import time

import torch

from vllm import _custom_ops as ops
from vllm.utils import FlexibleArgumentParser


def benchmark_swap_blocks(src_shape, num_blocks):
src = torch.randn(src_shape, dtype=torch.float16).cuda()
dst = torch.zeros_like(src).cpu()
block_mapping = [(i, i) for i in range(num_blocks)]
blocks_to_swap = torch.tensor(block_mapping,
device="cpu",
dtype=torch.int64).view(-1, 2)

num_iterations = 100
total_time = 0
for _ in range(num_iterations):
start_time = time.time()
ops.swap_blocks(src, dst, blocks_to_swap)
torch.cuda.synchronize()
end_time = time.time()
total_time += end_time - start_time

average_time = total_time / num_iterations
print(
f"Avg. GPU->CPU time taken for swapping blocks: {average_time} seconds"
)


def benchmark_swap_out_to_local_file(src_shape, dst, num_blocks):
src = torch.randn(src_shape, dtype=torch.float16).cuda()
num_elements = src.numel()
element_size = src.element_size()
total_bytes = num_elements * element_size
with open(dst, 'wb') as file:
file.write(b'0' * total_bytes)

block_mapping = [(i, i) for i in range(num_blocks)]
blocks_to_swap = torch.tensor(block_mapping,
device="cpu",
dtype=torch.int64).view(-1, 2)
num_iterations = 100
total_time = 0
for _ in range(num_iterations):
start_time = time.time()
ops.swap_out_to_local_file(src, dst, blocks_to_swap)
torch.cuda.synchronize()
end_time = time.time()
total_time += end_time - start_time

average_time = total_time / num_iterations
print(
f"Avg. GPU->File time taken for swapping blocks: {average_time} seconds"
)


def benchmark_swap_in_from_local_file(src_shape, src, num_blocks):
dst = torch.zeros(src_shape, dtype=torch.float16).cuda()
block_mapping = [(i, i) for i in range(num_blocks)]
blocks_to_swap = torch.tensor(block_mapping,
device="cpu",
dtype=torch.int64).view(-1, 2)
num_iterations = 100
total_time = 0
for _ in range(num_iterations):
start_time = time.time()
ops.swap_in_from_local_file(src, dst, blocks_to_swap)
torch.cuda.synchronize()
end_time = time.time()
total_time += end_time - start_time

average_time = total_time / num_iterations
print(
f"Avg. File->GPU time taken for swapping blocks: {average_time} seconds"
)


if __name__ == "__main__":
parser = FlexibleArgumentParser()
parser.add_argument("--num-blocks", type=int, default="1024")
parser.add_argument("--block-size", type=int, default=16)
parser.add_argument("--num-kv-heads", type=int, default=32)
parser.add_argument("--head-size", type=int, default=32)
parser.add_argument("--filename", type=str, default="./test.txt")
args = parser.parse_args()
print(args)

src_shape = (args.num_blocks, args.block_size, args.num_kv_heads,
args.head_size)

benchmark_swap_blocks(src_shape, args.num_blocks)
benchmark_swap_out_to_local_file(src_shape, args.filename, args.num_blocks)
benchmark_swap_in_from_local_file(src_shape, args.filename,
args.num_blocks)
22 changes: 22 additions & 0 deletions csrc/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,32 @@

#include <map>
#include <vector>
#include <string>

class FileSwapperParam {
public:
FileSwapperParam(char* cache_ptr, const std::string& file_name,
int64_t file_offset, int64_t size)
: cache_ptr(cache_ptr),
file_name(file_name),
file_offset(file_offset),
size(size) {}

char* cache_ptr;
std::string file_name;
int64_t file_offset;
int64_t size;
};

void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
const torch::Tensor& block_mapping);

void swap_out_to_local_file(torch::Tensor& src, std::string file_name,
const torch::Tensor& block_mapping);

void swap_in_from_local_file(std::string src, torch::Tensor& dst,
const torch::Tensor& block_mapping);

// Note: the key_caches and value_caches vectors are constant but
// not the Tensors they contain. The vectors need to be const refs
// in order to satisfy pytorch's C++ operator registration code.
Expand Down
127 changes: 126 additions & 1 deletion csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cassert>

#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "cache.h"

#ifdef USE_ROCM
#include "quantization/fp8/amd/quant_utils.cuh"
Expand All @@ -15,7 +17,10 @@
#include <cassert>
#include <map>
#include <vector>

#include <string>
#include <pthread.h>
#include <fcntl.h>
#include <unistd.h>
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
typedef __hip_bfloat16 __nv_bfloat16;
Expand Down Expand Up @@ -62,6 +67,126 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
}
}

void* copy_blocks_to_file(void* args) {
FileSwapperParam* param = static_cast<FileSwapperParam*>(args);
char* gpu_ptr = param->cache_ptr;
int64_t file_offset = param->file_offset;
int64_t size = param->size;
std::string file_name = param->file_name;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

char* cpu_ptr = new char[size];
cudaError_t cudaStatus =
cudaMemcpyAsync(cpu_ptr, gpu_ptr, size, cudaMemcpyDeviceToHost, stream);

// Currently, error handling is not supported in the kernel,
// so use assert to handle it.
// Please refer to https://github.com/vllm-project/vllm/issues/7577
assert(cudaStatus == cudaSuccess && "cudaMemcpyAsync failed");

int fd = open(file_name.c_str(), O_WRONLY);
assert(fd != -1 && "failed to open the file: " + file_name);

ssize_t bytesWritten = pwrite(fd, cpu_ptr, size, file_offset);
assert(bytesWritten == size && "failed to write the file: " + file_name);
close(fd);

delete[] cpu_ptr;
delete param;
return nullptr;
}

void swap_out_to_local_file(torch::Tensor& src, std::string dst_file,
const torch::Tensor& block_mapping) {
TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");

const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
char* src_ptr = static_cast<char*>(src.data_ptr());

const int64_t num_blocks = block_mapping.size(0);
pthread_t* threads = new pthread_t[num_blocks];
for (size_t i = 0; i < num_blocks; i++) {
int64_t src_block_number = block_mapping[i][0].item<int64_t>();
int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
int64_t src_offset = src_block_number * block_size_in_bytes;
int64_t dst_offset = dst_block_number * block_size_in_bytes;

char* gpu_ptr = src_ptr + src_offset;
FileSwapperParam* param = new FileSwapperParam(
gpu_ptr, dst_file, dst_offset, block_size_in_bytes);

// In order to increase the cache copy speed,
// multi-threading is used for copying.
int ret = pthread_create(&threads[i], nullptr, copy_blocks_to_file, param);
assert(ret == 0 && "create thread failed");
}

for (size_t i = 0; i < num_blocks; i++) {
pthread_join(threads[i], nullptr);
}

delete[] threads;
}

void* copy_blocks_from_file(void* args) {
FileSwapperParam* param = static_cast<FileSwapperParam*>(args);
char* gpu_ptr = param->cache_ptr;
int64_t size = param->size;
int64_t file_offset = param->file_offset;
std::string file_name = param->file_name;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
char* cpu_ptr = new char[size];

int fd = open(file_name.c_str(), O_RDONLY);
assert(fd != -1 && "failed to open the file: " + file_name);

int ret = pread(fd, cpu_ptr, size, file_offset);
assert(ret != -1 && "failed to read the file: " + file_name);
close(fd);

cudaError_t cudaStatus =
cudaMemcpyAsync(gpu_ptr, cpu_ptr, size, cudaMemcpyHostToDevice, stream);
assert(cudaStatus == cudaSuccess && "cudaMemcpyAsync failed");

delete[] cpu_ptr;
delete param;
return nullptr;
}

void swap_in_from_local_file(std::string src_file, torch::Tensor& dst,
const torch::Tensor& block_mapping) {
TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");

const int64_t block_size_in_bytes = dst.element_size() * dst[0].numel();
char* dst_ptr = static_cast<char*>(dst.data_ptr());

const int64_t num_blocks = block_mapping.size(0);
pthread_t* threads = new pthread_t[num_blocks];
for (size_t i = 0; i < num_blocks; i++) {
int64_t src_block_number = block_mapping[i][0].item<int64_t>();
int64_t dst_block_number = block_mapping[i][1].item<int64_t>();

int64_t src_offset = src_block_number * block_size_in_bytes;
int64_t dst_offset = dst_block_number * block_size_in_bytes;

char* gpu_ptr = dst_ptr + dst_offset;
FileSwapperParam* param = new FileSwapperParam(
gpu_ptr, src_file, src_offset, block_size_in_bytes);

// In order to increase the cache copy speed,
// multi-threading is used for copying.
int ret =
pthread_create(&threads[i], nullptr, copy_blocks_from_file, param);
assert(ret == 0 && "create thread failed");
}

for (size_t i = 0; i < num_blocks; i++) {
pthread_join(threads[i], nullptr);
}

delete[] threads;
}

namespace vllm {

// Grid: (num_layers, num_pairs)
Expand Down
12 changes: 12 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,18 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
"swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);

cache_ops.def(
"swap_out_to_local_file(Tensor src, str file_name, Tensor block_mapping) "
"-> ()");
cache_ops.impl("swap_out_to_local_file", torch::kCUDA,
&swap_out_to_local_file);

cache_ops.def(
"swap_in_from_local_file(str src, Tensor dst, Tensor block_mapping) -> "
"()");
cache_ops.impl("swap_in_from_local_file", torch::kCUDA,
&swap_in_from_local_file);

// Copy the cache blocks from src to dst.
cache_ops.def(
"copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor "
Expand Down
6 changes: 3 additions & 3 deletions tests/core/block/test_block_manager_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def test_swap(block_size, num_cpu_blocks, num_gpu_blocks, num_lookahead_slots,
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
mapping = block_manager.swap_out(seq_group)
mapping_keys = [key for key, _ in mapping]
mapping_keys = [key.block_id for key, _ in mapping]
assert mapping_keys == gpu_blocks
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
Expand All @@ -304,7 +304,7 @@ def test_swap(block_size, num_cpu_blocks, num_gpu_blocks, num_lookahead_slots,
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
mapping = block_manager.swap_in(seq_group)
cpu_blocks = block_manager.get_block_table(prompt)
mapping_keys = [key for key, _ in mapping]
mapping_keys = [key.block_id for key, _ in mapping]
assert mapping_keys == [cpu_blocks[0]]
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
Expand Down Expand Up @@ -338,7 +338,7 @@ def test_can_swap(block_size, num_gpu_blocks, num_lookahead_slots,
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
mapping = block_manager.swap_out(seq_group)
mapping_keys = [key for key, _ in mapping]
mapping_keys = [key.block_id for key, _ in mapping]
assert mapping_keys == gpu_blocks
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
Expand Down
Loading