Skip to content

Commit

Permalink
Merge pull request #965 from levskaya:optimizer_fixes
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 355229213
  • Loading branch information
Flax Authors committed Feb 2, 2021
2 parents b112b11 + f9038b1 commit 61580b9
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion flax/optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def apply_param_gradient(self, step, hyper_params, param, state, grad):
grad_sq_ema = beta2 * state.grad_sq_ema + (1. - beta2) * grad_sq

# bias correction
t = step + 1.
t = jnp.array(step + 1, lax.dtype(param.dtype))
grad_ema_corr = grad_ema / (1 - beta1 ** t)
grad_sq_ema_corr = grad_sq_ema / (1 - beta2 ** t)

Expand Down
2 changes: 1 addition & 1 deletion flax/optim/lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def apply_param_gradient(self, step, hyper_params, param, state, grad):
grad_ema = beta1 * state.grad_ema + (1. - beta1) * grad
grad_sq_ema = beta2 * state.grad_sq_ema + (1. - beta2) * grad_sq

t = step + 1.
t = jnp.array(step + 1, lax.dtype(param.dtype))
grad_ema_corr = grad_ema / (1. - beta1 ** t)
grad_sq_ema_corr = grad_sq_ema / (1. - beta2 ** t)

Expand Down

0 comments on commit 61580b9

Please sign in to comment.