Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix log_every_n_steps check in ThroughputMonitor #19470

Merged
merged 9 commits into from
Feb 15, 2024

Conversation

awaelchli
Copy link
Contributor

@awaelchli awaelchli commented Feb 14, 2024

What does this PR do?

Fixes #19461

The ThroughputMonitor has a check for whether accumulate_grad_batches divides log_every_n_steps. But since logging only happens when the optimizer steps (at the end of accumulation), the error checking is not needed and incorrect. This PR removes the check.

Example:

import os

import torch
from lightning.pytorch import LightningModule, Trainer
from torch.utils.data import DataLoader, Dataset
from lightning.pytorch.callbacks import ThroughputMonitor


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)

model = BoringModel()
trainer = Trainer(
    callbacks=ThroughputMonitor(batch_size_fn=lambda batch: batch.size(0)),
    default_root_dir=os.getcwd(),
    max_epochs=2,
    log_every_n_steps=2,
    accumulate_grad_batches=5,
    enable_model_summary=False,
)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)

📚 Documentation preview 📚: https://pytorch-lightning--19470.org.readthedocs.build/en/19470/

cc @Borda @carmocca

@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Feb 14, 2024
Copy link
Contributor

github-actions bot commented Feb 14, 2024

⚡ Required checks status: All passing 🟢

Groups summary

🟢 pytorch_lightning: Tests workflow
Check ID Status
pl-cpu (macOS-11, lightning, 3.8, 1.13, oldest) success
pl-cpu (macOS-11, lightning, 3.10, 1.13) success
pl-cpu (macOS-11, lightning, 3.10, 2.1) success
pl-cpu (macOS-11, lightning, 3.10, 2.2) success
pl-cpu (ubuntu-20.04, lightning, 3.8, 1.13, oldest) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 1.13) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 2.1) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 2.2) success
pl-cpu (windows-2022, lightning, 3.8, 1.13, oldest) success
pl-cpu (windows-2022, lightning, 3.10, 1.13) success
pl-cpu (windows-2022, lightning, 3.10, 2.1) success
pl-cpu (windows-2022, lightning, 3.10, 2.2) success
pl-cpu (macOS-11, pytorch, 3.8, 1.13) success
pl-cpu (ubuntu-20.04, pytorch, 3.8, 1.13) success
pl-cpu (windows-2022, pytorch, 3.8, 1.13) success
pl-cpu (macOS-12, pytorch, 3.11, 2.0) success
pl-cpu (macOS-12, pytorch, 3.11, 2.1) success
pl-cpu (ubuntu-22.04, pytorch, 3.11, 2.0) success
pl-cpu (ubuntu-22.04, pytorch, 3.11, 2.1) success
pl-cpu (windows-2022, pytorch, 3.11, 2.0) success
pl-cpu (windows-2022, pytorch, 3.11, 2.1) success

These checks are required after the changes to src/lightning/pytorch/callbacks/throughput_monitor.py, tests/tests_pytorch/callbacks/test_throughput_monitor.py.

🟢 pytorch_lightning: Azure GPU
Check ID Status
pytorch-lightning (GPUs) (testing Lightning | latest) success
pytorch-lightning (GPUs) (testing PyTorch | latest) success

These checks are required after the changes to src/lightning/pytorch/callbacks/throughput_monitor.py, tests/tests_pytorch/callbacks/test_throughput_monitor.py.

🟢 pytorch_lightning: Benchmarks
Check ID Status
lightning.Benchmarks success

These checks are required after the changes to src/lightning/pytorch/callbacks/throughput_monitor.py.

🟢 pytorch_lightning: Docs
Check ID Status
docs-make (pytorch, doctest) success
docs-make (pytorch, html) success

These checks are required after the changes to src/lightning/pytorch/callbacks/throughput_monitor.py.

🟢 mypy
Check ID Status
mypy success

These checks are required after the changes to src/lightning/pytorch/callbacks/throughput_monitor.py.

🟢 install
Check ID Status
install-pkg (ubuntu-22.04, app, 3.8) success
install-pkg (ubuntu-22.04, app, 3.11) success
install-pkg (ubuntu-22.04, fabric, 3.8) success
install-pkg (ubuntu-22.04, fabric, 3.11) success
install-pkg (ubuntu-22.04, pytorch, 3.8) success
install-pkg (ubuntu-22.04, pytorch, 3.11) success
install-pkg (ubuntu-22.04, lightning, 3.8) success
install-pkg (ubuntu-22.04, lightning, 3.11) success
install-pkg (ubuntu-22.04, notset, 3.8) success
install-pkg (ubuntu-22.04, notset, 3.11) success
install-pkg (macOS-12, app, 3.8) success
install-pkg (macOS-12, app, 3.11) success
install-pkg (macOS-12, fabric, 3.8) success
install-pkg (macOS-12, fabric, 3.11) success
install-pkg (macOS-12, pytorch, 3.8) success
install-pkg (macOS-12, pytorch, 3.11) success
install-pkg (macOS-12, lightning, 3.8) success
install-pkg (macOS-12, lightning, 3.11) success
install-pkg (macOS-12, notset, 3.8) success
install-pkg (macOS-12, notset, 3.11) success
install-pkg (windows-2022, app, 3.8) success
install-pkg (windows-2022, app, 3.11) success
install-pkg (windows-2022, fabric, 3.8) success
install-pkg (windows-2022, fabric, 3.11) success
install-pkg (windows-2022, pytorch, 3.8) success
install-pkg (windows-2022, pytorch, 3.11) success
install-pkg (windows-2022, lightning, 3.8) success
install-pkg (windows-2022, lightning, 3.11) success
install-pkg (windows-2022, notset, 3.8) success
install-pkg (windows-2022, notset, 3.11) success

These checks are required after the changes to src/lightning/pytorch/callbacks/throughput_monitor.py.


Thank you for your contribution! 💜

Note
This comment is automatically generated and updates for 60 minutes every 180 seconds. If you have any other questions, contact carmocca for help.

@awaelchli awaelchli added bug Something isn't working callback: throughput labels Feb 14, 2024
@awaelchli awaelchli added this to the 2.2.x milestone Feb 14, 2024
@awaelchli awaelchli added the fun Staff contributions outside working hours - to differentiate from the "community" label label Feb 14, 2024
@awaelchli awaelchli marked this pull request as draft February 14, 2024 14:38
@awaelchli awaelchli marked this pull request as ready for review February 14, 2024 14:58
@awaelchli awaelchli removed the fun Staff contributions outside working hours - to differentiate from the "community" label label Feb 14, 2024
Copy link

codecov bot commented Feb 14, 2024

Codecov Report

Merging #19470 (42af40e) into master (d61f6fe) will decrease coverage by 35%.
Report is 10 commits behind head on master.
The diff coverage is 100%.

Additional details and impacted files
@@            Coverage Diff             @@
##           master   #19470      +/-   ##
==========================================
- Coverage      84%      48%     -35%     
==========================================
  Files         450      442       -8     
  Lines       38091    37970     -121     
==========================================
- Hits        31814    18373   -13441     
- Misses       6277    19597   +13320     

Copy link
Contributor

@carmocca carmocca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@awaelchli awaelchli changed the title Fix divisibility check in ThroughputMonitor Fix log_every_n_steps check in ThroughputMonitor Feb 15, 2024
@awaelchli awaelchli changed the title Fix log_every_n_steps check in ThroughputMonitor Fix log_every_n_steps check in ThroughputMonitor Feb 15, 2024
@mergify mergify bot added the ready PRs ready to be merged label Feb 15, 2024
@awaelchli awaelchli merged commit 1967547 into master Feb 15, 2024
85 of 86 checks passed
@awaelchli awaelchli deleted the bugfix/throughput-divisible branch February 15, 2024 20:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working callback: throughput pl Generic label for PyTorch Lightning package ready PRs ready to be merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

TroughputMonitor callback does not work with gradient accumulation in Trainer
3 participants