From 96fdcfd77ddea657a72c09f61ae0c2b4ec953e89 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Tue, 18 Aug 2020 09:06:39 +0000 Subject: [PATCH 1/6] api2.0 paddle.nn.Bilinear and paddle.nn.functional.bilinear, test=develop --- paddle/fluid/pybind/op_function_generator.cc | 1 + .../tests/unittests/test_bilinear_api.py | 62 +++++++++++++ python/paddle/nn/__init__.py | 1 + python/paddle/nn/functional/__init__.py | 1 + python/paddle/nn/functional/common.py | 70 ++++++++++++++- python/paddle/nn/layer/common.py | 90 ++++++++++++++++++- 6 files changed, 223 insertions(+), 2 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_bilinear_api.py diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index 93ba9feedf95b..671de8818ebb6 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -41,6 +41,7 @@ std::map> op_ins_map = { {"fake_quantize_dequantize_moving_average_abs_max", {"X", "InScale", "InAccum", "InState"}}, {"nll_loss", {"X", "Label", "Weight"}}, + {"bilinear_tensor_product", {"X", "Y", "Weight", "Bias"}}, }; // NOTE(zhiqiu): Like op_ins_map. diff --git a/python/paddle/fluid/tests/unittests/test_bilinear_api.py b/python/paddle/fluid/tests/unittests/test_bilinear_api.py new file mode 100644 index 0000000000000..0d7a165e3a3d1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_bilinear_api.py @@ -0,0 +1,62 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle.fluid as fluid +from op_test import OpTest +import paddle + + +class TestBilinearAPI(unittest.TestCase): + def test_api(self): + with fluid.program_guard(fluid.default_startup_program(), + fluid.default_main_program()): + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + + data1 = fluid.data(name='X1', shape=[5, 5], dtype='float32') + data2 = fluid.data(name='X2', shape=[5, 4], dtype='float32') + + layer1 = np.random.random((5, 5)).astype('float32') + layer2 = np.random.random((5, 4)).astype('float32') + + bilinear = paddle.nn.Bilinear( + in1_features=5, in2_features=4, out_features=1000) + ret = bilinear(data1, data2) + + exe.run(fluid.default_startup_program()) + ret_fetch = exe.run(feed={'X1': layer1, + 'X2': layer2}, + fetch_list=[ret.name]) + self.assertEqual(ret_fetch[0].shape, (5, 1000)) + + +class TestBilinearAPIDygraph(unittest.TestCase): + def test_api(self): + with fluid.dygraph.guard(): + layer1 = np.random.random((5, 5)).astype('float32') + layer2 = np.random.random((5, 4)).astype('float32') + bilinear = paddle.nn.Bilinear( + in1_features=5, in2_features=4, out_features=1000) + ret = bilinear( + fluid.dygraph.base.to_variable(layer1), + fluid.dygraph.base.to_variable(layer2)) + self.assertEqual(ret.shape, [5, 1000]) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index a52d45521fd1b..2bd3d6515dd40 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -66,6 +66,7 @@ from .layer.common import Linear #DEFINE_ALIAS from .layer.common import Flatten #DEFINE_ALIAS from .layer.common import UpSample #DEFINE_ALIAS +from .layer.common import Bilinear #DEFINE_ALIAS from .layer.conv import Conv2D #DEFINE_ALIAS from .layer.conv import Conv2DTranspose #DEFINE_ALIAS from .layer.conv import Conv3D #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index bc71b8bdf06d2..06f1a36859ad6 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -62,6 +62,7 @@ # from .common import bilinear_tensor_product #DEFINE_ALIAS from .common import assign #DEFINE_ALIAS from .common import interpolate #DEFINE_ALIAS +from .common import bilinear #DEFINE_ALIAS from .conv import conv2d #DEFINE_ALIAS from .conv import conv2d_transpose #DEFINE_ALIAS from .conv import conv3d #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index fe41cb6e64c34..2cc097039e9a1 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -27,6 +27,9 @@ #from ...fluid.layers import fc #DEFINE_ALIAS from ...fluid.layers import pad_constant_like #DEFINE_ALIAS +from ...fluid.framework import in_dygraph_mode +from ...fluid import core, dygraph_utils +from ...fluid.data_feeder import check_variable_and_dtype __all__ = [ 'dropout', @@ -40,7 +43,8 @@ 'unfold', # 'bilinear_tensor_product', 'assign', - 'interpolate' + 'interpolate', + 'bilinear' ] @@ -446,3 +450,67 @@ def _is_list_or_turple_(data): outputs={"Out": out}, attrs=attrs) return out + + +def bilinear(x1, x2, weight, bias=None, name=None): + """ + + This layer performs bilinear on two inputs. + + .. math:: + out_{i} = x1 * W_{i} * {x2^\mathrm{T}}, i=0,1,...,size-1 + out = out + b + + In this formula: + - :math:`x1`: the first input contains in1_features elements, shape is [batch_size, in1_features]. + - :math:`x2`: the second input contains in2_features elements, shape is [batch_size, in2_features]. + - :math:`W_{i}`: the i-th learned weight, shape is [in1_features, in2_features], and learned weight's shape is [out_features, in1_features, in2_features]. + - :math:`out_{i}`: the i-th element of out, shape is [batch_size, out_features]. + - :math:`b`: the learned bias, shape is [1, out_features]. + - :math:`x2^\mathrm{T}`: the transpose of :math:`x2`. + + Parameters: + x1 (Tensor): the first input tensor, it's data type should be float32, float64. + x2 (Tensor): the second input tensor, it's data type should be float32, float64. + weight (Parameter): The learnable weights of this layer, shape is [out_features, in1_features, in2_features]. + bias (Parameter, optional): The learnable bias(Bias) of this layer, shape is [1, out_features]. If it is set to None, no bias will be added to the output units. The default value is None. + name (str, optional): The default value is None. Normally there is no need for user + to set this property. For more information, please refer to :ref:`api_guide_Name`. Default: None. + + Returns: + Variable: A 2-D Tensor of shape [batch_size, out_features]. + + Examples: + .. code-block:: python + + import paddle + import numpy + import paddle.nn.functional as F + + with paddle.fluid.dygraph.guard(): + x1 = numpy.random.random((5, 5)).astype('float32') + x2 = numpy.random.random((5, 4)).astype('float32') + w = numpy.random.random((1000, 5, 4)).astype('float32') + b = numpy.random.random((1, 1000)).astype('float32') + + result = F.bilinear(paddle.to_tensor(x1), paddle.to_tensor(x2), paddle.to_tensor(w), paddle.to_tensor(b)) # result shape [5, 1000] + + """ + + if in_dygraph_mode(): + return core.ops.bilinear_tensor_product(x1, x2, weight, bias) + + check_variable_and_dtype(x1, 'x1', ['float32', 'float64'], 'bilinear') + check_variable_and_dtype(x2, 'x2', ['float32', 'float64'], 'bilinear') + + inputs = {"X": x1, "Y": x2, "Weight": weight} + if bias is not None: + inputs["Bias"] = bias + + helper = LayerHelper("bilinear", **locals()) + out = helper.create_variable_for_type_inference(dtype=x1.dtype) + + helper.append_op( + type="bilinear_tensor_product", inputs=inputs, outputs={"Out": out}) + + return out diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index 45259bea49d42..27a218199017f 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -23,7 +23,7 @@ __all__ = [ 'BilinearTensorProduct', 'Pool2D', 'Embedding', 'Linear', 'UpSample', - 'Pad2D' + 'Pad2D', 'Bilinear' ] @@ -342,3 +342,91 @@ def forward(self, input): mode=self._mode, pad_value=self._pad_value, data_format=self._data_format) + + +class Bilinear(layers.Layer): + """ + + This layer performs bilinear on two inputs. + + .. math:: + out_{i} = x1 * W_{i} * {x2^\mathrm{T}}, i=0,1,...,size-1 + out = out + b + + In this formula: + - :math:`x1`: the first input contains in1_features elements, shape is [batch_size, in1_features]. + - :math:`x2`: the second input contains in2_features elements, shape is [batch_size, in2_features]. + - :math:`W_{i}`: the i-th learned weight, shape is [in1_features, in2_features], and learned weight's shape is [out_features, in1_features, in2_features]. + - :math:`out_{i}`: the i-th element of out, shape is [batch_size, out_features]. + - :math:`b`: the learned bias, shape is [1, out_features]. + - :math:`x2^\mathrm{T}`: the transpose of :math:`x2`. + + Parameters: + in1_features (int): The dimension of each first input(`x1`). + in2_features (int): The dimension of each second input(`x2`). + out_features (int): The dimension of output of this layer. + weight_attr (ParamAttr, optional): The parameter attribute for the learnable w, parameters/weights of + this layer. The default value is None. + bias_attr (ParamAttr, optional): The parameter attribute for the bias + of this layer. If it is set to False, no bias will be added to the output units. + If it is set to None, the bias is initialized zero. The default value is None. + name (str, optional): The default value is None. Normally there is no need for user + to set this property. For more information, please refer to :ref:`api_guide_Name`. Default: None. + + Attribute: + **weight** (Parameter): the learnable weights of this layer. + + **bias** (Parameter): the learnable bias of this layer. + + Returns: + Variable: A 2-D Tensor of shape [batch_size, out_features]. + + Examples: + .. code-block:: python + + import paddle + import numpy + + with paddle.fluid.dygraph.guard(): + layer1 = numpy.random.random((5, 5)).astype('float32') + layer2 = numpy.random.random((5, 4)).astype('float32') + bilinear = paddle.nn.Bilinear( + in1_features=5, in2_features=4, out_features=1000) + result = bilinear(paddle.to_tensor(layer1), + paddle.to_tensor(layer2)) # result shape [5, 1000] + + """ + + def __init__(self, + in1_features, + in2_features, + out_features, + weight_attr=None, + bias_attr=None, + name=None): + super(Bilinear, self).__init__() + self._weight_attr = weight_attr + self._bias_attr = bias_attr + self._name = name + self._in1_features = in1_features + self._in2_features = in2_features + self._out_features = out_features + self._dtype = self._helper.get_default_dtype() + + weight_shape = [ + self._out_features, self._in1_features, self._in2_features + ] + self.weight = self.create_parameter( + attr=self._weight_attr, + shape=weight_shape, + dtype=self._dtype, + is_bias=False) + bias_shape = [1, self._out_features] + self.bias = self.create_parameter( + attr=self._bias_attr, + shape=bias_shape, + dtype=self._dtype, + is_bias=True) + + def forward(self, x1, x2): + return F.bilinear(x1, x2, self.weight, self.bias, self._name) From 942ca0ca954b06d79a2ad300a5aeb8d005e33c25 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Tue, 18 Aug 2020 09:39:21 +0000 Subject: [PATCH 2/6] api2.0 fix code examples, test=develop --- python/paddle/nn/functional/common.py | 12 ++++++------ python/paddle/nn/layer/common.py | 14 +++++++------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 2cc097039e9a1..4c6912af38251 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -487,13 +487,13 @@ def bilinear(x1, x2, weight, bias=None, name=None): import numpy import paddle.nn.functional as F - with paddle.fluid.dygraph.guard(): - x1 = numpy.random.random((5, 5)).astype('float32') - x2 = numpy.random.random((5, 4)).astype('float32') - w = numpy.random.random((1000, 5, 4)).astype('float32') - b = numpy.random.random((1, 1000)).astype('float32') + paddle.disable_static() + x1 = numpy.random.random((5, 5)).astype('float32') + x2 = numpy.random.random((5, 4)).astype('float32') + w = numpy.random.random((1000, 5, 4)).astype('float32') + b = numpy.random.random((1, 1000)).astype('float32') - result = F.bilinear(paddle.to_tensor(x1), paddle.to_tensor(x2), paddle.to_tensor(w), paddle.to_tensor(b)) # result shape [5, 1000] + result = F.bilinear(paddle.to_tensor(x1), paddle.to_tensor(x2), paddle.to_tensor(w), paddle.to_tensor(b)) # result shape [5, 1000] """ diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index 27a218199017f..80eb8cfe7317d 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -387,13 +387,13 @@ class Bilinear(layers.Layer): import paddle import numpy - with paddle.fluid.dygraph.guard(): - layer1 = numpy.random.random((5, 5)).astype('float32') - layer2 = numpy.random.random((5, 4)).astype('float32') - bilinear = paddle.nn.Bilinear( - in1_features=5, in2_features=4, out_features=1000) - result = bilinear(paddle.to_tensor(layer1), - paddle.to_tensor(layer2)) # result shape [5, 1000] + paddle.disable_static() + layer1 = numpy.random.random((5, 5)).astype('float32') + layer2 = numpy.random.random((5, 4)).astype('float32') + bilinear = paddle.nn.Bilinear( + in1_features=5, in2_features=4, out_features=1000) + result = bilinear(paddle.to_tensor(layer1), + paddle.to_tensor(layer2)) # result shape [5, 1000] """ From 8a85a53f36d729e6c1b082633180dbc46b9c433f Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Wed, 19 Aug 2020 02:35:23 +0000 Subject: [PATCH 3/6] modify test_bilinear_api, about place,to_tensor , test=develop --- .../tests/unittests/test_bilinear_api.py | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_bilinear_api.py b/python/paddle/fluid/tests/unittests/test_bilinear_api.py index 0d7a165e3a3d1..24eae4797de85 100644 --- a/python/paddle/fluid/tests/unittests/test_bilinear_api.py +++ b/python/paddle/fluid/tests/unittests/test_bilinear_api.py @@ -15,17 +15,22 @@ from __future__ import print_function import unittest -import numpy as np -import paddle.fluid as fluid from op_test import OpTest + import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +import numpy as np class TestBilinearAPI(unittest.TestCase): def test_api(self): with fluid.program_guard(fluid.default_startup_program(), fluid.default_main_program()): - place = fluid.CUDAPlace(0) + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + else: + place = core.CPUPlace() exe = fluid.Executor(place) data1 = fluid.data(name='X1', shape=[5, 5], dtype='float32') @@ -47,15 +52,13 @@ def test_api(self): class TestBilinearAPIDygraph(unittest.TestCase): def test_api(self): - with fluid.dygraph.guard(): - layer1 = np.random.random((5, 5)).astype('float32') - layer2 = np.random.random((5, 4)).astype('float32') - bilinear = paddle.nn.Bilinear( - in1_features=5, in2_features=4, out_features=1000) - ret = bilinear( - fluid.dygraph.base.to_variable(layer1), - fluid.dygraph.base.to_variable(layer2)) - self.assertEqual(ret.shape, [5, 1000]) + paddle.disable_static() + layer1 = np.random.random((5, 5)).astype('float32') + layer2 = np.random.random((5, 4)).astype('float32') + bilinear = paddle.nn.Bilinear( + in1_features=5, in2_features=4, out_features=1000) + ret = bilinear(paddle.to_tensor(layer1), paddle.to_tensor(layer2)) + self.assertEqual(ret.shape, [5, 1000]) if __name__ == "__main__": From e822fc15ee0431d5c122f395ddef35186228ac51 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Thu, 20 Aug 2020 12:29:03 +0000 Subject: [PATCH 4/6] re pass pre-commit, test=develop --- python/paddle/nn/functional/common.py | 1 + python/paddle/nn/layer/common.py | 23 ++++++----------------- 2 files changed, 7 insertions(+), 17 deletions(-) diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index adadd89d4fa04..5c2b25bcd85a9 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -459,6 +459,7 @@ def _is_list_or_turple_(data): attrs=attrs) return out + def bilinear(x1, x2, weight, bias=None, name=None): """ diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index d9e4b8cf0487d..9ff73924369c5 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -23,23 +23,10 @@ __all__ = [ 'BilinearTensorProduct', 'Pool2D', 'Embedding', 'Linear', 'UpSample', - 'Pad2D', 'Bilinear', - 'BilinearTensorProduct', - 'Pool2D', - 'Embedding', - 'Linear', - 'UpSample', - 'Pad2D', - 'ReflectionPad1d', - 'ReplicationPad1d', - 'ConstantPad1d', - 'ReflectionPad2d', - 'ReplicationPad2d', - 'ConstantPad2d', - 'ZeroPad2d', - 'ConstantPad3d', - 'ReplicationPad3d', - 'CosineSimilarity' + 'Pad2D', 'Bilinear', 'BilinearTensorProduct', 'Pool2D', 'Embedding', + 'Linear', 'UpSample', 'Pad2D', 'ReflectionPad1d', 'ReplicationPad1d', + 'ConstantPad1d', 'ReflectionPad2d', 'ReplicationPad2d', 'ConstantPad2d', + 'ZeroPad2d', 'ConstantPad3d', 'ReplicationPad3d', 'CosineSimilarity' ] @@ -349,6 +336,7 @@ def forward(self, input): pad_value=self._pad_value, data_format=self._data_format) + class Bilinear(layers.Layer): """ @@ -436,6 +424,7 @@ def __init__(self, def forward(self, x1, x2): return F.bilinear(x1, x2, self.weight, self.bias, self._name) + class ReflectionPad1d(layers.Layer): """ This interface is used to construct a callable object of the ``ReflectionPad1d`` class. From aecc0663b575b713c62c5c099103db8ae8a2da0e Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Sun, 23 Aug 2020 19:24:07 +0800 Subject: [PATCH 5/6] Update common.py --- python/paddle/nn/functional/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index ecf981a339538..2bc1f6556d5a4 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -54,7 +54,7 @@ 'assign', 'interpolate', 'bilinear', - 'cosine_similarity' + 'cosine_similarity', ] From 357bb790fa2dd85037115cbe930dfc9c920a6815 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Mon, 24 Aug 2020 01:23:12 +0000 Subject: [PATCH 6/6] fix BilinearTensorProduct ci error, test=develop --- python/paddle/nn/__init__.py | 1 + python/paddle/nn/layer/common.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 8ef5fd9948615..fc8e93d49be7c 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -71,6 +71,7 @@ from .layer.activation import Tanhshrink #DEFINE_ALIAS from .layer.activation import LogSoftmax #DEFINE_ALIAS from .layer.activation import HSigmoid #DEFINE_ALIAS +from .layer.common import BilinearTensorProduct #DEFINE_ALIAS from .layer.common import Pool2D #DEFINE_ALIAS from .layer.common import Pad2D #DEFINE_ALIAS from .layer.common import ReflectionPad1d #DEFINE_ALIAS diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index 455b8d3f5a7a1..4680113f8d0c7 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -13,6 +13,7 @@ # limitations under the License. # TODO: define the common classes to build a neural network +from ...fluid.dygraph import BilinearTensorProduct #DEFINE_ALIAS from ...fluid.dygraph import Pool2D #DEFINE_ALIAS from ...fluid.dygraph import Embedding #DEFINE_ALIAS from ...fluid.dygraph import Linear #DEFINE_ALIAS @@ -22,6 +23,7 @@ from ...fluid.framework import _dygraph_tracer __all__ = [ + 'BilinearTensorProduct', 'Pool2D', 'Embedding', 'Linear',