Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Zero-Dim] paddle.nanmedian/nanquantile support 0D Tensor #54500

Merged
merged 2 commits into from
Jun 14, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 30 additions & 20 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2260,37 +2260,47 @@ void NanmedianInferMeta(const MetaTensor& x,
for (int64_t i = 0; i < x_rank; i++) {
out_dim.push_back(1);
}
} else {
out_dim.push_back(1);
}
} else {
std::vector<int64_t> cleaned_axis;
std::vector<int64_t> formated_axis;
for (auto& axis : axis_list) {
if (x_rank == 0) {
PADDLE_ENFORCE_EQ(axis == 0 || axis == -1,
true,
phi::errors::InvalidArgument(
"When input 0D Tensor, each element of the axis "
"can only be -1, 0, None"));
} else {
PADDLE_ENFORCE_LT(axis,
x_rank,
errors::InvalidArgument(
"each element of the axis should be in the "
"range [ -dimension(X), dimension(X) ) "
"which dimesion = %d. But received axis = %d.",
x_rank,
axis));
PADDLE_ENFORCE_GE(axis,
-x_rank,
errors::InvalidArgument(
"each element of the axis should be in the "
"range [ -dimension(X), dimension(X) ) "
"which dimesion = %d. But received axis = %d.",
x_rank,
axis));
}
if (axis < 0) axis += x_rank;

PADDLE_ENFORCE_LT(
axis,
x_rank,
errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], R is "
"the rank of Input(X). But received axis: %d, R: %d. "
"Current Input(X)'s shape is=[%s].",
axis,
x_rank,
x_dim));

PADDLE_ENFORCE_EQ(
std::find(cleaned_axis.begin(), cleaned_axis.end(), axis),
cleaned_axis.end(),
std::find(formated_axis.begin(), formated_axis.end(), axis),
formated_axis.end(),
errors::InvalidArgument("Attr(axes) has duplicated elements: %d.",
static_cast<int>(axis)));

cleaned_axis.push_back(axis);
formated_axis.push_back(axis);
}

for (int64_t i = 0; i < x_rank; i++) {
if (std::find(cleaned_axis.begin(), cleaned_axis.end(), i) ==
cleaned_axis.end()) {
if (std::find(formated_axis.begin(), formated_axis.end(), i) ==
formated_axis.end()) {
out_dim.push_back(x_dim[i]);
} else if (keep_dim) {
out_dim.push_back(1);
Expand Down
73 changes: 35 additions & 38 deletions paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/impl/nanmedian_grad_kernel_impl.h"
#include "paddle/phi/kernels/funcs/nanmedian_utils.h"

namespace phi {

Expand All @@ -26,67 +26,64 @@ void CalcMedianGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& median_index,
const DenseTensor& out_grad,
const IntArray& axes UNUSED,
DenseTensor* x_grad,
T* x_grad_ptr) {
DenseTensor* x_grad) {
T* dx_data = dev_ctx.template Alloc<T>(x_grad);
if (!dx_data) return;

phi::funcs::SetConstant<Context, T> set_zero;
set_zero(dev_ctx, x_grad, static_cast<T>(0));
if (!x_grad_ptr) return;

const int64_t* m_ptr = median_index.data<int64_t>();
const T* out_grad_ptr = out_grad.data<T>();
const int64_t* m_data = median_index.data<int64_t>();
const T* dout_data = out_grad.data<T>();
int64_t numel = x.numel();
auto x_dim = x.dims();
int64_t rank = x_dim.size();
int64_t stride = x_dim[rank - 1];

int64_t pre_dim = numel / stride;

int64_t i = 0;
int64_t offset = 0;
T div_factor = static_cast<T>(2.0);
for (i = 0; i < pre_dim; i++) {
if (m_ptr[2 * i] >= 0) {
if (m_ptr[2 * i] == m_ptr[2 * i + 1]) {
x_grad_ptr[offset + m_ptr[2 * i]] = out_grad_ptr[i];
if (m_data[2 * i] >= 0) {
if (m_data[2 * i] == m_data[2 * i + 1]) {
dx_data[offset + m_data[2 * i]] = dout_data[i];
} else {
x_grad_ptr[offset + m_ptr[2 * i]] = out_grad_ptr[i] / div_factor;
x_grad_ptr[offset + m_ptr[2 * i + 1]] = out_grad_ptr[i] / div_factor;
dx_data[offset + m_data[2 * i]] = dout_data[i] / static_cast<T>(2.0);
dx_data[offset + m_data[2 * i + 1]] =
dout_data[i] / static_cast<T>(2.0);
}
}
offset += stride;
}
}

template <typename T, typename Context>
void BaseMedianGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& median_index,
const DenseTensor& out_grad,
const IntArray& axes,
DenseTensor* x_grad) {
auto rank = x.dims().size();
T* x_grad_ptr = dev_ctx.template Alloc<T>(x_grad);
if (axes.size() && (rank > 1)) {
DenseTensor tmp_x_grad(*x_grad);
CalcMedianGradKernel<T, Context>(
dev_ctx, x, median_index, out_grad, axes, &tmp_x_grad, x_grad_ptr);
PostprocessMedianGradKernel<T, Context>(dev_ctx, &tmp_x_grad, axes, x_grad);
} else {
CalcMedianGradKernel<T, Context>(
dev_ctx, x, median_index, out_grad, axes, x_grad, x_grad_ptr);
}
}

template <typename T, typename Context>
void NanmedianGradKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& x,
const DenseTensor& median_index,
const DenseTensor& out_grad,
const IntArray& axes,
bool keep_dim UNUSED,
bool keepdim UNUSED,
DenseTensor* x_grad) {
BaseMedianGradKernel<T, Context>(
dev_ctx, input, median_index, out_grad, axes, x_grad);
DenseTensor tmp_x;
auto rank = x.dims().size();
if ((axes.size() == 0) || rank <= 1) {
tmp_x = x;
tmp_x.Resize({x.numel()});
CalcMedianGradKernel<T, Context>(
dev_ctx, tmp_x, median_index, out_grad, x_grad);
} else {
funcs::PreprocessMedianKernel<T, Context>(dev_ctx, x, axes, &tmp_x);

DenseTensor tmp_x_grad;
tmp_x_grad.Resize(x_grad->dims());
CalcMedianGradKernel<T, Context>(
dev_ctx, tmp_x, median_index, out_grad, &tmp_x_grad);

dev_ctx.template Alloc<T>(x_grad);
funcs::PostprocessMedianGradKernel<T, Context>(
dev_ctx, &tmp_x_grad, axes, x_grad);
}
}

} // namespace phi
Expand Down
69 changes: 28 additions & 41 deletions paddle/phi/kernels/cpu/nanmedian_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/nanmedian_kernel_impl.h"
#include "paddle/phi/kernels/funcs/nanmedian_utils.h"
#include "paddle/phi/kernels/top_k_kernel.h"

namespace phi {
Expand All @@ -31,7 +31,6 @@ void CalcMedianFunc(const Context& dev_ctx,
int64_t pre_dim,
T* o_ptr,
int64_t* m_ptr) {
bool should_ignore_nan = ignore_nan;
DenseTensor sort_out;
DenseTensor sort_indices;
auto sort_dim = x.dims();
Expand All @@ -52,7 +51,7 @@ void CalcMedianFunc(const Context& dev_ctx,
int64_t offset = 0;
int64_t i = 0;
bool is_ori_odd = stride & 1;
if (should_ignore_nan) {
if (ignore_nan) {
for (i = 0; i < pre_dim; i++) {
offset = i * sort_k;
if (nan_counts[i] == stride) {
Expand Down Expand Up @@ -107,11 +106,11 @@ void CalcMedianFunc(const Context& dev_ctx,
template <typename T, typename Context>
void ProcessMedianKernel(const Context& dev_ctx,
const DenseTensor& x,
T* o_ptr,
int64_t* m_ptr,
bool ignore_nan) {
bool should_ignore_nan = ignore_nan;
const T* x_ptr = x.data<T>();
DenseTensor* out,
DenseTensor* median_index) {
const T* x_data = x.data<T>();
T* out_data = dev_ctx.template Alloc<T>(out);
int64_t* m_data = dev_ctx.template Alloc<int64_t>(median_index);

int64_t numel = x.numel();
auto x_dim = x.dims();
Expand All @@ -122,7 +121,8 @@ void ProcessMedianKernel(const Context& dev_ctx,

int64_t max_valid_num = 0;
std::vector<int64_t> nan_counts;
if (should_ignore_nan) {
bool ignore_nan = true;
if (ignore_nan) {
int64_t total_nan_num = 0;
std::vector<T> col_vec;
col_vec.reserve(stride);
Expand All @@ -133,7 +133,7 @@ void ProcessMedianKernel(const Context& dev_ctx,
for (int64_t i = 0; i < pre_dim; i++) {
col_vec.clear();
col_vec.insert(
col_vec.begin(), x_ptr + i * stride, x_ptr + (i + 1) * stride);
col_vec.begin(), x_data + i * stride, x_data + (i + 1) * stride);
nan_counts[i] =
std::count_if(col_vec.begin(), col_vec.end(), [&](const T& val) {
return std::isnan(static_cast<float>(val));
Expand All @@ -145,47 +145,25 @@ void ProcessMedianKernel(const Context& dev_ctx,
// all elems are nan
if (total_nan_num == numel) {
for (i = 0; i < pre_dim; i++) {
o_ptr[i] = x_ptr[0];
m_ptr[2 * i] = -1;
m_ptr[2 * i + 1] = -1;
out_data[i] = std::numeric_limits<T>::quiet_NaN();
m_data[2 * i] = -1;
m_data[2 * i + 1] = -1;
}
return;
}
should_ignore_nan = total_nan_num > 0;
ignore_nan = total_nan_num > 0;
}

int64_t sort_k = should_ignore_nan ? max_valid_num : ((stride >> 1) + 1);
int64_t sort_k = ignore_nan ? max_valid_num : ((stride >> 1) + 1);
CalcMedianFunc<T, Context>(dev_ctx,
x,
nan_counts,
should_ignore_nan,
ignore_nan,
sort_k,
stride,
pre_dim,
o_ptr,
m_ptr);
}

template <typename T, typename Context>
void BaseMedianKernel(const Context& dev_ctx,
const DenseTensor& input,
const IntArray& axes,
DenseTensor* out,
DenseTensor* median_index,
bool ignore_nan) {
DenseTensor x;
auto rank = input.dims().size();
if ((axes.size() == 0) || rank <= 1) {
x = input;
x.Resize({input.numel()});
} else {
PreprocessMedianKernel<T, Context>(dev_ctx, input, axes, &x);
}

T* o_ptr = dev_ctx.template Alloc<T>(out);
int64_t* m_ptr = dev_ctx.template Alloc<int64_t>(median_index);
ProcessMedianKernel<T, Context>(dev_ctx, x, o_ptr, m_ptr, ignore_nan);
out->Resize(out->dims());
out_data,
m_data);
}

template <typename T, typename Context>
Expand All @@ -195,7 +173,16 @@ void NanmedianKernel(const Context& dev_ctx,
bool keepdim UNUSED,
DenseTensor* out,
DenseTensor* median_index) {
BaseMedianKernel<T, Context>(dev_ctx, x, axes, out, median_index, true);
DenseTensor tmp_x;
auto rank = x.dims().size();
if ((axes.size() == 0) || rank <= 1) {
tmp_x = x;
tmp_x.Resize({x.numel()});
} else {
funcs::PreprocessMedianKernel<T, Context>(dev_ctx, x, axes, &tmp_x);
}

ProcessMedianKernel<T, Context>(dev_ctx, tmp_x, out, median_index);
}

} // namespace phi
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,51 @@
#pragma once

#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/nanmedian_kernel.h"

namespace phi {
namespace funcs {

template <typename T, typename Context>
void PostprocessMedianGradKernel(const Context& dev_ctx,
DenseTensor* input,
const IntArray& raw_axes,
DenseTensor* x) {
auto input_dim = input->dims();
auto rank = input_dim.size();

std::vector<int64_t> axes = raw_axes.GetData();
int64_t axes_size = static_cast<int>(axes.size());
for (int64_t i = 0; i < axes_size; i++) {
if (axes[i] < 0) {
axes[i] += rank;
}
}

std::vector<int> trans_back;
std::vector<int> reshape_back;
trans_back.resize(rank);

int offset = 0;
for (int64_t i = 0; i < rank; i++) {
if (std::find(axes.begin(), axes.end(), i) == axes.end()) {
reshape_back.push_back(input_dim[i]);
trans_back[i] = offset;
offset += 1;
}
}

for (int64_t i = 0; i < rank; i++) {
if (std::find(axes.begin(), axes.end(), i) != axes.end()) {
trans_back[i] = offset;
reshape_back.push_back(input_dim[i]);
offset += 1;
}
}

input->Resize(make_ddim(reshape_back));
funcs::TransCompute<Context, T>(
static_cast<int>(trans_back.size()), dev_ctx, *input, x, trans_back);
}

template <typename T, typename Context>
void PreprocessMedianKernel(const Context& dev_ctx,
Expand Down Expand Up @@ -65,4 +107,5 @@ void PreprocessMedianKernel(const Context& dev_ctx,
x->Resize(make_ddim(reshape));
}

} // namespace funcs
} // namespace phi
Loading