-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🐛 Bug
When seed_everything(workers=True) is called, it will set the environment variable PL_SEED_WORKERS=1. Consequently Trainer will set the worker_init_fn for dataloaders to pl_worker_init_function. It seems to me that worker_init_fn is not set when using DDP. The reason is that DDPPlugin.setup_environment() eventually runs reset_seed(), which reads the value of the PL_GLOBAL_SEED environment value and calls seed_everything() with the default argument workers=False.
Please reproduce using the BoringModel
It's not possible to reproduce the issue in colab, since it doesn't support DDP.
To Reproduce
Here's a simple program that demonstrates the issue:
import os
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning import seed_everything
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run():
seed_everything(1234, workers=True)
# Sets PL_SEED_WORKERS=1
print('PL_SEED_WORKERS=' + os.environ['PL_SEED_WORKERS'])
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
test_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=1,
limit_val_batches=1,
num_sanity_val_steps=0,
max_epochs=1,
weights_summary=None,
gpus=2,
accelerator='ddp' # Using accelerator='dp' works
)
trainer.fit(model, train_dataloader=train_data, val_dataloaders=val_data)
# Trainer.accelerator.setup_environment() calls DDPPlugin.setup_environment(),
# which eventually runs seed_everything(workers=False) that sets PL_SEED_WORKERS=0
# Consequently dataloader.worker_init_fn is not set.
print('PL_SEED_WORKERS=' + os.environ['PL_SEED_WORKERS'])
if __name__ == '__main__':
run()Expected behavior
I would expect pl_worker_init_function to be called. By printing something from the function, I can see that it's called if I use dp accelerator, but not if I use ddp. I can also notice that the environment variable PL_SEED_WORKERS is reset to 0 during the Trainer.fit() call, but I would expect it to have the value 1 in the end.
I think the correct fix would be to make reset_seed() read the PL_SEED_WORKERS environment variable too and pass the corresponding workers argument to seed_everything(). However, I'm not familiar enough with the code to be sure that this is correct.
Preferably pl_worker_init_function would also display a log message that confirms that the workers are seeded correctly.
Environment
- CUDA:
- GPU:
- NVIDIA Tesla V100-SXM2-16GB
- NVIDIA Tesla V100-SXM2-16GB
- available: True
- version: 11.0
- GPU:
- Packages:
- numpy: 1.19.2
- pyTorch_debug: True
- pyTorch_version: 1.7.0
- pytorch-lightning: 1.4.0dev
- tqdm: 4.51.0
- System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.8.8
- version: Enable any ML experiment tracking framework #47-Ubuntu SMP Tue May 11 15:51:42 UTC 2021
Additional context
Recently there was discussion about an issue with data loading, where the same NumPy random seed is used across different workers. This causes the workers the use the same random numbers for data transforms. A fix was quickly introduced in PyTorch Lightning that seeds the dataloaders correctly by automatically setting the worker_init_fn for dataloaders.