Skip to content

Commit

Permalink
Fix self.log(on_epoch=True, reduce_fx=sum) on_batch_start (#9791)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Oct 19, 2021
1 parent d45897d commit e44921e
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 19 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `broadcast` in `DDPPlugin` and ``DDPSpawnPlugin` to respect the `src` input ([#9691](https://github.com/PyTorchLightning/pytorch-lightning/pull/9691))


- Fixed `self.log(on_epoch=True, reduce_fx=sum))` for the `on_batch_start` and `on_train_batch_start` hooks ([#9791(https://github.com/PyTorchLightning/pytorch-lightning/pull/9791))


- Fixed `self.log(on_epoch=True)` for the `on_batch_start` and `on_train_batch_start` hooks ([#9780](https://github.com/PyTorchLightning/pytorch-lightning/pull/9780))


Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,10 @@ def _on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx:
Raises:
AssertionError: If the number of dataloaders is None (has not yet been set).
"""
self.trainer.logger_connector.on_batch_start(batch_idx)
self.trainer.logger_connector.on_batch_start(batch_idx, batch)

assert self._num_dataloaders is not None
self.trainer.logger_connector.on_evaluation_batch_start(batch, dataloader_idx, self._num_dataloaders)
self.trainer.logger_connector.on_evaluation_batch_start(dataloader_idx, self._num_dataloaders)

if self.trainer.testing:
self.trainer.call_hook("on_test_batch_start", batch, batch_idx, dataloader_idx)
Expand Down
7 changes: 6 additions & 1 deletion pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,15 @@ def advance(self, *args: Any, **kwargs: Any) -> None:

self.batch_progress.increment_ready()

# cache the batch size value to avoid extracting it again after the batch loop runs as the value will be
# different if tbptt is enabled
batch_size = self.trainer.logger_connector.on_batch_start(batch_idx, batch)

if batch is None:
self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...")
batch_output = []
else:
# hook
self.trainer.logger_connector.on_batch_start(batch_idx)
response = self.trainer.call_hook("on_batch_start")
if response == -1:
self.batch_progress.increment_processed()
Expand All @@ -183,6 +186,8 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
with self.trainer.profiler.profile("run_training_batch"):
batch_output = self.batch_loop.run(batch, batch_idx)

self.trainer._results.batch_size = batch_size

self.batch_progress.increment_processed()

# update non-plateau LR schedulers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,11 @@ def _increment_eval_log_step(self) -> None:
elif self.trainer.state.stage is RunningStage.TESTING:
self._test_log_step += 1

def on_evaluation_batch_start(self, batch: Any, dataloader_idx: int, num_dataloaders: int) -> None:
def on_evaluation_batch_start(self, dataloader_idx: int, num_dataloaders: int) -> None:
model = self.trainer.lightning_module
# set dataloader_idx only if multiple ones
model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None

# track batch_size
assert self.trainer._results is not None
self.trainer._results.extract_batch_size(batch)

def update_eval_step_metrics(self) -> None:
if self.trainer.sanity_checking:
return
Expand Down Expand Up @@ -213,12 +209,8 @@ def update_eval_epoch_metrics(self) -> List[_OUT_DICT]:
"""

def on_train_split_start(self, split_idx: int, split_batch: Any) -> None:
assert self.trainer._results is not None
# when the user requests `dataloader_iter`, we can't track the batch_size
# and this is left to user responsibility.
if isinstance(split_batch, pl.utilities.fetching.DataLoaderIterDataFetcher):
self.trainer._results.extract_batch_size(split_batch)
self._split_idx = split_idx
self.on_new_batch(split_batch)

def update_train_step_metrics(self) -> None:
if self.trainer.fit_loop._should_accumulate() and self.trainer.lightning_module.automatic_optimization:
Expand Down Expand Up @@ -261,12 +253,21 @@ def _log_gpus_metrics(self) -> None:
Utilities and properties
"""

def on_new_batch(self, batch: Any) -> int:
# when the user requests `dataloader_iter`, we can't track the batch_size
# and this is left to user responsibility.
if not isinstance(batch, pl.utilities.fetching.StepFuncDataLoaderIter):
assert self.trainer._results is not None
return self.trainer._results.extract_batch_size(batch)
return 1

def on_epoch_start(self) -> None:
self._epoch_end_reached = False

def on_batch_start(self, batch_idx: int) -> None:
def on_batch_start(self, batch_idx: int, batch: Any) -> int:
self._batch_idx = batch_idx
self._epoch_end_reached = False
return self.on_new_batch(batch)

def epoch_end_reached(self) -> None:
self._epoch_end_reached = True
Expand Down
10 changes: 6 additions & 4 deletions pytorch_lightning/trainer/connectors/logger_connector/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def update(self, value: _IN_METRIC, batch_size: torch.Tensor) -> None:
elif self.meta.is_max_reduction or self.meta.is_min_reduction:
self.value = self.meta.reduce_fx(self.value, value.mean())
elif self.meta.is_sum_reduction:
self.value += value.mean() * batch_size
self.value += value.mean()
else:
self.value = value
self._forward_cache = value._forward_cache
Expand Down Expand Up @@ -550,11 +550,13 @@ def fn(item: ResultMetric) -> None:

apply_to_collection(self, ResultMetric, fn)

def extract_batch_size(self, batch: Any) -> None:
def extract_batch_size(self, batch: Any) -> int:
try:
self.batch_size = extract_batch_size(batch)
batch_size = extract_batch_size(batch)
except RecursionError:
self.batch_size = 1
batch_size = 1
self.batch_size = batch_size # the setter converts it to `Tensor`
return batch_size

def to(self, *args: Any, **kwargs: Any) -> "ResultCollection":
"""Move all data to the given device."""
Expand Down
40 changes: 39 additions & 1 deletion tests/trainer/logging_/test_train_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@
import numpy as np
import pytest
import torch
from torch.utils.data import DataLoader
from torchmetrics import Accuracy

from pytorch_lightning import callbacks, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBar
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel, RandomDictDataset
from tests.helpers.boring_model import BoringModel, RandomDataset, RandomDictDataset
from tests.helpers.runif import RunIf


Expand Down Expand Up @@ -699,3 +700,40 @@ def on_before_backward(self, loss: torch.Tensor) -> None:
gpus=1,
)
trainer.fit(TestModel())


def test_on_epoch_logging_with_sum_and_on_batch_start(tmpdir):
class TestModel(BoringModel):
def on_train_epoch_end(self):
assert all(v == 3 for v in self.trainer.callback_metrics.values())

def on_validation_epoch_end(self):
assert all(v == 3 for v in self.trainer.callback_metrics.values())

def on_train_batch_start(self, batch, batch_idx):
assert self.trainer._results.batch_size == 2
self.log("on_train_batch_start", 1.0, reduce_fx="sum")

def on_train_batch_end(self, outputs, batch, batch_idx):
assert self.trainer._results.batch_size == 2
self.log("on_train_batch_end", 1.0, reduce_fx="sum")

def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
assert self.trainer._results.batch_size == 2
self.log("on_validation_batch_start", 1.0, reduce_fx="sum")

def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
assert self.trainer._results.batch_size == 2
self.log("on_validation_batch_end", 1.0, reduce_fx="sum")

model = TestModel()
trainer = Trainer(
enable_progress_bar=False,
limit_train_batches=3,
limit_val_batches=3,
num_sanity_val_steps=3,
max_epochs=1,
)
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)

0 comments on commit e44921e

Please sign in to comment.