@@ -1191,6 +1191,7 @@ def _get_prompt_logprobs_dict(
11911191 if not num_prompt_logprobs_dict :
11921192 return {}
11931193
1194+ in_progress_dict = self .input_batch .in_progress_prompt_logprobs_cpu
11941195 prompt_logprobs_dict : dict [str , Optional [LogprobsTensors ]] = {}
11951196
11961197 # Since prompt logprobs are a rare feature, prioritize simple,
@@ -1206,16 +1207,36 @@ def _get_prompt_logprobs_dict(
12061207 prompt_token_ids = torch .tensor (request .prompt_token_ids ).to (
12071208 self .device , non_blocking = True )
12081209
1210+ # Set up target LogprobsTensors object.
1211+ logprobs_tensors = in_progress_dict .get (req_id )
1212+ if not logprobs_tensors :
1213+ # Create empty logprobs CPU tensors for the entire prompt.
1214+ # If chunked, we'll copy in slice by slice.
1215+ logprobs_tensors = LogprobsTensors .empty_cpu (
1216+ num_prompt_tokens - 1 , num_prompt_logprobs + 1 )
1217+ in_progress_dict [req_id ] = logprobs_tensors
1218+
12091219 # Determine number of logits to retrieve.
1210- start_tok = request .num_computed_tokens + 1
1220+ start_idx = request .num_computed_tokens
1221+ start_tok = start_idx + 1
12111222 num_remaining_tokens = num_prompt_tokens - start_tok
1212- if num_tokens < num_remaining_tokens :
1223+ if num_tokens <= num_remaining_tokens :
12131224 # This is a chunk, more tokens remain.
1225+ # In the == case, there are no more prompt logprobs to produce
1226+ # but we want to defer returning them to the next step where we
1227+ # have new generated tokens to return.
12141228 num_logits = num_tokens
12151229 else :
12161230 # This is the last chunk of prompt tokens to return.
12171231 num_logits = num_remaining_tokens
12181232 completed_prefill_reqs .append (req_id )
1233+ prompt_logprobs_dict [req_id ] = logprobs_tensors
1234+
1235+ if num_logits <= 0 :
1236+ # This can happen for the final chunk if we prefilled exactly
1237+ # (num_prompt_tokens - 1) tokens for this request in the prior
1238+ # step. There are no more prompt logprobs to produce.
1239+ continue
12191240
12201241 # Get the logits corresponding to this req's prompt tokens.
12211242 # If this is a partial request (i.e. chunked prefill),
@@ -1236,19 +1257,23 @@ def _get_prompt_logprobs_dict(
12361257 logprobs , num_prompt_logprobs , tgt_token_ids )
12371258
12381259 # Transfer GPU->CPU async.
1239- prompt_logprobs_dict [req_id ] = LogprobsTensors (
1240- token_ids .to ("cpu" , non_blocking = True ),
1241- logprobs .to ("cpu" , non_blocking = True ),
1242- ranks .to ("cpu" , non_blocking = True ),
1243- )
1260+ chunk_slice = slice (start_idx , start_idx + num_logits )
1261+ logprobs_tensors .logprob_token_ids [chunk_slice ].copy_ (
1262+ token_ids , non_blocking = True )
1263+ logprobs_tensors .logprobs [chunk_slice ].copy_ (logprobs ,
1264+ non_blocking = True )
1265+ logprobs_tensors .selected_token_ranks [chunk_slice ].copy_ (
1266+ ranks , non_blocking = True )
12441267
12451268 # Remove requests that have completed prefill from the batch
12461269 # num_prompt_logprobs_dict.
12471270 for req_id in completed_prefill_reqs :
12481271 del num_prompt_logprobs_dict [req_id ]
1272+ del in_progress_dict [req_id ]
12491273
12501274 # Must synchronize the non-blocking GPU->CPU transfers.
1251- torch .cuda .synchronize ()
1275+ if prompt_logprobs_dict :
1276+ torch .cuda .synchronize ()
12521277
12531278 return prompt_logprobs_dict
12541279
0 commit comments