Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

api2.0 paddle.nn.Bilinear and paddle.nn.functional.bilinear #26399

Merged
merged 10 commits into from
Aug 24, 2020
1 change: 1 addition & 0 deletions paddle/fluid/pybind/op_function_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"fake_quantize_dequantize_moving_average_abs_max",
{"X", "InScale", "InAccum", "InState"}},
{"nll_loss", {"X", "Label", "Weight"}},
{"bilinear_tensor_product", {"X", "Y", "Weight", "Bias"}},
{"gather", {"X", "Index", "Axis"}},
};

Expand Down
65 changes: 65 additions & 0 deletions python/paddle/fluid/tests/unittests/test_bilinear_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
from op_test import OpTest

import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import numpy as np


class TestBilinearAPI(unittest.TestCase):
def test_api(self):
with fluid.program_guard(fluid.default_startup_program(),
fluid.default_main_program()):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
exe = fluid.Executor(place)

data1 = fluid.data(name='X1', shape=[5, 5], dtype='float32')
data2 = fluid.data(name='X2', shape=[5, 4], dtype='float32')

layer1 = np.random.random((5, 5)).astype('float32')
layer2 = np.random.random((5, 4)).astype('float32')

bilinear = paddle.nn.Bilinear(
in1_features=5, in2_features=4, out_features=1000)
ret = bilinear(data1, data2)

exe.run(fluid.default_startup_program())
ret_fetch = exe.run(feed={'X1': layer1,
'X2': layer2},
fetch_list=[ret.name])
self.assertEqual(ret_fetch[0].shape, (5, 1000))


class TestBilinearAPIDygraph(unittest.TestCase):
def test_api(self):
paddle.disable_static()
layer1 = np.random.random((5, 5)).astype('float32')
layer2 = np.random.random((5, 4)).astype('float32')
bilinear = paddle.nn.Bilinear(
in1_features=5, in2_features=4, out_features=1000)
ret = bilinear(paddle.to_tensor(layer1), paddle.to_tensor(layer2))
self.assertEqual(ret.shape, [5, 1000])


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions python/paddle/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
from .layer.common import Linear #DEFINE_ALIAS
from .layer.common import Flatten #DEFINE_ALIAS
from .layer.common import UpSample #DEFINE_ALIAS
from .layer.common import Bilinear #DEFINE_ALIAS
from .layer.common import Dropout #DEFINE_ALIAS
from .layer.common import Dropout2D #DEFINE_ALIAS
from .layer.common import Dropout3D #DEFINE_ALIAS
Expand Down
1 change: 1 addition & 0 deletions python/paddle/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
# from .common import bilinear_tensor_product #DEFINE_ALIAS
from .common import assign #DEFINE_ALIAS
from .common import interpolate #DEFINE_ALIAS
from .common import bilinear #DEFINE_ALIAS
from .conv import conv1d #DEFINE_ALIAS
from .conv import conv_transpose1d #DEFINE_ALIAS
from .conv import conv2d #DEFINE_ALIAS
Expand Down
67 changes: 67 additions & 0 deletions python/paddle/nn/functional/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@

#from ...fluid.layers import fc #DEFINE_ALIAS
from ...fluid.layers import pad_constant_like #DEFINE_ALIAS
from ...fluid.framework import in_dygraph_mode
from ...fluid import core, dygraph_utils
from ...fluid import core, layers
from ...fluid.data_feeder import check_variable_and_dtype

Expand All @@ -52,6 +54,7 @@
# 'bilinear_tensor_product',
'assign',
'interpolate',
'bilinear',
'cosine_similarity',
]

Expand Down Expand Up @@ -460,6 +463,70 @@ def _is_list_or_turple_(data):
return out


def bilinear(x1, x2, weight, bias=None, name=None):
"""

This layer performs bilinear on two inputs.

.. math::
out_{i} = x1 * W_{i} * {x2^\mathrm{T}}, i=0,1,...,size-1
out = out + b

In this formula:
- :math:`x1`: the first input contains in1_features elements, shape is [batch_size, in1_features].
- :math:`x2`: the second input contains in2_features elements, shape is [batch_size, in2_features].
- :math:`W_{i}`: the i-th learned weight, shape is [in1_features, in2_features], and learned weight's shape is [out_features, in1_features, in2_features].
- :math:`out_{i}`: the i-th element of out, shape is [batch_size, out_features].
- :math:`b`: the learned bias, shape is [1, out_features].
- :math:`x2^\mathrm{T}`: the transpose of :math:`x2`.

Parameters:
x1 (Tensor): the first input tensor, it's data type should be float32, float64.
x2 (Tensor): the second input tensor, it's data type should be float32, float64.
weight (Parameter): The learnable weights of this layer, shape is [out_features, in1_features, in2_features].
bias (Parameter, optional): The learnable bias(Bias) of this layer, shape is [1, out_features]. If it is set to None, no bias will be added to the output units. The default value is None.
name (str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to :ref:`api_guide_Name`. Default: None.

Returns:
Variable: A 2-D Tensor of shape [batch_size, out_features].

Examples:
.. code-block:: python

import paddle
import numpy
import paddle.nn.functional as F

paddle.disable_static()
x1 = numpy.random.random((5, 5)).astype('float32')
x2 = numpy.random.random((5, 4)).astype('float32')
w = numpy.random.random((1000, 5, 4)).astype('float32')
b = numpy.random.random((1, 1000)).astype('float32')

result = F.bilinear(paddle.to_tensor(x1), paddle.to_tensor(x2), paddle.to_tensor(w), paddle.to_tensor(b)) # result shape [5, 1000]

"""

if in_dygraph_mode():
return core.ops.bilinear_tensor_product(x1, x2, weight, bias)

check_variable_and_dtype(x1, 'x1', ['float32', 'float64'], 'bilinear')
check_variable_and_dtype(x2, 'x2', ['float32', 'float64'], 'bilinear')

inputs = {"X": x1, "Y": x2, "Weight": weight}
if bias is not None:
inputs["Bias"] = bias

helper = LayerHelper("bilinear", **locals())
out = helper.create_variable_for_type_inference(dtype=x1.dtype)

helper.append_op(
type="bilinear_tensor_product", inputs=inputs, outputs={"Out": out})

return out


def dropout(x,
p=0.5,
axis=None,
Expand Down
114 changes: 109 additions & 5 deletions python/paddle/nn/layer/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,27 @@
from ...fluid.framework import _dygraph_tracer

__all__ = [
'BilinearTensorProduct', 'Pool2D', 'Embedding', 'Linear', 'UpSample',
'Pad2D', 'ReflectionPad1d', 'ReplicationPad1d', 'ConstantPad1d',
'ReflectionPad2d', 'ReplicationPad2d', 'ConstantPad2d', 'ZeroPad2d',
'ConstantPad3d', 'ReplicationPad3d', 'CosineSimilarity', 'Dropout',
'Dropout2D', 'Dropout3D', 'AlphaDropout'
'BilinearTensorProduct',
'Pool2D',
'Embedding',
'Linear',
'UpSample',
'Pad2D',
'ReflectionPad1d',
'ReplicationPad1d',
'ConstantPad1d',
'ReflectionPad2d',
'ReplicationPad2d',
'ConstantPad2d',
'ZeroPad2d',
'ConstantPad3d',
'ReplicationPad3d',
'CosineSimilarity',
'Dropout',
'Dropout2D',
'Dropout3D',
'Bilinear',
'AlphaDropout',
]


Expand Down Expand Up @@ -338,6 +354,94 @@ def forward(self, input):
data_format=self._data_format)


class Bilinear(layers.Layer):
"""

This layer performs bilinear on two inputs.

.. math::
out_{i} = x1 * W_{i} * {x2^\mathrm{T}}, i=0,1,...,size-1
out = out + b

In this formula:
- :math:`x1`: the first input contains in1_features elements, shape is [batch_size, in1_features].
- :math:`x2`: the second input contains in2_features elements, shape is [batch_size, in2_features].
- :math:`W_{i}`: the i-th learned weight, shape is [in1_features, in2_features], and learned weight's shape is [out_features, in1_features, in2_features].
- :math:`out_{i}`: the i-th element of out, shape is [batch_size, out_features].
- :math:`b`: the learned bias, shape is [1, out_features].
- :math:`x2^\mathrm{T}`: the transpose of :math:`x2`.

Parameters:
in1_features (int): The dimension of each first input(`x1`).
in2_features (int): The dimension of each second input(`x2`).
out_features (int): The dimension of output of this layer.
weight_attr (ParamAttr, optional): The parameter attribute for the learnable w, parameters/weights of
this layer. The default value is None.
bias_attr (ParamAttr, optional): The parameter attribute for the bias
of this layer. If it is set to False, no bias will be added to the output units.
If it is set to None, the bias is initialized zero. The default value is None.
name (str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to :ref:`api_guide_Name`. Default: None.

Attribute:
**weight** (Parameter): the learnable weights of this layer.

**bias** (Parameter): the learnable bias of this layer.

Returns:
Variable: A 2-D Tensor of shape [batch_size, out_features].

Examples:
.. code-block:: python

import paddle
import numpy

paddle.disable_static()
layer1 = numpy.random.random((5, 5)).astype('float32')
layer2 = numpy.random.random((5, 4)).astype('float32')
bilinear = paddle.nn.Bilinear(
in1_features=5, in2_features=4, out_features=1000)
result = bilinear(paddle.to_tensor(layer1),
paddle.to_tensor(layer2)) # result shape [5, 1000]

"""

def __init__(self,
in1_features,
in2_features,
out_features,
weight_attr=None,
bias_attr=None,
name=None):
super(Bilinear, self).__init__()
self._weight_attr = weight_attr
self._bias_attr = bias_attr
self._name = name
self._in1_features = in1_features
self._in2_features = in2_features
self._out_features = out_features
self._dtype = self._helper.get_default_dtype()

weight_shape = [
self._out_features, self._in1_features, self._in2_features
]
self.weight = self.create_parameter(
attr=self._weight_attr,
shape=weight_shape,
dtype=self._dtype,
is_bias=False)
bias_shape = [1, self._out_features]
self.bias = self.create_parameter(
attr=self._bias_attr,
shape=bias_shape,
dtype=self._dtype,
is_bias=True)

def forward(self, x1, x2):
return F.bilinear(x1, x2, self.weight, self.bias, self._name)


class Dropout(layers.Layer):
"""
Dropout is a regularization technique for reducing overfitting by preventing
Expand Down