Skip to content

Commit

Permalink
[bugfix] Resolve metrics not being properly resetted on validation ep…
Browse files Browse the repository at this point in the history
…och end (#9717)

Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
  • Loading branch information
tchaton and carmocca authored Sep 27, 2021
1 parent faee9df commit 64bbebc
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 4 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `lr_find` to generate same results on multiple calls ([#9704](https://github.com/PyTorchLightning/pytorch-lightning/pull/9704))


- Fixed `reset` metrics on validation epoch end ([#9717](https://github.com/PyTorchLightning/pytorch-lightning/pull/9717))



## [1.4.8] - 2021-09-22

- Fixed error reporting in DDP process reconciliation when processes are launched by an external agent ([#9389](https://github.com/PyTorchLightning/pytorch-lightning/pull/9389))
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@ def _on_evaluation_end(self, *args: Any, **kwargs: Any) -> None:
else:
self.trainer.call_hook("on_validation_end", *args, **kwargs)

# reset any `torchmetrics.Metric` and the logger connector state
self.trainer.logger_connector.reset_results(metrics=True)
# reset the logger connector state
self.trainer.logger_connector.reset_results()

def _on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_epoch_start`` and ``on_{validation/test}_epoch_start`` hooks."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,9 @@ def reset_metrics(self) -> None:
self._logged_metrics = {}
self._callback_metrics = {}

def reset_results(self, metrics: Optional[bool] = None) -> None:
def reset_results(self) -> None:
if self.trainer._results is not None:
self.trainer._results.reset(metrics=metrics)
self.trainer._results.reset()

self._batch_idx = None
self._split_idx = None
Expand Down
56 changes: 56 additions & 0 deletions tests/trainer/logging_/test_eval_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,3 +602,59 @@ def validation_step(self, batch, batch_idx):
)

trainer.fit(model)


@pytest.mark.parametrize("val_check_interval", [0.5, 1.0])
def test_multiple_dataloaders_reset(val_check_interval, tmpdir):
class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
out = super().training_step(batch, batch_idx)
value = 1 + batch_idx
if self.current_epoch != 0:
value *= 10
self.log("batch_idx", value, on_step=True, on_epoch=True, prog_bar=True)
return out

def training_epoch_end(self, outputs):
metrics = self.trainer.progress_bar_metrics
v = 15 if self.current_epoch == 0 else 150
assert metrics["batch_idx_epoch"] == (v / 5.0)

def validation_step(self, batch, batch_idx, dataloader_idx):
value = (1 + batch_idx) * (1 + dataloader_idx)
if self.current_epoch != 0:
value *= 10
self.log("val_loss", value, on_step=False, on_epoch=True, prog_bar=True, logger=True)
return value

def validation_epoch_end(self, outputs):
if self.current_epoch == 0:
assert sum(outputs[0]) / 5 == 3
assert sum(outputs[1]) / 5 == 6
else:
assert sum(outputs[0]) / 5 == 30
assert sum(outputs[1]) / 5 == 60

tot_loss = torch.mean(torch.tensor(outputs, dtype=torch.float))
if self.current_epoch == 0:
assert tot_loss == (3 + 6) / 2
else:
assert tot_loss == (30 + 60) / 2
assert self.trainer._results["validation_step.val_loss.0"].cumulated_batch_size == 5
assert self.trainer._results["validation_step.val_loss.1"].cumulated_batch_size == 5

def val_dataloader(self):
return [super().val_dataloader(), super().val_dataloader()]

model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=5,
limit_val_batches=5,
num_sanity_val_steps=0,
val_check_interval=val_check_interval,
max_epochs=3,
log_every_n_steps=1,
weights_summary=None,
)
trainer.fit(model)

0 comments on commit 64bbebc

Please sign in to comment.