diff --git a/CHANGELOG.md b/CHANGELOG.md index 81d18a7d49279..83bfbef4d9269 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added same step loggers' metrics aggregation ([#1278](https://github.com/PyTorchLightning/pytorch-lightning/pull/1278)) - Added parity test between a vanilla MNIST model and lightning model ([#1284](https://github.com/PyTorchLightning/pytorch-lightning/pull/1284)) - Added parity test between a vanilla RNN model and lightning model ([#1351](https://github.com/PyTorchLightning/pytorch-lightning/pull/1351)) - Added Reinforcement Learning - Deep Q-network (DQN) lightning example ([#1232](https://github.com/PyTorchLightning/pytorch-lightning/pull/1232)) @@ -30,6 +31,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Changed (renamed and refatored) `TensorRunningMean` -> `TensorRunningAccum`: running accumulations were generalized. ([#1278](https://github.com/PyTorchLightning/pytorch-lightning/pull/1278)) - Changed `progress_bar_refresh_rate` trainer flag to disable progress bar when set to 0. ([#1108](https://github.com/PyTorchLightning/pytorch-lightning/pull/1108)) - Enhanced `load_from_checkpoint` to also forward params to the model ([#1307](https://github.com/PyTorchLightning/pytorch-lightning/pull/1307)) - Updated references to self.forward() to instead use the `__call__` interface. ([#1211](https://github.com/PyTorchLightning/pytorch-lightning/pull/1211)) diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index 38800e4945a1e..dcaddadfaca2b 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -1,9 +1,12 @@ import argparse +import functools +import operator from abc import ABC, abstractmethod from argparse import Namespace from functools import wraps -from typing import Union, Optional, Dict, Iterable, Any, Callable, List +from typing import Union, Optional, Dict, Iterable, Any, Callable, List, Sequence, Mapping, Tuple +import numpy as np import torch @@ -25,22 +28,119 @@ def wrapped_fn(self, *args, **kwargs): class LightningLoggerBase(ABC): """Base class for experiment loggers.""" - def __init__(self): + def __init__( + self, + agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None, + agg_default_func: Callable[[Sequence[float]], float] = np.mean + ): + """ + Args: + agg_key_funcs: + Dictionary which maps a metric name to a function, which will + aggregate the metric values for the same steps. + agg_default_func: + Default function to aggregate metric values. If some metric name + is not presented in the `agg_key_funcs` dictionary, then the + `agg_default_func` will be used for aggregation. + + Notes: + `agg_key_funcs` and `agg_default_func` are used only when one logs metrics with + `LightningLoggerBase.agg_and_log_metrics` method. + """ self._rank = 0 + self._prev_step = -1 + self._metrics_to_agg: List[Dict[str, float]] = [] + self._agg_key_funcs = agg_key_funcs if agg_key_funcs else {} + self._agg_default_func = agg_default_func + + def update_agg_funcs( + self, + agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None, + agg_default_func: Callable[[Sequence[float]], float] = np.mean + ): + """Update aggregation methods. + + Args: + agg_key_funcs: + Dictionary which maps a metric name to a function, which will + aggregate the metric values for the same steps. + agg_default_func: + Default function to aggregate metric values. If some metric name + is not presented in the `agg_key_funcs` dictionary, then the + `agg_default_func` will be used for aggregation. + """ + if agg_key_funcs: + self._agg_key_funcs.update(agg_key_funcs) + if agg_default_func: + self._agg_default_func = agg_default_func @property @abstractmethod def experiment(self) -> Any: """Return the experiment object associated with this logger""" + def _aggregate_metrics( + self, metrics: Dict[str, float], step: Optional[int] = None + ) -> Tuple[int, Optional[Dict[str, float]]]: + """Aggregates metrics. + + Args: + metrics: Dictionary with metric names as keys and measured quantities as values + step: Step number at which the metrics should be recorded + + Returns: + sStep and aggregated metrics. The return value could be None. In such case, metrics + are added to the aggregation list, but not aggregated yet. + """ + # if you still receiving metric from the same step, just accumulate it + if step == self._prev_step: + self._metrics_to_agg.append(metrics) + return step, None + + # compute the metrics + agg_step, agg_mets = self._finalize_agg_metrics() + + # as new step received reset accumulator + self._metrics_to_agg = [metrics] + self._prev_step = step + return agg_step, agg_mets + + def _finalize_agg_metrics(self): + """Aggregate accumulated metrics. This shall be called in close.""" + # compute the metrics + if not self._metrics_to_agg: + agg_mets = None + elif len(self._metrics_to_agg) == 1: + agg_mets = self._metrics_to_agg[0] + else: + agg_mets = merge_dicts(self._metrics_to_agg, self._agg_key_funcs, self._agg_default_func) + return self._prev_step, agg_mets + + def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): + """Aggregates and records metrics. + This method doesn't log the passed metrics instantaneously, but instead + it aggregates them and logs only if metrics are ready to be logged. + + Args: + metrics: Dictionary with metric names as keys and measured quantities as values + step: Step number at which the metrics should be recorded + """ + agg_step, metrics_to_log = self._aggregate_metrics(metrics=metrics, step=step) + + if metrics_to_log is not None: + self.log_metrics(metrics=metrics_to_log, step=agg_step) + @abstractmethod def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): - """Record metrics. + """Records metrics. + This method logs metrics as as soon as it received them. If you want to aggregate + metrics for one specific `step`, use the `agg_and_log_metrics` method. Args: metrics: Dictionary with metric names as keys and measured quantities as values step: Step number at which the metrics should be recorded """ + pass @staticmethod def _convert_params(params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]: @@ -131,7 +231,10 @@ def finalize(self, status: str) -> None: def close(self) -> None: """Do any cleanup that is necessary to close an experiment.""" - pass + agg_step, metrics_to_log = self._finalize_agg_metrics() + + if metrics_to_log is not None: + self.log_metrics(metrics=metrics_to_log, step=agg_step) @property def rank(self) -> int: @@ -200,3 +303,48 @@ def name(self) -> str: @property def version(self) -> str: return '_'.join([str(logger.version) for logger in self._logger_iterable]) + + +def merge_dicts( + dicts: Sequence[Mapping], + agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None, + default_func: Callable[[Sequence[float]], float] = np.mean +) -> Dict: + """Merge a sequence with dictionaries into one dictionary by aggregating the + same keys with some given function. + + Args: + dicts: + Sequence of dictionaries to be merged. + agg_key_funcs: + Mapping from key name to function. This function will aggregate a + list of values, obtained from the same key of all dictionaries. + If some key has no specified aggregation function, the default one + will be used. Default is: None (all keys will be aggregated by the + default function). + default_func: + Default function to aggregate keys, which are not presented in the + `agg_key_funcs` map. + + Returns: + Dictionary with merged values. + + Examples: + >>> import pprint + >>> d1 = {'a': 1.7, 'b': 2.0, 'c': 1} + >>> d2 = {'a': 1.1, 'b': 2.2, 'v': 1} + >>> d3 = {'a': 1.1, 'v': 2.3} + >>> dflt_func = min + >>> agg_funcs = {'a': np.mean, 'v': max} + >>> pprint.pprint(merge_dicts([d1, d2, d3], agg_funcs, dflt_func)) + {'a': 1.3, 'b': 2.0, 'c': 1, 'v': 2.3} + """ + + keys = list(functools.reduce(operator.or_, [set(d.keys()) for d in dicts])) + d_out = {} + for k in keys: + fn = agg_key_funcs.get(k, default_func) if agg_key_funcs else default_func + agg_val = fn([v for v in [d_in.get(k) for d_in in dicts] if v is not None]) + d_out[k] = agg_val + + return d_out diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index 313eed54d3e6b..c8549da0957b5 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -71,7 +71,7 @@ def log_metrics(self, metrics, grad_norm_dic, step=None): step = step if step is not None else self.global_step # log actual metrics if self.proc_rank == 0 and self.logger is not None: - self.logger.log_metrics(scalar_metrics, step=step) + self.logger.agg_and_log_metrics(scalar_metrics, step=step) self.logger.save() def add_tqdm_metrics(self, metrics): diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 6c6129de466e9..925d96ed8baa3 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -1,13 +1,12 @@ import torch -class TensorRunningMean(object): - """ - Tracks a running mean without graph references. - Round robbin for the mean +class TensorRunningAccum(object): + """Tracks a running accumulation values (min, max, mean) without graph + references. Examples: - >>> accum = TensorRunningMean(5) + >>> accum = TensorRunningAccum(5) >>> accum.last(), accum.mean() (None, None) >>> accum.append(torch.tensor(1.5)) @@ -18,8 +17,8 @@ class TensorRunningMean(object): (tensor(2.5000), tensor(2.)) >>> accum.reset() >>> _= [accum.append(torch.tensor(i)) for i in range(13)] - >>> accum.last(), accum.mean() - (tensor(12.), tensor(10.)) + >>> accum.last(), accum.mean(), accum.min(), accum.max() + (tensor(12.), tensor(10.), tensor(8.), tensor(12.)) """ def __init__(self, window_length: int): self.window_length = window_length @@ -29,13 +28,16 @@ def __init__(self, window_length: int): self.rotated: bool = False def reset(self) -> None: - self = TensorRunningMean(self.window_length) + """Empty the accumulator.""" + self = TensorRunningAccum(self.window_length) def last(self): + """Get the last added element.""" if self.last_idx is not None: return self.memory[self.last_idx] def append(self, x): + """Add an element to the accumulator.""" # ensure same device and type if self.memory.device != x.device or self.memory.type() != x.type(): x = x.to(self.memory) @@ -54,5 +56,20 @@ def append(self, x): self.rotated = True def mean(self): + """Get mean value from stored elements.""" + return self._agg_memory('mean') + + def max(self): + """Get maximal value from stored elements.""" + return self._agg_memory('max') + + def min(self): + """Get minimal value from stored elements.""" + return self._agg_memory('min') + + def _agg_memory(self, how: str): if self.last_idx is not None: - return self.memory.mean() if self.rotated else self.memory[:self.current_idx].mean() + if self.rotated: + return getattr(self.memory, how)() + else: + return getattr(self.memory[:self.current_idx], how)() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index d470bcb35a8cd..5b140064939a3 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -28,7 +28,7 @@ from pytorch_lightning.trainer.logging import TrainerLoggingMixin from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin -from pytorch_lightning.trainer.supporters import TensorRunningMean +from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.trainer.training_io import TrainerIOMixin from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin @@ -342,7 +342,7 @@ def __init__( # training bookeeping self.total_batch_idx = 0 - self.running_loss = TensorRunningMean(window_length=20) + self.running_loss = TensorRunningAccum(window_length=20) self.batch_idx = 0 self.tqdm_metrics = {} self.callback_metrics = {} @@ -551,20 +551,19 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: allowed_types = (str, float, int, bool) # TODO: get "help" from docstring :) - for arg, arg_types, arg_default in cls.get_init_arguments_and_types(): - if arg not in depr_arg_names: - for allowed_type in allowed_types: - if allowed_type in arg_types: - if allowed_type is bool: - allowed_type = lambda x: bool(distutils.util.strtobool(x)) - parser.add_argument( - f'--{arg}', - default=arg_default, - type=allowed_type, - dest=arg, - help='autogenerated by pl.Trainer' - ) - break + for arg, arg_types, arg_default in (at for at in cls.get_init_arguments_and_types() + if at[0] not in depr_arg_names): + for allowed_type in (at for at in allowed_types if at in arg_types): + if isinstance(allowed_type, bool): + allowed_type = lambda x: bool(distutils.util.strtobool(x)) + parser.add_argument( + f'--{arg}', + default=arg_default, + type=allowed_type, + dest=arg, + help='autogenerated by pl.Trainer' + ) + break return parser diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 73333fb7d59cd..b6ec1a272d8cf 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -147,7 +147,7 @@ def training_step(self, batch, batch_idx): from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, LightningDataParallel from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.trainer.supporters import TensorRunningMean +from pytorch_lightning.trainer.supporters import TensorRunningAccum try: from apex import amp @@ -337,7 +337,7 @@ def train(self): self.accumulation_scheduler.on_epoch_start(self, self.get_model()) # stores accumulated grad fractions per batch - self.batch_loss_value = TensorRunningMean( + self.batch_loss_value = TensorRunningAccum( window_length=self.accumulate_grad_batches ) diff --git a/tests/loggers/test_base.py b/tests/loggers/test_base.py index d7f3503cf8193..b6614be345c0c 100644 --- a/tests/loggers/test_base.py +++ b/tests/loggers/test_base.py @@ -1,6 +1,9 @@ import pickle +from collections import OrderedDict from unittest.mock import MagicMock +import numpy as np + import tests.base.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.loggers import LightningLoggerBase, rank_zero_only, LoggerCollection @@ -56,6 +59,18 @@ def version(self): return "1" +class StoreHistoryLogger(CustomLogger): + def __init__(self): + super().__init__() + self.history = {} + + @rank_zero_only + def log_metrics(self, metrics, step): + if step not in self.history: + self.history[step] = {} + self.history[step].update(metrics) + + def test_custom_logger(tmpdir): hparams = tutils.get_default_hparams() model = LightningTestModel(hparams) @@ -153,5 +168,19 @@ def decorated(metrics, step): num_sanity_val_steps=0, ) trainer = Trainer(**trainer_options) - trainer.logger.log_metrics = _log_metrics_decorator(trainer.logger.log_metrics) + trainer.logger.log_metrics = _log_metrics_decorator( + trainer.logger.log_metrics) trainer.fit(model) + + +def test_with_accumulate_grad_batches(): + """Checks if the logging is performed once for `accumulate_grad_batches` steps.""" + logger = StoreHistoryLogger() + + np.random.seed(42) + for i, loss in enumerate(np.random.random(10)): + logger.agg_and_log_metrics({'loss': loss}, step=int(i / 5)) + + assert logger.history == {0: {'loss': 0.5623850983416314}} + logger.close() + assert logger.history == {0: {'loss': 0.5623850983416314}, 1: {'loss': 0.4778883735637184}} diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index f80870ec28c34..c26da3b7b280c 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1,17 +1,14 @@ import glob import math import os -from argparse import Namespace +from argparse import Namespace, ArgumentParser import pytest import torch import tests.base.utils as tutils from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ( - EarlyStopping, - ModelCheckpoint, -) +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning import Callback from pytorch_lightning.core.lightning import load_hparams_from_tags_csv from pytorch_lightning.trainer.logging import TrainerLoggingMixin