diff --git a/paddle/fluid/operators/rrelu_op.cc b/paddle/fluid/operators/rrelu_op.cc new file mode 100644 index 0000000000000..fe5a832e5e71b --- /dev/null +++ b/paddle/fluid/operators/rrelu_op.cc @@ -0,0 +1,142 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/infermeta/unary.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class RReluOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } +}; + +class RReluOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The input of RReLU op."); + AddOutput("Out", "The output of RReLU op."); + AddOutput("Mask", + "The random sampled RReLU mask which is based on X." + "Mask has the same shape as X. Mask[i] is 1 if X[i]>=0." + "Mask[i] is a random sampled value taken from a uniform " + "distribution if X[i]<0 when training. Mask[i] is " + "(lower + upper)/2.0 if X[i]<0 when inference .") + .AsIntermediate() + .AsExtra(); + AddAttr("is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); + // AddAttr("fix_seed", + // "(bool, default false) A flag indicating whether to use a fixed " + // "seed to generate random mask. NOTE: DO NOT set this flag to true in " + // "training. Setting this flag to true is only useful in " + // "unittest or for debug that always the same random sampled " + // "values will be generated.") + // .SetDefault(false) + // .AsExtra(); + + // AddAttr("seed", "RReLU random seed.") + // .SetDefault(0) + // .AsExtra(); + + AddAttr("lower", "Lower bound of the uniform distribution.") + .SetDefault(0.125f) + .AddCustomChecker([](const float& lower) { + PADDLE_ENFORCE_EQ(lower >= 0.0f && lower <= 1.0f, true, + platform::errors::InvalidArgument( + "'rrelu lower' must be in [0, 1].")); + }); + + AddAttr("upper", "Upper bound of the uniform distribution.") + .SetDefault(0.3333f) + .AddCustomChecker([](const float& upper) { + PADDLE_ENFORCE_EQ(upper >= 0.0f && upper <= 1.0f, true, + platform::errors::InvalidArgument( + "'rrelu upper' must be in [0, 1].")); + }); + AddComment(R"DOC( +RReLU Operator. + +Applies the randomized leaky rectified liner unit function, element-wise, +as described in the paper: + +`Empirical Evaluation of Rectified Activations in Convolutional Network`_. + +The function is defined as: + +.. math:: + \text{RReLU}(x) = + \begin{cases} + x & \text{if } x \geq 0 \\ + ax & \text{ otherwise } + \end{cases} + +where :math:`a` is randomly sampled from uniform distribution +:math:`\mathcal{U}(\text{lower}, \text{upper})`. + + See: https://arxiv.org/pdf/1505.00853.pdf + +)DOC"); + } +}; + +class RReluGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; +}; + +template +class RReluGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("rrelu_grad"); + op->SetInput("Mask", this->Output("Mask")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(rrelu, RReluInferShapeFunctor, + PD_INFER_META(phi::RReluInferMeta)); + +REGISTER_OPERATOR(rrelu, ops::RReluOp, ops::RReluOpMaker, + ops::RReluGradOpMaker, + ops::RReluGradOpMaker, + RReluInferShapeFunctor); + +DECLARE_INFER_SHAPE_FUNCTOR(rrelu_grad, RReluGradInferShapeFunctor, + PD_INFER_META(phi::RReluGradInferMeta)); +REGISTER_OPERATOR(rrelu_grad, ops::RReluGradOp, RReluGradInferShapeFunctor); diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index eda461be95a40..943574391488f 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1915,6 +1915,56 @@ void RollInferMeta(const MetaTensor& x, out->set_dtype(x.dtype()); } +void RReluInferMeta(const MetaTensor& x, + float lower, + float upper, + bool is_test, + // bool fix_seed, + // int seed, + MetaTensor* out, + MetaTensor* mask) { + auto x_dims = x.dims(); + PADDLE_ENFORCE_GE(lower, + 0, + phi::errors::InvalidArgument( + "The lower value should be greater than or equal to 0. " + "But received lower value = %f.", + lower)); + PADDLE_ENFORCE_LE(upper, + 1, + phi::errors::InvalidArgument( + "The upper value should be less than or equal to 1. " + "But received upper value = %f.", + upper)); + PADDLE_ENFORCE_GE( + upper, + lower, + phi::errors::InvalidArgument( + "The upper value should be greater than or equal to lower value " + "But received upper value = %f, lower value = %f.", + upper, + lower)); + + out->set_dims(x_dims); + out->set_dtype(x.dtype()); + out->set_layout(x.layout()); + out->share_lod(x); + + if (mask != nullptr) { + mask->set_dims(x_dims); + mask->set_dtype(x.dtype()); + mask->set_layout(x.layout()); + } +} + +void RReluGradInferMeta(const MetaTensor& out_grad, + const MetaTensor& mask, + MetaTensor* x_grad) { + x_grad->set_dims(out_grad.dims()); + x_grad->set_dtype(out_grad.dtype()); + x_grad->share_lod(out_grad); +} + void SetValueInferMeta(const MetaTensor& x, MetaTensor* out) { auto in_dims = x.dims(); PADDLE_ENFORCE_LT( diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 559857bd6ce9b..ade722074605a 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -273,6 +273,17 @@ void RollInferMeta(const MetaTensor& x, const std::vector& axis, MetaTensor* out); +void RReluInferMeta(const MetaTensor& x, + float lower, + float upper, + bool is_test, + MetaTensor* out, + MetaTensor* mask); + +void RReluGradInferMeta(const MetaTensor& out_grad, + const MetaTensor& mask, + MetaTensor* x_grad); + void SetValueInferMeta(const MetaTensor& x, MetaTensor* out); void ShapeInferMeta(const MetaTensor& input, MetaTensor* out); diff --git a/paddle/phi/kernels/cpu/rrelu_grad_kernel.cc b/paddle/phi/kernels/cpu/rrelu_grad_kernel.cc new file mode 100644 index 0000000000000..0a076c33dfec1 --- /dev/null +++ b/paddle/phi/kernels/cpu/rrelu_grad_kernel.cc @@ -0,0 +1,50 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/rrelu_grad_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" + +namespace phi { + +template +void RReluGradKernel(const Context& ctx, + const DenseTensor& mask, + const DenseTensor& out_grad, + DenseTensor* x_grad) { + x_grad->mutable_data(ctx.GetPlace()); + + auto dX = EigenVector::Flatten(*x_grad); + auto dY = EigenVector::Flatten(out_grad); + auto M = EigenVector::Flatten(mask); + + auto& place = *ctx.eigen_device(); + + // Can the following be changed to : + // dX.device(place) = dY * M ; + // dX.device(place) = dY * M.cast(); + dX.device(place) = dY * M; +} + +} // namespace phi + +PD_REGISTER_KERNEL( + rrelu_grad, + CPU, + ALL_LAYOUT, + phi::RReluGradKernel, + float, + double, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/cpu/rrelu_kernel.cc b/paddle/phi/kernels/cpu/rrelu_kernel.cc new file mode 100644 index 0000000000000..33f5cfd4fa3e7 --- /dev/null +++ b/paddle/phi/kernels/cpu/rrelu_kernel.cc @@ -0,0 +1,86 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/phi/kernels/rrelu_kernel.h" +#include "paddle/fluid/framework/generator.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void RReluKernel(const Context& ctx, + const DenseTensor& x, + const float lower, + const float upper, + bool is_test, + // bool fix_seed, + // int seed, + DenseTensor* out, + DenseTensor* mask) { + // auto* y = out; + const T* x_data = x.data(); + // you may try the following 2 lines(what is the difference?) + T* out_data = ctx.template Alloc(out); + T* mask_data = ctx.template Alloc(mask); + // auto* y_data = y->mutable_data(dev_ctx.GetPlace()); + // auto* mask_data = mask->mutable_data(dev_ctx.GetPlace()); + uint64_t size = x.numel(); + auto zero = static_cast(0); + auto one = static_cast(1); + + if (!is_test) { + // int seed_data = fix_seed ? seed : 0; + // auto engine = paddle::framework::GetCPURandomEngine(seed_data); + // std::uniform_real_distribution dist(lower, upper); + + auto gen = ctx.GetGenerator(); + auto engine = gen->GetCPUEngine(); + std::uniform_real_distribution dist(lower, upper); + + for (uint64_t i = 0; i < size; ++i) { + if (x_data[i] >= zero) { + mask_data[i] = one; + out_data[i] = x_data[i]; + } else { + auto ramdom_sampled_value = static_cast(dist(*engine)); + mask_data[i] = ramdom_sampled_value; + out_data[i] = x_data[i] * ramdom_sampled_value; + } + } + } else { + auto middle_value = static_cast((lower + upper) / 2.0f); + for (uint64_t i = 0; i < size; ++i) { + if (x_data[i] >= zero) { + out_data[i] = x_data[i]; + mask_data[i] = one; + } else { + out_data[i] = x_data[i] * middle_value; + mask_data[i] = middle_value; + } + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(rrelu, + CPU, + ALL_LAYOUT, + phi::RReluKernel, + float, + double, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/rrelu_grad_kernel.cu b/paddle/phi/kernels/gpu/rrelu_grad_kernel.cu new file mode 100644 index 0000000000000..0acd83d2600fa --- /dev/null +++ b/paddle/phi/kernels/gpu/rrelu_grad_kernel.cu @@ -0,0 +1,79 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// #include "paddle/phi/kernels/gpu/rrelu_impl.cu.h" +#include "paddle/phi/kernels/rrelu_grad_kernel.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/for_range.h" + +namespace phi { + +// template +// void RReluGradKernel(const Context& dev_ctx, +// const DenseTensor& mask, +// const DenseTensor& out_grad, +// DenseTensor* x_grad) { +// x_grad->mutable_data(dev_ctx.GetPlace()); +// auto size = mask.numel(); +// paddle::operators::RReluGradGPUKernelDriver( +// dev_ctx, out_grad, mask, x_grad); +// } + + +template +struct RReluGradCudaFunctor { + public: + RReluGradCudaFunctor(const T* mask, + const T* out_grad, + T* x_grad) + : mask_(mask), out_grad_(out_grad), x_grad_(x_grad) {} + + __device__ void operator()(int64_t idx) { + x_grad_[idx] = mask_[idx] * out_grad_[idx]; + } + + private: + const T* mask_; + const T* out_grad_; + T* x_grad_; +}; + +template +void RReluGradKernel(const Context& ctx, + const DenseTensor& mask, + const DenseTensor& out_grad, + DenseTensor* x_grad) { + const T* mask_data = mask.data(); + const T* out_grad_data = out_grad.data(); + T* x_grad_data = ctx.template Alloc(x_grad); + auto size = mask.numel(); + + phi::funcs::ForRange for_range(ctx, size); + + RReluGradCudaFunctor functor(mask_data, out_grad_data, x_grad_data); + for_range(functor); +} + + +} // namespace phi + +PD_REGISTER_KERNEL(rrelu_grad, + GPU, + ALL_LAYOUT, + phi::RReluGradKernel, + float, + double, + phi::dtype::float16) {} + // phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/rrelu_kernel.cu b/paddle/phi/kernels/gpu/rrelu_kernel.cu new file mode 100644 index 0000000000000..990ebb11b849d --- /dev/null +++ b/paddle/phi/kernels/gpu/rrelu_kernel.cu @@ -0,0 +1,129 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "paddle/phi/kernels/rrelu_kernel.h" +// #include "paddle/phi/kernels/gpu/rrelu_impl.cu.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/for_range.h" +#include "paddle/phi/kernels/funcs/distribution_helper.h" + + +namespace phi { + +template +struct RReluTrainCudaFunctor { + public: + RReluTrainCudaFunctor(const T* in, + T* out, + T* mask) + : in_(in), out_(out), mask_(mask) { + zero_ = static_cast(0); + one_ = static_cast(1); + } + + __device__ void operator()(int64_t idx) { + if (in_[idx] < zero_) { + out_[idx] = in_[idx] * mask_[idx]; + } else { + mask_[idx] = one_; + out_[idx] = in_[idx]; + } + } + + private: + const T* in_; + T* out_; + T* mask_; + T zero_; + T one_; +}; + + +template +struct RReluTestCudaFunctor { + public: + RReluTestCudaFunctor(const T* in, + T* out, + T* mask, + T mid_value) + : in_(in), out_(out), mask_(mask), mid_value_(mid_value) { + zero_ = static_cast(0); + one_ = static_cast(1); + } + + __device__ void operator()(int64_t idx) { + if (in_[idx] < zero_) { + mask_[idx] = mid_value_; + out_[idx] = in_[idx] * mid_value_; + } else { + mask_[idx] = one_; + out_[idx] = in_[idx]; + } + } + + private: + const T* in_; + T* out_; + T* mask_; + T mid_value_; + T zero_; + T one_; +}; + + + +template +void RReluKernel(const Context& ctx, + const DenseTensor& x, + const float lower, + const float upper, + bool is_test, + DenseTensor* out, + DenseTensor* mask) { + const T* x_data = x.data(); + T* out_data = ctx.template Alloc(out); + T* mask_data = ctx.template Alloc(mask); + auto size = x.numel(); + if (size <= 0) + return; + phi::funcs::ForRange for_range(ctx, size); + + if (!is_test) { + using MT = typename kps::details::MPTypeTrait::Type; + funcs::uniform_distribution dist; + funcs::uniform_real_transform trans(lower, upper); + funcs::distribution_and_transform(ctx, mask, dist, trans); + + RReluTrainCudaFunctor functor(x_data, out_data, mask_data); + for_range(functor); + } else { + T mid_value = static_cast((lower + upper) / 2.0f); + RReluTestCudaFunctor functor(x_data, out_data, mask_data, mid_value); + for_range(functor); + } +} + + +} // namespace phi + +PD_REGISTER_KERNEL(rrelu, + GPU, + ALL_LAYOUT, + phi::RReluKernel, + float, + double, + phi::dtype::float16) {} + // phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/rrelu_grad_kernel.h b/paddle/phi/kernels/rrelu_grad_kernel.h new file mode 100644 index 0000000000000..cb2424633e27e --- /dev/null +++ b/paddle/phi/kernels/rrelu_grad_kernel.h @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void RReluGradKernel(const Context& ctx, + const DenseTensor& mask, + const DenseTensor& out_grad, + DenseTensor* x_grad); +} // namespace phi diff --git a/paddle/phi/kernels/rrelu_kernel.h b/paddle/phi/kernels/rrelu_kernel.h new file mode 100644 index 0000000000000..142a8522d6200 --- /dev/null +++ b/paddle/phi/kernels/rrelu_kernel.h @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void RReluKernel(const Context& ctx, + const DenseTensor& x, + const float lower, + const float upper, + bool is_test, + // bool fix_seed, + // int seed, + DenseTensor* out, + DenseTensor* mask); +} // namespace phi diff --git a/paddle/phi/ops/compat/rrelu_sig.cc b/paddle/phi/ops/compat/rrelu_sig.cc new file mode 100644 index 0000000000000..a0c0002d4d4ad --- /dev/null +++ b/paddle/phi/ops/compat/rrelu_sig.cc @@ -0,0 +1,39 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature RReluOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "rrelu", + {"X"}, + {"lower", "upper", "is_test"}, + {"Out", "Mask"}); +} + +KernelSignature RReluGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "rrelu_grad", + {"Mask", "Out@GRAD"}, + {}, + {"X@GRAD"}); +} +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(rrelu, phi::RReluOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(rrelu_grad, phi::RReluGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_rrelu_op.py b/python/paddle/fluid/tests/unittests/test_rrelu_op.py new file mode 100644 index 0000000000000..add2d9b0b9a17 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_rrelu_op.py @@ -0,0 +1,555 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle.fluid.core as core +from op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16 +import paddle +import paddle.static as static +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard +from paddle.fluid.framework import _test_eager_guard, _enable_legacy_dygraph +import os + +from paddle import _C_ops + + +def rrelu_inference(x, lower, upper): + # use copy of input to avoid changing the value of input in the following calculation + x_t = x.copy() + alpha = (lower + upper) / 2.0 + return np.where(x_t < 0, alpha * x_t, x_t) + + +def check_element_range_of_rrelu_output_in_training(input: np.ndarray, op_output: np.ndarray, lower: float, upper: float): + """ + input: x + op outpout: y + check that: + if x[i] >= 0, then x[i] == y[i] + if x[i] < 0, then upper * x[i] <= y[i] <= lower * x[i] + + return: True: the test is passed; + False: the test is not passed + """ + # use copy of input to avoid changing the value of input in the following calculation + input, op_output = input.copy(), op_output.copy() + passed_1 = np.allclose(input[input >= 0], op_output[input >= 0]) + if passed_1 == False: + return False + passed_2 = (op_output[input < 0] <= (input[input < 0] * lower)).all() + if passed_2 == False: + return False + passed_3 = (op_output[input < 0] >= (input[input < 0] * upper)).all() + return passed_3 + + +def check_negative_elements_distribution_of_rrelu_output_in_training(input: np.ndarray, op_output: np.ndarray, lower: float, upper: float, num_segments: int, scale: float): + """ + input: x + op_output: y + + Only check negative elements. + Divide the interval [lower, upper] into num_segments equal parts. + [a, b] is a small interval in [lower, upper]. 0 <=a <= b <= 1 + count = the number of i that satisfies x[i] < 0 and b * x[i] <= y[i] <= a * x[i] + Then count / x.size >= scale * (b - a) / (upper - lower) + + scale is recommended to be in the range [0.1, 0.8] + + if this check is passed, you can "roughly" believe that the function of + RReLU API has been properly implemented. + the function of RReLU API can be shown as follows: + out = np.where(x < 0, + np.random.uniform(lower, upper, x.shape) * x, x) + """ + # use copy of input to avoid changing the value of input in the following calculation + input, op_output = input.copy(), op_output.copy() + num_negative_elements = np.sum(input < 0) + one_part_length = (upper - lower) / num_segments + special_alphas = [] + for i in range(num_segments): + alpha = lower + i * one_part_length + special_alphas.append(alpha) + special_alphas.append(upper) + + for i in range(num_segments): + bool_array_1 = op_output[input < 0] <= (input[input < 0] * special_alphas[i]) + bool_array_2 = op_output[input < 0] >= (input[input < 0] * special_alphas[i+1]) + count = np.sum(bool_array_1 * bool_array_2) + # print(i, "{}%".format(count / num_negative_elements * 100)) + if count / num_negative_elements < scale * 1 / num_segments: + return False + return True + + +# class TestRReluOpInference(OpTest): +# """ +# test the inference mode of rrelu op, +# you can subclass this class and modify "setUp" method +# as you want +# """ +# def setUp(self): +# self.op_type = "rrelu" +# self.lower = 0.1 +# self.upper = 0.3 +# # self.fix_seed = True +# # self.seed = 1 +# self.dtype = "float64" +# self.x_shape = [2, 3, 4, 5] +# self.x_low = -1 +# self.x_high = 1 +# self.init() + +# def init(self): +# x_np = np.random.uniform(self.x_low, self.x_high, self.x_shape).astype(self.dtype) +# out_np = rrelu_inference(x_np, self.lower, self.upper) +# mask_np = np.ones(self.x_shape).astype(self.dtype) +# mask_np[x_np < 0] = (self.lower + self.upper) / 2.0 + +# self.inputs = {'X': x_np} +# self.outputs = {'Out': out_np, 'Mask': mask_np} +# self.attrs = { +# 'lower': self.lower, +# "upper": self.upper, +# "is_test": True, +# # "fix_seed": self.fix_seed, +# # "seed": self.seed +# } + +# def test_check_output(self): +# self.check_output() + +# def test_check_grad(self): +# self.check_grad(['X'], 'Out') + + +# class TestRReluOpInference2(TestRReluOpInference): +# def setUp(self): +# self.op_type = "rrelu" +# self.lower = 0.3 +# self.upper = 0.99 +# # self.fix_seed = True +# # self.seed = 198 +# self.dtype = "float64" +# self.x_shape = [20, 10] +# self.x_low = -9 +# self.x_high = -1 +# self.init() + + +# class TestRReluOpInference3(TestRReluOpInference): +# def setUp(self): +# self.op_type = "rrelu" +# self.lower = 0.8 +# self.upper = 0.99 +# self.fix_seed = False +# self.seed = 198 +# self.dtype = "float32" +# self.x_shape = [2, 100] +# self.x_low = -9 +# self.x_high = 10 +# self.init() + +# def test_check_output(self): +# self.check_output(atol=1e-3) + + +# class TestRReluOpTraining(OpTest): +# """ +# test the training mode of rrelu op, but +# set lower to be equal to upper, +# you can subclass this class and modify "setUp" method +# as you want +# """ +# def setUp(self): +# self.op_type = "rrelu" +# self.lower = 0.1 +# self.fix_seed = True +# self.seed = 1 +# self.dtype = "float64" +# self.x_shape = [2, 3, 4, 5] +# self.x_low = -1 +# self.x_high = 1 +# self.init() + +# def init(self): +# x_np = np.random.uniform(self.x_low, self.x_high, self.x_shape).astype(self.dtype) +# out_np = rrelu_inference(x_np, self.lower, self.lower) +# mask_np = np.ones(self.x_shape).astype(self.dtype) +# mask_np[x_np < 0] = self.lower + +# self.inputs = {'X': x_np} +# self.outputs = {'Out': out_np, 'Mask': mask_np} +# self.attrs = { +# 'lower': self.lower, +# "upper": self.lower, +# "is_test": False, +# "fix_seed": self.fix_seed, +# "seed": self.seed +# } + +# def test_check_output(self): +# self.check_output() + +# def test_check_grad(self): +# self.check_grad(['X'], 'Out') + + +# class TestRReluOpTraining2(TestRReluOpTraining): +# def setUp(self): +# self.op_type = "rrelu" +# self.lower = 0.897 +# self.fix_seed = True +# self.seed = 123 +# self.dtype = "float64" +# self.x_shape = [11, 4, 5] +# self.x_low = -10 +# self.x_high = 10 +# self.init() + + +# class TestRReluOpTraining3(TestRReluOpTraining): +# def setUp(self): +# self.op_type = "rrelu" +# self.lower = 0.0786 +# self.fix_seed = False +# self.seed = 123 +# self.dtype = "float64" +# self.x_shape = [2, 3, 4, 5] +# self.x_low = -100 +# self.x_high = 10 +# self.init() + + +# class TestRReluOp(OpTest): +# def setUp(self): +# self.op_type = "rrelu" +# self.inputs = {'X': np.random.random((32, 64)).astype("float64")} +# self.attrs = { +# 'lower': 0.0, 'upper': 0.8, +# 'fix_seed': False, 'is_test': False} +# self.outputs = { +# 'Out': self.inputs['X'], +# 'Mask': np.ones((32, 64)).astype("float64") +# } + +# def test_check_output(self): +# self.check_output() + +# def test_check_grad_normal(self): +# self.check_grad(['X'], 'Out') + + +# class TestRReluOpInput1d(OpTest): +# def setUp(self): +# self.op_type = "rrelu" +# self.inputs = {'X': np.random.random((2000, )).astype("float64")} +# self.attrs = { +# 'lower': 0.2, 'upper': 0.7, +# 'fix_seed': True, 'is_test': False} +# self.outputs = { +# 'Out': self.inputs['X'], +# 'Mask': np.ones((2000)).astype('float64') +# } + +# def test_check_output(self): +# self.check_output() + +# def test_check_grad_normal(self): +# self.check_grad(['X'], 'Out') + + +# class TestRReluOp2(TestRReluOp): +# def setUp(self): +# self.op_type = "rrelu" +# self.inputs = {'X': np.random.uniform(-100, -10, [19, 3, 4]).astype('float64')} +# self.attrs = { +# 'lower': 0, 'upper': 0, +# 'fix_seed': True, 'is_test': False} +# self.outputs = { +# 'Out': np.zeros([19, 3, 4]).astype('float64'), +# 'Mask': np.zeros([19, 3, 4]).astype('float64') +# } + + +# class TestRReluOp3(TestRReluOp): +# def setUp(self): +# self.op_type = "rrelu" +# self.inputs = {'X': np.random.uniform(-10, 10, [2, 30, 4]).astype('float64')} +# self.attrs = { +# 'lower': 1, 'upper': 1, +# 'fix_seed': False, 'is_test': False} +# self.outputs = { +# 'Out': self.inputs['X'], +# 'Mask': np.ones([2, 30, 4]).astype('float64') +# } + + +# @skip_check_grad_ci(reason="For inference, check_grad is not required.") +# class TestRReluOp9(OpTest): +# def setUp(self): +# self.op_type = "rrelu" +# self.inputs = {'X': np.random.random((32, 64, 3)).astype("float32")} +# self.attrs = { +# 'is_test': False +# } +# self.outputs = {'Out': self.inputs['X']} + +# def test_check_output(self): +# self.check_output() + +##################################################################3333 +#The following tests are passed +@unittest.skipIf( + not core.is_compiled_with_cuda() or not core.op_support_gpu("rrelu"), + "core is not compiled with CUDA or core is not support rrelu") +@skip_check_grad_ci(reason="For inference, check_grad is not required.") +class TestFP16RReluOp(OpTest): + def setUp(self): + self.op_type = "rrelu" + self.init_test_case() + + x_np = np.random.uniform(-1, 1, self.x_shape).astype("float16") + out_np = rrelu_inference(x_np, self.lower, self.upper) + mask_np = np.ones(self.x_shape).astype("float16") + mask_np[x_np < 0] = (self.lower + self.upper) / 2.0 + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x_np)} + self.attrs = { + 'lower': self.lower, + 'upper': self.upper, + 'is_test': True + } + self.outputs = {'Out': out_np, 'Mask': mask_np} + + def init_test_case(self): + self.x_shape = [32, 64] + self.lower = 0.17 + self.upper = 0.89 + + def test_check_output(self): + self.check_output_with_place(core.CUDAPlace(0), atol=1e-3) + + +@unittest.skipIf( + not core.is_compiled_with_cuda() or not core.op_support_gpu("rrelu"), + "core is not compiled with CUDA or core is not support rrelu") +@skip_check_grad_ci(reason="For inference, check_grad is not required.") +class TestFP16RReluOp2(TestFP16RReluOp): + def init_test_case(self): + self.x_shape = [21, 3, 7] + self.lower = 0.1 + self.upper = 0.127 + + +# class TestBF16RReluOp(OpTest): +# def setUp(self): +# self.op_type = "rrelu" +# self.dtype = np.uint16 +# self.lower = self.upper = 0.78 + +# x_shape = (32, 64) +# x_np = np.random.uniform(-2, 3, x_shape).astype("float32") +# out_np = rrelu_inference(x_np, self.lower, self.upper) +# mask_np = np.ones(x_shape).astype("float32") +# mask_np[x_np < 0] = self.lower +# self.inputs = {'X': convert_float_to_uint16(x_np)} +# self.attrs = { +# 'lower': self.lower, +# 'upper': self.upper, +# 'is_test': False +# } +# self.outputs = { +# 'Out': convert_float_to_uint16(out_np), +# 'Mask': convert_float_to_uint16(mask_np) +# } + +# def test_check_output(self): +# self.check_output() + +# def test_check_grad_normal(self): +# self.check_grad(['X'], 'Out') + + +class TestRReluFAPI(unittest.TestCase): + def setUp(self): + np.random.seed(123) + self.places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(fluid.CUDAPlace(0)) + + def check_static_result(self, place): + paddle.enable_static() + with static.program_guard(static.Program(), static.Program()): + input = static.data(name="input", shape=[-1, -1], dtype="float32") + exe = static.Executor(place) + + res1 = paddle.nn.functional.rrelu(x=input, lower=1, upper=1, training=False) + res2 = paddle.nn.functional.rrelu(x=input, lower=1, upper=1, training=True) + in_np = np.random.uniform(-3, 2, [40, 40]).astype("float32") + res_np = in_np + for res in [res1, res2]: + fetches = exe.run(static.default_main_program(), + feed={"input": in_np}, + fetch_list=[res]) + self.assertTrue(np.allclose(fetches[0], res_np)) + + lower, upper = 0.17, 0.99 + res3 = paddle.nn.functional.rrelu(x=input, lower=lower, upper=upper, training=False) + in_np = np.random.uniform(-4, 1, [20, 20]).astype("float32") + res_np = rrelu_inference(in_np, lower, upper) + fetches = exe.run(static.default_main_program(), + feed={"input": in_np}, + fetch_list=[res3]) + self.assertTrue(np.allclose(fetches[0], res_np)) + + lower = upper = 0.23 + res4 = paddle.nn.functional.rrelu(x=input, lower=lower, upper=upper, training=True) + in_np = np.random.uniform(-5, 2, [11, 20]).astype("float32") + res_np = rrelu_inference(in_np, lower, upper) + fetches = exe.run(static.default_main_program(), + feed={"input": in_np}, + fetch_list=[res4]) + self.assertTrue(np.allclose(fetches[0], res_np)) + + # Attention: this part is important!!! + lower, upper = 0.2, 0.9 + res5 = paddle.nn.functional.rrelu(x=input, lower=lower, upper=upper, training=True) + in_np = np.random.uniform(-50, 1, [40, 30]).astype("float32") + fetches = exe.run(static.default_main_program(), + feed={"input": in_np}, + fetch_list=[res5]) + passed_1 = check_element_range_of_rrelu_output_in_training( + in_np, fetches[0], lower=lower, upper=upper + ) + self.assertTrue(passed_1) + passed_2 = check_negative_elements_distribution_of_rrelu_output_in_training( + in_np, fetches[0], lower=lower, upper=upper, num_segments=5, scale=0.8 + ) + self.assertTrue(passed_2) + + paddle.disable_static() + + def test_static(self): + for place in self.places: + self.check_static_result(place=place) + + def test_dygraph(self): + paddle.disable_static() + for place in self.places: + in_np = np.random.uniform(-3, 2, [2, 7, 40]).astype("float32") + res_np = in_np + in_tensor = paddle.to_tensor(in_np, place=place) + res1 = paddle.nn.functional.rrelu(x=in_tensor, lower=1, upper=1, training=False) + res2 = paddle.nn.functional.rrelu(x=in_tensor, lower=1, upper=1, training=True) + for res in [res1, res2]: + self.assertTrue(np.allclose(res.numpy(), res_np)) + + lower, upper = 0.17, 0.99 + in_np = np.random.uniform(-4, 1, [20, 20]).astype("float32") + res_np = rrelu_inference(in_np, lower, upper) + in_tensor = paddle.to_tensor(in_np, place=place) + res3 = paddle.nn.functional.rrelu(x=in_tensor, lower=lower, upper=upper, training=False) + self.assertTrue(np.allclose(res3.numpy(), res_np)) + + lower = upper = 0.23 + in_np = np.random.uniform(-5, 2, [11, 20]).astype("float32") + res_np = rrelu_inference(in_np, lower, upper) + in_tensor = paddle.to_tensor(in_np, place=place) + res4 = paddle.nn.functional.rrelu(x=in_tensor, lower=lower, upper=upper, training=True) + self.assertTrue(np.allclose(res4.numpy(), res_np)) + + #Attention: this part is important!!! + lower, upper = 0.23, 0.99 + in_np = np.random.uniform(-50, 1, [11, 20, 3]).astype("float32") + in_tensor = paddle.to_tensor(in_np, place=place) + res5 = paddle.nn.functional.rrelu(x=in_tensor, lower=lower, upper=upper, training=True) + passed_1 = check_element_range_of_rrelu_output_in_training( + in_np, res5.numpy(), lower=lower, upper=upper + ) + self.assertTrue(passed_1) + passed_2 = check_negative_elements_distribution_of_rrelu_output_in_training( + in_np, res5.numpy(), lower=lower, upper=upper, num_segments=5, scale=0.8 + ) + self.assertTrue(passed_2) + + +class TestRReluFAPIError(unittest.TestCase): + def test_errors(self): + paddle.enable_static() + with static.program_guard(static.Program(), static.Program()): + def test_Variable(): + # the input of rrelu must be Variable. + x1 = fluid.create_lod_tensor( + np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace()) + paddle.nn.functional.rrelu(x1, training=True) + + self.assertRaises(TypeError, test_Variable) + + def test_Variable2(): + # the input of rrelu must be Variable. + x1 = fluid.create_lod_tensor( + np.array([-1, -3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace()) + paddle.nn.functional.rrelu(x1, training=False) + + self.assertRaises(TypeError, test_Variable2) + + def test_dtype(): + xr = fluid.data(name='xr', shape=[3, 4, 5, 6], dtype="int32") + paddle.nn.functional.rrelu(xr) + + self.assertRaises(TypeError, test_dtype) + + def test_lower_dtype(): + # lower should be int or float + x2 = fluid.data(name='x2', shape=[3, 4, 5, 6], dtype="float32") + paddle.nn.functional.rrelu(x2, lower='0.5', upper=0.8) + + self.assertRaises(TypeError, test_lower_dtype) + + def test_lower_value(): + # lower should be in the interval [0.0, 1.0] + x2 = fluid.data(name='x2', shape=[3, 4, 5, 6], dtype="float32") + paddle.nn.functional.rrelu(x2, lower=-0.8, upper=0.5) + + self.assertRaises(ValueError, test_lower_value) + + paddle.disable_static() + + +class TestRReluCAPI(unittest.TestCase): + def setUp(self): + np.random.seed(123) + self.places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(fluid.CUDAPlace(0)) + + def test_dygraph(self): + for place in self.places: + with fluid.dygraph.guard(place): + input_np = np.random.random([40, 40]).astype("float32") + result_np = input_np + input = fluid.dygraph.to_variable(input_np) + rrelu_layer = paddle.nn.RReLU(lower=0.12, upper=0.87) + rrelu_layer.eval() + result = rrelu_layer(input) + self.assertTrue(np.allclose(result.numpy(), result_np)) + + +if __name__ == '__main__': + # paddle.enable_static() + unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index bceee4b964a33..b4be291b0697f 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -51,6 +51,7 @@ from .layer.activation import ThresholdedReLU # noqa: F401 from .layer.activation import LogSoftmax # noqa: F401 from .layer.activation import Maxout # noqa: F401 +from .layer.activation import RReLU # noqa: F401 from .layer.common import Pad1D # noqa: F401 from .layer.common import Pad2D # noqa: F401 from .layer.common import ZeroPad2D # noqa: F401 @@ -313,4 +314,5 @@ def weight_norm(*args): 'MaxUnPool3D', 'HingeEmbeddingLoss', 'Identity', + 'RReLU', ] diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 68213d831c550..44acf32894588 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -47,6 +47,7 @@ from .activation import log_softmax # noqa: F401 from .activation import glu # noqa: F401 from .activation import gumbel_softmax # noqa: F401 +from .activation import rrelu # noqa: F401 from .common import dropout # noqa: F401 from .common import dropout2d # noqa: F401 from .common import dropout3d # noqa: F401 @@ -228,4 +229,5 @@ 'class_center_sample', 'sparse_attention', 'fold', + 'rrelu', ] diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index e64efda7b33bf..27752fc8a9eaf 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ import warnings from ...fluid.layer_helper import LayerHelper -from ...fluid.framework import convert_np_dtype_to_dtype_ +from ...fluid.framework import convert_np_dtype_to_dtype_, default_main_program from ...fluid.framework import _in_legacy_dygraph, in_dygraph_mode, _non_static_mode from ...fluid.data_feeder import check_variable_and_dtype, check_dtype import paddle @@ -548,6 +548,94 @@ def prelu(x, weight, data_format="NCHW", name=None): return out +def rrelu(x, lower=1./8., upper=1./3., training=True, name=None): + """ + rrelu activation. + + .. _`Empirical Evaluation of Rectified Activations in Convolutional Network`: + https://arxiv.org/abs/1505.00853 + + .. math:: + \text{RReLU}(x) = + \begin{cases} + x & \text{if } x \geq 0 \\ + ax & \text{ otherwise } + \end{cases} + + where :math:`a` is randomly sampled from uniform distribution + :math:`\mathcal{U}(\text{lower}, \text{upper})`. + + Parameters: + x (Tensor): The input Tensor with data type float 16 float32, float64. + lower (float, optional): The lower bound of uniform distribution. Default: :math:`\frac{1}{8}`. + upper (float, optional): The upper bound of uniform distribution. Default: :math:`\frac{1}{3}`. + training (bool, optional): Current is training mode or others. Default is True. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A Tensor with the same data type and shape as ``x`` . + + Examples: + .. code-block:: python + :name: rrelu-example + + import paddle + import paddle.nn.functional as F + import numpy as np + + data = np.array([[[[-2.0, 3.0, -4.0, 5.0], + [ 3.0, -4.0, 5.0, -6.0], + [-7.0, -8.0, 8.0, 9.0]], + [[ 1.0, -2.0, -3.0, 4.0], + [-5.0, 6.0, 7.0, -8.0], + [ 6.0, 7.0, 8.0, 9.0]]]], 'float32') + input_tensor = paddle.to_tensor(data) + out = F.rrelu(input_tensor, 0.1, 0.3) + #[[[[-0.20000899 3. -0.8810822 5. ] + # [ 3. -0.55175185 5. -1.0776101 ] + # [-1.0680687 -1.9896201 8. 9. ]] + # [[ 1. -0.5238267 -0.65515125 4. ] + # [-1.3766339 6. 7. -2.3465784 ] + # [ 6. 7. 8. 9. ]]]] + """ + if not isinstance(lower, (float, int)) or not isinstance(upper, (float, int)): + raise TypeError( + "The lower and upper values must be float or int type. Received: lower {}, upper {}.". + format(lower, upper)) + + if lower < 0 or upper < lower or upper > 1: + raise ValueError( + "The lower and upper values must be in the range [0.0, 1.0] and upper must be greater " + " than or equal to lower. Received: lower={}, upper={}.". + format(lower, upper)) + + is_test = not training + + if _in_legacy_dygraph(): + out, mask = _C_ops.rrelu(x, 'lower', lower, 'upper', upper, 'is_test', is_test) + return out + + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'rrelu') + + helper = LayerHelper('rrelu', **locals()) + out = helper.create_variable_for_type_inference(x.dtype) + mask = helper.create_variable_for_type_inference(dtype=x.dtype) + attrs = { + 'lower': lower, + 'upper': upper, + 'is_test': is_test + } + + helper.append_op( + type='rrelu', + inputs={"X": x}, + outputs={"Out": out, + "Mask": mask}, + attrs=attrs) + return out + + def relu(x, name=None): """ relu activation. diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 31364f0281c8a..cca8c37645df6 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -26,6 +26,7 @@ from .activation import Sigmoid # noqa: F401 from .activation import Softmax # noqa: F401 from .activation import LogSoftmax # noqa: F401 +from .activation import RReLU # noqa: F401 from .activation import Softmax2D # noqa: F401 from .common import Bilinear # noqa: F401 from .common import Pad1D # noqa: F401 diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index cd82fe12fff6b..e25a588dbb6bc 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -436,6 +436,76 @@ def extra_repr(self): name_str) +class RReLU(Layer): + """ + rrelu activation. + + `Empirical Evaluation of Rectified Activations in Convolutional Network`: + https://arxiv.org/abs/1505.00853 + + .. math:: + \text{RReLU}(x) = + \begin{cases} + x & \text{if } x \geq 0 \\ + ax & \text{ otherwise } + \end{cases} + + where :math:`a` is randomly sampled from uniform distribution + :math:`\mathcal{U}(\text{lower}, \text{upper})`. + + Parameters: + lower (float, optional): The lower bound of uniform distribution. Default: :math:`\frac{1}{8}`. + upper (float, optional): The upper bound of uniform distribution. Default: :math:`\frac{1}{3}`. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - input: Tensor with any shape. Default dtype is float32. + - output: Tensor with the same shape as input. + + Examples: + .. code-block:: python + :name: RReLU-example + + import paddle + import numpy as np + + paddle.set_default_dtype("float64") + + data = np.array([[[[-2.0, 3.0, -4.0, 5.0], + [ 3.0, -4.0, 5.0, -6.0], + [-7.0, -8.0, 8.0, 9.0]], + [[ 1.0, -2.0, -3.0, 4.0], + [-5.0, 6.0, 7.0, -8.0], + [ 6.0, 7.0, 8.0, 9.0]]]], 'float64') + input_tensor = paddle.to_tensor(data) + rrelu_layer = paddle.nn.RReLU(0.1, 0.3) + output = rrelu_layer(input_tensor) + #[[[[-0.20000899 3. -0.88108218 5. ] + # [ 3. -0.55175185 5. -1.07761011] + # [-1.06806871 -1.98962009 8. 9. ]] + # [[ 1. -0.52382672 -0.65515128 4. ] + # [-1.37663394 6. 7. -2.34657836] + # [ 6. 7. 8. 9. ]]]] + """ + + def __init__(self, lower=1. / 8., upper=1. / 3., name=None): + super(RReLU, self).__init__() + self.lower = lower + self.upper = upper + self.name = name + + def forward(self, x): + return F.rrelu( + x, lower=self.lower, upper=self.upper, + training=self.training, name=self.name) + + def extra_repr(self): + name_str = ', name={}'.format(self.name) if self.name else '' + return 'lower={}, upper={}{}'.format( + self.lower, self.upper, name_str) + + class ReLU(Layer): """ ReLU Activation. diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index aaa667595f94c..8e90d84d5bf3b 100755 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -398,6 +398,7 @@ 'test_positive_negative_pair_op', 'test_precision_recall_op', 'test_prelu_op', + 'test_rrelu_op', 'test_prelu_mkldnn_op', 'test_print_op', 'test_prior_box_op',