Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize fit_loop() to reduce train_dataloader()'s memory footprint #20382

Open
guillaume-rochette-oxb opened this issue Nov 1, 2024 · 2 comments
Labels
feature Is an improvement or enhancement repro needed The issue is missing a reproducible example

Comments

@guillaume-rochette-oxb
Copy link

guillaume-rochette-oxb commented Nov 1, 2024

Description & Motivation

Hi,

I have noticed that the train_dataloader()'s workers were still up, idle but withholding resources, whilst the val_dataloader()'s would be actively delivering batches.
After some investigation, I found the following pseudo-code describing fit(), here simplified:

def fit(self):
    [...]
    for epoch in epochs:
        fit_loop()
    [...]

def fit_loop():
    [...]
    for batch in train_dataloader():
        [...]
        if should_check_val:
            val_loop()
        [...]
    [...]

def val_loop():
    [...]
    for batch in val_dataloader():
        [...]
    [...]

And the actual behaviour matches the pseudo code, so this is not a bug and is working as intended.

However, I've been struggling to maintain the equilibrium between data processing speed and memory footprint when running instance segmentation runs on large and dense non-public datasets.

I understand that when val_check_interval is different than None, running the val_loop within the train_dataloader() loop is necessary. However, in when the val_check_interval is None, I think that it would be beneficial to modify the fit_loop() to something like,

def fit_loop():
    [...]
    for batch in train_dataloader():
        [...]
        if should_check_val and val_check_interval is not None:
            val_loop()
        [...]
    [...]
    if should_check_val and val_check_interval is None:
        val_loop()
    [...]

That way resources would be freed as soon as they're not needed.

Pitch

Within the implementation, the val_loop() is called within on_advance_end(), and the fit_loop() within run() is considerably different than the pseudo-code.
I'm assuming that we need to modify and re-use on_advance_end() after the completion of the while-loop in run().

Is this correct?

Alternatives

No response

Additional context

I have made this boring.py to illustrate the situation and have a concrete example to debug on,

import torch
from torch import Tensor
from torch.nn import Linear, MSELoss
from torch.optim import AdamW
from torch.utils.data import ConcatDataset, Dataset, DataLoader

from torchmetrics import regression

from lightning.pytorch import LightningModule, LightningDataModule, Trainer


class BoringDataset(Dataset):
    def __init__(self, num_samples: int):
        self.num_samples = num_samples

    def __len__(self):
        return self.num_samples

    def __getitem__(self, index: int) -> dict[str, Tensor]:
        x = torch.randn(1, dtype=torch.float32)
        y = 5.0 * x + 2.0
        return {"x": x, "y": y}


class BoringDataModule(LightningDataModule):
    train_datasets: list[BoringDataset]
    val_datasets: list[BoringDataset]
    test_datasets: list[BoringDataset]
    predict_datasets: list[BoringDataset]

    def __init__(
        self, num_datasets: int, num_samples: int, batch_size: int, num_workers: int
    ):
        super().__init__()
        self.num_datasets = num_datasets
        self.num_samples = num_samples
        self.batch_size = batch_size
        self.num_workers = num_workers

    def prepare_data(self):
        pass

    def setup(self, stage: str | None = None):
        assert stage in ["all", "fit", "validate", "test", "predict", None]

        if stage in ["fit", "all"]:
            self.train_datasets = [
                BoringDataset(num_samples=self.num_samples)
                for _ in range(self.num_datasets)
            ]

        if stage in ["fit", "validate", "all"]:
            self.val_datasets = [
                BoringDataset(num_samples=self.num_samples)
                for _ in range(self.num_datasets)
            ]

        if stage in ["test", "all"]:
            self.test_datasets = [
                BoringDataset(num_samples=self.num_samples)
                for _ in range(self.num_datasets)
            ]

        if stage in ["predict", "all"]:
            self.predict_datasets = [
                BoringDataset(num_samples=self.num_samples)
                for _ in range(self.num_datasets)
            ]

    def teardown(self, stage: str | None = None):
        assert stage in ["all", "fit", "validate", "test", "predict", None]

        if stage in ["fit", "all"]:
            del self.train_datasets

        if stage in ["fit", "validate", "all"]:
            del self.val_datasets

        if stage in ["test", "all"]:
            del self.test_datasets

        if stage in ["predict", "all"]:
            del self.predict_datasets

    def train_dataloader(
        self,
    ) -> DataLoader:
        kwargs = {
            "batch_size": self.batch_size,
            "num_workers": self.num_workers,
            "pin_memory": False,
            "drop_last": True,
            "persistent_workers": False,
            "shuffle": True,
        }
        dataloader = DataLoader(ConcatDataset(self.train_datasets), **kwargs)
        return dataloader

    def val_dataloader(self) -> list[DataLoader]:
        kwargs = {
            "batch_size": self.batch_size,
            "num_workers": self.num_workers,
            "pin_memory": False,
            "drop_last": False,
            "persistent_workers": False,
            "shuffle": False,
        }
        dataloaders = [DataLoader(dataset, **kwargs) for dataset in self.val_datasets]
        return dataloaders

    def test_dataloader(self) -> list[DataLoader]:
        kwargs = {
            "batch_size": self.batch_size,
            "num_workers": self.num_workers,
            "pin_memory": False,
            "drop_last": False,
            "persistent_workers": False,
            "shuffle": False,
        }
        dataloaders = [DataLoader(dataset, **kwargs) for dataset in self.test_datasets]
        return dataloaders

    def predict_dataloader(self) -> list[DataLoader]:
        kwargs = {
            "batch_size": self.batch_size,
            "num_workers": self.num_workers,
            "pin_memory": False,
            "drop_last": False,
            "persistent_workers": False,
            "shuffle": False,
        }
        dataloaders = [
            DataLoader(dataset, **kwargs) for dataset in self.predict_datasets
        ]
        return dataloaders


class BoringModule(LightningModule):
    val_dataloader_idx: int = 0
    test_dataloader_idx: int = 0
    predict_dataloader_idx: int = 0

    def __init__(
        self, num_datasets: int, num_samples: int, batch_size: int, num_workers: int
    ):
        super().__init__()

        self.num_datasets = num_datasets
        self.num_samples = num_samples
        self.batch_size = batch_size
        self.num_workers = num_workers

    def prepare_data(self):
        pass

    def setup(self, stage: str | None = None):
        assert stage in ["all", "fit", "validate", "test", "predict", None]

        self.datamodule = BoringDataModule(
            num_datasets=self.num_datasets,
            num_samples=self.num_samples,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
        )
        self.datamodule.setup(stage=stage)

        if stage in ["fit", "all"]:
            self.loss_function = MSELoss()
            self.train_metric = regression.MeanSquaredError()

        if stage in ["fit", "validate", "all"]:
            self.val_metric = regression.MeanSquaredError()

        if stage in ["test", "all"]:
            self.test_metric = regression.MeanSquaredError()

        if stage in ["predict", "all"]:
            self.predict_metric = regression.MeanSquaredError()

    def configure_model(self):
        self.model = Linear(in_features=1, out_features=1, bias=True)

    def teardown(self, stage: str | None = None):
        assert stage in ["fit", "validate", "test", "predict", "all", None]

        self.datamodule.teardown(stage=stage)
        del self.datamodule

        del self.model

        if stage in ["fit", "all"]:
            del self.loss_function
            del self.train_metric

        if stage in ["fit", "validate", "all"]:
            del self.val_metric

        if stage in ["test", "all"]:
            del self.test_metric

        if stage in ["predict", "all"]:
            del self.predict_metric

    def train_dataloader(self) -> DataLoader:
        return self.datamodule.train_dataloader()

    def val_dataloader(self) -> list[DataLoader]:
        return self.datamodule.val_dataloader()

    def test_dataloader(self) -> list[DataLoader]:
        return self.datamodule.test_dataloader()

    def predict_dataloader(self) -> list[DataLoader]:
        return self.datamodule.predict_dataloader()

    def forward(self, input: dict) -> dict:
        return {
            "y": self.model(input["x"]),
        }

    def training_step(
        self,
        input: dict,
        batch_idx: int,
    ) -> Tensor:
        output = self(input)

        train_loss = self.loss_function(input=output["y"], target=input["y"])
        self.train_metric.update(preds=output["y"], target=input["y"])

        self.log_dict(
            dictionary={"train_loss": train_loss},
            prog_bar=True,
            sync_dist=not self.training,
            add_dataloader_idx=False,
        )

        self.log_dict(
            dictionary={"train_metric": self.train_metric},
            sync_dist=not self.training,
            add_dataloader_idx=False,
        )

        return train_loss

    def validation_step(
        self,
        input: dict,
        batch_idx: int,
        dataloader_idx: int = 0,
    ):
        if self.val_dataloader_idx != dataloader_idx:
            self.val_dataloader_idx = dataloader_idx
            self.val_metric.reset()

        output = self(input)

        self.val_metric.update(preds=output["y"], target=input["y"])

        self.log_dict(
            dictionary={f"val_metric/{dataloader_idx}": self.val_metric},
            sync_dist=not self.training,
            add_dataloader_idx=False,
        )

    def test_step(
        self,
        input: dict,
        batch_idx: int,
        dataloader_idx: int = 0,
    ):
        if self.test_dataloader_idx != dataloader_idx:
            self.test_dataloader_idx = dataloader_idx
            self.test_metric.reset()

        output = self(input)

        self.test_metric.update(preds=output["y"], target=input["y"])

        self.log_dict(
            dictionary={f"test_metric/{dataloader_idx}": self.test_metric},
            sync_dist=not self.training,
            add_dataloader_idx=False,
        )

    def predict_step(
        self,
        input: dict,
        batch_idx: int,
        dataloader_idx: int = 0,
    ):
        if self.predict_dataloader_idx != dataloader_idx:
            self.predict_dataloader_idx = dataloader_idx
            self.predict_metric.reset()

        output = self(input)

        self.predict_metric.update(preds=output["y"], target=input["y"])

        self.log_dict(
            dictionary={f"predict_metric/{dataloader_idx}": self.predict_metric},
            sync_dist=not self.training,
            add_dataloader_idx=False,
        )

    def configure_optimizers(self):
        return {
            "optimizer": AdamW(
                self.model.parameters(),
                lr=1e-1,
            ),
        }


def main():
    module = BoringModule(
        num_datasets=2,
        num_samples=10000,
        batch_size=32,
        num_workers=1,
    )
    trainer = Trainer(
        logger=True,
        max_epochs=10,
        num_sanity_val_steps=0,
        log_every_n_steps=1,
        gradient_clip_val=1.0,
        benchmark=True,
        detect_anomaly=False,
        sync_batchnorm=True,
        # reload_dataloaders_every_n_epochs=0, # Neither of those two options have any effect
        # reload_dataloaders_every_n_epochs=1, # on the lifetime of the train_dataloader()'s workers
    )
    trainer.fit(model=module)


if __name__ == "__main__":
    main()

cc @Borda

@guillaume-rochette-oxb guillaume-rochette-oxb added feature Is an improvement or enhancement needs triage Waiting to be triaged by maintainers labels Nov 1, 2024
@guillaume-rochette-oxb guillaume-rochette-oxb changed the title Optimize fit_loop() to reduce train_dataloader's memory footprint Optimize fit_loop() to reduce train_dataloader()'s memory footprint Nov 1, 2024
@guillaume-rochette-oxb guillaume-rochette-oxb changed the title Optimize fit_loop() to reduce train_dataloader()'s memory footprint Optimize fit_loop() to reduce train_dataloader()'s memory footprint Nov 1, 2024
@lantiga
Copy link
Collaborator

lantiga commented Nov 12, 2024

hey @guillaume-rochette-oxb, if you take a look at fit_loop.setup_data, tearing down dataloaders mid-flight or even in-between epochs could have several implications

A more benign option is to make sure that resources are freed up when it's time to load the first validation sample.
For instance did you try emptying the cache when you reach end of epoch (or start validation?).

def on_train_epoch_end(self):
    torch.cuda.empty_cache()

I'm trying to understand where resources are held exactly in your production case (I'm expecting you're not holding datasets in memory as in your examples). If you find out we can figure out how to help you deallocate.

@lantiga lantiga added repro needed The issue is missing a reproducible example and removed needs triage Waiting to be triaged by maintainers labels Nov 12, 2024
@guillaume-rochette-oxb
Copy link
Author

Hi @lantiga,

Maybe I was not clear, I am not trying to free the cache of the VRAM here, there's no problem with it.

I am however trying to give free the RAM that is held by the idle workers of the exhausted train_dataloader(), and that would require indeed tearing down the train_dataloader() after each epoch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement repro needed The issue is missing a reproducible example
Projects
None yet
Development

No branches or pull requests

2 participants