Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,7 +1060,7 @@ def _exec_recv_grads(self, buffer_id):
local_part=outputs[1],
group=self.grid.get_slice_parallel_group())
outputs[0].data = part_output.full()
outputs = ([outputs[0], *outputs[2:]])
outputs = (outputs[0], *outputs[2:])
# save for backward
self.pipe_buffers['outputs'][buffer_id] = outputs

Expand All @@ -1072,8 +1072,26 @@ def _exec_recv_grads(self, buffer_id):
dtype=outputs.dtype,
num_buffers=1)[0]
else:
sizes_and_dtypes = [(list(t.size()),
t.dtype) for t in outputs if t.is_floating_point()]
# XXX This is a HACK
# When we exchange activations/gradients, the two pipe stages
# need to issue the send/recv with the same buffer sizes or
# else there is a deadlock. The is_floating_point() filter is
# used to avoid sending gradients for tensors that do not
# produce gradients. When TP>1, we partition the first
# activations/gradients across TP ranks to save communication
# volume and memory. That partitioned tensor is represented as
# two tensors: a 1/TPth chunk of the original data and also a
# small LongTensor storing the metadata used to reconstruct on
# the other side. When combined, the floating point filter also
# filtered out the metadata tensor. A quick (hacky) fix just
# branches on is_grad_partitioned so we don't filter out the
# metadata tensor.
if self.is_grad_partitioned:
sizes_and_dtypes = [(list(t.size()), t.dtype) for t in outputs]
else:
sizes_and_dtypes = [(list(t.size()),
t.dtype) for t in outputs
if t.is_floating_point()]
self.grad_layer = self._allocate_buffers(sizes_and_dtypes,
num_buffers=1)[0]

Expand Down