Skip to content

Commit ace0936

Browse files
committed
Format
1 parent 89d96cc commit ace0936

File tree

2 files changed

+101
-111
lines changed

2 files changed

+101
-111
lines changed

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 95 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@
3030
from vllm.logger import init_logger
3131
from vllm.v1.utils import CpuGpuBuffer
3232
from vllm.utils import (
33-
is_pin_memory_available,
34-
)
33+
is_pin_memory_available, )
3534
import numpy as np
3635
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
3736

@@ -79,55 +78,53 @@ def __init__(
7978
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
8079
cudagraph_mode = compilation_config.cudagraph_mode
8180
if cudagraph_mode != CUDAGraphMode.NONE and not cudagraph_mode.has_mode(
82-
CUDAGraphMode.PIECEWISE
83-
):
81+
CUDAGraphMode.PIECEWISE):
8482
logger.warning(
8583
"Currently the eagle proposer only supports cudagraph_mode "
8684
"PIECEWISE, if you want the drafter to use cuda graphs, "
8785
"please set compilation_config.cudagraph_mode to PIECEWISE "
88-
"or FULL_AND_PIECEWISE"
89-
)
90-
self.use_aclgraph = (
91-
cudagraph_mode.has_mode(CUDAGraphMode.PIECEWISE)
92-
and not self.speculative_config.enforce_eager
93-
)
94-
95-
self.cudagraph_batch_sizes = (
96-
list(reversed(self.vllm_config.compilation_config.cudagraph_capture_sizes))
97-
if self.use_aclgraph
98-
else []
99-
)
86+
"or FULL_AND_PIECEWISE")
87+
self.use_aclgraph = (cudagraph_mode.has_mode(
88+
CUDAGraphMode.PIECEWISE)
89+
and not self.speculative_config.enforce_eager)
90+
91+
self.cudagraph_batch_sizes = (list(
92+
reversed(
93+
self.vllm_config.compilation_config.cudagraph_capture_sizes))
94+
if self.use_aclgraph else [])
10095

10196
# persistent buffers for aclgraph graph
102-
self.input_ids = torch.zeros(
103-
self.max_num_tokens, dtype=torch.int32, device=device
104-
)
97+
self.input_ids = torch.zeros(self.max_num_tokens,
98+
dtype=torch.int32,
99+
device=device)
105100
self.uses_mrope = self.vllm_config.model_config.uses_mrope
106101
if self.uses_mrope:
107102
# M-RoPE need (3, max_num_tokens)
108-
self.mrope_positions = torch.zeros(
109-
(3, self.max_num_tokens), dtype=torch.int64, device=device
110-
)
103+
self.mrope_positions = torch.zeros((3, self.max_num_tokens),
104+
dtype=torch.int64,
105+
device=device)
111106
else:
112107
# RoPE need (max_num_tokens,)
113-
self.positions = torch.zeros(
114-
self.max_num_tokens, dtype=torch.int64, device=device
115-
)
108+
self.positions = torch.zeros(self.max_num_tokens,
109+
dtype=torch.int64,
110+
device=device)
116111
self.hidden_states = torch.zeros(
117-
(self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device
118-
)
112+
(self.max_num_tokens, self.hidden_size),
113+
dtype=self.dtype,
114+
device=device)
119115

120116
# We need +1 here because the arange is used to set query_start_loc,
121117
# which has one more element than batch_size.
122118
max_batch_size = vllm_config.scheduler_config.max_num_seqs
123119
max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens)
124-
self.arange = torch.arange(
125-
max_num_slots_for_arange, device=device, dtype=torch.int32
126-
)
120+
self.arange = torch.arange(max_num_slots_for_arange,
121+
device=device,
122+
dtype=torch.int32)
127123

128124
self.inputs_embeds = torch.zeros(
129-
(self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device
130-
)
125+
(self.max_num_tokens, self.hidden_size),
126+
dtype=self.dtype,
127+
device=device)
131128

132129
self.backup_next_token_ids = CpuGpuBuffer(
133130
max_batch_size,
@@ -221,8 +218,8 @@ def generate_token_ids(self,
221218
hidden_states: torch.Tensor = None,
222219
attn_metadata=None,
223220
aux_hidden_states: torch.Tensor = None,
224-
common_attn_metadata: AscendCommonAttentionMetadata = None
225-
):
221+
common_attn_metadata: AscendCommonAttentionMetadata
222+
| None = None):
226223
if attn_metadata is not None and isinstance(attn_metadata, dict):
227224
attn_metadata = attn_metadata['model.layers.0.self_attn.attn']
228225
next_token_ids: list[int] = []
@@ -299,12 +296,12 @@ def generate_token_ids(self,
299296
# common_attn_metadata
300297
# )
301298
if self.speculative_config.disable_padded_drafter_batch:
302-
token_indices_to_sample = None
303-
common_attn_metadata, token_indices =\
304-
self._prepare_inputs(
305-
common_attn_metadata,
306-
sampled_token_ids,
307-
spec_decode_metadata.num_draft_tokens)
299+
token_indices_to_sample = None
300+
common_attn_metadata, token_indices =\
301+
self._prepare_inputs(
302+
common_attn_metadata,
303+
sampled_token_ids,
304+
spec_decode_metadata.num_draft_tokens)
308305
else:
309306
common_attn_metadata, token_indices, \
310307
token_indices_to_sample =\
@@ -317,15 +314,15 @@ def generate_token_ids(self,
317314
target_hidden_states = hidden_states[:token_indices]
318315

319316
draft_token_ids = self._propose(
320-
target_token_ids=target_token_ids,
321-
target_positions=target_positions,
322-
target_hidden_states=target_hidden_states,
323-
next_token_ids=next_token_ids,
324-
last_token_indices=token_indices_to_sample,
325-
common_attn_metadata=common_attn_metadata,
326-
sampling_metadata=sampling_metadata,
327-
)
328-
317+
target_token_ids=target_token_ids,
318+
target_positions=target_positions,
319+
target_hidden_states=target_hidden_states,
320+
next_token_ids=next_token_ids,
321+
last_token_indices=token_indices_to_sample,
322+
common_attn_metadata=common_attn_metadata,
323+
sampling_metadata=sampling_metadata,
324+
)
325+
329326
return draft_token_ids
330327

331328
def _prepare_inputs(
@@ -360,14 +357,16 @@ def _prepare_inputs(
360357
n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
361358
for i, n in enumerate(num_draft_tokens)
362359
]
363-
num_rejected_tokens = torch.tensor(num_rejected_tokens, dtype=torch.int32)
360+
num_rejected_tokens = torch.tensor(num_rejected_tokens,
361+
dtype=torch.int32)
364362

365363
device = common_attn_metadata.query_start_loc.device
366364
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
367365
new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens
368366

369367
# [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
370-
new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
368+
new_query_len_per_req = query_start_loc_cpu[
369+
1:] - query_start_loc_cpu[:-1]
371370
# [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
372371
new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
373372
new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()
@@ -388,36 +387,36 @@ def _prepare_inputs(
388387
# [0, 2, 6, 9] ->
389388
# [0, 0, 2, 2, 2, 2, 6, 6, 6]
390389
# _r1_ ____r2____ ___r3__
391-
new_query_start_locs_expanded = np.repeat(
392-
new_query_start_loc_np[:-1], new_num_tokens_per_req_np
393-
)
390+
new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1],
391+
new_num_tokens_per_req_np)
394392
# [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
395393
# [0, 1, 0, 1, 2, 3, 0, 1, 2]
396394
# _r1_ ____r2____ ___r3__
397-
token_offests = (
398-
self.token_arange_np[:total_num_tokens] - new_query_start_locs_expanded
399-
)
395+
token_offests = (self.token_arange_np[:total_num_tokens] -
396+
new_query_start_locs_expanded)
400397

401398
# Expand starting positions to match token pattern
402399
# [0, q1, q1 + q2] ->
403400
# [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
404401
# _r1_ _____r2_______ ___________r3____________
405402
old_query_start_locs_expanded = np.repeat(
406-
query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np
407-
)
403+
query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np)
408404
# Final token indices are:
409405
# [0, 1, // req 1
410406
# q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2
411407
# q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
412408
token_indices_np = token_offests + old_query_start_locs_expanded
413-
token_indices = torch.from_numpy(token_indices_np).to(device, non_blocking=True)
409+
token_indices = torch.from_numpy(token_indices_np).to(
410+
device, non_blocking=True)
414411

415412
spec_common_attn_metadata = AscendCommonAttentionMetadata(
416-
query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True),
413+
query_start_loc=new_query_start_loc_cpu.to(device,
414+
non_blocking=True),
417415
query_start_loc_cpu=new_query_start_loc_cpu,
418416
seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
419417
seq_lens_cpu=new_seq_lens_cpu,
420-
num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
418+
num_computed_tokens_cpu=common_attn_metadata.
419+
num_computed_tokens_cpu,
421420
num_reqs=common_attn_metadata.num_reqs,
422421
num_actual_tokens=total_num_tokens,
423422
max_query_len=new_query_len_per_req.max().item(),
@@ -432,9 +431,6 @@ def _prepare_inputs(
432431
decode_token_per_req=self.runner.decode_token_per_req,
433432
)
434433
return spec_common_attn_metadata, token_indices
435-
436-
437-
438434

439435
def _propose(
440436
self,
@@ -460,8 +456,7 @@ def _propose(
460456
if self.method == "eagle3":
461457
assert isinstance(self.model, Eagle3LlamaForCausalLM)
462458
target_hidden_states = self.model.combine_hidden_states(
463-
target_hidden_states
464-
)
459+
target_hidden_states)
465460
assert target_hidden_states.shape[-1] == self.hidden_size
466461

467462
# Shift the input ids by one token.
@@ -506,10 +501,9 @@ def _propose(
506501
if aclgraph_runtime_mode != CUDAGraphMode.NONE:
507502
# Fallback to piecewise graph, when acl full graph is enabled
508503
logger.warning(
509-
f"Currently the eagle proposer only supports cudagraph_mode "
510-
"PIECEWISE, and is forced to set graph mode from {aclgraph_runtime_mode} "
511-
"to CUDAGraphMode.PIECEWISE"
512-
)
504+
f"Currently the eagle proposer only supports cudagraph_mode "
505+
"PIECEWISE, and is forced to set graph mode from {aclgraph_runtime_mode} "
506+
"to CUDAGraphMode.PIECEWISE")
513507
aclgraph_runtime_mode = CUDAGraphMode.PIECEWISE
514508

515509
for step in range(self.num_speculative_tokens):
@@ -692,14 +686,15 @@ def prepare_next_token_ids_cpu(
692686
# Get the next token id from the request state.
693687
req_id = req_ids[i]
694688
req_state = requests[req_id]
695-
seq_len = req_state.num_computed_tokens + num_scheduled_tokens[req_id]
689+
seq_len = req_state.num_computed_tokens + num_scheduled_tokens[
690+
req_id]
696691
next_token_id = req_state.get_token_id(seq_len)
697692
next_token_ids.append(next_token_id)
698-
next_token_ids = torch.tensor(
699-
next_token_ids, dtype=torch.int32, device=self.input_ids.device
700-
)
693+
next_token_ids = torch.tensor(next_token_ids,
694+
dtype=torch.int32,
695+
device=self.input_ids.device)
701696
return next_token_ids
702-
697+
703698
def prepare_next_token_ids_padded(
704699
self,
705700
common_attn_metadata: CommonAttentionMetadata,
@@ -722,30 +717,24 @@ def prepare_next_token_ids_padded(
722717

723718
# Precompute get_token_id for when there is no valid next token
724719
num_reqs = gpu_input_batch.num_reqs
725-
self.backup_next_token_ids.np[:num_reqs] = np.array(
726-
[
727-
requests[gpu_input_batch.req_ids[i]].get_token_id(
728-
common_attn_metadata.seq_lens_cpu[i].item()
729-
)
730-
for i in range(num_reqs)
731-
]
732-
)
720+
self.backup_next_token_ids.np[:num_reqs] = np.array([
721+
requests[gpu_input_batch.req_ids[i]].get_token_id(
722+
common_attn_metadata.seq_lens_cpu[i].item())
723+
for i in range(num_reqs)
724+
])
733725
self.backup_next_token_ids.copy_to_gpu(num_reqs)
734726

735727
# Mask out the sampled tokens indices that should not be sampled.
736-
discard_sampled_tokens_req_indices = discard_request_indices[
737-
:num_discarded_requests
738-
]
728+
discard_sampled_tokens_req_indices = discard_request_indices[:
729+
num_discarded_requests]
739730

740731
valid_sampled_token_ids_gpu = sampled_token_ids.clone()
741732
valid_sampled_token_ids_gpu.index_fill_(
742-
0, discard_sampled_tokens_req_indices, -1
743-
)
733+
0, discard_sampled_tokens_req_indices, -1)
744734

745735
# Generate a mask for all valid tokens within those requests
746736
valid_mask = (valid_sampled_token_ids_gpu != -1) & (
747-
valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size
748-
)
737+
valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size)
749738

750739
# Count the number of valid tokens in each request
751740
valid_sampled_tokens_count = valid_mask.sum(dim=1)
@@ -757,8 +746,8 @@ def prepare_next_token_ids_padded(
757746
# Get last valid token from each row
758747
# (assume undefined state where there is no valid token)
759748
selected_tokens = torch.gather(
760-
valid_sampled_token_ids_gpu, 1, last_valid_indices_safe.unsqueeze(1)
761-
).squeeze(1)
749+
valid_sampled_token_ids_gpu, 1,
750+
last_valid_indices_safe.unsqueeze(1)).squeeze(1)
762751

763752
# Use last token if valid, pre-computed backup if not
764753
batch_size = valid_sampled_token_ids_gpu.shape[0]
@@ -769,7 +758,7 @@ def prepare_next_token_ids_padded(
769758
)
770759

771760
return next_token_ids, valid_sampled_tokens_count
772-
761+
773762
def prepare_inputs_padded(
774763
self,
775764
common_attn_metadata: CommonAttentionMetadata,
@@ -784,13 +773,11 @@ def prepare_inputs_padded(
784773
used as padding and filtered out later by `token_indices_to_sample`.
785774
No blocking CPU operations should be introduced in this function.
786775
"""
787-
num_draft_tokens_gpu = torch.cat(
788-
[
789-
spec_decode_metadata.cu_num_draft_tokens[0:1],
790-
spec_decode_metadata.cu_num_draft_tokens[1:]
791-
- spec_decode_metadata.cu_num_draft_tokens[:-1],
792-
]
793-
)
776+
num_draft_tokens_gpu = torch.cat([
777+
spec_decode_metadata.cu_num_draft_tokens[0:1],
778+
spec_decode_metadata.cu_num_draft_tokens[1:] -
779+
spec_decode_metadata.cu_num_draft_tokens[:-1],
780+
])
794781

795782
num_rejected_tokens_gpu = torch.where(
796783
num_draft_tokens_gpu > 0,
@@ -800,7 +787,8 @@ def prepare_inputs_padded(
800787

801788
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
802789

803-
new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
790+
new_query_len_per_req = query_start_loc_cpu[
791+
1:] - query_start_loc_cpu[:-1]
804792

805793
total_num_tokens = query_start_loc_cpu[-1].item()
806794
token_indices = self.arange[:total_num_tokens]
@@ -821,11 +809,11 @@ def prepare_inputs_padded(
821809
attn_state=self.runner.attn_state,
822810
graph_pad_size=self.runner.graph_pad_size,
823811
decode_token_per_req=self.runner.decode_token_per_req,
824-
num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
812+
num_computed_tokens_cpu=common_attn_metadata.
813+
num_computed_tokens_cpu,
825814
seq_lens=common_attn_metadata.seq_lens)
826815

827-
token_indices_to_sample = (
828-
common_attn_metadata.query_start_loc[1:] - 1 - num_rejected_tokens_gpu
829-
)
816+
token_indices_to_sample = (common_attn_metadata.query_start_loc[1:] -
817+
1 - num_rejected_tokens_gpu)
830818

831-
return spec_common_attn_metadata, token_indices, token_indices_to_sample
819+
return spec_common_attn_metadata, token_indices, token_indices_to_sample

0 commit comments

Comments
 (0)