Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update optimizer for 2.0 #26288

Merged
merged 35 commits into from
Aug 23, 2020
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
e45dcff
add doc; notest
MRXLT Aug 14, 2020
85b3f92
fix doc; notest
MRXLT Aug 14, 2020
cbcd950
update doc; notest
MRXLT Aug 14, 2020
9661a54
refine optimizer && adam
MRXLT Aug 14, 2020
f542d77
fix conflict
MRXLT Aug 17, 2020
73baac0
refine optimizer; notest
MRXLT Aug 18, 2020
5a55869
add adam
MRXLT Aug 18, 2020
fd34fbd
fix doc
MRXLT Aug 18, 2020
f5e6881
Merge remote-tracking branch 'upstream/develop' into 2.0-op
MRXLT Aug 18, 2020
a715c46
Merge remote-tracking branch 'upstream/develop' into 2.0-op
MRXLT Aug 19, 2020
e67cd86
fix doc && add adamw; notest
MRXLT Aug 19, 2020
da4025d
add error message
MRXLT Aug 19, 2020
f3699cb
bug fix
MRXLT Aug 19, 2020
6f00384
refine rmsprop && adamax
MRXLT Aug 19, 2020
654377d
fix ci
MRXLT Aug 19, 2020
fa7ccb1
buf fix
MRXLT Aug 19, 2020
9aaf899
update comment
MRXLT Aug 19, 2020
b727dad
unify arguments place; notest
MRXLT Aug 20, 2020
9cf4c3b
fix ut, test=develop
mapingshuo Aug 20, 2020
2e8d253
bug fix
MRXLT Aug 20, 2020
00c38fc
fix conflicts, test=develop
mapingshuo Aug 20, 2020
b75ab16
add examples code
MRXLT Aug 20, 2020
84205ce
Merge remote-tracking branch 'origin/2.0-op' into 2.0-op
MRXLT Aug 20, 2020
b6fa771
bug fix
MRXLT Aug 20, 2020
9cd1838
fix comments
MRXLT Aug 20, 2020
95310f5
fix sample code
MRXLT Aug 20, 2020
ce31795
add sample code for Optimizer
MRXLT Aug 20, 2020
0780b9c
add adamax ut, test=develop
mapingshuo Aug 21, 2020
87a7f56
fix rmsprop ut, test=develop
mapingshuo Aug 21, 2020
06f3c73
add ut for optimizer.py and adamw.py
MRXLT Aug 21, 2020
fd67080
Merge branch '2.0-op' of https://github.com/MRXLT/Paddle into 2.0-op
MRXLT Aug 21, 2020
b00b85f
remove TestAdamOptimizerBetaVariable
MRXLT Aug 21, 2020
6cc0fc2
update api && add ut
MRXLT Aug 21, 2020
5d42420
update doc && fix ut
MRXLT Aug 21, 2020
9094782
add ut
MRXLT Aug 23, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions python/paddle/fluid/tests/unittests/test_adam_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from paddle.fluid import core
from paddle.fluid.op import Operator
import paddle.fluid as fluid
import paddle


class TestAdamOp1(OpTest):
Expand Down Expand Up @@ -443,5 +444,50 @@ def test_with_place(place, shape):
test_with_place(place, shape)


class TestAdamOpV2(unittest.TestCase):
def test_adam_op(self):
place = fluid.CPUPlace()
shape = [2, 3, 8, 8]
exe = fluid.Executor(place)
train_prog = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(train_prog, startup):
with fluid.unique_name.guard():
data = fluid.data(name="data", shape=shape)
conv = fluid.layers.conv2d(data, 8, 3)
loss = fluid.layers.reduce_mean(conv)

beta1 = fluid.layers.create_global_var(
shape=[1], value=0.85, dtype='float32', persistable=True)
beta2 = fluid.layers.create_global_var(
shape=[1], value=0.95, dtype='float32', persistable=True)
betas = [beta1, beta2]
opt = paddle.optimizer.Adam(
learning_rate=1e-5,
beta1=beta1,
beta2=beta2,
weight_decay=0.01,
epsilon=1e-8)
opt.minimize(loss)

exe.run(startup)
data_np = np.random.random(shape).astype('float32')
rets = exe.run(train_prog, feed={"data": data_np}, fetch_list=[loss])
assert rets[0] is not None

def test_adam_op_dygraph(self):
with fluid.dygraph.guard():
value = np.arange(26).reshape(2, 13).astype("float32")
a = fluid.dygraph.to_variable(value)
linear = fluid.Linear(13, 5, dtype="float32")

adam = paddle.optimizer.Adam(
learning_rate=0.01, parameters=linear.parameters())
out = linear(a)
out.backward()
adam.step()
adam.clear_gradients()


if __name__ == "__main__":
unittest.main()
12 changes: 7 additions & 5 deletions python/paddle/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,23 @@

__all__ = [
'Adadelta', 'AdadeltaOptimizer', 'Adagrad', 'AdagradOptimizer', 'Adam',
'Adamax', 'AdamaxOptimizer', 'AdamOptimizer', 'DecayedAdagrad',
'Adamax', 'DecayedAdagrad', 'AdamW'
'DecayedAdagradOptimizer', 'DGCMomentumOptimizer', 'Dpsgd',
'DpsgdOptimizer', 'ExponentialMovingAverage', 'Ftrl', 'FtrlOptimizer',
'LambOptimizer', 'LarsMomentum', 'LarsMomentumOptimizer',
'LookaheadOptimizer', 'ModelAverage', 'Momentum', 'MomentumOptimizer',
'PipelineOptimizer', 'RecomputeOptimizer', 'RMSPropOptimizer', 'SGD',
'SGDOptimizer'
'SGDOptimizer', 'Optimizer'
]


from ..fluid.optimizer import SGD, Momentum, Adagrad, Adam, Adamax, Dpsgd, DecayedAdagrad, \
Ftrl, SGDOptimizer, MomentumOptimizer, AdagradOptimizer, \
AdamOptimizer, AdamaxOptimizer, DpsgdOptimizer, \
from ..fluid.optimizer import SGD, Momentum, Adagrad, Dpsgd, DecayedAdagrad, \
Ftrl, SGDOptimizer, MomentumOptimizer, AdagradOptimizer, DpsgdOptimizer, \
DecayedAdagradOptimizer, RMSPropOptimizer, FtrlOptimizer, Adadelta, \
AdadeltaOptimizer, ModelAverage, LarsMomentum, \
LarsMomentumOptimizer, DGCMomentumOptimizer, LambOptimizer, \
ExponentialMovingAverage, PipelineOptimizer, LookaheadOptimizer, \
RecomputeOptimizer

from .optimizer import Optimizer
from .adam import Adam
286 changes: 286 additions & 0 deletions python/paddle/optimizer/adam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .optimizer import Optimizer
from ..fluid import core
from ..fluid import framework
from ..fluid.framework import Variable


class Adam(Optimizer):
"""
The Adam optimizer uses an optimization described at the end
of section 2 of `Adam paper <https://arxiv.org/abs/1412.6980>`_ ,
it can dynamically adjusts the learning rate of each parameter using
the 1st moment estimates and the 2nd moment estimates of the gradient.

The parameter ``param_out`` update rule with gradient ``grad``:

.. math::

t & = t + 1

moment\_1\_out & = {\\beta}_1 * moment\_1 + (1 - {\\beta}_1) * grad

moment\_2\_out & = {\\beta}_2 * moment\_2 + (1 - {\\beta}_2) * grad * grad

learning\_rate & = learning\_rate * \\
\\frac{\sqrt{1 - {\\beta}_2^t}}{1 - {\\beta}_1^t}

param\_out & = param - learning\_rate * \\frac{moment\_1}{\sqrt{moment\_2} + \epsilon}

Related paper: `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_

Args:
learning_rate (float|Tensor, optional): The learning rate used to update ``Parameter``.
It can be a float value or a ``Variable`` with a float type. The default value is 0.001.
beta1 (float|Variable, optional): The exponential decay rate for the 1st moment estimates.
It should be a float number or a Variable with shape [1] and data type as float32.
The default value is 0.9.
beta2 (float|Variable, optional): The exponential decay rate for the 2nd moment estimates.
It should be a float number or a Variable with shape [1] and data type as float32.
The default value is 0.999.
epsilon (float, optional): A small float value for numerical stability.
The default value is 1e-08.
parameters (Iterable, optional): Iterable of ``Tensor`` names to update to minimize ``loss``. \
This parameter is required in dygraph mode. \
The default value is None in static mode, at this time all parameters will be updated.
weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \
It canbe a float value as coeff of L2 regularization or \
:ref:`api_fluid_regularizer_L1Decay`, :ref:`api_fluid_regularizer_L2Decay`.
If a parameter has set regularizer using :ref:`api_fluid_ParamAttr` already, \
the regularization setting here in optimizer will be ignored for this parameter. \
Otherwis, the regularization setting here in optimizer will take effect. \
Default None, meaning there is no regularization.
grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
some derived class of ``GradientClipBase`` . There are three cliping strategies
( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
:ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping.
name (str, optional): Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name`.
The default value is None.
lazy_mode (bool, optional): The official Adam algorithm has two moving-average accumulators.
The accumulators are updated at every step. Every element of the two moving-average
is updated in both dense mode and sparse mode. If the size of parameter is very large,
then the update may be very slow. The lazy mode only update the element that has
gradient in current mini-batch, so it will be much more faster. But this mode has
different semantics with the original Adam algorithm and may lead to different result.
The default value is False.

Examples:
.. code-block:: python

import paddle
import paddle.fluid as fluid

place = fluid.CPUPlace()
main = fluid.Program()
with fluid.program_guard(main):
x = fluid.data(name='x', shape=[None, 13], dtype='float32')
y = fluid.data(name='y', shape=[None, 1], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost)

adam_optimizer = paddle.optimizer.Adam(0.01)
adam_optimizer.minimize(avg_cost)

fetch_list = [avg_cost]
train_reader = paddle.batch(
paddle.dataset.uci_housing.train(), batch_size=1)
feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
for data in train_reader():
exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list)

.. code-block:: python

# Adam with beta1/beta2 as Variable
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers.learning_rate_scheduler as lr_scheduler

place = fluid.CPUPlace()
main = fluid.Program()
with fluid.program_guard(main):
x = fluid.data(name='x', shape=[None, 13], dtype='float32')
y = fluid.data(name='y', shape=[None, 1], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost)

# define beta decay variable
def get_decayed_betas(beta1_init, beta2_init, decay_steps, decay_rate):
global_step = lr_scheduler._decay_step_counter()

beta1 = fluid.layers.create_global_var(
shape=[1],
value=float(beta1_init),
dtype='float32',
# set persistable for save checkpoints and resume
persistable=True,
name="beta1")
beta2 = fluid.layers.create_global_var(
shape=[1],
value=float(beta2_init),
dtype='float32',
# set persistable for save checkpoints and resume
persistable=True,
name="beta2")

div_res = global_step / decay_steps
decayed_beta1 = beta1_init * (decay_rate**div_res)
decayed_beta2 = beta2_init * (decay_rate**div_res)
fluid.layers.assign(decayed_beta1, beta1)
fluid.layers.assign(decayed_beta2, beta2)

return beta1, beta2

beta1, beta2 = get_decayed_betas(0.9, 0.99, 1e5, 0.9)
adam_optimizer = paddle.optimizer.Adam(
learning_rate=0.01,
beta1=beta1,
beta2=beta2)
adam_optimizer.minimize(avg_cost)

fetch_list = [avg_cost]
train_reader = paddle.batch(
paddle.dataset.uci_housing.train(), batch_size=1)
feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
for data in train_reader():
exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list)
"""
_moment1_acc_str = "moment1"
_moment2_acc_str = "moment2"
_beta1_pow_acc_str = "beta1_pow_acc"
_beta2_pow_acc_str = "beta2_pow_acc"

def __init__(self,
learning_rate=0.001,
beta1=0.9,
beta2=0.999,
epsilon=1e-8,
parameters=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parameters 的位置能上前移动么,毕竟动态图强依赖这个参数

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为了与其他优化器保持一致,暂时先不移动这个参数

weight_decay=None,
grad_clip=None,
name=None,
lazy_mode=False):
assert learning_rate is not None
assert beta1 is not None
assert beta2 is not None
assert epsilon is not None
super(Adam, self).__init__(
learning_rate=learning_rate,
parameters=parameters,
weight_decay=weight_decay,
grad_clip=grad_clip,
name=name)
self.type = "adam"
self._beta1 = beta1
self._beta2 = beta2
self._epsilon = epsilon
self._lazy_mode = lazy_mode

def _create_accumulators(self, block, parameters):
assert isinstance(block, framework.Block)

# Create accumulator tensors for first and second moments
for p in parameters:
self._add_accumulator(self._moment1_acc_str, p)
self._add_accumulator(self._moment2_acc_str, p)
self._add_accumulator(
name=self._beta1_pow_acc_str,
param=p,
fill_value=0.9 if isinstance(self._beta1, Variable) \
else self._beta1,
shape=[1],
type=core.VarDesc.VarType.LOD_TENSOR, device='cpu')
self._add_accumulator(
name=self._beta2_pow_acc_str,
param=p,
fill_value=0.999 if isinstance(self._beta2, Variable) \
else self._beta2,
shape=[1],
type=core.VarDesc.VarType.LOD_TENSOR, device='cpu')

def _append_optimize_op(self, block, param_and_grad):
assert isinstance(block, framework.Block)

moment1 = self._get_accumulator(self._moment1_acc_str,
param_and_grad[0])
moment2 = self._get_accumulator(self._moment2_acc_str,
param_and_grad[0])
beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str,
param_and_grad[0])
beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str,
param_and_grad[0])
lr = self._create_param_lr(param_and_grad)
# create the adam optimize op

if framework.in_dygraph_mode():
_beta1 = self._beta1 if not isinstance(
self._beta1, Variable) else self._beta1.numpy().item(0)
_beta2 = self._beta2 if not isinstance(
self._beta2, Variable) else self._beta2.numpy().item(0)
_, _, _, _, _ = core.ops.adam(
param_and_grad[0], param_and_grad[1], lr, moment1, moment2,
beta1_pow_acc, beta2_pow_acc, param_and_grad[0], moment1,
moment2, beta1_pow_acc, beta2_pow_acc, 'epsilon', self._epsilon,
'lazy_mode', self._lazy_mode, 'min_row_size_to_use_multithread',
1000, 'beta1', _beta1, 'beta2', _beta2)

return None

inputs = {
"Param": [param_and_grad[0]],
"Grad": [param_and_grad[1]],
"LearningRate": [lr],
"Moment1": [moment1],
"Moment2": [moment2],
"Beta1Pow": [beta1_pow_acc],
"Beta2Pow": [beta2_pow_acc]
}
outputs = {
"ParamOut": [param_and_grad[0]],
"Moment1Out": [moment1],
"Moment2Out": [moment2],
"Beta1PowOut": [beta1_pow_acc],
"Beta2PowOut": [beta2_pow_acc],
}
attrs = {
"epsilon": self._epsilon,
"lazy_mode": self._lazy_mode,
"min_row_size_to_use_multithread": 1000
}

if isinstance(self._beta1, Variable):
inputs['Beta1Tensor'] = self._beta1
else:
attrs['beta1'] = self._beta1
if isinstance(self._beta2, Variable):
inputs['Beta2Tensor'] = self._beta2
else:
attrs['beta2'] = self._beta2

adam_op = block.append_op(
type=self.type,
inputs=inputs,
outputs=outputs,
attrs=attrs,
stop_gradient=True)

return adam_op
Loading