diff --git a/ignite/handlers/lr_finder.py b/ignite/handlers/lr_finder.py index 3b79d7a265b..7eb6c284a83 100644 --- a/ignite/handlers/lr_finder.py +++ b/ignite/handlers/lr_finder.py @@ -14,7 +14,7 @@ import ignite.distributed as idist from ignite.engine import Engine, Events from ignite.handlers import Checkpoint -from ignite.handlers.param_scheduler import LRScheduler, PiecewiseLinear +from ignite.handlers.param_scheduler import LRScheduler, ParamGroupScheduler, PiecewiseLinear class FastaiLRFinder: @@ -74,11 +74,12 @@ class FastaiLRFinder: .. versionadded:: 0.4.6 """ + _lr_schedule: Union[LRScheduler, PiecewiseLinear, ParamGroupScheduler] + def __init__(self) -> None: self._diverge_flag = False self._history = {} # type: Dict[str, List[Any]] self._best_loss = None - self._lr_schedule = None # type: Optional[Union[LRScheduler, PiecewiseLinear]] self.logger = logging.getLogger(__name__ + "." + self.__class__.__name__) def _run( @@ -87,8 +88,8 @@ def _run( optimizer: Optimizer, output_transform: Callable, num_iter: int, - start_lr: float, - end_lr: float, + start_lrs: List[float], + end_lrs: List[float], step_mode: str, smooth_f: float, diverge_th: float, @@ -118,22 +119,35 @@ def _run( ) self.logger.debug(f"Running LR finder for {num_iter} iterations") - if start_lr is None: - start_lr = optimizer.param_groups[0]["lr"] + # Initialize the proper learning rate policy if step_mode.lower() == "exp": - start_lr_list = [start_lr] * len(optimizer.param_groups) - self._lr_schedule = LRScheduler(_ExponentialLR(optimizer, start_lr_list, end_lr, num_iter)) + self._lr_schedule = LRScheduler(_ExponentialLR(optimizer, start_lrs, end_lrs, num_iter)) else: - self._lr_schedule = PiecewiseLinear( - optimizer, param_name="lr", milestones_values=[(0, start_lr), (num_iter, end_lr)] - ) + if len(start_lrs) == 1: + self._lr_schedule = PiecewiseLinear( + optimizer, + param_name="lr", + milestones_values=[(0, start_lrs[0]), (num_iter, end_lrs[0])], + ) + else: + self._lr_schedule = ParamGroupScheduler( + [ + PiecewiseLinear( + optimizer, + param_name="lr", + milestones_values=[(0, start_lrs[i]), (num_iter, end_lrs[i])], + param_group_index=i, + ) + for i in range(len(optimizer.param_groups)) + ] + ) if not trainer.has_event_handler(self._lr_schedule): trainer.add_event_handler(Events.ITERATION_COMPLETED, self._lr_schedule, num_iter) def _reset(self, trainer: Engine) -> None: self.logger.debug("Completed LR finder run") - trainer.remove_event_handler(self._lr_schedule, Events.ITERATION_COMPLETED) # type: ignore[arg-type] + trainer.remove_event_handler(self._lr_schedule, Events.ITERATION_COMPLETED) trainer.remove_event_handler(self._log_lr_and_loss, Events.ITERATION_COMPLETED) trainer.remove_event_handler(self._reached_num_iterations, Events.ITERATION_COMPLETED) @@ -157,7 +171,7 @@ def _log_lr_and_loss(self, trainer: Engine, output_transform: Callable, smooth_f f"but got output of type {type(loss).__name__}" ) loss = idist.all_reduce(loss) - lr = self._lr_schedule.get_param() # type: ignore[union-attr] + lr = self._lr_schedule.get_param() self._history["lr"].append(lr) if trainer.state.iteration == 1: self._best_loss = loss @@ -251,7 +265,6 @@ def plot( ) if not self._history: raise RuntimeError("learning rate finder didn't run yet so results can't be plotted") - if skip_start < 0: raise ValueError("skip_start cannot be negative") if skip_end < 0: @@ -367,8 +380,8 @@ def attach( to_save: Mapping, output_transform: Callable = lambda output: output, num_iter: Optional[int] = None, - start_lr: Optional[float] = None, - end_lr: float = 10.0, + start_lr: Optional[Union[float, List[float]]] = None, + end_lr: Optional[Union[float, List[float]]] = 10.0, step_mode: str = "exp", smooth_f: float = 0.05, diverge_th: float = 5.0, @@ -432,8 +445,37 @@ def attach( raise TypeError(f"if provided, num_iter should be an integer, but give {num_iter}") if num_iter <= 0: raise ValueError(f"if provided, num_iter should be positive, but give {num_iter}") - if isinstance(start_lr, (float, int)) and start_lr >= end_lr: - raise ValueError(f"start_lr must be less than end_lr, start_lr={start_lr} vs end_lr={end_lr}") + + optimizer = to_save["optimizer"] + if start_lr is None: + start_lrs = [pg["lr"] for pg in optimizer.param_groups] + elif isinstance(start_lr, float): + start_lrs = [start_lr] * len(optimizer.param_groups) + elif isinstance(start_lr, list): + if len(start_lr) != len(optimizer.param_groups): + raise ValueError( + "Number of values of start_lr should be equal to optimizer values." + f"start_lr values:{len(start_lr)} optimizer values: {len(optimizer.param_groups)}" + ) + start_lrs = start_lr + else: + raise TypeError(f"start_lr should be a float or list of floats, but given {type(start_lr)}") + + if isinstance(end_lr, float): + end_lrs = [end_lr] * len(optimizer.param_groups) + elif isinstance(end_lr, list): + if len(end_lr) != len(optimizer.param_groups): + raise ValueError( + "Number of values of end_lr should be equal to optimizer values." + f"end_lr values:{len(end_lr)} optimizer values: {len(optimizer.param_groups)}" + ) + end_lrs = end_lr + else: + raise TypeError(f"end_lr should be a float or list of floats, but given {type(end_lr)}") + + for start, end in zip(start_lrs, end_lrs): + if start >= end: + raise ValueError(f"start_lr must be less than end_lr, start_lr={start_lr} vs end_lr={end_lr}") # store to_save with tempfile.TemporaryDirectory() as tmpdirname: @@ -443,7 +485,6 @@ def attach( cache_filepath = Path(tmpdirname) / "ignite_lr_finder_cache.pt" torch.save(obj, cache_filepath.as_posix()) - optimizer = to_save["optimizer"] # Attach handlers if not trainer.has_event_handler(self._run): trainer.add_event_handler( @@ -452,8 +493,8 @@ def attach( optimizer, output_transform, num_iter, - start_lr, - end_lr, + start_lrs, + end_lrs, step_mode, smooth_f, diverge_th, @@ -479,23 +520,24 @@ class _ExponentialLR(_LRScheduler): Args: optimizer: wrapped optimizer. - end_lr: the initial learning rate which is the lower - boundary of the test. Default: 10. + start_lrs: the initial learning rate for parameter groups. + end_lrs: the final learning rate for parameter groups. num_iter: the number of iterations over which the test occurs. Default: 100. last_epoch: the index of last epoch. Default: -1. - """ - def __init__(self, optimizer: Optimizer, start_lr: List[float], end_lr: float, num_iter: int, last_epoch: int = -1): - self.end_lr = end_lr + def __init__( + self, optimizer: Optimizer, start_lrs: List[float], end_lrs: List[float], num_iter: int, last_epoch: int = -1 + ): + self.end_lrs = end_lrs self.num_iter = num_iter super(_ExponentialLR, self).__init__(optimizer, last_epoch) # override base_lrs - self.base_lrs = start_lr + self.base_lrs = start_lrs - def get_lr(self) -> List[float]: # type: ignore + def get_lr(self) -> List[float]: # type: ignore[override] curr_iter = self.last_epoch + 1 r = curr_iter / self.num_iter - return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs] + return [base_lr * (end_lr / base_lr) ** r for end_lr, base_lr in zip(self.end_lrs, self.base_lrs)] diff --git a/tests/ignite/handlers/test_lr_finder.py b/tests/ignite/handlers/test_lr_finder.py index 1ead99e45c0..333ceb5562c 100644 --- a/tests/ignite/handlers/test_lr_finder.py +++ b/tests/ignite/handlers/test_lr_finder.py @@ -161,32 +161,32 @@ def mnist_dataloader(): def test_attach_incorrect_input_args(lr_finder, dummy_engine, model, optimizer, dataloader): with pytest.raises(TypeError, match=r"Argument to_save should be a mapping"): - with lr_finder.attach(dummy_engine, to_save=123) as f: + with lr_finder.attach(dummy_engine, to_save=123): pass with pytest.raises(TypeError, match=r"Object should have `state_dict` method"): - with lr_finder.attach(dummy_engine, to_save={1: 2}) as f: + with lr_finder.attach(dummy_engine, to_save={1: 2}): pass with pytest.raises(ValueError, match=r"Mapping to_save should contain 'optimizer' key"): - with lr_finder.attach(dummy_engine, to_save={"model": model}) as f: + with lr_finder.attach(dummy_engine, to_save={"model": model}): pass to_save = {"model": model, "optimizer": optimizer} with pytest.raises(ValueError, match=r"smooth_f is outside the range \[0, 1\]"): - with lr_finder.attach(dummy_engine, to_save=to_save, smooth_f=234) as f: + with lr_finder.attach(dummy_engine, to_save=to_save, smooth_f=234): pass with pytest.raises(ValueError, match=r"diverge_th should be larger than 1"): - with lr_finder.attach(dummy_engine, to_save=to_save, diverge_th=0.0) as f: + with lr_finder.attach(dummy_engine, to_save=to_save, diverge_th=0.0): pass with pytest.raises(TypeError, match=r"if provided, num_iter should be an integer"): - with lr_finder.attach(dummy_engine, to_save=to_save, num_iter=0.0) as f: + with lr_finder.attach(dummy_engine, to_save=to_save, num_iter=0.0): pass with pytest.raises(ValueError, match=r"if provided, num_iter should be positive"): - with lr_finder.attach(dummy_engine, to_save=to_save, num_iter=0) as f: + with lr_finder.attach(dummy_engine, to_save=to_save, num_iter=0): pass with pytest.raises(TypeError, match=r"Object to_save\['optimizer'] should be torch optimizer"): @@ -194,7 +194,7 @@ def test_attach_incorrect_input_args(lr_finder, dummy_engine, model, optimizer, pass with pytest.raises(ValueError, match=r"step_mode should be 'exp' or 'linear'"): - with lr_finder.attach(dummy_engine, to_save=to_save, step_mode="abc") as f: + with lr_finder.attach(dummy_engine, to_save=to_save, step_mode="abc"): pass with lr_finder.attach(dummy_engine, to_save) as trainer_with_finder: @@ -205,6 +205,20 @@ def test_attach_incorrect_input_args(lr_finder, dummy_engine, model, optimizer, with pytest.raises(ValueError, match=r"skip_end cannot be negative"): lr_finder.plot(skip_end=-1) + with pytest.raises(ValueError, match=r"Number of values of start_lr should be equal to optimizer values."): + with lr_finder.attach(dummy_engine, to_save, start_lr=[0.1, 0.1]): + pass + with pytest.raises(ValueError, match=r"Number of values of end_lr should be equal to optimizer values."): + with lr_finder.attach(dummy_engine, to_save, end_lr=[0.1, 0.1]): + pass + + with pytest.raises(TypeError, match=r"start_lr should be a float or list of floats"): + with lr_finder.attach(dummy_engine, to_save, start_lr=1): + pass + with pytest.raises(TypeError, match=r"end_lr should be a float or list of floats"): + with lr_finder.attach(dummy_engine, to_save, end_lr=1): + pass + def test_attach_without_with(lr_finder, dummy_engine, to_save): _ = lr_finder.attach(dummy_engine, to_save=to_save) @@ -232,15 +246,22 @@ def test_with_attach(lr_finder, to_save, dummy_engine, dataloader): assert len(dummy_engine._event_handlers[event]) == 0 -def test_wrong_values_start_lr_and_end_lr(lr_finder, to_save, dummy_engine, dataloader): +def test_wrong_values_start_lr_and_end_lr( + lr_finder, dummy_engine, to_save, dummy_engine_mulitple_param_groups, to_save_mulitple_param_groups +): with pytest.raises(ValueError, match=r"start_lr must be less than end_lr"): - with lr_finder.attach(dummy_engine, to_save=to_save, start_lr=10, end_lr=1) as trainer_with_finder: - trainer_with_finder.run(dataloader) + with lr_finder.attach(dummy_engine, to_save=to_save, start_lr=10.0, end_lr=1.0): + pass with pytest.raises(ValueError, match=r"start_lr must be less than end_lr"): - with lr_finder.attach(dummy_engine, to_save=to_save, start_lr=10, end_lr=10) as trainer_with_finder: - trainer_with_finder.run(dataloader) + with lr_finder.attach( + dummy_engine_mulitple_param_groups, + to_save=to_save_mulitple_param_groups, + start_lr=[1.0, 10.0, 5.0], + end_lr=[10.0, 10.0, 10.0], + ): + pass def test_model_optimizer_reset(lr_finder, to_save, dummy_engine, dataloader): @@ -275,6 +296,24 @@ def test_lr_policy(lr_finder, to_save, dummy_engine, dataloader): assert all([lr[i - 1] < lr[i] for i in range(1, len(lr))]) +@pytest.mark.parametrize("step_mode", ["exp", "linear"]) +def test_multiple_optimizers( + lr_finder, dummy_engine_mulitple_param_groups, to_save_mulitple_param_groups, dataloader, step_mode +): + start_lr = [0.1, 0.1, 0.01] + end_lr = [1.0, 1.0, 1.0] + with lr_finder.attach( + dummy_engine_mulitple_param_groups, + to_save_mulitple_param_groups, + start_lr=start_lr, + end_lr=end_lr, + step_mode=step_mode, + ) as trainer: + trainer.run(dataloader) + groups_lrs = lr_finder.get_results()["lr"] + assert [all([group_lrs[i - 1] < group_lrs[i] for i in range(1, len(group_lrs))]) for group_lrs in groups_lrs] + + def assert_output_sizes(lr_finder, dummy_engine): iteration = dummy_engine.state.iteration lr_finder_results = lr_finder.get_results() @@ -313,7 +352,7 @@ def test_num_iter_is_not_enough(lr_finder, to_save, dummy_engine, dataloader): def test_detach_terminates(lr_finder, to_save, dummy_engine, dataloader): - with lr_finder.attach(dummy_engine, to_save, end_lr=100, diverge_th=2) as trainer_with_finder: + with lr_finder.attach(dummy_engine, to_save, end_lr=100.0, diverge_th=2) as trainer_with_finder: trainer_with_finder.run(dataloader) dummy_engine.run(dataloader, max_epochs=3) @@ -335,7 +374,7 @@ def test_different_num_iters(lr_finder, to_save, dummy_engine, dataloader): @pytest.mark.parametrize("step_mode", ["exp", "linear"]) def test_start_lr(lr_finder, to_save, dummy_engine, dataloader, step_mode): with lr_finder.attach( - dummy_engine, to_save, start_lr=0.01, end_lr=10, num_iter=5, step_mode=step_mode, diverge_th=1 + dummy_engine, to_save, start_lr=0.01, end_lr=10.0, num_iter=5, step_mode=step_mode, diverge_th=1 ) as trainer_with_finder: trainer_with_finder.run(dataloader) history = lr_finder.get_results() @@ -486,7 +525,7 @@ def test_no_matplotlib(no_site_packages, lr_finder): def test_plot_single_param_group(dirname, lr_finder, mnist_to_save, dummy_engine_mnist, mnist_dataloader): - with lr_finder.attach(dummy_engine_mnist, mnist_to_save, end_lr=20, smooth_f=0.04) as trainer_with_finder: + with lr_finder.attach(dummy_engine_mnist, mnist_to_save, end_lr=20.0, smooth_f=0.04) as trainer_with_finder: trainer_with_finder.run(mnist_dataloader) def _test(ax): @@ -516,7 +555,7 @@ def test_plot_multiple_param_groups( ): with lr_finder.attach( - dummy_engine_mulitple_param_groups, to_save_mulitple_param_groups, end_lr=20, smooth_f=0.04 + dummy_engine_mulitple_param_groups, to_save_mulitple_param_groups, end_lr=20.0, smooth_f=0.04 ) as trainer_with_finder: trainer_with_finder.run(dataloader_plot)