diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 495ba53cd7613..fd7116f079d66 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1370,7 +1370,7 @@ - op : histogram inputs : - input : X + {input : X, weight : Weight} outputs : out : Out diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 4e67144ba8a89..34c4383fb7b6f 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1144,12 +1144,13 @@ backward : heaviside_grad - op : histogram - args : (Tensor input, int64_t bins = 100, int min = 0, int max = 0) + args : (Tensor input, Tensor weight, int64_t bins = 100, int min = 0, int max = 0, bool density = false) output : Tensor(out) infer_meta : func : HistogramInferMeta kernel : func : histogram + optional : weight - op : huber_loss args : (Tensor input, Tensor label, float delta) diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index a66790d0ce6cd..1f8fec18a43a9 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -1711,6 +1711,60 @@ void GridSampleBaseInferMeta(const MetaTensor& x, out->share_lod(x); } +void HistogramInferMeta(const MetaTensor& input, + const MetaTensor& weight, + int64_t bins, + int min, + int max, + bool density, + MetaTensor* out) { + auto input_dim = input.dims(); + if (weight) { + auto weight_dim = weight.dims(); + PADDLE_ENFORCE_EQ( + weight_dim, + input_dim, + phi::errors::InvalidArgument( + "The 'shape' of Input(Weight) must be equal to the 'shape' of " + "Input(X)." + "But received: the 'shape' of Input(Weight) is [%s]," + "the 'shape' of Input(X) is [%s]", + weight_dim, + input_dim)); + PADDLE_ENFORCE_EQ( + input.dtype() == weight.dtype(), + true, + phi::errors::InvalidArgument( + "The 'dtpye' of Input(Weight) must be equal to the 'dtype' of " + "Input(X)." + "But received: the 'dtype' of Input(Weight) is [%s]," + "the 'dtype' of Input(X) is [%s]", + weight.dtype(), + input.dtype())); + } + PADDLE_ENFORCE_GE(bins, + 1, + phi::errors::InvalidArgument( + "The bins should be greater than or equal to 1." + "But received nbins is %d", + bins)); + PADDLE_ENFORCE_GE( + max, + min, + phi::errors::InvalidArgument("max must be larger or equal to min." + "But received max is %d, min is %d", + max, + min)); + + out->set_dims({bins}); + out->share_lod(input); + if (density) { + out->set_dtype(DataType::FLOAT32); + } else { + out->set_dtype(input.dtype()); + } +} + void HuberLossInferMeta(const MetaTensor& input, const MetaTensor& label, float delta, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 887da467e07b1..aa9d1d55c6871 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -275,6 +275,14 @@ void GridSampleBaseInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void HistogramInferMeta(const MetaTensor& input, + const MetaTensor& weight, + int64_t bins, + int min, + int max, + bool density, + MetaTensor* out); + void HuberLossInferMeta(const MetaTensor& input_meta, const MetaTensor& label_meta, float delta, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index aa1b6526cd5f8..4e56d21579e3d 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1754,27 +1754,6 @@ void GumbelSoftmaxInferMeta(const MetaTensor& x, UnchangedInferMetaCheckAxis(x, axis, out); } -void HistogramInferMeta( - const MetaTensor& input, int64_t bins, int min, int max, MetaTensor* out) { - PADDLE_ENFORCE_GE(bins, - 1, - phi::errors::InvalidArgument( - "The bins should be greater than or equal to 1." - "But received nbins is %d", - bins)); - PADDLE_ENFORCE_GE( - max, - min, - phi::errors::InvalidArgument("max must be larger or equal to min." - "But received max is %d, min is %d", - max, - min)); - - out->set_dims({bins}); - out->share_lod(input); - out->set_dtype(DataType::INT64); -} - void IdentityLossInferMeta(const MetaTensor& x, int reduction, MetaTensor* out) { diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index a3b7e87d86d0b..e331eb2bae708 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -266,8 +266,6 @@ void GumbelSoftmaxInferMeta(const MetaTensor& x, bool hard, int axis, MetaTensor* out); -void HistogramInferMeta( - const MetaTensor& input, int64_t bins, int min, int max, MetaTensor* out); void IdentityLossInferMeta(const MetaTensor& x, int reduction, MetaTensor* out); diff --git a/paddle/phi/kernels/cpu/histogram_kernel.cc b/paddle/phi/kernels/cpu/histogram_kernel.cc index 030dee9908b31..a5462940564a0 100644 --- a/paddle/phi/kernels/cpu/histogram_kernel.cc +++ b/paddle/phi/kernels/cpu/histogram_kernel.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -23,9 +23,11 @@ namespace phi { template void HistogramKernel(const Context& dev_ctx, const DenseTensor& input, + const paddle::optional& weight, int64_t bins, int min, int max, + bool density, DenseTensor* output) { auto& nbins = bins; auto& minval = min; @@ -34,11 +36,11 @@ void HistogramKernel(const Context& dev_ctx, const T* input_data = input.data(); auto input_numel = input.numel(); - int64_t* out_data = dev_ctx.template Alloc(output); - phi::funcs::SetConstant()( - dev_ctx, output, static_cast(0)); - - if (input_data == nullptr) return; + if (input_data == nullptr) { + dev_ctx.template Alloc(output); + phi::funcs::SetConstant()(dev_ctx, output, static_cast(0)); + return; + } T output_min = static_cast(minval); T output_max = static_cast(maxval); @@ -67,11 +69,44 @@ void HistogramKernel(const Context& dev_ctx, maxval, minval)); - for (int64_t i = 0; i < input_numel; i++) { - if (input_data[i] >= output_min && input_data[i] <= output_max) { - const int64_t bin = (int64_t)((input_data[i] - output_min) * nbins / - (output_max - output_min)); - out_data[std::min(bin, nbins - 1)] += 1; + bool has_weight = weight.is_initialized(); + auto weight_data = + (weight.get_ptr() == nullptr ? nullptr : weight.get_ptr()->data()); + + // compute output + if (density) { + T total = static_cast(0); + for (int64_t i = 0; i < input_numel; i++) { + if (input_data[i] >= output_min && input_data[i] <= output_max) { + total += + has_weight ? static_cast(weight_data[i]) : static_cast(1); + } + } + float* out_data = dev_ctx.template Alloc(output); + phi::funcs::SetConstant()( + dev_ctx, output, static_cast(0)); + + const float interval_len = + static_cast(output_max - output_min) / nbins; + for (int64_t i = 0; i < input_numel; i++) { + if (input_data[i] >= output_min && input_data[i] <= output_max) { + const int64_t bin = (int64_t)((input_data[i] - output_min) * nbins / + (output_max - output_min)); + T weight_idx = weight_data == nullptr ? 1 : weight_data[i]; + out_data[std::min(bin, nbins - 1)] += + (static_cast(weight_idx) / total) / interval_len; + } + } + } else { + T* out_data = dev_ctx.template Alloc(output); + phi::funcs::SetConstant()(dev_ctx, output, static_cast(0)); + for (int64_t i = 0; i < input_numel; i++) { + if (input_data[i] >= output_min && input_data[i] <= output_max) { + const int64_t bin = (int64_t)((input_data[i] - output_min) * nbins / + (output_max - output_min)); + T weight_idx = weight_data == nullptr ? 1 : weight_data[i]; + out_data[std::min(bin, nbins - 1)] += weight_idx; + } } } } @@ -86,5 +121,5 @@ PD_REGISTER_KERNEL(histogram, double, int, int64_t) { - kernel->OutputAt(0).SetDataType(paddle::DataType::INT64); + kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); } diff --git a/paddle/phi/kernels/gpu/histogram_kernel.cu b/paddle/phi/kernels/gpu/histogram_kernel.cu index aa10aea35f867..5f6dad3ebd420 100644 --- a/paddle/phi/kernels/gpu/histogram_kernel.cu +++ b/paddle/phi/kernels/gpu/histogram_kernel.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,6 +14,9 @@ #include "paddle/phi/kernels/histogram_kernel.h" +#include +#include + #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" @@ -44,11 +47,14 @@ __device__ static IndexType GetBin(T input_value, template __global__ void KernelHistogram(const T* input, const int total_elements, + const bool has_weight, + const T* weight, const int64_t nbins, const T* min_value, const T* max_value, - int64_t* output) { - extern __shared__ int64_t buf_hist[]; + T* output) { + extern __shared__ __align__(sizeof(T)) unsigned char buf_hist_tmp[]; + T* buf_hist = reinterpret_cast(buf_hist_tmp); for (int i = threadIdx.x; i < nbins; i += blockDim.x) { buf_hist[i] = 0; } @@ -57,10 +63,12 @@ __global__ void KernelHistogram(const T* input, CUDA_KERNEL_LOOP(input_index, total_elements) { // const IndexType input_index = threadIdx.x + blockIdx.x * blockDim.x; const auto input_value = input[input_index]; + const auto weight_value = + has_weight ? weight[input_index] : static_cast(1); if (input_value >= *min_value && input_value <= *max_value) { const IndexType output_index = GetBin(input_value, *min_value, *max_value, nbins); - phi::CudaAtomicAdd(&buf_hist[output_index], 1); + phi::CudaAtomicAdd(&buf_hist[output_index], weight_value); } } __syncthreads(); @@ -70,6 +78,71 @@ __global__ void KernelHistogram(const T* input, } } +template +__global__ void KernelHistogramDensity(const T* input, + const int total_elements, + const bool has_weight, + const T* weight, + const int64_t nbins, + const T* min_value, + const T* max_value, + float* output) { + T count_weight = 0; + T total_weight; + __shared__ T total[BlockSize]; + extern __shared__ __align__(sizeof(float)) unsigned char buf_histd_tmp[]; + float* buf_histd = reinterpret_cast(buf_histd_tmp); + + for (int i = threadIdx.x; i < (total_elements); i += BlockSize) { + const auto input_value = input[i]; + const auto weight_value = has_weight ? weight[i] : static_cast(1); + if (input_value >= *min_value && input_value <= *max_value) { + count_weight += weight_value; + } + } + total[threadIdx.x] = count_weight; + __syncthreads(); + +// reduce the count with init value 0, and output accuracy. +#ifdef PADDLE_WITH_CUDA + total_weight = thrust::reduce(thrust::device, total, total + BlockSize, 0.0); +#else + // HIP thrust::reduce not support __device__ + for (int s = BlockSize / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + total[threadIdx.x] += total[threadIdx.x + s]; + } + __syncthreads(); + } + total_weight = total[0]; +#endif + + for (int i = threadIdx.x; i < nbins; i += blockDim.x) { + buf_histd[i] = 0; + } + __syncthreads(); + + const float interval_len = + static_cast(*max_value - *min_value) / nbins; + CUDA_KERNEL_LOOP(input_index, total_elements) { + // const IndexType input_index = threadIdx.x + blockIdx.x * blockDim.x; + const auto input_value = input[input_index]; + auto weight_value = has_weight ? weight[input_index] : static_cast(1); + if (input_value >= *min_value && input_value <= *max_value) { + const IndexType output_index = + GetBin(input_value, *min_value, *max_value, nbins); + float prob_value = static_cast(weight_value) / + static_cast(total_weight) / interval_len; + phi::CudaAtomicAdd(&buf_histd[output_index], prob_value); + } + } + __syncthreads(); + + for (int i = threadIdx.x; i < nbins; i += blockDim.x) { + phi::CudaAtomicAdd(&output[i], buf_histd[i]); + } +} + template __global__ void KernelMinMax(const T* input, const int numel, @@ -127,9 +200,11 @@ __global__ void KernelMinMax(const T min_value, template void HistogramKernel(const Context& dev_ctx, const DenseTensor& input, + const paddle::optional& weight, int64_t bins, int min, int max, + bool density, DenseTensor* output) { auto& nbins = bins; auto& minval = min; @@ -138,11 +213,11 @@ void HistogramKernel(const Context& dev_ctx, const T* input_data = input.data(); const int input_numel = input.numel(); - int64_t* out_data = dev_ctx.template Alloc(output); - phi::funcs::SetConstant()( - dev_ctx, output, static_cast(0)); - - if (input_data == nullptr) return; + if (input_data == nullptr) { + dev_ctx.template Alloc(output); + phi::funcs::SetConstant()(dev_ctx, output, static_cast(0)); + return; + } T output_min = static_cast(minval); T output_max = static_cast(maxval); @@ -178,12 +253,41 @@ void HistogramKernel(const Context& dev_ctx, maxval, minval)); + bool has_weight = weight.is_initialized(); + const T* weight_data = has_weight ? weight->data() : nullptr; + auto stream = dev_ctx.stream(); - KernelHistogram<<>>( - input_data, input_numel, nbins, min_block_ptr, max_block_ptr, out_data); + if (!density) { + T* out_data = dev_ctx.template Alloc(output); + phi::funcs::SetConstant()(dev_ctx, output, static_cast(0)); + KernelHistogram<<>>(input_data, + input_numel, + has_weight, + weight_data, + nbins, + min_block_ptr, + max_block_ptr, + out_data); + } else { + float* out_data = dev_ctx.template Alloc(output); + phi::funcs::SetConstant()( + dev_ctx, output, static_cast(0)); + KernelHistogramDensity + <<>>(input_data, + input_numel, + has_weight, + weight_data, + nbins, + min_block_ptr, + max_block_ptr, + out_data); + } } } // namespace phi @@ -196,5 +300,5 @@ PD_REGISTER_KERNEL(histogram, double, int, int64_t) { - kernel->OutputAt(0).SetDataType(paddle::DataType::INT64); + kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); } diff --git a/paddle/phi/kernels/histogram_kernel.h b/paddle/phi/kernels/histogram_kernel.h index 0020f7b0435da..19e5411fbda27 100644 --- a/paddle/phi/kernels/histogram_kernel.h +++ b/paddle/phi/kernels/histogram_kernel.h @@ -1,4 +1,4 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,14 +15,18 @@ #pragma once #include "paddle/phi/core/dense_tensor.h" +#include "paddle/utils/optional.h" + namespace phi { template void HistogramKernel(const Context& dev_ctx, const DenseTensor& input, + const paddle::optional& weight, int64_t bins, int min, int max, + bool density, DenseTensor* output); } // namespace phi diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 68ac8d3a8a577..1759019d6b92c 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -122,6 +122,7 @@ from .tensor.creation import tril_indices # noqa: F401 from .tensor.creation import triu_indices # noqa: F401 from .tensor.creation import polar # noqa: F401 +from .tensor.creation import histogram_bin_edges # noqa: F401 from .tensor.linalg import matmul # noqa: F401 from .tensor.linalg import dot # noqa: F401 from .tensor.linalg import norm # noqa: F401 @@ -829,6 +830,7 @@ 'trapezoid', 'cumulative_trapezoid', 'polar', + 'histogram_bin_edges', 'vander', 'unflatten', 'as_strided', diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 841925f8b7ff8..a3224f5769b98 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -43,6 +43,7 @@ from .creation import empty_like # noqa: F401 from .creation import complex # noqa: F401 from .creation import polar # noqa: F401 +from .creation import histogram_bin_edges # noqa: F401 from .linalg import matmul # noqa: F401 from .linalg import dot # noqa: F401 from .linalg import cov # noqa: F401 @@ -657,6 +658,7 @@ 'trapezoid', 'cumulative_trapezoid', 'polar', + 'histogram_bin_edges', 'sigmoid', 'sigmoid_', 'vander', diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index ff92f9baae541..85a2a42cd972f 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -2659,3 +2659,54 @@ def polar(abs, angle, name=None): ) return paddle.complex(abs * paddle.cos(angle), abs * paddle.sin(angle)) + + +def histogram_bin_edges(input, bins=100, range=None, weight=None, name=None): + """ + Computes only the edges of the bins used by the histogram function. + + Args: + input (Tensor): A Tensor(or LoDTensor) with shape :math:`[N_1, N_2,..., N_k]` . The data type of the input Tensor + should be float32, float64, int32, int64. + bins (int, optional): number of histogram bins. + range (list | tuple): The lower and upper range of the bins. If None, `range` is simply (input.min(), input.max()). + The first element of the range must be less than or equal to the second. Default: None. + weight (Tensor, optional): Weight for each value in the input tensor. Should have the same shape and data type as input. + This is currently not used by any of the bin estimators, but may be in the future. Default: None. + name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + + Returns: + Tensor: the values of the histogram and the bin edges. The output data type will be float32. + + Examples: + .. code-block:: python + + import paddle + + inputs = paddle.to_tensor([1, 2, 1]) + result = paddle.histogram_bin_edges(inputs, bins=4, range=(0, 3)) + print(result) # [0., 0.75, 1.5, 2.25, 3.] + + """ + check_type(input, 'input', (Variable), 'histogram_bin_edges') + check_dtype( + input.dtype, + 'input', + ['float32', 'float64', 'int32', 'int64'], + 'histogram_bin_edges', + ) + check_type(bins, 'bins', int, 'histogram_bin_edges') + if range is None: + start = paddle.max(input) + stop = paddle.min(input) + else: + check_type(range, 'range', (list, tuple), 'histogram_bin_edges') + if len(range) != 2: + raise ValueError("The length of range should be equal 2") + start, stop = range + if start > stop: + raise ValueError("max must be larger than min in range parameter") + if (stop - start) == 0: + start = start - 0.5 + stop = stop + 0.5 + return linspace(start, stop, bins + 1) diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 60c5afb99fc7b..65f1fe6ffa7af 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -1685,7 +1685,9 @@ def bmm(x, y, name=None): return out -def histogram(input, bins=100, min=0, max=0, name=None): +def histogram( + input, weight=None, bins=100, min=0, max=0, density=False, name=None +): """ Computes the histogram of a tensor. The elements are sorted into equal width bins between min and max. If min and max are both zero, the minimum and maximum values of the data are used. @@ -1693,13 +1695,17 @@ def histogram(input, bins=100, min=0, max=0, name=None): Args: input (Tensor): A Tensor(or LoDTensor) with shape :math:`[N_1, N_2,..., N_k]` . The data type of the input Tensor should be float32, float64, int32, int64. + weight (Tensor, optional): Weight for each value in the input tensor. Should have the same shape and data type as input. + Default is None. bins (int, optional): number of histogram bins. Default: 100. min (int, optional): lower end of the range (inclusive). Default: 0. max (int, optional): upper end of the range (inclusive). Default: 0. + density (bool, optional): If False, the result will contain the count(or total weight) in each bin. If True, the result is the + value of the probability density function at the bin, normalized such that the integral over the range is 1. name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: - Tensor: data type is int64, shape is (nbins,). + Tensor: data type is float32, shape is (nbins,). Examples: .. code-block:: python @@ -1713,20 +1719,36 @@ def histogram(input, bins=100, min=0, max=0, name=None): [0, 2, 1, 0]) """ if in_dynamic_mode(): - return _C_ops.histogram(input, bins, min, max) + return _C_ops.histogram(input, weight, bins, min, max, density) else: helper = LayerHelper('histogram', **locals()) check_variable_and_dtype( input, 'X', ['int32', 'int64', 'float32', 'float64'], 'histogram' ) - out = helper.create_variable_for_type_inference(VarDesc.VarType.INT64) - helper.append_op( - type='histogram', - inputs={'X': input}, - outputs={'Out': out}, - attrs={'bins': bins, 'min': min, 'max': max}, - ) - return out + if weight is not None: + check_variable_and_dtype( + weight, + 'Weight', + ['int32', 'int64', 'float32', 'float64'], + 'histogram', + ) + if input.dtype != weight.dtype: + raise ValueError( + "Only support input and weight have the same dtype." + ) + check_type(bins, 'bins', int, 'histogram') + check_type(density, 'density', bool, 'histogram') + if density: + out = helper.create_variable_for_type_inference(VarDesc.VarType.FLOAT32) + else: + out = helper.create_variable_for_type_inference(dtype=input.dtype) + helper.append_op( + type='histogram', + inputs={'X': input, 'Weight': weight}, + outputs={'Out': out}, + attrs={'bins': bins, 'min': min, 'max': max, 'density': density}, + ) + return out def bincount(x, weights=None, minlength=0, name=None): diff --git a/test/legacy_test/test_histogram_bin_edges.py b/test/legacy_test/test_histogram_bin_edges.py new file mode 100644 index 0000000000000..c5c6b18558b5f --- /dev/null +++ b/test/legacy_test/test_histogram_bin_edges.py @@ -0,0 +1,94 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle + + +class TestHistogramBinEdgesAPI(unittest.TestCase): + """Test histogram_bin_edges api.""" + + def setUp(self): + self.input_np = np.random.uniform(-5, 5, [2, 3]).astype(np.float32) + self.bins = 4 + self.range = (0.0, 3.0) + self.place = [paddle.CPUPlace()] + + def test_api_static(self): + paddle.enable_static() + + def run(place): + with paddle.static.program_guard(paddle.static.Program()): + inputs = paddle.static.data( + name='input', dtype='float32', shape=[2, 3] + ) + out = paddle.histogram_bin_edges(inputs, self.bins, self.range) + exe = paddle.static.Executor(place) + res = exe.run( + feed={'input': self.input_np}, + fetch_list=[out], + ) + out_ref = np.histogram_bin_edges( + self.input_np, self.bins, self.range + ) + np.testing.assert_allclose(out_ref, res[0], rtol=1e-05) + + for place in self.place: + run(place) + + def test_api_dygraph(self): + def run(place): + paddle.disable_static(place) + inputs = paddle.to_tensor(self.input_np) + out1 = paddle.histogram_bin_edges(inputs, bins=4, range=(0, 3)) + out_ref1 = np.histogram_bin_edges( + self.input_np, bins=4, range=(0, 3) + ) + np.testing.assert_allclose(out_ref1, out1.numpy(), rtol=1e-05) + paddle.enable_static() + + for place in self.place: + run(place) + + def test_errors(self): + input = paddle.to_tensor(self.input_np) + bins = self.bins + range = self.range + # bin dtype is not int + self.assertRaises( + TypeError, paddle.histogram_bin_edges, input, bins=1.5, range=range + ) + # the range len is not equal 2 + self.assertRaises( + ValueError, + paddle.histogram_bin_edges, + input, + bins=bins, + range=(0, 2, 3), + ) + # the min of range greater than max + self.assertRaises( + ValueError, + paddle.histogram_bin_edges, + input, + bins=bins, + range=(3, 0), + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/legacy_test/test_histogram_op.py b/test/legacy_test/test_histogram_op.py index 8ca5b4dd500da..5ea385b36e0fa 100644 --- a/test/legacy_test/test_histogram_op.py +++ b/test/legacy_test/test_histogram_op.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -32,6 +32,9 @@ def test_static_graph(self): inputs = paddle.static.data( name='input', dtype='int64', shape=[2, 3] ) + weight = paddle.static.data( + name='weight', dtype='int64', shape=[2, 3] + ) output = paddle.histogram(inputs, bins=5, min=1, max=5) place = base.CPUPlace() if base.core.is_compiled_with_cuda(): @@ -39,11 +42,14 @@ def test_static_graph(self): exe = base.Executor(place) exe.run(startup_program) img = np.array([[2, 4, 2], [2, 5, 4]]).astype(np.int64) + w = np.array([[1, 3, 6], [2, 2, 8]]).astype(np.int64) res = exe.run( - train_program, feed={'input': img}, fetch_list=[output] + train_program, + feed={'input': img, 'weight': w}, + fetch_list=[output], ) actual = np.array(res[0]) - expected = np.array([0, 3, 0, 2, 1]).astype(np.int64) + expected = np.array([0, 9, 0, 11, 2]).astype(np.int64) self.assertTrue( (actual == expected).all(), msg='histogram output is wrong, out =' + str(actual), @@ -53,16 +59,21 @@ def test_dygraph(self): with base.dygraph.guard(): inputs_np = np.array([[2, 4, 2], [2, 5, 4]]).astype(np.int64) inputs = base.dygraph.to_variable(inputs_np) - actual = paddle.histogram(inputs, bins=5, min=1, max=5) - expected = np.array([0, 3, 0, 2, 1]).astype(np.int64) + actual = paddle.histogram( + inputs, bins=5, min=1, max=5, density=True + ) + expected = np.array( + [0.0, 0.625, 0.0, 0.4166667, 0.20833334] + ).astype(np.float32) self.assertTrue( (actual.numpy() == expected).all(), msg='histogram output is wrong, out =' + str(actual.numpy()), ) - inputs_np = np.array([[2, 4, 2], [2, 5, 4]]).astype(np.int64) inputs = paddle.to_tensor(inputs_np) - actual = paddle.histogram(inputs, bins=5, min=1, max=5) + actual = paddle.histogram( + inputs, bins=5, min=1, max=5, density=True + ) self.assertTrue( (actual.numpy() == expected).all(), msg='histogram output is wrong, out =' + str(actual.numpy()), @@ -99,7 +110,9 @@ def net_func(): input_value = paddle.tensor.fill_constant( shape=[3, 4], dtype='float32', value=3.0 ) - paddle.histogram(input=input_value, bins=1, min=5, max=1) + paddle.histogram( + input=input_value, bins=1, min=5, max=1, density=False + ) with self.assertRaises(ValueError): self.run_network(net_func) @@ -122,6 +135,42 @@ def test_type_errors(self): self.assertRaises( TypeError, paddle.histogram, 1, bins=5, min=1, max=5 ) + # The weight type must be Variable. + self.assertRaises( + TypeError, paddle.histogram, [1], weight=1, bins=5, min=1, max=5 + ) + # The weight type must be equal the input type. + x_int32 = paddle.static.data( + name='x_int32', shape=[4, 3], dtype='int32' + ) + weight_float32 = paddle.static.data( + name='weight_float32', shape=[4, 3], dtype='float32' + ) + self.assertRaises( + ValueError, + paddle.histogram, + x_int32, + weight=weight_float32, + bins=5, + min=1, + max=5, + ) + # The weight shape must be equal the input shape. + x_shape = paddle.static.data( + name='x_shape', shape=[4, 3], dtype='int32' + ) + w_shape = paddle.static.data( + name='w_shape', shape=[3, 4], dtype='int32' + ) + self.assertRaises( + ValueError, + paddle.histogram, + x_shape, + weight=w_shape, + bins=5, + min=1, + max=5, + ) # The input type must be 'int32', 'int64', 'float32', 'float64' x_bool = paddle.static.data( name='x_bool', shape=[4, 3], dtype='bool' @@ -134,35 +183,167 @@ def test_type_errors(self): class TestHistogramOp(OpTest): def setUp(self): self.op_type = "histogram" - self.init_test_case() - np_input = np.random.uniform(low=0.0, high=20.0, size=self.in_shape) self.python_api = paddle.histogram + self.init_test_case() + np_input = np.random.randint(low=0, high=20, size=self.in_shape).astype( + self.dtype + ) self.inputs = {"X": np_input} self.init_attrs() + self.place = [paddle.CPUPlace()] Out, _ = np.histogram( - np_input, bins=self.bins, range=(self.min, self.max) + np_input, + bins=self.bins, + range=(self.min, self.max), + density=self.density, ) - self.outputs = {"Out": Out.astype(np.int64)} + self.outputs = {"Out": Out} def init_test_case(self): self.in_shape = (10, 12) self.bins = 5 self.min = 1 self.max = 5 + self.density = False + self.dtype = np.int32 def init_attrs(self): - self.attrs = {"bins": self.bins, "min": self.min, "max": self.max} + self.attrs = { + "bins": self.bins, + "min": self.min, + "max": self.max, + "density": self.density, + } def test_check_output(self): self.check_output() +class TestCase1(TestHistogramOp): + # with weights(FLOAT32) + def setUp(self): + self.op_type = "histogram" + self.init_test_case() + np_input = np.random.uniform( + low=0.0, high=20.0, size=self.in_shape + ).astype(self.dtype) + self.np_weight = np.random.uniform( + low=0, high=20, size=self.in_shape + ).astype(self.dtype) + self.python_api = paddle.histogram + self.inputs = {"X": np_input, "Weight": self.np_weight} + self.init_attrs() + self.place = [paddle.CPUPlace()] + Out, _ = np.histogram( + np_input, + bins=self.bins, + range=(self.min, self.max), + weights=self.np_weight, + density=self.density, + ) + self.outputs = {"Out": Out} + + def init_test_case(self): + self.in_shape = (10, 12) + self.bins = 5 + self.min = 1 + self.max = 5 + self.density = False + self.dtype = np.float32 + + +class TestCase2(TestHistogramOp): + # with weights(FLOAT64) + def setUp(self): + self.op_type = "histogram" + self.init_test_case() + np_input = np.random.uniform( + low=0.0, high=20.0, size=self.in_shape + ).astype(self.dtype) + self.np_weight = np.random.uniform( + low=0.0, high=20.0, size=self.in_shape + ).astype(self.dtype) + self.python_api = paddle.histogram + self.inputs = {"X": np_input, "Weight": self.np_weight} + self.init_attrs() + self.place = [paddle.CPUPlace()] + + Out, _ = np.histogram( + np_input, + bins=self.bins, + range=(self.min, self.max), + weights=self.np_weight, + density=self.density, + ) + self.outputs = {"Out": Out} + + def init_test_case(self): + self.in_shape = (10, 12) + self.bins = 5 + self.min = 1 + self.max = 5 + self.density = True + self.dtype = np.float64 + + +class TestCase3(TestHistogramOp): + # with weights(INT64) + def setUp(self): + self.op_type = "histogram" + self.init_test_case() + np_input = np.random.randint(low=0, high=20, size=self.in_shape).astype( + self.dtype + ) + self.np_weight = np.random.randint( + low=0, high=20, size=self.in_shape + ).astype(self.dtype) + self.python_api = paddle.histogram + self.inputs = {"X": np_input, "Weight": self.np_weight} + self.init_attrs() + self.place = [paddle.CPUPlace()] + Out, _ = np.histogram( + np_input, + bins=self.bins, + range=(self.min, self.max), + weights=self.np_weight, + density=self.density, + ) + self.outputs = {"Out": Out} + + def init_test_case(self): + self.in_shape = (10, 12) + self.bins = 5 + self.min = 1 + self.max = 5 + self.density = False + self.dtype = np.int64 + + class TestHistogramOp_ZeroDim(TestHistogramOp): + def setUp(self): + self.op_type = "histogram" + self.python_api = paddle.histogram + self.init_test_case() + np_input = np.random.randint(low=0, high=20, size=self.in_shape).astype( + self.dtype + ) + self.inputs = {"X": np_input} + self.init_attrs() + self.place = [paddle.CPUPlace()] + Out, _ = np.histogram( + np_input, bins=self.bins, range=(self.min, self.max) + ) + self.outputs = {"Out": Out} + def init_test_case(self): self.in_shape = [] self.bins = 5 self.min = 1 self.max = 5 + self.dtype = np.float32 + + def init_attrs(self): + self.attrs = {"bins": self.bins, "min": self.min, "max": self.max} if __name__ == "__main__":