Skip to content

Commit

Permalink
refine adam/strided_slice && fix doc for rmsprop/unstack (#27740)
Browse files Browse the repository at this point in the history
* refine parameters order && doc

* update rmsprop doc

* refine adam/transpose/unstack/stride_slice

* fix bug && doc

* fix doc

* bug fix

* bug fix

* fix doc

* fix doc

* fix doc

* fix doc

* depercate old strided_slice

* update doc

* set default value for name

* update doc
  • Loading branch information
MRXLT authored Oct 12, 2020
1 parent e96fc6a commit 84d8e49
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 67 deletions.
10 changes: 6 additions & 4 deletions python/paddle/fluid/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10241,9 +10241,9 @@ def unstack(x, axis=0, num=None):
Examples:
.. code-block:: python

import paddle.fluid as fluid
x = fluid.data(name='x', shape=[2, 3, 5], dtype='float32') # create a tensor with shape=[2, 3, 5]
y = fluid.layers.unstack(x, axis=1) # unstack with second axis, which results 3 tensors with shape=[2, 5]
import paddle
x = paddle.ones(name='x', shape=[2, 3, 5], dtype='float32') # create a tensor with shape=[2, 3, 5]
y = paddle.unstack(x, axis=1) # unstack with second axis, which results 3 tensors with shape=[2, 5]

"""
helper = LayerHelper('unstack', **locals())
Expand Down Expand Up @@ -11017,7 +11017,7 @@ def slice(input, axes, starts, ends):
return out


@templatedoc()
@deprecated(since='2.0.0', update_to="paddle.strided_slice")
def strided_slice(input, axes, starts, ends, strides):
"""
:alias_main: paddle.strided_slice
Expand Down Expand Up @@ -11095,7 +11095,9 @@ def strided_slice(input, axes, starts, ends, strides):
.. code-block:: python

import paddle.fluid as fluid
import paddle

paddle.enable_static()
input = fluid.data(
name="input", shape=[3, 4, 5, 6], dtype='float32')

Expand Down
13 changes: 13 additions & 0 deletions python/paddle/fluid/tests/unittests/test_strided_slice_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
import numpy as np
import unittest
import paddle.fluid as fluid
import paddle

paddle.enable_static()


def strided_slice_native_forward(input, axes, starts, ends, strides):
Expand Down Expand Up @@ -498,6 +501,16 @@ def test_1(self):
assert np.array_equal(res_6, input[-3:3, 0:100:2, :, -1:2:-1])
assert np.array_equal(res_7, input[-1, 0:100:2, :, -1:2:-1])

def test_dygraph_op(self):
x = paddle.zeros(shape=[3, 4, 5, 6], dtype="float32")
axes = [1, 2, 3]
starts = [-3, 0, 2]
ends = [3, 2, 4]
strides_1 = [1, 1, 1]
sliced_1 = paddle.strided_slice(
x, axes=axes, starts=starts, ends=ends, strides=strides_1)
assert sliced_1.shape == (3, 2, 2, 2)


if __name__ == "__main__":
unittest.main()
15 changes: 6 additions & 9 deletions python/paddle/incubate/complex/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,13 @@ def transpose(x, perm, name=None):
.. code-block:: python
import paddle
import numpy as np
import paddle.fluid.dygraph as dg
with dg.guard():
a = np.array([[1.0 + 1.0j, 2.0 + 1.0j], [3.0+1.0j, 4.0+1.0j]])
x = dg.to_variable(a)
y = paddle.complex.transpose(x, [1, 0])
print(y.numpy())
# [[1.+1.j 3.+1.j]
# [2.+1.j 4.+1.j]]
x = paddle.to_tensor([[1.0 + 1.0j, 2.0 + 1.0j], [3.0+1.0j, 4.0+1.0j], [5.0+1.0j, 6.0+1.0j]])
x_transposed = paddle.complex.transpose(x, [1, 0])
print(x_transposed.numpy())
#[[1.+1.j 3.+1.j 5.+1.j]
# [2.+1.j 4.+1.j 6.+1.j]]
"""
complex_variable_exists([x], "transpose")
real = layers.transpose(x.real, perm, name)
Expand Down
38 changes: 15 additions & 23 deletions python/paddle/optimizer/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class Adam(Optimizer):
of section 2 of `Adam paper <https://arxiv.org/abs/1412.6980>`_ ,
it can dynamically adjusts the learning rate of each parameter using
the 1st moment estimates and the 2nd moment estimates of the gradient.
The parameter ``param_out`` update rule with gradient ``grad``:
.. math::
Expand Down Expand Up @@ -68,31 +68,28 @@ class Adam(Optimizer):
the regularization setting here in optimizer will be ignored for this parameter. \
Otherwise, the regularization setting here in optimizer will take effect. \
Default None, meaning there is no regularization.
grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
some derived class of ``GradientClipBase`` . There are three cliping strategies
( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
some derived class of ``GradientClipBase`` . There are three cliping strategies
( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
:ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping.
name (str, optional): Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name`.
The default value is None.
lazy_mode (bool, optional): The official Adam algorithm has two moving-average accumulators.
The accumulators are updated at every step. Every element of the two moving-average
is updated in both dense mode and sparse mode. If the size of parameter is very large,
then the update may be very slow. The lazy mode only update the element that has
gradient in current mini-batch, so it will be much more faster. But this mode has
different semantics with the original Adam algorithm and may lead to different result.
The default value is False.
name (str, optional): Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name`.
The default value is None.
Examples:
.. code-block:: python
import paddle
import numpy as np
paddle.disable_static()
inp = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
inp = paddle.to_tensor(inp)
inp = paddle.rand([10,10], dtype="float32")
out = linear(inp)
loss = paddle.mean(out)
adam = paddle.optimizer.Adam(learning_rate=0.1,
Expand All @@ -105,12 +102,9 @@ class Adam(Optimizer):
# Adam with beta1/beta2 as Tensor and weight_decay as float
import paddle
import numpy as np
paddle.disable_static()
inp = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
inp = paddle.to_tensor(inp)
inp = paddle.rand([10,10], dtype="float32")
out = linear(inp)
loss = paddle.mean(out)
Expand Down Expand Up @@ -140,8 +134,8 @@ def __init__(self,
parameters=None,
weight_decay=None,
grad_clip=None,
name=None,
lazy_mode=False):
lazy_mode=False,
name=None):
assert learning_rate is not None
assert beta1 is not None
assert beta2 is not None
Expand Down Expand Up @@ -258,21 +252,19 @@ def _append_optimize_op(self, block, param_and_grad):
def step(self):
"""
Execute the optimizer and update parameters once.
Returns:
None
Examples:
.. code-block:: python
import paddle
import numpy as np
paddle.disable_static()
value = np.arange(26).reshape(2, 13).astype("float32")
a = paddle.to_tensor(value)
a = paddle.rand([2,13], dtype="float32")
linear = paddle.nn.Linear(13, 5)
# This can be any optimizer supported by dygraph.
adam = paddle.optimizer.Adam(learning_rate = 0.01,
adam = paddle.optimizer.Adam(learning_rate = 0.01,
parameters = linear.parameters())
out = linear(a)
out.backward()
Expand Down
27 changes: 12 additions & 15 deletions python/paddle/optimizer/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

class AdamW(Adam):
"""
The AdamW optimizer is implemented based on the AdamW Optimization
The AdamW optimizer is implemented based on the AdamW Optimization
in paper `DECOUPLED WEIGHT DECAY REGULARIZATION <https://arxiv.org/pdf/1711.05101.pdf>`_.
it can resolves the problem of L2 regularization failure in the Adam optimizer.
Expand All @@ -32,7 +32,7 @@ class AdamW(Adam):
t & = t + 1
moment\_1\_out & = {\\beta}_1 * moment\_1 + (1 - {\\beta}_1) * grad
moemnt\_2\_out & = {\\beta}_2 * moment\_2 + (1 - {\\beta}_2) * grad * grad
learning\_rate & = learning\_rate * \\
Expand All @@ -57,35 +57,32 @@ class AdamW(Adam):
The default value is 1e-08.
weight_decay (float|Tensor, optional): The weight decay coefficient, it can be float or Tensor. The default value is 0.01.
apply_decay_param_fun (function|None, optional): If it is not None,
only tensors that makes apply_decay_param_fun(Tensor)==True
only tensors that makes apply_decay_param_fun(Tensor)==True
will be updated. It only works when we want to specify tensors.
Default: None.
grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
some derived class of ``GradientClipBase`` . There are three cliping strategies
( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
some derived class of ``GradientClipBase`` . There are three cliping strategies
( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
:ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping.
name (str, optional): Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name`.
The default value is None.
lazy_mode (bool, optional): The official Adam algorithm has two moving-average accumulators.
The accumulators are updated at every step. Every element of the two moving-average
is updated in both dense mode and sparse mode. If the size of parameter is very large,
then the update may be very slow. The lazy mode only update the element that has
gradient in current mini-batch, so it will be much more faster. But this mode has
different semantics with the original Adam algorithm and may lead to different result.
The default value is False.
name (str, optional): Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name`.
The default value is None.
**Notes**:
**Currently, AdamW doesn't support sparse parameter optimization.**
Examples:
.. code-block:: python
import paddle
import numpy as np
paddle.disable_static()
inp = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
inp = paddle.to_tensor(inp)
inp = paddle.rand([10,10], dtype="float32")
out = linear(inp)
loss = paddle.mean(out)
Expand All @@ -112,8 +109,8 @@ def __init__(self,
weight_decay=0.01,
apply_decay_param_fun=None,
grad_clip=None,
name=None,
lazy_mode=False):
lazy_mode=False,
name=None):
assert learning_rate is not None
assert beta1 is not None
assert beta2 is not None
Expand Down
24 changes: 9 additions & 15 deletions python/paddle/optimizer/rmsprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ class RMSProp(Optimizer):
the regularization setting here in optimizer will be ignored for this parameter. \
Otherwise, the regularization setting here in optimizer will take effect. \
Default None, meaning there is no regularization.
grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
some derived class of ``GradientClipBase`` . There are three cliping strategies
( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
some derived class of ``GradientClipBase`` . There are three cliping strategies
( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
:ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping.
name (str, optional): This parameter is used by developers to print debugging information. \
For details, please refer to :ref:`api_guide_Name`. Default is None.
Expand All @@ -104,24 +104,18 @@ class RMSProp(Optimizer):
.. code-block:: python
import paddle
import numpy as np
paddle.disable_static()
inp = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32")
inp = paddle.rand([10,10], dtype="float32")
linear = paddle.nn.Linear(10, 10)
inp = paddle.to_tensor(inp)
out = linear(inp)
loss = paddle.mean(out)
beta1 = paddle.to_tensor([0.9], dtype="float32")
beta2 = paddle.to_tensor([0.99], dtype="float32")
adam = paddle.optimizer.RMSProp(learning_rate=0.1,
parameters=linear.parameters(),
weight_decay=0.01)
rmsprop = paddle.optimizer.RMSProp(learning_rate=0.1,
parameters=linear.parameters(),
weight_decay=0.01)
out.backward()
adam.step()
adam.clear_grad()
rmsprop.step()
rmsprop.clear_grad()
"""

Expand Down
87 changes: 86 additions & 1 deletion python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
# TODO: define functions to manipulate a tensor
from ..fluid.layers import cast #DEFINE_ALIAS
from ..fluid.layers import slice #DEFINE_ALIAS
from ..fluid.layers import strided_slice #DEFINE_ALIAS
from ..fluid.layers import transpose #DEFINE_ALIAS
from ..fluid.layers import unstack #DEFINE_ALIAS

Expand Down Expand Up @@ -1461,3 +1460,89 @@ def gather_nd(x, index, name=None):
"""

return paddle.fluid.layers.gather_nd(input=x, index=index, name=name)


def strided_slice(x, axes, starts, ends, strides, name=None):
"""
This operator produces a slice of ``x`` along multiple axes. Similar to numpy:
https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html
Slice uses ``axes``, ``starts`` and ``ends`` attributes to specify the start and
end dimension for each axis in the list of axes and Slice uses this information
to slice the input data tensor. If a negative value is passed to
``starts`` or ``ends`` such as :math:`-i`, it represents the reverse position of the
axis :math:`i-1` th(here 0 is the initial position). The ``strides`` represents steps of
slicing and if the ``strides`` is negative, slice operation is in the opposite direction.
If the value passed to ``starts`` or ``ends`` is greater than n
(the number of elements in this dimension), it represents n.
For slicing to the end of a dimension with unknown size, it is recommended
to pass in INT_MAX. The size of ``axes`` must be equal to ``starts`` , ``ends`` and ``strides``.
Following examples will explain how strided_slice works:
.. code-block:: text
Case1:
Given:
data = [ [1, 2, 3, 4], [5, 6, 7, 8], ]
axes = [0, 1]
starts = [1, 0]
ends = [2, 3]
strides = [1, 1]
Then:
result = [ [5, 6, 7], ]
Case2:
Given:
data = [ [1, 2, 3, 4], [5, 6, 7, 8], ]
axes = [0, 1]
starts = [0, 1]
ends = [2, 0]
strides = [1, -1]
Then:
result = [ [8, 7, 6], ]
Case3:
Given:
data = [ [1, 2, 3, 4], [5, 6, 7, 8], ]
axes = [0, 1]
starts = [0, 1]
ends = [-1, 1000]
strides = [1, 3]
Then:
result = [ [2], ]
Args:
x (Tensor): An N-D ``Tensor``. The data type is ``float32``, ``float64``, ``int32`` or ``int64``.
axes (list|tuple): The data type is ``int32`` . Axes that `starts` and `ends` apply to.
It's optional. If it is not provides, it will be treated as :math:`[0,1,...,len(starts)-1]`.
starts (list|tuple|Tensor): The data type is ``int32`` . If ``starts`` is a list or tuple, the elements of it should be integers or Tensors with shape [1]. If ``starts`` is an Tensor, it should be an 1-D Tensor. It represents starting indices of corresponding axis in ``axes``.
ends (list|tuple|Tensor): The data type is ``int32`` . If ``ends`` is a list or tuple, the elements of
it should be integers or Tensors with shape [1]. If ``ends`` is an Tensor, it should be an 1-D Tensor . It represents ending indices of corresponding axis in ``axes``.
strides (list|tuple|Tensor): The data type is ``int32`` . If ``strides`` is a list or tuple, the elements of
it should be integers or Tensors with shape [1]. If ``strides`` is an Tensor, it should be an 1-D Tensor . It represents slice step of corresponding axis in ``axes``.
name(str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Returns:
Tensor: A ``Tensor`` with the same dimension as ``x``. The data type is same as ``x``.
Examples:
.. code-block:: python
import paddle
x = paddle.zeros(shape=[3,4,5,6], dtype="float32")
# example 1:
# attr starts is a list which doesn't contain Tensor.
axes = [1, 2, 3]
starts = [-3, 0, 2]
ends = [3, 2, 4]
strides_1 = [1, 1, 1]
strides_2 = [1, 1, 2]
sliced_1 = paddle.strided_slice(x, axes=axes, starts=starts, ends=ends, strides=strides_1)
# sliced_1 is x[:, 1:3:1, 0:2:1, 2:4:1].
# example 2:
# attr starts is a list which contain tensor Tensor.
minus_3 = paddle.fill_constant([1], "int32", -3)
sliced_2 = paddle.strided_slice(x, axes=axes, starts=[minus_3, 0, 2], ends=ends, strides=strides_2)
# sliced_2 is x[:, 1:3:1, 0:2:1, 2:4:2].
"""

return paddle.fluid.layers.strided_slice(
input=x, axes=axes, starts=starts, ends=ends, strides=strides)

0 comments on commit 84d8e49

Please sign in to comment.