Skip to content

Commit 8bb6606

Browse files
authored
[Distributed] Use custom overlapping method for backward chunks (#74891)
1 parent daf6fcd commit 8bb6606

File tree

1 file changed

+27
-3
lines changed

1 file changed

+27
-3
lines changed

python/paddle/distributed/fleet/meta_parallel/dualpipev.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
PipelineParallel,
3838
)
3939
from .pp_utils.batch_comm_helper import BatchCommHelper
40+
from .pp_utils.forward_backward_overlap_utils import ScheduleChunk
4041
from .zero_bubble_utils import EventStore, WeightGradStore
4142

4243
__all__ = []
@@ -225,9 +226,20 @@ def _backward_compute(self, phase: int, enable_zb: bool = False) -> None:
225226
loss = self.loss_tensors[acc_id]
226227
if self.overlapped_forward_backward:
227228
loss_fn_node = self.loss_fn_chunks[acc_id]
228-
input_grads = loss_fn_node.backward(scaler=self.scaler)
229229
backward_chunk = self.schedule_chunks[phase][acc_id]
230-
input_grads = backward_chunk.backward(input_grads)
230+
_, _, input_grads = (
231+
self._layers.overlapped_forward_backward(
232+
ScheduleChunk([]), # forward_chunk
233+
None, # forward_inputs
234+
None, # forward_loss_fn_node
235+
backward_chunk,
236+
loss_fn_node,
237+
None, # input_grads
238+
self.scaler,
239+
combine_bw_event_to_wait=None,
240+
pp_stream=None,
241+
)
242+
)
231243
self.loss_fn_chunks[acc_id] = None
232244
self.schedule_chunks[phase][acc_id] = None
233245
else:
@@ -239,7 +251,19 @@ def _backward_compute(self, phase: int, enable_zb: bool = False) -> None:
239251
outputs, output_grads = self._get_backward_inputs(phase, acc_id)
240252
if self.overlapped_forward_backward:
241253
backward_chunk = self.schedule_chunks[phase][acc_id]
242-
input_grads = backward_chunk.backward(output_grads)
254+
_, _, input_grads = (
255+
self._layers.overlapped_forward_backward(
256+
ScheduleChunk([]), # forward_chunk
257+
None, # forward_inputs
258+
None, # forward_loss_fn_node
259+
backward_chunk,
260+
None, # backward_loss_fn_node
261+
output_grads,
262+
None, # scaler
263+
combine_bw_event_to_wait=None,
264+
pp_stream=None,
265+
)
266+
)
243267
self.schedule_chunks[phase][acc_id] = None
244268
else:
245269
if len(outputs) > 0:

0 commit comments

Comments
 (0)