Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Edit LRFinder to have more than one parameter #2704

Merged
merged 20 commits into from
Oct 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 71 additions & 29 deletions ignite/handlers/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import ignite.distributed as idist
from ignite.engine import Engine, Events
from ignite.handlers import Checkpoint
from ignite.handlers.param_scheduler import LRScheduler, PiecewiseLinear
from ignite.handlers.param_scheduler import LRScheduler, ParamGroupScheduler, PiecewiseLinear


class FastaiLRFinder:
Expand Down Expand Up @@ -74,11 +74,12 @@ class FastaiLRFinder:
.. versionadded:: 0.4.6
"""

_lr_schedule: Union[LRScheduler, PiecewiseLinear, ParamGroupScheduler]

def __init__(self) -> None:
self._diverge_flag = False
self._history = {} # type: Dict[str, List[Any]]
self._best_loss = None
self._lr_schedule = None # type: Optional[Union[LRScheduler, PiecewiseLinear]]
self.logger = logging.getLogger(__name__ + "." + self.__class__.__name__)

def _run(
Expand All @@ -87,8 +88,8 @@ def _run(
optimizer: Optimizer,
output_transform: Callable,
num_iter: int,
start_lr: float,
end_lr: float,
start_lrs: List[float],
end_lrs: List[float],
step_mode: str,
smooth_f: float,
diverge_th: float,
Expand Down Expand Up @@ -118,22 +119,35 @@ def _run(
)

self.logger.debug(f"Running LR finder for {num_iter} iterations")
if start_lr is None:
start_lr = optimizer.param_groups[0]["lr"]

# Initialize the proper learning rate policy
if step_mode.lower() == "exp":
start_lr_list = [start_lr] * len(optimizer.param_groups)
self._lr_schedule = LRScheduler(_ExponentialLR(optimizer, start_lr_list, end_lr, num_iter))
self._lr_schedule = LRScheduler(_ExponentialLR(optimizer, start_lrs, end_lrs, num_iter))
else:
self._lr_schedule = PiecewiseLinear(
optimizer, param_name="lr", milestones_values=[(0, start_lr), (num_iter, end_lr)]
)
if len(start_lrs) == 1:
self._lr_schedule = PiecewiseLinear(
optimizer,
param_name="lr",
milestones_values=[(0, start_lrs[0]), (num_iter, end_lrs[0])],
)
else:
self._lr_schedule = ParamGroupScheduler(
[
PiecewiseLinear(
optimizer,
param_name="lr",
milestones_values=[(0, start_lrs[i]), (num_iter, end_lrs[i])],
param_group_index=i,
)
for i in range(len(optimizer.param_groups))
]
)
if not trainer.has_event_handler(self._lr_schedule):
trainer.add_event_handler(Events.ITERATION_COMPLETED, self._lr_schedule, num_iter)

def _reset(self, trainer: Engine) -> None:
self.logger.debug("Completed LR finder run")
trainer.remove_event_handler(self._lr_schedule, Events.ITERATION_COMPLETED) # type: ignore[arg-type]
trainer.remove_event_handler(self._lr_schedule, Events.ITERATION_COMPLETED)
trainer.remove_event_handler(self._log_lr_and_loss, Events.ITERATION_COMPLETED)
trainer.remove_event_handler(self._reached_num_iterations, Events.ITERATION_COMPLETED)

Expand All @@ -157,7 +171,7 @@ def _log_lr_and_loss(self, trainer: Engine, output_transform: Callable, smooth_f
f"but got output of type {type(loss).__name__}"
)
loss = idist.all_reduce(loss)
lr = self._lr_schedule.get_param() # type: ignore[union-attr]
lr = self._lr_schedule.get_param()
self._history["lr"].append(lr)
if trainer.state.iteration == 1:
self._best_loss = loss
Expand Down Expand Up @@ -251,7 +265,6 @@ def plot(
)
if not self._history:
raise RuntimeError("learning rate finder didn't run yet so results can't be plotted")

if skip_start < 0:
raise ValueError("skip_start cannot be negative")
if skip_end < 0:
Expand Down Expand Up @@ -367,8 +380,8 @@ def attach(
to_save: Mapping,
output_transform: Callable = lambda output: output,
num_iter: Optional[int] = None,
start_lr: Optional[float] = None,
end_lr: float = 10.0,
start_lr: Optional[Union[float, List[float]]] = None,
end_lr: Optional[Union[float, List[float]]] = 10.0,
step_mode: str = "exp",
smooth_f: float = 0.05,
diverge_th: float = 5.0,
Expand Down Expand Up @@ -432,8 +445,37 @@ def attach(
raise TypeError(f"if provided, num_iter should be an integer, but give {num_iter}")
if num_iter <= 0:
raise ValueError(f"if provided, num_iter should be positive, but give {num_iter}")
if isinstance(start_lr, (float, int)) and start_lr >= end_lr:
raise ValueError(f"start_lr must be less than end_lr, start_lr={start_lr} vs end_lr={end_lr}")

optimizer = to_save["optimizer"]
if start_lr is None:
start_lrs = [pg["lr"] for pg in optimizer.param_groups]
elif isinstance(start_lr, float):
start_lrs = [start_lr] * len(optimizer.param_groups)
elif isinstance(start_lr, list):
if len(start_lr) != len(optimizer.param_groups):
raise ValueError(
"Number of values of start_lr should be equal to optimizer values."
f"start_lr values:{len(start_lr)} optimizer values: {len(optimizer.param_groups)}"
)
start_lrs = start_lr
else:
raise TypeError(f"start_lr should be a float or list of floats, but given {type(start_lr)}")

if isinstance(end_lr, float):
end_lrs = [end_lr] * len(optimizer.param_groups)
elif isinstance(end_lr, list):
if len(end_lr) != len(optimizer.param_groups):
raise ValueError(
"Number of values of end_lr should be equal to optimizer values."
f"end_lr values:{len(end_lr)} optimizer values: {len(optimizer.param_groups)}"
)
end_lrs = end_lr
else:
raise TypeError(f"end_lr should be a float or list of floats, but given {type(end_lr)}")

for start, end in zip(start_lrs, end_lrs):
if start >= end:
raise ValueError(f"start_lr must be less than end_lr, start_lr={start_lr} vs end_lr={end_lr}")

# store to_save
with tempfile.TemporaryDirectory() as tmpdirname:
Expand All @@ -443,7 +485,6 @@ def attach(
cache_filepath = Path(tmpdirname) / "ignite_lr_finder_cache.pt"
torch.save(obj, cache_filepath.as_posix())

optimizer = to_save["optimizer"]
# Attach handlers
if not trainer.has_event_handler(self._run):
trainer.add_event_handler(
Expand All @@ -452,8 +493,8 @@ def attach(
optimizer,
output_transform,
num_iter,
start_lr,
end_lr,
start_lrs,
end_lrs,
step_mode,
smooth_f,
diverge_th,
Expand All @@ -479,23 +520,24 @@ class _ExponentialLR(_LRScheduler):

Args:
optimizer: wrapped optimizer.
end_lr: the initial learning rate which is the lower
boundary of the test. Default: 10.
start_lrs: the initial learning rate for parameter groups.
end_lrs: the final learning rate for parameter groups.
num_iter: the number of iterations over which the test
occurs. Default: 100.
last_epoch: the index of last epoch. Default: -1.

"""

def __init__(self, optimizer: Optimizer, start_lr: List[float], end_lr: float, num_iter: int, last_epoch: int = -1):
self.end_lr = end_lr
def __init__(
self, optimizer: Optimizer, start_lrs: List[float], end_lrs: List[float], num_iter: int, last_epoch: int = -1
):
self.end_lrs = end_lrs
self.num_iter = num_iter
super(_ExponentialLR, self).__init__(optimizer, last_epoch)

# override base_lrs
self.base_lrs = start_lr
self.base_lrs = start_lrs

def get_lr(self) -> List[float]: # type: ignore
def get_lr(self) -> List[float]: # type: ignore[override]
curr_iter = self.last_epoch + 1
r = curr_iter / self.num_iter
return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs]
return [base_lr * (end_lr / base_lr) ** r for end_lr, base_lr in zip(self.end_lrs, self.base_lrs)]
73 changes: 56 additions & 17 deletions tests/ignite/handlers/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,40 +161,40 @@ def mnist_dataloader():
def test_attach_incorrect_input_args(lr_finder, dummy_engine, model, optimizer, dataloader):

with pytest.raises(TypeError, match=r"Argument to_save should be a mapping"):
with lr_finder.attach(dummy_engine, to_save=123) as f:
with lr_finder.attach(dummy_engine, to_save=123):
pass

with pytest.raises(TypeError, match=r"Object <class 'int'> should have `state_dict` method"):
with lr_finder.attach(dummy_engine, to_save={1: 2}) as f:
with lr_finder.attach(dummy_engine, to_save={1: 2}):
pass

with pytest.raises(ValueError, match=r"Mapping to_save should contain 'optimizer' key"):
with lr_finder.attach(dummy_engine, to_save={"model": model}) as f:
with lr_finder.attach(dummy_engine, to_save={"model": model}):
pass

to_save = {"model": model, "optimizer": optimizer}
with pytest.raises(ValueError, match=r"smooth_f is outside the range \[0, 1\]"):
with lr_finder.attach(dummy_engine, to_save=to_save, smooth_f=234) as f:
with lr_finder.attach(dummy_engine, to_save=to_save, smooth_f=234):
pass

with pytest.raises(ValueError, match=r"diverge_th should be larger than 1"):
with lr_finder.attach(dummy_engine, to_save=to_save, diverge_th=0.0) as f:
with lr_finder.attach(dummy_engine, to_save=to_save, diverge_th=0.0):
pass

with pytest.raises(TypeError, match=r"if provided, num_iter should be an integer"):
with lr_finder.attach(dummy_engine, to_save=to_save, num_iter=0.0) as f:
with lr_finder.attach(dummy_engine, to_save=to_save, num_iter=0.0):
pass

with pytest.raises(ValueError, match=r"if provided, num_iter should be positive"):
with lr_finder.attach(dummy_engine, to_save=to_save, num_iter=0) as f:
with lr_finder.attach(dummy_engine, to_save=to_save, num_iter=0):
pass

with pytest.raises(TypeError, match=r"Object to_save\['optimizer'] should be torch optimizer"):
with lr_finder.attach(dummy_engine, {"model": to_save["model"], "optimizer": to_save["model"]}):
pass

with pytest.raises(ValueError, match=r"step_mode should be 'exp' or 'linear'"):
with lr_finder.attach(dummy_engine, to_save=to_save, step_mode="abc") as f:
with lr_finder.attach(dummy_engine, to_save=to_save, step_mode="abc"):
pass

with lr_finder.attach(dummy_engine, to_save) as trainer_with_finder:
Expand All @@ -205,6 +205,20 @@ def test_attach_incorrect_input_args(lr_finder, dummy_engine, model, optimizer,
with pytest.raises(ValueError, match=r"skip_end cannot be negative"):
lr_finder.plot(skip_end=-1)

with pytest.raises(ValueError, match=r"Number of values of start_lr should be equal to optimizer values."):
with lr_finder.attach(dummy_engine, to_save, start_lr=[0.1, 0.1]):
pass
with pytest.raises(ValueError, match=r"Number of values of end_lr should be equal to optimizer values."):
with lr_finder.attach(dummy_engine, to_save, end_lr=[0.1, 0.1]):
pass

with pytest.raises(TypeError, match=r"start_lr should be a float or list of floats"):
with lr_finder.attach(dummy_engine, to_save, start_lr=1):
pass
with pytest.raises(TypeError, match=r"end_lr should be a float or list of floats"):
with lr_finder.attach(dummy_engine, to_save, end_lr=1):
pass


def test_attach_without_with(lr_finder, dummy_engine, to_save):
_ = lr_finder.attach(dummy_engine, to_save=to_save)
Expand Down Expand Up @@ -232,15 +246,22 @@ def test_with_attach(lr_finder, to_save, dummy_engine, dataloader):
assert len(dummy_engine._event_handlers[event]) == 0


def test_wrong_values_start_lr_and_end_lr(lr_finder, to_save, dummy_engine, dataloader):
def test_wrong_values_start_lr_and_end_lr(
lr_finder, dummy_engine, to_save, dummy_engine_mulitple_param_groups, to_save_mulitple_param_groups
):

with pytest.raises(ValueError, match=r"start_lr must be less than end_lr"):
with lr_finder.attach(dummy_engine, to_save=to_save, start_lr=10, end_lr=1) as trainer_with_finder:
trainer_with_finder.run(dataloader)
with lr_finder.attach(dummy_engine, to_save=to_save, start_lr=10.0, end_lr=1.0):
pass

with pytest.raises(ValueError, match=r"start_lr must be less than end_lr"):
with lr_finder.attach(dummy_engine, to_save=to_save, start_lr=10, end_lr=10) as trainer_with_finder:
trainer_with_finder.run(dataloader)
with lr_finder.attach(
dummy_engine_mulitple_param_groups,
to_save=to_save_mulitple_param_groups,
start_lr=[1.0, 10.0, 5.0],
end_lr=[10.0, 10.0, 10.0],
):
pass


def test_model_optimizer_reset(lr_finder, to_save, dummy_engine, dataloader):
Expand Down Expand Up @@ -275,6 +296,24 @@ def test_lr_policy(lr_finder, to_save, dummy_engine, dataloader):
assert all([lr[i - 1] < lr[i] for i in range(1, len(lr))])


@pytest.mark.parametrize("step_mode", ["exp", "linear"])
def test_multiple_optimizers(
lr_finder, dummy_engine_mulitple_param_groups, to_save_mulitple_param_groups, dataloader, step_mode
):
start_lr = [0.1, 0.1, 0.01]
end_lr = [1.0, 1.0, 1.0]
with lr_finder.attach(
dummy_engine_mulitple_param_groups,
to_save_mulitple_param_groups,
start_lr=start_lr,
end_lr=end_lr,
step_mode=step_mode,
) as trainer:
trainer.run(dataloader)
groups_lrs = lr_finder.get_results()["lr"]
assert [all([group_lrs[i - 1] < group_lrs[i] for i in range(1, len(group_lrs))]) for group_lrs in groups_lrs]


def assert_output_sizes(lr_finder, dummy_engine):
iteration = dummy_engine.state.iteration
lr_finder_results = lr_finder.get_results()
Expand Down Expand Up @@ -313,7 +352,7 @@ def test_num_iter_is_not_enough(lr_finder, to_save, dummy_engine, dataloader):


def test_detach_terminates(lr_finder, to_save, dummy_engine, dataloader):
with lr_finder.attach(dummy_engine, to_save, end_lr=100, diverge_th=2) as trainer_with_finder:
with lr_finder.attach(dummy_engine, to_save, end_lr=100.0, diverge_th=2) as trainer_with_finder:
trainer_with_finder.run(dataloader)

dummy_engine.run(dataloader, max_epochs=3)
Expand All @@ -335,7 +374,7 @@ def test_different_num_iters(lr_finder, to_save, dummy_engine, dataloader):
@pytest.mark.parametrize("step_mode", ["exp", "linear"])
def test_start_lr(lr_finder, to_save, dummy_engine, dataloader, step_mode):
with lr_finder.attach(
dummy_engine, to_save, start_lr=0.01, end_lr=10, num_iter=5, step_mode=step_mode, diverge_th=1
dummy_engine, to_save, start_lr=0.01, end_lr=10.0, num_iter=5, step_mode=step_mode, diverge_th=1
) as trainer_with_finder:
trainer_with_finder.run(dataloader)
history = lr_finder.get_results()
Expand Down Expand Up @@ -486,7 +525,7 @@ def test_no_matplotlib(no_site_packages, lr_finder):

def test_plot_single_param_group(dirname, lr_finder, mnist_to_save, dummy_engine_mnist, mnist_dataloader):

with lr_finder.attach(dummy_engine_mnist, mnist_to_save, end_lr=20, smooth_f=0.04) as trainer_with_finder:
with lr_finder.attach(dummy_engine_mnist, mnist_to_save, end_lr=20.0, smooth_f=0.04) as trainer_with_finder:
trainer_with_finder.run(mnist_dataloader)

def _test(ax):
Expand Down Expand Up @@ -516,7 +555,7 @@ def test_plot_multiple_param_groups(
):

with lr_finder.attach(
dummy_engine_mulitple_param_groups, to_save_mulitple_param_groups, end_lr=20, smooth_f=0.04
dummy_engine_mulitple_param_groups, to_save_mulitple_param_groups, end_lr=20.0, smooth_f=0.04
) as trainer_with_finder:
trainer_with_finder.run(dataloader_plot)

Expand Down