From ae9ced0a73fc814d5aa7cb4f573f03429a608c64 Mon Sep 17 00:00:00 2001 From: James Martens Date: Wed, 27 Dec 2023 09:24:10 -0800 Subject: [PATCH] Improving docstrings for schedules. PiperOrigin-RevId: 594048775 --- optax/schedules/_schedule.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/optax/schedules/_schedule.py b/optax/schedules/_schedule.py index d7cf2d690..641f95cd8 100644 --- a/optax/schedules/_schedule.py +++ b/optax/schedules/_schedule.py @@ -57,8 +57,8 @@ def polynomial_schedule( init_value: initial value for the scalar to be annealed. end_value: end value of the scalar to be annealed. power: the power of the polynomial used to transition from init to end. - transition_steps: number of steps over which annealing takes place, - the scalar starts changing at `transition_begin` steps and completes + transition_steps: number of steps over which annealing takes place. + The scalar starts changing at `transition_begin` steps and completes the transition by `transition_begin + transition_steps` steps. If `transition_steps <= 0`, then the entire annealing process is disabled and the value is held fixed at `init_value`. @@ -76,7 +76,7 @@ def polynomial_schedule( if transition_begin < 0: logging.info( - 'An exponential schedule was set with a negative `transition_begin` ' + 'A polynomial schedule was set with a negative `transition_begin` ' 'value; this will result in `transition_begin` falling back to `0`.') transition_begin = 0 @@ -142,10 +142,12 @@ def exponential_decay( """Constructs a schedule with either continuous or discrete exponential decay. This function applies an exponential decay function to a provided initial - value. The function returns the decayed value as follows: + value. When `count >= transition_begin` the function returns the decayed value + as follows: ``` - decayed_value = init_value * decay_rate ^ (count / transition_steps) + decayed_value = init_value * decay_rate ^ ((count - transition_begin) + / transition_steps) ``` If the argument `staircase` is `True`, then `count / transition_steps` is