@@ -1320,14 +1320,14 @@ def _process_reqs(
13201320 if self .use_aux_hidden_state_outputs :
13211321 hidden_states , aux_hidden_states = hidden_states
13221322
1323- if _enable_lmhead_tp (): #
1323+ if _enable_lmhead_tp ():
13241324 if not with_prefill :
13251325 max_num_reqs_across_dp = padded_num_tokens_across_dp
13261326 else :
13271327 max_num_reqs_across_dp = self .max_num_reqs
1328- sample_indices = nn .functional .pad (
1329- sample_indices ,
1330- (0 , max_num_reqs_across_dp - sample_indices .shape [0 ]))
1328+ logits_indices = nn .functional .pad (
1329+ logits_indices ,
1330+ (0 , max_num_reqs_across_dp - logits_indices .shape [0 ]))
13311331
13321332 return (attn_metadata , hidden_states , spec_decode_metadata , positions ,
13331333 total_num_scheduled_tokens , logits_indices , aux_hidden_states ,
@@ -1656,14 +1656,14 @@ def execute_model(
16561656 # Sample the next token and get logprobs if needed.
16571657 sampling_metadata = self .input_batch .sampling_metadata
16581658 if spec_decode_metadata is None :
1659- if _enable_lmhead_tp ():
1659+ if _enable_lmhead_tp () and logits is not None :
16601660 logits = logits [:self .input_batch .num_reqs ]
16611661 sampler_output = self .sampler (
16621662 logits = logits ,
16631663 sampling_metadata = sampling_metadata ,
16641664 )
16651665 else :
1666- if _enable_lmhead_tp ():
1666+ if _enable_lmhead_tp () and logits is not None :
16671667 logits = logits [:len (spec_decode_metadata .logits_indices )]
16681668 # When indexing with a tensor (bonus_logits_indices), PyTorch
16691669 # creates a new tensor with separate storage from the original
@@ -1952,16 +1952,16 @@ def _dummy_run(
19521952 with_prefill , is_torchair_compile , input_ids , positions ,
19531953 attn_metadata , num_tokens , intermediate_tensors ,
19541954 inputs_embeds )
1955-
1955+
19561956 if _enable_lmhead_tp () and not self .in_profile_run :
1957- if not with_prefill :
1958- max_num_reqs_across_dp = num_reqs
1959- else :
1960- max_num_reqs_across_dp = max_num_reqs
1961- dummy_indices = torch .zeros (max_num_reqs_across_dp ,
1962- device = hidden_states .device ,
1963- dtype = torch .int32 )
1964- model .compute_logits (hidden_states [dummy_indices ], None )
1957+ if not with_prefill :
1958+ max_num_reqs_across_dp = num_reqs
1959+ else :
1960+ max_num_reqs_across_dp = max_num_reqs
1961+ dummy_indices = torch .zeros (max_num_reqs_across_dp ,
1962+ device = hidden_states .device ,
1963+ dtype = torch .int32 )
1964+ self . model .compute_logits (hidden_states [dummy_indices ], None )
19651965
19661966 if self .speculative_config and self .speculative_config .method == "deepseek_mtp" :
19671967 assert isinstance (self .drafter , MtpProposer )
@@ -1979,7 +1979,8 @@ def _dummy_run(
19791979 dummy_indices = torch .zeros (max_num_reqs_across_dp ,
19801980 device = hidden_states .device ,
19811981 dtype = torch .int32 )
1982- model .compute_logits (hidden_states [dummy_indices ], None )
1982+ self .model .compute_logits (hidden_states [dummy_indices ],
1983+ None )
19831984
19841985 return hidden_states
19851986
0 commit comments