-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Getting strange validation loss/metric values when multiple data-loaders are used #9683
Comments
Hey @raman-rajarathinam, I apologise the Let me explain what was wrong.
def aggregate_validation_metrics(self, val_outputs, loss_name):
tot_loss: torch.FloatTensor = torch.tensor(0.0, device=self.device)
tot_loss += sum(val_outputs) / len(val_outputs)
tot_loss = self.accelerator.reduce(tot_loss)
if self.trainer.is_global_rank_zero:
self.log(
f"tot_{loss_name}",
tot_loss,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
sync_dist=True,
rank_zero_only=True,
) The correct script looks like this. import pytorch_lightning as pl
import torch
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.trainer.trainer import Trainer
from torch.nn import functional as F
from mnist_datamodule import MNISTDataModule
pl.seed_everything(42)
class LitClassifier(pl.LightningModule):
def __init__(self, hidden_dim: int = 128, learning_rate: float = 0.0001):
super().__init__()
self.save_hyperparameters()
self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = torch.relu(self.l1(x))
x = torch.relu(self.l2(x))
return x
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log(
f"train_loss",
loss,
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
sync_dist=True,
)
return loss
def validation_step(self, batch, batch_idx, dataset_idx=None):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log(
f"val_loss",
loss,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)
return loss
def test_step(self, batch, batch_idx, dataset_idx=None):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log(
f"test_loss",
loss,
prog_bar=True,
logger=True,
)
return loss
def aggregate_validation_metrics(self, val_outputs, loss_name):
tot_loss: torch.FloatTensor = torch.tensor(0.0, device=self.device)
# multi data loader
if isinstance(val_outputs[0], list):
for loss in val_outputs:
tot_loss += sum(loss) / len(loss)
tot_loss = tot_loss / len(val_outputs)
# single data loader
else:
tot_loss += sum(val_outputs) / len(val_outputs)
self.log(
f"tot_{loss_name}",
tot_loss,
prog_bar=True,
logger=True,
)
def validation_epoch_end(self, val_outputs):
self.aggregate_validation_metrics(val_outputs, loss_name="val_loss")
def test_epoch_end(self, val_outputs):
self.aggregate_validation_metrics(val_outputs, loss_name="test_loss")
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
return optimizer
def main():
model = LitClassifier()
data_module = MNISTDataModule()
trainer = Trainer(
gpus=2,
max_epochs=5,
num_sanity_val_steps=0,
logger=TensorBoardLogger("mnist_logs", name="mnist"),
accelerator="ddp",
)
trainer.fit(model, data_module)
trainer.test(ckpt_path="best")
if __name__ == "__main__":
main() Best, |
I would be closing the issue as everything is working as expected. Best, |
I get the same results when I use |
Here is a smaller, synthetic repro example: import os
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loggers import TensorBoardLogger
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class Dataset1(Dataset):
def __getitem__(self, item):
return [1, 2, 3, 4, 5][item]
def __len__(self):
return 5
class Dataset2(Dataset):
def __getitem__(self, item):
return [2, 4, 6, 8, 10][item]
def __len__(self):
return 5
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx, dataset_idx):
if self.current_epoch == 0:
self.log("val_loss", batch, on_step=False, on_epoch=True, prog_bar=True, logger=True)
return batch.item()
else:
self.log("val_loss", batch * 10, on_step=False, on_epoch=True, prog_bar=True, logger=True)
return batch.item() * 10
def validation_epoch_end(self, outputs):
if self.current_epoch == 0:
assert sum(outputs[0]) / 5 == 3
assert sum(outputs[1]) / 5 == 6
else:
assert sum(outputs[0]) / 5 == 30
assert sum(outputs[1]) / 5 == 60
tot_loss = torch.tensor(0.0)
for loss in outputs:
tot_loss += sum(loss) / len(loss)
tot_loss = tot_loss / len(outputs)
if self.current_epoch == 0:
assert tot_loss == (3 + 6) / 2
else:
assert tot_loss == (30 + 60) / 2
self.log("tot_val_loss", tot_loss, prog_bar=True, logger=True)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run():
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data1 = DataLoader(Dataset1())
val_data2 = DataLoader(Dataset2())
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=1,
num_sanity_val_steps=0,
max_epochs=3,
log_every_n_steps=1,
weights_summary=None,
logger=TensorBoardLogger("mnist_logs", name="mnist2"),
)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=[val_data1, val_data2])
if __name__ == "__main__":
run() The assertions hold, but the progress bar logging (and tb logging) do not show the same values. It is clear that the tracking for the loss of the individual dataloader_idx parts do not get reset from one epoch to the next. Instead, they keep aggregating. |
Environment
pip
when installing pytorch-lightning🐛 Bug
Hi, I have been using PL 1.3.x all along, when I updated to 1.4.x (I have tried from 1.4.0 to1.4.8) I started getting weird values for validation loss/metric. Training uses 2 gpus,
ddp
and 2 dataloaders for validation.At
validation_epoch_end
I do aggregate (average) the results ofdataloader_idx_0
anddataloader_idx_1
, but when I check the values printed byself.log
they don't add upAggregate method used,
and its results
Expected behavior
Correct aggregated (averaged) values at
validation_epoch_end
To Reproduce
I have used MNIST model and have attached the code
run
python simple_classifier.py
The text was updated successfully, but these errors were encountered: