Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
TBD1 committed Aug 30, 2023
1 parent 66c0d60 commit aaabf00
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 80 deletions.
20 changes: 12 additions & 8 deletions paddle/phi/kernels/cpu/histogram_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,27 +70,31 @@ void HistogramKernel(const Context& dev_ctx,
minval));

bool has_weight = weight.is_initialized();
auto weight_data = (weight.get_ptr() == nullptr ? nullptr : weight.get_ptr()->data<T>());
auto weight_data =
(weight.get_ptr() == nullptr ? nullptr : weight.get_ptr()->data<T>());

// compute output
if (density) {
T total = static_cast<T>(0);
for(int64_t i = 0; i < input_numel; i++) {
for (int64_t i = 0; i < input_numel; i++) {
if (input_data[i] >= output_min && input_data[i] <= output_max) {
total += has_weight ? static_cast<T>(weight_data[i]) : static_cast<T>(1);
total +=
has_weight ? static_cast<T>(weight_data[i]) : static_cast<T>(1);
}
}
float* out_data = dev_ctx.template Alloc<float>(output);
phi::funcs::SetConstant<Context, float>()(dev_ctx, output, static_cast<float>(0));
phi::funcs::SetConstant<Context, float>()(
dev_ctx, output, static_cast<float>(0));

const float interval_len = static_cast<float>(output_max - output_min) / nbins;
const float interval_len =
static_cast<float>(output_max - output_min) / nbins;
for (int64_t i = 0; i < input_numel; i++) {
if (input_data[i] >= output_min && input_data[i] <= output_max) {
const int64_t bin = (int64_t)((input_data[i] - output_min) * nbins /
(output_max - output_min));
T weight_idx = weight_data == nullptr ? 1 : weight_data[i];
out_data[std::min(bin, nbins - 1)] += (static_cast<float>(weight_idx)
/ total) / interval_len;
out_data[std::min(bin, nbins - 1)] +=
(static_cast<float>(weight_idx) / total) / interval_len;
}
}
} else {
Expand All @@ -102,7 +106,7 @@ void HistogramKernel(const Context& dev_ctx,
(output_max - output_min));
T weight_idx = weight_data == nullptr ? 1 : weight_data[i];
out_data[std::min(bin, nbins - 1)] += weight_idx;
}
}
}
}
}
Expand Down
56 changes: 34 additions & 22 deletions paddle/phi/kernels/gpu/histogram_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,8 @@ __global__ void KernelHistogram(const T* input,
const T* min_value,
const T* max_value,
T* output) {
// extern __shared__ T buf_hist1[];
extern __shared__ __align__(sizeof(T)) unsigned char buf_hist_tmp[];
T *buf_hist = reinterpret_cast<T *>(buf_hist_tmp);
T* buf_hist = reinterpret_cast<T*>(buf_hist_tmp);
for (int i = threadIdx.x; i < nbins; i += blockDim.x) {
buf_hist[i] = 0;
}
Expand All @@ -64,7 +63,8 @@ __global__ void KernelHistogram(const T* input,
CUDA_KERNEL_LOOP(input_index, total_elements) {
// const IndexType input_index = threadIdx.x + blockIdx.x * blockDim.x;
const auto input_value = input[input_index];
const auto weight_value = has_weight ? weight[input_index] : static_cast<T>(1);
const auto weight_value =
has_weight ? weight[input_index] : static_cast<T>(1);
if (input_value >= *min_value && input_value <= *max_value) {
const IndexType output_index =
GetBin<T, IndexType>(input_value, *min_value, *max_value, nbins);
Expand All @@ -90,9 +90,8 @@ __global__ void KernelHistogramDensity(const T* input,
T count_weight = 0;
T total_weight;
__shared__ T total[BlockSize];
// extern __shared__ float buf_histd[];
extern __shared__ __align__(sizeof(float)) unsigned char buf_histd_tmp[];
float *buf_histd = reinterpret_cast<float *>(buf_histd_tmp);
float* buf_histd = reinterpret_cast<float*>(buf_histd_tmp);

for (int i = threadIdx.x; i < (total_elements); i += BlockSize) {
const auto input_value = input[i];
Expand Down Expand Up @@ -122,19 +121,18 @@ __global__ void KernelHistogramDensity(const T* input,
buf_histd[i] = 0;
}
__syncthreads();
const float interval_len = static_cast<float>(*max_value - *min_value)
/ nbins;

const float interval_len =
static_cast<float>(*max_value - *min_value) / nbins;
CUDA_KERNEL_LOOP(input_index, total_elements) {
// const IndexType input_index = threadIdx.x + blockIdx.x * blockDim.x;
const auto input_value = input[input_index];
auto weight_value = has_weight ? weight[input_index] : static_cast<T>(1);
if (input_value >= *min_value && input_value <= *max_value) {
const IndexType output_index =
GetBin<T, IndexType>(input_value, *min_value, *max_value, nbins);
float prob_value = static_cast<float>(weight_value)
/ static_cast<float>(total_weight)
/ interval_len;
float prob_value = static_cast<float>(weight_value) /
static_cast<float>(total_weight) / interval_len;
phi::CudaAtomicAdd(&buf_histd[output_index], prob_value);
}
}
Expand Down Expand Up @@ -259,22 +257,36 @@ void HistogramKernel(const Context& dev_ctx,
const T* weight_data = has_weight ? weight->data<T>() : nullptr;

auto stream = dev_ctx.stream();
if(!density) {
if (!density) {
T* out_data = dev_ctx.template Alloc<T>(output);
phi::funcs::SetConstant<Context, T>()(dev_ctx, output, static_cast<T>(0));
KernelHistogram<T, IndexType>
<<<GET_BLOCKS(input_numel), PADDLE_CUDA_NUM_THREADS,
nbins * sizeof(int64_t), stream>>>(
input_data, input_numel, has_weight, weight_data,
nbins, min_block_ptr, max_block_ptr, out_data);
KernelHistogram<T, IndexType><<<GET_BLOCKS(input_numel),
PADDLE_CUDA_NUM_THREADS,
nbins * sizeof(int64_t),
stream>>>(input_data,
input_numel,
has_weight,
weight_data,
nbins,
min_block_ptr,
max_block_ptr,
out_data);
} else {
float* out_data = dev_ctx.template Alloc<float>(output);
phi::funcs::SetConstant<Context, float>()(dev_ctx, output, static_cast<float>(0));
phi::funcs::SetConstant<Context, float>()(
dev_ctx, output, static_cast<float>(0));
KernelHistogramDensity<PADDLE_CUDA_NUM_THREADS, T, IndexType>
<<<GET_BLOCKS(input_numel), PADDLE_CUDA_NUM_THREADS,
nbins * sizeof(int64_t), stream>>>(
input_data, input_numel, has_weight, weight_data,
nbins, min_block_ptr, max_block_ptr, out_data);
<<<GET_BLOCKS(input_numel),
PADDLE_CUDA_NUM_THREADS,
nbins * sizeof(int64_t),
stream>>>(input_data,
input_numel,
has_weight,
weight_data,
nbins,
min_block_ptr,
max_block_ptr,
out_data);
}
}

Expand Down
2 changes: 1 addition & 1 deletion python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@
'trapezoid',
'cumulative_trapezoid',
'polar',
'histogram_bin_edges'
'histogram_bin_edges',
'sigmoid',
'sigmoid_',
'vander',
Expand Down
10 changes: 5 additions & 5 deletions python/paddle/tensor/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2664,17 +2664,17 @@ def polar(abs, angle, name=None):
def histogram_bin_edges(input, bins=100, range=None, weight=None, name=None):
"""
Computes only the edges of the bins used by the histogram function.
Args:
input (Tensor): A Tensor(or LoDTensor) with shape :math:`[N_1, N_2,..., N_k]` . The data type of the input Tensor
should be float32, float64, int32, int64.
bins (int, optional): number of histogram bins.
range (list | tuple): The lower and upper range of the bins. If None, `range` is simply (input.min(), input.max()).
The first element of the range must be less than or equal to the second. Default: None.
weight (Tensor, optional): Weight for each value in the input tensor. Should have the same shape and data type as input.
weight (Tensor, optional): Weight for each value in the input tensor. Should have the same shape and data type as input.
This is currently not used by any of the bin estimators, but may be in the future. Default: None.
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
Returns:
Tensor: the values of the histogram and the bin edges. The output data type will be float32.
Expand All @@ -2686,7 +2686,7 @@ def histogram_bin_edges(input, bins=100, range=None, weight=None, name=None):
inputs = paddle.to_tensor([1, 2, 1])
result = paddle.histogram_bin_edges(inputs, bins=4, range=(0, 3))
print(result) # [0., 0.75, 1.5, 2.25, 3.]
"""
check_type(input, 'input', (Variable), 'histogram_bin_edges')
check_dtype(
Expand All @@ -2709,4 +2709,4 @@ def histogram_bin_edges(input, bins=100, range=None, weight=None, name=None):
if (stop - start) == 0:
start = start - 0.5
stop = stop + 0.5
return linspace(start, stop, bins+1)
return linspace(start, stop, bins + 1)
4 changes: 3 additions & 1 deletion python/paddle/tensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1631,7 +1631,9 @@ def bmm(x, y, name=None):
return out


def histogram(input, weight=None, bins=100, min=0, max=0, density=False, name=None):
def histogram(
input, weight=None, bins=100, min=0, max=0, density=False, name=None
):
"""
Computes the histogram of a tensor. The elements are sorted into equal width bins between min and max.
If min and max are both zero, the minimum and maximum values of the data are used.
Expand Down
33 changes: 21 additions & 12 deletions test/legacy_test/test_histogram_bin_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,22 @@
import unittest

import numpy as np
from eager_op_test import OpTest, paddle_static_guard

import paddle
from paddle import fluid
from paddle.fluid import Program, program_guard


class TestHistogramBinEdgesAPI(unittest.TestCase):
"""Test histogram_bin_edges api."""

def setUp(self):
self.input_np = np.random.uniform(-5, 5, [2, 3]).astype(np.float32)
self.bins = 4
self.range = (0., 3.)
self.range = (0.0, 3.0)
self.place = [paddle.CPUPlace()]

def test_api_static(self):
paddle.enable_static()

def run(place):
with paddle.static.program_guard(paddle.static.Program()):
inputs = paddle.static.data(
Expand All @@ -42,7 +42,9 @@ def run(place):
feed={'input': self.input_np},
fetch_list=[out],
)
out_ref = np.histogram_bin_edges(self.input_np, self.bins, self.range)
out_ref = np.histogram_bin_edges(
self.input_np, self.bins, self.range
)
np.testing.assert_allclose(out_ref, res[0], rtol=1e-05)

for place in self.place:
Expand All @@ -53,8 +55,9 @@ def run(place):
paddle.disable_static(place)
inputs = paddle.to_tensor(self.input_np)
out1 = paddle.histogram_bin_edges(inputs, bins=4, range=(0, 3))

out_ref1 = np.histogram_bin_edges(self.input_np, bins=4, range=(0, 3))
out_ref1 = np.histogram_bin_edges(
self.input_np, bins=4, range=(0, 3)
)
np.testing.assert_allclose(out_ref1, out1.numpy(), rtol=1e-05)
paddle.enable_static()

Expand All @@ -67,17 +70,23 @@ def test_errors(self):
range = self.range
# bin dtype is not int
self.assertRaises(
TypeError,
paddle.histogram_bin_edges,
input, bins=1.5, range=range
TypeError, paddle.histogram_bin_edges, input, bins=1.5, range=range
)
# the range len is not equal 2
self.assertRaises(
ValueError, paddle.histogram_bin_edges, input, bins=bins, range=(0, 2, 3)
ValueError,
paddle.histogram_bin_edges,
input,
bins=bins,
range=(0, 2, 3),
)
# the min of range greater than max
self.assertRaises(
ValueError, paddle.histogram_bin_edges, input, bins=bins, range=(3, 0)
ValueError,
paddle.histogram_bin_edges,
input,
bins=bins,
range=(3, 0),
)


Expand Down
Loading

0 comments on commit aaabf00

Please sign in to comment.