Skip to content

Commit

Permalink
Merge pull request #1 from PaddlePaddle/release/2.2
Browse files Browse the repository at this point in the history
[cherry pick]split minimize and add unscale_ for GradScaler (#35927)
  • Loading branch information
YuanRisheng authored Sep 26, 2021
2 parents c658c79 + e262125 commit f58e33c
Show file tree
Hide file tree
Showing 5 changed files with 260 additions and 45 deletions.
99 changes: 92 additions & 7 deletions python/paddle/amp/grad_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,28 @@
# limitations under the License.

from paddle.fluid.dygraph.amp import AmpScaler
from paddle.fluid.dygraph.amp import OptimizerState
from collections import defaultdict

__all__ = []


def _refresh_optimizer_state():
return {"state": OptimizerState.INIT}


class GradScaler(AmpScaler):
"""
GradScaler is used for Auto-Mixed-Precision training in dynamic graph mode.
It controls the scaling of loss, helps avoiding numerical overflow.
The object of this class has two methods `scale()`, `minimize()`.
The object of this class has nineteen methods `scale()`, `unscale_()`, `minimize()`, `step()`, `update()` and `get`/`set` api of parameters.
`scale()` is used to multiply the loss by a scale ratio.
`minimize()` is similar as `optimizer.minimize()`, performs parameters updating.
`unscale_()` is used to unscale the gradients of parameters, multiplies the gradients of parameters by 1/(scale ratio)
`minimize()` is similar as `optimizer.minimize()`, performs parameters updating, and it will update the loss_scaling, it equal to `step()` + `update()`.
`step()` is similar as `optimizer.step()`, which performs parameters updating.
`update` is used to update the loss_scaling.
Commonly, it is used together with `paddle.amp.auto_cast` to achieve Auto-Mixed-Precision in
dynamic graph mode.
Expand Down Expand Up @@ -115,7 +125,7 @@ def minimize(self, optimizer, *args, **kwargs):
This function is similar as `optimizer.minimize()`, which performs parameters updating.
If the scaled gradients of parameters contains NAN or INF, the parameters updating is skipped.
Otherwise, it first unscales the scaled gradients of parameters, then updates the parameters.
Otherwise, if `unscale_()` has not been called, it first unscales the scaled gradients of parameters, then updates the parameters.
Finally, the loss scaling ratio is updated.
Expand Down Expand Up @@ -151,16 +161,18 @@ def step(self, optimizer):
This function is similar as `optimizer.step()`, which performs parameters updating.
If the scaled gradients of parameters contains NAN or INF, the parameters updating is skipped.
Otherwise, it first unscales the scaled gradients of parameters, then updates the parameters.
Otherwise, if `unscale_()` has not been called, it first unscales the scaled gradients of parameters, then updates the parameters.
Args:
optimizer(Optimizer): The optimizer used to update parameters.
Examples:
.. code-block:: python
# required: gpu
import paddle
model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True)
optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
Expand All @@ -170,24 +182,97 @@ def step(self, optimizer):
loss = paddle.mean(conv)
scaled = scaler.scale(loss) # scale the loss
scaled.backward() # do backward
scaler.step(optimizer)
scaler.step(optimizer) # update parameters
scaler.update() # update the loss scaling ratio
optimizer.clear_grad()
"""
if not self._enable:
return optimizer.step()

optimizer_state = self._optimizer_states[id(optimizer)]
if optimizer_state["state"] is OptimizerState.STEPPED:
raise RuntimeError(
"step() has already been called since the last update().")

# unscale the grad
self._unscale(optimizer)
if optimizer_state["state"] is OptimizerState.INIT:
self._unscale(optimizer)

if self._found_inf:
self._cache_founf_inf = True
else:
optimizer.step()
self._cache_founf_inf = False

optimizer_state["state"] = OptimizerState.STEPPED

if not self._use_dynamic_loss_scaling:
self._optimizer_states = defaultdict(_refresh_optimizer_state)

def update(self):
"""
Updates the loss_scaling.
Examples:
.. code-block:: python
# required: gpu
import paddle
model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True)
optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
data = paddle.rand([10, 3, 32, 32])
with paddle.amp.auto_cast():
conv = model(data)
loss = paddle.mean(conv)
scaled = scaler.scale(loss) # scale the loss
scaled.backward() # do backward
scaler.step(optimizer) # update parameters
scaler.update() # update the loss scaling ratio
optimizer.clear_grad()
"""
if not self._enable:
return
if self._use_dynamic_loss_scaling:
# uopdate the scale
self._update()
self._optimizer_states = defaultdict(_refresh_optimizer_state)
return

def unscale_(self, optimizer):
"""
Unscale the gradients of parameters, multiplies the gradients of parameters by 1/(loss scaling ratio).
If this instance of :class:`GradScaler` is not enabled, output are returned unmodified.
Args:
optimizer(Optimizer): The optimizer used to update parameters.
Returns:
The unscaled parameters or original parameters.
Examples:
.. code-block:: python
# required: gpu
import paddle
model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True)
optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
data = paddle.rand([10, 3, 32, 32])
with paddle.amp.auto_cast():
conv = model(data)
loss = paddle.mean(conv)
scaled = scaler.scale(loss) # scale the loss
scaled.backward() # do backward
scaler.unscale_(optimizer) # unscale the parameter
scaler.step(optimizer)
scaler.update()
optimizer.clear_grad()
"""
return super(GradScaler, self)._unscale(optimizer)

def is_enable(self):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def _broadcast_final_loss(self):
def _optimizer_step(self):
if self.scaler:
self.scaler.step(self.optimizer)
self.scaler.update()
else:
self.optimizer.step()

Expand Down
114 changes: 77 additions & 37 deletions python/paddle/fluid/dygraph/amp/loss_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,20 @@
import warnings
import numpy as np
from paddle import _C_ops
from collections import defaultdict
from enum import Enum

__all__ = ['AmpScaler']
__all__ = ['AmpScaler', 'OptimizerState']


class OptimizerState(Enum):
INIT = 0
UNSCALED = 1
STEPPED = 2


def _refresh_optimizer_state():
return {"state": OptimizerState.INIT}


class AmpScaler(object):
Expand All @@ -31,10 +43,11 @@ class AmpScaler(object):
AmpScaler is used for Auto-Mixed-Precision training/inferring in imperative
mode. It controls the scaling of loss, helps avoiding numerical overflow.
The object of this class has two methods `scale()`, `minimize()`.
The object of this class has seventeen methods `scale()`, `unscale_()`, `minimize()` and `get`/`set` api of parameters.
`scale()` is used to multiply the loss by a scale ratio.
`minimize()` is similar as `Optimizer.minimize()`, performs parameters updating.
`unscale_()` is used to unscale the gradients of parameters, multiplies the gradients of parameters by 1/(scale ratio)
`minimize()` is similar as `optimizer.minimize()`, performs parameters updating, and it will update the loss_scaling.
Commonly, it is used together with `amp_guard` to achieve Auto-Mixed-Precision in
imperative mode.
Expand Down Expand Up @@ -117,6 +130,7 @@ def __init__(self,
self._scale = to_variable(
np.array([self._init_loss_scaling]).astype(np.float32))
self._cache_founf_inf = None
self._optimizer_states = defaultdict(_refresh_optimizer_state)

def scale(self, var):
"""
Expand All @@ -129,24 +143,25 @@ def scale(self, var):
The scaled variable or original variable.
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
with fluid.dygraph.guard():
model = fluid.dygraph.Conv2D(3, 2, 3)
optimizer = fluid.optimizer.SGDOptimizer(
learning_rate=0.01, parameter_list=model.parameters())
scaler = fluid.dygraph.AmpScaler(init_loss_scaling=1024)
data = fluid.dygraph.to_variable(data)
with fluid.dygraph.amp_guard():
conv = model(data)
loss = fluid.layers.reduce_mean(conv)
scaled = scaler.scale(loss)
scaled.backward()
scaler.minimize(optimizer, scaled)
import numpy as np
import paddle.fluid as fluid
data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
with fluid.dygraph.guard():
model = fluid.dygraph.Conv2D(3, 2, 3)
optimizer = fluid.optimizer.SGDOptimizer(
learning_rate=0.01, parameter_list=model.parameters())
scaler = fluid.dygraph.AmpScaler(init_loss_scaling=1024)
data = fluid.dygraph.to_variable(data)
with fluid.dygraph.amp_guard():
conv = model(data)
loss = fluid.layers.reduce_mean(conv)
scaled = scaler.scale(loss)
scaled.backward()
scaler.minimize(optimizer, scaled)
"""
check_type(var, "var", core.VarBase, 'AmpScaler.scale()')

Expand All @@ -160,7 +175,7 @@ def minimize(self, optimizer, *args, **kwargs):
This function is similar as `Optimizer.minimize()`, which performs parameters updating.
If the scaled gradients of parameters contains NAN or INF, the parameters updating is skipped.
Otherwise, it first unscales the scaled gradients of parameters, then updates the parameters.
Otherwise, if `unscale_()` has not been called, it first unscales the scaled gradients of parameters, then updates the parameters.
Finally, the loss scaling ratio is updated.
Expand All @@ -170,30 +185,34 @@ def minimize(self, optimizer, *args, **kwargs):
kwargs: Keyword arguments, which will be forward to `Optimizer.minimize()`.
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
with fluid.dygraph.guard():
model = fluid.dygraph.Conv2D(3, 2, 3)
optimizer = fluid.optimizer.SGDOptimizer(
learning_rate=0.01, parameter_list=model.parameters())
scaler = fluid.dygraph.AmpScaler(init_loss_scaling=1024)
data = fluid.dygraph.to_variable(data)
with fluid.dygraph.amp_guard():
conv = model(data)
loss = fluid.layers.reduce_mean(conv)
scaled = scaler.scale(loss)
scaled.backward()
scaler.minimize(optimizer, scaled)
import numpy as np
import paddle.fluid as fluid
data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
with fluid.dygraph.guard():
model = fluid.dygraph.Conv2D(3, 2, 3)
optimizer = fluid.optimizer.SGDOptimizer(
learning_rate=0.01, parameter_list=model.parameters())
scaler = fluid.dygraph.AmpScaler(init_loss_scaling=1024)
data = fluid.dygraph.to_variable(data)
with fluid.dygraph.amp_guard():
conv = model(data)
loss = fluid.layers.reduce_mean(conv)
scaled = scaler.scale(loss)
scaled.backward()
scaler.minimize(optimizer, scaled)
"""
if not self._enable:
return optimizer.minimize(*args, **kwargs)

optimizer_state = self._optimizer_states[id(optimizer)]

# unscale the grad
self._unscale(optimizer)
if optimizer_state["state"] is OptimizerState.INIT:
self._unscale(optimizer)

optimize_ops, params_grads = (None, None)

Expand All @@ -207,12 +226,31 @@ def minimize(self, optimizer, *args, **kwargs):
# uopdate the scale
self._update()

self._optimizer_states = defaultdict(_refresh_optimizer_state)

return optimize_ops, params_grads

def _unscale(self, optimizer):
"""
Unscale the gradients of parameters, multiplies the gradients of parameters by 1/(loss scaling ratio).
If this instance of :class:`GradScaler` is not enabled, output are returned unmodified.
Args:
optimizer(Optimizer): The optimizer used to update parameters.
Returns:
The unscaled parameters or original parameters.
"""
if not self._enable:
return

optimizer_state = self._optimizer_states[id(optimizer)]

if optimizer_state["state"] is OptimizerState.UNSCALED:
raise RuntimeError(
"unscale_() has already been called on this optimizer since the last update()."
)
elif optimizer_state["state"] is OptimizerState.STEPPED:
raise RuntimeError("unscale_() is being called after step().")

if getattr(optimizer, '_param_groups', None) and isinstance(
optimizer._param_groups[0], dict):
param_grads = []
Expand Down Expand Up @@ -256,6 +294,8 @@ def _unscale(self, optimizer):
temp_found_inf_fp32)
self._found_inf = temp_found_inf_fp16 or temp_found_inf_fp32

optimizer_state["state"] = OptimizerState.UNSCALED

def _update(self):
"""
Updates the loss_scaling.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def train_batch(self, batch, model, optimizer, is_mp):
scaled.backward() # do backward

scaler.step(optimizer) # update parameters
scaler.update()
optimizer.clear_grad()
return scaled

Expand Down
Loading

1 comment on commit f58e33c

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.