@@ -869,72 +869,9 @@ def _get_spec_token_ids(
869869 spec_token_ids = self ._generate_draft_token_ids (
870870 valid_sampled_token_ids , sampling_metadata )
871871 elif self .speculative_config .method == "eagle" :
872- raise NotImplementedError ("eagle method for spec decode doesn't work on vllm-ascend currently" )
873- assert isinstance (self .drafter , EagleProposer )
874- # TODO(woosuk): Refactor the loop.
875- next_token_ids : list [int ] = []
876- for i , token_ids in enumerate (valid_sampled_token_ids ):
877- if token_ids :
878- # Common case.
879- next_token_id = token_ids [- 1 ]
880- else :
881- # Partial prefill (rare case).
882- # Get the next token id from the request state.
883- req_id = self .input_batch .req_ids [i ]
884- req_state = self .requests [req_id ]
885- seq_len = (req_state .num_computed_tokens +
886- scheduler_output .num_scheduled_tokens [req_id ])
887- next_token_id = req_state .get_token_id (seq_len )
888- next_token_ids .append (next_token_id )
889- next_token_ids = torch .tensor (next_token_ids ,
890- dtype = torch .int32 ,
891- device = self .device )
892-
893- if spec_decode_metadata is None :
894- # input_ids can be None for multimodal models.
895- # We need to slice token_ids, positions, and hidden_states
896- # because the eagle head does not use cuda graph and should
897- # not include padding.
898- target_token_ids = self .input_ids [:num_scheduled_tokens ]
899- target_positions = positions [:num_scheduled_tokens ]
900- target_hidden_states = hidden_states [:num_scheduled_tokens ]
901- target_slot_mapping = attn_metadata .slot_mapping
902- cu_num_tokens = attn_metadata .query_start_loc
903- else :
904- # TODO(woosuk): Refactor this.
905- num_draft_tokens = spec_decode_metadata .num_draft_tokens
906- num_rejected_tokens = [
907- n + 1 - len (valid_sampled_token_ids [i ]) if n > 0 else 0
908- for i , n in enumerate (num_draft_tokens )
909- ]
910- num_rejected_tokens = torch .tensor (
911- num_rejected_tokens ,
912- dtype = torch .int32 ,
913- device = self .device ,
914- )
915- cu_num_tokens , token_indices = self .drafter .prepare_inputs (
916- attn_metadata .query_start_loc ,
917- num_rejected_tokens ,
918- )
919- target_token_ids = self .input_ids [token_indices ]
920- target_positions = positions [token_indices ]
921- target_hidden_states = hidden_states [token_indices ]
922- target_slot_mapping = attn_metadata .slot_mapping [token_indices ]
923-
924- draft_token_ids , draft_probs = self .drafter .propose (
925- target_token_ids = target_token_ids ,
926- target_positions = target_positions ,
927- target_hidden_states = target_hidden_states ,
928- target_slot_mapping = target_slot_mapping ,
929- next_token_ids = next_token_ids ,
930- cu_num_tokens = cu_num_tokens ,
931- block_table = attn_metadata .block_tables ,
932- sampling_metadata = sampling_metadata ,
872+ raise NotImplementedError (
873+ "eagle method for spec decode doesn't work on vllm-ascend currently"
933874 )
934- spec_token_ids = draft_token_ids .tolist ()
935- # TODO(woosuk): Cache draft_probs and use it for rejection sampling
936- # in the next step.
937- del draft_probs
938875 return spec_token_ids
939876
940877 @torch .inference_mode ()
0 commit comments