Skip to content

Commit

Permalink
Allow clipping the dynamic loss scale to a minimum value
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 528490282
  • Loading branch information
Flax Team committed May 1, 2023
1 parent 12b0245 commit bb76e47
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions flax/training/dynamic_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,17 @@ def loss_fn(p):
be increased (default: 2000).
fin_steps: indicates how many gradient steps in a row have been finite.
scale: the current scale by which the loss is multiplied.
minimum_scale: the minimum value that the scale can take (default: the
smallest positive number representable in floating point).
"""
growth_factor: float = struct.field(pytree_node=False, default=2.0)
backoff_factor: float = struct.field(pytree_node=False, default=0.5)
growth_interval: int = struct.field(pytree_node=False, default=2000)
fin_steps: Array = 0
scale: Array = 65536.0
minimum_scale: Optional[float] = struct.field(
pytree_node=False, default=jnp.finfo(jnp.float32).tiny
)

def value_and_grad(self, fun: Callable[..., Any],
argnums: Union[int, Sequence[int]] = 0,
Expand Down Expand Up @@ -137,6 +142,8 @@ def grad_fn_wrapper(*args):
jnp.minimum(self.scale * self.growth_factor, jnp.finfo(jnp.float32).max),
self.scale)
inf_scale = self.scale * self.backoff_factor
if self.minimum_scale is not None:
inf_scale = jnp.maximum(inf_scale, self.minimum_scale)
new_scale = jnp.where(finite, fin_scale, inf_scale)
new_fin_steps = jnp.where(grow | (~finite), 0, self.fin_steps + 1)

Expand Down

0 comments on commit bb76e47

Please sign in to comment.