Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added ReduceLROnPlateau callback for VI. #7011

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
73 changes: 72 additions & 1 deletion pymc/variational/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import numpy as np

__all__ = ["Callback", "CheckParametersConvergence", "Tracker"]
__all__ = ["Callback", "CheckParametersConvergence", "ReduceLROnPlateau", "Tracker"]


class Callback:
Expand Down Expand Up @@ -93,6 +93,77 @@
return np.concatenate([sh.get_value().flatten() for sh in shared_list])


class ReduceLROnPlateau(Callback):
"""Reduce learning rate when the loss has stopped improving.

This is inspired by Keras' homonymous callback:
https://github.com/keras-team/keras/blob/v2.14.0/keras/callbacks.py

Parameters
----------
optimizer: callable
PyMC optimizer
factor: float
factor by which the learning rate will be reduced: `new_lr = lr * factor`
patience: int
number of epochs with no improvement after which learning rate will be reduced
min_lr: float
lower bound on the learning rate
cooldown: int
number of iterations to wait before resuming normal operation after lr has been reduced
verbose: bool
false: quiet, true: update messages
"""

def __init__(
self,
optimizer,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the user have to provide this? Can it instead be inferred somehow from the host VI object? It's ugly to have to pass the optimizer twice (once for the VI itself, then again in the callback)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, this would be great, but I haven't figured out whether it's possible. Probably one for someone more familiar with the codebase :)

factor=0.1,
patience=10,
min_lr=1e-6,
cooldown=0,
):
self.optimizer = optimizer
self.factor = factor
self.patience = patience
self.min_lr = min_lr
self.cooldown = cooldown

self.cooldown_counter = 0
self.wait = 0
self.best = float("inf")

def __call__(self, approx, loss_hist, i):
current = loss_hist[-1]

if np.isinf(current):
return

if self.in_cooldown():
self.cooldown_counter -= 1
self.wait = 0
return

Check warning on line 145 in pymc/variational/callbacks.py

View check run for this annotation

Codecov / codecov/patch

pymc/variational/callbacks.py#L143-L145

Added lines #L143 - L145 were not covered by tests

if current < self.best:
self.best = current
self.wait = 0
elif not np.isinf(self.best):
self.wait += 1
if self.wait >= self.patience:
self.reduce_lr()
self.cooldown_counter = self.cooldown
self.wait = 0

def reduce_lr(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would still prefer that this was done symbolically with shared variables, because it will allow for composition between learning rate annealing strategies.

old_lr = float(self.optimizer.keywords["learning_rate"])
if old_lr > self.min_lr:
new_lr = max(old_lr * self.factor, self.min_lr)
self.optimizer.keywords["learning_rate"] = new_lr

def in_cooldown(self):
return self.cooldown_counter > 0


class Tracker(Callback):
"""
Helper class to record arbitrary stats during VI
Expand Down
30 changes: 30 additions & 0 deletions tests/variational/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,33 @@ def test_tracker_callback():
tracker = pm.callbacks.Tracker(bad=lambda t: t) # bad signature
with pytest.raises(TypeError):
tracker(None, None, 1)


def test_reducelronplateau_callback():
optimizer = pm.adam(learning_rate=0.1)
alvaropp marked this conversation as resolved.
Show resolved Hide resolved
cb = pm.variational.callbacks.ReduceLROnPlateau(
optimizer=optimizer,
patience=1,
min_lr=0.001,
)
cb(None, [float("inf")], 1)
np.testing.assert_almost_equal(optimizer.keywords["learning_rate"], 0.1)
assert cb.best == float("inf")
cb(None, [float("inf"), 2], 1)
np.testing.assert_almost_equal(optimizer.keywords["learning_rate"], 0.1)
assert cb.best == 2
cb(None, [float("inf"), 2, 1], 1)
np.testing.assert_almost_equal(optimizer.keywords["learning_rate"], 0.1)
assert cb.best == 1
cb(None, [float("inf"), 2, 1, 99], 1)
np.testing.assert_almost_equal(optimizer.keywords["learning_rate"], 0.01)
assert cb.best == 1
cb(None, [float("inf"), 2, 1, 99, 0.9], 1)
np.testing.assert_almost_equal(optimizer.keywords["learning_rate"], 0.01)
assert cb.best == 0.9
cb(None, [float("inf"), 2, 1, 99, 0.9, 99], 1)
np.testing.assert_almost_equal(optimizer.keywords["learning_rate"], 0.001)
assert cb.best == 0.9
cb(None, [float("inf"), 2, 1, 99, 0.9, 99, 99], 1)
np.testing.assert_almost_equal(optimizer.keywords["learning_rate"], 0.001)
assert cb.best == 0.9
Loading