-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added accumulation of loggers' metrics for the same steps #1278
Changes from all commits
c5c467e
625c01a
05b9d55
a14da90
30d0347
f222da1
0d69480
40812f8
06f450c
c4ee89b
d206a13
aa22904
bab529d
39745aa
73a724a
c586217
e48a711
8064470
b11a80e
8e7e28f
1143294
a04f2c3
94dd1f8
7ba9c96
e9fc69a
6b90977
41aff77
2fd1c79
5afe31d
56a9b71
c6ab2e0
014b6dd
0cf3122
64ac60e
2ecb8c5
382aadb
80da0fb
6b64fde
5f1f557
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,12 @@ | ||
import argparse | ||
import functools | ||
import operator | ||
from abc import ABC, abstractmethod | ||
from argparse import Namespace | ||
from functools import wraps | ||
from typing import Union, Optional, Dict, Iterable, Any, Callable, List | ||
from typing import Union, Optional, Dict, Iterable, Any, Callable, List, Sequence, Mapping, Tuple | ||
|
||
import numpy as np | ||
import torch | ||
|
||
|
||
|
@@ -25,22 +28,119 @@ def wrapped_fn(self, *args, **kwargs): | |
class LightningLoggerBase(ABC): | ||
"""Base class for experiment loggers.""" | ||
|
||
def __init__(self): | ||
def __init__( | ||
self, | ||
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None, | ||
agg_default_func: Callable[[Sequence[float]], float] = np.mean | ||
): | ||
""" | ||
Args: | ||
agg_key_funcs: | ||
Dictionary which maps a metric name to a function, which will | ||
aggregate the metric values for the same steps. | ||
agg_default_func: | ||
Default function to aggregate metric values. If some metric name | ||
is not presented in the `agg_key_funcs` dictionary, then the | ||
`agg_default_func` will be used for aggregation. | ||
|
||
Notes: | ||
`agg_key_funcs` and `agg_default_func` are used only when one logs metrics with | ||
`LightningLoggerBase.agg_and_log_metrics` method. | ||
""" | ||
self._rank = 0 | ||
self._prev_step = -1 | ||
self._metrics_to_agg: List[Dict[str, float]] = [] | ||
self._agg_key_funcs = agg_key_funcs if agg_key_funcs else {} | ||
self._agg_default_func = agg_default_func | ||
|
||
def update_agg_funcs( | ||
self, | ||
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None, | ||
agg_default_func: Callable[[Sequence[float]], float] = np.mean | ||
): | ||
"""Update aggregation methods. | ||
|
||
Args: | ||
agg_key_funcs: | ||
Dictionary which maps a metric name to a function, which will | ||
aggregate the metric values for the same steps. | ||
agg_default_func: | ||
Default function to aggregate metric values. If some metric name | ||
is not presented in the `agg_key_funcs` dictionary, then the | ||
`agg_default_func` will be used for aggregation. | ||
""" | ||
if agg_key_funcs: | ||
self._agg_key_funcs.update(agg_key_funcs) | ||
if agg_default_func: | ||
self._agg_default_func = agg_default_func | ||
|
||
@property | ||
@abstractmethod | ||
def experiment(self) -> Any: | ||
"""Return the experiment object associated with this logger""" | ||
|
||
def _aggregate_metrics( | ||
self, metrics: Dict[str, float], step: Optional[int] = None | ||
) -> Tuple[int, Optional[Dict[str, float]]]: | ||
"""Aggregates metrics. | ||
|
||
Args: | ||
metrics: Dictionary with metric names as keys and measured quantities as values | ||
step: Step number at which the metrics should be recorded | ||
|
||
Returns: | ||
sStep and aggregated metrics. The return value could be None. In such case, metrics | ||
are added to the aggregation list, but not aggregated yet. | ||
""" | ||
# if you still receiving metric from the same step, just accumulate it | ||
if step == self._prev_step: | ||
self._metrics_to_agg.append(metrics) | ||
return step, None | ||
|
||
# compute the metrics | ||
agg_step, agg_mets = self._finalize_agg_metrics() | ||
|
||
# as new step received reset accumulator | ||
self._metrics_to_agg = [metrics] | ||
self._prev_step = step | ||
return agg_step, agg_mets | ||
|
||
def _finalize_agg_metrics(self): | ||
"""Aggregate accumulated metrics. This shall be called in close.""" | ||
# compute the metrics | ||
if not self._metrics_to_agg: | ||
agg_mets = None | ||
elif len(self._metrics_to_agg) == 1: | ||
agg_mets = self._metrics_to_agg[0] | ||
else: | ||
agg_mets = merge_dicts(self._metrics_to_agg, self._agg_key_funcs, self._agg_default_func) | ||
return self._prev_step, agg_mets | ||
|
||
def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): | ||
"""Aggregates and records metrics. | ||
This method doesn't log the passed metrics instantaneously, but instead | ||
it aggregates them and logs only if metrics are ready to be logged. | ||
|
||
Args: | ||
metrics: Dictionary with metric names as keys and measured quantities as values | ||
step: Step number at which the metrics should be recorded | ||
""" | ||
agg_step, metrics_to_log = self._aggregate_metrics(metrics=metrics, step=step) | ||
|
||
if metrics_to_log is not None: | ||
self.log_metrics(metrics=metrics_to_log, step=agg_step) | ||
|
||
@abstractmethod | ||
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): | ||
"""Record metrics. | ||
"""Records metrics. | ||
This method logs metrics as as soon as it received them. If you want to aggregate | ||
metrics for one specific `step`, use the `agg_and_log_metrics` method. | ||
|
||
Args: | ||
metrics: Dictionary with metric names as keys and measured quantities as values | ||
step: Step number at which the metrics should be recorded | ||
""" | ||
pass | ||
|
||
@staticmethod | ||
def _convert_params(params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]: | ||
|
@@ -131,7 +231,10 @@ def finalize(self, status: str) -> None: | |
|
||
def close(self) -> None: | ||
"""Do any cleanup that is necessary to close an experiment.""" | ||
pass | ||
agg_step, metrics_to_log = self._finalize_agg_metrics() | ||
|
||
if metrics_to_log is not None: | ||
self.log_metrics(metrics=metrics_to_log, step=agg_step) | ||
|
||
@property | ||
def rank(self) -> int: | ||
|
@@ -200,3 +303,48 @@ def name(self) -> str: | |
@property | ||
def version(self) -> str: | ||
return '_'.join([str(logger.version) for logger in self._logger_iterable]) | ||
|
||
|
||
def merge_dicts( | ||
dicts: Sequence[Mapping], | ||
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None, | ||
default_func: Callable[[Sequence[float]], float] = np.mean | ||
) -> Dict: | ||
"""Merge a sequence with dictionaries into one dictionary by aggregating the | ||
same keys with some given function. | ||
|
||
Args: | ||
dicts: | ||
Sequence of dictionaries to be merged. | ||
agg_key_funcs: | ||
Mapping from key name to function. This function will aggregate a | ||
list of values, obtained from the same key of all dictionaries. | ||
If some key has no specified aggregation function, the default one | ||
will be used. Default is: None (all keys will be aggregated by the | ||
default function). | ||
default_func: | ||
Default function to aggregate keys, which are not presented in the | ||
`agg_key_funcs` map. | ||
|
||
Returns: | ||
Dictionary with merged values. | ||
|
||
Examples: | ||
>>> import pprint | ||
>>> d1 = {'a': 1.7, 'b': 2.0, 'c': 1} | ||
>>> d2 = {'a': 1.1, 'b': 2.2, 'v': 1} | ||
>>> d3 = {'a': 1.1, 'v': 2.3} | ||
>>> dflt_func = min | ||
>>> agg_funcs = {'a': np.mean, 'v': max} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. won't numpy functions slow things down because everything needs to go to cpu? we don't want to move things to CPU for the user ever haha. Every cpu calls slows training down a ton There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so to have it rather as Tensor... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @alexeykarnachev we may use the Running accum or make another for full accum and extending every N steps and copy existing... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so the structure will change from list of dict to dict of tensors but not sure if it makes much faster... also it will allow us to use agg implemented for Torch There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmmm. I thought, that at this point all values are already on cpu. And few lines above this call, we transform the metrics to the scalars: So, do we really need tensors here? Metrics which come to the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Borda , the def training_step(self, batch, batch_idx):
loss, logits, _ = self.forward(batch)
lr = self.trainer.optimizers[0].param_groups[0]['lr']
log = {'Loss/train': loss, 'Learning-Rate': lr}
return {'loss': loss, 'log': log} Here, the What do you think on this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agree, let's get this done and think about speedup later... :] There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok perfect. this might actually be the main cause of the minimal speed discrepancy between lightning and pure pytorch. |
||
>>> pprint.pprint(merge_dicts([d1, d2, d3], agg_funcs, dflt_func)) | ||
{'a': 1.3, 'b': 2.0, 'c': 1, 'v': 2.3} | ||
""" | ||
|
||
keys = list(functools.reduce(operator.or_, [set(d.keys()) for d in dicts])) | ||
d_out = {} | ||
for k in keys: | ||
fn = agg_key_funcs.get(k, default_func) if agg_key_funcs else default_func | ||
agg_val = fn([v for v in [d_in.get(k) for d_in in dicts] if v is not None]) | ||
d_out[k] = agg_val | ||
|
||
return d_out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
like we can specify the mean/avg/... here