Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 6th No.25】为 paddle.nn.functional.threshold 进行功能对齐与功能增强 -part #63453

Merged
merged 4 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/phi/kernels/activation_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,6 @@ DECLARE_ACTIVATION_GRAD_KERNEL_NODEP(Floor);
DECLARE_ACTIVATION_GRAD_KERNEL_NODEP(Ceil);

DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(LeakyRelu, alpha);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(ThresholdedRelu, threshold);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(SoftShrink, lambda);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(HardShrink, threshold);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Logit, eps);
Expand All @@ -318,5 +317,6 @@ DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(HardTanh, t_min, t_max);
DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(STanh, scale_a, scale_b);
DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(Softplus, beta, threshold);
DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPOUT(HardSigmoid, slope, offset);
DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(ThresholdedRelu, threshold, value);

} // namespace phi
2 changes: 1 addition & 1 deletion paddle/phi/kernels/activation_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ DECLARE_ACTIVATION_KERNEL(Ceil)
DECLARE_ACTIVATION_KERNEL(Negative)

DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(LeakyRelu, alpha)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, threshold)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(SoftShrink, lambda)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Mish, threshold)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(HardShrink, threshold)
Expand All @@ -87,6 +86,7 @@ DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(HardTanh, t_min, t_max)
DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(STanh, scale_a, scale_b)
DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(Softplus, beta, threshold)
DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(HardSigmoid, slope, offset)
DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(ThresholdedRelu, threshold, value)

template <typename T, typename Context>
void HardSwishKernel(const Context& dev_ctx,
Expand Down
7 changes: 4 additions & 3 deletions paddle/phi/kernels/cpu/activation_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,6 @@ DEFINE_CPU_ACTIVATION_GRAD_KERNEL_NODEP(Ceil, ZeroGradFunctor);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(LeakyRelu,
LeakyReluGradFunctor,
alpha);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(ThresholdedRelu,
ThresholdedReluGradFunctor,
threshold);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(SoftShrink,
SoftShrinkGradFunctor,
lambda);
Expand Down Expand Up @@ -188,6 +185,10 @@ DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPOUT(HardSigmoid,
HardSigmoidGradFunctor,
slope,
offset);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(ThresholdedRelu,
ThresholdedReluGradFunctor,
threshold,
value);

template <typename T, typename Context>
void SiluGradKernel(const Context& dev_ctx,
Expand Down
7 changes: 4 additions & 3 deletions paddle/phi/kernels/cpu/activation_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,6 @@ DEFINE_CPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Exp, ExpFunctor)
DEFINE_CPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Expm1, Expm1Functor)

DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, LeakyReluFunctor, alpha)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu,
ThresholdedReluFunctor,
threshold)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Mish, MishFunctor, threshold)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink, HardShrinkFunctor, threshold)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(SoftShrink, SoftShrinkFunctor, lambda)
Expand All @@ -122,6 +119,10 @@ DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(HardSigmoid,
HardSigmoidFunctor,
slope,
offset)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(ThresholdedRelu,
ThresholdedReluFunctor,
threshold,
value)

template <typename T, typename Context>
void HardSwishKernel(const Context& dev_ctx,
Expand Down
20 changes: 12 additions & 8 deletions paddle/phi/kernels/funcs/activation_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1825,22 +1825,25 @@ struct LeakyReluGradGradFunctor : public BaseActivationFunctor<T> {
template <typename T>
struct ThresholdedReluFunctor : public BaseActivationFunctor<T> {
float threshold;
float value;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
return {{"threshold", &threshold}, {"value", &value}};
}

template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
auto th = static_cast<T>(threshold); // NOLINT
out.device(d) = (x > th).template cast<T>() * x;
out.device(d) = (x > th).template cast<T>() * x +
(x <= th).template cast<T>() * static_cast<T>(value);
}
};

template <typename T>
struct ThresholdedReluGradFunctor : public BaseActivationFunctor<T> {
float threshold;
float value;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
return {{"threshold", &threshold}, {"value", &value}};
}

template <typename Device,
Expand Down Expand Up @@ -4230,26 +4233,27 @@ struct CudaHardTanhGradFunctor : public BaseActivationFunctor<T> {

template <typename T>
struct CudaThresholdedReluFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
float threshold;
float value;

typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
return {{"threshold", &threshold}, {"value", &value}};
}

// thresholded_relu(x) = x > threshold ? x : 0
// thresholded_relu(x, threshold, value) = x > threshold ? x : value
__device__ __forceinline__ T operator()(const T x) const {
return x > static_cast<T>(threshold) ? x : zero;
return x > static_cast<T>(threshold) ? x : static_cast<T>(value);
}
};

template <typename T>
struct CudaThresholdedReluGradFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
float threshold;
float value;

typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
return {{"threshold", &threshold}, {"value", &value}};
}

// dx = x > threshold ? dout : 0
Expand Down
8 changes: 4 additions & 4 deletions paddle/phi/kernels/gpu/activation_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,6 @@ DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Swish, CudaSwishGradFunctor);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(LeakyRelu,
CudaLeakyReluGradFunctor,
alpha);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(ThresholdedRelu,
CudaThresholdedReluGradFunctor,
threshold);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(SoftShrink,
CudaSoftShrinkGradFunctor,
lambda);
Expand Down Expand Up @@ -247,7 +244,10 @@ DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPOUT(HardSigmoid,
CudaHardSigmoidGradFunctor,
slope,
offset);

DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(ThresholdedRelu,
CudaThresholdedReluGradFunctor,
threshold,
value);
template <typename T, typename Context>
void SiluGradKernel(const Context& dev_ctx,
const DenseTensor& x,
Expand Down
7 changes: 4 additions & 3 deletions paddle/phi/kernels/gpu/activation_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,6 @@ DEFINE_GPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Expm1, CudaExpm1Functor)

DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, CudaLeakyReluFunctor, alpha)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LogitCUDA, CudaLogitFunctor, eps)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu,
CudaThresholdedReluFunctor,
threshold)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink,
CudaHardShrinkFunctor,
threshold)
Expand All @@ -148,6 +145,10 @@ DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(HardSigmoid,
slope,
offset)
DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(Selu, CudaSeluFunctor, scale, alpha)
DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(ThresholdedRelu,
CudaThresholdedReluFunctor,
threshold,
value)

template <typename T, typename Context>
void HardSwishKernel(const Context& dev_ctx,
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/ops/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3194,8 +3194,8 @@
func : tensor_unfold_grad

- backward_op : thresholded_relu_grad
forward : thresholded_relu (Tensor x, float threshold) -> Tensor(out)
args : (Tensor x, Tensor out_grad, float threshold)
forward : thresholded_relu (Tensor x, float threshold, float value) -> Tensor(out)
args : (Tensor x, Tensor out_grad, float threshold, float value)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
Expand Down
8 changes: 8 additions & 0 deletions paddle/phi/ops/yaml/op_version.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,14 @@
comment : A flag to indicate whether to do softmax
default : "true"

- op : thresholded_relu
version :
- checkpoint : Upgrade thresholded_relu, add a new attribute [value]
action :
- add_attr : value
comment : The threshold value of thresholded_relu.
default : 0.0

- op : trace
version :
- checkpoint : Upgrade trace add a new attribute [axis2]
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4209,7 +4209,7 @@
no_need_buffer : input

- op : thresholded_relu
args : (Tensor x, float threshold = 1.0)
args : (Tensor x, float threshold = 1.0, float value = 0.0)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
Expand Down
13 changes: 7 additions & 6 deletions python/paddle/nn/functional/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1543,7 +1543,7 @@ def tanhshrink(x, name=None):
return out


def thresholded_relu(x, threshold=1.0, name=None):
def thresholded_relu(x, threshold=1.0, value=0.0, name=None):
r"""
thresholded relu activation.

Expand All @@ -1553,14 +1553,15 @@ def thresholded_relu(x, threshold=1.0, name=None):
\left\{
\begin{array}{rl}
x,& \text{if } \ x > threshold \\
0,& \text{otherwise}
value,& \text{otherwise}
\end{array}
\right.


Parameters:
x (Tensor): The input Tensor with data type float32, float64.
threshold (float, optional): The value of threshold for thresholded_relu. Default is 1.0
value (float, optional): The value to replace with when x is less than threshold. Default is 0.0
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.

Returns:
Expand All @@ -1580,7 +1581,7 @@ def thresholded_relu(x, threshold=1.0, name=None):
"""

if in_dynamic_or_pir_mode():
return _C_ops.thresholded_relu(x, threshold)
return _C_ops.thresholded_relu(x, threshold, value)
else:
check_variable_and_dtype(
x,
Expand All @@ -1594,19 +1595,19 @@ def thresholded_relu(x, threshold=1.0, name=None):
type='thresholded_relu',
inputs={'X': x},
outputs={'Out': out},
attrs={'threshold': threshold},
attrs={'threshold': threshold, 'value': value},
)
return out


@inplace_apis_in_dygraph_only
def thresholded_relu_(x, threshold=1.0, name=None):
def thresholded_relu_(x, threshold=1.0, value=0.0, name=None):
r"""
Inplace version of ``thresholded_relu`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_nn_functional_thresholded_relu`.
"""
if in_dynamic_mode():
return _C_ops.thresholded_relu_(x, threshold)
return _C_ops.thresholded_relu_(x, threshold, value)


def log_softmax(x, axis=-1, dtype=None, name=None):
Expand Down
10 changes: 6 additions & 4 deletions python/paddle/nn/layer/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,13 +1164,14 @@ class ThresholdedReLU(Layer):
\left\{
\begin{array}{rl}
x,& \text{if } \ x > threshold \\
0,& \text{otherwise}
value,& \text{otherwise}
\end{array}
\right.


Parameters:
threshold (float, optional): The value of threshold for ThresholdedReLU. Default is 1.0
value (float, optinal): The value to replace with when x is less than threshold. Default is 0.0
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.

Expand All @@ -1191,17 +1192,18 @@ class ThresholdedReLU(Layer):
[2., 0., 0.])
"""

def __init__(self, threshold=1.0, name=None):
def __init__(self, threshold=1.0, value=0.0, name=None):
super().__init__()
self._threshold = threshold
self._value = value
self._name = name

def forward(self, x):
return F.thresholded_relu(x, self._threshold, self._name)
return F.thresholded_relu(x, self._threshold, self._value, self._name)

def extra_repr(self):
name_str = f', name={self._name}' if self._name else ''
return f'threshold={self._threshold}{name_str}'
return f'threshold={self._threshold}, value={self._value}{name_str}'


class Silu(Layer):
Expand Down
Loading