Skip to content

Commit 9607d5e

Browse files
authored
[Hybrid Allocator] Support full attention with different hidden size (#25101)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
1 parent c60e613 commit 9607d5e

File tree

6 files changed

+325
-93
lines changed

6 files changed

+325
-93
lines changed

tests/v1/core/test_kv_cache_utils.py

Lines changed: 103 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@
1818
from vllm.v1.core.kv_cache_utils import (
1919
BlockHash, FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics,
2020
estimate_max_model_len, generate_block_hash_extra_keys,
21-
get_kv_cache_configs, get_max_concurrency_for_kv_cache_config,
22-
get_request_block_hasher, hash_block_tokens, init_none_hash,
23-
is_kv_cache_type_uniform, make_block_hash_with_group_id)
21+
generate_scheduler_kv_cache_config, get_kv_cache_configs,
22+
get_max_concurrency_for_kv_cache_config, get_request_block_hasher,
23+
hash_block_tokens, init_none_hash, is_kv_cache_spec_uniform,
24+
make_block_hash_with_group_id)
2425
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
2526
KVCacheGroupSpec, KVCacheSpec,
26-
KVCacheTensor, SlidingWindowSpec)
27+
KVCacheTensor, SlidingWindowSpec,
28+
UniformTypeKVCacheSpecs)
2729
from vllm.v1.metrics.stats import PrefixCacheStats
2830
from vllm.v1.request import Request
2931

@@ -927,36 +929,36 @@ def test_merge_kv_cache_spec():
927929
assert merged_layer_spec.sliding_window == 1
928930

929931

930-
def test_is_kv_cache_type_uniform():
932+
def test_is_kv_cache_spec_uniform():
931933
kv_cache_spec = {
932934
"layer_1": new_kv_cache_spec(num_kv_heads=32),
933935
"layer_2": new_kv_cache_spec(num_kv_heads=32),
934936
}
935-
assert is_kv_cache_type_uniform(kv_cache_spec)
937+
assert is_kv_cache_spec_uniform(kv_cache_spec)
936938

937939
kv_cache_spec = {
938940
"layer_1": new_kv_cache_spec(num_kv_heads=32),
939941
"layer_2": new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
940942
}
941-
assert is_kv_cache_type_uniform(kv_cache_spec)
943+
assert is_kv_cache_spec_uniform(kv_cache_spec)
942944

943945
kv_cache_spec = {
944946
"layer_1": new_kv_cache_spec(num_kv_heads=32),
945947
"layer_2": new_sliding_window_spec(num_kv_heads=32, sliding_window=1),
946948
}
947-
assert not is_kv_cache_type_uniform(kv_cache_spec)
949+
assert not is_kv_cache_spec_uniform(kv_cache_spec)
948950

949951
kv_cache_spec = {
950952
"layer_1": new_sliding_window_spec(num_kv_heads=32, sliding_window=1),
951953
"layer_2": new_sliding_window_spec(num_kv_heads=32, sliding_window=1),
952954
}
953-
assert is_kv_cache_type_uniform(kv_cache_spec)
955+
assert is_kv_cache_spec_uniform(kv_cache_spec)
954956

955957
kv_cache_spec = {
956958
"layer_1": new_sliding_window_spec(num_kv_heads=32, sliding_window=1),
957959
"layer_2": new_sliding_window_spec(num_kv_heads=32, sliding_window=2),
958960
}
959-
assert not is_kv_cache_type_uniform(kv_cache_spec)
961+
assert not is_kv_cache_spec_uniform(kv_cache_spec)
960962

961963

962964
@pytest.mark.parametrize(
@@ -1286,14 +1288,28 @@ def test_get_kv_cache_config_one_worker():
12861288
],
12871289
)
12881290

1289-
# different hidden size, unimplemented
1291+
# different hidden size
12901292
kv_cache_specs_hybrid = {
12911293
'layer_1': new_kv_cache_spec(head_size=128),
1292-
'layer_2': new_kv_cache_spec(),
1294+
'layer_2': new_kv_cache_spec(head_size=64),
12931295
}
1294-
with pytest.raises(NotImplementedError):
1295-
get_kv_cache_configs(vllm_config, [kv_cache_specs_hybrid],
1296-
[mem_per_block_per_layer * 2 * 32])[0]
1296+
kv_cache_config_hybrid = get_kv_cache_configs(
1297+
vllm_config, [kv_cache_specs_hybrid],
1298+
[mem_per_block_per_layer * 3 * 32])[0]
1299+
assert kv_cache_config_hybrid == KVCacheConfig(
1300+
num_blocks=32,
1301+
kv_cache_tensors=[
1302+
KVCacheTensor(size=mem_per_block_per_layer * 32 * 2,
1303+
shared_by=["layer_1"]),
1304+
KVCacheTensor(size=mem_per_block_per_layer * 32,
1305+
shared_by=["layer_2"]),
1306+
],
1307+
kv_cache_groups=[
1308+
KVCacheGroupSpec(["layer_1", "layer_2"],
1309+
UniformTypeKVCacheSpecs(
1310+
block_size=16,
1311+
kv_cache_specs=kv_cache_specs_hybrid))
1312+
])
12971313

12981314
# Test num_gpu_blocks_override
12991315
vllm_config.cache_config.num_gpu_blocks_override = 16
@@ -1324,3 +1340,75 @@ def test_get_kv_cache_configs_attention_free():
13241340
kv_cache_groups=[],
13251341
)
13261342
]
1343+
1344+
1345+
def test_generate_uniform_type_kv_cache_specs():
1346+
# All layers are full attention, can be merged
1347+
kv_cache_specs = {
1348+
'layer_1': new_kv_cache_spec(),
1349+
'layer_2': new_kv_cache_spec(head_size=128),
1350+
}
1351+
uniform_spec = UniformTypeKVCacheSpecs.from_specs(kv_cache_specs)
1352+
assert uniform_spec == UniformTypeKVCacheSpecs(
1353+
block_size=16, kv_cache_specs=kv_cache_specs)
1354+
1355+
# Full attention + sliding window, cannot be merged
1356+
kv_cache_specs = {
1357+
'layer_1': new_kv_cache_spec(),
1358+
'layer_2': new_sliding_window_spec(sliding_window=1),
1359+
}
1360+
uniform_spec = UniformTypeKVCacheSpecs.from_specs(kv_cache_specs)
1361+
assert uniform_spec is None
1362+
1363+
# different order of full attention + sliding window, cannot be merged
1364+
kv_cache_specs = {
1365+
'layer_1': new_sliding_window_spec(sliding_window=1),
1366+
'layer_2': new_kv_cache_spec(),
1367+
}
1368+
uniform_spec = UniformTypeKVCacheSpecs.from_specs(kv_cache_specs)
1369+
assert uniform_spec is None
1370+
1371+
# Same-size sliding window, can be merged
1372+
kv_cache_specs = {
1373+
'layer_1': new_sliding_window_spec(sliding_window=1),
1374+
'layer_2': new_sliding_window_spec(sliding_window=1, head_size=128),
1375+
}
1376+
uniform_spec = UniformTypeKVCacheSpecs.from_specs(kv_cache_specs)
1377+
assert uniform_spec == UniformTypeKVCacheSpecs(
1378+
block_size=16, kv_cache_specs=kv_cache_specs)
1379+
1380+
# different block sizes, cannot be merged
1381+
kv_cache_specs = {
1382+
'layer_1': new_kv_cache_spec(block_size=16),
1383+
'layer_2': new_kv_cache_spec(block_size=32),
1384+
}
1385+
uniform_spec = UniformTypeKVCacheSpecs.from_specs(kv_cache_specs)
1386+
assert uniform_spec is None
1387+
1388+
1389+
def test_generate_scheduler_kv_cache_config():
1390+
kv_cache_specs = {
1391+
'layer_1': new_kv_cache_spec(),
1392+
'layer_2': new_kv_cache_spec(head_size=128),
1393+
}
1394+
kv_cache_configs = [
1395+
KVCacheConfig(
1396+
num_blocks=10,
1397+
kv_cache_tensors=[],
1398+
kv_cache_groups=[
1399+
KVCacheGroupSpec(['layer_1', 'layer_2'],
1400+
UniformTypeKVCacheSpecs(
1401+
block_size=16,
1402+
kv_cache_specs=kv_cache_specs)),
1403+
],
1404+
)
1405+
]
1406+
scheduler_kv_cache_config = generate_scheduler_kv_cache_config(
1407+
kv_cache_configs)
1408+
assert scheduler_kv_cache_config == KVCacheConfig(
1409+
num_blocks=10,
1410+
kv_cache_tensors=[],
1411+
kv_cache_groups=[
1412+
KVCacheGroupSpec(['layer_1', 'layer_2'], new_kv_cache_spec())
1413+
],
1414+
)

0 commit comments

Comments
 (0)