Skip to content

Commit

Permalink
Eval epoch can now log independently (#3843)
Browse files Browse the repository at this point in the history
* ref: routed epoch outputs to logger

* ref: routed epoch outputs to logger

* ref: routed epoch outputs to logger

* ref: routed epoch outputs to logger
  • Loading branch information
williamFalcon authored Oct 4, 2020
1 parent b76fc5b commit 1aa9d39
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 57 deletions.
54 changes: 45 additions & 9 deletions pytorch_lightning/trainer/connectors/logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from pprint import pprint
from typing import Iterable
from copy import deepcopy
from collections import ChainMap


class LoggerConnector:
Expand Down Expand Up @@ -105,12 +106,12 @@ def add_progress_bar_metrics(self, metrics):

self.trainer.dev_debugger.track_pbar_metrics_history(metrics)

def on_evaluation_epoch_end(self, eval_results, using_eval_result, test_mode):
self._track_callback_metrics(eval_results, using_eval_result)
self._log_on_evaluation_epoch_end_metrics()
def on_evaluation_epoch_end(self, deprecated_eval_results, epoch_logs, using_eval_result, test_mode):
self._track_callback_metrics(deprecated_eval_results, using_eval_result)
self._log_on_evaluation_epoch_end_metrics(epoch_logs)

# TODO: deprecate parts of this for 1.0 (when removing results)
self.__process_eval_epoch_end_results_and_log_legacy(eval_results, test_mode)
self.__process_eval_epoch_end_results_and_log_legacy(deprecated_eval_results, test_mode)

# get the final loop results
eval_loop_results = self._get_evaluate_epoch_results(test_mode)
Expand All @@ -131,15 +132,43 @@ def _get_evaluate_epoch_results(self, test_mode):
self.eval_loop_results = []
return results

def _log_on_evaluation_epoch_end_metrics(self):
def _log_on_evaluation_epoch_end_metrics(self, epoch_logs):
step_metrics = self.trainer.evaluation_loop.step_metrics

num_loaders = len(step_metrics)

# clear mem
self.trainer.evaluation_loop.step_metrics = []

num_loaders = len(step_metrics)
if self.trainer.running_sanity_check:
return

# track all metrics we want to log
metrics_to_log = []

# process metrics per dataloader
# ---------------------------
# UPDATE EPOCH LOGGED METRICS
# ---------------------------
# (ie: in methods at the val_epoch_end level)
# union the epoch logs with whatever was returned from loaders and reduced
epoch_logger_metrics = epoch_logs.get_epoch_log_metrics()
epoch_pbar_metrics = epoch_logs.get_epoch_pbar_metrics()

self.logged_metrics.update(epoch_logger_metrics)
self.progress_bar_metrics.update(epoch_pbar_metrics)

# enable the metrics to be monitored
self.callback_metrics.update(epoch_logger_metrics)
self.callback_metrics.update(epoch_pbar_metrics)

if len(epoch_logger_metrics) > 0:
metrics_to_log.append(epoch_logger_metrics)

# --------------------------------
# UPDATE METRICS PER DATALOADER
# --------------------------------
# each dataloader aggregated metrics
# now we log all of them
for dl_idx, dl_metrics in enumerate(step_metrics):
if len(dl_metrics) == 0:
continue
Expand All @@ -162,7 +191,13 @@ def _log_on_evaluation_epoch_end_metrics(self):
self.eval_loop_results.append(deepcopy(self.callback_metrics))

# actually log
self.log_metrics(logger_metrics, {}, step=self.trainer.global_step)
if len(epoch_logger_metrics) > 0:
metrics_to_log.append(epoch_logger_metrics)

# log all the metrics as a s single dict
metrics_to_log = dict(ChainMap(*metrics_to_log))
if len(metrics_to_log) > 0:
self.log_metrics(metrics_to_log, {}, step=self.trainer.global_step)

def __rename_keys_by_dataloader_idx(self, metrics, dataloader_idx, num_loaders):
if num_loaders == 1:
Expand Down Expand Up @@ -240,7 +275,8 @@ def __process_eval_epoch_end_results_and_log_legacy(self, eval_results, test_mod
self.trainer.logger_connector.add_progress_bar_metrics(prog_bar_metrics)

# log metrics
self.trainer.logger_connector.log_metrics(log_metrics, {})
if len(log_metrics) > 0:
self.trainer.logger_connector.log_metrics(log_metrics, {})

# track metrics for callbacks (all prog bar, logged and callback metrics)
self.trainer.logger_connector.callback_metrics.update(callback_metrics)
Expand Down
20 changes: 10 additions & 10 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,19 +171,23 @@ def evaluation_epoch_end(self, num_dataloaders):
using_eval_result = self.is_using_eval_results()

# call the model epoch end
eval_results = self.__run_eval_epoch_end(num_dataloaders, using_eval_result)
deprecated_results = self.__run_eval_epoch_end(num_dataloaders, using_eval_result)

# 1.0
epoch_logs = self.trainer.get_model()._results

# enable returning anything
for r in eval_results:
for i, r in enumerate(deprecated_results):
if not isinstance(r, (dict, Result, torch.Tensor)):
return []
deprecated_results[i] = []

return eval_results
return deprecated_results, epoch_logs

def log_epoch_metrics(self, eval_results, test_mode):
def log_epoch_metrics(self, deprecated_eval_results, epoch_logs, test_mode):
using_eval_result = self.is_using_eval_results()
eval_loop_results = self.trainer.logger_connector.on_evaluation_epoch_end(
eval_results,
deprecated_eval_results,
epoch_logs,
using_eval_result,
test_mode
)
Expand Down Expand Up @@ -228,10 +232,6 @@ def __run_eval_epoch_end(self, num_dataloaders, using_eval_result):
if using_eval_result and not user_reduced:
eval_results = self.__auto_reduce_result_objs(outputs)

result = model._results
if len(result) > 0 and eval_results is None:
eval_results = result.get_epoch_log_metrics()

if not isinstance(eval_results, list):
eval_results = [eval_results]

Expand Down
8 changes: 5 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,10 +603,12 @@ def run_evaluation(self, test_mode: bool = False, max_batches=None):
self.evaluation_loop.step_metrics.append(dl_step_metrics)

# lightning module method
eval_results = self.evaluation_loop.evaluation_epoch_end(num_dataloaders=len(dataloaders))
deprecated_eval_results, epoch_logs = self.evaluation_loop.evaluation_epoch_end(
num_dataloaders=len(dataloaders)
)

# bookkeeping
eval_loop_results = self.evaluation_loop.log_epoch_metrics(eval_results, test_mode)
eval_loop_results = self.evaluation_loop.log_epoch_metrics(deprecated_eval_results, epoch_logs, test_mode)
self.evaluation_loop.predictions.to_disk()

# hook
Expand All @@ -619,7 +621,7 @@ def run_evaluation(self, test_mode: bool = False, max_batches=None):
# hook
self.evaluation_loop.on_evaluation_end()

return eval_loop_results, eval_results
return eval_loop_results, deprecated_eval_results

def run_test(self):
# only load test dataloader for testing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_val_step_result_callbacks(tmpdir):

# did not request any metrics to log (except the metrics saying which epoch we are on)
assert len(trainer.logger_connector.progress_bar_metrics) == 0
assert len(trainer.dev_debugger.logged_metrics) == 5
assert len(trainer.dev_debugger.logged_metrics) == 0


def test_val_step_using_train_callbacks(tmpdir):
Expand Down Expand Up @@ -112,7 +112,7 @@ def test_val_step_using_train_callbacks(tmpdir):

# did not request any metrics to log (except the metrics saying which epoch we are on)
assert len(trainer.logger_connector.progress_bar_metrics) == 0
assert len(trainer.dev_debugger.logged_metrics) == expected_epochs
assert len(trainer.dev_debugger.logged_metrics) == 0


def test_val_step_only_epoch_metrics(tmpdir):
Expand Down Expand Up @@ -210,40 +210,9 @@ def test_val_step_only_step_metrics(tmpdir):
assert len(trainer.dev_debugger.early_stopping_history) == 0

# make sure we logged the exact number of metrics
assert len(trainer.dev_debugger.logged_metrics) == epochs * batches + (epochs)
assert len(trainer.dev_debugger.logged_metrics) == epochs * batches
assert len(trainer.dev_debugger.pbar_added_metrics) == epochs * batches + (epochs)

# make sure we logged the correct epoch metrics
total_empty_epoch_metrics = 0
epoch = 0
for metric in trainer.dev_debugger.logged_metrics:
if 'epoch' in metric:
epoch += 1
if len(metric) > 2:
assert 'no_val_no_pbar' not in metric
assert 'val_step_pbar_acc' not in metric
assert metric[f'val_step_log_acc/epoch_{epoch}']
assert metric[f'val_step_log_pbar_acc/epoch_{epoch}']
else:
total_empty_epoch_metrics += 1

assert total_empty_epoch_metrics == 3

# make sure we logged the correct epoch pbar metrics
total_empty_epoch_metrics = 0
for metric in trainer.dev_debugger.pbar_added_metrics:
if 'epoch' in metric:
epoch += 1
if len(metric) > 2:
assert 'no_val_no_pbar' not in metric
assert 'val_step_log_acc' not in metric
assert metric['val_step_log_pbar_acc']
assert metric['val_step_pbar_acc']
else:
total_empty_epoch_metrics += 1

assert total_empty_epoch_metrics == 3

# only 1 checkpoint expected since values didn't change after that
assert len(trainer.dev_debugger.checkpoint_callback_history) == 1

Expand Down
50 changes: 49 additions & 1 deletion tests/trainer/logging/test_eval_loop_logging_1_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from pytorch_lightning import Trainer
from pytorch_lightning import callbacks
from tests.base.deterministic_model import DeterministicModel
from tests.base import SimpleModule
from tests.base import SimpleModule, BoringModel
import os
import torch
import pytest


def test__validation_step__log(tmpdir):
Expand Down Expand Up @@ -148,6 +149,53 @@ def backward(self, trainer, loss, optimizer, optimizer_idx):
assert expected_cb_metrics == callback_metrics


@pytest.mark.parametrize(['batches', 'log_interval', 'max_epochs'], [(1, 1, 1), (64, 32, 2)])
def test_eval_epoch_logging(tmpdir, batches, log_interval, max_epochs):
"""
Tests that only training_step can be used
"""
os.environ['PL_DEV_DEBUG'] = '1'

class TestModel(BoringModel):
def validation_epoch_end(self, outputs):
self.log('c', torch.tensor(2), on_epoch=True, prog_bar=True, logger=True)
self.log('d/e/f', 2)

model = TestModel()

trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=batches,
limit_val_batches=batches,
max_epochs=max_epochs,
row_log_interval=log_interval,
weights_summary=None,
)
trainer.fit(model)

# make sure all the metrics are available for callbacks
logged_metrics = set(trainer.logged_metrics.keys())
expected_logged_metrics = {
'c',
'd/e/f',
}
assert logged_metrics == expected_logged_metrics

pbar_metrics = set(trainer.progress_bar_metrics.keys())
expected_pbar_metrics = {'c'}
assert pbar_metrics == expected_pbar_metrics

callback_metrics = set(trainer.callback_metrics.keys())
expected_callback_metrics = set()
expected_callback_metrics = expected_callback_metrics.union(logged_metrics)
expected_callback_metrics = expected_callback_metrics.union(pbar_metrics)
callback_metrics.remove('debug_epoch')
assert callback_metrics == expected_callback_metrics

# assert the loggers received the expected number
assert len(trainer.dev_debugger.logged_metrics) == max_epochs


def test_monitor_val_epoch_end(tmpdir):
epoch_min_loss_override = 0
model = SimpleModule()
Expand Down

0 comments on commit 1aa9d39

Please sign in to comment.