Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add end_scale argument #975

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions optax/contrib/_reduce_on_plateau.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def reduce_on_plateau(
atol: float = 0.0,
cooldown: int = 0,
accumulation_size: int = 1,
min_scale: float = 0.0,
) -> base.GradientTransformationExtraArgs:
"""Reduce learning rate when a metric has stopped improving.

Expand All @@ -62,17 +63,23 @@ def reduce_on_plateau(
atol: Absolute tolerance for measuring new optimum.
cooldown: Number of iterations to wait before resuming normal operation
after scale has been reduced.
accumulation_size: Number of valeus to aggregate before applying the logic
accumulation_size: Number of values to aggregate before applying the logic
of reduce on plateau. If the value fed to the optimizer is a test value,
simply take 1 (default). If the value fed to the optimizer is the loss on
a the current minibatch, consider using a larger accumulation size.
min_scale: Scale at which the learning rate decay stops.

Returns:
A GradientTransformationExtraArgs object.

.. seealso::
* :doc:`../../_collections/examples/contrib/reduce_on_plateau` example.
"""
if factor <= 0.0 or factor >= 1.0:
raise ValueError(
f"Factor must be in the range (0, 1), got factor = {factor}."
)

if rtol < 0.0 or atol < 0.0:
raise ValueError(
"Both rtol and atol must be non-negative, got "
Expand Down Expand Up @@ -124,14 +131,18 @@ def not_in_cooldown():
new_plateau_count = jnp.where(
curr_plateau_count == patience, 0, curr_plateau_count
)
new_scale = jnp.where(
curr_plateau_count == patience,
state.scale * factor,
state.scale,
new_scale = jnp.maximum(
jnp.where(
curr_plateau_count == patience,
state.scale * factor,
state.scale,
),
min_scale,
)
new_cooldown_count = jnp.where(
curr_plateau_count == patience, cooldown, 0
).astype(jnp.int32)

return new_plateau_count, new_scale, new_cooldown_count

new_plateau_count, new_scale, new_cooldown_count = jax.lax.cond(
Expand Down
37 changes: 35 additions & 2 deletions optax/contrib/_reduce_on_plateau_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def setUp(self):
rtol=1e-4,
atol=0.0,
cooldown=self.cooldown,
accumulation_size=1
accumulation_size=1,
min_scale=0.01,
)
self.updates = {'params': jnp.array(1.0)} # dummy updates

Expand All @@ -55,7 +56,6 @@ def test_learning_rate_reduced_after_cooldown_period_is_over(
# Initialize the state
state = self.transform.init(self.updates['params'])

updates = self.updates
# Wait until patience runs out
for _ in range(self.patience + 1):
updates, state = self.transform.update(
Expand Down Expand Up @@ -138,6 +138,39 @@ def test_learning_rate_not_reduced_during_cooldown(self, enable_x64):
chex.assert_trees_all_close(plateau_count, 0)
chex.assert_trees_all_close(cooldown_count, 2)

@parameterized.parameters(False, True)
def test_learning_rate_not_reduced_after_end_scale_is_reached(
self, enable_x64
):
"""Test that learning rate is not reduced if min_scale has been reached."""

# Enable float64 if requested
jax.config.update('jax_enable_x64', enable_x64)

# State with scale == min_scale
state = _reduce_on_plateau.ReduceLROnPlateauState(
best_value=jnp.array(1.0, dtype=jnp.float32),
plateau_count=jnp.array(0, dtype=jnp.int32),
scale=jnp.array(0.01, dtype=jnp.float32),
cooldown_count=jnp.array(0, dtype=jnp.int32),
count=jnp.array(0, dtype=jnp.int32),
avg_value=jnp.array(0.0, dtype=jnp.float32),
)

# Wait until patience runs out
for _ in range(self.patience + 1):
updates, state = self.transform.update(
updates=self.updates, state=state, value=0.1,
)

# Check that learning rate is not reduced
scale, best_value, plateau_count, cooldown_count, *_ = state
chex.assert_trees_all_close(scale, 0.01)
chex.assert_trees_all_close(best_value, 0.1)
chex.assert_trees_all_close(plateau_count, 0)
chex.assert_trees_all_close(cooldown_count, self.cooldown)
chex.assert_trees_all_close(updates, {'params': jnp.array(0.01)})


if __name__ == '__main__':
absltest.main()
Loading