Skip to content

allreduce_gradients breaks in the narrow operation #863

@tchaton

Description

@tchaton

Dear people from DeepSpeed,

Associated PR: https://github.com/PyTorchLightning/pytorch-lightning/pull/6546/files

Here is the associated tests causing this error message:
https://github.com/PyTorchLightning/pytorch-lightning/blob/d273393ece9c207b02f02c38e3eb8ddb71b32605/tests/plugins/test_deepspeed_plugin.py#L413

Traceback (most recent call last):
  File "/home/ubuntu/pytorch-lightning/pytorch_lightning/trainer/trainer.py", line 589, in run_train
    self.train_loop.run_training_epoch()
  File "/home/ubuntu/pytorch-lightning/pytorch_lightning/trainer/training_loop.py", line 491, in run_training_epoch
    batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)
  File "/home/ubuntu/pytorch-lightning/pytorch_lightning/trainer/training_loop.py", line 650, in run_training_batch
    self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
  File "/home/ubuntu/pytorch-lightning/pytorch_lightning/trainer/training_loop.py", line 425, in optimizer_step
    model_ref.optimizer_step(
  File "/home/ubuntu/pytorch-lightning/pytorch_lightning/core/lightning.py", line 1397, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/home/ubuntu/pytorch-lightning/pytorch_lightning/core/optimizer.py", line 214, in step
    self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs)
  File "/home/ubuntu/pytorch-lightning/pytorch_lightning/core/optimizer.py", line 134, in __optimizer_step
    trainer.accelerator.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)
  File "/home/ubuntu/pytorch-lightning/pytorch_lightning/accelerators/accelerator.py", line 285, in optimizer_step
    make_optimizer_step = self.precision_plugin.pre_optimizer_step(
  File "/home/ubuntu/pytorch-lightning/pytorch_lightning/plugins/precision/deepspeed_precision.py", line 46, in pre_optimizer_step
    lambda_closure()
  File "/home/ubuntu/pytorch-lightning/pytorch_lightning/trainer/training_loop.py", line 644, in train_step_and_backward_closure
    result = self.training_step_and_backward(
  File "/home/ubuntu/pytorch-lightning/pytorch_lightning/trainer/training_loop.py", line 751, in training_step_and_backward
    self.backward(result, optimizer, opt_idx)
  File "/home/ubuntu/pytorch-lightning/pytorch_lightning/trainer/training_loop.py", line 780, in backward
    result.closure_loss = self.trainer.accelerator.backward(
  File "/home/ubuntu/pytorch-lightning/pytorch_lightning/accelerators/accelerator.py", line 268, in backward
    output = self.precision_plugin.backward(
  File "/home/ubuntu/pytorch-lightning/pytorch_lightning/plugins/precision/deepspeed_precision.py", line 72, in backward
    deepspeed_engine.backward(closure_loss, *args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1004, in backward
    self.allreduce_gradients()
  File "/opt/conda/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 924, in allreduce_gradients
    self.optimizer.overlapping_partition_gradients_reduce_epilogue()
  File "/opt/conda/lib/python3.8/site-packages/deepspeed/runtime/zero/stage3.py", line 1378, in overlapping_partition_gradients_reduce_epilogue
    self.independent_gradient_partition_epilogue()
  File "/opt/conda/lib/python3.8/site-packages/deepspeed/runtime/zero/stage3.py", line 1291, in independent_gradient_partition_epilogue
    self.averaged_gradients[i] = [
  File "/opt/conda/lib/python3.8/site-packages/deepspeed/runtime/zero/stage3.py", line 1293, in <listcomp>
    param.grad.data.narrow(0,
RuntimeError: start (0) + length (2) exceeds dimension size (1).

Any ideas on how to solve this ?

Best,
T.C

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions