Skip to content

Commit

Permalink
Merge pull request #877 from AntonioCarta/oob_metrics
Browse files Browse the repository at this point in the history
evaluation can receive metric values from anywhere
  • Loading branch information
Andrea Cossu authored Jan 11, 2022
2 parents bebb074 + 2657294 commit 128491b
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 16 deletions.
38 changes: 23 additions & 15 deletions avalanche/training/plugins/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections import defaultdict
from typing import Union, Sequence, TYPE_CHECKING

from avalanche.evaluation.metric_results import MetricValue
from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics
from avalanche.training.plugins.strategy_plugin import StrategyPlugin
from avalanche.logging import StrategyLogger, InteractiveLogger
Expand Down Expand Up @@ -110,6 +111,9 @@ def __init__(self,
self._active = True
"""If False, no metrics will be collected."""

self._metric_values = []
"""List of metrics that have yet to be processed by loggers."""

@property
def active(self):
return self._active
Expand All @@ -120,31 +124,35 @@ def active(self, value):
"Active must be set as either True or False"
self._active = value

def publish_metric_value(self, mval: MetricValue):
"""Publish a MetricValue to be processed by the loggers."""
self._metric_values.append(mval)

name = mval.name
x = mval.x_plot
val = mval.value
if self.collect_all:
self.all_metric_results[name][0].append(x)
self.all_metric_results[name][1].append(val)
self.last_metric_results[name] = val

def _update_metrics(self, strategy: 'BaseStrategy', callback: str):
"""Call the metric plugins with the correct callback `callback` and
update the loggers with the new metric values."""
if not self._active:
return []

metric_values = []
for metric in self.metrics:
metric_result = getattr(metric, callback)(strategy)
if isinstance(metric_result, Sequence):
metric_values += list(metric_result)
for mval in metric_result:
self.publish_metric_value(mval)
elif metric_result is not None:
metric_values.append(metric_result)

for metric_value in metric_values:
name = metric_value.name
x = metric_value.x_plot
val = metric_value.value
if self.collect_all:
self.all_metric_results[name][0].append(x)
self.all_metric_results[name][1].append(val)

self.last_metric_results[name] = val
self.publish_metric_value(metric_result)

for logger in self.loggers:
getattr(logger, callback)(strategy, metric_values)
return metric_values
getattr(logger, callback)(strategy, self._metric_values)
self._metric_values = []

def get_last_metrics(self):
"""
Expand Down
13 changes: 12 additions & 1 deletion tests/training/test_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
from avalanche.benchmarks import nc_benchmark, GenericCLScenario, \
benchmark_with_validation_stream
from avalanche.benchmarks.utils.data_loader import TaskBalancedDataLoader
from avalanche.evaluation.metric_results import MetricValue
from avalanche.evaluation.metrics import Mean
from avalanche.logging import TextLogger
from avalanche.models import BaseModel
from avalanche.training.plugins import StrategyPlugin
from avalanche.training.plugins import StrategyPlugin, EvaluationPlugin
from avalanche.training.plugins.lr_scheduling import LRSchedulerPlugin
from avalanche.training.strategies import Naive

Expand Down Expand Up @@ -542,5 +543,15 @@ def get_features(self, x):
return x


class EvaluationPluginTest(unittest.TestCase):
def test_publish_metric(self):
ep = EvaluationPlugin()
mval = MetricValue(self, 'metric', 1.0, 0)
ep.publish_metric_value(mval)

# check key exists
assert len(ep.get_all_metrics()['metric'][1]) == 1


if __name__ == '__main__':
unittest.main()

0 comments on commit 128491b

Please sign in to comment.