Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,6 +1010,18 @@ def is_encoder_decoder(self) -> bool:
"""Extract the HF encoder/decoder model flag."""
return is_encoder_decoder(self.hf_config)

@property
def requires_multi_step_decode(self) -> bool:
return getattr(self.hf_config, "model_type", "")=="deepseek_mtp" and \
getattr(self.hf_config, "num_nextn_predict_layers", 0) > 1

@property
def num_decode_modules(self) -> int:
if getattr(self.hf_config, "model_type", "") == "deepseek_mtp":
return getattr(self.hf_config, "num_nextn_predict_layers", 0)
else:
return 1

@property
def uses_mrope(self) -> bool:
return uses_mrope(self.hf_config)
Expand Down Expand Up @@ -3468,7 +3480,8 @@ def _set_cudagraph_sizes(self):
# which then becomes the max_batchsize_to_capture
larger_sizes = [
x for x in possible_sizes
if x >= self.scheduler_config.max_num_seqs
if x >= self.scheduler_config.max_num_seqs *
self.model_config.num_decode_modules
]
if larger_sizes:
max_batchsize_to_capture = larger_sizes[0]
Expand All @@ -3481,6 +3494,7 @@ def _set_cudagraph_sizes(self):
size for size in possible_sizes
if size <= max_batchsize_to_capture
]
# print(f"{batch_size_capture_list=}")
else:
batch_size_capture_list = []
if self.model_config is not None and \
Expand Down
2 changes: 2 additions & 0 deletions vllm/engine/output_processor/multi_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ def _process_seq_outputs(self, seq: Sequence,
is_prefill_sampled_token = seq.data.get_num_uncomputed_tokens() == 0
# Incrementally append tokens to the sequence, as if we had only one new
# token.
# TODO: add an attribute here for reset, can be set at output processor
seq.data.reset_new_appended_tokens()
for output_token_id, output_logprob in zip(output_token_ids,
output_logprobs):
seq.append_token_id(
Expand Down
53 changes: 53 additions & 0 deletions vllm/model_executor/models/deepseek_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,47 @@
return self.model.compute_logits(hidden_states, sampling_metadata,
spec_step_idx)

def generate_proposals(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],

Check failure on line 182 in vllm/model_executor/models/deepseek_mtp.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F821)

vllm/model_executor/models/deepseek_mtp.py:182:20: F821 Undefined name `List`
attn_metadata: AttentionMetadata,

Check failure on line 183 in vllm/model_executor/models/deepseek_mtp.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F821)

vllm/model_executor/models/deepseek_mtp.py:183:24: F821 Undefined name `AttentionMetadata`
previous_hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> List[SamplerOutput]:

Check failure on line 186 in vllm/model_executor/models/deepseek_mtp.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F821)

vllm/model_executor/models/deepseek_mtp.py:186:10: F821 Undefined name `List`
hidden_states = previous_hidden_states
cur_input_ids = input_ids
outputs = []
for i in range(self.model.num_mtp_layers):
hidden_states = self.forward(cur_input_ids,
positions,
kv_caches,
attn_metadata,
hidden_states,
spec_step_idx=i)
logits = self.compute_logits(hidden_states=hidden_states,
sampling_metadata=sampling_metadata,
spec_step_idx=i)
output = self.sample(
logits=logits,
sampling_metadata=sampling_metadata,
)
outputs.append(output)
cur_input_ids = self.get_next_layer_input(input_ids, attn_metadata,
output)
return outputs

def get_next_layer_input(
self, input_ids: torch.Tensor, attn_metadata: AttentionMetadata,

Check failure on line 210 in vllm/model_executor/models/deepseek_mtp.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F821)

vllm/model_executor/models/deepseek_mtp.py:210:59: F821 Undefined name `AttentionMetadata`
outputs: SamplerOutput) -> Tuple[torch.Tensor, SamplerOutput]:
assert outputs.sampled_token_ids is not None
assert attn_metadata.query_start_loc is not None
input_ids = input_ids.roll(shifts=-1, dims=0)
query_end_loc = attn_metadata.query_start_loc[1:] - 1
input_ids[query_end_loc] = outputs.sampled_token_ids[:, 0]
return input_ids

def sample(
self,
logits: torch.Tensor,
Expand All @@ -183,6 +224,18 @@
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def get_last_sample_output(
self,
output: SamplerOutput,
attn_metadata: AttentionMetadata,

Check failure on line 230 in vllm/model_executor/models/deepseek_mtp.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F821)

vllm/model_executor/models/deepseek_mtp.py:230:24: F821 Undefined name `AttentionMetadata`
) -> SamplerOutput:
query_end_loc = attn_metadata.query_start_loc[1:] - 1
output.sampled_token_ids = output.sampled_token_ids[query_end_loc]
if output.sampled_token_probs is not None:
output.sampled_token_probs = output.sampled_token_probs[
query_end_loc]
return output

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
Expand Down
37 changes: 30 additions & 7 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,9 @@ def get_delta_and_reset(self) -> SequenceDataDelta:
self._new_appended_tokens = []
return delta

def reset_new_appended_tokens(self) -> None:
self._new_appended_tokens = []

def apply_delta(self, delta: SequenceDataDelta):
self._num_computed_tokens = delta.new_num_computed_tokens
self._cumulative_logprob = delta.new_cumulative_logprob
Expand Down Expand Up @@ -1212,12 +1215,13 @@ class HiddenStates(msgspec.Struct, array_like=True,
# last proposed token is accepted (i.e., in case of bonus tokens). For the
# case of no bonus tokens, these are ignored.
second_last_token_hidden_states: Optional[torch.Tensor] = None

# for varseq
hidden_states_seq_indices: Optional[torch.Tensor] = None
_seq_ids: List[int] = msgspec.field(default_factory=list)

def __post_init__(self):
if self.seq_group_metadata_list is not None:
assert len(self.seq_group_metadata_list) == len(self.hidden_states)
# TODO: add assertion for the group metadata list with var seqs
self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)

@property
Expand All @@ -1231,8 +1235,20 @@ def update(self,
"""Update hidden states from target model invocation. Only used for
decode steps"""
assert len(seq_group_metadata_list) == len(hidden_states)
self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
last_seq_indice = len(self._seq_ids)
new_seq_ids = get_all_seq_ids(seq_group_metadata_list)
self._seq_ids.extend(new_seq_ids)
self.hidden_states = torch.cat([self.hidden_states, hidden_states])
if self.hidden_states_seq_indices is not None:
updated_indices = list(range(last_seq_indice, len(self._seq_ids)))
# assume new updated are hidden states from
# prefill which is always length of 1
new_seq_indices = torch.tensor(
updated_indices, device=self.hidden_states_seq_indices.device)
self.hidden_states_seq_indices = torch.concat([
self.hidden_states_seq_indices,
new_seq_indices,
])

if self.second_last_token_hidden_states is not None:
# Adding dummy hidden_states to this to maintain same shape
Expand All @@ -1255,10 +1271,17 @@ def prune(self,
if seq_ids != self._seq_ids:
# Batch contents changed - prune removed sequences.
index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
self.hidden_states = self.hidden_states[index]
if self.second_last_token_hidden_states is not None:
self.second_last_token_hidden_states = self\
.second_last_token_hidden_states[index]
if self.hidden_states_seq_indices is not None:
target_indices_tensor = torch.tensor(
index, device=self.hidden_states_seq_indices.device)
index = (self.hidden_states_seq_indices[..., None] ==
target_indices_tensor).any(dim=-1)
self.hidden_states = self.hidden_states[index]
else:
self.hidden_states = self.hidden_states[index]
if self.second_last_token_hidden_states is not None:
self.second_last_token_hidden_states = self\
.second_last_token_hidden_states[index]
self._seq_ids = seq_ids

def expand_with_bonus_tokens(
Expand Down
65 changes: 51 additions & 14 deletions vllm/spec_decode/draft_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
# vllm_flash_attn is not installed, try the ROCm FA metadata
from vllm.attention.backends.rocm_flash_attn import (
ROCmFlashAttentionMetadata as FlashAttentionMetadata)
try:
from vllm.attention.backends.triton_mla import TritonMLAMetadata
except (ModuleNotFoundError, ImportError):
TritonMLAMetadata = FlashAttentionMetadata

except (ModuleNotFoundError, ImportError) as err:
raise RuntimeError(
"Draft model speculative decoding currently only supports "
Expand Down Expand Up @@ -57,7 +62,7 @@
"return_hidden_states is not supported for TP1DraftModelRunner."
)
super().__init__(model_runner)

self.mtp = False
self.indices_of_seq_with_bonus_tokens = None

def _update_sampling_metadata(self, sampling_metadata, num_seqs,
Expand Down Expand Up @@ -92,7 +97,8 @@

# Update attn_metadata
attn_metadata = model_input.attn_metadata
assert isinstance(attn_metadata, FlashAttentionMetadata)
assert isinstance(attn_metadata,
(FlashAttentionMetadata, TritonMLAMetadata))

attn_metadata.advance_step(model_input, sampled_token_ids,
self.block_size, num_seqs, num_queries)
Expand Down Expand Up @@ -193,6 +199,7 @@
# iteration invokes this function only once
# (Look at multi-step-worker code)
is_fallback = num_steps == 1
self.mtp = self.model.config.model_type == "deepseek_mtp"
if not is_fallback:
# Since we do not broadcast data inside execute_model anymore,
# we need to figure out the best way to support TP > 1 in this
Expand Down Expand Up @@ -269,6 +276,9 @@
hidden_states = previous_hidden_states

outputs: List[SamplerOutput] = []
input_tokens = model_input.input_tokens
input_positions = model_input.input_positions
attn_metadata = model_input.attn_metadata
for step in range(num_steps):
multi_modal_kwargs = model_input.multi_modal_kwargs or {}

Expand All @@ -277,37 +287,64 @@

compute_logits_kwargs = {}
# Run model
if hasattr(self.model.config, "num_nextn_predict_layers"):
spec_step_idx = kwargs.get("spec_step_idx", 0)
if self.model_config.requires_multi_step_decode:
# for DeepSeek MTP only to use the corresponding layer for
# each step
spec_step_idx = kwargs.get("spec_step_idx", step)
model_execute_kwargs["spec_step_idx"] = spec_step_idx
compute_logits_kwargs["spec_step_idx"] = spec_step_idx
with set_forward_context(model_input.attn_metadata,
self.vllm_config):
if spec_step_idx >= 0:
model_execute_kwargs["spec_step_idx"] = spec_step_idx
compute_logits_kwargs["spec_step_idx"] = spec_step_idx

graph_batch_size = model_input.input_tokens.shape[0]
graph_idx = self.parallel_config.pipeline_parallel_size * spec_step_idx + model_input.virtual_engine

Check failure on line 300 in vllm/spec_decode/draft_model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/spec_decode/draft_model_runner.py:300:81: E501 Line too long (120 > 80)
model_executable = self.graph_runners[graph_idx][graph_batch_size]
elif not use_cuda_graph:
# for single step prefill
with set_forward_context(attn_metadata, self.vllm_config):
return model_executable.generate_proposals(
input_ids=input_tokens,
positions=input_positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
sampling_metadata=model_input.sampling_metadata,
**model_execute_kwargs,
)
# model_execute_kwargs["spec_step_idx"] = spec_step_idx
with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
input_ids=input_tokens,
positions=input_positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device),
**model_execute_kwargs,
)

# Compute the logits.
logits = self.model.compute_logits(hidden_states,
model_input.sampling_metadata,
**compute_logits_kwargs)
logits = self.model.compute_logits(
hidden_states, # do not sample for the previous tokens
model_input.sampling_metadata,
**compute_logits_kwargs)
if not self.is_driver_worker:
return []
# Sample the next token.
output = self.model.sample(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
# TODO: do sampling/compute logits for the last token only
if self.mtp:
# return last token only for each step for MTP
output = self.model.get_last_sample_output(
output, attn_metadata)
input_tokens = self.model.get_next_layer_input(
input_tokens, attn_metadata, output)
outputs.append(output)

if model_input.attn_metadata.num_prefills == 0 \
if not self.mtp and model_input.attn_metadata.num_prefills == 0 \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this block skipped?

and self.indices_of_seq_with_bonus_tokens is not None:
assert output.sampled_token_ids is not None
# output.sampled_token_ids should be of shape (num_seqs, 1)
Expand All @@ -327,7 +364,7 @@
count += 1

# Prepare inputs for the next step
if step != num_steps - 1:
if step != num_steps - 1 and not self.mtp:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the multi-step logic omitted here, and self.mtp is just using TP1DraftModelRunner in is_fallback mode?

model_input = self._gpu_advance_step(model_input, outputs[-1])

return outputs
27 changes: 19 additions & 8 deletions vllm/spec_decode/multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ def set_should_modify_greedy_probs_inplace(self) -> None:
self.model_runner.model.sampler.should_modify_greedy_probs_inplace = (
True)

@property
def has_mtp_runner(self) -> bool:
return getattr(self.model_runner, "mtp", False)

@torch.inference_mode()
def sampler_output(
self,
Expand All @@ -74,10 +78,13 @@ def sampler_output(
# Expand the batch for sequences with a bonus token.
# Perform a forward pass on the expanded batch and filter the
# response to retain only the original sequences' responses.
expanded_request, indices_of_seq_with_bonus_tokens =\
self._expand_execute_model_request(
execute_model_req, seq_ids_with_bonus_token_in_last_step)

if self.has_mtp_runner:
expanded_request, indices_of_seq_with_bonus_tokens =\
execute_model_req, []
else:
expanded_request, indices_of_seq_with_bonus_tokens =\
self._expand_execute_model_request(
execute_model_req, seq_ids_with_bonus_token_in_last_step)
# Run model sample_len times.
model_outputs: List[SamplerOutput] = []
if current_platform.is_cuda_alike() and isinstance(
Expand Down Expand Up @@ -109,10 +116,14 @@ def sampler_output(
model_outputs.append(model_output)

# move indices to device to avoid stream sync
indices_of_seq_with_bonus_tokens = torch.tensor(
indices_of_seq_with_bonus_tokens, device=self.device)
filtered_model_outputs = self._filter_model_output(
model_outputs, indices_of_seq_with_bonus_tokens)
if self.has_mtp_runner:
filtered_model_outputs = model_outputs
else:
indices_of_seq_with_bonus_tokens = torch.tensor(
indices_of_seq_with_bonus_tokens, device=self.device)
filtered_model_outputs = self._filter_model_output(
model_outputs, indices_of_seq_with_bonus_tokens)

return filtered_model_outputs, True

@staticmethod
Expand Down
Loading