From 66c0d607a1e77f6632a94138ee19ac5b6a5911a6 Mon Sep 17 00:00:00 2001 From: TBD1 Date: Tue, 29 Aug 2023 11:08:39 +0000 Subject: [PATCH 1/2] add weight and density parameter in histogram op && add histogram_bin_edges op --- paddle/phi/api/yaml/op_compat.yaml | 2 +- paddle/phi/api/yaml/ops.yaml | 3 +- paddle/phi/infermeta/binary.cc | 54 +++++++ paddle/phi/infermeta/binary.h | 8 + paddle/phi/infermeta/unary.cc | 21 --- paddle/phi/infermeta/unary.h | 2 - paddle/phi/kernels/cpu/histogram_kernel.cc | 55 +++++-- paddle/phi/kernels/gpu/histogram_kernel.cu | 122 ++++++++++++-- paddle/phi/kernels/histogram_kernel.h | 6 +- python/paddle/__init__.py | 2 + python/paddle/tensor/__init__.py | 2 + python/paddle/tensor/creation.py | 51 ++++++ python/paddle/tensor/linalg.py | 42 +++-- test/legacy_test/test_histogram_bin_edges.py | 85 ++++++++++ test/legacy_test/test_histogram_op.py | 157 +++++++++++++++++-- 15 files changed, 534 insertions(+), 78 deletions(-) create mode 100644 test/legacy_test/test_histogram_bin_edges.py diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 00f0c9f300a93..c0c429d5b2509 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1363,7 +1363,7 @@ - op : histogram inputs : - input : X + {input : X, weight : Weight} outputs : out : Out diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index eca5b93e24f88..d541b40a2e37f 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1144,12 +1144,13 @@ backward : heaviside_grad - op : histogram - args : (Tensor input, int64_t bins = 100, int min = 0, int max = 0) + args : (Tensor input, Tensor weight, int64_t bins = 100, int min = 0, int max = 0, bool density = false) output : Tensor(out) infer_meta : func : HistogramInferMeta kernel : func : histogram + optional : weight - op : huber_loss args : (Tensor input, Tensor label, float delta) diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 0118db6041203..e09fae890b27f 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -1671,6 +1671,60 @@ void GridSampleBaseInferMeta(const MetaTensor& x, out->share_lod(x); } +void HistogramInferMeta(const MetaTensor& input, + const MetaTensor& weight, + int64_t bins, + int min, + int max, + bool density, + MetaTensor* out) { + auto input_dim = input.dims(); + if (weight) { + auto weight_dim = weight.dims(); + PADDLE_ENFORCE_EQ( + weight_dim, + input_dim, + phi::errors::InvalidArgument( + "The 'shape' of Input(Weight) must be equal to the 'shape' of " + "Input(X)." + "But received: the 'shape' of Input(Weight) is [%s]," + "the 'shape' of Input(X) is [%s]", + weight_dim, + input_dim)); + PADDLE_ENFORCE_EQ( + input.dtype() == weight.dtype(), + true, + phi::errors::InvalidArgument( + "The 'dtpye' of Input(Weight) must be equal to the 'dtype' of " + "Input(X)." + "But received: the 'dtype' of Input(Weight) is [%s]," + "the 'dtype' of Input(X) is [%s]", + weight.dtype(), + input.dtype())); + } + PADDLE_ENFORCE_GE(bins, + 1, + phi::errors::InvalidArgument( + "The bins should be greater than or equal to 1." + "But received nbins is %d", + bins)); + PADDLE_ENFORCE_GE( + max, + min, + phi::errors::InvalidArgument("max must be larger or equal to min." + "But received max is %d, min is %d", + max, + min)); + + out->set_dims({bins}); + out->share_lod(input); + if (density) { + out->set_dtype(DataType::FLOAT32); + } else { + out->set_dtype(input.dtype()); + } +} + void HuberLossInferMeta(const MetaTensor& input, const MetaTensor& label, float delta, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 8aa4114e74046..bd3d1e43d7d5b 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -270,6 +270,14 @@ void GridSampleBaseInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void HistogramInferMeta(const MetaTensor& input, + const MetaTensor& weight, + int64_t bins, + int min, + int max, + bool density, + MetaTensor* out); + void HuberLossInferMeta(const MetaTensor& input_meta, const MetaTensor& label_meta, float delta, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index fc4e4483d17e0..3042ea84cc1ef 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1751,27 +1751,6 @@ void GumbelSoftmaxInferMeta(const MetaTensor& x, UnchangedInferMetaCheckAxis(x, axis, out); } -void HistogramInferMeta( - const MetaTensor& input, int64_t bins, int min, int max, MetaTensor* out) { - PADDLE_ENFORCE_GE(bins, - 1, - phi::errors::InvalidArgument( - "The bins should be greater than or equal to 1." - "But received nbins is %d", - bins)); - PADDLE_ENFORCE_GE( - max, - min, - phi::errors::InvalidArgument("max must be larger or equal to min." - "But received max is %d, min is %d", - max, - min)); - - out->set_dims({bins}); - out->share_lod(input); - out->set_dtype(DataType::INT64); -} - void IdentityLossInferMeta(const MetaTensor& x, int reduction, MetaTensor* out) { diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 14c3ec0f857bd..10f0ef1afa51a 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -266,8 +266,6 @@ void GumbelSoftmaxInferMeta(const MetaTensor& x, bool hard, int axis, MetaTensor* out); -void HistogramInferMeta( - const MetaTensor& input, int64_t bins, int min, int max, MetaTensor* out); void IdentityLossInferMeta(const MetaTensor& x, int reduction, MetaTensor* out); diff --git a/paddle/phi/kernels/cpu/histogram_kernel.cc b/paddle/phi/kernels/cpu/histogram_kernel.cc index 030dee9908b31..b5562688e8ac8 100644 --- a/paddle/phi/kernels/cpu/histogram_kernel.cc +++ b/paddle/phi/kernels/cpu/histogram_kernel.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -23,9 +23,11 @@ namespace phi { template void HistogramKernel(const Context& dev_ctx, const DenseTensor& input, + const paddle::optional& weight, int64_t bins, int min, int max, + bool density, DenseTensor* output) { auto& nbins = bins; auto& minval = min; @@ -34,11 +36,11 @@ void HistogramKernel(const Context& dev_ctx, const T* input_data = input.data(); auto input_numel = input.numel(); - int64_t* out_data = dev_ctx.template Alloc(output); - phi::funcs::SetConstant()( - dev_ctx, output, static_cast(0)); - - if (input_data == nullptr) return; + if (input_data == nullptr) { + dev_ctx.template Alloc(output); + phi::funcs::SetConstant()(dev_ctx, output, static_cast(0)); + return; + } T output_min = static_cast(minval); T output_max = static_cast(maxval); @@ -67,11 +69,40 @@ void HistogramKernel(const Context& dev_ctx, maxval, minval)); - for (int64_t i = 0; i < input_numel; i++) { - if (input_data[i] >= output_min && input_data[i] <= output_max) { - const int64_t bin = (int64_t)((input_data[i] - output_min) * nbins / - (output_max - output_min)); - out_data[std::min(bin, nbins - 1)] += 1; + bool has_weight = weight.is_initialized(); + auto weight_data = (weight.get_ptr() == nullptr ? nullptr : weight.get_ptr()->data()); + + // compute output + if (density) { + T total = static_cast(0); + for(int64_t i = 0; i < input_numel; i++) { + if (input_data[i] >= output_min && input_data[i] <= output_max) { + total += has_weight ? static_cast(weight_data[i]) : static_cast(1); + } + } + float* out_data = dev_ctx.template Alloc(output); + phi::funcs::SetConstant()(dev_ctx, output, static_cast(0)); + + const float interval_len = static_cast(output_max - output_min) / nbins; + for (int64_t i = 0; i < input_numel; i++) { + if (input_data[i] >= output_min && input_data[i] <= output_max) { + const int64_t bin = (int64_t)((input_data[i] - output_min) * nbins / + (output_max - output_min)); + T weight_idx = weight_data == nullptr ? 1 : weight_data[i]; + out_data[std::min(bin, nbins - 1)] += (static_cast(weight_idx) + / total) / interval_len; + } + } + } else { + T* out_data = dev_ctx.template Alloc(output); + phi::funcs::SetConstant()(dev_ctx, output, static_cast(0)); + for (int64_t i = 0; i < input_numel; i++) { + if (input_data[i] >= output_min && input_data[i] <= output_max) { + const int64_t bin = (int64_t)((input_data[i] - output_min) * nbins / + (output_max - output_min)); + T weight_idx = weight_data == nullptr ? 1 : weight_data[i]; + out_data[std::min(bin, nbins - 1)] += weight_idx; + } } } } @@ -86,5 +117,5 @@ PD_REGISTER_KERNEL(histogram, double, int, int64_t) { - kernel->OutputAt(0).SetDataType(paddle::DataType::INT64); + kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); } diff --git a/paddle/phi/kernels/gpu/histogram_kernel.cu b/paddle/phi/kernels/gpu/histogram_kernel.cu index aa10aea35f867..cea9977d3339f 100644 --- a/paddle/phi/kernels/gpu/histogram_kernel.cu +++ b/paddle/phi/kernels/gpu/histogram_kernel.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,6 +14,9 @@ #include "paddle/phi/kernels/histogram_kernel.h" +#include +#include + #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" @@ -44,11 +47,15 @@ __device__ static IndexType GetBin(T input_value, template __global__ void KernelHistogram(const T* input, const int total_elements, + const bool has_weight, + const T* weight, const int64_t nbins, const T* min_value, const T* max_value, - int64_t* output) { - extern __shared__ int64_t buf_hist[]; + T* output) { + // extern __shared__ T buf_hist1[]; + extern __shared__ __align__(sizeof(T)) unsigned char buf_hist_tmp[]; + T *buf_hist = reinterpret_cast(buf_hist_tmp); for (int i = threadIdx.x; i < nbins; i += blockDim.x) { buf_hist[i] = 0; } @@ -57,10 +64,11 @@ __global__ void KernelHistogram(const T* input, CUDA_KERNEL_LOOP(input_index, total_elements) { // const IndexType input_index = threadIdx.x + blockIdx.x * blockDim.x; const auto input_value = input[input_index]; + const auto weight_value = has_weight ? weight[input_index] : static_cast(1); if (input_value >= *min_value && input_value <= *max_value) { const IndexType output_index = GetBin(input_value, *min_value, *max_value, nbins); - phi::CudaAtomicAdd(&buf_hist[output_index], 1); + phi::CudaAtomicAdd(&buf_hist[output_index], weight_value); } } __syncthreads(); @@ -70,6 +78,73 @@ __global__ void KernelHistogram(const T* input, } } +template +__global__ void KernelHistogramDensity(const T* input, + const int total_elements, + const bool has_weight, + const T* weight, + const int64_t nbins, + const T* min_value, + const T* max_value, + float* output) { + T count_weight = 0; + T total_weight; + __shared__ T total[BlockSize]; + // extern __shared__ float buf_histd[]; + extern __shared__ __align__(sizeof(float)) unsigned char buf_histd_tmp[]; + float *buf_histd = reinterpret_cast(buf_histd_tmp); + + for (int i = threadIdx.x; i < (total_elements); i += BlockSize) { + const auto input_value = input[i]; + const auto weight_value = has_weight ? weight[i] : static_cast(1); + if (input_value >= *min_value && input_value <= *max_value) { + count_weight += weight_value; + } + } + total[threadIdx.x] = count_weight; + __syncthreads(); + +// reduce the count with init value 0, and output accuracy. +#ifdef PADDLE_WITH_CUDA + total_weight = thrust::reduce(thrust::device, total, total + BlockSize, 0.0); +#else + // HIP thrust::reduce not support __device__ + for (int s = BlockSize / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + total[threadIdx.x] += total[threadIdx.x + s]; + } + __syncthreads(); + } + total_weight = total[0]; +#endif + + for (int i = threadIdx.x; i < nbins; i += blockDim.x) { + buf_histd[i] = 0; + } + __syncthreads(); + + const float interval_len = static_cast(*max_value - *min_value) + / nbins; + CUDA_KERNEL_LOOP(input_index, total_elements) { + // const IndexType input_index = threadIdx.x + blockIdx.x * blockDim.x; + const auto input_value = input[input_index]; + auto weight_value = has_weight ? weight[input_index] : static_cast(1); + if (input_value >= *min_value && input_value <= *max_value) { + const IndexType output_index = + GetBin(input_value, *min_value, *max_value, nbins); + float prob_value = static_cast(weight_value) + / static_cast(total_weight) + / interval_len; + phi::CudaAtomicAdd(&buf_histd[output_index], prob_value); + } + } + __syncthreads(); + + for (int i = threadIdx.x; i < nbins; i += blockDim.x) { + phi::CudaAtomicAdd(&output[i], buf_histd[i]); + } +} + template __global__ void KernelMinMax(const T* input, const int numel, @@ -127,9 +202,11 @@ __global__ void KernelMinMax(const T min_value, template void HistogramKernel(const Context& dev_ctx, const DenseTensor& input, + const paddle::optional& weight, int64_t bins, int min, int max, + bool density, DenseTensor* output) { auto& nbins = bins; auto& minval = min; @@ -138,11 +215,11 @@ void HistogramKernel(const Context& dev_ctx, const T* input_data = input.data(); const int input_numel = input.numel(); - int64_t* out_data = dev_ctx.template Alloc(output); - phi::funcs::SetConstant()( - dev_ctx, output, static_cast(0)); - - if (input_data == nullptr) return; + if (input_data == nullptr) { + dev_ctx.template Alloc(output); + phi::funcs::SetConstant()(dev_ctx, output, static_cast(0)); + return; + } T output_min = static_cast(minval); T output_max = static_cast(maxval); @@ -178,12 +255,27 @@ void HistogramKernel(const Context& dev_ctx, maxval, minval)); + bool has_weight = weight.is_initialized(); + const T* weight_data = has_weight ? weight->data() : nullptr; + auto stream = dev_ctx.stream(); - KernelHistogram<<>>( - input_data, input_numel, nbins, min_block_ptr, max_block_ptr, out_data); + if(!density) { + T* out_data = dev_ctx.template Alloc(output); + phi::funcs::SetConstant()(dev_ctx, output, static_cast(0)); + KernelHistogram + <<>>( + input_data, input_numel, has_weight, weight_data, + nbins, min_block_ptr, max_block_ptr, out_data); + } else { + float* out_data = dev_ctx.template Alloc(output); + phi::funcs::SetConstant()(dev_ctx, output, static_cast(0)); + KernelHistogramDensity + <<>>( + input_data, input_numel, has_weight, weight_data, + nbins, min_block_ptr, max_block_ptr, out_data); + } } } // namespace phi @@ -196,5 +288,5 @@ PD_REGISTER_KERNEL(histogram, double, int, int64_t) { - kernel->OutputAt(0).SetDataType(paddle::DataType::INT64); + kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); } diff --git a/paddle/phi/kernels/histogram_kernel.h b/paddle/phi/kernels/histogram_kernel.h index 0020f7b0435da..19e5411fbda27 100644 --- a/paddle/phi/kernels/histogram_kernel.h +++ b/paddle/phi/kernels/histogram_kernel.h @@ -1,4 +1,4 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,14 +15,18 @@ #pragma once #include "paddle/phi/core/dense_tensor.h" +#include "paddle/utils/optional.h" + namespace phi { template void HistogramKernel(const Context& dev_ctx, const DenseTensor& input, + const paddle::optional& weight, int64_t bins, int min, int max, + bool density, DenseTensor* output); } // namespace phi diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index ca7ab3485e52d..289b9544a67e7 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -122,6 +122,7 @@ from .tensor.creation import tril_indices # noqa: F401 from .tensor.creation import triu_indices # noqa: F401 from .tensor.creation import polar # noqa: F401 +from .tensor.creation import histogram_bin_edges # noqa: F401 from .tensor.linalg import matmul # noqa: F401 from .tensor.linalg import dot # noqa: F401 from .tensor.linalg import norm # noqa: F401 @@ -827,6 +828,7 @@ 'trapezoid', 'cumulative_trapezoid', 'polar', + 'histogram_bin_edges', 'vander', 'unflatten', 'as_strided', diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 841925f8b7ff8..65ef37817afa1 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -43,6 +43,7 @@ from .creation import empty_like # noqa: F401 from .creation import complex # noqa: F401 from .creation import polar # noqa: F401 +from .creation import histogram_bin_edges # noqa: F401 from .linalg import matmul # noqa: F401 from .linalg import dot # noqa: F401 from .linalg import cov # noqa: F401 @@ -657,6 +658,7 @@ 'trapezoid', 'cumulative_trapezoid', 'polar', + 'histogram_bin_edges' 'sigmoid', 'sigmoid_', 'vander', diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 528a3c40b9a4b..e8e8bc13d7240 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -2659,3 +2659,54 @@ def polar(abs, angle, name=None): ) return paddle.complex(abs * paddle.cos(angle), abs * paddle.sin(angle)) + + +def histogram_bin_edges(input, bins=100, range=None, weight=None, name=None): + """ + Computes only the edges of the bins used by the histogram function. + + Args: + input (Tensor): A Tensor(or LoDTensor) with shape :math:`[N_1, N_2,..., N_k]` . The data type of the input Tensor + should be float32, float64, int32, int64. + bins (int, optional): number of histogram bins. + range (list | tuple): The lower and upper range of the bins. If None, `range` is simply (input.min(), input.max()). + The first element of the range must be less than or equal to the second. Default: None. + weight (Tensor, optional): Weight for each value in the input tensor. Should have the same shape and data type as input. + This is currently not used by any of the bin estimators, but may be in the future. Default: None. + name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + + Returns: + Tensor: the values of the histogram and the bin edges. The output data type will be float32. + + Examples: + .. code-block:: python + + import paddle + + inputs = paddle.to_tensor([1, 2, 1]) + result = paddle.histogram_bin_edges(inputs, bins=4, range=(0, 3)) + print(result) # [0., 0.75, 1.5, 2.25, 3.] + + """ + check_type(input, 'input', (Variable), 'histogram_bin_edges') + check_dtype( + input.dtype, + 'input', + ['float32', 'float64', 'int32', 'int64'], + 'histogram_bin_edges', + ) + check_type(bins, 'bins', int, 'histogram_bin_edges') + if range is None: + start = paddle.max(input) + stop = paddle.min(input) + else: + check_type(range, 'range', (list, tuple), 'histogram_bin_edges') + if len(range) != 2: + raise ValueError("The length of range should be equal 2") + start, stop = range + if start > stop: + raise ValueError("max must be larger than min in range parameter") + if (stop - start) == 0: + start = start - 0.5 + stop = stop + 0.5 + return linspace(start, stop, bins+1) diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 0e522533c8512..ff04d0f1787ed 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -1631,7 +1631,7 @@ def bmm(x, y, name=None): return out -def histogram(input, bins=100, min=0, max=0, name=None): +def histogram(input, weight=None, bins=100, min=0, max=0, density=False, name=None): """ Computes the histogram of a tensor. The elements are sorted into equal width bins between min and max. If min and max are both zero, the minimum and maximum values of the data are used. @@ -1639,13 +1639,17 @@ def histogram(input, bins=100, min=0, max=0, name=None): Args: input (Tensor): A Tensor(or LoDTensor) with shape :math:`[N_1, N_2,..., N_k]` . The data type of the input Tensor should be float32, float64, int32, int64. + weight (Tensor, optional): Weight for each value in the input tensor. Should have the same shape and data type as input. + Default is None. bins (int, optional): number of histogram bins. min (int, optional): lower end of the range (inclusive). max (int, optional): upper end of the range (inclusive). + density (bool, optional): If False, the result will contain the count(or total weight) in each bin. If True, the result is the + value of the probability density function at the bin, normalized such that the integral over the range is 1. name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: - Tensor: data type is int64, shape is (nbins,). + Tensor: data type is float32, shape is (nbins,). Examples: .. code-block:: python @@ -1657,20 +1661,36 @@ def histogram(input, bins=100, min=0, max=0, name=None): print(result) # [0, 2, 1, 0] """ if in_dynamic_mode(): - return _C_ops.histogram(input, bins, min, max) + return _C_ops.histogram(input, weight, bins, min, max, density) else: helper = LayerHelper('histogram', **locals()) check_variable_and_dtype( input, 'X', ['int32', 'int64', 'float32', 'float64'], 'histogram' ) - out = helper.create_variable_for_type_inference(VarDesc.VarType.INT64) - helper.append_op( - type='histogram', - inputs={'X': input}, - outputs={'Out': out}, - attrs={'bins': bins, 'min': min, 'max': max}, - ) - return out + if weight is not None: + check_variable_and_dtype( + weight, + 'Weight', + ['int32', 'int64', 'float32', 'float64'], + 'histogram', + ) + if input.dtype != weight.dtype: + raise ValueError( + "Only support input and weight have the same dtype." + ) + check_type(bins, 'bins', int, 'histogram') + check_type(density, 'density', bool, 'histogram') + if density: + out = helper.create_variable_for_type_inference(VarDesc.VarType.FLOAT32) + else: + out = helper.create_variable_for_type_inference(dtype=input.dtype) + helper.append_op( + type='histogram', + inputs={'X': input, 'Weight': weight}, + outputs={'Out': out}, + attrs={'bins': bins, 'min': min, 'max': max, 'density': density}, + ) + return out def bincount(x, weights=None, minlength=0, name=None): diff --git a/test/legacy_test/test_histogram_bin_edges.py b/test/legacy_test/test_histogram_bin_edges.py new file mode 100644 index 0000000000000..4f50ddab39f57 --- /dev/null +++ b/test/legacy_test/test_histogram_bin_edges.py @@ -0,0 +1,85 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from eager_op_test import OpTest, paddle_static_guard + +import paddle +from paddle import fluid +from paddle.fluid import Program, program_guard + +class TestHistogramBinEdgesAPI(unittest.TestCase): + """Test histogram_bin_edges api.""" + def setUp(self): + self.input_np = np.random.uniform(-5, 5, [2, 3]).astype(np.float32) + self.bins = 4 + self.range = (0., 3.) + self.place = [paddle.CPUPlace()] + + def test_api_static(self): + paddle.enable_static() + def run(place): + with paddle.static.program_guard(paddle.static.Program()): + inputs = paddle.static.data( + name='input', dtype='float32', shape=[2, 3] + ) + out = paddle.histogram_bin_edges(inputs, self.bins, self.range) + exe = paddle.static.Executor(place) + res = exe.run( + feed={'input': self.input_np}, + fetch_list=[out], + ) + out_ref = np.histogram_bin_edges(self.input_np, self.bins, self.range) + np.testing.assert_allclose(out_ref, res[0], rtol=1e-05) + + for place in self.place: + run(place) + + def test_api_dygraph(self): + def run(place): + paddle.disable_static(place) + inputs = paddle.to_tensor(self.input_np) + out1 = paddle.histogram_bin_edges(inputs, bins=4, range=(0, 3)) + + out_ref1 = np.histogram_bin_edges(self.input_np, bins=4, range=(0, 3)) + np.testing.assert_allclose(out_ref1, out1.numpy(), rtol=1e-05) + paddle.enable_static() + + for place in self.place: + run(place) + + def test_errors(self): + input = paddle.to_tensor(self.input_np) + bins = self.bins + range = self.range + # bin dtype is not int + self.assertRaises( + TypeError, + paddle.histogram_bin_edges, + input, bins=1.5, range=range + ) + # the range len is not equal 2 + self.assertRaises( + ValueError, paddle.histogram_bin_edges, input, bins=bins, range=(0, 2, 3) + ) + # the min of range greater than max + self.assertRaises( + ValueError, paddle.histogram_bin_edges, input, bins=bins, range=(3, 0) + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/legacy_test/test_histogram_op.py b/test/legacy_test/test_histogram_op.py index 36df3209ac76e..fe3fbf56fead8 100644 --- a/test/legacy_test/test_histogram_op.py +++ b/test/legacy_test/test_histogram_op.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,7 +24,6 @@ class TestHistogramOpAPI(unittest.TestCase): """Test histogram api.""" - def test_static_graph(self): startup_program = fluid.Program() train_program = fluid.Program() @@ -32,18 +31,23 @@ def test_static_graph(self): inputs = paddle.static.data( name='input', dtype='int64', shape=[2, 3] ) - output = paddle.histogram(inputs, bins=5, min=1, max=5) + weight = paddle.static.data( + name='weight', dtype='int64', shape=[2, 3] + ) + output = paddle.histogram(inputs, weight, bins=5, min=1, max=5) place = fluid.CPUPlace() if fluid.core.is_compiled_with_cuda(): place = fluid.CUDAPlace(0) exe = fluid.Executor(place) exe.run(startup_program) img = np.array([[2, 4, 2], [2, 5, 4]]).astype(np.int64) + w = np.array([[1, 3, 6], [2, 2, 8]]).astype(np.int64) res = exe.run( - train_program, feed={'input': img}, fetch_list=[output] + train_program, feed={'input': img, 'weight': w}, + fetch_list=[output] ) actual = np.array(res[0]) - expected = np.array([0, 3, 0, 2, 1]).astype(np.int64) + expected = np.array([0, 9, 0, 11, 2]).astype(np.int64) self.assertTrue( (actual == expected).all(), msg='histogram output is wrong, out =' + str(actual), @@ -53,8 +57,8 @@ def test_dygraph(self): with fluid.dygraph.guard(): inputs_np = np.array([[2, 4, 2], [2, 5, 4]]).astype(np.int64) inputs = fluid.dygraph.to_variable(inputs_np) - actual = paddle.histogram(inputs, bins=5, min=1, max=5) - expected = np.array([0, 3, 0, 2, 1]).astype(np.int64) + actual = paddle.histogram(inputs, bins=5, min=1, max=5, density=True) + expected = np.array([0., 0.625, 0., 0.4166667, 0.20833334]).astype(np.float32) self.assertTrue( (actual.numpy() == expected).all(), msg='histogram output is wrong, out =' + str(actual.numpy()), @@ -62,7 +66,7 @@ def test_dygraph(self): inputs_np = np.array([[2, 4, 2], [2, 5, 4]]).astype(np.int64) inputs = paddle.to_tensor(inputs_np) - actual = paddle.histogram(inputs, bins=5, min=1, max=5) + actual = paddle.histogram(inputs, bins=5, min=1, max=5, density=True) self.assertTrue( (actual.numpy() == expected).all(), msg='histogram output is wrong, out =' + str(actual.numpy()), @@ -99,7 +103,7 @@ def net_func(): input_value = paddle.tensor.fill_constant( shape=[3, 4], dtype='float32', value=3.0 ) - paddle.histogram(input=input_value, bins=1, min=5, max=1) + paddle.histogram(input=input_value, bins=1, min=5, max=1, density=False) with self.assertRaises(ValueError): self.run_network(net_func) @@ -122,6 +126,32 @@ def test_type_errors(self): self.assertRaises( TypeError, paddle.histogram, 1, bins=5, min=1, max=5 ) + # The weight type must be Variable. + self.assertRaises( + TypeError, paddle.histogram, [1], weight=1, bins=5, min=1, max=5 + ) + # The weight type must be equal the input type. + x_int32 = paddle.static.data( + name='x_int32', shape=[4, 3], dtype='int32' + ) + weight_float32 = paddle.static.data( + name='weight_float32', shape=[4, 3], dtype='float32' + ) + self.assertRaises( + ValueError, paddle.histogram, x_int32, weight=weight_float32, + bins=5, min=1, max=5 + ) + # The weight shape must be equal the input shape. + x_shape = paddle.static.data( + name='x_shape', shape=[4, 3], dtype='int32' + ) + w_shape = paddle.static.data( + name='w_shape', shape=[3, 4], dtype='int32' + ) + self.assertRaises( + ValueError, paddle.histogram, x_shape, weight=w_shape, + bins=5, min=1, max=5 + ) # The input type must be 'int32', 'int64', 'float32', 'float64' x_bool = paddle.static.data( name='x_bool', shape=[4, 3], dtype='bool' @@ -134,35 +164,134 @@ def test_type_errors(self): class TestHistogramOp(OpTest): def setUp(self): self.op_type = "histogram" - self.init_test_case() - np_input = np.random.uniform(low=0.0, high=20.0, size=self.in_shape) self.python_api = paddle.histogram + self.init_test_case() + np_input = np.random.randint(low=0, high=20, size=self.in_shape).astype(self.dtype) self.inputs = {"X": np_input} self.init_attrs() + self.place = [paddle.CPUPlace()] Out, _ = np.histogram( - np_input, bins=self.bins, range=(self.min, self.max) + np_input, bins=self.bins, range=(self.min, self.max), density=self.density ) - self.outputs = {"Out": Out.astype(np.int64)} + self.outputs = {"Out": Out} def init_test_case(self): self.in_shape = (10, 12) self.bins = 5 self.min = 1 self.max = 5 + self.density = False + self.dtype = np.int32 def init_attrs(self): - self.attrs = {"bins": self.bins, "min": self.min, "max": self.max} + self.attrs = {"bins": self.bins, + "min": self.min, "max": self.max, + "density": self.density} def test_check_output(self): self.check_output() +class TestCase1(TestHistogramOp): + # with weights(FLOAT32) + def setUp(self): + self.op_type = "histogram" + self.init_test_case() + np_input = np.random.uniform(low=0.0, high=20.0, size=self.in_shape).astype(self.dtype) + self.np_weight = np.random.uniform(low=0, high=20, size=self.in_shape).astype(self.dtype) + self.python_api = paddle.histogram + self.inputs = {"X": np_input, "Weight": self.np_weight} + self.init_attrs() + self.place = [paddle.CPUPlace()] + Out, _ = np.histogram( + np_input, bins=self.bins, range=(self.min, self.max), + weights=self.np_weight, density=self.density + ) + self.outputs = {"Out": Out} + + def init_test_case(self): + self.in_shape = (10, 12) + self.bins = 5 + self.min = 1 + self.max = 5 + self.density = False + self.dtype = np.float32 + + +class TestCase2(TestHistogramOp): + # with weights(FLOAT64) + def setUp(self): + self.op_type = "histogram" + self.init_test_case() + np_input = np.random.uniform(low=0.0, high=20.0, size=self.in_shape).astype(self.dtype) + self.np_weight = np.random.uniform(low=0.0, high=20.0, size=self.in_shape).astype(self.dtype) + self.python_api = paddle.histogram + self.inputs = {"X": np_input, "Weight": self.np_weight} + self.init_attrs() + self.place = [paddle.CPUPlace()] + + Out, _ = np.histogram( + np_input, bins=self.bins, range=(self.min, self.max), + weights=self.np_weight, density=self.density + ) + self.outputs = {"Out": Out} + + def init_test_case(self): + self.in_shape = (10, 12) + self.bins = 5 + self.min = 1 + self.max = 5 + self.density = True + self.dtype = np.float64 + + +class TestCase3(TestHistogramOp): + # with weights(INT64) + def setUp(self): + self.op_type = "histogram" + self.init_test_case() + np_input = np.random.randint(low=0, high=20, size=self.in_shape).astype(self.dtype) + self.np_weight = np.random.randint(low=0, high=20, size=self.in_shape).astype(self.dtype) + self.python_api = paddle.histogram + self.inputs = {"X": np_input, "Weight": self.np_weight} + self.init_attrs() + self.place = [paddle.CPUPlace()] + Out, _ = np.histogram( + np_input, bins=self.bins, range=(self.min, self.max), + weights=self.np_weight, density=self.density + ) + self.outputs = {"Out": Out} + + def init_test_case(self): + self.in_shape = (10, 12) + self.bins = 5 + self.min = 1 + self.max = 5 + self.density = False + self.dtype = np.int64 + + class TestHistogramOp_ZeroDim(TestHistogramOp): + def setUp(self): + self.op_type = "histogram" + self.python_api = paddle.histogram + self.init_test_case() + np_input = np.random.randint(low=0, high=20, size=self.in_shape).astype(self.dtype) + self.inputs = {"X": np_input} + self.init_attrs() + self.place = [paddle.CPUPlace()] + Out, _ = np.histogram(np_input, bins=self.bins, range=(self.min, self.max)) + self.outputs = {"Out": Out} + def init_test_case(self): self.in_shape = [] self.bins = 5 self.min = 1 self.max = 5 + self.dtype = np.float32 + + def init_attrs(self): + self.attrs = {"bins": self.bins, "min": self.min, "max": self.max} if __name__ == "__main__": From aaabf00550262ae39e5a4ce4d467efd2e9c5d244 Mon Sep 17 00:00:00 2001 From: TBD1 <798934910@qq.com> Date: Wed, 30 Aug 2023 14:12:12 +0800 Subject: [PATCH 2/2] update --- paddle/phi/kernels/cpu/histogram_kernel.cc | 20 ++-- paddle/phi/kernels/gpu/histogram_kernel.cu | 56 +++++---- python/paddle/tensor/__init__.py | 2 +- python/paddle/tensor/creation.py | 10 +- python/paddle/tensor/linalg.py | 4 +- test/legacy_test/test_histogram_bin_edges.py | 33 ++++-- test/legacy_test/test_histogram_op.py | 114 ++++++++++++++----- 7 files changed, 159 insertions(+), 80 deletions(-) diff --git a/paddle/phi/kernels/cpu/histogram_kernel.cc b/paddle/phi/kernels/cpu/histogram_kernel.cc index b5562688e8ac8..a5462940564a0 100644 --- a/paddle/phi/kernels/cpu/histogram_kernel.cc +++ b/paddle/phi/kernels/cpu/histogram_kernel.cc @@ -70,27 +70,31 @@ void HistogramKernel(const Context& dev_ctx, minval)); bool has_weight = weight.is_initialized(); - auto weight_data = (weight.get_ptr() == nullptr ? nullptr : weight.get_ptr()->data()); + auto weight_data = + (weight.get_ptr() == nullptr ? nullptr : weight.get_ptr()->data()); // compute output if (density) { T total = static_cast(0); - for(int64_t i = 0; i < input_numel; i++) { + for (int64_t i = 0; i < input_numel; i++) { if (input_data[i] >= output_min && input_data[i] <= output_max) { - total += has_weight ? static_cast(weight_data[i]) : static_cast(1); + total += + has_weight ? static_cast(weight_data[i]) : static_cast(1); } } float* out_data = dev_ctx.template Alloc(output); - phi::funcs::SetConstant()(dev_ctx, output, static_cast(0)); + phi::funcs::SetConstant()( + dev_ctx, output, static_cast(0)); - const float interval_len = static_cast(output_max - output_min) / nbins; + const float interval_len = + static_cast(output_max - output_min) / nbins; for (int64_t i = 0; i < input_numel; i++) { if (input_data[i] >= output_min && input_data[i] <= output_max) { const int64_t bin = (int64_t)((input_data[i] - output_min) * nbins / (output_max - output_min)); T weight_idx = weight_data == nullptr ? 1 : weight_data[i]; - out_data[std::min(bin, nbins - 1)] += (static_cast(weight_idx) - / total) / interval_len; + out_data[std::min(bin, nbins - 1)] += + (static_cast(weight_idx) / total) / interval_len; } } } else { @@ -102,7 +106,7 @@ void HistogramKernel(const Context& dev_ctx, (output_max - output_min)); T weight_idx = weight_data == nullptr ? 1 : weight_data[i]; out_data[std::min(bin, nbins - 1)] += weight_idx; - } + } } } } diff --git a/paddle/phi/kernels/gpu/histogram_kernel.cu b/paddle/phi/kernels/gpu/histogram_kernel.cu index cea9977d3339f..5f6dad3ebd420 100644 --- a/paddle/phi/kernels/gpu/histogram_kernel.cu +++ b/paddle/phi/kernels/gpu/histogram_kernel.cu @@ -53,9 +53,8 @@ __global__ void KernelHistogram(const T* input, const T* min_value, const T* max_value, T* output) { - // extern __shared__ T buf_hist1[]; extern __shared__ __align__(sizeof(T)) unsigned char buf_hist_tmp[]; - T *buf_hist = reinterpret_cast(buf_hist_tmp); + T* buf_hist = reinterpret_cast(buf_hist_tmp); for (int i = threadIdx.x; i < nbins; i += blockDim.x) { buf_hist[i] = 0; } @@ -64,7 +63,8 @@ __global__ void KernelHistogram(const T* input, CUDA_KERNEL_LOOP(input_index, total_elements) { // const IndexType input_index = threadIdx.x + blockIdx.x * blockDim.x; const auto input_value = input[input_index]; - const auto weight_value = has_weight ? weight[input_index] : static_cast(1); + const auto weight_value = + has_weight ? weight[input_index] : static_cast(1); if (input_value >= *min_value && input_value <= *max_value) { const IndexType output_index = GetBin(input_value, *min_value, *max_value, nbins); @@ -90,9 +90,8 @@ __global__ void KernelHistogramDensity(const T* input, T count_weight = 0; T total_weight; __shared__ T total[BlockSize]; - // extern __shared__ float buf_histd[]; extern __shared__ __align__(sizeof(float)) unsigned char buf_histd_tmp[]; - float *buf_histd = reinterpret_cast(buf_histd_tmp); + float* buf_histd = reinterpret_cast(buf_histd_tmp); for (int i = threadIdx.x; i < (total_elements); i += BlockSize) { const auto input_value = input[i]; @@ -122,9 +121,9 @@ __global__ void KernelHistogramDensity(const T* input, buf_histd[i] = 0; } __syncthreads(); - - const float interval_len = static_cast(*max_value - *min_value) - / nbins; + + const float interval_len = + static_cast(*max_value - *min_value) / nbins; CUDA_KERNEL_LOOP(input_index, total_elements) { // const IndexType input_index = threadIdx.x + blockIdx.x * blockDim.x; const auto input_value = input[input_index]; @@ -132,9 +131,8 @@ __global__ void KernelHistogramDensity(const T* input, if (input_value >= *min_value && input_value <= *max_value) { const IndexType output_index = GetBin(input_value, *min_value, *max_value, nbins); - float prob_value = static_cast(weight_value) - / static_cast(total_weight) - / interval_len; + float prob_value = static_cast(weight_value) / + static_cast(total_weight) / interval_len; phi::CudaAtomicAdd(&buf_histd[output_index], prob_value); } } @@ -259,22 +257,36 @@ void HistogramKernel(const Context& dev_ctx, const T* weight_data = has_weight ? weight->data() : nullptr; auto stream = dev_ctx.stream(); - if(!density) { + if (!density) { T* out_data = dev_ctx.template Alloc(output); phi::funcs::SetConstant()(dev_ctx, output, static_cast(0)); - KernelHistogram - <<>>( - input_data, input_numel, has_weight, weight_data, - nbins, min_block_ptr, max_block_ptr, out_data); + KernelHistogram<<>>(input_data, + input_numel, + has_weight, + weight_data, + nbins, + min_block_ptr, + max_block_ptr, + out_data); } else { float* out_data = dev_ctx.template Alloc(output); - phi::funcs::SetConstant()(dev_ctx, output, static_cast(0)); + phi::funcs::SetConstant()( + dev_ctx, output, static_cast(0)); KernelHistogramDensity - <<>>( - input_data, input_numel, has_weight, weight_data, - nbins, min_block_ptr, max_block_ptr, out_data); + <<>>(input_data, + input_numel, + has_weight, + weight_data, + nbins, + min_block_ptr, + max_block_ptr, + out_data); } } diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 65ef37817afa1..a3224f5769b98 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -658,7 +658,7 @@ 'trapezoid', 'cumulative_trapezoid', 'polar', - 'histogram_bin_edges' + 'histogram_bin_edges', 'sigmoid', 'sigmoid_', 'vander', diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index e8e8bc13d7240..f57d0c6fc2d49 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -2664,17 +2664,17 @@ def polar(abs, angle, name=None): def histogram_bin_edges(input, bins=100, range=None, weight=None, name=None): """ Computes only the edges of the bins used by the histogram function. - + Args: input (Tensor): A Tensor(or LoDTensor) with shape :math:`[N_1, N_2,..., N_k]` . The data type of the input Tensor should be float32, float64, int32, int64. bins (int, optional): number of histogram bins. range (list | tuple): The lower and upper range of the bins. If None, `range` is simply (input.min(), input.max()). The first element of the range must be less than or equal to the second. Default: None. - weight (Tensor, optional): Weight for each value in the input tensor. Should have the same shape and data type as input. + weight (Tensor, optional): Weight for each value in the input tensor. Should have the same shape and data type as input. This is currently not used by any of the bin estimators, but may be in the future. Default: None. name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. - + Returns: Tensor: the values of the histogram and the bin edges. The output data type will be float32. @@ -2686,7 +2686,7 @@ def histogram_bin_edges(input, bins=100, range=None, weight=None, name=None): inputs = paddle.to_tensor([1, 2, 1]) result = paddle.histogram_bin_edges(inputs, bins=4, range=(0, 3)) print(result) # [0., 0.75, 1.5, 2.25, 3.] - + """ check_type(input, 'input', (Variable), 'histogram_bin_edges') check_dtype( @@ -2709,4 +2709,4 @@ def histogram_bin_edges(input, bins=100, range=None, weight=None, name=None): if (stop - start) == 0: start = start - 0.5 stop = stop + 0.5 - return linspace(start, stop, bins+1) + return linspace(start, stop, bins + 1) diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index ff04d0f1787ed..b9639a5eef497 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -1631,7 +1631,9 @@ def bmm(x, y, name=None): return out -def histogram(input, weight=None, bins=100, min=0, max=0, density=False, name=None): +def histogram( + input, weight=None, bins=100, min=0, max=0, density=False, name=None +): """ Computes the histogram of a tensor. The elements are sorted into equal width bins between min and max. If min and max are both zero, the minimum and maximum values of the data are used. diff --git a/test/legacy_test/test_histogram_bin_edges.py b/test/legacy_test/test_histogram_bin_edges.py index 4f50ddab39f57..c5c6b18558b5f 100644 --- a/test/legacy_test/test_histogram_bin_edges.py +++ b/test/legacy_test/test_histogram_bin_edges.py @@ -15,22 +15,22 @@ import unittest import numpy as np -from eager_op_test import OpTest, paddle_static_guard import paddle -from paddle import fluid -from paddle.fluid import Program, program_guard + class TestHistogramBinEdgesAPI(unittest.TestCase): """Test histogram_bin_edges api.""" + def setUp(self): self.input_np = np.random.uniform(-5, 5, [2, 3]).astype(np.float32) self.bins = 4 - self.range = (0., 3.) + self.range = (0.0, 3.0) self.place = [paddle.CPUPlace()] def test_api_static(self): paddle.enable_static() + def run(place): with paddle.static.program_guard(paddle.static.Program()): inputs = paddle.static.data( @@ -42,7 +42,9 @@ def run(place): feed={'input': self.input_np}, fetch_list=[out], ) - out_ref = np.histogram_bin_edges(self.input_np, self.bins, self.range) + out_ref = np.histogram_bin_edges( + self.input_np, self.bins, self.range + ) np.testing.assert_allclose(out_ref, res[0], rtol=1e-05) for place in self.place: @@ -53,8 +55,9 @@ def run(place): paddle.disable_static(place) inputs = paddle.to_tensor(self.input_np) out1 = paddle.histogram_bin_edges(inputs, bins=4, range=(0, 3)) - - out_ref1 = np.histogram_bin_edges(self.input_np, bins=4, range=(0, 3)) + out_ref1 = np.histogram_bin_edges( + self.input_np, bins=4, range=(0, 3) + ) np.testing.assert_allclose(out_ref1, out1.numpy(), rtol=1e-05) paddle.enable_static() @@ -67,17 +70,23 @@ def test_errors(self): range = self.range # bin dtype is not int self.assertRaises( - TypeError, - paddle.histogram_bin_edges, - input, bins=1.5, range=range + TypeError, paddle.histogram_bin_edges, input, bins=1.5, range=range ) # the range len is not equal 2 self.assertRaises( - ValueError, paddle.histogram_bin_edges, input, bins=bins, range=(0, 2, 3) + ValueError, + paddle.histogram_bin_edges, + input, + bins=bins, + range=(0, 2, 3), ) # the min of range greater than max self.assertRaises( - ValueError, paddle.histogram_bin_edges, input, bins=bins, range=(3, 0) + ValueError, + paddle.histogram_bin_edges, + input, + bins=bins, + range=(3, 0), ) diff --git a/test/legacy_test/test_histogram_op.py b/test/legacy_test/test_histogram_op.py index fe3fbf56fead8..8265ea12e6c9a 100644 --- a/test/legacy_test/test_histogram_op.py +++ b/test/legacy_test/test_histogram_op.py @@ -24,6 +24,7 @@ class TestHistogramOpAPI(unittest.TestCase): """Test histogram api.""" + def test_static_graph(self): startup_program = fluid.Program() train_program = fluid.Program() @@ -43,8 +44,9 @@ def test_static_graph(self): img = np.array([[2, 4, 2], [2, 5, 4]]).astype(np.int64) w = np.array([[1, 3, 6], [2, 2, 8]]).astype(np.int64) res = exe.run( - train_program, feed={'input': img, 'weight': w}, - fetch_list=[output] + train_program, + feed={'input': img, 'weight': w}, + fetch_list=[output], ) actual = np.array(res[0]) expected = np.array([0, 9, 0, 11, 2]).astype(np.int64) @@ -57,16 +59,21 @@ def test_dygraph(self): with fluid.dygraph.guard(): inputs_np = np.array([[2, 4, 2], [2, 5, 4]]).astype(np.int64) inputs = fluid.dygraph.to_variable(inputs_np) - actual = paddle.histogram(inputs, bins=5, min=1, max=5, density=True) - expected = np.array([0., 0.625, 0., 0.4166667, 0.20833334]).astype(np.float32) + actual = paddle.histogram( + inputs, bins=5, min=1, max=5, density=True + ) + expected = np.array( + [0.0, 0.625, 0.0, 0.4166667, 0.20833334] + ).astype(np.float32) self.assertTrue( (actual.numpy() == expected).all(), msg='histogram output is wrong, out =' + str(actual.numpy()), ) - inputs_np = np.array([[2, 4, 2], [2, 5, 4]]).astype(np.int64) inputs = paddle.to_tensor(inputs_np) - actual = paddle.histogram(inputs, bins=5, min=1, max=5, density=True) + actual = paddle.histogram( + inputs, bins=5, min=1, max=5, density=True + ) self.assertTrue( (actual.numpy() == expected).all(), msg='histogram output is wrong, out =' + str(actual.numpy()), @@ -103,7 +110,9 @@ def net_func(): input_value = paddle.tensor.fill_constant( shape=[3, 4], dtype='float32', value=3.0 ) - paddle.histogram(input=input_value, bins=1, min=5, max=1, density=False) + paddle.histogram( + input=input_value, bins=1, min=5, max=1, density=False + ) with self.assertRaises(ValueError): self.run_network(net_func) @@ -138,8 +147,13 @@ def test_type_errors(self): name='weight_float32', shape=[4, 3], dtype='float32' ) self.assertRaises( - ValueError, paddle.histogram, x_int32, weight=weight_float32, - bins=5, min=1, max=5 + ValueError, + paddle.histogram, + x_int32, + weight=weight_float32, + bins=5, + min=1, + max=5, ) # The weight shape must be equal the input shape. x_shape = paddle.static.data( @@ -149,8 +163,13 @@ def test_type_errors(self): name='w_shape', shape=[3, 4], dtype='int32' ) self.assertRaises( - ValueError, paddle.histogram, x_shape, weight=w_shape, - bins=5, min=1, max=5 + ValueError, + paddle.histogram, + x_shape, + weight=w_shape, + bins=5, + min=1, + max=5, ) # The input type must be 'int32', 'int64', 'float32', 'float64' x_bool = paddle.static.data( @@ -166,12 +185,17 @@ def setUp(self): self.op_type = "histogram" self.python_api = paddle.histogram self.init_test_case() - np_input = np.random.randint(low=0, high=20, size=self.in_shape).astype(self.dtype) + np_input = np.random.randint(low=0, high=20, size=self.in_shape).astype( + self.dtype + ) self.inputs = {"X": np_input} self.init_attrs() self.place = [paddle.CPUPlace()] Out, _ = np.histogram( - np_input, bins=self.bins, range=(self.min, self.max), density=self.density + np_input, + bins=self.bins, + range=(self.min, self.max), + density=self.density, ) self.outputs = {"Out": Out} @@ -184,9 +208,12 @@ def init_test_case(self): self.dtype = np.int32 def init_attrs(self): - self.attrs = {"bins": self.bins, - "min": self.min, "max": self.max, - "density": self.density} + self.attrs = { + "bins": self.bins, + "min": self.min, + "max": self.max, + "density": self.density, + } def test_check_output(self): self.check_output() @@ -197,15 +224,22 @@ class TestCase1(TestHistogramOp): def setUp(self): self.op_type = "histogram" self.init_test_case() - np_input = np.random.uniform(low=0.0, high=20.0, size=self.in_shape).astype(self.dtype) - self.np_weight = np.random.uniform(low=0, high=20, size=self.in_shape).astype(self.dtype) + np_input = np.random.uniform( + low=0.0, high=20.0, size=self.in_shape + ).astype(self.dtype) + self.np_weight = np.random.uniform( + low=0, high=20, size=self.in_shape + ).astype(self.dtype) self.python_api = paddle.histogram self.inputs = {"X": np_input, "Weight": self.np_weight} self.init_attrs() self.place = [paddle.CPUPlace()] Out, _ = np.histogram( - np_input, bins=self.bins, range=(self.min, self.max), - weights=self.np_weight, density=self.density + np_input, + bins=self.bins, + range=(self.min, self.max), + weights=self.np_weight, + density=self.density, ) self.outputs = {"Out": Out} @@ -223,16 +257,23 @@ class TestCase2(TestHistogramOp): def setUp(self): self.op_type = "histogram" self.init_test_case() - np_input = np.random.uniform(low=0.0, high=20.0, size=self.in_shape).astype(self.dtype) - self.np_weight = np.random.uniform(low=0.0, high=20.0, size=self.in_shape).astype(self.dtype) + np_input = np.random.uniform( + low=0.0, high=20.0, size=self.in_shape + ).astype(self.dtype) + self.np_weight = np.random.uniform( + low=0.0, high=20.0, size=self.in_shape + ).astype(self.dtype) self.python_api = paddle.histogram self.inputs = {"X": np_input, "Weight": self.np_weight} self.init_attrs() self.place = [paddle.CPUPlace()] Out, _ = np.histogram( - np_input, bins=self.bins, range=(self.min, self.max), - weights=self.np_weight, density=self.density + np_input, + bins=self.bins, + range=(self.min, self.max), + weights=self.np_weight, + density=self.density, ) self.outputs = {"Out": Out} @@ -250,15 +291,22 @@ class TestCase3(TestHistogramOp): def setUp(self): self.op_type = "histogram" self.init_test_case() - np_input = np.random.randint(low=0, high=20, size=self.in_shape).astype(self.dtype) - self.np_weight = np.random.randint(low=0, high=20, size=self.in_shape).astype(self.dtype) + np_input = np.random.randint(low=0, high=20, size=self.in_shape).astype( + self.dtype + ) + self.np_weight = np.random.randint( + low=0, high=20, size=self.in_shape + ).astype(self.dtype) self.python_api = paddle.histogram self.inputs = {"X": np_input, "Weight": self.np_weight} self.init_attrs() self.place = [paddle.CPUPlace()] Out, _ = np.histogram( - np_input, bins=self.bins, range=(self.min, self.max), - weights=self.np_weight, density=self.density + np_input, + bins=self.bins, + range=(self.min, self.max), + weights=self.np_weight, + density=self.density, ) self.outputs = {"Out": Out} @@ -276,11 +324,15 @@ def setUp(self): self.op_type = "histogram" self.python_api = paddle.histogram self.init_test_case() - np_input = np.random.randint(low=0, high=20, size=self.in_shape).astype(self.dtype) + np_input = np.random.randint(low=0, high=20, size=self.in_shape).astype( + self.dtype + ) self.inputs = {"X": np_input} self.init_attrs() self.place = [paddle.CPUPlace()] - Out, _ = np.histogram(np_input, bins=self.bins, range=(self.min, self.max)) + Out, _ = np.histogram( + np_input, bins=self.bins, range=(self.min, self.max) + ) self.outputs = {"Out": Out} def init_test_case(self): @@ -289,7 +341,7 @@ def init_test_case(self): self.min = 1 self.max = 5 self.dtype = np.float32 - + def init_attrs(self): self.attrs = {"bins": self.bins, "min": self.min, "max": self.max}