forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
【PaddlePaddle Hackathon 2】16 新增 API RRelu (PaddlePaddle#41823)
* rrelu逻辑部分 * unregistered op kernel (unresolved) * commit before merge * 丰富测试用例 * 修复rrelu-sig的bug * 修复cpu环境测试 * 修改拼写错误 * 修改code format * 尝试优化测试用例timeout的问题 * 优化测试用例 * 移除seed, 优化随机函数 * update en doc for rrelu * fix rrelu en docs, test=document_fix * add paper link for en docs, test=document_fix * udpate en doc * add r,test=document_fix
- Loading branch information
Showing
17 changed files
with
1,129 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include <memory> | ||
#include <string> | ||
#include "paddle/fluid/framework/infershape_utils.h" | ||
#include "paddle/fluid/framework/op_registry.h" | ||
#include "paddle/phi/infermeta/unary.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
using framework::Tensor; | ||
|
||
class RReluOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
protected: | ||
framework::OpKernelType GetExpectedKernelType( | ||
const framework::ExecutionContext& ctx) const override { | ||
return framework::OpKernelType( | ||
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); | ||
} | ||
}; | ||
|
||
class RReluOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
void Make() override { | ||
AddInput("X", "The input of RReLU op."); | ||
AddOutput("Out", "The output of RReLU op."); | ||
AddOutput("Noise", "The random sampled RReLU noise.") | ||
.AsIntermediate() | ||
.AsExtra(); | ||
AddAttr<bool>("is_test", | ||
"(bool, default false) Set to true for inference only, false " | ||
"for training. Some layers may run faster when this is true.") | ||
.SetDefault(false); | ||
float default_lower = 1. / 8.; | ||
AddAttr<float>("lower", "Lower bound of the uniform distribution.") | ||
.SetDefault(default_lower) | ||
.AddCustomChecker([](const float& lower) { | ||
PADDLE_ENFORCE_EQ(lower >= 0.0f && lower < 1.0f, true, | ||
platform::errors::InvalidArgument( | ||
"'RRelu_lower' must be between 0.0 and 1.0.")); | ||
}); | ||
float defalut_upper = 1. / 3.; | ||
AddAttr<float>("upper", "Upper bound of the uniform distribution.") | ||
.SetDefault(defalut_upper) | ||
.AddCustomChecker([](const float& upper) { | ||
PADDLE_ENFORCE_EQ(upper > 0.0f && upper <= 1.0f, true, | ||
platform::errors::InvalidArgument( | ||
"'RRelu_upper' must be between 0.0 and 1.0.")); | ||
}); | ||
AddComment(R"DOC( | ||
RReLU Operator. | ||
Applies the randomized leaky rectified liner unit function, element-wise, | ||
as described in the paper: | ||
`Empirical Evaluation of Rectified Activations in Convolutional Network`_. | ||
The function is defined as: | ||
.. math:: | ||
\text{RReLU}(x) = | ||
\begin{cases} | ||
x & \text{if } x \geq 0 \\ | ||
ax & \text{ otherwise } | ||
\end{cases} | ||
where :math:`a` is randomly sampled from uniform distribution | ||
:math:`\mathcal{U}(\text{lower}, \text{upper})`. | ||
See: https://arxiv.org/pdf/1505.00853.pdf | ||
)DOC"); | ||
} | ||
}; | ||
|
||
class RReluGradOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
}; | ||
|
||
template <typename T> | ||
class RReluGradOpMaker : public framework::SingleGradOpMaker<T> { | ||
public: | ||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; | ||
|
||
protected: | ||
void Apply(GradOpPtr<T> op) const override { | ||
op->SetType("rrelu_grad"); | ||
op->SetInput("X", this->Input("X")); | ||
op->SetInput("Noise", this->Output("Noise")); | ||
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); | ||
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
DECLARE_INFER_SHAPE_FUNCTOR(rrelu, RReluInferShapeFunctor, | ||
PD_INFER_META(phi::RReluInferMeta)); | ||
|
||
REGISTER_OPERATOR(rrelu, ops::RReluOp, ops::RReluOpMaker, | ||
ops::RReluGradOpMaker<paddle::framework::OpDesc>, | ||
ops::RReluGradOpMaker<paddle::imperative::OpBase>, | ||
RReluInferShapeFunctor); | ||
|
||
DECLARE_INFER_SHAPE_FUNCTOR(rrelu_grad, RReluGradInferShapeFunctor, | ||
PD_INFER_META(phi::RReluGradInferMeta)); | ||
REGISTER_OPERATOR(rrelu_grad, ops::RReluGradOp, RReluGradInferShapeFunctor); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/rrelu_grad_kernel.h" | ||
|
||
#include "paddle/phi/backends/cpu/cpu_context.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
|
||
namespace phi { | ||
|
||
template <typename T, typename Context> | ||
void RReluGradKernel(const Context& dev_ctx, | ||
const DenseTensor& x, | ||
const DenseTensor& noise, | ||
const DenseTensor& out_grad, | ||
DenseTensor* x_grad) { | ||
const T* n_ptr = noise.data<T>(); | ||
const T* x_ptr = x.data<T>(); | ||
const T* out_grad_ptr = out_grad.data<T>(); | ||
int numel = x.numel(); | ||
if (!x_grad) return; | ||
|
||
int i = 0; | ||
T* x_grad_ptr = dev_ctx.template Alloc<T>(x_grad); | ||
for (i = 0; i < numel; i++) { | ||
x_grad_ptr[i] = x_ptr[i] > 0 ? out_grad_ptr[i] : n_ptr[i] * out_grad_ptr[i]; | ||
} | ||
} | ||
|
||
} // namespace phi | ||
|
||
PD_REGISTER_KERNEL( | ||
rrelu_grad, CPU, ALL_LAYOUT, phi::RReluGradKernel, float, double) {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/rrelu_kernel.h" | ||
|
||
#include "paddle/fluid/framework/generator.h" | ||
#include "paddle/phi/backends/cpu/cpu_context.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
|
||
namespace phi { | ||
|
||
template <typename T, typename Context> | ||
void RReluKernel(const Context& dev_ctx, | ||
const DenseTensor& x, | ||
const float lower, | ||
const float upper, | ||
bool is_test, | ||
DenseTensor* out, | ||
DenseTensor* noise) { | ||
const T* x_ptr = x.data<T>(); | ||
T* o_ptr = dev_ctx.template Alloc<T>(out); | ||
T* n_ptr = dev_ctx.template Alloc<T>(noise); | ||
T zero = static_cast<T>(0); | ||
int numel = x.numel(); | ||
int i = 0; | ||
|
||
if (is_test) { | ||
T mid_val = static_cast<T>((lower + upper) / 2.0); | ||
for (i = 0; i < numel; i++) { | ||
if (x_ptr[i] < zero) { | ||
o_ptr[i] = mid_val * x_ptr[i]; | ||
n_ptr[i] = mid_val; | ||
} else { | ||
o_ptr[i] = x_ptr[i]; | ||
n_ptr[i] = 1.0; | ||
} | ||
} | ||
|
||
return; | ||
} | ||
|
||
auto engine = paddle::framework::GetCPURandomEngine(0); | ||
|
||
std::uniform_real_distribution<float> dist(lower, upper); | ||
|
||
for (i = 0; i < numel; i++) { | ||
if (x_ptr[i] < zero) { | ||
T scale = static_cast<T>(dist(*engine)); | ||
o_ptr[i] = scale * x_ptr[i]; | ||
n_ptr[i] = scale; | ||
} else { | ||
o_ptr[i] = x_ptr[i]; | ||
n_ptr[i] = 1.0; | ||
} | ||
} | ||
} | ||
|
||
} // namespace phi | ||
|
||
PD_REGISTER_KERNEL(rrelu, | ||
CPU, | ||
ALL_LAYOUT, | ||
phi::RReluKernel, | ||
float, | ||
phi::dtype::float16, | ||
double) {} |
Oops, something went wrong.