@@ -1375,7 +1375,6 @@ def _pool(
13751375 num_scheduled_tokens_np : np .ndarray ,
13761376 finished_sending : Optional [set [str ]],
13771377 finished_recving : Optional [set [str ]],
1378- finished_loading_dict : Optional [dict [str , int ]],
13791378 ) -> ModelRunnerOutput :
13801379 assert self .input_batch .num_reqs == \
13811380 len (self .input_batch .pooling_params ), \
@@ -1412,7 +1411,6 @@ def _pool(
14121411 pooler_output = pooler_output ,
14131412 finished_sending = finished_sending ,
14141413 finished_recving = finished_recving ,
1415- finished_loading_dict = finished_loading_dict ,
14161414 )
14171415
14181416 @torch .inference_mode ()
@@ -1532,7 +1530,6 @@ def execute_model(
15321530 self .maybe_wait_for_kv_save ()
15331531 finished_sending , finished_recving = (
15341532 self .get_finished_kv_transfers (scheduler_output ))
1535- finished_loading_dict = self .get_finished_loading (scheduler_output )
15361533
15371534 if self .use_aux_hidden_state_outputs :
15381535 hidden_states , aux_hidden_states = model_output
@@ -1550,11 +1547,9 @@ def execute_model(
15501547 if not get_pp_group ().is_last_rank :
15511548 # For mid-pipeline stages, return the hidden states.
15521549 if not broadcast_pp_output :
1553- if (finished_sending or finished_recving
1554- or finished_loading_dict ):
1550+ if finished_sending or finished_recving :
15551551 hidden_states .finished_sending = finished_sending
15561552 hidden_states .finished_recving = finished_recving
1557- hidden_states .finished_loading_dict = finished_loading_dict
15581553 return hidden_states
15591554 assert isinstance (hidden_states , IntermediateTensors )
15601555 get_pp_group ().send_tensor_dict (hidden_states .tensors ,
@@ -1564,7 +1559,7 @@ def execute_model(
15641559 if self .input_batch .pooling_params :
15651560 return self ._pool (hidden_states , num_scheduled_tokens ,
15661561 num_scheduled_tokens_np , finished_sending ,
1567- finished_recving , finished_loading_dict )
1562+ finished_recving )
15681563
15691564 sample_hidden_states = hidden_states [logits_indices ]
15701565 logits = self .model .compute_logits (sample_hidden_states , None )
@@ -1716,7 +1711,6 @@ def execute_model(
17161711 pooler_output = [],
17171712 finished_sending = finished_sending ,
17181713 finished_recving = finished_recving ,
1719- finished_loading_dict = finished_loading_dict ,
17201714 num_nans_in_logits = num_nans_in_logits ,
17211715 )
17221716
0 commit comments