Skip to content

Commit 2bba577

Browse files
py4Pooya Moradi
andauthored
[Runner] Separate execute_model and sample_tokens to adapt upstream change. (#1003)
Signed-off-by: Pooya Moradi <pooyam@google.com> Co-authored-by: Pooya Moradi <pooyam@google.com>
1 parent 6c68c29 commit 2bba577

File tree

3 files changed

+139
-61
lines changed

3 files changed

+139
-61
lines changed

tpu_inference/runner/structured_decoding_manager.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
from tpu_inference.utils import device_array
88

99
if TYPE_CHECKING:
10-
from vllm.v1.core.sched.output import \
11-
SchedulerOutput as VllmSchedulerOutput
10+
from vllm.v1.core.sched.output import GrammarOutput
1211

1312
from tpu_inference.runner.tpu_jax_runner import TPUModelRunner
1413

@@ -51,9 +50,9 @@ def _apply_grammar_bitmask_kernel(self, logits: jax.Array,
5150
return jnp.where(require_struct_decoding, masked_logits, logits)
5251

5352
def prepare_structured_decoding_input(
54-
self, logits: jax.Array, scheduler_output: "VllmSchedulerOutput"
53+
self, logits: jax.Array, grammar_output: "GrammarOutput"
5554
) -> Tuple[jax.Array, jax.Array, jax.Array]:
56-
grammar_bitmask = scheduler_output.grammar_bitmask
55+
grammar_bitmask = grammar_output.grammar_bitmask
5756
assert grammar_bitmask is not None
5857
num_reqs, _ = logits.shape
5958

@@ -62,7 +61,7 @@ def prepare_structured_decoding_input(
6261
self.runner.require_structured_out_cpu.fill(0)
6362

6463
sorted_struct_requests = sorted(
65-
scheduler_output.structured_output_request_ids.items(),
64+
grammar_output.structured_output_request_ids.items(),
6665
key=lambda item: item[1])
6766

6867
cumulative_mask_idx = 0

tpu_inference/runner/tpu_jax_runner.py

Lines changed: 129 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@
2121
from vllm.sequence import IntermediateTensors
2222
from vllm.tasks import SupportedTask
2323
from vllm.utils.math_utils import cdiv
24+
from vllm.v1.core.sched.output import GrammarOutput
2425
from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
2526
from vllm.v1.kv_cache_interface import KVCacheConfig
2627
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
27-
DraftTokenIds, ModelRunnerOutput)
28+
DraftTokenIds, KVConnectorOutput,
29+
ModelRunnerOutput)
2830
from vllm.v1.request import Request
2931
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
3032
from vllm.v1.worker.kv_connector_model_runner_mixin import \
@@ -51,8 +53,8 @@
5153
from tpu_inference.runner.multimodal_manager import MultiModalManager
5254
from tpu_inference.runner.persistent_batch_manager import \
5355
PersistentBatchManager
54-
from tpu_inference.runner.speculative_decoding_manager import \
55-
SpeculativeDecodingManager
56+
from tpu_inference.runner.speculative_decoding_manager import (
57+
SpecDecodeMetadata, SpeculativeDecodingManager)
5658
from tpu_inference.runner.structured_decoding_manager import \
5759
StructuredDecodingManager
5860
from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer
@@ -126,6 +128,21 @@ class AsyncPreResults:
126128
placeholder_req_id_to_index: dict[str, int]
127129

128130

131+
@dataclass
132+
class ExecuteModelState:
133+
"""Ephemeral cached state transferred between execute_model() and
134+
sample_tokens(), after execute_model() returns None."""
135+
136+
scheduler_output: "VllmSchedulerOutput"
137+
attn_metadata: AttentionMetadata
138+
input_ids: Optional[jax.Array]
139+
hidden_states: jax.Array
140+
logits: jax.Array
141+
aux_hidden_states: Optional[jax.Array]
142+
spec_decode_metadata: Optional[SpecDecodeMetadata]
143+
kv_connector_output: Optional[KVConnectorOutput]
144+
145+
129146
@functools.partial(jax.jit, donate_argnums=(0, 1, 2))
130147
def _substitute_placeholder_token(
131148
input_ids: jax.Array, token_in_tpu_cur_input_indices: jax.Array,
@@ -215,6 +232,7 @@ def __init__(
215232

216233
self._pre_async_results: AsyncPreResults | None = None
217234
self._substitute_placeholder_token_fn = _substitute_placeholder_token
235+
self.execute_model_state: ExecuteModelState | None = None
218236

219237
def _init_random(self):
220238
if self.model_config.seed is None:
@@ -430,9 +448,49 @@ def execute_model(
430448
self,
431449
scheduler_output: "VllmSchedulerOutput",
432450
intermediate_tensors: Optional[IntermediateTensors] = None,
451+
) -> ModelRunnerOutput | None:
452+
if self.execute_model_state is not None:
453+
raise RuntimeError("State error: sample_tokens() must be called "
454+
"after execute_model() returns None.")
455+
_, output = self._execute_model(scheduler_output)
456+
return output
457+
458+
def sample_tokens(
459+
self,
460+
grammar_output: "GrammarOutput | None",
433461
) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput:
434-
435-
return self._execute_model(scheduler_output)[1]
462+
if self.execute_model_state is None:
463+
# This can happen in pipeline parallel case.
464+
return EMPTY_MODEL_RUNNER_OUTPUT
465+
466+
(scheduler_output, attn_metadata, input_ids, hidden_states, logits,
467+
aux_hidden_states, spec_decode_metadata,
468+
kv_connector_output) = (self.execute_model_state.scheduler_output,
469+
self.execute_model_state.attn_metadata,
470+
self.execute_model_state.input_ids,
471+
self.execute_model_state.hidden_states,
472+
self.execute_model_state.logits,
473+
self.execute_model_state.aux_hidden_states,
474+
self.execute_model_state.spec_decode_metadata,
475+
self.execute_model_state.kv_connector_output)
476+
self.execute_model_state = None
477+
478+
if grammar_output is not None:
479+
(
480+
require_struct_decoding, grammar_bitmask_padded, arange
481+
) = self.structured_decoding_manager.prepare_structured_decoding_input(
482+
logits, grammar_output)
483+
logits = self.structured_decoding_manager.structured_decode_fn(
484+
require_struct_decoding,
485+
grammar_bitmask_padded,
486+
logits,
487+
arange,
488+
)
489+
return self._sample_from_logits(scheduler_output, attn_metadata,
490+
input_ids, hidden_states, logits,
491+
aux_hidden_states,
492+
spec_decode_metadata,
493+
kv_connector_output)
436494

437495
def _modify_prev_results(self):
438496
# If copy to host has not been done, we just wait.
@@ -510,8 +568,7 @@ def _update_placeholder(self, discard_sampled_tokens_req_indices,
510568
def _execute_model(
511569
self,
512570
scheduler_output: "VllmSchedulerOutput",
513-
) -> tuple[AttentionMetadata, ModelRunnerOutput
514-
| AsyncTPUModelRunnerOutput]:
571+
) -> tuple[AttentionMetadata, ModelRunnerOutput | None]:
515572
self.persistent_batch_manager.update_states(
516573
scheduler_output, self.get_mrope_input_positions_fn)
517574
if not scheduler_output.total_num_scheduled_tokens:
@@ -529,9 +586,10 @@ def _execute_model(
529586
"Should not schedule a request that does nothing!")
530587
# raise Exception(
531588
# "Should not schedule a request that does nothing!")
532-
return DUMMY_METADATA, EMPTY_MODEL_RUNNER_OUTPUT,
589+
return DUMMY_METADATA, EMPTY_MODEL_RUNNER_OUTPUT
533590

534-
(input_ids, attn_metadata, sampling_metadata, logits_indices,
591+
# TODO(pooyam): I guess we can remove returning sampling_metadata in `_prepare_inputs` after https://github.com/njhill/vllm/commit/b7433ca1a47732394b1bdea4099d98389515954b
592+
(input_ids, attn_metadata, _, logits_indices,
535593
spec_decode_metadata) = self._prepare_inputs(scheduler_output)
536594

537595
# multi-modal support
@@ -584,51 +642,67 @@ def _execute_model(
584642
hidden_states,
585643
lora_metadata,
586644
)
587-
if scheduler_output.grammar_bitmask is not None:
588-
(
589-
require_struct_decoding, grammar_bitmask_padded, arange
590-
) = self.structured_decoding_manager.prepare_structured_decoding_input(
591-
logits, scheduler_output)
592-
logits = self.structured_decoding_manager.structured_decode_fn(
593-
require_struct_decoding,
594-
grammar_bitmask_padded,
595-
logits,
596-
arange,
597-
)
598-
tpu_sampling_metadata = sampling_metadata
599-
if spec_decode_metadata is None:
600-
next_tokens = sample(
601-
self.rng_params_for_sampling,
602-
self.mesh,
603-
logits,
604-
tpu_sampling_metadata,
605-
)
606-
else:
607-
bonus_logits = self._select_from_array_fn(
608-
logits, spec_decode_metadata.bonus_logits_indices)
609-
bonus_token_ids = sample(
610-
self.rng_params_for_sampling,
611-
self.mesh,
612-
bonus_logits,
613-
tpu_sampling_metadata,
614-
)
615-
target_logits = self._select_from_array_fn(
616-
logits, spec_decode_metadata.target_logits_indices)
617-
next_tokens = self.rejection_sampler(
618-
draft_token_ids=spec_decode_metadata.draft_token_ids,
619-
num_draft_tokens=spec_decode_metadata.draft_lengths,
620-
draft_probs=None,
621-
target_logits=target_logits,
622-
bonus_token_ids=bonus_token_ids,
623-
sampling_metadata=tpu_sampling_metadata,
624-
key=self.rng_params_for_sampling,
625-
)
626645

627-
if tpu_sampling_metadata.logprobs:
628-
logprobs = self._compute_and_gather_logprobs(
629-
logits, next_tokens, self.model_config.max_logprobs)
630-
else:
631-
logprobs = None
646+
self.execute_model_state = ExecuteModelState(
647+
scheduler_output=scheduler_output,
648+
attn_metadata=attn_metadata,
649+
input_ids=input_ids,
650+
hidden_states=hidden_states,
651+
logits=logits,
652+
aux_hidden_states=aux_hidden_states,
653+
spec_decode_metadata=spec_decode_metadata,
654+
kv_connector_output=kv_connector_output,
655+
)
656+
return attn_metadata, None
657+
658+
def _sample_from_logits(
659+
self,
660+
scheduler_output: "VllmSchedulerOutput",
661+
attn_metadata: AttentionMetadata,
662+
input_ids: Optional[jax.Array],
663+
hidden_states: jax.Array,
664+
logits: jax.Array,
665+
aux_hidden_states: Optional[jax.Array],
666+
spec_decode_metadata: Optional[SpecDecodeMetadata],
667+
kv_connector_output: Optional[KVConnectorOutput],
668+
) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput:
669+
padded_num_reqs = runner_utils.get_padded_num_reqs_with_upper_limit(
670+
self.input_batch.num_reqs, self.max_num_reqs)
671+
tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
672+
self.mesh, self.input_batch, padded_num_reqs)
673+
if spec_decode_metadata is None:
674+
next_tokens = sample(
675+
self.rng_params_for_sampling,
676+
self.mesh,
677+
logits,
678+
tpu_sampling_metadata,
679+
)
680+
else:
681+
bonus_logits = self._select_from_array_fn(
682+
logits, spec_decode_metadata.bonus_logits_indices)
683+
bonus_token_ids = sample(
684+
self.rng_params_for_sampling,
685+
self.mesh,
686+
bonus_logits,
687+
tpu_sampling_metadata,
688+
)
689+
target_logits = self._select_from_array_fn(
690+
logits, spec_decode_metadata.target_logits_indices)
691+
next_tokens = self.rejection_sampler(
692+
draft_token_ids=spec_decode_metadata.draft_token_ids,
693+
num_draft_tokens=spec_decode_metadata.draft_lengths,
694+
draft_probs=None,
695+
target_logits=target_logits,
696+
bonus_token_ids=bonus_token_ids,
697+
sampling_metadata=tpu_sampling_metadata,
698+
key=self.rng_params_for_sampling,
699+
)
700+
701+
if tpu_sampling_metadata.logprobs:
702+
logprobs = self._compute_and_gather_logprobs(
703+
logits, next_tokens, self.model_config.max_logprobs)
704+
else:
705+
logprobs = None
632706

633707
num_reqs = self.input_batch.num_reqs
634708

@@ -707,7 +781,7 @@ def _execute_model(
707781
async_model_runner_output = AsyncTPUModelRunnerOutput(
708782
model_runner_output, next_tokens, num_reqs,
709783
discard_sampled_tokens_req_indices)
710-
return attn_metadata, async_model_runner_output
784+
return async_model_runner_output
711785

712786
if spec_decode_metadata is None:
713787
next_tokens = np.asarray(jax.device_get(next_tokens))
@@ -766,7 +840,7 @@ def _execute_model(
766840
pooler_output=[],
767841
kv_connector_output=kv_connector_output,
768842
)
769-
return attn_metadata, model_runner_output
843+
return model_runner_output
770844

771845
@functools.partial(jax.jit, static_argnums=(0, ))
772846
def _select_from_array_fn(self, array, indices_to_select):

tpu_inference/worker/tpu_worker_jax.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from vllm.tasks import SupportedTask
1919
from vllm.v1 import utils as vllm_utils
2020
from vllm.v1.core.kv_cache_utils import get_num_blocks, get_uniform_page_size
21-
from vllm.v1.core.sched.output import SchedulerOutput
21+
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
2222
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
2323
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
2424

@@ -200,11 +200,16 @@ def execute_model(
200200
output = self.model_runner.execute_model(scheduler_output)
201201

202202
# With a connector, the scheduler expects output from all workers
203+
# TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866
203204
if has_kv_transfer_group():
204205
return output
205206

206207
return output if self.is_driver_worker else None
207208

209+
def sample_tokens(self,
210+
grammar_output: GrammarOutput) -> ModelRunnerOutput:
211+
return self.model_runner.sample_tokens(grammar_output)
212+
208213
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
209214
return self.model_runner.take_draft_token_ids()
210215

0 commit comments

Comments
 (0)