Skip to content

Commit d2c305b

Browse files
authored
[CB] refactoring spyre model runner (#172)
This PR does some refactoring primarily on spyre_model_runner. This changes tries to reduce code deduplication between static batching and continuous batching. However, the intention of this work will not be complete until a next PR has as goal remove kv cache manager from the spyre model runner. Summary of changes: - Reduce code deduplication in spyre model runner, some methods are common in `SpyreMoldeRunner` class, while `StaticBatchingSpyreModelRunner` and `ContinuousBatchingSpyreModelRunner` override few of them to do their specific logic - Changed `ContinuousBatchingFmsModel` class to get the attention metadata via forward context, and changed the model runner to pass to use the `with set_forward_context` to pass the attention metadata. This is the way vLLM does to support multiple attention backends [[REF](vllm-project/vllm#10558)] - Moved the left pads to the CachedRequestState. - Bugfix: The `execute_model` in CB model runner was inconsistent with the data of input batch when it outputs the resul in `CBSpyreModelRunnerOutput`. Changed it with prepare_prompt to use the data of input batch. - Misc: few renamed variables, more comments, and TODOs --------- Signed-off-by: Wallas Santos <wallashss@ibm.com>
1 parent 94cee66 commit d2c305b

File tree

6 files changed

+276
-299
lines changed

6 files changed

+276
-299
lines changed

tests/e2e/test_spyre_basic.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,14 +128,11 @@ def test_output_sendnn_decoder(
128128
@pytest.mark.parametrize("cb",
129129
[pytest.param(1, marks=pytest.mark.cb, id="cb"), 0])
130130
def test_batch_handling(model: str, backend: str, cb: int,
131-
monkeypatch: pytest.MonkeyPatch, runtime_xfail):
131+
monkeypatch: pytest.MonkeyPatch):
132132
"""Test that the spyre worker correctly handles
133133
continuous batches of requests that
134134
finish after different numbers of forward passes"""
135135

136-
if cb == 1:
137-
runtime_xfail("Batch handling bug with continuous batching")
138-
139136
prompts = get_chicken_soup_prompts(4)
140137

141138
sampling_params1 = SamplingParams(max_tokens=5,

tests/v1/worker/test_spyre_input_batch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ def _construct_cached_request_state(req_id_suffix: int):
163163
sampling_params=_create_sampling_params(),
164164
generator=None,
165165
output_token_ids=output_token_ids,
166+
left_padding=0,
166167
)
167168

168169

vllm_spyre/model_executor/model_loader/spyre.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Utilities for selecting and loading Spyre models."""
22
import os
3-
from typing import Any, Optional
3+
from dataclasses import dataclass
4+
from typing import Any, Optional, cast
45

56
import torch
67
import torch._inductor.config
@@ -9,6 +10,7 @@
910
from fms.models import get_model
1011
from transformers import PretrainedConfig
1112
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
13+
from vllm.forward_context import get_forward_context
1214
from vllm.logger import init_logger
1315
from vllm.model_executor.layers.logits_processor import LogitsProcessor
1416
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
@@ -29,6 +31,14 @@
2931
logger = init_logger(__name__)
3032

3133

34+
@dataclass
35+
class SpyreAttentionMetadata:
36+
slot_mapping: torch.Tensor = None
37+
current_tkv_mask: torch.Tensor = None
38+
left_padded_prompt_mask: torch.Tensor = None
39+
block_table: torch.Tensor = None
40+
41+
3242
class SpyreCausalLM(nn.Module):
3343

3444
def __init__(
@@ -73,10 +83,6 @@ def forward(
7383
positions: torch.Tensor,
7484
masks: torch.Tensor,
7585
is_prompt: bool,
76-
current_tkv_mask: Optional[torch.Tensor] = None,
77-
left_padded_prompt_mask: Optional[torch.Tensor] = None,
78-
block_table: Optional[torch.Tensor] = None,
79-
slot_mapping: Optional[torch.Tensor] = None,
8086
) -> torch.Tensor:
8187

8288
if is_prompt and not envs_spyre.VLLM_SPYRE_USE_CB:
@@ -88,12 +94,6 @@ def forward(
8894
# cpu impl when padding too much
8995
extra_kwargs["attn_algorithm"] = "math"
9096

91-
if envs_spyre.VLLM_SPYRE_USE_CB:
92-
extra_kwargs["current_tkv_mask"] = current_tkv_mask
93-
extra_kwargs["left_padded_prompt_mask"] = left_padded_prompt_mask
94-
extra_kwargs["block_table"] = block_table
95-
extra_kwargs["slot_mapping"] = slot_mapping
96-
9797
# normal prefill or decoding step
9898
logits = self.model(
9999
input_ids,
@@ -353,32 +353,30 @@ def forward(
353353
mask: torch.Tensor,
354354
use_cache: bool,
355355
only_last_token: bool,
356-
current_tkv_mask: torch.Tensor,
357-
left_padded_prompt_mask: torch.Tensor,
358-
block_table: torch.Tensor,
359-
slot_mapping: torch.Tensor,
360356
**extra_kwargs,
361357
) -> torch.Tensor:
362358

359+
forward_context = get_forward_context()
360+
361+
attn_metadata = cast(SpyreAttentionMetadata,
362+
forward_context.attn_metadata)
363363
# import will be not be needed/ handled by FMS soon
364364
import fms.utils.spyre.paged # noqa # pylint: disable=unused-import
365365

366366
# specify attention type for continuous batching
367367
extra_kwargs['attn_name'] = "spyre_paged_attn"
368368

369-
# additional (paged) attention arguments
370-
extra_kwargs['current_tkv_mask'] = current_tkv_mask
371-
extra_kwargs['left_padded_prompt_mask'] = left_padded_prompt_mask
372-
extra_kwargs['block_table'] = block_table
373-
extra_kwargs['slot_mapping'] = slot_mapping
374-
375369
output = self.model(
376370
input_ids,
377371
position_ids=position_ids,
378372
mask=mask,
379373
past_key_value_states=self.past_key_value_states,
380374
use_cache=use_cache,
381375
only_last_token=only_last_token,
376+
current_tkv_mask=attn_metadata.current_tkv_mask,
377+
left_padded_prompt_mask=attn_metadata.left_padded_prompt_mask,
378+
block_table=attn_metadata.block_table,
379+
slot_mapping=attn_metadata.slot_mapping,
382380
**extra_kwargs,
383381
)
384382

vllm_spyre/v1/worker/spyre_input_batch.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class CachedRequestState:
2323
generator: Optional[torch.Generator]
2424

2525
output_token_ids: list[int]
26+
left_padding: int = 0 # Defaults to 0, i. e. not padding
2627

2728
@property
2829
def num_tokens(self) -> int:
@@ -565,3 +566,8 @@ def no_allowed_token_ids(self) -> bool:
565566
@property
566567
def requests_ids(self) -> list[str]:
567568
return list(self.req_id_to_index.keys())
569+
570+
@property
571+
def sorted_requests_ids(self) -> list[str]:
572+
return sorted(self.req_id_to_index,
573+
key=self.req_id_to_index.get) # type: ignore

0 commit comments

Comments
 (0)