From 44f259963025caa23975f4de588879c7fb83881f Mon Sep 17 00:00:00 2001 From: PuQing Date: Wed, 27 Mar 2024 03:17:16 +0000 Subject: [PATCH 1/3] Update thresholded_relu forward and backward functions --- paddle/phi/api/yaml/backward.yaml | 4 +- paddle/phi/kernels/activation_grad_kernel.h | 2 +- paddle/phi/kernels/activation_kernel.h | 2 +- .../phi/kernels/cpu/activation_grad_kernel.cc | 7 ++- paddle/phi/kernels/cpu/activation_kernel.cc | 7 ++- paddle/phi/kernels/funcs/activation_functor.h | 19 +++--- .../phi/kernels/gpu/activation_grad_kernel.cu | 7 ++- paddle/phi/kernels/gpu/activation_kernel.cu | 7 ++- python/paddle/nn/functional/activation.py | 13 ++-- python/paddle/nn/layer/activation.py | 10 +-- test/legacy_test/test_activation_op.py | 62 ++++++++++++------- 11 files changed, 85 insertions(+), 55 deletions(-) diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index c53f81cad71f4..20e974ebab000 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -2504,8 +2504,8 @@ func : tensor_unfold_grad - backward_op : thresholded_relu_grad - forward : thresholded_relu (Tensor x, float threshold) -> Tensor(out) - args : (Tensor x, Tensor out_grad, float threshold) + forward : thresholded_relu (Tensor x, float threshold, float value) -> Tensor(out) + args : (Tensor x, Tensor out_grad, float threshold, float value) output : Tensor(x_grad) infer_meta : func : UnchangedInferMeta diff --git a/paddle/phi/kernels/activation_grad_kernel.h b/paddle/phi/kernels/activation_grad_kernel.h index b2fae7b0406e0..1db80db5248d8 100644 --- a/paddle/phi/kernels/activation_grad_kernel.h +++ b/paddle/phi/kernels/activation_grad_kernel.h @@ -307,7 +307,6 @@ DECLARE_ACTIVATION_GRAD_KERNEL_NODEP(Floor); DECLARE_ACTIVATION_GRAD_KERNEL_NODEP(Ceil); DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(LeakyRelu, alpha); -DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(ThresholdedRelu, threshold); DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(SoftShrink, lambda); DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(HardShrink, threshold); DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Logit, eps); @@ -317,6 +316,7 @@ DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Celu, alpha); DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(HardTanh, t_min, t_max); DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(STanh, scale_a, scale_b); DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(Softplus, beta, threshold); +DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(ThresholdedRelu, threshold, value); DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPOUT(HardSigmoid, slope, offset); } // namespace phi diff --git a/paddle/phi/kernels/activation_kernel.h b/paddle/phi/kernels/activation_kernel.h index 70c0187e68865..bf3cb325160d3 100644 --- a/paddle/phi/kernels/activation_kernel.h +++ b/paddle/phi/kernels/activation_kernel.h @@ -74,7 +74,6 @@ DECLARE_ACTIVATION_KERNEL(Ceil) DECLARE_ACTIVATION_KERNEL(Negative) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(LeakyRelu, alpha) -DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, threshold) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(SoftShrink, lambda) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Mish, threshold) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(HardShrink, threshold) @@ -87,6 +86,7 @@ DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(HardTanh, t_min, t_max) DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(STanh, scale_a, scale_b) DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(Softplus, beta, threshold) DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(HardSigmoid, slope, offset) +DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(ThresholdedRelu, threshold, value) template void HardSwishKernel(const Context& dev_ctx, diff --git a/paddle/phi/kernels/cpu/activation_grad_kernel.cc b/paddle/phi/kernels/cpu/activation_grad_kernel.cc index cb821233004f8..52972b51d009c 100644 --- a/paddle/phi/kernels/cpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_grad_kernel.cc @@ -155,9 +155,6 @@ DEFINE_CPU_ACTIVATION_GRAD_KERNEL_NODEP(Ceil, ZeroGradFunctor); DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(LeakyRelu, LeakyReluGradFunctor, alpha); -DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(ThresholdedRelu, - ThresholdedReluGradFunctor, - threshold); DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(SoftShrink, SoftShrinkGradFunctor, lambda); @@ -184,6 +181,10 @@ DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(Softplus, SoftplusGradFunctor, beta, threshold); +DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(ThresholdedRelu, + ThresholdedReluGradFunctor, + threshold, + value); DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPOUT(HardSigmoid, HardSigmoidGradFunctor, slope, diff --git a/paddle/phi/kernels/cpu/activation_kernel.cc b/paddle/phi/kernels/cpu/activation_kernel.cc index 11312aa3a7972..f3c3d8c022821 100644 --- a/paddle/phi/kernels/cpu/activation_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_kernel.cc @@ -106,9 +106,6 @@ DEFINE_CPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Exp, ExpFunctor) DEFINE_CPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Expm1, Expm1Functor) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, LeakyReluFunctor, alpha) -DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, - ThresholdedReluFunctor, - threshold) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Mish, MishFunctor, threshold) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink, HardShrinkFunctor, threshold) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(SoftShrink, SoftShrinkFunctor, lambda) @@ -122,6 +119,10 @@ DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(HardSigmoid, HardSigmoidFunctor, slope, offset) +DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(ThresholdedRelu, + ThresholdedReluFunctor, + threshold, + value) template void HardSwishKernel(const Context& dev_ctx, diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index 8b83fcb0d10c1..42d5c326711c3 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -1825,22 +1825,25 @@ struct LeakyReluGradGradFunctor : public BaseActivationFunctor { template struct ThresholdedReluFunctor : public BaseActivationFunctor { float threshold; + float value; typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"threshold", &threshold}}; + return {{"threshold", &threshold}, {"value", &value}}; } template void operator()(Device d, X x, Out out) const { auto th = static_cast(threshold); // NOLINT - out.device(d) = (x > th).template cast() * x; + out.device(d) = (x > th).template cast() * x + + (x <= th).template cast() * static_cast(value); } }; template struct ThresholdedReluGradFunctor : public BaseActivationFunctor { float threshold; + float value; typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"threshold", &threshold}}; + return {{"threshold", &threshold}, {"value", &value}}; } template struct CudaThresholdedReluFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); float threshold; + float value; typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"threshold", &threshold}}; + return {{"threshold", &threshold}, {"value", &value}}; } - // thresholded_relu(x) = x > threshold ? x : 0 + // thresholded_relu(x, threshold, value) = x > threshold ? x : value __device__ __forceinline__ T operator()(const T x) const { - return x > static_cast(threshold) ? x : zero; + return x > static_cast(threshold) ? x : static_cast(value); } }; @@ -4115,9 +4119,10 @@ template struct CudaThresholdedReluGradFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); float threshold; + float value; typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"threshold", &threshold}}; + return {{"threshold", &threshold}, {"value", &value}}; } // dx = x > threshold ? dout : 0 diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index 7af857345cdd6..a052fca83a799 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -209,9 +209,6 @@ DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Swish, CudaSwishGradFunctor); DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(LeakyRelu, CudaLeakyReluGradFunctor, alpha); -DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(ThresholdedRelu, - CudaThresholdedReluGradFunctor, - threshold); DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(SoftShrink, CudaSoftShrinkGradFunctor, lambda); @@ -243,6 +240,10 @@ DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(Softplus, CudaSoftplusGradFunctor, beta, threshold); +DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(ThresholdedRelu, + CudaThresholdedReluGradFunctor, + threshold, + value); DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPOUT(HardSigmoid, CudaHardSigmoidGradFunctor, slope, diff --git a/paddle/phi/kernels/gpu/activation_kernel.cu b/paddle/phi/kernels/gpu/activation_kernel.cu index e8dadf31fd945..30ee664cae556 100644 --- a/paddle/phi/kernels/gpu/activation_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_kernel.cu @@ -123,9 +123,6 @@ DEFINE_GPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Expm1, CudaExpm1Functor) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, CudaLeakyReluFunctor, alpha) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LogitCUDA, CudaLogitFunctor, eps) -DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, - CudaThresholdedReluFunctor, - threshold) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink, CudaHardShrinkFunctor, threshold) @@ -148,6 +145,10 @@ DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(HardSigmoid, slope, offset) DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(Selu, CudaSeluFunctor, scale, alpha) +DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(ThresholdedRelu, + CudaThresholdedReluFunctor, + threshold, + value) template void HardSwishKernel(const Context& dev_ctx, diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index acf85a5f675ce..7a16135d6c3ff 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -1547,7 +1547,7 @@ def tanhshrink(x, name=None): return out -def thresholded_relu(x, threshold=1.0, name=None): +def thresholded_relu(x, threshold=1.0, value=0.0, name=None): r""" thresholded relu activation. @@ -1557,7 +1557,7 @@ def thresholded_relu(x, threshold=1.0, name=None): \left\{ \begin{array}{rl} x,& \text{if } \ x > threshold \\ - 0,& \text{otherwise} + value,& \text{otherwise} \end{array} \right. @@ -1565,6 +1565,7 @@ def thresholded_relu(x, threshold=1.0, name=None): Parameters: x (Tensor): The input Tensor with data type float32, float64. threshold (float, optional): The value of threshold for thresholded_relu. Default is 1.0 + value (float, optional): The value to replace with when x is less than threshold. Default is 0.0 name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: @@ -1584,7 +1585,7 @@ def thresholded_relu(x, threshold=1.0, name=None): """ if in_dynamic_or_pir_mode(): - return _C_ops.thresholded_relu(x, threshold) + return _C_ops.thresholded_relu(x, threshold, value) else: check_variable_and_dtype( x, @@ -1598,19 +1599,19 @@ def thresholded_relu(x, threshold=1.0, name=None): type='thresholded_relu', inputs={'X': x}, outputs={'Out': out}, - attrs={'threshold': threshold}, + attrs={'threshold': threshold, 'value': value}, ) return out @inplace_apis_in_dygraph_only -def thresholded_relu_(x, threshold=1.0, name=None): +def thresholded_relu_(x, threshold=1.0, value=0.0, name=None): r""" Inplace version of ``thresholded_relu`` API, the output Tensor will be inplaced with input ``x``. Please refer to :ref:`api_paddle_nn_functional_thresholded_relu`. """ if in_dynamic_mode(): - return _C_ops.thresholded_relu_(x, threshold) + return _C_ops.thresholded_relu_(x, threshold, value) def log_softmax(x, axis=-1, dtype=None, name=None): diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index 59a9436dadb51..5430d994bc847 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -1172,13 +1172,14 @@ class ThresholdedReLU(Layer): \left\{ \begin{array}{rl} x,& \text{if } \ x > threshold \\ - 0,& \text{otherwise} + value,& \text{otherwise} \end{array} \right. Parameters: threshold (float, optional): The value of threshold for ThresholdedReLU. Default is 1.0 + value (float, optinal): The value to replace with when x is less than threshold. Default is 0.0 name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -1199,17 +1200,18 @@ class ThresholdedReLU(Layer): [2., 0., 0.]) """ - def __init__(self, threshold=1.0, name=None): + def __init__(self, threshold=1.0, value=0.0, name=None): super().__init__() self._threshold = threshold + self._value = value self._name = name def forward(self, x): - return F.thresholded_relu(x, self._threshold, self._name) + return F.thresholded_relu(x, self._threshold, self._value, self._name) def extra_repr(self): name_str = f', name={self._name}' if self._name else '' - return f'threshold={self._threshold}{name_str}' + return f'threshold={self._threshold}value={self._value}{name_str}' class Silu(Layer): diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index 2607f9a170ecb..6e812f4ba939c 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -3287,26 +3287,34 @@ def test_check_grad(self): self.check_grad( ['X'], 'Out', - check_prim=True - if self.dtype not in [np.complex64, np.complex128] - else False, + check_prim=( + True + if self.dtype not in [np.complex64, np.complex128] + else False + ), only_check_prim=self.if_only_check_prim(), check_pir=True, - check_prim_pir=True - if self.dtype not in [np.complex64, np.complex128] - else False, + check_prim_pir=( + True + if self.dtype not in [np.complex64, np.complex128] + else False + ), check_pir_onednn=self.check_pir_onednn, ) def test_check_output(self): self.check_output( - check_prim=True - if self.dtype not in [np.complex64, np.complex128] - else False, + check_prim=( + True + if self.dtype not in [np.complex64, np.complex128] + else False + ), check_pir=True, - check_prim_pir=True - if self.dtype not in [np.complex64, np.complex128] - else False, + check_prim_pir=( + True + if self.dtype not in [np.complex64, np.complex128] + else False + ), check_pir_onednn=self.check_pir_onednn, ) @@ -4720,8 +4728,8 @@ def test_errors(self): F.softsign(x_fp16) -def ref_thresholded_relu(x, threshold=1.0): - out = (x > threshold) * x +def ref_thresholded_relu(x, threshold=1.0, value=0.0): + out = (x > threshold) * x + (x <= threshold) * value return out @@ -4733,15 +4741,16 @@ def setUp(self): self.python_api = paddle.nn.functional.thresholded_relu threshold = 15 + value = 5 np.random.seed(1024) x = np.random.uniform(-20, 20, self.shape).astype(self.dtype) x[np.abs(x) < 0.005] = 0.02 - out = ref_thresholded_relu(x, threshold) + out = ref_thresholded_relu(x, threshold, value) self.inputs = {'X': OpTest.np_dtype_to_base_dtype(x)} self.outputs = {'Out': out} - self.attrs = {"threshold": threshold} + self.attrs = {"threshold": threshold, "value": value} self.convert_input_output() def init_shape(self): @@ -4769,6 +4778,7 @@ class TestThresholdedReluAPI(unittest.TestCase): # test paddle.nn.ThresholdedReLU, paddle.nn.functional.thresholded_relu def setUp(self): self.threshold = 15 + self.value = 5 np.random.seed(1024) self.x_np = np.random.uniform(-20, 20, [10, 12]).astype(np.float64) self.x_np[np.abs(self.x_np) < 0.005] = 0.02 @@ -4783,22 +4793,30 @@ def test_static_api(self): with static_guard(): with paddle.static.program_guard(paddle.static.Program()): x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype) - out1 = F.thresholded_relu(x, self.threshold) - thresholded_relu = paddle.nn.ThresholdedReLU(self.threshold) + out1 = F.thresholded_relu(x, self.threshold, self.value) + thresholded_relu = paddle.nn.ThresholdedReLU( + self.threshold, self.value + ) out2 = thresholded_relu(x) exe = paddle.static.Executor(self.place) res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2]) - out_ref = ref_thresholded_relu(self.x_np, self.threshold) + out_ref = ref_thresholded_relu( + self.x_np, self.threshold, self.value + ) for r in res: np.testing.assert_allclose(out_ref, r, rtol=1e-05) def test_dygraph_api(self): with dynamic_guard(): x = paddle.to_tensor(self.x_np) - out1 = F.thresholded_relu(x, self.threshold) - thresholded_relu = paddle.nn.ThresholdedReLU(self.threshold) + out1 = F.thresholded_relu(x, self.threshold, self.value) + thresholded_relu = paddle.nn.ThresholdedReLU( + self.threshold, self.value + ) out2 = thresholded_relu(x) - out_ref = ref_thresholded_relu(self.x_np, self.threshold) + out_ref = ref_thresholded_relu( + self.x_np, self.threshold, self.value + ) for r in [out1, out2]: np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) From 1d4a2a6dd16e1aa446a1e44c8b3aceacc9ba9f36 Mon Sep 17 00:00:00 2001 From: PuQing Date: Wed, 27 Mar 2024 03:19:07 +0000 Subject: [PATCH 2/3] Add weight parameter to histogram op --- paddle/phi/api/yaml/op_compat.yaml | 2 +- paddle/phi/api/yaml/ops.yaml | 5 +- paddle/phi/infermeta/unary.cc | 28 ++++++- paddle/phi/infermeta/unary.h | 9 ++- paddle/phi/kernels/cpu/histogram_kernel.cc | 56 ++++++++++---- paddle/phi/kernels/gpu/histogram_kernel.cu | 89 +++++++++++++++++----- paddle/phi/kernels/histogram_kernel.h | 2 + python/paddle/tensor/__init__.py | 1 + python/paddle/tensor/linalg.py | 77 +++++++++++++++++-- test/legacy_test/test_histogram_op.py | 47 ++++++++++++ 10 files changed, 269 insertions(+), 47 deletions(-) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 0c3f7488362eb..ebe70b05724e8 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1745,7 +1745,7 @@ - op : histogram inputs : - input : X + {input: X, weight: Weight} outputs : out : Out diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 3693e31721c14..44a67f5f21938 100755 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1336,8 +1336,9 @@ backward : heaviside_grad - op : histogram - args : (Tensor input, int64_t bins = 100, int min = 0, int max = 0) + args : (Tensor input, Tensor weight, bool density = false, int64_t bins = 100, int min = 0, int max = 0) output : Tensor(out) + optional: weight infer_meta : func : HistogramInferMeta kernel : @@ -2837,7 +2838,7 @@ no_need_buffer : input - op : thresholded_relu - args : (Tensor x, float threshold = 1.0) + args : (Tensor x, float threshold = 1.0, float value = 0.0) output : Tensor(out) infer_meta : func : UnchangedInferMeta diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 64262af8885d9..1cdb3cb8105b0 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -22,6 +22,7 @@ limitations under the License. */ #include "paddle/phi/common/type_traits.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/core/meta_tensor.h" #include "paddle/phi/core/utils/data_type.h" #include "paddle/phi/kernels/funcs/parse_qr_mode.h" #include "paddle/phi/kernels/funcs/pooling.h" @@ -1916,8 +1917,13 @@ void GumbelSoftmaxInferMeta(const MetaTensor& x, UnchangedInferMetaCheckAxis(x, axis, out); } -void HistogramInferMeta( - const MetaTensor& input, int64_t bins, int min, int max, MetaTensor* out) { +void HistogramInferMeta(const MetaTensor& input, + const MetaTensor& weight, + bool density, + int64_t bins, + int min, + int max, + MetaTensor* out) { PADDLE_ENFORCE_GE(bins, 1, phi::errors::InvalidArgument( @@ -1932,9 +1938,25 @@ void HistogramInferMeta( max, min)); + if (weight) { + auto weight_dims = weight.dims(); + PADDLE_ENFORCE_EQ( + weight_dims, + input.dims(), + phi::errors::InvalidArgument( + "The shape of weight should be equal to the shape of input." + "But received weight shape is [%s], input shape is [%s]", + weight_dims, + input.dims())); + } + out->set_dims({bins}); out->share_lod(input); - out->set_dtype(DataType::INT64); + if (density || weight) { + out->set_dtype(DataType::FLOAT32); + } else { + out->set_dtype(DataType::INT64); + } } void IdentityLossInferMeta(const MetaTensor& x, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 3314545faa185..8e4df067698dc 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -307,8 +307,13 @@ void GumbelSoftmaxInferMeta(const MetaTensor& x, bool hard, int axis, MetaTensor* out); -void HistogramInferMeta( - const MetaTensor& input, int64_t bins, int min, int max, MetaTensor* out); +void HistogramInferMeta(const MetaTensor& input, + const MetaTensor& weight, + bool density, + int64_t bins, + int min, + int max, + MetaTensor* out); void IdentityLossInferMeta(const MetaTensor& x, int reduction, MetaTensor* out); diff --git a/paddle/phi/kernels/cpu/histogram_kernel.cc b/paddle/phi/kernels/cpu/histogram_kernel.cc index 030dee9908b31..5bd94a7e1fcef 100644 --- a/paddle/phi/kernels/cpu/histogram_kernel.cc +++ b/paddle/phi/kernels/cpu/histogram_kernel.cc @@ -13,16 +13,24 @@ // limitations under the License. #include "paddle/phi/kernels/histogram_kernel.h" +#include #include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/backends/device_ext.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/int_array.h" +#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/reduce_sum_kernel.h" +#include "paddle/utils/optional.h" namespace phi { - template void HistogramKernel(const Context& dev_ctx, const DenseTensor& input, + const paddle::optional& weight, + bool density, int64_t bins, int min, int max, @@ -32,12 +40,9 @@ void HistogramKernel(const Context& dev_ctx, auto& maxval = max; const T* input_data = input.data(); + auto weight_data = weight.get_ptr() ? weight.get_ptr()->data() : nullptr; auto input_numel = input.numel(); - int64_t* out_data = dev_ctx.template Alloc(output); - phi::funcs::SetConstant()( - dev_ctx, output, static_cast(0)); - if (input_data == nullptr) return; T output_min = static_cast(minval); @@ -67,11 +72,38 @@ void HistogramKernel(const Context& dev_ctx, maxval, minval)); - for (int64_t i = 0; i < input_numel; i++) { - if (input_data[i] >= output_min && input_data[i] <= output_max) { - const int64_t bin = (int64_t)((input_data[i] - output_min) * nbins / - (output_max - output_min)); - out_data[std::min(bin, nbins - 1)] += 1; + if (density || weight_data) { + float* out_data = dev_ctx.template Alloc(output); + phi::funcs::SetConstant()( + dev_ctx, output, static_cast(0)); + for (int64_t i = 0; i < input_numel; i++) { + if (input_data[i] >= output_min && input_data[i] <= output_max) { + const int64_t bin = (int64_t)((input_data[i] - output_min) * nbins / + (output_max - output_min)); + out_data[std::min(bin, nbins - 1)] += + weight_data ? static_cast(weight_data[i]) : 1; + } + } + if (density) { + DenseTensor sum = phi::Sum( + dev_ctx, *output, phi::IntArray({0}), phi::DataType::FLOAT32, false); + float* sum_data = sum.data(); + float gap = static_cast(nbins) / + static_cast((output_max - output_min)) / *sum_data; + for (int64_t i = 0; i < nbins; i++) { + out_data[i] *= gap; + } + } + } else { + int64_t* out_data = dev_ctx.template Alloc(output); + phi::funcs::SetConstant()( + dev_ctx, output, static_cast(0)); + for (int64_t i = 0; i < input_numel; i++) { + if (input_data[i] >= output_min && input_data[i] <= output_max) { + const int64_t bin = (int64_t)((input_data[i] - output_min) * nbins / + (output_max - output_min)); + out_data[std::min(bin, nbins - 1)] += 1; + } } } } @@ -85,6 +117,4 @@ PD_REGISTER_KERNEL(histogram, float, double, int, - int64_t) { - kernel->OutputAt(0).SetDataType(paddle::DataType::INT64); -} + int64_t) {} diff --git a/paddle/phi/kernels/gpu/histogram_kernel.cu b/paddle/phi/kernels/gpu/histogram_kernel.cu index aa10aea35f867..b0eaba8b6889c 100644 --- a/paddle/phi/kernels/gpu/histogram_kernel.cu +++ b/paddle/phi/kernels/gpu/histogram_kernel.cu @@ -13,13 +13,23 @@ // limitations under the License. #include "paddle/phi/kernels/histogram_kernel.h" +#include +#include +#include "paddle/phi/api/include/tensor.h" #include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/int_array.h" +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/broadcast_function.h" +#include "paddle/phi/kernels/funcs/functors.h" #include "paddle/phi/kernels/funcs/math_cuda_utils.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/reduce_sum_kernel.h" namespace phi { @@ -41,14 +51,15 @@ __device__ static IndexType GetBin(T input_value, return output_index; } -template +template __global__ void KernelHistogram(const T* input, + const T* weight, const int total_elements, const int64_t nbins, const T* min_value, const T* max_value, - int64_t* output) { - extern __shared__ int64_t buf_hist[]; + Out_T* output) { + extern __shared__ float buf_hist[]; for (int i = threadIdx.x; i < nbins; i += blockDim.x) { buf_hist[i] = 0; } @@ -60,7 +71,8 @@ __global__ void KernelHistogram(const T* input, if (input_value >= *min_value && input_value <= *max_value) { const IndexType output_index = GetBin(input_value, *min_value, *max_value, nbins); - phi::CudaAtomicAdd(&buf_hist[output_index], 1); + phi::CudaAtomicAdd(&buf_hist[output_index], + weight ? static_cast(weight[input_index]) : 1); } } __syncthreads(); @@ -124,9 +136,18 @@ __global__ void KernelMinMax(const T min_value, } } +__global__ void KernelMul(float* data, float* scale, int64_t numel) { + size_t index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < numel) { + data[index] /= *scale; + } +} + template void HistogramKernel(const Context& dev_ctx, const DenseTensor& input, + const paddle::optional& weight, + bool density, int64_t bins, int min, int max, @@ -137,10 +158,7 @@ void HistogramKernel(const Context& dev_ctx, const T* input_data = input.data(); const int input_numel = input.numel(); - - int64_t* out_data = dev_ctx.template Alloc(output); - phi::funcs::SetConstant()( - dev_ctx, output, static_cast(0)); + auto weight_data = weight.get_ptr() ? weight.get_ptr()->data() : nullptr; if (input_data == nullptr) return; @@ -179,13 +197,50 @@ void HistogramKernel(const Context& dev_ctx, minval)); auto stream = dev_ctx.stream(); - KernelHistogram<<>>( - input_data, input_numel, nbins, min_block_ptr, max_block_ptr, out_data); -} + if (!density && !weight_data) { + int64_t* out_data = dev_ctx.template Alloc(output); + phi::funcs::SetConstant()(dev_ctx, output, 0); + KernelHistogram<<>>(input_data, + weight_data, + input_numel, + nbins, + min_block_ptr, + max_block_ptr, + out_data); + return; + + } else { + float* out_data = dev_ctx.template Alloc(output); + phi::funcs::SetConstant()( + dev_ctx, output, static_cast(0)); + KernelHistogram<<>>(input_data, + weight_data, + input_numel, + nbins, + min_block_ptr, + max_block_ptr, + out_data); + if (density) { + DenseTensor sum = phi::Sum( + dev_ctx, *output, phi::IntArray({0}), phi::DataType::FLOAT32, false); + float gap = static_cast(nbins) / + static_cast(output_max - output_min); + std::vector ins = {output}; + std::vector outs = {output}; + auto functor = phi::funcs::ScaleFunctor(gap); + phi::funcs::ElementwiseKernel(dev_ctx, ins, &outs, functor); + KernelMul<<(bins)), + PADDLE_CUDA_NUM_THREADS>>>(out_data, sum.data(), bins); + } + } +} } // namespace phi PD_REGISTER_KERNEL(histogram, @@ -195,6 +250,4 @@ PD_REGISTER_KERNEL(histogram, float, double, int, - int64_t) { - kernel->OutputAt(0).SetDataType(paddle::DataType::INT64); -} + int64_t) {} diff --git a/paddle/phi/kernels/histogram_kernel.h b/paddle/phi/kernels/histogram_kernel.h index 0020f7b0435da..0bd6e05776c5f 100644 --- a/paddle/phi/kernels/histogram_kernel.h +++ b/paddle/phi/kernels/histogram_kernel.h @@ -20,6 +20,8 @@ namespace phi { template void HistogramKernel(const Context& dev_ctx, const DenseTensor& input, + const paddle::optional& weight, + bool density, int64_t bins, int min, int max, diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 4513bcbdba8f8..4b61f8ad754e6 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -78,6 +78,7 @@ eigvals, eigvalsh, histogram, + histogram_bin_edges, histogramdd, householder_product, lstsq, diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 5ff36cdb754d5..07ee82d4c9932 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -2182,7 +2182,9 @@ def bmm(x, y, name=None): return out -def histogram(input, bins=100, min=0, max=0, name=None): +def histogram( + input, weight=None, density: bool = False, bins=100, min=0, max=0, name=None +): """ Computes the histogram of a tensor. The elements are sorted into equal width bins between min and max. If min and max are both zero, the minimum and maximum values of the data are used. @@ -2210,22 +2212,79 @@ def histogram(input, bins=100, min=0, max=0, name=None): [0, 2, 1, 0]) """ if in_dynamic_or_pir_mode(): - return _C_ops.histogram(input, bins, min, max) + return _C_ops.histogram(input, weight, density, bins, min, max) else: helper = LayerHelper('histogram', **locals()) check_variable_and_dtype( input, 'X', ['int32', 'int64', 'float32', 'float64'], 'histogram' ) - out = helper.create_variable_for_type_inference(VarDesc.VarType.INT64) + + if density or weight: + check_variable_and_dtype( + weight, + "Weight", + ["int32", "int64", "float32", "float64"], + "histogram", + ) + out = helper.create_variable_for_type_inference( + dtype=VarDesc.VarType.FLOAT32 + ) + else: + out = helper.create_variable_for_type_inference( + dtype=VarDesc.VarType.INT64 + ) helper.append_op( type='histogram', - inputs={'X': input}, + inputs={'X': input, "Weight": weight}, outputs={'Out': out}, - attrs={'bins': bins, 'min': min, 'max': max}, + attrs={'density': density, 'bins': bins, 'min': min, 'max': max}, ) return out +def histogram_bin_edges(input, bins=100, range=None, name=None): + """ + Computes only the edges of the bins used by the histogram function. + Args: + input (Tensor): A Tensor(or LoDTensor) with shape :math:`[N_1, N_2,..., N_k]` . The data type of the input Tensor + should be float32, float64, int32, int64. + bins (int, optional): number of histogram bins. + range (list | tuple): The lower and upper range of the bins. If None, `range` is simply (input.min(), input.max()). + The first element of the range must be less than or equal to the second. Default: None. + name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + Returns: + Tensor: the values of the histogram and the bin edges. The output data type will be float32. + Examples: + .. code-block:: python + import paddle + inputs = paddle.to_tensor([1, 2, 1]) + result = paddle.histogram_bin_edges(inputs, bins=4, range=(0, 3)) + print(result) # [0., 0.75, 1.5, 2.25, 3.] + """ + check_type(input, 'input', (Variable), 'histogram_bin_edges') + check_dtype( + input.dtype, + 'input', + ['float32', 'float64', 'int32', 'int64'], + 'histogram_bin_edges', + ) + check_type(bins, 'bins', int, 'histogram_bin_edges') + if range is None: + start = paddle.max(input) + stop = paddle.min(input) + else: + check_type(range, 'range', (list, tuple), 'histogram_bin_edges') + if len(range) != 2: + raise ValueError("The length of range should be equal 2") + start, stop = range + if start > stop: + raise ValueError("max must be larger than min in range parameter") + if (stop - start) == 0: + start = start - 0.5 + stop = stop + 0.5 + return paddle.linspace(start, stop, bins + 1, name=name) + + def bincount(x, weights=None, minlength=0, name=None): """ Computes frequency of each value in the input tensor. @@ -4331,9 +4390,11 @@ def _householder_product(x, tau): Q = paddle.static.setitem( Q, (slice(None), slice(i, None)), - Q[:, i:] - (Q[:, i:] @ w @ w.T * tau[i]) - if x.dtype in [paddle.complex128, paddle.complex64] - else Q[:, i:] - (Q[:, i:] @ w @ w.T * tau[i]), + ( + Q[:, i:] - (Q[:, i:] @ w @ w.T * tau[i]) + if x.dtype in [paddle.complex128, paddle.complex64] + else Q[:, i:] - (Q[:, i:] @ w @ w.T * tau[i]) + ), ) return Q[:, :n] diff --git a/test/legacy_test/test_histogram_op.py b/test/legacy_test/test_histogram_op.py index 06d7bec545087..f7098efd77f72 100644 --- a/test/legacy_test/test_histogram_op.py +++ b/test/legacy_test/test_histogram_op.py @@ -159,6 +159,53 @@ def test_check_output(self): self.check_output(check_pir=True) +class TestHistogramOpApi(OpTest): + def setUp(self): + self.op_type = "histogram" + self.init_test_case() + self.python_api = paddle.histogram + self.init_attrs() + Out = np.histogram( + a=self.inputs["X"], + weights=self.inputs["Weight"], + bins=self.bins, + range=(self.min, self.max), + density=self.density, + ) + self.outputs = {"Out": Out[0].astype(np.float32)} + + def init_test_case(self): + self.in_shape = (10, 12) + self.density = False + self.bins = 5 + self.min = 1 + self.max = 5 + + def init_attrs(self): + self.inputs = { + "X": np.random.uniform(low=0.0, high=20.0, size=self.in_shape), + "Weight": np.random.uniform(low=0.0, high=1.0, size=self.in_shape), + } + self.attrs = { + "density": self.density, + "bins": self.bins, + "min": self.min, + "max": self.max, + } + + def test_check_output(self): + self.check_output(check_pir=True) + + +class TestHistogramOpDensity(TestHistogramOpApi): + def init_test_case(self): + self.in_shape = (10, 12) + self.density = True + self.bins = 5 + self.min = 1 + self.max = 5 + + class TestHistogramOp_ZeroDim(TestHistogramOp): def init_test_case(self): self.in_shape = [] From 33ba33148328cf3685fa6c86f6c30dfb553f5142 Mon Sep 17 00:00:00 2001 From: PuQing Date: Sat, 30 Mar 2024 13:39:40 +0000 Subject: [PATCH 3/3] Fix function parameter order and update test case --- python/paddle/tensor/linalg.py | 2 +- test/legacy_test/test_imperative_layers.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 07ee82d4c9932..ac80e3c53e9c5 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -2183,7 +2183,7 @@ def bmm(x, y, name=None): def histogram( - input, weight=None, density: bool = False, bins=100, min=0, max=0, name=None + input, density: bool = False, bins=100, min=0, max=0, weight=None, name=None ): """ Computes the histogram of a tensor. The elements are sorted into equal width bins between min and max. diff --git a/test/legacy_test/test_imperative_layers.py b/test/legacy_test/test_imperative_layers.py index 9906d3ba0ede0..947ab037ee89b 100644 --- a/test/legacy_test/test_imperative_layers.py +++ b/test/legacy_test/test_imperative_layers.py @@ -85,7 +85,9 @@ def test_layer_str(self): self.assertEqual(str(module), 'Tanhshrink()') module = nn.ThresholdedReLU() - self.assertEqual(str(module), 'ThresholdedReLU(threshold=1.0)') + self.assertEqual( + str(module), 'ThresholdedReLU(threshold=1.0, value=0.0)' + ) module = nn.LogSigmoid() self.assertEqual(str(module), 'LogSigmoid()')