Skip to content

Commit 56191dc

Browse files
orozerycharlifu
authored andcommitted
[KV offload][3/N] Add worker-side CPU support (vllm-project#21448)
Signed-off-by: Or Ozeri <oro@il.ibm.com> Signed-off-by: charlifu <charlifu@amd.com>
1 parent e434176 commit 56191dc

File tree

2 files changed

+348
-0
lines changed

2 files changed

+348
-0
lines changed
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import random
4+
import time
5+
6+
import pytest
7+
import torch
8+
9+
from vllm.platforms import current_platform
10+
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
11+
from vllm.v1.attention.backends.flashinfer import FlashInferBackend
12+
from vllm.v1.attention.backends.mla.flashattn_mla import FlashAttnMLABackend
13+
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec
14+
from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandler
15+
16+
NUM_GPU_BLOCKS = [64]
17+
NUM_CPU_BLOCKS = [256]
18+
GPU_BLOCK_SIZES = [16]
19+
GPU_BLOCKS_PER_CPU_BLOCK = [1, 3]
20+
HEAD_SIZES = [64]
21+
NUM_HEADS = [8]
22+
NUM_LAYERS = [4]
23+
DTYPES = [torch.bfloat16]
24+
SEEDS = [0]
25+
CUDA_DEVICES = ['cuda:0']
26+
NUM_MAPPINGS = [3]
27+
28+
29+
@pytest.mark.parametrize("gpu_to_cpu", [True, False])
30+
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
31+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
32+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
33+
@pytest.mark.parametrize("gpu_block_size", GPU_BLOCK_SIZES)
34+
@pytest.mark.parametrize("gpu_blocks_per_cpu_block", GPU_BLOCKS_PER_CPU_BLOCK)
35+
@pytest.mark.parametrize("num_gpu_blocks", NUM_GPU_BLOCKS)
36+
@pytest.mark.parametrize("num_cpu_blocks", NUM_CPU_BLOCKS)
37+
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
38+
@pytest.mark.parametrize("dtype", DTYPES)
39+
@pytest.mark.parametrize("seed", SEEDS)
40+
@pytest.mark.parametrize("device", CUDA_DEVICES)
41+
@torch.inference_mode()
42+
def test_transfer(
43+
gpu_to_cpu: bool,
44+
num_mappings: int,
45+
head_size: int,
46+
num_heads: int,
47+
gpu_block_size: int,
48+
gpu_blocks_per_cpu_block: int,
49+
num_gpu_blocks: int,
50+
num_cpu_blocks: int,
51+
num_layers: int,
52+
dtype: torch.dtype,
53+
seed: int,
54+
device: str,
55+
) -> None:
56+
current_platform.seed_everything(seed)
57+
58+
# create per-layer GPU KV caches
59+
attn_backends_list = [
60+
FlashAttentionBackend, FlashInferBackend, FlashAttnMLABackend
61+
]
62+
63+
gpu_caches = {}
64+
attn_backends = {}
65+
for i in range(num_layers):
66+
layer_name = f'layer {i}'
67+
68+
attn_backend = attn_backends_list[i % len(attn_backends_list)]
69+
attn_backends[layer_name] = attn_backend
70+
71+
gpu_cache_shape = attn_backend.get_kv_cache_shape(
72+
num_gpu_blocks, gpu_block_size, num_heads, head_size)
73+
gpu_caches[layer_name] = torch.rand(gpu_cache_shape,
74+
dtype=dtype,
75+
device=device)
76+
77+
# create handler
78+
cpu_block_size = gpu_blocks_per_cpu_block * gpu_block_size
79+
handler = CpuGpuOffloadingHandler(attn_backends=attn_backends,
80+
gpu_block_size=gpu_block_size,
81+
cpu_block_size=cpu_block_size,
82+
num_cpu_blocks=num_cpu_blocks,
83+
gpu_caches=gpu_caches)
84+
85+
# select block mappings
86+
gpu_blocks = random.sample(range(num_gpu_blocks),
87+
num_mappings * gpu_blocks_per_cpu_block)
88+
cpu_blocks = random.sample(range(num_cpu_blocks), num_mappings)
89+
90+
# convert cpu blocks to gpu block size
91+
cpu_blocks_in_gpu_block_size = []
92+
for cpu_block in cpu_blocks:
93+
base_block_id = cpu_block * gpu_blocks_per_cpu_block
94+
for i in range(gpu_blocks_per_cpu_block):
95+
cpu_blocks_in_gpu_block_size.append(i + base_block_id)
96+
97+
# maybe skip a GPU block to test writing to the middle of a CPU block
98+
if gpu_to_cpu:
99+
gpu_blocks = gpu_blocks[gpu_blocks_per_cpu_block - 1:]
100+
cpu_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size[
101+
gpu_blocks_per_cpu_block - 1:]
102+
103+
# set transfer direction
104+
if gpu_to_cpu:
105+
src_kv_caches = handler.gpu_tensors
106+
dst_kv_caches = handler.cpu_tensors
107+
src_spec_class = GPULoadStoreSpec
108+
dst_spec_class = CPULoadStoreSpec
109+
src_blocks = gpu_blocks
110+
dst_blocks = cpu_blocks
111+
src_blocks_in_gpu_block_size = gpu_blocks
112+
dst_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size
113+
dst_size_in_gpu_blocks = num_cpu_blocks * gpu_blocks_per_cpu_block
114+
else:
115+
src_kv_caches = handler.cpu_tensors
116+
dst_kv_caches = handler.gpu_tensors
117+
src_spec_class = CPULoadStoreSpec
118+
dst_spec_class = GPULoadStoreSpec
119+
src_blocks = cpu_blocks
120+
dst_blocks = gpu_blocks
121+
src_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size
122+
dst_blocks_in_gpu_block_size = gpu_blocks
123+
dst_size_in_gpu_blocks = num_gpu_blocks
124+
125+
# build dst -> src mapping
126+
dst_to_src = {}
127+
for src_block, dst_block in zip(src_blocks_in_gpu_block_size,
128+
dst_blocks_in_gpu_block_size):
129+
dst_to_src[dst_block] = src_block
130+
131+
# build transfer specs
132+
src_spec = src_spec_class(src_blocks)
133+
dst_spec = dst_spec_class(dst_blocks)
134+
135+
# clone src and dst tensors before transfer
136+
orig_src_caches = [x.clone() for x in src_kv_caches]
137+
orig_dst_caches = [x.clone() for x in dst_kv_caches]
138+
139+
# call transfer function
140+
assert handler.transfer_async(1, (src_spec, dst_spec))
141+
assert set(handler.transfer_events.keys()) == {1}
142+
143+
# wait for transfer to complete
144+
end_time = time.time() + 10
145+
while time.time() < end_time:
146+
finished = handler.get_finished()
147+
if finished:
148+
assert finished == [(1, True)]
149+
break
150+
time.sleep(0.1)
151+
152+
# verify src tensors did not change
153+
for orig_tensor, tensor in zip(orig_src_caches, src_kv_caches):
154+
assert torch.equal(orig_tensor, tensor)
155+
156+
# verify dst tensors
157+
for dst_block in range(dst_size_in_gpu_blocks):
158+
src_block_candidate = dst_to_src.get(dst_block)
159+
for src_cache, dst_cache, orig_dst_cache, kv_dim in zip(
160+
src_kv_caches, dst_kv_caches, orig_dst_caches,
161+
handler.kv_dim_before_num_blocks):
162+
if kv_dim:
163+
# iterate over key, value
164+
for i in range(2):
165+
if src_block_candidate is not None:
166+
expected_value = src_cache[i][src_block_candidate]
167+
else:
168+
expected_value = orig_dst_cache[i][dst_block]
169+
torch.testing.assert_close(dst_cache[i][dst_block].cpu(),
170+
expected_value.cpu())
171+
else:
172+
if src_block_candidate is not None:
173+
expected_value = src_cache[src_block_candidate]
174+
else:
175+
expected_value = orig_dst_cache[dst_block]
176+
torch.testing.assert_close(dst_cache[dst_block].cpu(),
177+
expected_value.cpu())
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import numpy as np
5+
import torch
6+
7+
from vllm import _custom_ops as ops
8+
from vllm.attention import AttentionBackend
9+
from vllm.logger import init_logger
10+
from vllm.utils import is_pin_memory_available
11+
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec
12+
from vllm.v1.kv_offload.worker.worker import (OffloadingHandler,
13+
TransferResult, TransferSpec)
14+
15+
logger = init_logger(__name__)
16+
17+
18+
def expand_block_ids(block_ids: np.ndarray,
19+
block_size_factor: int,
20+
output: np.ndarray,
21+
skip_count: int = 0):
22+
"""
23+
Convert a list of block IDs to a list of matching block ids,
24+
assuming each block is composed of actual block_size_factor blocks.
25+
Outputs to output tensor.
26+
The first skip_count blocks will be skipped.
27+
Note that skip_count must be less than block_size_factor.
28+
29+
For example, if block_ids = [0, 1, 3] and block_size_factor = 4,
30+
then it yields [0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15]
31+
since 0 maps to [0, 1, 2, 3]
32+
1 maps to [4, 5, 6, 7]
33+
and 3 maps to [12, 13, 14, 15]
34+
"""
35+
assert skip_count < block_size_factor
36+
37+
first_range = np.arange(skip_count, block_size_factor)
38+
full_range = np.arange(0, block_size_factor)
39+
40+
output_idx = 0
41+
for i, block_id in enumerate(block_ids):
42+
base_block_id = block_id * block_size_factor
43+
indices = first_range if i == 0 else full_range
44+
output_end_idx = output_idx + len(indices)
45+
output[output_idx:output_end_idx] = base_block_id + indices
46+
output_idx = output_end_idx
47+
48+
49+
class CpuGpuOffloadingHandler(OffloadingHandler):
50+
51+
def __init__(self, gpu_block_size: int, cpu_block_size: int,
52+
num_cpu_blocks: int, gpu_caches: dict[str, torch.Tensor],
53+
attn_backends: dict[str, type[AttentionBackend]]):
54+
assert cpu_block_size % gpu_block_size == 0
55+
self.block_size_factor = cpu_block_size // gpu_block_size
56+
57+
# cuda streams for gpu->cpu and cpu->gpu
58+
self.d2h_stream = torch.cuda.Stream()
59+
self.h2d_stream = torch.cuda.Stream()
60+
61+
# job_id -> transfer cuda event
62+
self.transfer_events: dict[int, torch.cuda.Event] = {}
63+
# list of cuda events available for re-use
64+
self.events_pool: list[torch.cuda.Event] = []
65+
66+
pin_memory = is_pin_memory_available()
67+
68+
# allocate cpu tensors
69+
logger.info("Allocating %d CPU tensors...", len(gpu_caches))
70+
self.gpu_tensors: list[torch.Tensor] = []
71+
self.cpu_tensors: list[torch.Tensor] = []
72+
self.kv_dim_before_num_blocks: list[bool] = []
73+
for layer_name, gpu_tensor in gpu_caches.items():
74+
self.gpu_tensors.append(gpu_tensor)
75+
76+
gpu_shape = gpu_tensor.shape
77+
test_shape = attn_backends[layer_name].get_kv_cache_shape(
78+
num_blocks=1234, block_size=16, num_kv_heads=8, head_size=256)
79+
if test_shape[0] == 1234:
80+
# shape is (num_blocks, ...)
81+
num_blocks_idx = 0
82+
self.kv_dim_before_num_blocks.append(False)
83+
else:
84+
# shape should be (2, num_blocks, ...)
85+
assert test_shape[0] == 2
86+
assert test_shape[1] == 1234
87+
assert gpu_shape[0] == 2
88+
89+
num_blocks_idx = 1
90+
self.kv_dim_before_num_blocks.append(True)
91+
92+
cpu_shape = list(gpu_shape)
93+
cpu_shape[num_blocks_idx] = num_cpu_blocks * self.block_size_factor
94+
95+
logger.debug("Allocating CPU tensor of shape %r", cpu_shape)
96+
self.cpu_tensors.append(
97+
torch.zeros(cpu_shape,
98+
dtype=gpu_tensor.dtype,
99+
device="cpu",
100+
pin_memory=pin_memory))
101+
102+
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
103+
src_spec, dst_spec = spec
104+
if isinstance(src_spec, CPULoadStoreSpec):
105+
assert isinstance(dst_spec, GPULoadStoreSpec)
106+
stream = self.h2d_stream
107+
src_tensors = self.cpu_tensors
108+
dst_tensors = self.gpu_tensors
109+
src_block_size_factor = self.block_size_factor
110+
dst_block_size_factor = 1
111+
else:
112+
assert isinstance(src_spec, GPULoadStoreSpec)
113+
assert isinstance(dst_spec, CPULoadStoreSpec)
114+
stream = self.d2h_stream
115+
src_tensors = self.gpu_tensors
116+
dst_tensors = self.cpu_tensors
117+
src_block_size_factor = 1
118+
dst_block_size_factor = self.block_size_factor
119+
120+
src_blocks = src_spec.block_ids
121+
dst_blocks = dst_spec.block_ids
122+
assert src_blocks.ndim == 1
123+
assert dst_blocks.ndim == 1
124+
125+
dst_sub_blocks_to_skip = (-src_blocks.size % dst_block_size_factor)
126+
src_sub_block_count = src_blocks.size * src_block_size_factor
127+
128+
assert (
129+
src_sub_block_count == dst_blocks.size * dst_block_size_factor -
130+
dst_sub_blocks_to_skip)
131+
132+
src_to_dst = np.empty((src_sub_block_count, 2), dtype=np.int64)
133+
expand_block_ids(src_blocks, src_block_size_factor, src_to_dst[:, 0])
134+
expand_block_ids(dst_blocks,
135+
dst_block_size_factor,
136+
src_to_dst[:, 1],
137+
skip_count=dst_sub_blocks_to_skip)
138+
src_to_dst_tensor = torch.from_numpy(src_to_dst)
139+
140+
event = self.events_pool.pop() if self.events_pool \
141+
else torch.cuda.Event()
142+
with torch.cuda.stream(stream):
143+
for src_tensor, dst_tensor, kv_dim in zip(
144+
src_tensors, dst_tensors, self.kv_dim_before_num_blocks):
145+
if kv_dim:
146+
src_key_cache = src_tensor[0]
147+
dst_key_cache = dst_tensor[0]
148+
ops.swap_blocks(src_key_cache, dst_key_cache,
149+
src_to_dst_tensor)
150+
src_value_cache = src_tensor[1]
151+
dst_value_cache = dst_tensor[1]
152+
ops.swap_blocks(src_value_cache, dst_value_cache,
153+
src_to_dst_tensor)
154+
else:
155+
ops.swap_blocks(src_tensor, dst_tensor, src_to_dst_tensor)
156+
event.record(stream)
157+
158+
self.transfer_events[job_id] = event
159+
160+
# success
161+
return True
162+
163+
def get_finished(self) -> list[TransferResult]:
164+
results: list[TransferResult] = []
165+
for job_id, event in self.transfer_events.items():
166+
if event.query():
167+
results.append((job_id, True))
168+
self.events_pool.append(event)
169+
for job_id, _ in results:
170+
del self.transfer_events[job_id]
171+
return results

0 commit comments

Comments
 (0)