Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pl_module.log(name, value) does not work in on_*_batch_end hooks #9772

Closed
michalivne opened this issue Sep 30, 2021 · 5 comments
Closed

pl_module.log(name, value) does not work in on_*_batch_end hooks #9772

michalivne opened this issue Sep 30, 2021 · 5 comments
Labels
bug Something isn't working help wanted Open to be worked on logging Related to the `LoggerConnector` and `log()`

Comments

@michalivne
Copy link

🐛 Bug

Logging in Trainer callbacks does not seem to log any value into WandB.
The code has been validated to call the relevant callbacks during training, and to call the pl_module.log(...) method, but without any value being logged. The same callbacks, when implemented inside pl_module, do log the value.

This bug has been observed when working on https://github.com/NVIDIA/NeMo,
and might be related to #4611

To Reproduce

from pytorch_lightning.callbacks import Callback


class TimingCallback(Callback):
    """
    Logs execution time of train/val/test steps
    """

    def _on_batch_start(self, name):
        pass

    def _on_batch_end(self, name, pl_module):
        pl_module.log(name, 0.0)

    def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
        self._on_batch_start("train_step_timing")

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        self._on_batch_end("train_step_timing", pl_module)

    def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
        self._on_batch_start("validation_step_timing")

    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        self._on_batch_end("validation_step_timing", pl_module)

    def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
        self._on_batch_start("test_step_timing")

    def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        self._on_batch_end("test_step_timing", pl_module)


trainer = Trainer(callbacks=[TimingCallback()])

Expected behavior

TImer values (a constant 0.0 above) to be logged.

Environment

* CUDA:
        - GPU:
                - Quadro GV100
                - Quadro GV100
                - Quadro GV100
                - Quadro GV100
        - available:         True
        - version:           11.1
* Packages:
        - numpy:             1.20.3
        - pyTorch_debug:     False
        - pyTorch_version:   1.9.0
        - pytorch-lightning: 1.4.8
        - tqdm:              4.62.0
* System:
        - OS:                Linux
        - architecture:
                - 64bit
                - ELF
        - processor:         x86_64
        - python:            3.8.8
        - version:           #242-Ubuntu SMP Fri Apr 16 09:57:56 UTC 2021

Additional context

@michalivne michalivne added bug Something isn't working help wanted Open to be worked on labels Sep 30, 2021
@awaelchli awaelchli added the logging Related to the `LoggerConnector` and `log()` label Sep 30, 2021
@awaelchli
Copy link
Contributor

awaelchli commented Sep 30, 2021

In this section of the docs it says:

Use the log() method to log from anywhere in a lightning module and callbacks except functions with batch_start in their names.

There is no mention of on_*_end hooks, so this is expected to work I assume.

Here is a working example based on your code:

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer

from pytorch_lightning.callbacks import Callback


class TimingCallback(Callback):
    """
    Logs execution time of train/val/test steps
    """

    def _on_batch_start(self, name):
        pass

    def _on_batch_end(self, name, pl_module):
        pl_module.log(name, 0.0, on_step=True, on_epoch=False)

    def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
        self._on_batch_start("train_step_timing")

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        self._on_batch_end("train_step_timing", pl_module)

    def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
        self._on_batch_start("validation_step_timing")

    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        self._on_batch_end("validation_step_timing", pl_module)

    def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
        self._on_batch_start("test_step_timing")

    def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        self._on_batch_end("test_step_timing", pl_module)


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir="tensorboard_logs",
        num_sanity_val_steps=0,
        max_epochs=1,
        weights_summary=None,
        callbacks=[TimingCallback()],
        log_every_n_steps=1,
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)


if __name__ == "__main__":
    run()

With screenshot:

image

Important part is to set on_step=True and on_epoch=False and log_every_n_step to a sufficient small value (for toy examples)

Maybe we should infer the defaults for on_step and on_epoch for these batch hooks automatically like we do for other methods?
cc @carmocca in case I miss the point why it might not be possible.

@awaelchli awaelchli changed the title pl_module.log(name, value) does not work in Trainer.on_{train,validation,test}_batch_end pl_module.log(name, value) does not work in on_{train,validation,test}_batch_end hooks Sep 30, 2021
@awaelchli awaelchli changed the title pl_module.log(name, value) does not work in on_{train,validation,test}_batch_end hooks pl_module.log(name, value) does not work in on_*_batch_end hooks Sep 30, 2021
@michalivne
Copy link
Author

Thanks for the quick reply @awaelchli !
Let me check my logging parameters, that might be the issue.
Will close the issue once tested.

@michalivne
Copy link
Author

michalivne commented Oct 1, 2021

Setting the logging parameters indeed solved the problem. Thanks!

@carmocca
Copy link
Contributor

carmocca commented Oct 1, 2021

Hi!

Logging this with on_step=False, on_epoch=True (default) is supported. But the values will not be accessible until the end of the epoch.

You can verify this by checking trainer.callback_metrics in the on_train_epoch_end hook.

Maybe we should infer the defaults for on_step and on_epoch for these batch hooks automatically like we do for other methods?

Would users expect that values logged in these have those defaults set? Changing this could be problematic in terms of backwards compatibility

@michalivne
Copy link
Author

Perhaps raising a warning in cases where the logging is behaving in a less intuitive manner (such as this case) might help other users in the future?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on logging Related to the `LoggerConnector` and `log()`
Projects
None yet
Development

No branches or pull requests

3 participants