Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The Lightning 2.0 Trainer issues length-zero DataLoader / CombinedLoader warning when num_sanity_val_steps=0 #17193

Closed
mishooax opened this issue Mar 25, 2023 · 5 comments · Fixed by #17218
Labels
bug Something isn't working data handling Generic data-related topic

Comments

@mishooax
Copy link

mishooax commented Mar 25, 2023

Bug description

Hello, after upgrading to pytorch lightning 2.0 my trainer.fit started issuing the following warnings:

(...)pytorch_lightning/utilities/data.py:105: UserWarning: Total length of `CombinedLoader` across ranks is zero. Please make sure this was your intention.
  rank_zero_warn(
(...)pytorch_lightning/utilities/data.py:105: UserWarning: Total length of `DataLoader` across ranks is zero. Please make sure this was your intention.
  rank_zero_warn(

This is a single-gpu run with strategy="auto", devices=1, accelerator="gpu" and num_sanity_val_steps=0. I suspect the removal of the validation sanity check may be leading to this warning? My val_dataloader definitely has length >> 1 (I am waiting for the validation epoch to run through at this very moment...)

How to reproduce the bug

No response

Error messages and logs

# Error messages and logs here please
(...)pytorch_lightning/utilities/data.py:105: UserWarning: Total length of `CombinedLoader` across ranks is zero. Please make sure this was your intention.
  rank_zero_warn(
(...)pytorch_lightning/utilities/data.py:105: UserWarning: Total length of `DataLoader` across ranks is zero. Please make sure this was your intention.
  rank_zero_warn(

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): Trainer
#- PyTorch Lightning Version (e.g., 1.5.0): 2.0
#- Lightning App Version (e.g., 0.5.2): n/a
#- PyTorch Version (e.g., 2.0): 2.0
#- Python version (e.g., 3.9): 3.9
#- OS (e.g., Linux): Linux
#- CUDA/cuDNN version: 11.8
#- GPU models and configuration: A100
#- How you installed Lightning(`conda`, `pip`, source): conda
#- Running environment of LightningApp (e.g. local, cloud): n/a

More info

No response

cc @justusschock @awaelchli

@mishooax mishooax added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Mar 25, 2023
@awaelchli awaelchli added data handling Generic data-related topic and removed needs triage Waiting to be triaged by maintainers labels Mar 26, 2023
@awaelchli
Copy link
Contributor

I don't see the warning when setting num_sanity_val_steps=0. Could you include your data loading code in the below snippet so we can reproduce it?

import os

import torch
from torch.utils.data import DataLoader, Dataset

from lightning.pytorch import LightningModule, Trainer


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)



if __name__ == "__main__":
    run()

@mishooax
Copy link
Author

Thanks @awaelchli . I'll try to come up with a minimal example built on top of the code you shared above. Won't be easy as the dataset code is rather long and coupled to some other stuff. Just to say, my dataloader is built around an IterableDataset - maybe that's causing the issue (i.e. the Lightning code can't figure out the total length of the dataloader?)

@yosefahab
Copy link

getting a similar warning when using a custom IterableDataset
Total length of 'CombinedLoader' across ranks is zero. Please make sure this was your intention.
-as well some some semaphor leaks interestingly enough 🤔-

@awaelchli
Copy link
Contributor

Yes, it is because of the IterableDataset. With this hint, here is a reproducible example:

import os

import torch
from torch.utils.data import IterableDataset
from torch.utils.data import DataLoader

from lightning.pytorch import LightningModule, Trainer


class RandomDataset(IterableDataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __iter__(self):
        for i in range(self.len):
            yield self.data[i]


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
    )
    trainer.fit(model, train_dataloaders=train_data)


if __name__ == "__main__":
    run()

@carmocca
Copy link
Contributor

The relevant logic to address is in this function: https://github.com/Lightning-AI/lightning/blob/4f82068bcf21f5008aecd46426c806514209112c/src/lightning/pytorch/utilities/data.py#L91-L133

I opened #17218 changing it to skip all those checks if __len__ is not defined, and just check when one or all ranks return a 0 length. This is more in line with what the function originally intended in #9827

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working data handling Generic data-related topic
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants