Skip to content

Commit

Permalink
Merge pull request #887 from AntonioCarta/early_stopping
Browse files Browse the repository at this point in the history
Early stopping & model checkpoint on iterations
  • Loading branch information
AntonioCarta authored Jan 21, 2022
2 parents c8d11a5 + f779a4b commit efa5223
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 22 deletions.
55 changes: 53 additions & 2 deletions avalanche/benchmarks/generators/benchmark_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ def random_validation_split_strategy(
valid_n_instances = int(validation_size)
if valid_n_instances > len(exp_dataset):
raise ValueError(
f'Can\'t create the validation experience: nott enough '
f'Can\'t create the validation experience: not enough '
f'instances. Required {valid_n_instances}, got only'
f'{len(exp_dataset)}')

Expand All @@ -590,6 +590,57 @@ def random_validation_split_strategy(
return result_train_dataset, result_valid_dataset


def class_balanced_split_strategy(
validation_size: Union[int, float],
experience: Experience):
"""Class-balanced train/validation splits.
This splitting strategy splits `experience` into two experiences
(train and validation) of size `validation_size` using a class-balanced
split. Sample of each class are chosen randomly.
:param validation_size: The percentage of samples to allocate to the
validation experience as a float between 0 and 1.
:param experience: The experience to split.
:return: A tuple containing 2 elements: the new training and validation
datasets.
"""
if not isinstance(validation_size, float):
raise ValueError("validation_size must be an integer")
if not 0.0 <= validation_size <= 1.0:
raise ValueError("validation_size must be a float in [0, 1].")

exp_dataset = experience.dataset
if validation_size > len(exp_dataset):
raise ValueError(
f'Can\'t create the validation experience: not enough '
f'instances. Required {validation_size}, got only'
f'{len(exp_dataset)}')

exp_indices = list(range(len(exp_dataset)))
exp_classes = experience.classes_in_this_experience

# shuffle exp_indices
exp_indices = torch.as_tensor(exp_indices)[
torch.randperm(len(exp_indices))]
# shuffle the targets as well
exp_targets = torch.as_tensor(experience.dataset.targets)[exp_indices]

train_exp_indices = []
valid_exp_indices = []
for cid in exp_classes: # split indices for each class separately.
c_indices = exp_indices[exp_targets == cid]
valid_n_instances = int(validation_size * len(c_indices))
valid_exp_indices.extend(c_indices[:valid_n_instances])
train_exp_indices.extend(c_indices[valid_n_instances:])

result_train_dataset = AvalancheSubset(
exp_dataset, indices=train_exp_indices)
result_valid_dataset = AvalancheSubset(
exp_dataset, indices=valid_exp_indices)
return result_train_dataset, result_valid_dataset


def _gen_split(split_generator: Iterable[Tuple[AvalancheDataset,
AvalancheDataset]]) -> \
Tuple[Generator[AvalancheDataset, None, None],
Expand Down Expand Up @@ -630,7 +681,7 @@ def _lazy_train_val_split(

def benchmark_with_validation_stream(
benchmark_instance: GenericCLScenario,
validation_size: Union[int, float],
validation_size: Union[int, float] = 0.5,
shuffle: bool = False,
input_stream: str = 'train',
output_stream: str = 'valid',
Expand Down
73 changes: 58 additions & 15 deletions avalanche/training/plugins/early_stopping.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,34 @@
import operator
import warnings
from copy import deepcopy

from avalanche.training.plugins import StrategyPlugin


class EarlyStoppingPlugin(StrategyPlugin):
""" Early stopping plugin.
""" Early stopping and model checkpoint plugin.
The plugin checks a metric and stops the training loop when the accuracy
on the metric stopped progressing for `patience` epochs.
After training, the best model's checkpoint is loaded.
.. warning::
The plugin checks the metric value, which is updated by the strategy
during the evaluation. This means that you must ensure that the
evaluation is called frequently enough during the training loop.
For example, if you set `patience=1`, you must also set `eval_every=1`
in the `BaseStrategy`, otherwise the metric won't be updated after
every epoch/iteration. Similarly, `peval_mode` must have the same
value.
Simple plugin stopping the training when the accuracy on the
corresponding validation metric stopped progressing for a few epochs.
The state of the best model is saved after each improvement on the
given metric and is loaded back into the model before stopping the
training procedure.
"""

def __init__(self, patience: int, val_stream_name: str,
metric_name: str = 'Top1_Acc_Stream', mode: str = 'max'):
"""
metric_name: str = 'Top1_Acc_Stream', mode: str = 'max',
peval_mode: str = 'epoch'):
""" Init.
:param patience: Number of epochs to wait before stopping the training.
:param val_stream_name: Name of the validation stream to search in the
metrics. The corresponding stream will be used to keep track of the
Expand All @@ -24,36 +37,66 @@ def __init__(self, patience: int, val_stream_name: str,
reported in the evaluator.
:param mode: Must be "max" or "min". max (resp. min) means that the
given metric should me maximized (resp. minimized).
:param peval_mode: one of {'epoch', 'iteration'}. Decides whether the
early stopping should happen after `patience`
epochs or iterations (Default='epoch').
"""
super().__init__()
self.val_stream_name = val_stream_name
self.patience = patience

assert peval_mode in {'epoch', 'iteration'}
self.peval_mode = peval_mode

self.metric_name = metric_name
self.metric_key = f'{self.metric_name}/eval_phase/' \
f'{self.val_stream_name}'
print(self.metric_key)
if mode not in ('max', 'min'):
raise ValueError(f'Mode must be "max" or "min", got {mode}.')
self.operator = operator.gt if mode == 'max' else operator.lt

self.best_state = None # Contains the best parameters
self.best_val = None
self.best_epoch = None
self.best_step = None

def before_training(self, strategy, **kwargs):
self.best_state = None
self.best_val = None
self.best_epoch = None
self.best_step = None

def before_training_iteration(self, strategy, **kwargs):
if self.peval_mode == 'iteration':
self._update_best(strategy)
curr_step = self._get_strategy_counter(strategy)
if curr_step - self.best_step >= self.patience:
strategy.model.load_state_dict(self.best_state)
strategy.stop_training()

def before_training_epoch(self, strategy, **kwargs):
self._update_best(strategy)
if strategy.clock.train_exp_epochs - self.best_epoch >= self.patience:
strategy.model.load_state_dict(self.best_state)
strategy.stop_training()
if self.peval_mode == 'epoch':
self._update_best(strategy)
curr_step = self._get_strategy_counter(strategy)
if curr_step - self.best_step >= self.patience:
strategy.model.load_state_dict(self.best_state)
strategy.stop_training()

def _update_best(self, strategy):
res = strategy.evaluator.get_last_metrics()
val_acc = res.get(self.metric_key)
if self.best_val is None:
warnings.warn(
f"Metric {self.metric_name} used by the EarlyStopping plugin "
f"is not computed yet. EarlyStopping will not be triggered.")
if self.best_val is None or self.operator(val_acc, self.best_val):
self.best_state = deepcopy(strategy.model.state_dict())
self.best_val = val_acc
self.best_epoch = strategy.clock.train_exp_epochs
self.best_step = self._get_strategy_counter(strategy)

def _get_strategy_counter(self, strategy):
if self.peval_mode == 'epoch':
return strategy.clock.train_exp_epochs
elif self.peval_mode == 'iteration':
return strategy.clock.train_exp_iterations
else:
raise ValueError("Invalid `peval_mode`:", self.peval_mode)
28 changes: 27 additions & 1 deletion tests/test_high_level_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from os.path import expanduser

import torch
from numpy.testing import assert_almost_equal, assert_allclose
from torchvision.datasets import CIFAR10, MNIST
from torchvision.datasets.utils import download_url, extract_archive
from torchvision.transforms import ToTensor
Expand All @@ -12,11 +13,13 @@
tensors_benchmark, paths_benchmark, data_incremental_benchmark, \
benchmark_with_validation_stream
from avalanche.benchmarks.datasets import default_dataset_location
from avalanche.benchmarks.generators.benchmark_generators import \
class_balanced_split_strategy
from avalanche.benchmarks.scenarios.generic_benchmark_creation import \
create_lazy_generic_benchmark, LazyStreamDefinition
from avalanche.benchmarks.utils import AvalancheDataset, \
AvalancheTensorDataset, AvalancheDatasetType
from tests.unit_tests_utils import common_setups
from tests.unit_tests_utils import common_setups, get_fast_benchmark


class HighLevelGeneratorTests(unittest.TestCase):
Expand Down Expand Up @@ -647,3 +650,26 @@ def test_gen():
torch.equal(
test_y,
valid_benchmark.test_stream[0].dataset[:][1]))


class DataSplitStrategiesTests(unittest.TestCase):
def test_dataset_benchmark(self):
benchmark = get_fast_benchmark(n_samples_per_class=1000)
exp = benchmark.train_stream[0]
num_classes = len(exp.classes_in_this_experience)

train_d, valid_d = class_balanced_split_strategy(0.5, exp)
assert abs(len(train_d) - len(valid_d)) <= num_classes
for cid in exp.classes_in_this_experience:
train_cnt = (torch.as_tensor(train_d.targets) == cid).sum()
valid_cnt = (torch.as_tensor(valid_d.targets) == cid).sum()
assert abs(train_cnt - valid_cnt) <= 1

ratio = 0.123
len_data = len(exp.dataset)
train_d, valid_d = class_balanced_split_strategy(ratio, exp)
assert_almost_equal(len(valid_d) / len_data, ratio, decimal=2)
for cid in exp.classes_in_this_experience:
data_cnt = (torch.as_tensor(exp.dataset.targets) == cid).sum()
valid_cnt = (torch.as_tensor(valid_d.targets) == cid).sum()
assert_almost_equal(valid_cnt / data_cnt, ratio, decimal=2)
85 changes: 83 additions & 2 deletions tests/training/test_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
from avalanche.evaluation.metric_results import MetricValue
from avalanche.evaluation.metrics import Mean
from avalanche.logging import TextLogger
from avalanche.models import BaseModel
from avalanche.training.plugins import StrategyPlugin, EvaluationPlugin
from avalanche.models import BaseModel, SimpleMLP
from avalanche.training.plugins import StrategyPlugin, EvaluationPlugin, \
EarlyStoppingPlugin
from avalanche.training.plugins.clock import Clock
from avalanche.training.plugins.lr_scheduling import LRSchedulerPlugin
from avalanche.training.strategies import Naive

Expand Down Expand Up @@ -553,5 +555,84 @@ def test_publish_metric(self):
assert len(ep.get_all_metrics()['metric'][1]) == 1


class EarlyStoppingPluginTest(unittest.TestCase):
def test_early_stop_epochs(self):
class MockEvaluator:
def __init__(self, clock, metrics):
self.clock = clock
self.metrics = metrics

def get_last_metrics(self):
idx = self.clock.train_exp_iterations
return {'Top1_Acc_Stream/eval_phase/a': self.metrics[idx]}

class ESMockStrategy:
"""An empty strategy to test early stopping."""
def __init__(self, p, metric_vals):
self.p = p
self.clock = Clock()
self.evaluator = MockEvaluator(self.clock, metric_vals)

self.model = SimpleMLP()

def before_training_iteration(self):
self.p.before_training_iteration(self)
self.clock.before_training_iteration(self)

def before_training_epoch(self):
self.p.before_training_epoch(self)
self.clock.before_training_epoch(self)

def after_training_iteration(self):
self.p.after_training_iteration(self)
self.clock.after_training_iteration(self)

def after_training_epoch(self):
self.p.after_training_epoch(self)
self.clock.after_training_epoch(self)

def stop_training(self):
raise StopIteration()

def run_es(mvals, p):
strat = ESMockStrategy(p, mvals)
for t in range(100):
try:
if t % 10 == 0:
strat.before_training_epoch()
strat.before_training_iteration()
strat.after_training_iteration()
if t % 10 == 9:
strat.after_training_epoch()
except StopIteration:
break
return strat

# best on epoch
metric_vals = list(range(200))
p = EarlyStoppingPlugin(5, val_stream_name='a')
run_es(metric_vals, p)
print(f"best step={p.best_step}, val={p.best_val}")
assert p.best_step == 9
assert p.best_val == 90

# best on iteration
metric_vals = list(range(200))
p = EarlyStoppingPlugin(5, val_stream_name='a', peval_mode='iteration')
run_es(metric_vals, p)
print(f"best step={p.best_step}, val={p.best_val}")
assert p.best_step == 99
assert p.best_val == 99

# check patience
metric_vals = list([1 for _ in range(200)])
p = EarlyStoppingPlugin(5, val_stream_name='a')
strat = run_es(metric_vals, p)
print(f"best step={p.best_step}, val={p.best_val}")
assert p.best_step == 0
assert strat.clock.train_exp_epochs == p.patience
assert p.best_val == 1


if __name__ == '__main__':
unittest.main()
4 changes: 2 additions & 2 deletions tests/unit_tests_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def load_benchmark(use_task_labels=False, fast_test=True):
return my_nc_benchmark


def get_fast_benchmark(use_task_labels=False, shuffle=True):
n_samples_per_class = 100
def get_fast_benchmark(use_task_labels=False, shuffle=True,
n_samples_per_class=100):
dataset = make_classification(
n_samples=10 * n_samples_per_class,
n_classes=10,
Expand Down

0 comments on commit efa5223

Please sign in to comment.