Skip to content

Commit 9b5084f

Browse files
orozeryxuebwang-amd
authored andcommitted
[KV offload][2/N] Introduce LRU-based CPU offloading management (vllm-project#20075)
Signed-off-by: Or Ozeri <oro@il.ibm.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent 8db947d commit 9b5084f

File tree

4 files changed

+464
-0
lines changed

4 files changed

+464
-0
lines changed

tests/v1/kv_offload/test_cpu.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from collections.abc import Iterable
4+
from dataclasses import dataclass
5+
from typing import Optional
6+
7+
import numpy as np
8+
9+
from vllm.v1.core.kv_cache_utils import BlockHash
10+
from vllm.v1.kv_offload.abstract import (LoadStoreSpec, OffloadingEvent,
11+
PrepareStoreOutput)
12+
from vllm.v1.kv_offload.backends.cpu import CPUBackend
13+
from vllm.v1.kv_offload.lru_manager import LRUOffloadingManager
14+
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec
15+
16+
17+
@dataclass
18+
class ExpectedPrepareStoreOutput:
19+
block_hashes_to_store: list[int]
20+
store_block_ids: list[int]
21+
block_hashes_evicted: list[int]
22+
23+
24+
def to_hashes(int_hashes: list[int]) -> list[BlockHash]:
25+
return [BlockHash(str(i).encode()) for i in int_hashes]
26+
27+
28+
def verify_store_output(
29+
prepare_store_output: Optional[PrepareStoreOutput],
30+
expected_prepare_store_output: ExpectedPrepareStoreOutput):
31+
assert prepare_store_output is not None
32+
assert (prepare_store_output.block_hashes_to_store == to_hashes(
33+
expected_prepare_store_output.block_hashes_to_store))
34+
assert (prepare_store_output.block_hashes_evicted == to_hashes(
35+
expected_prepare_store_output.block_hashes_evicted))
36+
store_spec = prepare_store_output.store_spec
37+
assert isinstance(store_spec, CPULoadStoreSpec)
38+
expected_array = np.array(expected_prepare_store_output.store_block_ids,
39+
dtype=np.int64)
40+
assert np.array_equal(expected_array, store_spec.block_ids)
41+
42+
43+
def verify_load_output(prepare_load_output: LoadStoreSpec,
44+
expected_prepare_load_output: list[int]):
45+
assert isinstance(prepare_load_output, CPULoadStoreSpec)
46+
expected_array = np.array(expected_prepare_load_output, dtype=np.int64)
47+
assert np.array_equal(expected_array, prepare_load_output.block_ids)
48+
49+
50+
def verify_events(events: Iterable[OffloadingEvent],
51+
block_size: int,
52+
expected_stores: tuple[set[int], ...] = (),
53+
expected_evictions: tuple[set[int], ...] = ()):
54+
stores: list[set[BlockHash]] = []
55+
evictions: list[set[BlockHash]] = []
56+
for event in events:
57+
assert event.medium == CPULoadStoreSpec.medium()
58+
assert event.block_size == block_size
59+
if event.removed:
60+
evictions.append(set(event.block_hashes))
61+
else:
62+
stores.append(set(event.block_hashes))
63+
64+
def to_hash_sets(
65+
int_sets: tuple[set[int], ...]) -> tuple[set[BlockHash], ...]:
66+
return tuple([set(to_hashes(list(int_set))) for int_set in int_sets])
67+
68+
assert tuple(evictions) == to_hash_sets(expected_evictions)
69+
assert tuple(stores) == to_hash_sets(expected_stores)
70+
71+
72+
def test_cpu_manager():
73+
"""
74+
Tests LRUOffloadingManager with a CPUBackend.
75+
"""
76+
# initialize a CPU backend with a capacity of 4 blocks
77+
block_size = 256
78+
cpu_backend = CPUBackend(block_size=block_size, num_blocks=4)
79+
cpu_manager = LRUOffloadingManager(cpu_backend, enable_events=True)
80+
81+
# prepare store [1, 2]
82+
prepare_store_output = cpu_manager.prepare_store(to_hashes([1, 2]))
83+
verify_store_output(
84+
prepare_store_output,
85+
ExpectedPrepareStoreOutput(
86+
block_hashes_to_store=[1, 2],
87+
store_block_ids=[0, 1],
88+
block_hashes_evicted=[],
89+
))
90+
91+
# lookup [1, 2] -> not ready
92+
assert cpu_manager.lookup(to_hashes([1, 2])) == 0
93+
94+
# no events so far
95+
assert list(cpu_manager.take_events()) == []
96+
97+
# complete store [1, 2]
98+
cpu_manager.complete_store(to_hashes([1, 2]))
99+
verify_events(cpu_manager.take_events(),
100+
block_size=block_size,
101+
expected_stores=({1, 2}, ))
102+
103+
# lookup [1, 2]
104+
assert cpu_manager.lookup(to_hashes([1])) == 1
105+
assert cpu_manager.lookup(to_hashes([1, 2])) == 2
106+
assert cpu_manager.lookup(to_hashes([1, 2, 3])) == 2
107+
108+
# prepare store [2, 3, 4, 5] -> evicts [1]
109+
prepare_store_output = cpu_manager.prepare_store(to_hashes([2, 3, 4, 5]))
110+
verify_store_output(
111+
prepare_store_output,
112+
ExpectedPrepareStoreOutput(
113+
block_hashes_to_store=[3, 4, 5],
114+
store_block_ids=[2, 3, 0],
115+
block_hashes_evicted=[1],
116+
))
117+
118+
# verify eviction event
119+
verify_events(cpu_manager.take_events(),
120+
block_size=block_size,
121+
expected_evictions=({1}, ))
122+
123+
# prepare store with no space
124+
assert cpu_manager.prepare_store(to_hashes([1, 6])) is None
125+
126+
# complete store [2, 3, 4, 5]
127+
cpu_manager.complete_store(to_hashes([2, 3, 4, 5]))
128+
129+
# prepare load [2, 3]
130+
prepare_load_output = cpu_manager.prepare_load(to_hashes([2, 3]))
131+
verify_load_output(prepare_load_output, [1, 2])
132+
133+
# prepare store with no space ([2, 3] is being loaded)
134+
assert cpu_manager.prepare_store(to_hashes([6, 7, 8])) is None
135+
136+
# complete load [2, 3]
137+
cpu_manager.complete_load(to_hashes([2, 3]))
138+
139+
# prepare store [6, 7, 8] -> evicts [2, 3, 4] (oldest)
140+
prepare_store_output = cpu_manager.prepare_store(to_hashes([6, 7, 8]))
141+
verify_store_output(
142+
prepare_store_output,
143+
ExpectedPrepareStoreOutput(
144+
block_hashes_to_store=[6, 7, 8],
145+
store_block_ids=[3, 2, 1],
146+
block_hashes_evicted=[2, 3, 4],
147+
))
148+
149+
# complete store [6, 7, 8]
150+
cpu_manager.complete_store(to_hashes([6, 7, 8]))
151+
152+
# touch [5, 6, 7] (move to end of LRU order)
153+
cpu_manager.touch(to_hashes([5, 6, 7]))
154+
155+
# prepare store [7, 9] -> evicts [8] (oldest following previous touch)
156+
prepare_store_output = cpu_manager.prepare_store(to_hashes([9]))
157+
verify_store_output(
158+
prepare_store_output,
159+
ExpectedPrepareStoreOutput(
160+
block_hashes_to_store=[9],
161+
store_block_ids=[1],
162+
block_hashes_evicted=[8],
163+
))
164+
165+
# complete store [7, 9] with failure
166+
cpu_manager.complete_store(to_hashes([7, 9]), success=False)
167+
168+
# assert [7] is still stored, but [9] is not
169+
assert cpu_manager.lookup(to_hashes([7])) == 1
170+
assert cpu_manager.lookup(to_hashes([9])) == 0
171+
172+
verify_events(cpu_manager.take_events(),
173+
block_size=block_size,
174+
expected_stores=({3, 4, 5}, {6, 7, 8}),
175+
expected_evictions=({2, 3, 4}, {8}))

vllm/v1/kv_offload/backend.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import ctypes
4+
from abc import ABC, abstractmethod
5+
from collections.abc import Iterable
6+
7+
from vllm.v1.core.kv_cache_utils import BlockHash
8+
from vllm.v1.kv_offload.abstract import LoadStoreSpec
9+
10+
11+
class BlockStatus(ctypes.Structure):
12+
"""
13+
Offloading status for a single block of KV data.
14+
Holds the following information:
15+
16+
ref_cnt - the current number of transfers using this block as a source.
17+
A value of -1 indicates the block is not yet ready to be read.
18+
load_store_spec - backend-specific information on how to actually
19+
read/write the block.
20+
"""
21+
_fields_ = [("ref_cnt", ctypes.c_int32)]
22+
23+
def __init__(self):
24+
super().__init__()
25+
# initialize block as "not ready" (ref_cnt = -1)
26+
self.ref_cnt = -1
27+
28+
@property
29+
def is_ready(self) -> bool:
30+
"""
31+
Returns whether the block is ready to be read.
32+
"""
33+
return self.ref_cnt >= 0
34+
35+
36+
class Backend(ABC):
37+
"""
38+
An abstract class for allocating and returning specs for writing
39+
KV blocks to some backend.
40+
"""
41+
42+
def __init__(self, block_size: int, medium: str):
43+
self.block_size = block_size
44+
self.medium = medium
45+
46+
@abstractmethod
47+
def get_num_free_blocks(self):
48+
"""
49+
Returns the number of current number of blocks that can be allocated.
50+
"""
51+
pass
52+
53+
@abstractmethod
54+
def allocate_blocks(self,
55+
block_hashes: list[BlockHash]) -> list[BlockStatus]:
56+
"""
57+
Allocate space for writing blocks.
58+
This method assumes there is enough space for allocation.
59+
It is unsafe to use without checking get_num_free_blocks beforehand.
60+
61+
Args:
62+
block_hashes: the hashes identifying the blocks to be written.
63+
64+
Returns:
65+
A list of BlockStatus for the allocated blocks.
66+
The ref_cnt of each returned item will be -1, meaning the block
67+
is not yet ready to be read.
68+
"""
69+
pass
70+
71+
@abstractmethod
72+
def free(self, block: BlockStatus):
73+
"""
74+
Free a previously allocated block.
75+
You should only call this function with blocks returned by
76+
allocate_blocks, and only once per each block.
77+
78+
Args:
79+
block: The block to be freed.
80+
"""
81+
pass
82+
83+
def get_load_store_spec(self, block_hashes: Iterable[BlockHash],
84+
blocks: Iterable[BlockStatus]) -> LoadStoreSpec:
85+
"""
86+
Get backend-specific information on how to read/write blocks.
87+
88+
Args:
89+
block_hashes: the list of block hashes identifying the blocks.
90+
blocks: the list of blocks.
91+
92+
Returns:
93+
A LoadStoreSpec that can be used by a worker
94+
to read/write the blocks.
95+
"""
96+
raise NotImplementedError

vllm/v1/kv_offload/backends/cpu.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import ctypes
4+
from collections.abc import Iterable
5+
6+
from vllm.v1.core.kv_cache_utils import BlockHash
7+
from vllm.v1.kv_offload.abstract import LoadStoreSpec
8+
from vllm.v1.kv_offload.backend import Backend, BlockStatus
9+
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec
10+
11+
12+
class CPUBlockStatus(BlockStatus):
13+
_fields_ = BlockStatus._fields_ + [("block_id", ctypes.c_int64)
14+
] # type: ignore
15+
16+
def __init__(self, block_id: int):
17+
super().__init__()
18+
self.block_id = block_id
19+
20+
21+
class CPUBackend(Backend):
22+
23+
def __init__(self, block_size: int, num_blocks: int):
24+
super().__init__(block_size=block_size,
25+
medium=CPULoadStoreSpec.medium())
26+
27+
self.num_blocks: int = num_blocks
28+
self.num_allocated_blocks: int = 0
29+
self.allocated_blocks_free_list: list[int] = []
30+
31+
def get_num_free_blocks(self):
32+
return (len(self.allocated_blocks_free_list) + self.num_blocks -
33+
self.num_allocated_blocks)
34+
35+
def allocate_blocks(self,
36+
block_hashes: list[BlockHash]) -> list[BlockStatus]:
37+
num_fresh_blocks = min(len(block_hashes),
38+
self.num_blocks - self.num_allocated_blocks)
39+
num_reused_blocks = len(block_hashes) - num_fresh_blocks
40+
assert len(self.allocated_blocks_free_list) >= num_reused_blocks
41+
42+
# allocate fresh blocks
43+
blocks: list[BlockStatus] = []
44+
for _ in range(num_fresh_blocks):
45+
blocks.append(CPUBlockStatus(self.num_allocated_blocks))
46+
self.num_allocated_blocks += 1
47+
48+
# allocate reused blocks
49+
for _ in range(num_reused_blocks):
50+
block_id = self.allocated_blocks_free_list.pop()
51+
blocks.append(CPUBlockStatus(block_id))
52+
53+
return blocks
54+
55+
def free(self, block: BlockStatus):
56+
assert isinstance(block, CPUBlockStatus)
57+
self.allocated_blocks_free_list.append(block.block_id)
58+
59+
def get_load_store_spec(self, block_hashes: Iterable[BlockHash],
60+
blocks: Iterable[BlockStatus]) -> LoadStoreSpec:
61+
return CPULoadStoreSpec([block.block_id for block in blocks])

0 commit comments

Comments
 (0)