@@ -51,7 +51,7 @@ def propose(
5151 # [batch_size, max_num_blocks_per_req]
5252 block_table : torch .Tensor ,
5353 sampling_metadata : SamplingMetadata ,
54- ) -> tuple [ torch .Tensor , torch . Tensor ] :
54+ ) -> torch .Tensor :
5555 num_tokens = target_token_ids .shape [0 ]
5656 batch_size = next_token_ids .shape [0 ]
5757 last_token_indices = cu_num_tokens [1 :] - 1
@@ -94,17 +94,15 @@ def propose(
9494 )
9595 sample_hidden_states = hidden_states [last_token_indices ]
9696 logits = self .model .compute_logits (sample_hidden_states , None )
97- draft_token_ids , draft_probs = compute_probs_and_sample_next_token (
98- logits , sampling_metadata )
97+ draft_token_ids = logits .argmax (dim = - 1 )
9998
10099 # Early exit if there is only one draft token to be generated.
101100 if self .num_speculative_tokens == 1 :
102- # [batch_size, 1] and [batch_size, 1, vocab_size]
103- return draft_token_ids .view (- 1 , 1 ), draft_probs . unsqueeze ( dim = 1 )
101+ # [batch_size, 1]
102+ return draft_token_ids .view (- 1 , 1 )
104103
105104 # Generate the remaining draft tokens.
106105 draft_token_ids_list = [draft_token_ids ]
107- draft_probs_list = [draft_probs ]
108106
109107 positions = target_positions [last_token_indices ]
110108 hidden_states = sample_hidden_states
@@ -159,16 +157,12 @@ def propose(
159157 positions = clamped_positions ,
160158 )
161159 logits = self .model .compute_logits (hidden_states , None )
162- draft_token_ids , probs = compute_probs_and_sample_next_token (
163- logits , sampling_metadata )
160+ draft_token_ids = logits .argmax (dim = - 1 )
164161 draft_token_ids_list .append (draft_token_ids )
165- draft_probs_list .append (probs )
166162
167163 # [batch_size, num_speculative_tokens]
168164 draft_token_ids = torch .stack (draft_token_ids_list , dim = 1 )
169- # [batch_size, num_speculative_tokens, vocab_size]
170- draft_probs = torch .stack (draft_probs_list , dim = 1 )
171- return draft_token_ids , draft_probs
165+ return draft_token_ids
172166
173167 @staticmethod
174168 def prepare_inputs (
@@ -238,6 +232,10 @@ def load_model(self, target_model: nn.Module) -> None:
238232 self .model .lm_head = target_model .lm_head
239233
240234
235+ # NOTE(woosuk): Currently, the below code is not used and we always use argmax
236+ # to sample the draft tokens. We will use this after we find a way to manage
237+ # the draft prob tensor.
238+ # Refer to https://github.com/vllm-project/vllm/pull/16899 for the details.
241239# FIXME(woosuk): The logic here is duplicated with the main sampling code.
242240# We should refactor this to reuse the same sampling implementation.
243241def compute_probs_and_sample_next_token (
0 commit comments