Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Getting strange validation loss/metric values when multiple data-loaders are used #9683

Closed
raman-r-4978 opened this issue Sep 24, 2021 · 7 comments · Fixed by #9717
Closed
Assignees
Labels
bug Something isn't working help wanted Open to be worked on priority: 0 High priority task

Comments

@raman-r-4978
Copy link

Environment

  • PyTorch Lightning Version: 1.4.8 (This happens in all 1.4.x)
  • PyTorch Version: 1.9.1
  • Python version: 3.8
  • OS: Linux
  • CUDA/cuDNN version: 11.4
  • GPU models and configuration: MNIST, ddp
  • How you installed PyTorch: pip when installing pytorch-lightning
  • Any other relevant information: Till 1.3.8 everything worked

🐛 Bug

Hi, I have been using PL 1.3.x all along, when I updated to 1.4.x (I have tried from 1.4.0 to1.4.8) I started getting weird values for validation loss/metric. Training uses 2 gpus, ddp and 2 dataloaders for validation.

At validation_epoch_end I do aggregate (average) the results of dataloader_idx_0 and dataloader_idx_1, but when I check the values printed by self.log they don't add up

Aggregate method used,

def aggregate_validation_metrics(self, val_outputs, loss_name):
    tot_loss: torch.FloatTensor = torch.tensor(0.0, device=self.device)
    # multi data loader
    if isinstance(val_outputs[0], list):
        for loss in val_outputs:
            tot_loss += sum(loss) / len(loss)
        tot_loss = tot_loss / len(val_outputs)
    # single data loader
    else:
        tot_loss += sum(val_outputs) / len(val_outputs)

    self.log(
        f"tot_{loss_name}",
        tot_loss,
        on_step=False,
        on_epoch=True,
        prog_bar=True,
        logger=True,
        sync_dist=True,
        rank_zero_only=True,
    )

and its results

PL1.4.8 validation results

Expected behavior

Correct aggregated (averaged) values at validation_epoch_end

To Reproduce

I have used MNIST model and have attached the code

  1. simple_classifier.py
  2. mnist_datamodule.py

run python simple_classifier.py

@raman-r-4978 raman-r-4978 added bug Something isn't working help wanted Open to be worked on labels Sep 24, 2021
@tchaton tchaton added the priority: 0 High priority task label Sep 24, 2021
@tchaton tchaton self-assigned this Sep 24, 2021
@tchaton
Copy link
Contributor

tchaton commented Sep 24, 2021

Hey @raman-rajarathinam,

I apologise the self.log was confusing. After checking from your script, everything seems fine.

Let me explain what was wrong.

  • sync_dist should be used only when you want to perform reduction on_step across processes. This is quite expensive.
  • rank_zero_only = True should be used ONLY within epoch_end hooks and if you perform the cross process reduction by yourself and log only on 1 process.
def aggregate_validation_metrics(self, val_outputs, loss_name):
    tot_loss: torch.FloatTensor = torch.tensor(0.0, device=self.device)
    tot_loss += sum(val_outputs) / len(val_outputs)
    tot_loss  = self.accelerator.reduce(tot_loss)

    if self.trainer.is_global_rank_zero:
          self.log(
              f"tot_{loss_name}",
              tot_loss,
              on_step=False,
              on_epoch=True,
              prog_bar=True,
              logger=True,
              sync_dist=True,
              rank_zero_only=True,
          )

The correct script looks like this.

import pytorch_lightning as pl
import torch
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.trainer.trainer import Trainer
from torch.nn import functional as F

from mnist_datamodule import MNISTDataModule

pl.seed_everything(42)


class LitClassifier(pl.LightningModule):
    def __init__(self, hidden_dim: int = 128, learning_rate: float = 0.0001):
        super().__init__()
        self.save_hyperparameters()
        self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
        self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.l1(x))
        x = torch.relu(self.l2(x))
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log(
            f"train_loss",
            loss,
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            logger=True,
            sync_dist=True,
        )
        return loss

    def validation_step(self, batch, batch_idx, dataset_idx=None):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log(
            f"val_loss",
            loss,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )
        return loss

    def test_step(self, batch, batch_idx, dataset_idx=None):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log(
            f"test_loss",
            loss,
            prog_bar=True,
            logger=True,
        )
        return loss

    def aggregate_validation_metrics(self, val_outputs, loss_name):
        tot_loss: torch.FloatTensor = torch.tensor(0.0, device=self.device)
        # multi data loader
        if isinstance(val_outputs[0], list):
            for loss in val_outputs:
                tot_loss += sum(loss) / len(loss)
            tot_loss = tot_loss / len(val_outputs)
        # single data loader
        else:
            tot_loss += sum(val_outputs) / len(val_outputs)

        self.log(
            f"tot_{loss_name}",
            tot_loss,
            prog_bar=True,
            logger=True,
        )

    def validation_epoch_end(self, val_outputs):
        self.aggregate_validation_metrics(val_outputs, loss_name="val_loss")

    def test_epoch_end(self, val_outputs):
        self.aggregate_validation_metrics(val_outputs, loss_name="test_loss")

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        return optimizer


def main():
    model = LitClassifier()
    data_module = MNISTDataModule()
    trainer = Trainer(
        gpus=2,
        max_epochs=5,
        num_sanity_val_steps=0,
        logger=TensorBoardLogger("mnist_logs", name="mnist"),
        accelerator="ddp",
    )
    trainer.fit(model, data_module)
    trainer.test(ckpt_path="best")


if __name__ == "__main__":
    main()

Screenshot 2021-09-24 at 12 39 15

Best,
T.C

@tchaton
Copy link
Contributor

tchaton commented Sep 24, 2021

I would be closing the issue as everything is working as expected.

Best,
T.C

@tchaton tchaton closed this as completed Sep 24, 2021
@raman-r-4978
Copy link
Author

raman-r-4978 commented Sep 24, 2021

Thanks for the corrections in self.log. But I don't think I got the answer. You haven't tried with 2 data loaders. You would get totally different graph. I have generated this graph with the code you suggested above.

Screenshot 2021-09-24 at 6 31 02 PM

@raman-r-4978
Copy link
Author

This is when PL1.3.8 is used for the same code. May I know why there is such a difference

Screenshot 2021-09-24 at 6 39 19 PM

@raman-r-4978
Copy link
Author

raman-r-4978 commented Sep 24, 2021

FYI - These are the results when PL1.3.8 is used. You can easily notice that aggregated average values (tot_val_loss) are correct. i.e.tot_val_loss = (val_loss/dataloader_idx_0 + val_loss/dataloader_idx_1) / 2

Screenshot 2021-09-24 at 11 21 09 PM

@tchaton tchaton reopened this Sep 25, 2021
@raman-r-4978
Copy link
Author

I get the same results when I use gpus=1, so It is not an issue with ddp?

@awaelchli
Copy link
Contributor

Here is a smaller, synthetic repro example:

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loggers import TensorBoardLogger


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class Dataset1(Dataset):
    def __getitem__(self, item):
        return [1, 2, 3, 4, 5][item]

    def __len__(self):
        return 5


class Dataset2(Dataset):
    def __getitem__(self, item):
        return [2, 4, 6, 8, 10][item]

    def __len__(self):
        return 5


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx, dataset_idx):
        if self.current_epoch == 0:
            self.log("val_loss", batch, on_step=False, on_epoch=True, prog_bar=True, logger=True)
            return batch.item()
        else:
            self.log("val_loss", batch * 10, on_step=False, on_epoch=True, prog_bar=True, logger=True)
            return batch.item() * 10

    def validation_epoch_end(self, outputs):
        if self.current_epoch == 0:
            assert sum(outputs[0]) / 5 == 3
            assert sum(outputs[1]) / 5 == 6
        else:
            assert sum(outputs[0]) / 5 == 30
            assert sum(outputs[1]) / 5 == 60

        tot_loss = torch.tensor(0.0)
        for loss in outputs:
            tot_loss += sum(loss) / len(loss)
        tot_loss = tot_loss / len(outputs)
        if self.current_epoch == 0:
            assert tot_loss == (3 + 6) / 2
        else:
            assert tot_loss == (30 + 60) / 2
        self.log("tot_val_loss", tot_loss, prog_bar=True, logger=True)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data1 = DataLoader(Dataset1())
    val_data2 = DataLoader(Dataset2())

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        num_sanity_val_steps=0,
        max_epochs=3,
        log_every_n_steps=1,
        weights_summary=None,
        logger=TensorBoardLogger("mnist_logs", name="mnist2"),
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=[val_data1, val_data2])


if __name__ == "__main__":
    run()

The assertions hold, but the progress bar logging (and tb logging) do not show the same values. It is clear that the tracking for the loss of the individual dataloader_idx parts do not get reset from one epoch to the next. Instead, they keep aggregating.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on priority: 0 High priority task
Projects
None yet
3 participants