Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 3 additions & 14 deletions tpu_inference/runner/speculative_decoding_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from tpu_inference.runner import utils as runner_utils
from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer
from tpu_inference.utils import device_array

if TYPE_CHECKING:
from tpu_inference.layers.common.attention_metadata import \
Expand Down Expand Up @@ -109,8 +108,7 @@ def propose_eagle3_draft_token_ids(
assert pad_len >= 0
next_token_ids += [0] * pad_len

next_token_ids = device_array(
self.runner.mesh, np.array(next_token_ids, dtype=jnp.int32))
next_token_ids = np.array(next_token_ids, dtype=jnp.int32)

if spec_decode_metadata is None:
num_rejected_tokens = None
Expand All @@ -123,9 +121,8 @@ def propose_eagle3_draft_token_ids(

pad_len = self.runner.max_num_reqs - len(num_rejected_tokens)
num_rejected_tokens += [0] * pad_len
num_rejected_tokens = device_array(
self.runner.mesh, np.array(num_rejected_tokens,
dtype=jnp.int32))
num_rejected_tokens = np.array(num_rejected_tokens,
dtype=jnp.int32)

target_hidden_states, input_ids, last_token_indices, attn_metadata = self.runner.drafter.prepare_inputs(
attn_metadata,
Expand Down Expand Up @@ -228,14 +225,6 @@ def get_spec_decode_metadata(
])

padded_num_draft_tokens_cpu = padded_num_draft_tokens
# CPU -> TPU copy.
(padded_num_draft_tokens, padded_draft_token_ids,
padded_logits_indices, padded_target_logits_indices,
padded_bonus_logits_indices) = device_array(
self.runner.mesh,
(padded_num_draft_tokens, padded_draft_token_ids,
padded_logits_indices, padded_target_logits_indices,
padded_bonus_logits_indices))

metadata = SpecDecodeMetadata(
draft_token_ids=padded_draft_token_ids,
Expand Down
26 changes: 21 additions & 5 deletions tpu_inference/runner/tpu_jax_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,6 @@ def _prepare_inputs(self, scheduler_output: "VllmSchedulerOutput"):
spec_decode_metadata = self.speculative_decoding_manager.get_spec_decode_metadata(
num_draft_tokens, self.query_start_loc_cpu[1:num_reqs + 1],
padded_num_reqs)
logits_indices = spec_decode_metadata.final_logits_indices

# Put to device
sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
Expand All @@ -696,10 +695,27 @@ def _prepare_inputs(self, scheduler_output: "VllmSchedulerOutput"):

query_start_loc_cpu = query_start_loc
seq_lens_cpu = seq_lens
(input_ids, positions, block_tables, query_start_loc, seq_lens,
logits_indices, request_distribution) = device_array(
self.mesh, (input_ids, positions, block_tables, query_start_loc,
seq_lens, logits_indices, request_distribution))
if not spec_decode_metadata:
(input_ids, positions, block_tables, query_start_loc, seq_lens,
logits_indices, request_distribution) = device_array(
self.mesh,
(input_ids, positions, block_tables, query_start_loc,
seq_lens, logits_indices, request_distribution))
else:
(input_ids, positions, block_tables, query_start_loc, seq_lens,
request_distribution, spec_decode_metadata.draft_token_ids,
spec_decode_metadata.draft_lengths,
spec_decode_metadata.target_logits_indices,
spec_decode_metadata.bonus_logits_indices,
spec_decode_metadata.final_logits_indices) = device_array(
self.mesh, (input_ids, positions, block_tables,
query_start_loc, seq_lens, request_distribution,
spec_decode_metadata.draft_token_ids,
spec_decode_metadata.draft_lengths,
spec_decode_metadata.target_logits_indices,
spec_decode_metadata.bonus_logits_indices,
spec_decode_metadata.final_logits_indices))
logits_indices = spec_decode_metadata.final_logits_indices

if self.lora_config is not None:
self.lora_utils.set_active_loras(
Expand Down
22 changes: 10 additions & 12 deletions tpu_inference/spec_decode/jax/eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ def prepare_inputs(
attn_metadata: AttentionMetadata,
input_ids: jax.Array,
aux_hidden_states: tuple[jax.Array, ...],
next_token_ids: jax.Array,
num_rejected_tokens: Optional[jax.Array] = None,
next_token_ids: np.array,
num_rejected_tokens: Optional[np.array] = None,
) -> tuple[jax.Array, jax.Array, jax.Array, AttentionMetadata]:
"""Prepare drafter inputs based on target forward outputs.

Expand All @@ -169,8 +169,9 @@ def prepare_inputs(
num_reqs = self.runner.input_batch.num_reqs

if num_rejected_tokens is None:
num_reqs = device_array(self.mesh,
np.asarray([num_reqs], dtype=jnp.int32))
num_reqs, next_token_ids = device_array(
self.mesh,
(np.asarray([num_reqs], dtype=jnp.int32), next_token_ids))
# block_tables = device_array(self.mesh, block_tables)
attn_metadata = replace(attn_metadata,
block_tables=device_array(
Expand All @@ -185,10 +186,6 @@ def prepare_inputs(
seq_lens_cpu = attn_metadata.seq_lens_cpu
assert query_start_loc_cpu is not None and seq_lens_cpu is not None

# Rejection-aware path: compute new per-request lengths and token indices.
# Convert to host numpy for efficient prefix-sum and repeat ops.
nrt_cpu = jax.device_get(num_rejected_tokens).astype("int32")

# query_len_per_req = [q1, q2, ...]
query_len_per_req = (query_start_loc_cpu[1:] -
query_start_loc_cpu[:-1])
Expand All @@ -197,7 +194,7 @@ def prepare_inputs(
# For padded requests, the query length should be 0.
query_len_per_req[num_reqs:] = 1
# num_tokens_per_req = [q1 - n1, q2 - n2, ...]
num_tokens_per_req = (query_len_per_req - nrt_cpu)
num_tokens_per_req = (query_len_per_req - num_rejected_tokens)
assert (num_tokens_per_req
>= 0).all(), ("num_tokens_per_req must be non-negative")

Expand Down Expand Up @@ -232,12 +229,13 @@ def prepare_inputs(
"constant",
constant_values=0)
# Update seq_lens for active requests only: new_seq_lens = s - n.
new_seq_lens_cpu = seq_lens_cpu - nrt_cpu
new_seq_lens_cpu = seq_lens_cpu - num_rejected_tokens

query_start_loc, seq_lens, token_indices, num_reqs, block_tables = device_array(
query_start_loc, seq_lens, token_indices, num_reqs, block_tables, next_token_ids = device_array(
self.mesh,
(new_query_start_loc_cpu, new_seq_lens_cpu, token_indices_cpu,
np.asarray([num_reqs], dtype=jnp.int32), block_tables))
np.asarray([num_reqs],
dtype=jnp.int32), block_tables, next_token_ids))

attn_metadata = replace(attn_metadata, block_tables=block_tables)
return self._filter_token_and_prepare_initial_inputs(
Expand Down