Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions python/paddle/distributed/fleet/meta_parallel/dualpipev.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
PipelineParallel,
)
from .pp_utils.batch_comm_helper import BatchCommHelper
from .pp_utils.forward_backward_overlap_utils import ScheduleChunk
from .zero_bubble_utils import EventStore, WeightGradStore

__all__ = []
Expand Down Expand Up @@ -225,9 +226,20 @@ def _backward_compute(self, phase: int, enable_zb: bool = False) -> None:
loss = self.loss_tensors[acc_id]
if self.overlapped_forward_backward:
loss_fn_node = self.loss_fn_chunks[acc_id]
input_grads = loss_fn_node.backward(scaler=self.scaler)
backward_chunk = self.schedule_chunks[phase][acc_id]
input_grads = backward_chunk.backward(input_grads)
_, _, input_grads = (
self._layers.overlapped_forward_backward(
ScheduleChunk([]), # forward_chunk
None, # forward_inputs
None, # forward_loss_fn_node
backward_chunk,
loss_fn_node,
None, # input_grads
self.scaler,
combine_bw_event_to_wait=None,
pp_stream=None,
)
)
self.loss_fn_chunks[acc_id] = None
self.schedule_chunks[phase][acc_id] = None
else:
Expand All @@ -239,7 +251,19 @@ def _backward_compute(self, phase: int, enable_zb: bool = False) -> None:
outputs, output_grads = self._get_backward_inputs(phase, acc_id)
if self.overlapped_forward_backward:
backward_chunk = self.schedule_chunks[phase][acc_id]
input_grads = backward_chunk.backward(output_grads)
_, _, input_grads = (
self._layers.overlapped_forward_backward(
ScheduleChunk([]), # forward_chunk
None, # forward_inputs
None, # forward_loss_fn_node
backward_chunk,
None, # backward_loss_fn_node
output_grads,
None, # scaler
combine_bw_event_to_wait=None,
pp_stream=None,
)
)
self.schedule_chunks[phase][acc_id] = None
else:
if len(outputs) > 0:
Expand Down