Skip to content

Commit f39bd30

Browse files
authored
[Hybrid KV] Follow up UniformTypeKVCacheSpecs (#3070)
### What this PR does / why we need it? Follow up `UniformTypeKVCacheSpecs` changes introduced by vllm-project/vllm#25101, which support different hidden size in uniform type kvcache specs This also fix the CI issue about `TypeError: AttentionGroup.__init__() missing 1 required positional argument: 'kv_cache_spec'` ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? Tests passed with exsiting e2e tests. - vLLM version: v0.10.2 - vLLM main: vllm-project/vllm@c60e613 --------- Signed-off-by: MengqingCao <cmq0113@163.com>
1 parent f1f2c8f commit f39bd30

File tree

4 files changed

+101
-35
lines changed

4 files changed

+101
-35
lines changed

.github/workflows/format_pr_body.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ jobs:
3636

3737
- name: Get vLLM version
3838
run: |
39-
VLLM_COMMIT=c60e6137f0bf2034853919b3a9d705d7e06b93cf
39+
VLLM_COMMIT=9607d5eb449711b349d4c2bee0a9c94afcc7ed14
4040
echo "VLLM_COMMIT=https://github.com/vllm-project/vllm/commit/$VLLM_COMMIT" >> $GITHUB_ENV
4141
4242
- name: Checkout repository

.github/workflows/vllm_ascend_test.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ jobs:
4242
lint:
4343
uses: ./.github/workflows/pre-commit.yml
4444
with:
45-
vllm: c60e6137f0bf2034853919b3a9d705d7e06b93cf
45+
vllm: 9607d5eb449711b349d4c2bee0a9c94afcc7ed14
4646

4747
changes:
4848
runs-on: ubuntu-latest
@@ -83,7 +83,7 @@ jobs:
8383
VLLM_USE_MODELSCOPE: True
8484
strategy:
8585
matrix:
86-
vllm_version: [c60e6137f0bf2034853919b3a9d705d7e06b93cf, v0.10.2]
86+
vllm_version: [9607d5eb449711b349d4c2bee0a9c94afcc7ed14, v0.10.2]
8787
steps:
8888
- name: Install packages
8989
run: |
@@ -138,7 +138,7 @@ jobs:
138138
name: e2e-light
139139
strategy:
140140
matrix:
141-
vllm_version: [c60e6137f0bf2034853919b3a9d705d7e06b93cf, v0.10.2]
141+
vllm_version: [9607d5eb449711b349d4c2bee0a9c94afcc7ed14, v0.10.2]
142142
# Note (yikun): If CI resource are limited we can split job into two chain jobs
143143
needs: [lint, changes]
144144
# only trigger e2e test after lint passed and the change is e2e related with pull request.

.github/workflows/vllm_ascend_test_full.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ jobs:
6868
name: e2e-full
6969
strategy:
7070
matrix:
71-
vllm_version: [c60e6137f0bf2034853919b3a9d705d7e06b93cf, v0.10.2]
71+
vllm_version: [9607d5eb449711b349d4c2bee0a9c94afcc7ed14, v0.10.2]
7272
needs: [changes]
7373
if: ${{ needs.changes.outputs.e2e_tracker == 'true' }}
7474
uses: ./.github/workflows/_e2e_test.yaml

vllm_ascend/worker/model_runner_v1.py

Lines changed: 96 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
from copy import deepcopy
2828
from dataclasses import dataclass
2929
from multiprocessing import Manager
30-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
30+
from typing import (TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional,
31+
Union, cast)
3132

3233
import numpy as np
3334
import numpy.typing as npt
@@ -72,8 +73,12 @@
7273
from vllm.v1.attention.backends.utils import \
7374
reorder_batch_to_split_decodes_and_prefills
7475
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
76+
# yapf conflicts with isort for this block
77+
# yapf: disable
7578
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
76-
KVCacheConfig, KVCacheSpec, MambaSpec)
79+
KVCacheConfig, KVCacheGroupSpec,
80+
KVCacheSpec, MambaSpec)
81+
# yapf: enable
7782
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
7883
DraftTokenIds, LogprobsTensors, ModelRunnerOutput)
7984
from vllm.v1.pool.metadata import PoolingMetadata
@@ -134,6 +139,11 @@
134139
else:
135140
ACL_FORMAT = ACL_FORMAT_FRACTAL_ND
136141

142+
if not vllm_version_is("0.10.2"):
143+
from vllm.v1.kv_cache_interface import UniformTypeKVCacheSpecs
144+
else:
145+
UniformTypeKVCacheSpecs = None
146+
137147

138148
@dataclass
139149
class GraphCaptureContext:
@@ -2584,10 +2594,13 @@ def initialize_kv_cache_tensors_deepseek(
25842594
kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size
25852595

25862596
kv_caches: Dict[str, torch.Tensor] = {}
2587-
for kv_cache_spec, kv_cache_group in self._kv_cache_spec_attn_group_iterator(
2588-
):
2589-
attn_backend = kv_cache_group.backend
2590-
for layer_name in kv_cache_group.layer_names:
2597+
for group in self._kv_cache_spec_attn_group_iterator_dispatcher():
2598+
if vllm_version_is("0.10.2"):
2599+
kv_cache_spec, group = group
2600+
else:
2601+
kv_cache_spec = group.kv_cache_spec
2602+
attn_backend = group.backend
2603+
for layer_name in group.layer_names:
25912604
if layer_name in self.runner_only_attn_layers:
25922605
continue
25932606
tensor_size = kv_cache_sizes[layer_name]
@@ -2729,10 +2742,13 @@ def initialize_kv_cache_tensors(
27292742
)), "Some layers are not correctly initialized"
27302743

27312744
kv_caches: Dict[str, torch.Tensor] = {}
2732-
for kv_cache_spec, kv_cache_group in self._kv_cache_spec_attn_group_iterator(
2733-
):
2734-
attn_backend = kv_cache_group.backend
2735-
for layer_name in kv_cache_group.layer_names:
2745+
for group in self._kv_cache_spec_attn_group_iterator_dispatcher():
2746+
if vllm_version_is("0.10.2"):
2747+
kv_cache_spec, group = group
2748+
else:
2749+
kv_cache_spec = group.kv_cache_spec
2750+
attn_backend = group.backend
2751+
for layer_name in group.layer_names:
27362752
if layer_name in self.runner_only_attn_layers:
27372753
continue
27382754

@@ -2829,15 +2845,6 @@ def initialize_kv_cache_tensors(
28292845

28302846
return kv_caches
28312847

2832-
def _kv_cache_spec_attn_group_iterator(
2833-
self) -> Iterator[tuple[KVCacheSpec, AttentionGroup]]:
2834-
if not self.kv_cache_config.kv_cache_groups:
2835-
return
2836-
for kv_cache_spec_id, attn_groups in enumerate(self.attn_groups):
2837-
for attn_group in attn_groups:
2838-
yield self.kv_cache_config.kv_cache_groups[
2839-
kv_cache_spec_id].kv_cache_spec, attn_group
2840-
28412848
def may_reinitialize_input_batch(self,
28422849
kv_cache_config: KVCacheConfig) -> None:
28432850
"""
@@ -2917,9 +2924,45 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
29172924
assert len(self.attn_groups) == 0, \
29182925
"Attention backends are already initialized"
29192926

2927+
class AttentionGroupKey(NamedTuple):
2928+
attn_backend: type[AttentionBackend]
2929+
kv_cache_spec: KVCacheSpec
2930+
2931+
def get_attn_backends_for_group(
2932+
kv_cache_group_spec: KVCacheGroupSpec,
2933+
) -> dict[AttentionGroupKey, list[str]]:
2934+
layers = get_layers_from_vllm_config(
2935+
self.vllm_config, AttentionLayerBase,
2936+
kv_cache_group_spec.layer_names)
2937+
attn_backends = {}
2938+
attn_backend_layers = defaultdict(list)
2939+
# Dedupe based on full class name; this is a bit safer than
2940+
# using the class itself as the key because when we create dynamic
2941+
# attention backend subclasses (e.g. ChunkedLocalAttention) unless
2942+
# they are cached correctly, there will be different objects per
2943+
# layer.
2944+
for layer_name in kv_cache_group_spec.layer_names:
2945+
attn_backend = layers[layer_name].get_attn_backend()
2946+
full_cls_name = attn_backend.full_cls_name()
2947+
layer_kv_cache_spec = kv_cache_group_spec.kv_cache_spec
2948+
if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs):
2949+
layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[
2950+
layer_name]
2951+
key = (full_cls_name, layer_kv_cache_spec)
2952+
attn_backends[key] = AttentionGroupKey(attn_backend,
2953+
layer_kv_cache_spec)
2954+
attn_backend_layers[key].append(layer_name)
2955+
return {
2956+
attn_backends[k]: v
2957+
for k, v in attn_backend_layers.items()
2958+
}
2959+
29202960
def get_attn_backends_for_layers(
29212961
layer_names: list[str]
29222962
) -> dict[type[AttentionBackend], list[str]]:
2963+
"""Get attention_backend for all attention layers
2964+
TODO: Only used in v0.10.2, drop me when 0.10.2 is dropped
2965+
"""
29232966
layers = get_layers_from_vllm_config(self.vllm_config,
29242967
AttentionLayerBase,
29252968
layer_names)
@@ -2960,10 +3003,10 @@ def create_attn_groups_v0102(
29603003

29613004
def create_attn_groups(
29623005
attn_backends_map: dict[AttentionBackend, list[str]],
2963-
kv_cache_spec: KVCacheSpec,
29643006
) -> list[AttentionGroup]:
29653007
attn_groups: list[AttentionGroup] = []
2966-
for attn_backend, layer_names in attn_backends_map.items():
3008+
for (attn_backend,
3009+
kv_cache_spec), layer_names in attn_backends_map.items():
29673010
attn_metadata_builders = []
29683011
attn_metadata_builders.append(attn_backend.get_builder_cls()(
29693012
kv_cache_spec,
@@ -2973,27 +3016,50 @@ def create_attn_groups(
29733016
))
29743017
attn_group = AttentionGroup(attn_backend,
29753018
attn_metadata_builders,
2976-
layer_names)
3019+
layer_names, kv_cache_spec)
29773020
attn_groups.append(attn_group)
29783021
return attn_groups
29793022

2980-
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
2981-
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
2982-
attn_backends = get_attn_backends_for_layers(
2983-
kv_cache_group_spec.layer_names)
2984-
if vllm_version_is("0.10.2"):
3023+
if vllm_version_is("0.10.2"):
3024+
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
3025+
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
3026+
attn_backends = get_attn_backends_for_layers(
3027+
kv_cache_group_spec.layer_names)
29853028
self.attn_groups.append(
29863029
create_attn_groups_v0102(attn_backends, kv_cache_spec))
2987-
else:
2988-
self.attn_groups.append(
2989-
create_attn_groups(attn_backends, kv_cache_spec))
3030+
else:
3031+
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
3032+
attn_backends = get_attn_backends_for_group( # type: ignore
3033+
kv_cache_group_spec)
3034+
self.attn_groups.append(create_attn_groups(attn_backends))
29903035

29913036
# Calculate reorder batch threshold (if needed)
29923037
self.calculate_reorder_batch_threshold()
29933038

29943039
def _attn_group_iterator(self) -> Iterator[AttentionGroup]:
29953040
return itertools.chain.from_iterable(self.attn_groups)
29963041

3042+
def _kv_cache_spec_attn_group_iterator(self) -> Iterator[AttentionGroup]:
3043+
if not self.kv_cache_config.kv_cache_groups:
3044+
return
3045+
for attn_groups in self.attn_groups:
3046+
yield from attn_groups
3047+
3048+
def _kv_cache_spec_attn_group_iterator_v0102(
3049+
self) -> Iterator[tuple[KVCacheSpec, AttentionGroup]]:
3050+
if not self.kv_cache_config.kv_cache_groups:
3051+
return
3052+
for kv_cache_spec_id, attn_groups in enumerate(self.attn_groups):
3053+
for attn_group in attn_groups:
3054+
yield self.kv_cache_config.kv_cache_groups[
3055+
kv_cache_spec_id].kv_cache_spec, attn_group
3056+
3057+
def _kv_cache_spec_attn_group_iterator_dispatcher(self):
3058+
if vllm_version_is("0.10.2"):
3059+
return self._kv_cache_spec_attn_group_iterator_v0102()
3060+
else:
3061+
return self._kv_cache_spec_attn_group_iterator()
3062+
29973063
def calculate_reorder_batch_threshold(self) -> None:
29983064
"""
29993065
Check that if any backends reorder batches; that the reordering

0 commit comments

Comments
 (0)