diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py index e70536e0199..77079c7763c 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -17,7 +17,12 @@ get_logprob_from_pp_outputs, ) from sglang.srt.model_executor.forward_batch_info import PPProxyTensors -from sglang.srt.utils import DynamicGradMode, broadcast_pyobj, point_to_point_pyobj, require_mlp_sync +from sglang.srt.utils import ( + DynamicGradMode, + broadcast_pyobj, + point_to_point_pyobj, + require_mlp_sync, +) logger = logging.getLogger(__name__) @@ -161,7 +166,9 @@ def _pp_send_output_to_next_stage( pp_outputs: PPProxyTensors | None, ) -> List[P2PWork]: send_output_work = [] - if self.pp_group.is_last_rank: + if self.pp_group.is_last_rank and ( + self.server_args.pp_async_batch_depth == 0 or self.server_args.pp_size <= 2 + ): # send ready PP output to rank 0 if mbs[next_first_rank_mb_id] is not None: q_event, pp_outputs_to_send = last_rank_comm_queue.popleft() @@ -173,7 +180,10 @@ def _pp_send_output_to_next_stage( ) # send the outputs from the last round to let the next stage worker run post processing if not self.pp_group.is_last_rank: - if pp_outputs: + if pp_outputs and ( + self.server_args.pp_size <= 2 + or self.server_args.pp_async_batch_depth > 0 + ): with torch.profiler.record_function("send_res_dict_to_next_stage"): send_output_work = self._pp_send_dict_to_next_stage( pp_outputs.tensors, @@ -354,6 +364,34 @@ def event_loop_pp(self: Scheduler): result.pp_hidden_states_proxy_tensors.tensors, async_send=True, ) + elif ( + self.server_args.pp_async_batch_depth > 0 + and mbs[mb_id] is not None + and self.server_args.pp_size > 2 + ): + # send ready PP output to rank 0 + q_event, pp_outputs_to_send = last_rank_comm_queue.popleft() + torch.cuda.current_stream().wait_event(q_event) + with torch.profiler.record_function("send_res_dict_to_next_stage"): + send_output_work = self._pp_send_dict_to_next_stage( + pp_outputs_to_send.tensors, + async_send=True, + ) + # send the outputs from the last round to let the next stage worker run post processing + if not self.pp_group.is_last_rank: + if next_pp_outputs and ( + self.server_args.pp_size > 2 + and self.server_args.pp_async_batch_depth == 0 + ): + self._pp_commit_comm_work(work=send_output_work) + with torch.profiler.record_function( + "send_res_dict_to_next_stage" + ): + send_output_work = self._pp_send_dict_to_next_stage( + next_pp_outputs.tensors, + async_send=True, + ) + next_pp_outputs = None # if self.delayed_weight_sync_fn: # self.delayed_weight_sync_fn()