Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient norms are not logged unless row_log_interval==1 #3487

Closed
Tim-Chard opened this issue Sep 13, 2020 · 2 comments · Fixed by #3489
Closed

Gradient norms are not logged unless row_log_interval==1 #3487

Tim-Chard opened this issue Sep 13, 2020 · 2 comments · Fixed by #3489
Assignees
Labels
bug Something isn't working help wanted Open to be worked on priority: 0 High priority task

Comments

@Tim-Chard
Copy link
Contributor

🐛 Bug

In version 0.9 the guards to calculate the gradient norms and then log the metrics can't be satisfied in the same batch unless the row_log_interval is 1. In most places the guard seems to be (batch_idx + 1) % self.row_log_interval == 0 such as here:

https://github.com/PyTorchLightning/pytorch-lightning/blob/b40de5464a953ff5866a255f4670d318bd8fd65a/pytorch_lightning/trainer/training_loop.py#L749-L757

However in run_batch_backward_pass it is batch_idx % self.row_log_interval == 0

https://github.com/PyTorchLightning/pytorch-lightning/blob/b40de5464a953ff5866a255f4670d318bd8fd65a/pytorch_lightning/trainer/training_loop.py#L929-L939

To Reproduce

Steps to reproduce the behavior:

  1. Run the code sample below (taken from Track_grad_norm only tracks the parameters of the last optimizer defined #1527 ).
  2. Confirm that gradients are not being logged in tensorboard.
  3. Change row_log_interval to 1 and rerun the code.
    4 Confirm that gradients are now being logged.

Code sample

import pytorch_lightning as pl
from torch.utils.data import TensorDataset, DataLoader
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from torch.optim import SGD
import torch.nn as nn
import torch


class MWENet(pl.LightningModule):
    def __init__(self):
        super(MWENet, self).__init__()

        self.first = nn.Conv2d(1, 1, 3)
        self.second = nn.Conv2d(1, 1, 3)
        self.loss = nn.L1Loss()

    def train_dataloader(self):
        xs, ys = torch.zeros(16, 1, 10, 10), torch.ones(16, 1, 6, 6)
        ds = TensorDataset(xs, ys)
        return DataLoader(ds)

    def forward(self, xs):
        out = self.first(xs)
        out = self.second(out)
        return out

    def configure_optimizers(self):
        first = SGD(self.first.parameters(), lr=0.01)
        second = SGD(self.second.parameters(), lr=0.01)
        return [second, first]

    def training_step(self, batch, batch_idx, optimizer_idx):
        xs, ys = batch
        out = self.forward(xs)
        return {'loss': self.loss(out, ys)}


net = MWENet()
logger = TensorBoardLogger('tb_logs', name='testing')
trainer = pl.Trainer(
    track_grad_norm=2,
    row_log_interval=2,
    max_epochs=50,
    logger=logger)
trainer.fit(net)

Expected behavior

Gradients should be logged if track_grad_norm is True

Environment

  • CUDA:
    - GPU:
    - GeForce GTX 1080 Ti
    - available: True
    - version: 10.2
  • Packages:
    - numpy: 1.19.1
    - pyTorch_debug: False
    - pyTorch_version: 1.6.0
    - pytorch-lightning: 0.9.0
    - tensorboard: 2.2.1
    - tqdm: 4.48.2
  • System:
    - OS: Windows
    - architecture:
    - 64bit
    - WindowsPE
    - processor: AMD64 Family 23 Model 113 Stepping 0, AuthenticAMD
    - python: 3.7.9
    - version: 10.0.19041

Additional context

@Tim-Chard Tim-Chard added bug Something isn't working help wanted Open to be worked on labels Sep 13, 2020
@github-actions
Copy link
Contributor

Hi! thanks for your contribution!, great first issue!

@rohitgr7
Copy link
Contributor

yeah, it's a bug, mind send a PR?

@awaelchli awaelchli self-assigned this Sep 14, 2020
@Borda Borda added the priority: 0 High priority task label Sep 15, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on priority: 0 High priority task
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants