From 1aa9d39506fa00558c07900924aa4e54ad9a762d Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 4 Oct 2020 13:36:35 -0400 Subject: [PATCH] Eval epoch can now log independently (#3843) * ref: routed epoch outputs to logger * ref: routed epoch outputs to logger * ref: routed epoch outputs to logger * ref: routed epoch outputs to logger --- .../trainer/connectors/logger_connector.py | 54 +++++++++++++++---- pytorch_lightning/trainer/evaluation_loop.py | 20 +++---- pytorch_lightning/trainer/trainer.py | 8 +-- .../test_validation_steps_result_return.py | 37 ++----------- .../logging/test_eval_loop_logging_1_0.py | 50 ++++++++++++++++- 5 files changed, 112 insertions(+), 57 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector.py index 743af1223f37e..e7fcc0c005fe2 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector.py @@ -22,6 +22,7 @@ from pprint import pprint from typing import Iterable from copy import deepcopy +from collections import ChainMap class LoggerConnector: @@ -105,12 +106,12 @@ def add_progress_bar_metrics(self, metrics): self.trainer.dev_debugger.track_pbar_metrics_history(metrics) - def on_evaluation_epoch_end(self, eval_results, using_eval_result, test_mode): - self._track_callback_metrics(eval_results, using_eval_result) - self._log_on_evaluation_epoch_end_metrics() + def on_evaluation_epoch_end(self, deprecated_eval_results, epoch_logs, using_eval_result, test_mode): + self._track_callback_metrics(deprecated_eval_results, using_eval_result) + self._log_on_evaluation_epoch_end_metrics(epoch_logs) # TODO: deprecate parts of this for 1.0 (when removing results) - self.__process_eval_epoch_end_results_and_log_legacy(eval_results, test_mode) + self.__process_eval_epoch_end_results_and_log_legacy(deprecated_eval_results, test_mode) # get the final loop results eval_loop_results = self._get_evaluate_epoch_results(test_mode) @@ -131,15 +132,43 @@ def _get_evaluate_epoch_results(self, test_mode): self.eval_loop_results = [] return results - def _log_on_evaluation_epoch_end_metrics(self): + def _log_on_evaluation_epoch_end_metrics(self, epoch_logs): step_metrics = self.trainer.evaluation_loop.step_metrics + num_loaders = len(step_metrics) + # clear mem self.trainer.evaluation_loop.step_metrics = [] - num_loaders = len(step_metrics) + if self.trainer.running_sanity_check: + return + + # track all metrics we want to log + metrics_to_log = [] - # process metrics per dataloader + # --------------------------- + # UPDATE EPOCH LOGGED METRICS + # --------------------------- + # (ie: in methods at the val_epoch_end level) + # union the epoch logs with whatever was returned from loaders and reduced + epoch_logger_metrics = epoch_logs.get_epoch_log_metrics() + epoch_pbar_metrics = epoch_logs.get_epoch_pbar_metrics() + + self.logged_metrics.update(epoch_logger_metrics) + self.progress_bar_metrics.update(epoch_pbar_metrics) + + # enable the metrics to be monitored + self.callback_metrics.update(epoch_logger_metrics) + self.callback_metrics.update(epoch_pbar_metrics) + + if len(epoch_logger_metrics) > 0: + metrics_to_log.append(epoch_logger_metrics) + + # -------------------------------- + # UPDATE METRICS PER DATALOADER + # -------------------------------- + # each dataloader aggregated metrics + # now we log all of them for dl_idx, dl_metrics in enumerate(step_metrics): if len(dl_metrics) == 0: continue @@ -162,7 +191,13 @@ def _log_on_evaluation_epoch_end_metrics(self): self.eval_loop_results.append(deepcopy(self.callback_metrics)) # actually log - self.log_metrics(logger_metrics, {}, step=self.trainer.global_step) + if len(epoch_logger_metrics) > 0: + metrics_to_log.append(epoch_logger_metrics) + + # log all the metrics as a s single dict + metrics_to_log = dict(ChainMap(*metrics_to_log)) + if len(metrics_to_log) > 0: + self.log_metrics(metrics_to_log, {}, step=self.trainer.global_step) def __rename_keys_by_dataloader_idx(self, metrics, dataloader_idx, num_loaders): if num_loaders == 1: @@ -240,7 +275,8 @@ def __process_eval_epoch_end_results_and_log_legacy(self, eval_results, test_mod self.trainer.logger_connector.add_progress_bar_metrics(prog_bar_metrics) # log metrics - self.trainer.logger_connector.log_metrics(log_metrics, {}) + if len(log_metrics) > 0: + self.trainer.logger_connector.log_metrics(log_metrics, {}) # track metrics for callbacks (all prog bar, logged and callback metrics) self.trainer.logger_connector.callback_metrics.update(callback_metrics) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 38eb5b9465bc9..0f46d6c4cbddb 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -171,19 +171,23 @@ def evaluation_epoch_end(self, num_dataloaders): using_eval_result = self.is_using_eval_results() # call the model epoch end - eval_results = self.__run_eval_epoch_end(num_dataloaders, using_eval_result) + deprecated_results = self.__run_eval_epoch_end(num_dataloaders, using_eval_result) + + # 1.0 + epoch_logs = self.trainer.get_model()._results # enable returning anything - for r in eval_results: + for i, r in enumerate(deprecated_results): if not isinstance(r, (dict, Result, torch.Tensor)): - return [] + deprecated_results[i] = [] - return eval_results + return deprecated_results, epoch_logs - def log_epoch_metrics(self, eval_results, test_mode): + def log_epoch_metrics(self, deprecated_eval_results, epoch_logs, test_mode): using_eval_result = self.is_using_eval_results() eval_loop_results = self.trainer.logger_connector.on_evaluation_epoch_end( - eval_results, + deprecated_eval_results, + epoch_logs, using_eval_result, test_mode ) @@ -228,10 +232,6 @@ def __run_eval_epoch_end(self, num_dataloaders, using_eval_result): if using_eval_result and not user_reduced: eval_results = self.__auto_reduce_result_objs(outputs) - result = model._results - if len(result) > 0 and eval_results is None: - eval_results = result.get_epoch_log_metrics() - if not isinstance(eval_results, list): eval_results = [eval_results] diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 023a02ae9bfa8..87f85f938532e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -603,10 +603,12 @@ def run_evaluation(self, test_mode: bool = False, max_batches=None): self.evaluation_loop.step_metrics.append(dl_step_metrics) # lightning module method - eval_results = self.evaluation_loop.evaluation_epoch_end(num_dataloaders=len(dataloaders)) + deprecated_eval_results, epoch_logs = self.evaluation_loop.evaluation_epoch_end( + num_dataloaders=len(dataloaders) + ) # bookkeeping - eval_loop_results = self.evaluation_loop.log_epoch_metrics(eval_results, test_mode) + eval_loop_results = self.evaluation_loop.log_epoch_metrics(deprecated_eval_results, epoch_logs, test_mode) self.evaluation_loop.predictions.to_disk() # hook @@ -619,7 +621,7 @@ def run_evaluation(self, test_mode: bool = False, max_batches=None): # hook self.evaluation_loop.on_evaluation_end() - return eval_loop_results, eval_results + return eval_loop_results, deprecated_eval_results def run_test(self): # only load test dataloader for testing diff --git a/tests/trainer/legacy_deprecate_flow_log_tests/test_validation_steps_result_return.py b/tests/trainer/legacy_deprecate_flow_log_tests/test_validation_steps_result_return.py index e7e305adc576d..a43b50c442dac 100644 --- a/tests/trainer/legacy_deprecate_flow_log_tests/test_validation_steps_result_return.py +++ b/tests/trainer/legacy_deprecate_flow_log_tests/test_validation_steps_result_return.py @@ -63,7 +63,7 @@ def test_val_step_result_callbacks(tmpdir): # did not request any metrics to log (except the metrics saying which epoch we are on) assert len(trainer.logger_connector.progress_bar_metrics) == 0 - assert len(trainer.dev_debugger.logged_metrics) == 5 + assert len(trainer.dev_debugger.logged_metrics) == 0 def test_val_step_using_train_callbacks(tmpdir): @@ -112,7 +112,7 @@ def test_val_step_using_train_callbacks(tmpdir): # did not request any metrics to log (except the metrics saying which epoch we are on) assert len(trainer.logger_connector.progress_bar_metrics) == 0 - assert len(trainer.dev_debugger.logged_metrics) == expected_epochs + assert len(trainer.dev_debugger.logged_metrics) == 0 def test_val_step_only_epoch_metrics(tmpdir): @@ -210,40 +210,9 @@ def test_val_step_only_step_metrics(tmpdir): assert len(trainer.dev_debugger.early_stopping_history) == 0 # make sure we logged the exact number of metrics - assert len(trainer.dev_debugger.logged_metrics) == epochs * batches + (epochs) + assert len(trainer.dev_debugger.logged_metrics) == epochs * batches assert len(trainer.dev_debugger.pbar_added_metrics) == epochs * batches + (epochs) - # make sure we logged the correct epoch metrics - total_empty_epoch_metrics = 0 - epoch = 0 - for metric in trainer.dev_debugger.logged_metrics: - if 'epoch' in metric: - epoch += 1 - if len(metric) > 2: - assert 'no_val_no_pbar' not in metric - assert 'val_step_pbar_acc' not in metric - assert metric[f'val_step_log_acc/epoch_{epoch}'] - assert metric[f'val_step_log_pbar_acc/epoch_{epoch}'] - else: - total_empty_epoch_metrics += 1 - - assert total_empty_epoch_metrics == 3 - - # make sure we logged the correct epoch pbar metrics - total_empty_epoch_metrics = 0 - for metric in trainer.dev_debugger.pbar_added_metrics: - if 'epoch' in metric: - epoch += 1 - if len(metric) > 2: - assert 'no_val_no_pbar' not in metric - assert 'val_step_log_acc' not in metric - assert metric['val_step_log_pbar_acc'] - assert metric['val_step_pbar_acc'] - else: - total_empty_epoch_metrics += 1 - - assert total_empty_epoch_metrics == 3 - # only 1 checkpoint expected since values didn't change after that assert len(trainer.dev_debugger.checkpoint_callback_history) == 1 diff --git a/tests/trainer/logging/test_eval_loop_logging_1_0.py b/tests/trainer/logging/test_eval_loop_logging_1_0.py index 9b6b50898e1b9..d417f9a2f6ad4 100644 --- a/tests/trainer/logging/test_eval_loop_logging_1_0.py +++ b/tests/trainer/logging/test_eval_loop_logging_1_0.py @@ -4,9 +4,10 @@ from pytorch_lightning import Trainer from pytorch_lightning import callbacks from tests.base.deterministic_model import DeterministicModel -from tests.base import SimpleModule +from tests.base import SimpleModule, BoringModel import os import torch +import pytest def test__validation_step__log(tmpdir): @@ -148,6 +149,53 @@ def backward(self, trainer, loss, optimizer, optimizer_idx): assert expected_cb_metrics == callback_metrics +@pytest.mark.parametrize(['batches', 'log_interval', 'max_epochs'], [(1, 1, 1), (64, 32, 2)]) +def test_eval_epoch_logging(tmpdir, batches, log_interval, max_epochs): + """ + Tests that only training_step can be used + """ + os.environ['PL_DEV_DEBUG'] = '1' + + class TestModel(BoringModel): + def validation_epoch_end(self, outputs): + self.log('c', torch.tensor(2), on_epoch=True, prog_bar=True, logger=True) + self.log('d/e/f', 2) + + model = TestModel() + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=batches, + limit_val_batches=batches, + max_epochs=max_epochs, + row_log_interval=log_interval, + weights_summary=None, + ) + trainer.fit(model) + + # make sure all the metrics are available for callbacks + logged_metrics = set(trainer.logged_metrics.keys()) + expected_logged_metrics = { + 'c', + 'd/e/f', + } + assert logged_metrics == expected_logged_metrics + + pbar_metrics = set(trainer.progress_bar_metrics.keys()) + expected_pbar_metrics = {'c'} + assert pbar_metrics == expected_pbar_metrics + + callback_metrics = set(trainer.callback_metrics.keys()) + expected_callback_metrics = set() + expected_callback_metrics = expected_callback_metrics.union(logged_metrics) + expected_callback_metrics = expected_callback_metrics.union(pbar_metrics) + callback_metrics.remove('debug_epoch') + assert callback_metrics == expected_callback_metrics + + # assert the loggers received the expected number + assert len(trainer.dev_debugger.logged_metrics) == max_epochs + + def test_monitor_val_epoch_end(tmpdir): epoch_min_loss_override = 0 model = SimpleModule()