diff --git a/fvcore/common/param_scheduler.py b/fvcore/common/param_scheduler.py index 5176538..6b82f0b 100644 --- a/fvcore/common/param_scheduler.py +++ b/fvcore/common/param_scheduler.py @@ -7,6 +7,7 @@ "ParamScheduler", "ConstantParamScheduler", "CosineParamScheduler", + "ExponentialParamScheduler", "LinearParamScheduler", "CompositeParamScheduler", "MultiStepParamScheduler", @@ -85,6 +86,33 @@ def __call__(self, where: float) -> float: ) +class ExponentialParamScheduler(ParamScheduler): + """ + Exponetial schedule based on start value and decay, where value is caluated + for timestep t out of T total stepsm by + param_t = start_value * (decay ** t/T).The schedule is updated after every train + step by default based on the fraction of samples seen. + + Example: + + .. code-block:: python + ExponentialParamScheduler(start_value=2.0, decay=0.02) + + Corresponds to a decreasing schedule with values in [2.0, 0.04). + """ + + def __init__( + self, + start_value: float, + decay: float, + ) -> None: + self._start_value = start_value + self._decay = decay + + def __call__(self, where: float) -> float: + return self._start_value * (self._decay ** where) + + class LinearParamScheduler(ParamScheduler): """ Linearly interpolates parameter between ``start_value`` and ``end_value``. diff --git a/tests/param_scheduler/test_scheduler_exponential.py b/tests/param_scheduler/test_scheduler_exponential.py new file mode 100644 index 0000000..4c94b97 --- /dev/null +++ b/tests/param_scheduler/test_scheduler_exponential.py @@ -0,0 +1,29 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import unittest + +from fvcore.common.param_scheduler import ExponentialParamScheduler + + +class TestExponentialScheduler(unittest.TestCase): + _num_epochs = 10 + + def _get_valid_config(self): + return {"start_value": 2.0, "decay": 0.1} + + def _get_valid_intermediate_values(self): + return [1.5887, 1.2619, 1.0024, 0.7962, 0.6325, 0.5024, 0.3991, 0.3170, 0.2518] + + def test_scheduler(self): + config = self._get_valid_config() + + scheduler = ExponentialParamScheduler(**config) + schedule = [ + round(scheduler(epoch_num / self._num_epochs), 4) + for epoch_num in range(self._num_epochs) + ] + expected_schedule = [ + config["start_value"] + ] + self._get_valid_intermediate_values() + + self.assertEqual(schedule, expected_schedule)