Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix self.log(on_epoch=True, reduce_fx=sum) on_batch_start #9791

Merged
merged 13 commits into from
Oct 19, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `broadcast` in `DDPPlugin` and ``DDPSpawnPlugin` to respect the `src` input ([#9691](https://github.com/PyTorchLightning/pytorch-lightning/pull/9691))


- Fixed `self.log(on_epoch=True, reduce_fx=sum))` for the `on_batch_start` and `on_train_batch_start` hooks ([#9791(https://github.com/PyTorchLightning/pytorch-lightning/pull/9791))


- Fixed `self.log(on_epoch=True)` for the `on_batch_start` and `on_train_batch_start` hooks ([#9780](https://github.com/PyTorchLightning/pytorch-lightning/pull/9780))


Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,10 @@ def _on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx:
Raises:
AssertionError: If the number of dataloaders is None (has not yet been set).
"""
self.trainer.logger_connector.on_batch_start(batch_idx)
self.trainer.logger_connector.on_batch_start(batch_idx, batch)

assert self._num_dataloaders is not None
self.trainer.logger_connector.on_evaluation_batch_start(batch, dataloader_idx, self._num_dataloaders)
self.trainer.logger_connector.on_evaluation_batch_start(dataloader_idx, self._num_dataloaders)

if self.trainer.testing:
self.trainer.call_hook("on_test_batch_start", batch, batch_idx, dataloader_idx)
Expand Down
7 changes: 6 additions & 1 deletion pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,15 @@ def advance(self, *args: Any, **kwargs: Any) -> None:

self.batch_progress.increment_ready()

# cache the batch size value to avoid extracting it again after the batch loop runs as the value will be
# different if tbptt is enabled
batch_size = self.trainer.logger_connector.on_batch_start(batch_idx, batch)
carmocca marked this conversation as resolved.
Show resolved Hide resolved

if batch is None:
self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...")
batch_output = []
else:
# hook
self.trainer.logger_connector.on_batch_start(batch_idx)
response = self.trainer.call_hook("on_batch_start")
if response == -1:
self.batch_progress.increment_processed()
Expand All @@ -183,6 +186,8 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
with self.trainer.profiler.profile("run_training_batch"):
batch_output = self.batch_loop.run(batch, batch_idx)

self.trainer._results.batch_size = batch_size

self.batch_progress.increment_processed()

# update non-plateau LR schedulers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,11 @@ def _increment_eval_log_step(self) -> None:
elif self.trainer.state.stage is RunningStage.TESTING:
self._test_log_step += 1

def on_evaluation_batch_start(self, batch: Any, dataloader_idx: int, num_dataloaders: int) -> None:
def on_evaluation_batch_start(self, dataloader_idx: int, num_dataloaders: int) -> None:
model = self.trainer.lightning_module
# set dataloader_idx only if multiple ones
model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None

# track batch_size
assert self.trainer._results is not None
self.trainer._results.extract_batch_size(batch)

def update_eval_step_metrics(self) -> None:
if self.trainer.sanity_checking:
return
Expand Down Expand Up @@ -213,12 +209,8 @@ def update_eval_epoch_metrics(self) -> List[_OUT_DICT]:
"""

def on_train_split_start(self, split_idx: int, split_batch: Any) -> None:
assert self.trainer._results is not None
# when the user requests `dataloader_iter`, we can't track the batch_size
# and this is left to user responsibility.
if isinstance(split_batch, pl.utilities.fetching.DataLoaderIterDataFetcher):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self.trainer._results.extract_batch_size(split_batch)
self._split_idx = split_idx
self.on_new_batch(split_batch)

def update_train_step_metrics(self) -> None:
if self.trainer.fit_loop._should_accumulate() and self.trainer.lightning_module.automatic_optimization:
Expand Down Expand Up @@ -261,12 +253,21 @@ def _log_gpus_metrics(self) -> None:
Utilities and properties
"""

def on_new_batch(self, batch: Any) -> int:
# when the user requests `dataloader_iter`, we can't track the batch_size
# and this is left to user responsibility.
if not isinstance(batch, pl.utilities.fetching.StepFuncDataLoaderIter):
assert self.trainer._results is not None
return self.trainer._results.extract_batch_size(batch)
return 1

def on_epoch_start(self) -> None:
self._epoch_end_reached = False

def on_batch_start(self, batch_idx: int) -> None:
def on_batch_start(self, batch_idx: int, batch: Any) -> int:
self._batch_idx = batch_idx
self._epoch_end_reached = False
return self.on_new_batch(batch)

def epoch_end_reached(self) -> None:
self._epoch_end_reached = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def update(self, value: _IN_METRIC, batch_size: torch.Tensor) -> None:
elif self.meta.is_max_reduction or self.meta.is_min_reduction:
self.value = self.meta.reduce_fx(self.value, value.mean())
elif self.meta.is_sum_reduction:
self.value += value.mean() * batch_size
self.value += value.mean()
carmocca marked this conversation as resolved.
Show resolved Hide resolved
else:
self.value = value
self._forward_cache = value._forward_cache
Expand Down Expand Up @@ -550,11 +550,13 @@ def fn(item: ResultMetric) -> None:

apply_to_collection(self, ResultMetric, fn)

def extract_batch_size(self, batch: Any) -> None:
def extract_batch_size(self, batch: Any) -> int:
try:
self.batch_size = extract_batch_size(batch)
batch_size = extract_batch_size(batch)
except RecursionError:
self.batch_size = 1
batch_size = 1
self.batch_size = batch_size # the setter converts it to `Tensor`
return batch_size

def to(self, *args: Any, **kwargs: Any) -> "ResultCollection":
"""Move all data to the given device."""
Expand Down
40 changes: 39 additions & 1 deletion tests/trainer/logging_/test_train_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@
import numpy as np
import pytest
import torch
from torch.utils.data import DataLoader
from torchmetrics import Accuracy

from pytorch_lightning import callbacks, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBar
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel, RandomDictDataset
from tests.helpers.boring_model import BoringModel, RandomDataset, RandomDictDataset
from tests.helpers.runif import RunIf


Expand Down Expand Up @@ -721,3 +722,40 @@ def on_before_backward(self, loss: torch.Tensor) -> None:
gpus=1,
)
trainer.fit(TestModel())


def test_on_epoch_logging_with_sum_and_on_batch_start(tmpdir):
class TestModel(BoringModel):
def on_train_epoch_end(self):
assert all(v == 3 for v in self.trainer.callback_metrics.values())

def on_validation_epoch_end(self):
assert all(v == 3 for v in self.trainer.callback_metrics.values())

def on_train_batch_start(self, batch, batch_idx):
assert self.trainer._results.batch_size == 2
self.log("on_train_batch_start", 1.0, reduce_fx="sum")

def on_train_batch_end(self, outputs, batch, batch_idx):
assert self.trainer._results.batch_size == 2
self.log("on_train_batch_end", 1.0, reduce_fx="sum")

def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
assert self.trainer._results.batch_size == 2
self.log("on_validation_batch_start", 1.0, reduce_fx="sum")

def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
assert self.trainer._results.batch_size == 2
self.log("on_validation_batch_end", 1.0, reduce_fx="sum")

model = TestModel()
trainer = Trainer(
enable_progress_bar=False,
limit_train_batches=3,
limit_val_batches=3,
num_sanity_val_steps=3,
max_epochs=1,
)
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)