diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index eaf256f7cb8c..138ae69f4507 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -87,6 +87,9 @@ def run_vllm( download_dir: Optional[str] = None, load_format: str = EngineArgs.load_format, disable_async_output_proc: bool = False, + swap_space: int = 4, + external_swapper: str = "", + external_swapper_space: int = 0, ) -> float: from vllm import LLM, SamplingParams llm = LLM( @@ -112,6 +115,9 @@ def run_vllm( num_scheduler_steps=num_scheduler_steps, use_v2_block_manager=use_v2_block_manager, disable_async_output_proc=disable_async_output_proc, + swap_space=swap_space, + external_swapper=external_swapper, + external_swapper_space=external_swapper_space, ) # Add the requests to the engine. @@ -240,7 +246,8 @@ def main(args: argparse.Namespace): args.max_num_batched_tokens, args.distributed_executor_backend, args.gpu_memory_utilization, args.num_scheduler_steps, args.use_v2_block_manager, args.download_dir, args.load_format, - args.disable_async_output_proc) + args.disable_async_output_proc, args.swap_space, + args.external_swapper, args.external_swapper_space) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, @@ -426,6 +433,21 @@ def main(args: argparse.Namespace): action='store_true', default=False, help="Disable async output processor for vLLM backend.") + parser.add_argument('--swap-space', + type=float, + default=EngineArgs.swap_space, + help='CPU swap space size (GiB) per GPU.') + + parser.add_argument( + '--external-swapper', + type=str, + default="", + help="External storage kv cache medium, supports local file currently." + ) + parser.add_argument('--external-swapper-space', + type=int, + default=0, + help="External swapper space size (GiB) per GPU.") args = parser.parse_args() if args.tokenizer is None: args.tokenizer = args.model diff --git a/benchmarks/kernels/benchmark_swap_blocks.py b/benchmarks/kernels/benchmark_swap_blocks.py new file mode 100644 index 000000000000..cd5aebd57a14 --- /dev/null +++ b/benchmarks/kernels/benchmark_swap_blocks.py @@ -0,0 +1,96 @@ +import time + +import torch + +from vllm import _custom_ops as ops +from vllm.utils import FlexibleArgumentParser + + +def benchmark_swap_blocks(src_shape, num_blocks): + src = torch.randn(src_shape, dtype=torch.float16).cuda() + dst = torch.zeros_like(src).cpu() + block_mapping = [(i, i) for i in range(num_blocks)] + blocks_to_swap = torch.tensor(block_mapping, + device="cpu", + dtype=torch.int64).view(-1, 2) + + num_iterations = 100 + total_time = 0 + for _ in range(num_iterations): + start_time = time.time() + ops.swap_blocks(src, dst, blocks_to_swap) + torch.cuda.synchronize() + end_time = time.time() + total_time += end_time - start_time + + average_time = total_time / num_iterations + print( + f"Avg. GPU->CPU time taken for swapping blocks: {average_time} seconds" + ) + + +def benchmark_swap_out_to_local_file(src_shape, dst, num_blocks): + src = torch.randn(src_shape, dtype=torch.float16).cuda() + num_elements = src.numel() + element_size = src.element_size() + total_bytes = num_elements * element_size + with open(dst, 'wb') as file: + file.write(b'0' * total_bytes) + + block_mapping = [(i, i) for i in range(num_blocks)] + blocks_to_swap = torch.tensor(block_mapping, + device="cpu", + dtype=torch.int64).view(-1, 2) + num_iterations = 100 + total_time = 0 + for _ in range(num_iterations): + start_time = time.time() + ops.swap_out_to_local_file(src, dst, blocks_to_swap) + torch.cuda.synchronize() + end_time = time.time() + total_time += end_time - start_time + + average_time = total_time / num_iterations + print( + f"Avg. GPU->File time taken for swapping blocks: {average_time} seconds" + ) + + +def benchmark_swap_in_from_local_file(src_shape, src, num_blocks): + dst = torch.zeros(src_shape, dtype=torch.float16).cuda() + block_mapping = [(i, i) for i in range(num_blocks)] + blocks_to_swap = torch.tensor(block_mapping, + device="cpu", + dtype=torch.int64).view(-1, 2) + num_iterations = 100 + total_time = 0 + for _ in range(num_iterations): + start_time = time.time() + ops.swap_in_from_local_file(src, dst, blocks_to_swap) + torch.cuda.synchronize() + end_time = time.time() + total_time += end_time - start_time + + average_time = total_time / num_iterations + print( + f"Avg. File->GPU time taken for swapping blocks: {average_time} seconds" + ) + + +if __name__ == "__main__": + parser = FlexibleArgumentParser() + parser.add_argument("--num-blocks", type=int, default="1024") + parser.add_argument("--block-size", type=int, default=16) + parser.add_argument("--num-kv-heads", type=int, default=32) + parser.add_argument("--head-size", type=int, default=32) + parser.add_argument("--filename", type=str, default="./test.txt") + args = parser.parse_args() + print(args) + + src_shape = (args.num_blocks, args.block_size, args.num_kv_heads, + args.head_size) + + benchmark_swap_blocks(src_shape, args.num_blocks) + benchmark_swap_out_to_local_file(src_shape, args.filename, args.num_blocks) + benchmark_swap_in_from_local_file(src_shape, args.filename, + args.num_blocks) diff --git a/csrc/cache.h b/csrc/cache.h index 11c4c5001daa..7b9a4efd88ba 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -4,10 +4,32 @@ #include #include +#include + +class FileSwapperParam { + public: + FileSwapperParam(char* cache_ptr, const std::string& file_name, + int64_t file_offset, int64_t size) + : cache_ptr(cache_ptr), + file_name(file_name), + file_offset(file_offset), + size(size) {} + + char* cache_ptr; + std::string file_name; + int64_t file_offset; + int64_t size; +}; void swap_blocks(torch::Tensor& src, torch::Tensor& dst, const torch::Tensor& block_mapping); +void swap_out_to_local_file(torch::Tensor& src, std::string file_name, + const torch::Tensor& block_mapping); + +void swap_in_from_local_file(std::string src, torch::Tensor& dst, + const torch::Tensor& block_mapping); + // Note: the key_caches and value_caches vectors are constant but // not the Tensors they contain. The vectors need to be const refs // in order to satisfy pytorch's C++ operator registration code. diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 1be806bbfa43..c9ec124dc8ec 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -1,9 +1,11 @@ #include #include #include +#include #include "cuda_compat.h" #include "dispatch_utils.h" +#include "cache.h" #ifdef USE_ROCM #include "quantization/fp8/amd/quant_utils.cuh" @@ -15,7 +17,10 @@ #include #include #include - +#include +#include +#include +#include #ifdef USE_ROCM #include typedef __hip_bfloat16 __nv_bfloat16; @@ -62,6 +67,126 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst, } } +void* copy_blocks_to_file(void* args) { + FileSwapperParam* param = static_cast(args); + char* gpu_ptr = param->cache_ptr; + int64_t file_offset = param->file_offset; + int64_t size = param->size; + std::string file_name = param->file_name; + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + char* cpu_ptr = new char[size]; + cudaError_t cudaStatus = + cudaMemcpyAsync(cpu_ptr, gpu_ptr, size, cudaMemcpyDeviceToHost, stream); + + // Currently, error handling is not supported in the kernel, + // so use assert to handle it. + // Please refer to https://github.com/vllm-project/vllm/issues/7577 + assert(cudaStatus == cudaSuccess && "cudaMemcpyAsync failed"); + + int fd = open(file_name.c_str(), O_WRONLY); + assert(fd != -1 && "failed to open the file: " + file_name); + + ssize_t bytesWritten = pwrite(fd, cpu_ptr, size, file_offset); + assert(bytesWritten == size && "failed to write the file: " + file_name); + close(fd); + + delete[] cpu_ptr; + delete param; + return nullptr; +} + +void swap_out_to_local_file(torch::Tensor& src, std::string dst_file, + const torch::Tensor& block_mapping) { + TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU"); + + const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); + char* src_ptr = static_cast(src.data_ptr()); + + const int64_t num_blocks = block_mapping.size(0); + pthread_t* threads = new pthread_t[num_blocks]; + for (size_t i = 0; i < num_blocks; i++) { + int64_t src_block_number = block_mapping[i][0].item(); + int64_t dst_block_number = block_mapping[i][1].item(); + int64_t src_offset = src_block_number * block_size_in_bytes; + int64_t dst_offset = dst_block_number * block_size_in_bytes; + + char* gpu_ptr = src_ptr + src_offset; + FileSwapperParam* param = new FileSwapperParam( + gpu_ptr, dst_file, dst_offset, block_size_in_bytes); + + // In order to increase the cache copy speed, + // multi-threading is used for copying. + int ret = pthread_create(&threads[i], nullptr, copy_blocks_to_file, param); + assert(ret == 0 && "create thread failed"); + } + + for (size_t i = 0; i < num_blocks; i++) { + pthread_join(threads[i], nullptr); + } + + delete[] threads; +} + +void* copy_blocks_from_file(void* args) { + FileSwapperParam* param = static_cast(args); + char* gpu_ptr = param->cache_ptr; + int64_t size = param->size; + int64_t file_offset = param->file_offset; + std::string file_name = param->file_name; + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + char* cpu_ptr = new char[size]; + + int fd = open(file_name.c_str(), O_RDONLY); + assert(fd != -1 && "failed to open the file: " + file_name); + + int ret = pread(fd, cpu_ptr, size, file_offset); + assert(ret != -1 && "failed to read the file: " + file_name); + close(fd); + + cudaError_t cudaStatus = + cudaMemcpyAsync(gpu_ptr, cpu_ptr, size, cudaMemcpyHostToDevice, stream); + assert(cudaStatus == cudaSuccess && "cudaMemcpyAsync failed"); + + delete[] cpu_ptr; + delete param; + return nullptr; +} + +void swap_in_from_local_file(std::string src_file, torch::Tensor& dst, + const torch::Tensor& block_mapping) { + TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU"); + + const int64_t block_size_in_bytes = dst.element_size() * dst[0].numel(); + char* dst_ptr = static_cast(dst.data_ptr()); + + const int64_t num_blocks = block_mapping.size(0); + pthread_t* threads = new pthread_t[num_blocks]; + for (size_t i = 0; i < num_blocks; i++) { + int64_t src_block_number = block_mapping[i][0].item(); + int64_t dst_block_number = block_mapping[i][1].item(); + + int64_t src_offset = src_block_number * block_size_in_bytes; + int64_t dst_offset = dst_block_number * block_size_in_bytes; + + char* gpu_ptr = dst_ptr + dst_offset; + FileSwapperParam* param = new FileSwapperParam( + gpu_ptr, src_file, src_offset, block_size_in_bytes); + + // In order to increase the cache copy speed, + // multi-threading is used for copying. + int ret = + pthread_create(&threads[i], nullptr, copy_blocks_from_file, param); + assert(ret == 0 && "create thread failed"); + } + + for (size_t i = 0; i < num_blocks; i++) { + pthread_join(threads[i], nullptr); + } + + delete[] threads; +} + namespace vllm { // Grid: (num_layers, num_pairs) diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 7783acd741f5..6bb6894cc15a 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -292,6 +292,18 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()"); cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks); + cache_ops.def( + "swap_out_to_local_file(Tensor src, str file_name, Tensor block_mapping) " + "-> ()"); + cache_ops.impl("swap_out_to_local_file", torch::kCUDA, + &swap_out_to_local_file); + + cache_ops.def( + "swap_in_from_local_file(str src, Tensor dst, Tensor block_mapping) -> " + "()"); + cache_ops.impl("swap_in_from_local_file", torch::kCUDA, + &swap_in_from_local_file); + // Copy the cache blocks from src to dst. cache_ops.def( "copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor " diff --git a/tests/core/block/test_block_manager_v2.py b/tests/core/block/test_block_manager_v2.py index 30efe4437741..a9785abea17c 100644 --- a/tests/core/block/test_block_manager_v2.py +++ b/tests/core/block/test_block_manager_v2.py @@ -290,7 +290,7 @@ def test_swap(block_size, num_cpu_blocks, num_gpu_blocks, num_lookahead_slots, before_cpu_blocks = block_manager.get_num_free_cpu_blocks() before_gpu_blocks = block_manager.get_num_free_gpu_blocks() mapping = block_manager.swap_out(seq_group) - mapping_keys = [key for key, _ in mapping] + mapping_keys = [key.block_id for key, _ in mapping] assert mapping_keys == gpu_blocks after_cpu_blocks = block_manager.get_num_free_cpu_blocks() after_gpu_blocks = block_manager.get_num_free_gpu_blocks() @@ -304,7 +304,7 @@ def test_swap(block_size, num_cpu_blocks, num_gpu_blocks, num_lookahead_slots, before_gpu_blocks = block_manager.get_num_free_gpu_blocks() mapping = block_manager.swap_in(seq_group) cpu_blocks = block_manager.get_block_table(prompt) - mapping_keys = [key for key, _ in mapping] + mapping_keys = [key.block_id for key, _ in mapping] assert mapping_keys == [cpu_blocks[0]] after_cpu_blocks = block_manager.get_num_free_cpu_blocks() after_gpu_blocks = block_manager.get_num_free_gpu_blocks() @@ -338,7 +338,7 @@ def test_can_swap(block_size, num_gpu_blocks, num_lookahead_slots, before_cpu_blocks = block_manager.get_num_free_cpu_blocks() before_gpu_blocks = block_manager.get_num_free_gpu_blocks() mapping = block_manager.swap_out(seq_group) - mapping_keys = [key for key, _ in mapping] + mapping_keys = [key.block_id for key, _ in mapping] assert mapping_keys == gpu_blocks after_cpu_blocks = block_manager.get_num_free_cpu_blocks() after_gpu_blocks = block_manager.get_num_free_gpu_blocks() diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 2ee9f20824f2..7296aa93cb7d 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -321,7 +321,7 @@ def test_swap(): before_cpu_blocks = block_manager.get_num_free_cpu_blocks() before_gpu_blocks = block_manager.get_num_free_gpu_blocks() mapping = block_manager.swap_out(seq_group) - assert [x[0] for x in mapping] == gpu_blocks + assert [x[0].block_id for x in mapping] == gpu_blocks after_cpu_blocks = block_manager.get_num_free_cpu_blocks() after_gpu_blocks = block_manager.get_num_free_gpu_blocks() assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks) @@ -334,7 +334,7 @@ def test_swap(): before_cpu_blocks = block_manager.get_num_free_cpu_blocks() before_gpu_blocks = block_manager.get_num_free_gpu_blocks() mapping = block_manager.swap_in(seq_group) - assert [x[0] for x in mapping] == cpu_blocks + assert [x[0].block_id for x in mapping] == cpu_blocks after_cpu_blocks = block_manager.get_num_free_cpu_blocks() after_gpu_blocks = block_manager.get_num_free_gpu_blocks() assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks @@ -374,7 +374,7 @@ def test_swap_encoder_decoder(): before_cpu_blocks = block_manager.get_num_free_cpu_blocks() before_gpu_blocks = block_manager.get_num_free_gpu_blocks() mapping = block_manager.swap_out(seq_group) - assert [x[0] for x in mapping] == gpu_blocks + assert [x[0].block_id for x in mapping] == gpu_blocks #assert list(mapping.keys()) == gpu_blocks after_cpu_blocks = block_manager.get_num_free_cpu_blocks() after_gpu_blocks = block_manager.get_num_free_gpu_blocks() @@ -390,7 +390,7 @@ def test_swap_encoder_decoder(): before_cpu_blocks = block_manager.get_num_free_cpu_blocks() before_gpu_blocks = block_manager.get_num_free_gpu_blocks() mapping = block_manager.swap_in(seq_group) - assert [x[0] for x in mapping] == cpu_blocks + assert [x[0].block_id for x in mapping] == cpu_blocks after_cpu_blocks = block_manager.get_num_free_cpu_blocks() after_gpu_blocks = block_manager.get_num_free_gpu_blocks() assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 11168d2423b0..4d105ecd0640 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -10,6 +10,7 @@ from vllm.core.scheduler import Scheduler, SchedulingBudget from vllm.lora.request import LoRARequest from vllm.sequence import SequenceGroup, SequenceStatus +from vllm.utils import BlockSwapParam from .utils import (append_new_token, append_new_token_seq_group, create_dummy_prompt, get_sequence_groups, @@ -622,7 +623,7 @@ def test_schedule_decode_blocks_to_copy_update(): def test_schedule_swapped_simple(): scheduler = initialize_scheduler() curr_loras = None - blocks_to_swap_out: List[Tuple[int, int]] = [] + blocks_to_swap_out: List[Tuple[BlockSwapParam, BlockSwapParam]] = [] _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) @@ -641,7 +642,12 @@ def test_schedule_swapped_simple(): blocks_to_swap_in_reverse = [] for swapin, swapout in output.blocks_to_swap_in: blocks_to_swap_in_reverse.append((swapout, swapin)) - assert blocks_to_swap_out == blocks_to_swap_in_reverse + + for i in range(len(blocks_to_swap_out)): + assert blocks_to_swap_out[i][0].block_id == blocks_to_swap_in_reverse[ + i][0].block_id + assert blocks_to_swap_out[i][1].block_id == blocks_to_swap_in_reverse[ + i][1].block_id def test_schedule_swapped_max_token_budget(): diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 71d18359164b..6cf1941569a5 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -27,6 +27,9 @@ # We assume fp8 is always enabled for testing. KV_CACHE_DTYPE = ["auto", "fp8"] +# Local file storing kv cache +LOCAL_FILE = [("./test_key.txt", "./test_value.txt")] + @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) @pytest.mark.parametrize("num_layers", NUM_LAYERS) @@ -303,6 +306,88 @@ def test_reshape_and_cache_flash( torch.testing.assert_close(value_cache, cloned_value_cache) +@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) +@pytest.mark.parametrize("local_file", LOCAL_FILE) +@torch.inference_mode() +def test_file_swapper( + kv_cache_factory, + num_mappings: int, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + seed: int, + device: str, + kv_cache_dtype: str, + local_file: Tuple[str, str], +) -> None: + if kv_cache_dtype == "fp8": + pytest.skip() + if kv_cache_dtype == "fp8" and head_size % 16: + pytest.skip() + random.seed(seed) + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + src_blocks = random.sample(range(num_blocks), num_mappings) + dst_blocks = random.sample(range(num_blocks), num_mappings) + block_mapping = list(zip(src_blocks, dst_blocks)) + block_mapping_tensor = torch.tensor(block_mapping, + dtype=torch.int64, + device="cpu").view(-1, 2) + + # Create the KV caches on the cuda device. + src_key_caches, src_value_caches = kv_cache_factory( + num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype, + seed, "cuda:0") + src_key_test = torch.zeros(src_key_caches[0].shape, dtype=dtype).cuda() + src_value_test = torch.zeros(src_value_caches[0].shape, dtype=dtype).cuda() + + # Create local file. + num_elements_key = src_key_caches[0].numel() + element_size_key = src_key_caches[0].element_size() + total_bytes_key = num_elements_key * element_size_key + + num_elements_value = src_value_caches[0].numel() + element_size_value = src_value_caches[0].element_size() + total_bytes_value = num_elements_value * element_size_value + with open(local_file[0], 'wb') as file: + file.write(b'0' * total_bytes_key) + + with open(local_file[1], 'wb') as file: + file.write(b'0' * total_bytes_value) + + # Call the swap_out_to_local_file kernel. + ops.swap_out_to_local_file(src_key_caches[0], local_file[0], + block_mapping_tensor) + ops.swap_out_to_local_file(src_value_caches[0], local_file[1], + block_mapping_tensor) + torch.cuda.synchronize() + block_mapping_tensor[:, [0, 1]] = block_mapping_tensor[:, [1, 0]] + + # Call the swap_in_from_local_file kernel. + ops.swap_in_from_local_file(local_file[0], src_key_test, + block_mapping_tensor) + ops.swap_in_from_local_file(local_file[1], src_value_test, + block_mapping_tensor) + torch.cuda.synchronize() + + for src, _ in block_mapping: + torch.testing.assert_close(src_key_caches[0][src], src_key_test[src]) + torch.testing.assert_close(src_value_caches[0][src], + src_value_test[src]) + + @pytest.mark.parametrize("direction", COPYING_DIRECTION) @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) @pytest.mark.parametrize("num_heads", NUM_HEADS) diff --git a/tests/worker/test_model_input.py b/tests/worker/test_model_input.py index 1e7f560fc68c..9783adc3a80b 100644 --- a/tests/worker/test_model_input.py +++ b/tests/worker/test_model_input.py @@ -60,6 +60,22 @@ def copy_blocks( ) -> None: pass + @staticmethod + def swap_out_to_local_file( + src_kv_cache: torch.Tensor, + dst_kv_cache: Tuple[str, str], + src_to_dst: torch.Tensor, + ) -> None: + pass + + @staticmethod + def swap_in_form_local_file( + src_kv_cache: Tuple[str, str], + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + pass + def test_model_runner_input(): sampling_metadata = SamplingMetadata( diff --git a/tests/worker/test_swap.py b/tests/worker/test_swap.py index 7aa439ba0a15..5c6d74803024 100644 --- a/tests/worker/test_swap.py +++ b/tests/worker/test_swap.py @@ -2,7 +2,8 @@ from vllm.engine.arg_utils import EngineArgs from vllm.sequence import ExecuteModelRequest -from vllm.utils import get_distributed_init_method, get_ip, get_open_port +from vllm.utils import (BlockSwapParam, Device, get_distributed_init_method, + get_ip, get_open_port) from vllm.worker.worker import Worker @@ -54,7 +55,11 @@ def test_swap() -> None: a.cuda(), b.cuda(), rtol=0.0, atol=0.0) # Test swap out. - blocks_to_swap_out = [(3, 72), (56, 35), (84, 34)] + blocks_to_swap_out = [ + (BlockSwapParam(3, Device.GPU), BlockSwapParam(72, Device.CPU)), + (BlockSwapParam(56, Device.GPU), BlockSwapParam(35, Device.CPU)), + (BlockSwapParam(84, Device.GPU), BlockSwapParam(34, Device.CPU)), + ] execute_model_req = ExecuteModelRequest( seq_group_metadata_list=[], blocks_to_swap_in=[], @@ -67,17 +72,19 @@ def test_swap() -> None: gpu_key_cache, gpu_value_cache = gpu_cache[i] cpu_key_cache, cpu_value_cache = cpu_cache[i] for src, dst in blocks_to_swap_out: - assert allclose(gpu_key_cache[src], cpu_key_cache[dst]) - assert allclose(gpu_value_cache[src], cpu_value_cache[dst]) + assert allclose(gpu_key_cache[src.block_id], + cpu_key_cache[dst.block_id]) + assert allclose(gpu_value_cache[src.block_id], + cpu_value_cache[dst.block_id]) # Test swap in. execute_model_req.blocks_to_swap_out = [] execute_model_req.blocks_to_swap_in = [ - (19, 45), - (67, 23), - (12, 78), - (40, 99), - (1, 71), + (BlockSwapParam(19, Device.CPU), BlockSwapParam(45, Device.GPU)), + (BlockSwapParam(67, Device.CPU), BlockSwapParam(23, Device.GPU)), + (BlockSwapParam(12, Device.CPU), BlockSwapParam(78, Device.GPU)), + (BlockSwapParam(40, Device.CPU), BlockSwapParam(99, Device.GPU)), + (BlockSwapParam(1, Device.CPU), BlockSwapParam(71, Device.GPU)), ] worker.execute_model(execute_model_req=execute_model_req) @@ -85,5 +92,7 @@ def test_swap() -> None: gpu_key_cache, gpu_value_cache = gpu_cache[i] cpu_key_cache, cpu_value_cache = cpu_cache[i] for src, dst in execute_model_req.blocks_to_swap_in: - assert allclose(gpu_key_cache[dst], cpu_key_cache[src]) - assert allclose(gpu_value_cache[dst], cpu_value_cache[src]) + assert allclose(gpu_key_cache[dst.block_id], + cpu_key_cache[src.block_id]) + assert allclose(gpu_value_cache[dst.block_id], + cpu_value_cache[src.block_id]) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index fe254732e730..faa0cf07d33c 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -584,6 +584,16 @@ def copy_blocks(key_caches: List[torch.Tensor], torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping) +def swap_out_to_local_file(src: torch.Tensor, dst: str, + block_mapping: torch.Tensor) -> None: + torch.ops._C_cache_ops.swap_out_to_local_file(src, dst, block_mapping) + + +def swap_in_from_local_file(src: str, dst: torch.Tensor, + block_mapping: torch.Tensor) -> None: + torch.ops._C_cache_ops.swap_in_from_local_file(src, dst, block_mapping) + + def swap_blocks(src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor) -> None: torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index ccfc6b254c1e..dfb7431fc0cd 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -86,6 +86,24 @@ def copy_blocks( def advance_step(self, num_seqs: int, num_queries: int): raise NotImplementedError + @staticmethod + @abstractmethod + def swap_out_to_local_file( + src_kv_cache: torch.Tensor, + dst_kv_cache: Tuple[str, str], + src_to_dst: torch.Tensor, + ) -> None: + raise NotImplementedError + + @staticmethod + @abstractmethod + def swap_in_form_local_file( + src_kv_cache: Tuple[str, str], + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + raise NotImplementedError + @dataclass class AttentionMetadata: diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 30ce715d5d05..fdf49df7b485 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -181,6 +181,35 @@ def copy_blocks( value_caches = [kv_cache[1] for kv_cache in kv_caches] ops.copy_blocks(key_caches, value_caches, src_to_dists) + @staticmethod + def swap_out_to_local_file( + src_kv_cache: torch.Tensor, + dst_kv_cache: Tuple[str, str], + src_to_dst: torch.Tensor, + ) -> None: + src_key_cache = src_kv_cache[0] + dst_key_addr = dst_kv_cache[0] + ops.swap_out_to_local_file(src_key_cache, dst_key_addr, src_to_dst) + + src_value_cache = src_kv_cache[1] + dst_value_addr = dst_kv_cache[1] + ops.swap_out_to_local_file(src_value_cache, dst_value_addr, src_to_dst) + + @staticmethod + def swap_in_form_local_file( + src_kv_cache: Tuple[str, str], + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + src_key_cache = src_kv_cache[0] + dst_key_addr = dst_kv_cache[0] + ops.swap_in_from_local_file(src_key_cache, dst_key_addr, src_to_dst) + + src_value_cache = src_kv_cache[1] + dst_value_addr = dst_kv_cache[1] + ops.swap_in_from_local_file(src_value_cache, dst_value_addr, + src_to_dst) + @dataclass class FlashAttentionMetadata(AttentionMetadata): diff --git a/vllm/config.py b/vllm/config.py index b84d91d40237..c2ea3d549b07 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -596,6 +596,8 @@ def __init__( gpu_memory_utilization: float, swap_space: float, cache_dtype: str, + external_swapper: str = "", + external_swapper_space: int = 0, num_gpu_blocks_override: Optional[int] = None, sliding_window: Optional[int] = None, enable_prefix_caching: bool = False, @@ -609,6 +611,9 @@ def __init__( self.sliding_window = sliding_window self.enable_prefix_caching = enable_prefix_caching self.cpu_offload_gb = cpu_offload_gb + self.external_swapper = external_swapper + self.external_swapper_space_bytes = \ + external_swapper_space * GiB_bytes self._verify_args() self._verify_cache_dtype() self._verify_prefix_caching() @@ -616,6 +621,7 @@ def __init__( # Will be set after profiling. self.num_gpu_blocks = None self.num_cpu_blocks = None + self.num_external_blocks = None def metrics_info(self): # convert cache_config to dict(key: str, value: str) for prometheus diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index c87246c1c6d6..75757c4b16d2 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -4,7 +4,7 @@ DeviceAwareBlockAllocator) from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator -from vllm.utils import Device +from vllm.utils import BlockSwapParam, Device, get_external_device class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): @@ -25,6 +25,8 @@ def create( num_gpu_blocks: int, num_cpu_blocks: int, block_size: int, + num_external_blocks: int = 0, + external_swapper: str = "", ) -> DeviceAwareBlockAllocator: """Creates a CpuGpuBlockAllocator instance with the specified configuration. @@ -102,7 +104,7 @@ def __init__(self, cpu_block_allocator: BlockAllocator, Device.GPU: gpu_block_allocator, } - self._swap_mapping: Dict[int, int] = {} + self._swap_mapping: Dict[BlockSwapParam, BlockSwapParam] = {} self._null_block: Optional[Block] = None self._block_ids_to_allocator: Dict[int, BlockAllocator] = {} @@ -232,7 +234,7 @@ def get_physical_block_id(self, device: Device, absolute_id: int) -> int: return self._allocators[device].get_physical_block_id(absolute_id) def swap(self, blocks: List[Block], src_device: Device, - dst_device: Device) -> Dict[int, int]: + dst_device: Device) -> Dict[BlockSwapParam, BlockSwapParam]: """Execute the swap for the given blocks from source_device on to dest_device, save the current swap mapping and append them to the accumulated `self._swap_mapping` for each @@ -247,16 +249,26 @@ def swap(self, blocks: List[Block], src_device: Device, Dict[int, int]: Swap mapping from source_device on to dest_device. """ - src_block_ids = [block.block_id for block in blocks] + src_blocks = [ + BlockSwapParam(block.block_id, src_device) for block in blocks + ] self._allocators[src_device].swap_out(blocks) self._allocators[dst_device].swap_in(blocks) - dst_block_ids = [block.block_id for block in blocks] - - current_swap_mapping: Dict[int, int] = {} - for src_block_id, dst_block_id in zip(src_block_ids, dst_block_ids): - if src_block_id is not None and dst_block_id is not None: - self._swap_mapping[src_block_id] = dst_block_id - current_swap_mapping[src_block_id] = dst_block_id + dst_blocks = [] + for block in blocks: + + # because block" to allow reusing the + # existing "block" object + # so change block.allocator to the dst allocator + block.set_current_allocator(self._allocators[dst_device]) + dst_blocks.append(BlockSwapParam(block.block_id, dst_device)) + + current_swap_mapping: Dict[BlockSwapParam, BlockSwapParam] = {} + for src_block, dst_block in zip(src_blocks, dst_blocks): + if src_block.block_id is not None \ + and dst_block.block_id is not None: + self._swap_mapping[src_block] = dst_block + current_swap_mapping[src_block] = dst_block return current_swap_mapping def get_num_blocks_touched(self, @@ -328,18 +340,147 @@ def get_prefix_cache_hit_rate(self, device: Device) -> float: assert device in self._allocators return self._allocators[device].get_prefix_cache_hit_rate() - def get_and_reset_swaps(self) -> List[Tuple[int, int]]: + def get_and_reset_swaps( + self) -> List[Tuple[BlockSwapParam, BlockSwapParam]]: """Returns and clears the mapping of source to destination block IDs. Will be called after every swapping operations for now, and after every schedule when BlockManagerV2 become default. Currently not useful. Returns: - List[Tuple[int, int]]: A mapping of source to destination block IDs. + List[Tuple[BlockSwapParam, BlockSwapParam]]: A mapping of source to + destination block IDs. """ mapping = self._swap_mapping.copy() self._swap_mapping.clear() return list(mapping.items()) + def get_block_device(self, block: Block) -> Device: + for device, allocator in self._allocators.items(): + if block.get_current_allocator == allocator: + return device + + raise ValueError(f"Unknown block {block}") + + +class CpuGpuExternalBlockAllocator(CpuGpuBlockAllocator): + + @staticmethod + def create( + allocator_type: str, + num_gpu_blocks: int, + num_cpu_blocks: int, + block_size: int, + num_external_blocks: int = 0, + external_swapper: str = "", + ) -> DeviceAwareBlockAllocator: + """Creates a CpuGpuBlockAllocator instance with the specified + configuration. + + This static method creates and returns a CpuGpuBlockAllocator instance + based on the provided parameters. It initializes the CPU and GPU block + allocators with the specified number of blocks, block size, and + allocator type. + + Args: + allocator_type (str): The type of block allocator to use for CPU + and GPU blocks. Currently supported values are "naive" and + "prefix_caching". + num_gpu_blocks (int): The number of blocks to allocate for GPU + memory. + num_cpu_blocks (int): The number of blocks to allocate for CPU + memory. + block_size (int): The size of each block in number of tokens. + + Returns: + DeviceAwareBlockAllocator: A CpuGpuBlockAllocator instance with the + specified configuration. + + Notes: + - The block IDs are assigned contiguously, with GPU block IDs coming + before CPU block IDs. + """ + block_ids = list( + range(num_gpu_blocks + num_cpu_blocks + num_external_blocks)) + gpu_block_ids = block_ids[:num_gpu_blocks] + cpu_block_ids = block_ids[num_gpu_blocks:num_gpu_blocks + + num_cpu_blocks] + external_block_ids = block_ids[num_gpu_blocks + num_cpu_blocks:] + + if allocator_type == "naive": + gpu_allocator: BlockAllocator = NaiveBlockAllocator( + create_block=NaiveBlock, # type: ignore + num_blocks=num_gpu_blocks, + block_size=block_size, + block_ids=gpu_block_ids, + ) + + cpu_allocator: BlockAllocator = NaiveBlockAllocator( + create_block=NaiveBlock, # type: ignore + num_blocks=num_cpu_blocks, + block_size=block_size, + block_ids=cpu_block_ids, + ) + + external_allocator: BlockAllocator = NaiveBlockAllocator( + create_block=NaiveBlock, # type: ignore + num_blocks=num_external_blocks, + block_size=block_size, + block_ids=external_block_ids, + ) + elif allocator_type == "prefix_caching": + gpu_allocator = PrefixCachingBlockAllocator( + num_blocks=num_gpu_blocks, + block_size=block_size, + block_ids=gpu_block_ids, + ) + + cpu_allocator = PrefixCachingBlockAllocator( + num_blocks=num_cpu_blocks, + block_size=block_size, + block_ids=cpu_block_ids, + ) + + external_allocator = PrefixCachingBlockAllocator( + num_blocks=num_external_blocks, + block_size=block_size, + block_ids=external_block_ids, + ) + else: + raise ValueError(f"Unknown allocator type {allocator_type=}") + + return CpuGpuExternalBlockAllocator( + cpu_block_allocator=cpu_allocator, + gpu_block_allocator=gpu_allocator, + external_block_allocator=external_allocator, + external_swapper=external_swapper, + ) + + def __init__(self, cpu_block_allocator: BlockAllocator, + gpu_block_allocator: BlockAllocator, + external_block_allocator: BlockAllocator, + external_swapper: str): + assert not (cpu_block_allocator.all_block_ids + & gpu_block_allocator.all_block_ids + & external_block_allocator.all_block_ids + ), "cpu, gpu and external block allocators can't \ + have intersection of block ids" + + self._external_device = get_external_device(external_swapper) + + self._allocators = { + Device.CPU: cpu_block_allocator, + Device.GPU: gpu_block_allocator, + self._external_device: external_block_allocator, + } + + self._swap_mapping: Dict[BlockSwapParam, BlockSwapParam] = {} + self._null_block: Optional[Block] = None + + self._block_ids_to_allocator: Dict[int, BlockAllocator] = {} + for _, allocator in self._allocators.items(): + for block_id in allocator.all_block_ids: + self._block_ids_to_allocator[block_id] = allocator + class NullBlock(Block): """ @@ -405,3 +546,12 @@ def last_accessed(self, last_accessed_ts: float): @property def content_hash(self): return self._proxy.content_hash + + @property + def get_current_allocator(self) -> BlockAllocator: + raise NotImplementedError( + "get_current_allocator is not used for null block") + + def set_current_allocator(self, allocator: BlockAllocator) -> None: + raise NotImplementedError( + "set_current_allocator is not used for null block") diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index f26bc761c996..ed5229e9e966 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Dict, FrozenSet, List, Optional, Protocol, Tuple -from vllm.utils import Device +from vllm.utils import BlockSwapParam, Device BlockId = int @@ -71,6 +71,15 @@ def last_accessed(self) -> float: def last_accessed(self, last_accessed_ts: float): raise NotImplementedError + @property + @abstractmethod + def get_current_allocator(self) -> "BlockAllocator": + raise NotImplementedError + + @abstractmethod + def set_current_allocator(self, allocator: "BlockAllocator") -> None: + raise NotImplementedError + class Factory(Protocol): @abstractmethod @@ -268,7 +277,7 @@ def get_num_blocks_touched(self, @abstractmethod def swap(self, blocks: List[Block], src_device: Device, - dst_device: Device) -> Dict[int, int]: + dst_device: Device) -> Dict[BlockSwapParam, BlockSwapParam]: pass @abstractmethod @@ -288,3 +297,8 @@ def allocate_or_get_null_block(self) -> Block: def get_prefix_cache_hit_rate(self, device: Device) -> float: """Prefix cache hit rate. -1 means not supported or disabled.""" pass + + @abstractmethod + def get_block_device(self, block: Block) -> Device: + """Indicate the device of the block.""" + pass diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index 1643fd69c58a..fea7c523d164 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -462,3 +462,10 @@ def prev_block(self) -> Optional["Block"]: @property def content_hash(self) -> Optional[int]: return None + + @property + def get_current_allocator(self) -> BlockAllocator: + return self._allocator + + def set_current_allocator(self, allocator: BlockAllocator) -> None: + self._allocator = allocator diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index a87e814cfb04..58a392b230c3 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -854,6 +854,13 @@ def hash_block_tokens(is_first_block: bool, prev_block_hash: Optional[int], assert (prev_block_hash is None) == is_first_block return hash((is_first_block, prev_block_hash, *cur_block_token_ids)) + @property + def get_current_allocator(self) -> BlockAllocator: + return self._allocator + + def set_current_allocator(self, allocator: BlockAllocator) -> None: + self._allocator = allocator # type: ignore + class ComputedBlocksTracker: """Handles caching of per-sequence computed block ids. diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 24ab9eb66194..a9cfee4bbfd3 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -14,7 +14,7 @@ from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.logger import init_logger from vllm.sequence import Sequence, SequenceGroup, SequenceStatus -from vllm.utils import Device +from vllm.utils import BlockSwapParam, Device, get_external_device logger = init_logger(__name__) @@ -234,6 +234,8 @@ def __init__( block_size: int, num_gpu_blocks: int, num_cpu_blocks: int, + num_external_blocks: int = 0, + external_swapper: str = "", watermark: float = 0.01, sliding_window: Optional[int] = None, enable_caching: bool = False, @@ -241,6 +243,7 @@ def __init__( self.block_size = block_size self.num_total_gpu_blocks = num_gpu_blocks self.num_total_cpu_blocks = num_cpu_blocks + self.num_total_external_blocks = num_external_blocks if enable_caching and sliding_window is not None: raise NotImplementedError( @@ -258,6 +261,10 @@ def __init__( self.enable_caching = enable_caching self.watermark_blocks = int(watermark * num_gpu_blocks) + self.enable_external_swapper = (external_swapper != "" + and external_swapper is not None) + if self.enable_external_swapper: + self._external_device = get_external_device(external_swapper) if self.enable_caching: logger.info("Automatic prefix caching is enabled.") @@ -265,11 +272,21 @@ def __init__( Device.GPU, block_size, num_gpu_blocks) self.cpu_allocator: BlockAllocatorBase = CachedBlockAllocator( Device.CPU, block_size, num_cpu_blocks) + if self.enable_external_swapper: + self.external_allocator: BlockAllocatorBase = \ + CachedBlockAllocator( + self._external_device, block_size, + num_external_blocks) else: self.gpu_allocator = UncachedBlockAllocator( Device.GPU, block_size, num_gpu_blocks) self.cpu_allocator = UncachedBlockAllocator( Device.CPU, block_size, num_cpu_blocks) + if self.enable_external_swapper: + self.external_allocator = \ + UncachedBlockAllocator( + self._external_device, block_size, + num_external_blocks) # Mapping: seq_id -> BlockTable. self.block_tables: Dict[int, BlockTable] = {} @@ -566,7 +583,9 @@ def _swap_block_table( return new_block_table - def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + def swap_in( + self, seq_group: SequenceGroup + ) -> List[Tuple[BlockSwapParam, BlockSwapParam]]: request_id = seq_group.request_id @@ -586,14 +605,63 @@ def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: self.gpu_allocator, mapping) - return [(cpu_block.block_number, gpu_block.block_number) + return [(BlockSwapParam(cpu_block.block_number, Device.CPU), + BlockSwapParam(gpu_block.block_number, Device.GPU)) for cpu_block, gpu_block in mapping.items()] + def swap_in_from_external( + self, seq_group: SequenceGroup + ) -> List[Tuple[BlockSwapParam, BlockSwapParam]]: + + request_id = seq_group.request_id + + # External block -> GPU block. + # dict is efficient in lookup `if cpu_block in mapping` + + mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} + for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): + self.block_tables[seq.seq_id] = \ + self._swap_block_table(self.block_tables[seq.seq_id], + self.external_allocator, + self.gpu_allocator, + mapping) + + if seq_group.is_encoder_decoder(): + self.cross_block_tables[request_id] = \ + self._swap_block_table(self.cross_block_tables[request_id], + self.external_allocator, + self.gpu_allocator, + mapping) + + return [(BlockSwapParam(external_block.block_number, + self._external_device), + BlockSwapParam(gpu_block.block_number, Device.GPU)) + for external_block, gpu_block in mapping.items()] + + def is_swap_in_from_external( + self, + seq_group: SequenceGroup, + ) -> bool: + for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): + if len(self.block_tables[seq.seq_id]) == 0: + continue + + return self.block_tables[ + seq.seq_id][0].device == self._external_device + + return True + def can_swap_out(self, seq_group: SequenceGroup) -> bool: blocks = self._get_physical_blocks(seq_group) return len(blocks) <= self.cpu_allocator.get_num_free_blocks() - def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + def can_swap_out_to_external(self, seq_group: SequenceGroup) -> bool: + blocks = self._get_physical_blocks(seq_group) + return len(blocks) <= self.external_allocator.get_num_free_blocks() + + def swap_out( + self, seq_group: SequenceGroup + ) -> List[Tuple[BlockSwapParam, BlockSwapParam]]: request_id = seq_group.request_id # GPU block -> CPU block. @@ -612,8 +680,36 @@ def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: self.cpu_allocator, mapping) - return [(cpu_block.block_number, gpu_block.block_number) - for cpu_block, gpu_block in mapping.items()] + return [(BlockSwapParam(gpu_block.block_number, Device.GPU), + BlockSwapParam(cpu_block.block_number, Device.CPU)) + for gpu_block, cpu_block in mapping.items()] + + def swap_out_to_external( + self, seq_group: SequenceGroup + ) -> List[Tuple[BlockSwapParam, BlockSwapParam]]: + request_id = seq_group.request_id + + # GPU block -> External block. + # dict is efficient in lookup `if gpu_block in mapping` + mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + self.block_tables[seq.seq_id] = \ + self._swap_block_table(self.block_tables[seq.seq_id], + self.gpu_allocator, + self.external_allocator, + mapping) + + if seq_group.is_encoder_decoder(): + self.cross_block_tables[request_id] = \ + self._swap_block_table(self.cross_block_tables[request_id], + self.gpu_allocator, + self.external_allocator, + mapping) + + return [(BlockSwapParam(gpu_block.block_number, Device.GPU), + BlockSwapParam(external_block.block_number, + self._external_device)) + for gpu_block, external_block in mapping.items()] def _free_block_table(self, block_table: BlockTable) -> None: # when using a sliding window, each seq will only use up @@ -669,6 +765,9 @@ def get_num_free_gpu_blocks(self) -> int: def get_num_free_cpu_blocks(self) -> int: return self.cpu_allocator.get_num_free_blocks() + def get_num_free_external_blocks(self) -> int: + return self.external_allocator.get_num_free_blocks() + def access_all_blocks_in_seq( self, seq: Sequence, diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index b06385b062e8..1a109868b162 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -5,14 +5,15 @@ from typing import Tuple from vllm.core.block.block_table import BlockTable -from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator +from vllm.core.block.cpu_gpu_block_allocator import ( + CpuGpuBlockAllocator, CpuGpuExternalBlockAllocator) from vllm.core.block.interfaces import Block from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker, LastAccessBlocksTracker) from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.sequence import Sequence, SequenceGroup, SequenceStatus -from vllm.utils import Device +from vllm.utils import BlockSwapParam, Device, get_external_device SeqId = int EncoderSeqId = str @@ -64,6 +65,8 @@ def __init__( block_size: int, num_gpu_blocks: int, num_cpu_blocks: int, + num_external_blocks: int = 0, + external_swapper: str = "", watermark: float = 0.01, sliding_window: Optional[int] = None, enable_caching: bool = False, @@ -71,6 +74,9 @@ def __init__( self.block_size = block_size self.num_total_gpu_blocks = num_gpu_blocks self.num_total_cpu_blocks = num_cpu_blocks + self.num_total_external_blocks = num_external_blocks + self.enable_external_swapper = (external_swapper != "" + and external_swapper is not None) self.sliding_window = sliding_window # max_block_sliding_window is the max number of blocks that need to be @@ -92,12 +98,23 @@ def __init__( self.watermark_blocks = int(watermark * num_gpu_blocks) - self.block_allocator = CpuGpuBlockAllocator.create( - allocator_type="prefix_caching" if enable_caching else "naive", - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=num_cpu_blocks, - block_size=block_size, - ) + if not self.enable_external_swapper: + self.block_allocator = CpuGpuBlockAllocator.create( + allocator_type="prefix_caching" if enable_caching else "naive", + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks, + block_size=block_size, + ) + else: + self._external_device = get_external_device(external_swapper) + self.block_allocator = CpuGpuExternalBlockAllocator.create( + allocator_type="prefix_caching" if enable_caching else "naive", + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks, + num_external_blocks=num_external_blocks, + external_swapper=external_swapper, + block_size=block_size, + ) self.block_tables: Dict[SeqId, BlockTable] = {} self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {} @@ -339,7 +356,7 @@ def can_swap_in(self, seq_group: SequenceGroup, Args: sequence_group (SequenceGroup): The sequence group to swap in. - num_lookahead_slots (int): Number of lookahead slots used in + num_lookahead_slots (int): Number of lookahead slots used in speculative decoding, default to 0. Returns: @@ -348,7 +365,9 @@ def can_swap_in(self, seq_group: SequenceGroup, return self._can_swap(seq_group, Device.GPU, SequenceStatus.SWAPPED, num_lookahead_slots) - def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + def swap_in( + self, seq_group: SequenceGroup + ) -> List[Tuple[BlockSwapParam, BlockSwapParam]]: """Returns the block id mapping (from CPU to GPU) generated by swapping in the given seq_group with num_lookahead_slots. @@ -356,8 +375,8 @@ def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: seq_group (SequenceGroup): The sequence group to swap in. Returns: - List[Tuple[int, int]]: The mapping of swapping block from CPU - to GPU. + List[Tuple[BlockSwapParam, BlockSwapParam]]: The mapping of + swapping block from CPU to GPU. """ physical_block_id_mapping = [] for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): @@ -373,11 +392,14 @@ def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: self.block_tables[seq.seq_id].update(blocks) seq_physical_block_id_mapping = { - self.block_allocator.get_physical_block_id( - Device.CPU, cpu_block_id): - self.block_allocator.get_physical_block_id( - Device.GPU, gpu_block_id) - for cpu_block_id, gpu_block_id in seq_swap_mapping.items() + BlockSwapParam( + self.block_allocator.get_physical_block_id( + Device.CPU, cpu_block.block_id), # type: ignore + Device.CPU): BlockSwapParam( + self.block_allocator.get_physical_block_id( + Device.GPU, gpu_block.block_id), # type: ignore + Device.GPU) + for cpu_block, gpu_block in seq_swap_mapping.items() } physical_block_id_mapping.extend( @@ -385,13 +407,72 @@ def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: return physical_block_id_mapping + def swap_in_from_external( + self, seq_group: SequenceGroup + ) -> List[Tuple[BlockSwapParam, BlockSwapParam]]: + """Returns the block id mapping (from External Swapper to GPU) + generated by swapping in the given seq_group with + num_lookahead_slots. + + Args: + seq_group (SequenceGroup): The sequence group to swap in. + + Returns: + List[Tuple[BlockSwapParam, BlockSwapParam]]: The mapping of + swapping block from External Swapper to GPU. + """ + physical_block_id_mapping = [] + for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): + blocks = self.block_tables[seq.seq_id].blocks + if len(blocks) == 0: + continue + + seq_swap_mapping = self.block_allocator.swap( + blocks=blocks, + src_device=self._external_device, + dst_device=Device.GPU) + + # Refresh the block ids of the table (post-swap) + self.block_tables[seq.seq_id].update(blocks) + + seq_physical_block_id_mapping = { + BlockSwapParam( + self.block_allocator.get_physical_block_id( + self._external_device, + external_block.block_id), # type: ignore + self._external_device): BlockSwapParam( + self.block_allocator.get_physical_block_id( + Device.GPU, gpu_block.block_id), # type: ignore + Device.GPU) + for external_block, gpu_block in seq_swap_mapping.items() + } + + physical_block_id_mapping.extend( + list(seq_physical_block_id_mapping.items())) + + return physical_block_id_mapping + + def is_swap_in_from_external( + self, + seq_group: SequenceGroup, + ) -> bool: + for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): + if len(self.block_tables[seq.seq_id].blocks) == 0: + continue + + block = self.block_tables[seq.seq_id].blocks[0] + return self.block_allocator.get_block_device( + block) == self._external_device + + return True + def can_swap_out(self, seq_group: SequenceGroup) -> bool: """Returns whether we can swap out the given sequence_group with num_lookahead_slots. Args: seq_group (SequenceGroup): The sequence group to swap in. - num_lookahead_slots (int): Number of lookahead slots used in + num_lookahead_slots (int): Number of lookahead slots used in speculative decoding, default to 0. Returns: @@ -403,7 +484,27 @@ def can_swap_out(self, seq_group: SequenceGroup) -> bool: return True return False - def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + def can_swap_out_to_external(self, seq_group: SequenceGroup) -> bool: + """Returns whether we can swap out the given external swapper + with num_lookahead_slots. + + Args: + seq_group (SequenceGroup): The sequence group to swap in. + num_lookahead_slots (int): Number of lookahead slots used in + speculative decoding, default to 0. + + Returns: + bool: Whether it's possible to swap out current sequence group. + """ + alloc_status = self._can_swap(seq_group, self._external_device, + SequenceStatus.RUNNING) + if alloc_status == AllocStatus.OK: + return True + return False + + def swap_out( + self, seq_group: SequenceGroup + ) -> List[Tuple[BlockSwapParam, BlockSwapParam]]: """Returns the block id mapping (from GPU to CPU) generated by swapping out the given sequence_group with num_lookahead_slots. @@ -411,8 +512,8 @@ def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: sequence_group (SequenceGroup): The sequence group to swap in. Returns: - List[Tuple[int, int]]: The mapping of swapping block from - GPU to CPU. + List[Tuple[BlockSwapParam, BlockSwapParam]]: The mapping of + swapping block from GPU to CPU. """ physical_block_id_mapping = [] for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): @@ -428,11 +529,59 @@ def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: self.block_tables[seq.seq_id].update(blocks) seq_physical_block_id_mapping = { - self.block_allocator.get_physical_block_id( - Device.GPU, gpu_block_id): - self.block_allocator.get_physical_block_id( - Device.CPU, cpu_block_id) - for gpu_block_id, cpu_block_id in seq_swap_mapping.items() + BlockSwapParam( + self.block_allocator.get_physical_block_id( + Device.GPU, gpu_block.block_id), # type: ignore + Device.GPU): BlockSwapParam( + self.block_allocator.get_physical_block_id( + Device.CPU, cpu_block.block_id), # type: ignore + Device.CPU) + for gpu_block, cpu_block in seq_swap_mapping.items() + } + + physical_block_id_mapping.extend( + list(seq_physical_block_id_mapping.items())) + + return physical_block_id_mapping + + def swap_out_to_external( + self, seq_group: SequenceGroup + ) -> List[Tuple[BlockSwapParam, BlockSwapParam]]: + """Returns the block id mapping (from GPU to External swapper) + generated by swapping out the given sequence_group with + num_lookahead_slots. + + Args: + sequence_group (SequenceGroup): The sequence group to swap in. + + Returns: + List[Tuple[BlockSwapParam, BlockSwapParam]]: The mapping of + swapping block from GPU to External Swapper. + """ + physical_block_id_mapping = [] + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + blocks = self.block_tables[seq.seq_id].blocks + if len(blocks) == 0: + continue + + seq_swap_mapping = self.block_allocator.swap( + blocks=blocks, + src_device=Device.GPU, + dst_device=self._external_device) + + # Refresh the block ids of the table (post-swap) + self.block_tables[seq.seq_id].update(blocks) + + seq_physical_block_id_mapping = { + BlockSwapParam( + self.block_allocator.get_physical_block_id( + Device.GPU, gpu_block.block_id), # type: ignore + Device.GPU): BlockSwapParam( + self.block_allocator.get_physical_block_id( + self._external_device, + external_block.block_id), # type: ignore + self._external_device) + for gpu_block, external_block in seq_swap_mapping.items() } physical_block_id_mapping.extend( @@ -446,6 +595,9 @@ def get_num_free_gpu_blocks(self) -> int: def get_num_free_cpu_blocks(self) -> int: return self.block_allocator.get_num_free_blocks(Device.CPU) + def get_num_free_external_blocks(self) -> int: + return self.block_allocator.get_num_free_blocks(self._external_device) + def get_prefix_cache_hit_rate(self, device: Device) -> float: return self.block_allocator.get_prefix_cache_hit_rate(device) @@ -462,7 +614,7 @@ def _can_swap(self, device (Device): device to swap the 'seq_group' on. status (SequenceStatus): The status of sequence which is needed for action. RUNNING for swap out and SWAPPED for swap in - num_lookahead_slots (int): Number of lookahead slots used in + num_lookahead_slots (int): Number of lookahead slots used in speculative decoding, default to 0. Returns: diff --git a/vllm/core/embedding_model_block_manager.py b/vllm/core/embedding_model_block_manager.py index c47d7d8dfb07..86253783d700 100644 --- a/vllm/core/embedding_model_block_manager.py +++ b/vllm/core/embedding_model_block_manager.py @@ -2,7 +2,7 @@ from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.sequence import Sequence, SequenceGroup -from vllm.utils import Device +from vllm.utils import BlockSwapParam, Device class EmbeddingModelBlockSpaceManager(BlockSpaceManager): @@ -47,13 +47,36 @@ def can_swap_in(self, seq_group: SequenceGroup, num_lookahead_slots: int) -> AllocStatus: return AllocStatus.OK - def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + def swap_in( + self, seq_group: SequenceGroup + ) -> List[Tuple[BlockSwapParam, BlockSwapParam]]: return None # type: ignore + def swap_in_from_external( + self, seq_group: SequenceGroup + ) -> List[Tuple[BlockSwapParam, BlockSwapParam]]: + return None # type: ignore + + def is_swap_in_from_external( + self, + seq_group: SequenceGroup, + ) -> bool: + return False + def can_swap_out(self, seq_group: SequenceGroup) -> bool: return True - def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + def can_swap_out_to_external(self, seq_group: SequenceGroup) -> bool: + return True + + def swap_out( + self, seq_group: SequenceGroup + ) -> List[Tuple[BlockSwapParam, BlockSwapParam]]: + return None # type: ignore + + def swap_out_to_external( + self, seq_group: SequenceGroup + ) -> List[Tuple[BlockSwapParam, BlockSwapParam]]: return None # type: ignore def free(self, seq: Sequence) -> None: diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index 96f8dd851b2f..ed22856d73e0 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -5,7 +5,7 @@ from typing import Tuple from vllm.sequence import Sequence, SequenceGroup -from vllm.utils import Device +from vllm.utils import BlockSwapParam, Device class AllocStatus(enum.Enum): @@ -74,7 +74,22 @@ def can_swap_in(self, seq_group: SequenceGroup, pass @abstractmethod - def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + def swap_in( + self, seq_group: SequenceGroup + ) -> List[Tuple[BlockSwapParam, BlockSwapParam]]: + pass + + @abstractmethod + def is_swap_in_from_external( + self, + seq_group: SequenceGroup, + ) -> bool: + pass + + @abstractmethod + def swap_in_from_external( + self, seq_group: SequenceGroup + ) -> List[Tuple[BlockSwapParam, BlockSwapParam]]: pass @abstractmethod @@ -82,7 +97,19 @@ def can_swap_out(self, seq_group: SequenceGroup) -> bool: pass @abstractmethod - def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + def can_swap_out_to_external(self, seq_group: SequenceGroup) -> bool: + pass + + @abstractmethod + def swap_out( + self, seq_group: SequenceGroup + ) -> List[Tuple[BlockSwapParam, BlockSwapParam]]: + pass + + @abstractmethod + def swap_out_to_external( + self, seq_group: SequenceGroup + ) -> List[Tuple[BlockSwapParam, BlockSwapParam]]: pass @abstractmethod diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 4c2f71582031..fc723737a984 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -15,7 +15,7 @@ from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceGroupMetadataDelta, SequenceStatus) -from vllm.utils import Device, PyObjectCache +from vllm.utils import BlockSwapParam, Device, PyObjectCache logger = init_logger(__name__) @@ -120,10 +120,10 @@ class SchedulerOutputs: num_prefill_groups: int # Total number of batched tokens. num_batched_tokens: int - # Blocks to swap in. List of CPU -> GPU block number. - blocks_to_swap_in: List[Tuple[int, int]] - # Blocks to swap out. List of GPU -> CPU block number. - blocks_to_swap_out: List[Tuple[int, int]] + # Blocks to swap in. List of CPU/External -> GPU block number. + blocks_to_swap_in: List[Tuple[BlockSwapParam, BlockSwapParam]] + # Blocks to swap out. List of GPU -> CPU/External block number. + blocks_to_swap_out: List[Tuple[BlockSwapParam, BlockSwapParam]] # Blocks to copy. Source to dest block. blocks_to_copy: List[Tuple[int, int]] # Sequence groups that are going to be ignored. @@ -188,7 +188,7 @@ class SchedulerRunningOutputs: # Sequences that are swapped out. swapped_out: List[SequenceGroup] # The blocks to swap out. - blocks_to_swap_out: List[Tuple[int, int]] + blocks_to_swap_out: List[Tuple[BlockSwapParam, BlockSwapParam]] # The blocks to copy. blocks_to_copy: List[Tuple[int, int]] # The number of slots for lookahead decoding. @@ -226,7 +226,7 @@ class SchedulerSwappedInOutputs: # phase. I.e., it means the prefill has been chunked. prefill_seq_groups: List[ScheduledSequenceGroup] # The blocks to swap in. - blocks_to_swap_in: List[Tuple[int, int]] + blocks_to_swap_in: List[Tuple[BlockSwapParam, BlockSwapParam]] # The blocks to copy. blocks_to_copy: List[Tuple[int, int]] # The number of slots for lookahead decoding. @@ -328,11 +328,17 @@ def __init__( if num_cpu_blocks: num_cpu_blocks //= pipeline_parallel_size + num_external_blocks = cache_config.num_external_blocks + if num_external_blocks: + num_external_blocks //= pipeline_parallel_size + # Create the block space manager. self.block_manager = BlockSpaceManagerImpl( block_size=self.cache_config.block_size, num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks, + num_external_blocks=num_external_blocks, + external_swapper=self.cache_config.external_swapper, sliding_window=self.cache_config.sliding_window, enable_caching=self.cache_config.enable_prefix_caching) @@ -380,6 +386,10 @@ def __init__( self.use_async_output_proc = self.output_proc_callback is not None self.num_cache_iters = 2 if self.use_async_output_proc else 1 + self.enable_external_swapper = ( + self.cache_config.external_swapper != "" + and self.cache_config.external_swapper is not None) + self.cache_id = 0 for i in range(self.num_cache_iters): self._seq_group_metadata_cache.append( @@ -528,7 +538,8 @@ def _schedule_running( ret.prefill_seq_groups_list.clear() # Blocks that need to be swapped or copied before model execution. - blocks_to_swap_out: List[Tuple[int, int]] = ret.blocks_to_swap_out + blocks_to_swap_out: List[Tuple[ + BlockSwapParam, BlockSwapParam]] = ret.blocks_to_swap_out blocks_to_copy: List[Tuple[int, int]] = ret.blocks_to_copy decode_seq_groups: List[ScheduledSequenceGroup] = ret.decode_seq_groups @@ -666,7 +677,7 @@ def _schedule_swapped( SchedulerSwappedInOutputs. """ # Blocks that need to be swapped or copied before model execution. - blocks_to_swap_in: List[Tuple[int, int]] = [] + blocks_to_swap_in: List[Tuple[BlockSwapParam, BlockSwapParam]] = [] blocks_to_copy: List[Tuple[int, int]] = [] decode_seq_groups: List[ScheduledSequenceGroup] = [] prefill_seq_groups: List[ScheduledSequenceGroup] = [] @@ -1321,7 +1332,7 @@ def _append_slots( def _preempt( self, seq_group: SequenceGroup, - blocks_to_swap_out: List[Tuple[int, int]], + blocks_to_swap_out: List[Tuple[BlockSwapParam, BlockSwapParam]], preemption_mode: Optional[PreemptionMode] = None, ) -> PreemptionMode: # If preemption mode is not specified, we determine the mode as follows: @@ -1378,16 +1389,26 @@ def _preempt_by_recompute( def _preempt_by_swap( self, seq_group: SequenceGroup, - blocks_to_swap_out: List[Tuple[int, int]], + blocks_to_swap_out: List[Tuple[BlockSwapParam, BlockSwapParam]], ) -> None: self._swap_out(seq_group, blocks_to_swap_out) + def _is_swap_in_from_external( + self, + seq_group: SequenceGroup, + ) -> bool: + return self.block_manager.is_swap_in_from_external(seq_group) + def _swap_in( self, seq_group: SequenceGroup, - blocks_to_swap_in: List[Tuple[int, int]], + blocks_to_swap_in: List[Tuple[BlockSwapParam, BlockSwapParam]], ) -> None: - mapping = self.block_manager.swap_in(seq_group) + if self.enable_external_swapper and self._is_swap_in_from_external( + seq_group): + mapping = self.block_manager.swap_in_from_external(seq_group) + else: + mapping = self.block_manager.swap_in(seq_group) blocks_to_swap_in.extend(mapping) for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): seq.status = SequenceStatus.RUNNING @@ -1395,15 +1416,24 @@ def _swap_in( def _swap_out( self, seq_group: SequenceGroup, - blocks_to_swap_out: List[Tuple[int, int]], + blocks_to_swap_out: List[Tuple[BlockSwapParam, BlockSwapParam]], ) -> None: - if not self.block_manager.can_swap_out(seq_group): + swap_to_cpu = self.block_manager.can_swap_out(seq_group) + swap_to_external = ( + self.enable_external_swapper + and self.block_manager.can_swap_out_to_external(seq_group)) + + if not swap_to_cpu and not swap_to_external: # FIXME(woosuk): Abort the sequence group instead of aborting the # entire engine. raise RuntimeError( "Aborted due to the lack of CPU swap space. Please increase " "the swap space to avoid this error.") - mapping = self.block_manager.swap_out(seq_group) + if swap_to_cpu: + mapping = self.block_manager.swap_out(seq_group) + else: + mapping = self.block_manager.swap_out_to_external(seq_group) + blocks_to_swap_out.extend(mapping) for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): seq.status = SequenceStatus.SWAPPED diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d98f57bc2d35..815d1553cc77 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -85,6 +85,8 @@ class EngineArgs: swap_space: float = 4 # GiB cpu_offload_gb: float = 0 # GiB gpu_memory_utilization: float = 0.90 + external_swapper: str = "" # type: ignore + external_swapper_space: int = 0 #GiB, max_num_batched_tokens: Optional[int] = None max_num_seqs: int = 256 max_logprobs: int = 20 # Default value for OpenAI Chat Completions API @@ -741,6 +743,21 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.disable_async_output_proc, help="Disable async output processing. This may result in " "lower performance.") + # The external storage medium of the kv cache. + # You can customize the format of the string. + # For example, currently only local file is supported. + # Local file format is: "file://path/to/directory" + parser.add_argument( + '--external-swapper', + type=str, + default="", + help="The external storage medium of the kv cache. Currently only " + "local file is supported. Local file format is: " + "'file://path/to/directory'.") + parser.add_argument('--external-swapper-space', + type=int, + default=0, + help="External swapper space size (GiB) per GPU") return parser @classmethod @@ -812,6 +829,8 @@ def create_engine_config(self) -> EngineConfig: sliding_window=model_config.get_sliding_window(), enable_prefix_caching=self.enable_prefix_caching, cpu_offload_gb=self.cpu_offload_gb, + external_swapper=self.external_swapper, + external_swapper_space=self.external_swapper_space, ) parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1eab83f3b988..11b579960123 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -223,7 +223,8 @@ def __init__( "decoding_config=%r, observability_config=%r, " "seed=%d, served_model_name=%s, use_v2_block_manager=%s, " "num_scheduler_steps=%d, enable_prefix_caching=%s, " - "use_async_output_proc=%s)", + "use_async_output_proc=%s, external_swapper=%s, " + "external_swapper_space=%d)", VLLM_VERSION, model_config.model, speculative_config, @@ -255,6 +256,8 @@ def __init__( scheduler_config.num_scheduler_steps, cache_config.enable_prefix_caching, model_config.use_async_output_proc, + cache_config.external_swapper, + cache_config.external_swapper_space_bytes, ) # TODO(woosuk): Print more configs in debug mode. from vllm.plugins import load_general_plugins @@ -312,6 +315,9 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: observability_config=self.observability_config, ) + self.enable_external_swapper = ( + self.cache_config.external_swapper != "" + and self.cache_config.external_swapper is not None) if not self.model_config.embedding_mode: self._initialize_kv_caches() @@ -458,11 +464,19 @@ def _initialize_kv_caches(self) -> None: num_gpu_blocks_override) num_gpu_blocks = num_gpu_blocks_override + if self.enable_external_swapper: + self._initialize_external_caches() + self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks) + def _initialize_external_caches(self) -> None: + blocks = self.model_executor.determine_num_external_available_blocks() + self.cache_config.num_external_blocks = blocks + self.model_executor.initialize_external_cache(blocks) + @classmethod def _get_executor_cls(cls, engine_config: EngineConfig) -> Type[ExecutorBase]: @@ -1741,6 +1755,8 @@ def _get_stats(self, len(scheduler.swapped) for scheduler in self.scheduler) num_waiting_sys = sum( len(scheduler.waiting) for scheduler in self.scheduler) + num_cumulative_preemption = sum(scheduler.num_cumulative_preemption + for scheduler in self.scheduler) # KV Cache Usage in % num_total_gpu = self.cache_config.num_gpu_blocks @@ -1759,6 +1775,15 @@ def _get_stats(self, for scheduler in self.scheduler) cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu) + num_total_external = self.cache_config.num_external_blocks + external_cache_usage_sys = 0. + if num_total_external is not None and num_total_external > 0: + num_free_external = sum( + scheduler.block_manager.get_num_free_external_blocks() + for scheduler in self.scheduler) + external_cache_usage_sys = 1.0 - (num_free_external / + num_total_external) + # Prefix Cache Hit Rate. Note that we always use # the cache hit rate of the first virtual engine. cpu_prefix_cache_hit_rate = self.scheduler[ @@ -1879,9 +1904,11 @@ def _get_stats(self, num_running_sys=num_running_sys, num_swapped_sys=num_swapped_sys, num_waiting_sys=num_waiting_sys, + num_cumulative_preemption=num_cumulative_preemption, # KV Cache Usage in % gpu_cache_usage_sys=gpu_cache_usage_sys, cpu_cache_usage_sys=cpu_cache_usage_sys, + external_cache_usage_sys=external_cache_usage_sys, # Prefix Cache Hit Rate cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate, gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate, diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 74277cae7c8e..08ec6a261018 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -352,15 +352,19 @@ def log(self, stats: Stats) -> None: "Avg prompt throughput: %.1f tokens/s, " "Avg generation throughput: %.1f tokens/s, " "Running: %d reqs, Swapped: %d reqs, " - "Pending: %d reqs, GPU KV cache usage: %.1f%%, " - "CPU KV cache usage: %.1f%%.", + "Pending: %d reqs, Num Cumulative Preemption: %d, " + "GPU KV cache usage: %.1f%%, " + "CPU KV cache usage: %.1f%%, " + "External KV cache usage: %.1f%%.", prompt_throughput, generation_throughput, stats.num_running_sys, stats.num_swapped_sys, stats.num_waiting_sys, + stats.num_cumulative_preemption, stats.gpu_cache_usage_sys * 100, stats.cpu_cache_usage_sys * 100, + stats.external_cache_usage_sys * 100, ) if (stats.cpu_prefix_cache_hit_rate >= 0 or stats.gpu_prefix_cache_hit_rate >= 0): diff --git a/vllm/engine/metrics_types.py b/vllm/engine/metrics_types.py index 1eccb2359340..6054041d1f20 100644 --- a/vllm/engine/metrics_types.py +++ b/vllm/engine/metrics_types.py @@ -29,9 +29,11 @@ class Stats: num_running_sys: int num_waiting_sys: int num_swapped_sys: int + num_cumulative_preemption: int # KV Cache Usage in % gpu_cache_usage_sys: float cpu_cache_usage_sys: float + external_cache_usage_sys: float # Prefix caching block hit rate cpu_prefix_cache_hit_rate: float gpu_prefix_cache_hit_rate: float diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 0edd4bfaecd6..8a04b8bb9312 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -125,6 +125,8 @@ def __init__( gpu_memory_utilization: float = 0.9, swap_space: float = 4, cpu_offload_gb: float = 0, + external_swapper: str = "", + external_swapper_space: int = 0, # GiB, enforce_eager: Optional[bool] = None, max_context_len_to_capture: Optional[int] = None, max_seq_len_to_capture: int = 8192, @@ -167,6 +169,8 @@ def __init__( gpu_memory_utilization=gpu_memory_utilization, swap_space=swap_space, cpu_offload_gb=cpu_offload_gb, + external_swapper=external_swapper, + external_swapper_space=external_swapper_space, enforce_eager=enforce_eager, max_context_len_to_capture=max_context_len_to_capture, max_seq_len_to_capture=max_seq_len_to_capture, diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 21ad43f64168..ccfaae303ff8 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -195,6 +195,9 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: return self.driver_method_invoker(self.driver_worker, "determine_num_available_blocks") + def determine_num_external_available_blocks(self) -> int: + return 0 + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: """Initialize the KV cache by invoking the underlying worker. @@ -211,6 +214,9 @@ def initialize_cache(self, num_gpu_blocks: int, num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks) + def initialize_external_cache(self, num_external_blocks: int) -> None: + pass + def execute_model( self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py index ad84422ee212..0bf6f9d9eafc 100644 --- a/vllm/executor/distributed_gpu_executor.py +++ b/vllm/executor/distributed_gpu_executor.py @@ -46,6 +46,18 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: return num_gpu_blocks, num_cpu_blocks + def determine_num_external_available_blocks(self) -> int: + # Get the maximum number of blocks that can be + # allocated on external swapper. + num_blocks = self._run_workers( + "determine_num_external_available_blocks", ) + # Since we use a shared centralized controller, we take the minimum + # number of blocks across all workers to make sure all the memory + # operators can be applied to all workers. + num_external_blocks = min(b for b in num_blocks) + + return num_external_blocks + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: """Initialize the KV cache in all workers. @@ -64,6 +76,13 @@ def initialize_cache(self, num_gpu_blocks: int, num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks) + def initialize_external_cache(self, num_external_blocks: int) -> None: + logger.info("# External blocks: %d", num_external_blocks) + + self.cache_config.num_gpu_blocks = num_external_blocks + self._run_workers("initialize_external_cache", + num_external_blocks=num_external_blocks) + def execute_model( self, execute_model_req: ExecuteModelRequest, diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index c96cb0f2c298..97dc07149f80 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -66,6 +66,10 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: """ raise NotImplementedError + @abstractmethod + def determine_num_external_available_blocks(self) -> int: + raise NotImplementedError + @abstractmethod def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: @@ -73,6 +77,10 @@ def initialize_cache(self, num_gpu_blocks: int, """ raise NotImplementedError + @abstractmethod + def initialize_external_cache(self, num_external_blocks: int) -> None: + raise NotImplementedError + @abstractmethod def execute_model( self, execute_model_req: ExecuteModelRequest diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 947776e5d6ef..90b5a268977d 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -113,6 +113,9 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: """ return self.driver_worker.determine_num_available_blocks() + def determine_num_external_available_blocks(self) -> int: + return self.driver_worker.determine_num_external_available_blocks() + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: """Initialize the KV cache by invoking the underlying worker. """ @@ -124,6 +127,11 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) + def initialize_external_cache(self, num_external_blocks: int) -> None: + logger.info("# External blocks: %d", num_external_blocks) + + self.driver_worker.initialize_external_cache(num_external_blocks) + def execute_model( self, execute_model_req: ExecuteModelRequest ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]: diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index f2fcfa58b26e..eccb37a1fc9b 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -46,12 +46,18 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: """ return self.driver_worker.determine_num_available_blocks() + def determine_num_external_available_blocks(self) -> int: + return 0 + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: """Initialize the KV cache by invoking the underlying worker. """ self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) + def initialize_external_cache(self, num_external_blocks: int) -> None: + pass + def execute_model( self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: diff --git a/vllm/executor/openvino_executor.py b/vllm/executor/openvino_executor.py index 78606e223aa7..65a4b4739a42 100644 --- a/vllm/executor/openvino_executor.py +++ b/vllm/executor/openvino_executor.py @@ -62,6 +62,9 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: """ return self.driver_worker.determine_num_available_blocks() + def determine_num_external_available_blocks(self) -> int: + return 0 + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: """Initialize the KV cache by invoking the underlying worker.""" @@ -74,6 +77,9 @@ def initialize_cache(self, num_gpu_blocks: int, logger.info("# CPU blocks: %d", num_gpu_blocks) self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) + def initialize_external_cache(self, num_external_blocks: int) -> None: + pass + def execute_model( self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: diff --git a/vllm/executor/tpu_executor.py b/vllm/executor/tpu_executor.py index 0af8ba41e24d..87968436d747 100644 --- a/vllm/executor/tpu_executor.py +++ b/vllm/executor/tpu_executor.py @@ -86,6 +86,12 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: underlying worker.""" return self.driver_worker.determine_num_available_blocks() + def determine_num_external_available_blocks(self) -> int: + return 0 + + def initialize_external_cache(self, num_external_blocks: int) -> None: + pass + def execute_model( self, execute_model_req: ExecuteModelRequest, diff --git a/vllm/sequence.py b/vllm/sequence.py index 87b3d21fa7ae..7ea9d2b547c8 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -17,6 +17,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics +from vllm.utils import BlockSwapParam if TYPE_CHECKING: from vllm.inputs import LLMInputs @@ -1202,11 +1203,13 @@ class ExecuteModelRequest( seq_group_metadata_list: List[Union[SequenceGroupMetadata, SequenceGroupMetadataDelta]] # Blocks to swap in. List of CPU -> GPU block number. - blocks_to_swap_in: List[Tuple[int, - int]] = msgspec.field(default_factory=list) + blocks_to_swap_in: List[Tuple[BlockSwapParam, + BlockSwapParam]] = msgspec.field( + default_factory=list) # Blocks to swap out. List of GPU -> CPU block number. - blocks_to_swap_out: List[Tuple[int, - int]] = msgspec.field(default_factory=list) + blocks_to_swap_out: List[Tuple[BlockSwapParam, + BlockSwapParam]] = msgspec.field( + default_factory=list) # Blocks to copy. Source to dest block. blocks_to_copy: List[Tuple[int, int]] = msgspec.field(default_factory=list) # Virtual engine ID for pipeline parallel. diff --git a/vllm/spec_decode/proposer_worker_base.py b/vllm/spec_decode/proposer_worker_base.py index 28a537593f26..e1b282bda0f2 100644 --- a/vllm/spec_decode/proposer_worker_base.py +++ b/vllm/spec_decode/proposer_worker_base.py @@ -47,10 +47,3 @@ def execute_model( def determine_num_available_blocks(self) -> Tuple[int, int]: """This is never called on the proposer, only the target model""" raise NotImplementedError - - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - pass - - def get_cache_block_size_bytes(self) -> int: - return 0 diff --git a/vllm/utils.py b/vllm/utils.py index 657a3ecef696..ab15a60164f2 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -152,6 +152,15 @@ class _Sentinel: class Device(enum.Enum): GPU = enum.auto() CPU = enum.auto() + # External swapper device + File = enum.auto() + + +def get_external_device(external_swapper: str) -> Device: + if external_swapper.startswith("file://"): + return Device.File + else: + raise ValueError(f"Invalid external swapper name: {external_swapper}") class Counter: @@ -1224,3 +1233,12 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, def supports_dynamo() -> bool: base_torch_version = Version(Version(torch.__version__).base_version) return base_torch_version >= Version("2.4.0") + + +class BlockSwapParam: + """A utility class to save the block id and device information for swap. + """ + + def __init__(self, block_id: Optional[int], device: Device): + self.block_id = block_id + self.device = device diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 252440c7b7e0..ba12771525b4 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -8,6 +8,7 @@ from vllm.logger import init_logger from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size, is_pin_memory_available) +from vllm.worker.external_swapper import ExternalSwapperBase logger = init_logger(__name__) @@ -26,6 +27,8 @@ def __init__( model_config: ModelConfig, parallel_config: ParallelConfig, device_config: DeviceConfig, + rank: int = 0, + pipeline_parallel_id: int = 0, ) -> None: self.cache_config = cache_config self.model_config = model_config @@ -66,6 +69,37 @@ def __init__( self.gpu_cache = self._allocate_kv_cache( self.num_gpu_blocks, self.device_config.device_type) self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu") + self.enable_external_swapper = ( + self.cache_config.external_swapper != "" + and self.cache_config.external_swapper is not None) + + if self.enable_external_swapper: + self.num_external_blocks = cache_config.num_external_blocks + if self.num_external_blocks: + self.num_external_blocks //= \ + parallel_config.pipeline_parallel_size + + self.rank = rank + self.pipeline_parallel_id = pipeline_parallel_id + self.identifier = self._get_cache_engine_identifier() + + external_swapper_impl = ExternalSwapperBase. \ + get_external_swapper_class( + self.cache_config.external_swapper) + self.external_swapper = external_swapper_impl( + cache_config=self.cache_config, + model_config=self.model_config, + parallel_config=self.parallel_config, + dtype=self.dtype, + attn_backend=self.attn_backend, + gpu_cache=self.gpu_cache, + cache_engine_identifier=self.identifier, + ) + + def _get_cache_engine_identifier(self) -> str: + """Returns a unique identifier for the cache engine.""" + + return f"cache_engine_rank_{self.rank}_pp_{self.pipeline_parallel_id}" def _allocate_kv_cache( self, @@ -93,11 +127,17 @@ def swap_in(self, src_to_dst: torch.Tensor) -> None: self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i], src_to_dst) + def swap_in_from_external(self, src_to_dst: torch.Tensor) -> None: + self.external_swapper.swap_in(src_to_dst) + def swap_out(self, src_to_dst: torch.Tensor) -> None: for i in range(self.num_attention_layers): self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i], src_to_dst) + def swap_out_to_external(self, src_to_dst: torch.Tensor) -> None: + self.external_swapper.swap_out(src_to_dst) + def copy(self, src_to_dsts: torch.Tensor) -> None: self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts) diff --git a/vllm/worker/external_swapper.py b/vllm/worker/external_swapper.py new file mode 100644 index 000000000000..bca262311bb6 --- /dev/null +++ b/vllm/worker/external_swapper.py @@ -0,0 +1,120 @@ +import operator +from abc import ABC, abstractmethod +from functools import reduce +from typing import List, Tuple + +import torch + +from vllm.attention.backends.abstract import AttentionBackend +from vllm.config import CacheConfig, ModelConfig, ParallelConfig +from vllm.utils import get_dtype_size + + +class ExternalSwapperBase(ABC): + """Base class for external swapper.""" + + @staticmethod + def get_external_swapper_class(external_swapper: str): + external_swapper = external_swapper.lower() + + if external_swapper.startswith("file://"): + return LocalFileSwapper + + raise ValueError(f"Unknown external_swapper_type {external_swapper=}") + + @abstractmethod + def _allocate_kv_cache(self) -> List[Tuple[str, str]]: + """Allocate KV cache.""" + raise NotImplementedError + + @abstractmethod + def swap_out(self, src_to_dst: torch.Tensor) -> None: + """Swap out blocks from GPU -> NVMf.""" + raise NotImplementedError + + @abstractmethod + def swap_in(self, src_to_dst: torch.Tensor) -> None: + """Swap in blocks from NVMf -> GPU.""" + raise NotImplementedError + + +class LocalFileSwapper(ExternalSwapperBase): + """External swapper for local file.""" + + def __init__( + self, + cache_config: CacheConfig, + model_config: ModelConfig, + parallel_config: ParallelConfig, + dtype: torch.dtype, + attn_backend: AttentionBackend, + gpu_cache: List[torch.Tensor], + cache_engine_identifier: str, + ) -> None: + self.head_size = model_config.get_head_size() + self.num_attention_layers = model_config.get_num_attention_layers( + parallel_config) + self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) + + self.block_size = cache_config.block_size + self.external_swapper_space_bytes = \ + cache_config.external_swapper_space_bytes + self.num_external_blocks = cache_config.num_external_blocks + self.external_swapper = cache_config.external_swapper + + self.dtype = dtype + self.attn_backend = attn_backend + self.gpu_cache = gpu_cache + + self.cache_engine_identifier = cache_engine_identifier + self.kv_cache_shape = self.attn_backend.get_kv_cache_shape( + self.num_external_blocks, self.block_size, self.num_kv_heads, + self.head_size) + + self.directory = self._parse_swapper_parameters(self.external_swapper) + self.kv_cache = self._allocate_kv_cache() + + def _parse_swapper_parameters(self, + external_swapper_parameters: str) -> str: + if not external_swapper_parameters.startswith("file://"): + raise ValueError( + f"Invalid external swapper name: {external_swapper_parameters}" + ) + + directory = external_swapper_parameters[len("file://"):].strip() + return directory + + def _allocate_kv_cache(self) -> List[Tuple[str, str]]: + dtype_size = get_dtype_size(self.dtype) + kv_attention_layer_bytes = reduce(operator.mul, self.kv_cache_shape, + 1) * dtype_size + if kv_attention_layer_bytes % 2 != 0: + raise ValueError( + f"Invalid kv bytes size: {kv_attention_layer_bytes}") + key_attention_layer_bytes = int(kv_attention_layer_bytes / 2) + + kv_cache: List[Tuple[str, str]] = [] + for i in range(self.num_attention_layers): + key_file_name = \ + f"{self.directory}/external_{self.cache_engine_identifier}_layer_{i}_key" + val_file_name = \ + f"{self.directory}/external_{self.cache_engine_identifier}_layer_{i}_val" + + with open(key_file_name, 'wb') as f: + f.truncate(key_attention_layer_bytes) + with open(val_file_name, 'wb') as f: + f.truncate(key_attention_layer_bytes) + kv_cache.append((key_file_name, val_file_name)) + return kv_cache + + def swap_out(self, src_to_dst: torch.Tensor) -> None: + for i in range(self.num_attention_layers): + self.attn_backend.swap_out_to_local_file(self.gpu_cache[i], + self.kv_cache[i], + src_to_dst) + + def swap_in(self, src_to_dst: torch.Tensor) -> None: + for i in range(self.num_attention_layers): + self.attn_backend.swap_in_form_local_file(self.kv_cache[i], + self.gpu_cache[i], + src_to_dst) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 0ff559a9af53..89c6f851cbbe 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -23,6 +23,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, SequenceGroupMetadata, SequenceGroupMetadataDelta) +from vllm.utils import Device from vllm.worker.cache_engine import CacheEngine from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner @@ -248,6 +249,24 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: torch.cuda.empty_cache() return num_gpu_blocks, num_cpu_blocks + def determine_num_external_available_blocks(self) -> int: + cache_block_size = self.get_cache_block_size_bytes() + num_external_blocks = int( + self.cache_config.external_swapper_space_bytes // cache_block_size) + num_external_blocks = max(num_external_blocks, 0) + return num_external_blocks + + def initialize_external_cache(self, num_external_blocks: int) -> None: + """This function sets the number of external blocks in cache config. + + The actual cache allocate is in the initialize_cache function. + """ + raise_if_cache_size_invalid(num_external_blocks, + self.cache_config.block_size, + self.model_config.max_model_len) + + self.cache_config.num_external_blocks = num_external_blocks + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: """Allocate GPU and CPU KV cache with the specified number of blocks. @@ -267,9 +286,13 @@ def initialize_cache(self, num_gpu_blocks: int, def _init_cache_engine(self): assert self.cache_config.num_gpu_blocks is not None self.cache_engine = [ + + # When using external storage media, + # different cache engines need to be distinguished, + # so rank id and pipeline parallel id are used to distinguish them. CacheEngine(self.cache_config, self.model_config, - self.parallel_config, self.device_config) - for _ in range(self.parallel_config.pipeline_parallel_size) + self.parallel_config, self.device_config, self.rank, i) + for i in range(self.parallel_config.pipeline_parallel_size) ] self.gpu_cache = [ self.cache_engine[ve].gpu_cache @@ -297,14 +320,41 @@ def prepare_worker_input( virtual_engine = execute_model_req.virtual_engine num_steps = execute_model_req.num_steps num_seq_groups = len(execute_model_req.seq_group_metadata_list) - # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. - # they contain parameters to launch cudamemcpyasync. - blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in, - device="cpu", - dtype=torch.int64).view(-1, 2) - blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out, - device="cpu", - dtype=torch.int64).view(-1, 2) + block_to_swap_in_cpu_list = [] + block_to_swap_in_external_list = [] + block_to_swap_out_cpu_list = [] + block_to_swap_out_external_list = [] + + for tuple_item in execute_model_req.blocks_to_swap_in: + first_block = tuple_item[0] + if first_block.device == Device.CPU: + block_to_swap_in_cpu_list.append( + (tuple_item[0].block_id, tuple_item[1].block_id)) + else: + block_to_swap_in_external_list.append( + (tuple_item[0].block_id, tuple_item[1].block_id)) + + for tuple_item in execute_model_req.blocks_to_swap_out: + second_block = tuple_item[1] + if second_block.device == Device.CPU: + block_to_swap_out_cpu_list.append( + (tuple_item[0].block_id, tuple_item[1].block_id)) + else: + block_to_swap_out_external_list.append( + (tuple_item[0].block_id, tuple_item[1].block_id)) + + blocks_to_swap_in_cpu = torch.tensor(block_to_swap_in_cpu_list, + device="cpu", + dtype=torch.int64).view(-1, 2) + blocks_to_swap_out_cpu = torch.tensor(block_to_swap_out_cpu_list, + device="cpu", + dtype=torch.int64).view(-1, 2) + blocks_to_swap_in_external = torch.tensor( + block_to_swap_in_external_list, device="cpu", + dtype=torch.int64).view(-1, 2) + blocks_to_swap_out_external = torch.tensor( + block_to_swap_out_external_list, device="cpu", + dtype=torch.int64).view(-1, 2) # `blocks_to_copy` is a gpu tensor. The src and tgt of # blocks to copy are in the same device, and `blocks_to_copy` # can be used directly within cuda kernels. @@ -314,8 +364,10 @@ def prepare_worker_input( return WorkerInput( num_seq_groups=num_seq_groups, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, + blocks_to_swap_in=blocks_to_swap_in_cpu, + blocks_to_swap_in_external=blocks_to_swap_in_external, + blocks_to_swap_out=blocks_to_swap_out_cpu, + blocks_to_swap_out_external=blocks_to_swap_out_external, blocks_to_copy=blocks_to_copy, virtual_engine=virtual_engine, num_steps=num_steps, @@ -333,6 +385,14 @@ def execute_worker(self, worker_input: WorkerInput) -> None: and worker_input.blocks_to_swap_out.numel() > 0): self.cache_engine[virtual_engine].swap_out( worker_input.blocks_to_swap_out) + if (worker_input.blocks_to_swap_in_external is not None + and worker_input.blocks_to_swap_in_external.numel() > 0): + self.cache_engine[virtual_engine].swap_in_from_external( + worker_input.blocks_to_swap_in_external) + if (worker_input.blocks_to_swap_out_external is not None + and worker_input.blocks_to_swap_out_external.numel() > 0): + self.cache_engine[virtual_engine].swap_out_to_external( + worker_input.blocks_to_swap_out_external) if (worker_input.blocks_to_copy is not None and worker_input.blocks_to_copy.numel() > 0): self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 6ba4f272315c..8cb020ac3634 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -14,7 +14,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.platforms import current_platform from vllm.sequence import ExecuteModelRequest, IntermediateTensors -from vllm.utils import (enable_trace_function_call_for_thread, +from vllm.utils import (Device, enable_trace_function_call_for_thread, update_environment_variables) from vllm.worker.model_runner_base import (BroadcastableModelInput, ModelRunnerBase, @@ -51,6 +51,14 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: """ raise NotImplementedError + @abstractmethod + def determine_num_external_available_blocks(self) -> int: + raise NotImplementedError + + @abstractmethod + def initialize_external_cache(self, num_external_blocks: int) -> None: + raise NotImplementedError + @abstractmethod def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: @@ -119,6 +127,12 @@ def pin_lora(self, lora_id: int) -> bool: def list_loras(self) -> Set[int]: raise ValueError(f"{type(self)} does not support LoRA") + def determine_num_external_available_blocks(self) -> int: + return 0 + + def initialize_external_cache(self, num_external_blocks: int) -> None: + pass + @dataclasses.dataclass(frozen=True) class WorkerInput: @@ -128,7 +142,11 @@ class WorkerInput: num_seq_groups: Optional[int] = None blocks_to_swap_in: Optional[torch.Tensor] = None + blocks_to_swap_in_external: Optional[torch.Tensor] = None + blocks_to_swap_in_external_device: Optional[Device] = None blocks_to_swap_out: Optional[torch.Tensor] = None + blocks_to_swap_out_external: Optional[torch.Tensor] = None + blocks_to_swap_out_external_device: Optional[Device] = None blocks_to_copy: Optional[torch.Tensor] = None virtual_engine: int = 0 num_steps: int = 1