Skip to content

Commit

Permalink
Add _get_parameter method to Lamb optimizer (#39416)
Browse files Browse the repository at this point in the history
* add _get_parameter func to lamb

* remove duplicate code
  • Loading branch information
sneaxiy authored Feb 10, 2022
1 parent 32d79bb commit c47d672
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 7 deletions.
35 changes: 30 additions & 5 deletions python/paddle/fluid/tests/unittests/test_lambv2_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,32 +195,57 @@ def check_main(self, x_np, place, multi_precision=False, seed=10, n=10):
hidden = linear(x)
loss = paddle.mean(hidden)

optimizer = paddle.optimizer.Lamb(learning_rate=1e-3)
optimizer._multi_precision = multi_precision
original_optimizer = paddle.optimizer.Lamb(learning_rate=1e-3)
original_optimizer._multi_precision = multi_precision
if multi_precision:
optimizer = paddle.static.amp.decorate(
optimizer, use_pure_fp16=True, use_fp16_guard=True)
original_optimizer, use_pure_fp16=True, use_fp16_guard=True)
else:
optimizer = original_optimizer
optimizer.minimize(loss)

weight, bias = linear.weight, linear.bias
scope = paddle.static.Scope()
exe = paddle.static.Executor(place)
scope = paddle.static.Scope()
x = main_prog.global_block().var(x.name)
if x.dtype == core.VarDesc.VarType.FP16:
x_np = x_np.astype(np.float16)

def get_parameter(var):
name = var if isinstance(var, (str, bytes)) else var.name
params = original_optimizer._get_parameter(name, scope)
assert isinstance(params, (list, tuple))
params = list(params)
assert len(params) == 2
if multi_precision:
params[0] = np.array(params[0])
params[1] = np.array(params[1])
self.assertTrue(
np.array_equal(params[0], params[1].astype(np.float16)))
return params[0].astype(np.float32)
else:
self.assertTrue(params[0] is not None)
self.assertTrue(params[1] is None)
params[0] = np.array(params[0])
return params[0]

with paddle.static.scope_guard(scope):
exe.run(startup_prog)
if multi_precision:
optimizer.amp_init(place)

weight_np, bias_np = None, None
for i in range(n):
feed_dict = {x.name: x_np}
weight_np, bias_np = exe.run(main_prog,
feed=feed_dict,
fetch_list=[weight, bias])
return weight_np.astype('float32'), bias_np.astype('float32')
weight_np = weight_np.astype('float32')
bias_np = bias_np.astype('float32')
self.assertTrue(
np.array_equal(weight_np, get_parameter(weight)))
self.assertTrue(np.array_equal(bias_np, get_parameter(bias)))
return weight_np, bias_np

@switch_to_static_graph
def test_main(self):
Expand Down
25 changes: 23 additions & 2 deletions python/paddle/optimizer/lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ..fluid import unique_name
from ..fluid.layer_helper import LayerHelper
from paddle import _C_ops
from paddle.fluid.executor import global_scope

__all__ = []

Expand Down Expand Up @@ -131,9 +132,25 @@ def __init__(self,
'exclude_from_weight_decay_fn': exclude_from_weight_decay_fn,
}
self._master_weights = {}
self._used_master_weights = {}
# TODO(zengjinle): expose API as soon as possible
self._multi_precision = False

def _get_parameter(self, name, scope=None):
if scope is None:
scope = global_scope()

p_t = scope.find_var(name).get_tensor()

master_name = self._used_master_weights.get(name)
if master_name is not None:
master_p_t = scope.find_var(master_name).get_tensor()
assert master_p_t._dtype() != p_t._dtype()
assert master_p_t.shape() == p_t.shape()
else:
master_p_t = None
return p_t, master_p_t

def _create_master_weight(self, param):
assert self._multi_precision
if param.name in self._master_weights:
Expand Down Expand Up @@ -243,8 +260,12 @@ def _append_optimize_op(self, block, param_and_grad):

find_master = self._multi_precision and param_and_grad[
0].dtype == core.VarDesc.VarType.FP16
master_weight = self._master_weights[param_and_grad[0]
.name] if find_master else None
p_name = param_and_grad[0].name
if find_master:
master_weight = self._master_weights[p_name]
self._used_master_weights[p_name] = master_weight.name
else:
master_weight = None
found_inf = self._get_auxiliary_var('found_inf')

if framework.in_dygraph_mode():
Expand Down

0 comments on commit c47d672

Please sign in to comment.