Skip to content

Training hangs with DeepSpeed when DDP workers have different number of training batches #13498

Closed
@xinyangz

Description

@xinyangz

🐛 Bug

My use case involves streaming a large dataset for distributed training. During this process, each distributed worker may get different number of training batches. Please see the boring model example bellow for an equivalent case.

When turning DeepSpeed integration on, the code hangs after one full epoch. All GPUs have 100% utilization, while GPU power remains low. I cannot pinpoint the error as keyboard interrupt wouldn't work and I have to kill everything.

The training does not hang if DeepSpeed is turned off. I'm not quite sure if this is a lightning bug or a DeepSpeed bug.

To Reproduce

import os

import torch
from torch.utils.data import DataLoader, Dataset, IterableDataset

from pytorch_lightning import LightningModule, Trainer


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        num_sanity_val_steps=0,
        max_epochs=10,
        enable_model_summary=False,
        strategy="deepspeed_stage_1",     # DeepSpeed turned on
        accelerator="gpu",
        devices=2,    # bug happens when n_gpu > 1
    )
    train_data = DataLoader(RandomDataset(32, 64 + trainer.local_rank * 4), batch_size=2)   # each DDP worker gets different number of batches
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    test_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    # trainer.test(model, dataloaders=test_data)


if __name__ == "__main__":
    run()

Expected behavior

Training finishes without hanging.

Environment

  • CUDA:
    - GPU:
    - A100-SXM4-40GB
    - A100-SXM4-40GB
    - available: True
    - version: 11.3
  • Packages:
    - numpy: 1.21.6
    - pyTorch_debug: False
    - pyTorch_version: 1.11.0+cu113
    - pytorch-lightning: 1.6.4
    - tqdm: 4.64.0
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.7.12
  • Any other relevant information: DeepSpeed version 0.6.5

Additional context

cc @justusschock @awaelchli @ninginthecloud @rohitgr7 @otaj @SeanNaren @akihironitta

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions