@@ -1375,7 +1375,6 @@ def _pool(
13751375 num_scheduled_tokens_np: np.ndarray,
13761376 finished_sending: Optional[set[str]],
13771377 finished_recving: Optional[set[str]],
1378- finished_loading_dict: Optional[dict[str, int]],
13791378 ) -> ModelRunnerOutput:
13801379 assert self.input_batch.num_reqs ==\
13811380 len(self.input_batch.pooling_params), \
@@ -1412,7 +1411,6 @@ def _pool(
14121411 pooler_output=pooler_output,
14131412 finished_sending=finished_sending,
14141413 finished_recving=finished_recving,
1415- finished_loading_dict=finished_loading_dict,
14161414 )
14171415
14181416 @torch.inference_mode()
@@ -1532,7 +1530,6 @@ def execute_model(
15321530 self.maybe_wait_for_kv_save()
15331531 finished_sending, finished_recving = (
15341532 self.get_finished_kv_transfers(scheduler_output))
1535- finished_loading_dict = self.get_finished_loading(scheduler_output)
15361533
15371534 if self.use_aux_hidden_state_outputs:
15381535 hidden_states, aux_hidden_states = model_output
@@ -1550,11 +1547,9 @@ def execute_model(
15501547 if not get_pp_group().is_last_rank:
15511548 # For mid-pipeline stages, return the hidden states.
15521549 if not broadcast_pp_output:
1553- if (finished_sending or finished_recving
1554- or finished_loading_dict):
1550+ if finished_sending or finished_recving:
15551551 hidden_states.finished_sending = finished_sending
15561552 hidden_states.finished_recving = finished_recving
1557- hidden_states.finished_loading_dict = finished_loading_dict
15581553 return hidden_states
15591554 assert isinstance(hidden_states, IntermediateTensors)
15601555 get_pp_group().send_tensor_dict(hidden_states.tensors,
@@ -1564,7 +1559,7 @@ def execute_model(
15641559 if self.input_batch.pooling_params:
15651560 return self._pool(hidden_states, num_scheduled_tokens,
15661561 num_scheduled_tokens_np, finished_sending,
1567- finished_recving, finished_loading_dict )
1562+ finished_recving)
15681563
15691564 sample_hidden_states = hidden_states[logits_indices]
15701565 logits = self.model.compute_logits(sample_hidden_states, None)
@@ -1716,7 +1711,6 @@ def execute_model(
17161711 pooler_output=[],
17171712 finished_sending=finished_sending,
17181713 finished_recving=finished_recving,
1719- finished_loading_dict=finished_loading_dict,
17201714 num_nans_in_logits=num_nans_in_logits,
17211715 )
17221716
0 commit comments