-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
pl_module.log(name, value)
does not work in on_*_batch_end
hooks
#9772
Comments
In this section of the docs it says:
There is no mention of Here is a working example based on your code: import os
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback
class TimingCallback(Callback):
"""
Logs execution time of train/val/test steps
"""
def _on_batch_start(self, name):
pass
def _on_batch_end(self, name, pl_module):
pl_module.log(name, 0.0, on_step=True, on_epoch=False)
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
self._on_batch_start("train_step_timing")
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self._on_batch_end("train_step_timing", pl_module)
def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
self._on_batch_start("validation_step_timing")
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self._on_batch_end("validation_step_timing", pl_module)
def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
self._on_batch_start("test_step_timing")
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self._on_batch_end("test_step_timing", pl_module)
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run():
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
trainer = Trainer(
default_root_dir="tensorboard_logs",
num_sanity_val_steps=0,
max_epochs=1,
weights_summary=None,
callbacks=[TimingCallback()],
log_every_n_steps=1,
)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
if __name__ == "__main__":
run() With screenshot: Important part is to set Maybe we should infer the defaults for on_step and on_epoch for these batch hooks automatically like we do for other methods? |
pl_module.log(name, value)
does not work in on_{train,validation,test}_batch_end
hooks
pl_module.log(name, value)
does not work in on_{train,validation,test}_batch_end
hookspl_module.log(name, value)
does not work in on_*_batch_end
hooks
Thanks for the quick reply @awaelchli ! |
Setting the logging parameters indeed solved the problem. Thanks! |
Hi! Logging this with You can verify this by checking
Would users expect that values logged in these have those defaults set? Changing this could be problematic in terms of backwards compatibility |
Perhaps raising a warning in cases where the logging is behaving in a less intuitive manner (such as this case) might help other users in the future? |
🐛 Bug
Logging in Trainer callbacks does not seem to log any value into WandB.
The code has been validated to call the relevant callbacks during training, and to call the
pl_module.log(...)
method, but without any value being logged. The same callbacks, when implemented inside pl_module, do log the value.This bug has been observed when working on https://github.com/NVIDIA/NeMo,
and might be related to #4611
To Reproduce
Expected behavior
TImer values (a constant 0.0 above) to be logged.
Environment
Additional context
The text was updated successfully, but these errors were encountered: