3737 PipelineParallel ,
3838)
3939from .pp_utils .batch_comm_helper import BatchCommHelper
40+ from .pp_utils .forward_backward_overlap_utils import ScheduleChunk
4041from .zero_bubble_utils import EventStore , WeightGradStore
4142
4243__all__ = []
@@ -225,9 +226,20 @@ def _backward_compute(self, phase: int, enable_zb: bool = False) -> None:
225226 loss = self .loss_tensors [acc_id ]
226227 if self .overlapped_forward_backward :
227228 loss_fn_node = self .loss_fn_chunks [acc_id ]
228- input_grads = loss_fn_node .backward (scaler = self .scaler )
229229 backward_chunk = self .schedule_chunks [phase ][acc_id ]
230- input_grads = backward_chunk .backward (input_grads )
230+ _ , _ , input_grads = (
231+ self ._layers .overlapped_forward_backward (
232+ ScheduleChunk ([]), # forward_chunk
233+ None , # forward_inputs
234+ None , # forward_loss_fn_node
235+ backward_chunk ,
236+ loss_fn_node ,
237+ None , # input_grads
238+ self .scaler ,
239+ combine_bw_event_to_wait = None ,
240+ pp_stream = None ,
241+ )
242+ )
231243 self .loss_fn_chunks [acc_id ] = None
232244 self .schedule_chunks [phase ][acc_id ] = None
233245 else :
@@ -239,7 +251,19 @@ def _backward_compute(self, phase: int, enable_zb: bool = False) -> None:
239251 outputs , output_grads = self ._get_backward_inputs (phase , acc_id )
240252 if self .overlapped_forward_backward :
241253 backward_chunk = self .schedule_chunks [phase ][acc_id ]
242- input_grads = backward_chunk .backward (output_grads )
254+ _ , _ , input_grads = (
255+ self ._layers .overlapped_forward_backward (
256+ ScheduleChunk ([]), # forward_chunk
257+ None , # forward_inputs
258+ None , # forward_loss_fn_node
259+ backward_chunk ,
260+ None , # backward_loss_fn_node
261+ output_grads ,
262+ None , # scaler
263+ combine_bw_event_to_wait = None ,
264+ pp_stream = None ,
265+ )
266+ )
243267 self .schedule_chunks [phase ][acc_id ] = None
244268 else :
245269 if len (outputs ) > 0 :
0 commit comments