Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot replicate training results with seed_everything and strategy=ddp_spawn #17399

Closed
YUHUINI1995 opened this issue Apr 17, 2023 · 0 comments · Fixed by #18238
Closed

Cannot replicate training results with seed_everything and strategy=ddp_spawn #17399

YUHUINI1995 opened this issue Apr 17, 2023 · 0 comments · Fixed by #18238
Labels
bug Something isn't working reproducibility strategy: ddp DistributedDataParallel ver: 1.9.x

Comments

@YUHUINI1995
Copy link

YUHUINI1995 commented Apr 17, 2023

Bug description

I noticed that strategy=ddp_spawn can not give deterministic results. If I change strategy=ddp_spawn to strategy=ddp_fork, then training results are deterministic. I use pytorch_lightning==1.9.5 and torch==1.11.0

Below is the code that can reproduce the bug.

import copy
import os
from typing import Dict, List

import numpy as np
import pytorch_lightning as pl
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST


class LitAutoEncoder(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
        self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        embedding = self.encoder(x)
        return embedding

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop. It is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train, val = random_split(dataset, [55000, 5000])

def compare(strategy):
    pl.seed_everything(5)
    torch.set_num_threads(1)
    autoencoder = LitAutoEncoder()
    trainer = pl.Trainer(
        max_epochs=1,
        accelerator="cpu",
        devices=2,
        enable_progress_bar=False,
        enable_model_summary=False,
        deterministic=True,
        strategy=strategy,
    )
    trainer.fit(
        autoencoder, DataLoader(train, batch_size=500), DataLoader(val, batch_size=500)
    )

    pl.seed_everything(5)
    torch.set_num_threads(1)
    copy_autoencoder = LitAutoEncoder()
    copy_trainer = pl.Trainer(
        max_epochs=1,
        accelerator="cpu",
        devices=2,
        enable_progress_bar=False,
        enable_model_summary=False,
        deterministic=True,
        strategy=strategy,
    )
    copy_trainer.fit(
        copy_autoencoder, DataLoader(train, batch_size=500), DataLoader(val, batch_size=500)
    )

    def assert_state_dict_equal(state_dict_a: dict, state_dict_b: dict):
        np.testing.assert_equal(list(state_dict_a.keys()), list(state_dict_b.keys()))
        for key in state_dict_a:
            np.testing.assert_array_equal(state_dict_a[key], state_dict_b[key])

    assert_state_dict_equal(
        copy.deepcopy(autoencoder).state_dict(),
        copy.deepcopy(copy_autoencoder).state_dict(),
    )

if __name__ == '__main__':
    strategy = "ddp_spawn"
    compare(strategy)

What version are you seeing the problem on?

1.9.x

How to reproduce the bug

No response

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

No response

cc @awaelchli @justusschock

@YUHUINI1995 YUHUINI1995 added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Apr 17, 2023
@carmocca carmocca added reproducibility strategy: ddp spawn and removed needs triage Waiting to be triaged by maintainers labels Apr 17, 2023
@awaelchli awaelchli added strategy: ddp DistributedDataParallel and removed strategy: ddp spawn labels Nov 4, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working reproducibility strategy: ddp DistributedDataParallel ver: 1.9.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants