From 9fe3ee7de3811aeae9956184a9b4aacc84133fb6 Mon Sep 17 00:00:00 2001 From: thunder95 <290844930@qq.com> Date: Thu, 28 Apr 2022 14:55:40 +0000 Subject: [PATCH 01/18] nanmedian op --- paddle/fluid/operators/nanmedian_op.cc | 120 +++++++++ paddle/phi/infermeta/unary.cc | 24 ++ paddle/phi/infermeta/unary.h | 6 + .../phi/kernels/cpu/nanmedian_grad_kernel.cc | 69 +++++ paddle/phi/kernels/cpu/nanmedian_kernel.cc | 102 +++++++ .../phi/kernels/gpu/nanmedian_grad_kernel.cu | 83 ++++++ paddle/phi/kernels/gpu/nanmedian_kernel.cu | 251 ++++++++++++++++++ paddle/phi/kernels/nanmedian_grad_kernel.h | 26 ++ paddle/phi/kernels/nanmedian_kernel.h | 26 ++ paddle/phi/ops/compat/nanmedian_sig.cc | 33 +++ python/paddle/__init__.py | 1 + .../fluid/tests/unittests/test_nanmedian.py | 190 +++++++++++++ python/paddle/tensor/__init__.py | 2 + python/paddle/tensor/stat.py | 132 +++++++++ tools/parallel_UT_rule.py | 2 +- 15 files changed, 1066 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/operators/nanmedian_op.cc create mode 100644 paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/nanmedian_kernel.cc create mode 100644 paddle/phi/kernels/gpu/nanmedian_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/nanmedian_kernel.cu create mode 100644 paddle/phi/kernels/nanmedian_grad_kernel.h create mode 100644 paddle/phi/kernels/nanmedian_kernel.h create mode 100644 paddle/phi/ops/compat/nanmedian_sig.cc create mode 100644 python/paddle/fluid/tests/unittests/test_nanmedian.py diff --git a/paddle/fluid/operators/nanmedian_op.cc b/paddle/fluid/operators/nanmedian_op.cc new file mode 100644 index 0000000000000..3db9d9a0263f5 --- /dev/null +++ b/paddle/fluid/operators/nanmedian_op.cc @@ -0,0 +1,120 @@ +/*Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" + +namespace paddle { +namespace operators { + +class NanmedianOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } +}; + +class NanmedianOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor, default Tensor), " + "the input feature data of NanmedianOp, dtype should be" + "int32, int64, float16, float32, float64."); + AddAttr( + "ignore_nan", + "(bool, default true) Set to true if nan values should be ignored. " + "Set to false when no nan value in x were considered. ") + .SetDefault(true); + AddOutput("Medians", + "The calculation differs in the odd or even of the valid " + "elements amount." + "Along the axis, two elements contributed to the median value in " + "each row." + "If the amount of valid elements were even, both were the same.") + .AsIntermediate() + .AsExtra(); + AddOutput("Out", + "(Tensor, default Tensor)," + " the output of NanmedianOp, whose dtype is the same as X"); + AddComment(R"DOC( + Nanmedian operator + + This operator is considered as an extention of median operation, + which supports specifically the case of nan values in the input. + + If all the elements in input are NaN it will also return NaN. + If no elements in input are Nan, this op is identical to thie median op. + + This operator can also supports multiple axis, + and could be switched to median operator when `ignore_nan` were set to False. + )DOC"); + } +}; + +template +class NanmedianGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + void Apply(GradOpPtr op) const override { + op->SetType("nanmedian_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput("Medians", this->Output("Medians")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + } +}; + +class NanmedianGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "nanmedian"); + OP_INOUT_CHECK(ctx->HasInput("Medians"), "Input", "Medians", "nanmedian"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + framework::GradVarName("Out"), "nanmedian"); + + auto x_dims = ctx->GetInputDim("X"); + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(nanmedian, NanmedianInferShapeFunctor, + PD_INFER_META(phi::NanmedianInferMeta)); + +REGISTER_OPERATOR(nanmedian, ops::NanmedianOp, ops::NanmedianOpMaker, + ops::NanmedianGradMaker, + ops::NanmedianGradMaker, + NanmedianInferShapeFunctor); + +REGISTER_OPERATOR(nanmedian_grad, ops::NanmedianGradOp); diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 400c56db3efc2..ab22269c160a3 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1203,6 +1203,30 @@ void MultinomialInferMeta(const MetaTensor& x, out->set_dtype(DataType::INT64); } +void NanmedianInferMeta(const MetaTensor& x, + bool ignore_nan, + MetaTensor* out, + MetaTensor* medians) { + auto x_dim = x.dims(); + int64_t x_rank = x_dim.size(); + + std::vector out_dims(x_rank); + std::vector median_dims(x_rank); + for (int64_t i = 0; i < x_rank - 1; i++) { + out_dims[i] = x_dim[i]; + median_dims[i] = x_dim[i]; + } + + out_dims[x_rank - 1] = 1; + median_dims[x_rank - 1] = 2; + + out->set_dims(make_ddim(out_dims)); + out->set_dtype(x.dtype()); + + medians->set_dims(make_ddim(median_dims)); + medians->set_dtype(x.dtype()); +} + void NormInferMeta(const MetaTensor& x, int axis, float epsilon, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index c67eb2068d8bf..613ee0d2ebad7 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -173,6 +173,12 @@ void MultinomialInferMeta(const MetaTensor& x, int num_samples, bool replacement, MetaTensor* out); + +void NanmedianInferMeta(const MetaTensor& x, + bool ignore_nan, + MetaTensor* out, + MetaTensor* medians); + void NormInferMeta(const MetaTensor& x, int axis, float epsilon, diff --git a/paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc b/paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc new file mode 100644 index 0000000000000..c29d23290b755 --- /dev/null +++ b/paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc @@ -0,0 +1,69 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/nanmedian_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void NanmedianGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& medians, + const DenseTensor& out_grad, + DenseTensor* x_grad) { + const T* x_ptr = x.data(); + const T* m_ptr = medians.data(); + const T* out_grad_ptr = out_grad.data(); + + int64_t numel = x.numel(); + auto x_dim = x.dims(); + int64_t x_rank = x_dim.size(); + int64_t stride = x_dim[x_rank - 1]; + auto zero = static_cast(0); + + if (x_grad) { + T* x_grad_ptr = dev_ctx.template Alloc(x_grad); + int64_t i = 0; + for (i = 0; i < numel; i++) { + if (std::isnan(x_ptr[i])) { + x_grad_ptr[i] = zero; + continue; + } + + int64_t row = static_cast(i / stride); + int64_t m_row = 2 * row; + if (std::isnan(m_ptr[m_row]) || + (x_ptr[i] != m_ptr[m_row] && x_ptr[i] != m_ptr[m_row + 1])) { + x_grad_ptr[i] = zero; + continue; + } + + x_grad_ptr[i] = out_grad_ptr[row]; + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(nanmedian_grad, + CPU, + ALL_LAYOUT, + phi::NanmedianGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/cpu/nanmedian_kernel.cc b/paddle/phi/kernels/cpu/nanmedian_kernel.cc new file mode 100644 index 0000000000000..ae2ecdb555b8d --- /dev/null +++ b/paddle/phi/kernels/cpu/nanmedian_kernel.cc @@ -0,0 +1,102 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/nanmedian_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void NanmedianKernel(const Context& dev_ctx, + const DenseTensor& x, + bool ignore_nan, + DenseTensor* out, + DenseTensor* medians) { + const T* x_ptr = x.data(); + T* o_ptr = dev_ctx.template Alloc(out); + T* m_ptr = dev_ctx.template Alloc(medians); + + int64_t numel = x.numel(); + auto x_dim = x.dims(); + int64_t x_rank = x_dim.size(); + int64_t stride = x_dim[x_rank - 1]; + int64_t pre_dim = numel / stride; + int64_t i = 0; + + bool all_nan = true; + for (i = 0; i < numel; i++) { + if (!std::isnan(*(x_ptr + i))) { + all_nan = false; + break; + } + } + + if (all_nan) { + for (i = 0; i < pre_dim; i++) { + o_ptr[i] = x_ptr[0]; + m_ptr[2 * i] = x_ptr[0]; + m_ptr[2 * i + 1] = x_ptr[0]; + } + return; + } + + std::vector col_vec; + col_vec.reserve(stride); + col_vec.resize(stride); + for (i = 0; i < pre_dim; i++) { + col_vec.clear(); + col_vec.insert( + col_vec.begin(), x_ptr + i * stride, x_ptr + (i + 1) * stride); + + int64_t num_nan = + std::count_if(col_vec.begin(), col_vec.end(), std::isnan); + int64_t pos = (stride - num_nan - 1) / 2; + std::nth_element(col_vec.begin(), + col_vec.begin() + pos, + col_vec.end(), + [](const T& l, const T& r) { + return (!std::isnan(static_cast(l)) && + std::isnan(static_cast(r))) || + (l < r); + }); + + m_ptr[2 * i] = col_vec[pos]; + m_ptr[2 * i + 1] = col_vec[pos]; + if ((stride - num_nan) % 2 == 0) { + std::nth_element(col_vec.begin(), + col_vec.begin() + pos + 1, + col_vec.end(), + [](const T& l, const T& r) { + return (!std::isnan(static_cast(l)) && + std::isnan(static_cast(r))) || + (l < r); + }); + m_ptr[2 * i + 1] = col_vec[pos + 1]; + } + o_ptr[i] = static_cast((m_ptr[2 * i] + m_ptr[2 * i + 1]) / 2.0); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(nanmedian, + CPU, + ALL_LAYOUT, + phi::NanmedianKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/nanmedian_grad_kernel.cu b/paddle/phi/kernels/gpu/nanmedian_grad_kernel.cu new file mode 100644 index 0000000000000..f316ea92ab319 --- /dev/null +++ b/paddle/phi/kernels/gpu/nanmedian_grad_kernel.cu @@ -0,0 +1,83 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/nanmedian_grad_kernel.h" + +#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_meta.h" + +namespace phi { + +using paddle::platform::PADDLE_CUDA_NUM_THREADS; +inline int GET_BLOCKS(const int N) { + return (N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS; +} + +template +__global__ void KernelNanmedianGrad(const T* x_ptr, + const T* medians_ptr, + const T* out_grad_ptr, + T* x_grad_ptr, + int64_t stride, + int64_t numel) { + auto zero = static_cast(0); + CUDA_KERNEL_LOOP(index, numel) { + int64_t row = static_cast(index / stride); + int64_t m_row = 2 * row; + if (isnan(x_ptr[index]) || isnan(medians_ptr[m_row]) || + (x_ptr[index] != medians_ptr[m_row] && + x_ptr[index] != medians_ptr[m_row + 1])) { + x_grad_ptr[index] = zero; + } else { + x_grad_ptr[index] = out_grad_ptr[row]; + } + } +} + +template +void NanmedianGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& medians, + const DenseTensor& out_grad, + DenseTensor* x_grad) { + auto stream = dev_ctx.stream(); + const T* x_ptr = x.data(); + const T* m_ptr = medians.data(); + const T* out_grad_ptr = out_grad.data(); + T* x_grad_ptr = dev_ctx.template Alloc(x_grad); + + int64_t numel = x.numel(); + auto x_dim = x.dims(); + int64_t x_rank = x_dim.size(); + int64_t stride = x_dim[x_rank - 1]; + + KernelNanmedianGrad< + T><<>>( + x_ptr, m_ptr, out_grad_ptr, x_grad_ptr, stride, numel); +} + +} // namespace phi + +PD_REGISTER_KERNEL(nanmedian_grad, + GPU, + ALL_LAYOUT, + phi::NanmedianGradKernel, + float, + double, + int, + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/nanmedian_kernel.cu b/paddle/phi/kernels/gpu/nanmedian_kernel.cu new file mode 100644 index 0000000000000..8112a425bea0a --- /dev/null +++ b/paddle/phi/kernels/gpu/nanmedian_kernel.cu @@ -0,0 +1,251 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/nanmedian_kernel.h" + +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/top_k_kernel.h" + +namespace phi { + +using paddle::platform::PADDLE_CUDA_NUM_THREADS; + +inline int GET_BLOCKS(const int N) { + return (N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS; +} + +template +__global__ void KernelNanCounts(const T* input, + const int numel, + const int64_t pre_dim, + const int64_t stride, + T min_val, + int64_t* nan_total, + int64_t* nan_counts, + T* output) { + extern __shared__ int64_t buf[]; + for (int i = threadIdx.x; i < pre_dim; i += blockDim.x) { + buf[i] = 0; + } + __syncthreads(); + + CUDA_KERNEL_LOOP(index, numel) { + const T x = input[index]; + if (isnan(x)) { + auto bin = static_cast(index / stride); + paddle::platform::CudaAtomicAdd(&buf[bin], 1); + output[index] = min_val; + } + } + __syncthreads(); + + for (int i = threadIdx.x; i < pre_dim; i += blockDim.x) { + paddle::platform::CudaAtomicAdd(&nan_counts[i], buf[i]); + paddle::platform::CudaAtomicAdd(&nan_total[0], buf[i]); + paddle::platform::CudaAtomicMax(&nan_total[1], buf[i]); + } +} + +template +__global__ void CalcMedianKernel(const T* sort_out, + T* median_val, + T* output, + const bool is_odd, + const int64_t pre_dim, + const int64_t stride) { + T div_factor = static_cast(2.0); + CUDA_KERNEL_LOOP(index, pre_dim) { + int64_t pos = static_cast((index + 1) * stride) - 1; + if (is_odd) { + median_val[index * 2] = sort_out[pos]; + median_val[index * 2 + 1] = sort_out[pos]; + output[index] = sort_out[pos]; + } else { + median_val[index * 2] = pos > 1 ? sort_out[pos - 1] : sort_out[pos]; + median_val[index * 2 + 1] = sort_out[pos]; + output[index] = + (median_val[index * 2] + median_val[index * 2 + 1]) / div_factor; + } + } +} + +template +__global__ void CalcNanmedianKernel(const T* sort_out, + int64_t* nan_counts, + T* median_val, + T* output, + const bool is_odd, + const int64_t pre_dim, + const int64_t max_nan_num, + const int64_t stride) { + T div_factor = static_cast(2.0); + T nan_val = std::numeric_limits::quiet_NaN(); + + CUDA_KERNEL_LOOP(index, pre_dim) { + int64_t pos = static_cast(index * max_nan_num); + int64_t nan_cnt = nan_counts[index]; + if (nan_cnt == stride) { + median_val[index * 2] = nan_val; + median_val[index * 2 + 1] = nan_val; + output[index] = nan_val; + } else { + bool check_odd = is_odd; + if (nan_cnt > 0) { + int64_t nan_k = static_cast(stride - nan_cnt); + int64_t new_k = static_cast(nan_k >> 1); + pos += new_k - 1; + check_odd = nan_k & 1; + } else { + pos += max_nan_num - 1; + } + + if (check_odd) { + median_val[index * 2] = sort_out[pos]; + median_val[index * 2 + 1] = sort_out[pos]; + output[index] = sort_out[pos]; + } else { + median_val[index * 2] = pos > 1 ? sort_out[pos - 1] : sort_out[pos]; + median_val[index * 2 + 1] = sort_out[pos]; + output[index] = + (median_val[index * 2] + median_val[index * 2 + 1]) / div_factor; + } + } + } +} + +template +void NanmedianKernel(const Context& dev_ctx, + const DenseTensor& x, + bool ignore_nan, + DenseTensor* out, + DenseTensor* medians) { + auto stream = dev_ctx.stream(); + auto* ctx = + reinterpret_cast(&dev_ctx); + + const T* x_ptr = x.data(); + T* o_ptr = dev_ctx.template Alloc(out); + T* m_ptr = dev_ctx.template Alloc(medians); + + int64_t numel = x.numel(); + auto x_dim = x.dims(); + int64_t x_rank = x_dim.size(); + int64_t stride = x_dim[x_rank - 1]; + int64_t pre_dim = numel / stride; + int64_t i = 0; + + int64_t half_stride = (stride >> 1) + 1; + bool is_ori_odd = stride & 1; + + DenseTensor sort_out; + auto sort_dim = x.dims(); + sort_dim[x_rank - 1] = half_stride; + + sort_out.Resize(sort_dim); + dev_ctx.template Alloc(&sort_out); + T* sort_out_ptr = sort_out.data(); + + std::vector out_dim_vec = vectorize(sort_dim); + DenseTensor indices = phi::Empty(dev_ctx, IntArray(out_dim_vec)); + + if (ignore_nan) { + DenseTensor nan_counts, nan_stat, nonnan_x; + + nan_counts.Resize(phi::make_ddim({pre_dim})); + dev_ctx.template Alloc(&nan_counts); + int64_t* nan_counts_ptr = nan_counts.data(); + + nan_stat.Resize(phi::make_ddim({2})); + int64_t* nan_stat_mem = dev_ctx.template Alloc(&nan_stat); + int64_t* nan_stat_ptr = nan_stat.data(); + + nonnan_x.Resize(x.dims()); + dev_ctx.template Alloc(&nonnan_x); + T* nonnan_x_ptr = nonnan_x.data(); + + KernelNanCounts<<>>(x_ptr, + numel, + pre_dim, + stride, + std::numeric_limits::min(), + nan_stat_ptr, + nan_counts_ptr, + nonnan_x_ptr); + + auto nan_stat_mem_cpu = + paddle::memory::Alloc(phi::CPUPlace(), sizeof(int64_t) * 2); + int64_t* nan_stat_cpu_ptr = + reinterpret_cast(nan_stat_mem_cpu->ptr()); + paddle::memory::Copy(phi::CPUPlace(), + nan_stat_cpu_ptr, + dev_ctx.GetPlace(), + nan_stat_mem, + sizeof(int64_t) * 2, + stream); + + // all elements are nan + T nan_val = std::numeric_limits::quiet_NaN(); + if (nan_stat_cpu_ptr[0] == numel) { + FullLikeKernel(dev_ctx, x, nan_val, x.dtype(), out); + return; + } + + if (nan_stat_cpu_ptr[0] > 0) { + int64_t max_nan_num = nan_stat_cpu_ptr[1]; + TopkKernel( + dev_ctx, x, Scalar(max_nan_num), -1, true, true, &sort_out, &indices); + + CalcNanmedianKernel< + T><<>>( + sort_out_ptr, + nan_counts_ptr, + m_ptr, + o_ptr, + is_ori_odd, + pre_dim, + max_nan_num, + stride); + + return; + } + } + + TopkKernel( + dev_ctx, x, Scalar(half_stride), -1, true, true, &sort_out, &indices); + + CalcMedianKernel< + T><<>>( + sort_out_ptr, m_ptr, o_ptr, is_ori_odd, pre_dim, half_stride); +} + +} // namespace phi + +PD_REGISTER_KERNEL(nanmedian, + GPU, + ALL_LAYOUT, + phi::NanmedianKernel, + float, + double, + int, + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/nanmedian_grad_kernel.h b/paddle/phi/kernels/nanmedian_grad_kernel.h new file mode 100644 index 0000000000000..714b2b8192d86 --- /dev/null +++ b/paddle/phi/kernels/nanmedian_grad_kernel.h @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void NanmedianGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& medians, + const DenseTensor& out_grad, + DenseTensor* x_grad); +} // namespace phi diff --git a/paddle/phi/kernels/nanmedian_kernel.h b/paddle/phi/kernels/nanmedian_kernel.h new file mode 100644 index 0000000000000..e30472550399c --- /dev/null +++ b/paddle/phi/kernels/nanmedian_kernel.h @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void NanmedianKernel(const Context& dev_ctx, + const DenseTensor& x, + bool ignore_nan, + DenseTensor* out, + DenseTensor* medians); +} // namespace phi diff --git a/paddle/phi/ops/compat/nanmedian_sig.cc b/paddle/phi/ops/compat/nanmedian_sig.cc new file mode 100644 index 0000000000000..58cb13e344232 --- /dev/null +++ b/paddle/phi/ops/compat/nanmedian_sig.cc @@ -0,0 +1,33 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature NanmedianOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature( + "nanmedian", {"X"}, {"ignore_nan"}, {"Out", "Medians"}); +} + +KernelSignature NanmedianGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "nanmedian_grad", {"X", "Medians", "Out@GRAD"}, {}, {"X@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(nanmedian, phi::NanmedianOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(nanmedian_grad, phi::NanmedianGradOpArgumentMapping); diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index cb0135d9b4c29..7064aca53d110 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -329,6 +329,7 @@ from .tensor.stat import var # noqa: F401 from .tensor.stat import numel # noqa: F401 from .tensor.stat import median # noqa: F401 +from .tensor.stat import nanmedian # noqa: F401 from .tensor.stat import quantile # noqa: F401 from .tensor.stat import nanquantile # noqa: F401 from .device import get_cudnn_version # noqa: F401 diff --git a/python/paddle/fluid/tests/unittests/test_nanmedian.py b/python/paddle/fluid/tests/unittests/test_nanmedian.py new file mode 100644 index 0000000000000..20c8e816297e0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_nanmedian.py @@ -0,0 +1,190 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid.core as core + +np.random.seed(10) + + +class TestNanmedian(unittest.TestCase): + def setUp(self): + single_axis_shape = (120) + multi_axis_shape = (2, 3, 4, 5) + + self.fake_data = { + "single_axis_normal": + np.random.uniform(-1, 1, single_axis_shape).astype(np.float32), + "multi_axis_normal": + np.random.uniform(-1, 1, multi_axis_shape).astype(np.float32), + "single_axis_all_nan": np.full(single_axis_shape, np.nan), + "multi_axis_all_nan": np.full(multi_axis_shape, np.nan), + } + + single_partial_nan = self.fake_data["single_axis_normal"] + single_partial_nan[single_partial_nan > 0] = np.nan + multi_partial_nan = self.fake_data["multi_axis_normal"] + multi_partial_nan[multi_partial_nan > 0] = np.nan + self.fake_data["single_axis_partial_nan"] = single_partial_nan + self.fake_data["multi_axis_partial_nan"] = multi_partial_nan + + row_data = np.random.uniform(-1, 1, multi_axis_shape).astype(np.float32) + row_data[:, :, :, 0] = np.nan + row_data[:, :, :2, 1] = np.nan + row_data[:, :, 2:, 2] = np.nan + self.fake_data["row_nan_even"] = row_data + self.fake_data["row_nan_float64"] = row_data.astype(np.float64) + self.fake_data["row_nan_int64"] = row_data.astype(np.int64) + self.fake_data["row_nan_int32"] = row_data.astype(np.int32) + + col_data = np.random.uniform(-1, 1, multi_axis_shape).astype(np.float32) + col_data[:, :, 0, :] = np.nan + col_data[:, :, 1, :3] = np.nan + col_data[:, :, 2, 3:] = np.nan + self.fake_data["col_nan_odd"] = col_data + + self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + self.axis_candiate_list = [ + None, 0, 2, -1, -2, (1, 2), [0, -1], [0, 1, 3], (1, 2, -3), + [0, 2, 1, 3] + ] + + def test_api_static(self): + data = self.fake_data["col_nan_odd"] + paddle.enable_static() + np_res = np.nanmedian(data, keepdims=True) + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.fluid.data('X', data.shape) + out1 = paddle.nanmedian(x, keepdim=True) + out2 = paddle.tensor.nanmedian(x, keepdim=True) + out3 = paddle.tensor.stat.nanmedian(x, keepdim=True) + axis = np.arange(len(data.shape)).tolist() + out4 = paddle.nanmedian(x, axis=axis, keepdim=True) + out5 = paddle.nanmedian(x, axis=tuple(axis), keepdim=True) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': data}, + fetch_list=[out1, out2, out3, out4, out5]) + + for out in res: + self.assertTrue(np.allclose(np_res, out, equal_nan=True)) + + def test_api_dygraph(self): + paddle.disable_static(self.place) + + def clean_axis_numpy(axis, shape_len): + if isinstance(axis, tuple): + axis = list(axis) + if isinstance(axis, list): + for k in range(len(axis)): + if axis[k] < 0: + axis[k] += shape_len + axis = set(axis) + return axis + + def test_data_case(data, ignore_nan=True): + for keep_dim in [False, True]: + np_res = np.nanmedian(data, keepdims=keep_dim) + pd_res = paddle.nanmedian( + paddle.to_tensor(data), + ignore_nan=ignore_nan, + keepdim=keep_dim) + self.assertTrue( + np.allclose( + np_res, pd_res.numpy(), equal_nan=True)) + + def test_axis_case(data, axis, ignore_nan=True): + pd_res = paddle.nanmedian( + paddle.to_tensor(data), + axis=axis, + ignore_nan=ignore_nan, + keepdim=False) + axis = clean_axis_numpy(axis, len(data.shape)) + np_res = np.nanmedian(data, axis=axis, keepdims=False) + self.assertTrue(np.allclose(np_res, pd_res.numpy(), equal_nan=True)) + + for name, data in self.fake_data.items(): + test_data_case(data) + if "_normal" in name: + test_data_case(data, ignore_nan=False) + + for axis in self.axis_candiate_list: + test_axis_case(self.fake_data["row_nan_even"], axis) + test_axis_case(self.fake_data["col_nan_odd"], axis) + + paddle.enable_static() + + def test_errors(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.fluid.data("X", [10, 12]) + + def test_dtype(): + x2 = paddle.fluid.data('X2', [10, 12], 'bool') + paddle.nanmedian(x2) + + def test_empty_axis(): + paddle.nanmedian(x, axis=[], keepdim=True) + + def test_axis_not_in_range(): + paddle.nanmedian(x, axis=3, keepdim=True) + + self.assertRaises(TypeError, test_dtype) + self.assertRaises(ValueError, test_empty_axis) + self.assertRaises(ValueError, test_axis_not_in_range) + + def test_dygraph(self): + paddle.disable_static(place=self.place) + with paddle.fluid.dygraph.guard(): + data = self.fake_data["col_nan_odd"] + out = paddle.nanmedian(paddle.to_tensor(data), keepdim=True) + np_res = np.nanmedian(data, keepdims=True) + self.assertTrue(np.allclose(np_res, out, equal_nan=True)) + paddle.enable_static() + + def test_check_grad(self): + paddle.disable_static(place=self.place) + shape = (4, 5) + x_np = np.random.uniform(-1, 1, shape).astype(np.float64) + x_np[0, :] = np.nan + x_np[1, :3] = np.nan + x_np[2, 3:] = np.nan + x_np_sorted = np.sort(x_np) + nan_counts = np.count_nonzero(np.isnan(x_np).astype(np.int32), axis=1) + np_grad = np.zeros((shape)) + for i in range(shape[0]): + valid_cnts = shape[1] - nan_counts[i] + if valid_cnts == 0: + continue + + mid = int(valid_cnts / 2) + targets = [x_np_sorted[i, mid]] + if valid_cnts % 2 == 0 and mid > 0: + targets.append(x_np_sorted[i, mid - 1]) + for j in range(shape[1]): + if x_np[i, j] in targets: + np_grad[i, j] = 1 + + x_tensor = paddle.to_tensor(x_np, stop_gradient=False) + y = paddle.nanmedian(x_tensor, axis=1, keepdim=True) + dx = paddle.grad(y, x_tensor)[0].numpy() + self.assertTrue(np.allclose(np_grad, dx, equal_nan=True)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 5f0fb4336e014..c652699b9dd56 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -261,6 +261,7 @@ from .stat import var # noqa: F401 from .stat import numel # noqa: F401 from .stat import median # noqa: F401 +from .stat import nanmedian # noqa: F401 from .stat import quantile # noqa: F401 from .stat import nanquantile # noqa: F401 @@ -445,6 +446,7 @@ 'var', 'numel', 'median', + 'nanmedian', 'quantile', 'nanquantile', 'is_complex', diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index 991b86fd47d16..8db170e55574e 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -253,6 +253,138 @@ def numel(x, name=None): return out +def nanmedian(x, axis=None, ignore_nan=True, keepdim=True, name=None): + """ + Compute the median along the specified axis, while ignoring NaNs. + + Args: + x (Tensor): The input Tensor, it's data type can be int32, int64, float16, float32, float64. + axis (None|int|list|tuple, optional): + The axis along which to perform median calculations ``axis`` should be int. + ``axis`` should be in range [-D, D), where D is the dimensions of ``x`` . + If ``axis`` is less than 0, it works the same way as :math:`axis + D`. + If ``axis`` is None, median is calculated over all elements of ``x``. Default is None. + ignore_nan (bool, optional): Whether to ignore nan values when median was calculated. + If `ignore_nan` is True, the calculation process is the same as `median` operator. + Default is True. + keepdim (bool, optional): Whether to reserve the reduced dimension(s) + in the output Tensor. If ``keepdim`` is True, the dimensions of + the output Tensor is the same as ``x`` except in the reduced + dimensions(it is of size 1 in this case). Otherwise, the shape of + the output Tensor is squeezed in ``axis`` . Default is False. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor, results of median along ``axis`` of ``x``. The output dtype is the same as `x`. + + Examples: + .. code-block:: python + + import paddle + x = paddle.to_tensor([[np.nan, 2. , 3. ], [0. , 1. , 2. ]]) + + y1 = x.nanmedian() + # y1 is [[[2.]] + + y2 = x.nanmedian(0) + # y2 is [[0., 1.5, 2.5]] + + y3 = x.nanmedian(0, keepdim=False) + # y3 is [0., 1.5, 2.5] + """ + if not isinstance(x, Variable): + raise TypeError("In median, the input x should be a Tensor.") + + dims = len(x.shape) + out_shape = list(x.shape) + if axis is None: + x = paddle.flatten(x) + axis = 0 + out_shape = [1] * dims if keepdim else [1] + else: + if isinstance(axis, tuple): + axis = list(axis) + + if isinstance(axis, list): + for k in range(len(axis)): + if axis[k] < 0: + axis[k] += dims + axis = list(set(axis)) + if len(axis) <= 0: + raise ValueError("axis should not be empty") + + axis_src, axis_dst = [], [] + for axis_single in axis: + if not isinstance(axis_single, int) or not ( + axis_single < dims and axis_single >= -dims): + raise ValueError( + "Axis should be None, int, or a list, element should in range [-rank(x), rank(x))." + ) + axis_src.append(axis_single) + out_shape[axis_single] = 1 if keepdim else 0 + + axis_dst = list(range(-len(axis), 0)) + x = paddle.moveaxis(x, axis_src, axis_dst) + x = paddle.flatten(x, axis_dst[0], axis_dst[-1]) + axis = axis_dst[0] + else: + if not isinstance(axis, int) or not (axis < dims and axis >= -dims): + raise ValueError( + "Axis should be None, int, or a list, element should in range [-rank(x), rank(x))." + ) + if axis < 0: + axis += dims + out_shape[axis] = 1 if keepdim else 0 + + trans_axis = [] + last_axis = len(x.shape) - 1 + if len(x.shape) > 1 and axis is not None and axis != last_axis: + for i in range(axis): + trans_axis.append(i) + trans_axis.append(last_axis) + for i in range(axis + 1, last_axis): + trans_axis.append(i) + trans_axis.append(axis) + x = paddle.transpose(x, perm=trans_axis) + + if len(x.shape) == 1 and axis == 0: + axis = None + else: + axis = len(x.shape) - 1 + res_shape = [x for x in out_shape if x > 0] + + if _in_legacy_dygraph(): + medians, out = _C_ops.nanmedian(x, 'ignore_nan', ignore_nan) + if len(trans_axis) > 1: + out = paddle.transpose(out, perm=trans_axis) + if len(res_shape) > 0: + return paddle.reshape(out, res_shape) + + return out + + check_variable_and_dtype( + x, 'X', ['int32', 'int64', 'float16', 'float32', 'float64'], + 'nanmedian') + + helper = LayerHelper('nanmedian', **locals()) + attrs = {'ignore_nan': ignore_nan} + out = helper.create_variable_for_type_inference(x.dtype) + medians = helper.create_variable_for_type_inference(x.dtype) + helper.append_op( + type='nanmedian', + inputs={'X': x}, + outputs={'Out': out, + 'Medians': medians}, + attrs=attrs) + + if len(trans_axis) > 1: + out = paddle.transpose(out, perm=trans_axis) + if len(res_shape) > 0: + return paddle.reshape(out, res_shape) + return out + + def median(x, axis=None, keepdim=False, name=None): """ Compute the median along the specified axis. diff --git a/tools/parallel_UT_rule.py b/tools/parallel_UT_rule.py index 5088ad3457fb9..7702e8be9c958 100755 --- a/tools/parallel_UT_rule.py +++ b/tools/parallel_UT_rule.py @@ -824,7 +824,7 @@ 'test_mean_op', 'test_is_tensor', 'test_run_program_op', 'test_cuda_random_seed', 'test_linear_interp_op', 'test_fuse_all_reduce_pass', 'tensor_util_test', 'test_median', - 'test_linear', 'test_imperative_qat_amp', + 'test_nanmedian', 'test_linear', 'test_imperative_qat_amp', 'test_truncated_gaussian_random_op', 'test_lstm_cudnn_op', 'copy_same_tensor_test', 'test_squeeze2_op', 'naive_best_fit_allocator_test', 'test_model', 'test_py_reader_combination', From 7eae9c2866ba33cb18a07c918bc2960ae4220211 Mon Sep 17 00:00:00 2001 From: thunder95 <290844930@qq.com> Date: Fri, 29 Apr 2022 05:13:23 +0000 Subject: [PATCH 02/18] =?UTF-8?q?=E4=BF=AE=E6=94=B9cuda=20kernel=E7=9A=84b?= =?UTF-8?q?ug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/phi/kernels/gpu/nanmedian_kernel.cu | 55 ++++++++++++------- .../fluid/tests/unittests/test_nanmedian.py | 6 +- 2 files changed, 37 insertions(+), 24 deletions(-) diff --git a/paddle/phi/kernels/gpu/nanmedian_kernel.cu b/paddle/phi/kernels/gpu/nanmedian_kernel.cu index 8112a425bea0a..103f255c7b2e6 100644 --- a/paddle/phi/kernels/gpu/nanmedian_kernel.cu +++ b/paddle/phi/kernels/gpu/nanmedian_kernel.cu @@ -43,7 +43,14 @@ __global__ void KernelNanCounts(const T* input, extern __shared__ int64_t buf[]; for (int i = threadIdx.x; i < pre_dim; i += blockDim.x) { buf[i] = 0; + nan_counts[i] = 0; } + + if (threadIdx.x == 0) { + nan_total[0] = 0; + nan_total[1] = 0; + } + __syncthreads(); CUDA_KERNEL_LOOP(index, numel) { @@ -51,6 +58,7 @@ __global__ void KernelNanCounts(const T* input, if (isnan(x)) { auto bin = static_cast(index / stride); paddle::platform::CudaAtomicAdd(&buf[bin], 1); + // NOTE: at this moment paddle.sort does not suppert nan values output[index] = min_val; } } @@ -59,7 +67,7 @@ __global__ void KernelNanCounts(const T* input, for (int i = threadIdx.x; i < pre_dim; i += blockDim.x) { paddle::platform::CudaAtomicAdd(&nan_counts[i], buf[i]); paddle::platform::CudaAtomicAdd(&nan_total[0], buf[i]); - paddle::platform::CudaAtomicMax(&nan_total[1], buf[i]); + paddle::platform::CudaAtomicMax(&nan_total[1], stride - buf[i]); } } @@ -93,13 +101,12 @@ __global__ void CalcNanmedianKernel(const T* sort_out, T* output, const bool is_odd, const int64_t pre_dim, - const int64_t max_nan_num, - const int64_t stride) { - T div_factor = static_cast(2.0); - T nan_val = std::numeric_limits::quiet_NaN(); - + const int64_t max_valid_num, + const int64_t stride, + const T div_factor, + const T nan_val) { CUDA_KERNEL_LOOP(index, pre_dim) { - int64_t pos = static_cast(index * max_nan_num); + int64_t pos = static_cast(index * max_valid_num); int64_t nan_cnt = nan_counts[index]; if (nan_cnt == stride) { median_val[index * 2] = nan_val; @@ -107,21 +114,17 @@ __global__ void CalcNanmedianKernel(const T* sort_out, output[index] = nan_val; } else { bool check_odd = is_odd; - if (nan_cnt > 0) { - int64_t nan_k = static_cast(stride - nan_cnt); - int64_t new_k = static_cast(nan_k >> 1); - pos += new_k - 1; - check_odd = nan_k & 1; - } else { - pos += max_nan_num - 1; - } + int64_t nan_k = + nan_cnt > 0 ? static_cast(stride - nan_cnt) : max_valid_num; + pos += static_cast(nan_k >> 1); + check_odd = nan_k & 1; if (check_odd) { median_val[index * 2] = sort_out[pos]; median_val[index * 2 + 1] = sort_out[pos]; output[index] = sort_out[pos]; } else { - median_val[index * 2] = pos > 1 ? sort_out[pos - 1] : sort_out[pos]; + median_val[index * 2] = pos > 0 ? sort_out[pos - 1] : sort_out[pos]; median_val[index * 2 + 1] = sort_out[pos]; output[index] = (median_val[index * 2] + median_val[index * 2 + 1]) / div_factor; @@ -211,9 +214,17 @@ void NanmedianKernel(const Context& dev_ctx, } if (nan_stat_cpu_ptr[0] > 0) { - int64_t max_nan_num = nan_stat_cpu_ptr[1]; - TopkKernel( - dev_ctx, x, Scalar(max_nan_num), -1, true, true, &sort_out, &indices); + int64_t max_valid_num = nan_stat_cpu_ptr[1]; + T div_factor = static_cast(2.0); + + TopkKernel(dev_ctx, + x, + Scalar(max_valid_num), + -1, + true, + true, + &sort_out, + &indices); CalcNanmedianKernel< T><<>>( @@ -223,8 +234,10 @@ void NanmedianKernel(const Context& dev_ctx, o_ptr, is_ori_odd, pre_dim, - max_nan_num, - stride); + max_valid_num, + stride, + div_factor, + nan_val); return; } diff --git a/python/paddle/fluid/tests/unittests/test_nanmedian.py b/python/paddle/fluid/tests/unittests/test_nanmedian.py index 20c8e816297e0..93d6f487e6f38 100644 --- a/python/paddle/fluid/tests/unittests/test_nanmedian.py +++ b/python/paddle/fluid/tests/unittests/test_nanmedian.py @@ -19,7 +19,7 @@ import paddle import paddle.fluid.core as core -np.random.seed(10) +np.random.seed(102) class TestNanmedian(unittest.TestCase): @@ -36,9 +36,9 @@ def setUp(self): "multi_axis_all_nan": np.full(multi_axis_shape, np.nan), } - single_partial_nan = self.fake_data["single_axis_normal"] + single_partial_nan = self.fake_data["single_axis_normal"].copy() single_partial_nan[single_partial_nan > 0] = np.nan - multi_partial_nan = self.fake_data["multi_axis_normal"] + multi_partial_nan = self.fake_data["multi_axis_normal"].copy() multi_partial_nan[multi_partial_nan > 0] = np.nan self.fake_data["single_axis_partial_nan"] = single_partial_nan self.fake_data["multi_axis_partial_nan"] = multi_partial_nan From 1f2a6e6f3bb0f51f83ab3260f4eb06a8218bd88c Mon Sep 17 00:00:00 2001 From: thunder95 <290844930@qq.com> Date: Fri, 29 Apr 2022 06:52:38 +0000 Subject: [PATCH 03/18] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dcount=5Fif=E5=9C=A8?= =?UTF-8?q?=E5=85=B6=E4=BB=96=E7=A1=AC=E4=BB=B6=E5=B9=B3=E5=8F=B0=E4=B8=8D?= =?UTF-8?q?=E5=85=BC=E5=AE=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/phi/kernels/cpu/nanmedian_kernel.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/paddle/phi/kernels/cpu/nanmedian_kernel.cc b/paddle/phi/kernels/cpu/nanmedian_kernel.cc index ae2ecdb555b8d..806ca63eb49f2 100644 --- a/paddle/phi/kernels/cpu/nanmedian_kernel.cc +++ b/paddle/phi/kernels/cpu/nanmedian_kernel.cc @@ -62,7 +62,10 @@ void NanmedianKernel(const Context& dev_ctx, col_vec.begin(), x_ptr + i * stride, x_ptr + (i + 1) * stride); int64_t num_nan = - std::count_if(col_vec.begin(), col_vec.end(), std::isnan); + std::count_if(col_vec.begin(), col_vec.end(), [&](const T& val) { + return std::isnan(static_cast(val)); + }); + int64_t pos = (stride - num_nan - 1) / 2; std::nth_element(col_vec.begin(), col_vec.begin() + pos, From d5c35d8d05f559ac0f76efc03c356d3925b15484 Mon Sep 17 00:00:00 2001 From: thunder95 <290844930@qq.com> Date: Sat, 30 Apr 2022 07:01:34 +0000 Subject: [PATCH 04/18] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E6=9F=90=E4=BA=9Bcpu?= =?UTF-8?q?=E7=A1=AC=E4=BB=B6=E4=B8=8D=E5=85=BC=E5=AE=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/phi/kernels/cpu/nanmedian_kernel.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/kernels/cpu/nanmedian_kernel.cc b/paddle/phi/kernels/cpu/nanmedian_kernel.cc index 806ca63eb49f2..287276782c211 100644 --- a/paddle/phi/kernels/cpu/nanmedian_kernel.cc +++ b/paddle/phi/kernels/cpu/nanmedian_kernel.cc @@ -38,7 +38,7 @@ void NanmedianKernel(const Context& dev_ctx, bool all_nan = true; for (i = 0; i < numel; i++) { - if (!std::isnan(*(x_ptr + i))) { + if (!std::isnan(static_cast(*(x_ptr + i)))) { all_nan = false; break; } From 4ec331b15a65e91156d6c2e5a78c2a699e44a2a4 Mon Sep 17 00:00:00 2001 From: thunder95 <290844930@qq.com> Date: Sat, 30 Apr 2022 07:34:59 +0000 Subject: [PATCH 05/18] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E6=9F=90=E4=BA=9Bcpu?= =?UTF-8?q?=E7=A1=AC=E4=BB=B6=E4=B8=8D=E5=85=BC=E5=AE=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc b/paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc index c29d23290b755..c032732374702 100644 --- a/paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc @@ -39,14 +39,14 @@ void NanmedianGradKernel(const Context& dev_ctx, T* x_grad_ptr = dev_ctx.template Alloc(x_grad); int64_t i = 0; for (i = 0; i < numel; i++) { - if (std::isnan(x_ptr[i])) { + if (std::isnan(static_cast(x_ptr[i]))) { x_grad_ptr[i] = zero; continue; } int64_t row = static_cast(i / stride); int64_t m_row = 2 * row; - if (std::isnan(m_ptr[m_row]) || + if (std::isnan(static_cast(m_ptr[m_row])) || (x_ptr[i] != m_ptr[m_row] && x_ptr[i] != m_ptr[m_row + 1])) { x_grad_ptr[i] = zero; continue; From 24424a75cbb68156f898ee090d0f1c4b290b15d3 Mon Sep 17 00:00:00 2001 From: thunder95 <290844930@qq.com> Date: Sat, 30 Apr 2022 08:01:29 +0000 Subject: [PATCH 06/18] =?UTF-8?q?=E4=BF=AE=E5=A4=8Disnan=E5=88=A4=E6=96=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc | 4 ++-- paddle/phi/kernels/cpu/nanmedian_kernel.cc | 12 ++++++------ paddle/phi/kernels/gpu/nanmedian_grad_kernel.cu | 3 ++- paddle/phi/kernels/gpu/nanmedian_kernel.cu | 2 +- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc b/paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc index c032732374702..b0c782e3f5ad5 100644 --- a/paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc @@ -39,14 +39,14 @@ void NanmedianGradKernel(const Context& dev_ctx, T* x_grad_ptr = dev_ctx.template Alloc(x_grad); int64_t i = 0; for (i = 0; i < numel; i++) { - if (std::isnan(static_cast(x_ptr[i]))) { + if (std::isnan(static_cast(x_ptr[i]))) { x_grad_ptr[i] = zero; continue; } int64_t row = static_cast(i / stride); int64_t m_row = 2 * row; - if (std::isnan(static_cast(m_ptr[m_row])) || + if (std::isnan(static_cast(m_ptr[m_row])) || (x_ptr[i] != m_ptr[m_row] && x_ptr[i] != m_ptr[m_row + 1])) { x_grad_ptr[i] = zero; continue; diff --git a/paddle/phi/kernels/cpu/nanmedian_kernel.cc b/paddle/phi/kernels/cpu/nanmedian_kernel.cc index 287276782c211..a750d245c7aae 100644 --- a/paddle/phi/kernels/cpu/nanmedian_kernel.cc +++ b/paddle/phi/kernels/cpu/nanmedian_kernel.cc @@ -38,7 +38,7 @@ void NanmedianKernel(const Context& dev_ctx, bool all_nan = true; for (i = 0; i < numel; i++) { - if (!std::isnan(static_cast(*(x_ptr + i)))) { + if (!std::isnan(static_cast(*(x_ptr + i)))) { all_nan = false; break; } @@ -63,7 +63,7 @@ void NanmedianKernel(const Context& dev_ctx, int64_t num_nan = std::count_if(col_vec.begin(), col_vec.end(), [&](const T& val) { - return std::isnan(static_cast(val)); + return std::isnan(static_cast(val)); }); int64_t pos = (stride - num_nan - 1) / 2; @@ -71,8 +71,8 @@ void NanmedianKernel(const Context& dev_ctx, col_vec.begin() + pos, col_vec.end(), [](const T& l, const T& r) { - return (!std::isnan(static_cast(l)) && - std::isnan(static_cast(r))) || + return (!std::isnan(static_cast(l)) && + std::isnan(static_cast(r))) || (l < r); }); @@ -83,8 +83,8 @@ void NanmedianKernel(const Context& dev_ctx, col_vec.begin() + pos + 1, col_vec.end(), [](const T& l, const T& r) { - return (!std::isnan(static_cast(l)) && - std::isnan(static_cast(r))) || + return (!std::isnan(static_cast(l)) && + std::isnan(static_cast(r))) || (l < r); }); m_ptr[2 * i + 1] = col_vec[pos + 1]; diff --git a/paddle/phi/kernels/gpu/nanmedian_grad_kernel.cu b/paddle/phi/kernels/gpu/nanmedian_grad_kernel.cu index f316ea92ab319..a6e1e6228018c 100644 --- a/paddle/phi/kernels/gpu/nanmedian_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/nanmedian_grad_kernel.cu @@ -38,7 +38,8 @@ __global__ void KernelNanmedianGrad(const T* x_ptr, CUDA_KERNEL_LOOP(index, numel) { int64_t row = static_cast(index / stride); int64_t m_row = 2 * row; - if (isnan(x_ptr[index]) || isnan(medians_ptr[m_row]) || + if (isnan(static_cast(x_ptr[index])) || + isnan(static_cast(medians_ptr[m_row])) || (x_ptr[index] != medians_ptr[m_row] && x_ptr[index] != medians_ptr[m_row + 1])) { x_grad_ptr[index] = zero; diff --git a/paddle/phi/kernels/gpu/nanmedian_kernel.cu b/paddle/phi/kernels/gpu/nanmedian_kernel.cu index 103f255c7b2e6..5149c6d6baeac 100644 --- a/paddle/phi/kernels/gpu/nanmedian_kernel.cu +++ b/paddle/phi/kernels/gpu/nanmedian_kernel.cu @@ -55,7 +55,7 @@ __global__ void KernelNanCounts(const T* input, CUDA_KERNEL_LOOP(index, numel) { const T x = input[index]; - if (isnan(x)) { + if (isnan(static_cast(x))) { auto bin = static_cast(index / stride); paddle::platform::CudaAtomicAdd(&buf[bin], 1); // NOTE: at this moment paddle.sort does not suppert nan values From a0e6c3c7b6eabb4b83a968e46a193c5e88ed60c0 Mon Sep 17 00:00:00 2001 From: thunder95 <290844930@qq.com> Date: Sat, 30 Apr 2022 10:47:18 +0000 Subject: [PATCH 07/18] =?UTF-8?q?=E5=85=BC=E5=AE=B9numpy=E4=BD=8E=E7=89=88?= =?UTF-8?q?=E6=9C=AC=E4=B8=8D=E6=94=AF=E6=8C=81=E5=85=A8=E9=83=A8nan?= =?UTF-8?q?=E7=9A=84=E6=83=85=E5=86=B5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/paddle/fluid/tests/unittests/test_nanmedian.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_nanmedian.py b/python/paddle/fluid/tests/unittests/test_nanmedian.py index 93d6f487e6f38..384b6c03f2107 100644 --- a/python/paddle/fluid/tests/unittests/test_nanmedian.py +++ b/python/paddle/fluid/tests/unittests/test_nanmedian.py @@ -99,6 +99,12 @@ def clean_axis_numpy(axis, shape_len): def test_data_case(data, ignore_nan=True): for keep_dim in [False, True]: + if np.isnan(data).all() and keep_dim: + np_ver = np.version.version.split('.') + if int(np_ver[0]) < 1 or int(np_ver[1]) <= 20: + print("This numpy version does not support all nan elements when keepdim is True") + continue + np_res = np.nanmedian(data, keepdims=keep_dim) pd_res = paddle.nanmedian( paddle.to_tensor(data), From 0dac2bda99cc06069a37acc4aa87129a0b30ebd8 Mon Sep 17 00:00:00 2001 From: thunder95 <290844930@qq.com> Date: Sat, 30 Apr 2022 10:47:58 +0000 Subject: [PATCH 08/18] =?UTF-8?q?=E5=85=BC=E5=AE=B9numpy=E4=BD=8E=E7=89=88?= =?UTF-8?q?=E6=9C=AC=E4=B8=8D=E6=94=AF=E6=8C=81=E5=85=A8=E9=83=A8nan?= =?UTF-8?q?=E7=9A=84=E6=83=85=E5=86=B5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/paddle/fluid/tests/unittests/test_nanmedian.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_nanmedian.py b/python/paddle/fluid/tests/unittests/test_nanmedian.py index 384b6c03f2107..6486b71b137b6 100644 --- a/python/paddle/fluid/tests/unittests/test_nanmedian.py +++ b/python/paddle/fluid/tests/unittests/test_nanmedian.py @@ -102,7 +102,9 @@ def test_data_case(data, ignore_nan=True): if np.isnan(data).all() and keep_dim: np_ver = np.version.version.split('.') if int(np_ver[0]) < 1 or int(np_ver[1]) <= 20: - print("This numpy version does not support all nan elements when keepdim is True") + print( + "This numpy version does not support all nan elements when keepdim is True" + ) continue np_res = np.nanmedian(data, keepdims=keep_dim) From 2a944f609971d0b3dc6f91a8c5e61f0b8817cadb Mon Sep 17 00:00:00 2001 From: thunder95 <290844930@qq.com> Date: Sun, 1 May 2022 15:29:56 +0800 Subject: [PATCH 09/18] fix code example --- python/paddle/tensor/stat.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index 8db170e55574e..a257709c7a611 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -280,8 +280,10 @@ def nanmedian(x, axis=None, ignore_nan=True, keepdim=True, name=None): Examples: .. code-block:: python + :name: nanmedian-example import paddle + import numpy as np x = paddle.to_tensor([[np.nan, 2. , 3. ], [0. , 1. , 2. ]]) y1 = x.nanmedian() From 06af1838d8a804000c68cfdd18a5f29ead4eaf8e Mon Sep 17 00:00:00 2001 From: thunder95 <290844930@qq.com> Date: Thu, 5 May 2022 05:14:26 +0000 Subject: [PATCH 10/18] fix api comment error --- python/paddle/tensor/stat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index a257709c7a611..02ed93d0a9f39 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -265,7 +265,7 @@ def nanmedian(x, axis=None, ignore_nan=True, keepdim=True, name=None): If ``axis`` is less than 0, it works the same way as :math:`axis + D`. If ``axis`` is None, median is calculated over all elements of ``x``. Default is None. ignore_nan (bool, optional): Whether to ignore nan values when median was calculated. - If `ignore_nan` is True, the calculation process is the same as `median` operator. + If `ignore_nan` is False, the calculation process is the same as `median` operator. Default is True. keepdim (bool, optional): Whether to reserve the reduced dimension(s) in the output Tensor. If ``keepdim`` is True, the dimensions of From 39f5eb9b5ea1780d1dd3c6bc14a52f4dbc776d60 Mon Sep 17 00:00:00 2001 From: thunder95 <290844930@qq.com> Date: Tue, 10 May 2022 07:08:39 +0000 Subject: [PATCH 11/18] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E5=8F=8D=E5=90=91?= =?UTF-8?q?=E4=BC=A0=E6=92=AD=E9=80=BB=E8=BE=91=E4=BB=A5=E5=8F=8Ac++?= =?UTF-8?q?=E5=A4=84=E7=90=86=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/fluid/operators/nanmedian_op.cc | 61 ++--- paddle/phi/infermeta/unary.cc | 74 ++++-- paddle/phi/infermeta/unary.h | 12 +- .../phi/kernels/cpu/nanmedian_grad_kernel.cc | 85 ++++--- paddle/phi/kernels/cpu/nanmedian_kernel.cc | 214 ++++++++++++----- .../phi/kernels/gpu/nanmedian_grad_kernel.cu | 78 +++++-- paddle/phi/kernels/gpu/nanmedian_kernel.cu | 215 ++++++++++-------- paddle/phi/kernels/nanmedian_grad_kernel.h | 49 +++- paddle/phi/kernels/nanmedian_kernel.h | 50 +++- paddle/phi/ops/compat/nanmedian_sig.cc | 8 +- .../fluid/tests/unittests/test_nanmedian.py | 21 +- python/paddle/tensor/stat.py | 99 +++----- 12 files changed, 630 insertions(+), 336 deletions(-) diff --git a/paddle/fluid/operators/nanmedian_op.cc b/paddle/fluid/operators/nanmedian_op.cc index 3db9d9a0263f5..01a68b03cac46 100644 --- a/paddle/fluid/operators/nanmedian_op.cc +++ b/paddle/fluid/operators/nanmedian_op.cc @@ -1,4 +1,4 @@ -/*Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +/*Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -34,36 +34,42 @@ class NanmedianOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", - "(Tensor, default Tensor), " + "(Tensor), " "the input feature data of NanmedianOp, dtype should be" - "int32, int64, float16, float32, float64."); - AddAttr( - "ignore_nan", - "(bool, default true) Set to true if nan values should be ignored. " - "Set to false when no nan value in x were considered. ") - .SetDefault(true); - AddOutput("Medians", - "The calculation differs in the odd or even of the valid " - "elements amount." - "Along the axis, two elements contributed to the median value in " - "each row." - "If the amount of valid elements were even, both were the same.") + "int32, int64, float16, float32 or float64."); + AddOutput( + "MedianIndex", + "Store the index position of median values, The calculation differs " + "in the odd or even valid elements numbers." + "Along the axis, two elements contributed to the median value in " + "each row." + "If the amount of valid elements were even, both were the same.") .AsIntermediate() .AsExtra(); AddOutput("Out", - "(Tensor, default Tensor)," + "(Tensor)," " the output of NanmedianOp, whose dtype is the same as X"); + AddAttr("keepdim", + "(bool, default true) " + "If true, retain the reduced axis with length 1.") + .SetDefault(true); + AddAttr>("axes", + "(std::vector). List of integers," + " indicating the dimensions to calculate medians") + .SetDefault({}); AddComment(R"DOC( Nanmedian operator This operator is considered as an extention of median operation, - which supports specifically the case of nan values in the input. + which supports specifically the case of NaN values in the input. If all the elements in input are NaN it will also return NaN. If no elements in input are Nan, this op is identical to thie median op. - This operator can also supports multiple axis, - and could be switched to median operator when `ignore_nan` were set to False. + If the valid count of elements is a even number, the average value of + the elements in the middle is calculated as the median. + + This operator can also supports multiple axis. )DOC"); } }; @@ -76,9 +82,10 @@ class NanmedianGradMaker : public framework::SingleGradOpMaker { void Apply(GradOpPtr op) const override { op->SetType("nanmedian_grad"); op->SetInput("X", this->Input("X")); - op->SetInput("Medians", this->Output("Medians")); + op->SetInput("MedianIndex", this->Output("MedianIndex")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetAttrMap(this->Attrs()); } }; @@ -86,16 +93,6 @@ class NanmedianGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "nanmedian"); - OP_INOUT_CHECK(ctx->HasInput("Medians"), "Input", "Medians", "nanmedian"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", - framework::GradVarName("Out"), "nanmedian"); - - auto x_dims = ctx->GetInputDim("X"); - ctx->SetOutputDim(framework::GradVarName("X"), x_dims); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -117,4 +114,8 @@ REGISTER_OPERATOR(nanmedian, ops::NanmedianOp, ops::NanmedianOpMaker, ops::NanmedianGradMaker, NanmedianInferShapeFunctor); -REGISTER_OPERATOR(nanmedian_grad, ops::NanmedianGradOp); +DECLARE_INFER_SHAPE_FUNCTOR(nanmedian_grad, NanmedianGradInferShapeFunctor, + PD_INFER_META(phi::NanmedianGradInferMeta)); + +REGISTER_OPERATOR(nanmedian_grad, ops::NanmedianGradOp, + NanmedianGradInferShapeFunctor); diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 3d033fbd5b70a..7a93f6532b7cb 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1244,27 +1244,73 @@ void MultinomialInferMeta(const MetaTensor& x, } void NanmedianInferMeta(const MetaTensor& x, - bool ignore_nan, + const IntArray& axes, + bool keep_dim, MetaTensor* out, - MetaTensor* medians) { + MetaTensor* median_index) { + std::vector axis_list = axes.GetData(); auto x_dim = x.dims(); int64_t x_rank = x_dim.size(); + out->set_dtype(x.dtype()); + median_index->set_dtype(DataType::INT64); + median_index->set_dims(make_ddim({x.numel() * 2})); + + std::vector out_dim; + if (axis_list.empty()) { + if (keep_dim) { + for (int64_t i = 0; i < x_rank; i++) { + out_dim.push_back(1); + } + } else { + out_dim.push_back(1); + } + } else { + std::vector cleaned_axis; + for (auto& axis : axis_list) { + if (axis < 0) axis += x_rank; - std::vector out_dims(x_rank); - std::vector median_dims(x_rank); - for (int64_t i = 0; i < x_rank - 1; i++) { - out_dims[i] = x_dim[i]; - median_dims[i] = x_dim[i]; - } + PADDLE_ENFORCE_LT( + axis, + x_rank, + errors::InvalidArgument( + "Attr(axis) value should be in range [-R, R-1], R is " + "the rank of Input(X). But received axis: %d, R: %d. " + "Current Input(X)'s shape is=[%s].", + axis, + x_rank, + x_dim)); - out_dims[x_rank - 1] = 1; - median_dims[x_rank - 1] = 2; + PADDLE_ENFORCE_EQ( + std::find(cleaned_axis.begin(), cleaned_axis.end(), axis), + cleaned_axis.end(), + errors::InvalidArgument("Attr(axes) has duplicated elements: %d.", + static_cast(axis))); - out->set_dims(make_ddim(out_dims)); - out->set_dtype(x.dtype()); + cleaned_axis.push_back(axis); + } + + for (int64_t i = 0; i < x_rank; i++) { + if (std::find(cleaned_axis.begin(), cleaned_axis.end(), i) == + cleaned_axis.end()) { + out_dim.push_back(x_dim[i]); + } else if (keep_dim) { + out_dim.push_back(1); + } + } + } - medians->set_dims(make_ddim(median_dims)); - medians->set_dtype(x.dtype()); + out->set_dims(make_ddim(out_dim)); +} + +void NanmedianGradInferMeta(const MetaTensor& x, + const MetaTensor& median_index, + const MetaTensor& out_grad, + const IntArray& axes, + bool keep_dim, + MetaTensor* x_grad) { + auto x_dims = x.dims(); + x_grad->set_dims(x_dims); + x_grad->set_dtype(x.dtype()); } void NormInferMeta(const MetaTensor& x, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 8fddb48ef635d..80eca26150238 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -179,9 +179,17 @@ void MultinomialInferMeta(const MetaTensor& x, MetaTensor* out); void NanmedianInferMeta(const MetaTensor& x, - bool ignore_nan, + const IntArray& axes, + bool keep_dim, MetaTensor* out, - MetaTensor* medians); + MetaTensor* median_index); + +void NanmedianGradInferMeta(const MetaTensor& x, + const MetaTensor& median_index, + const MetaTensor& out_grad, + const IntArray& axes, + bool keep_dim, + MetaTensor* x_grad); void NormInferMeta(const MetaTensor& x, int axis, diff --git a/paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc b/paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc index b0c782e3f5ad5..a66e505ad122d 100644 --- a/paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc @@ -13,50 +13,75 @@ // limitations under the License. #include "paddle/phi/kernels/nanmedian_grad_kernel.h" - #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" namespace phi { template -void NanmedianGradKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& medians, - const DenseTensor& out_grad, - DenseTensor* x_grad) { - const T* x_ptr = x.data(); - const T* m_ptr = medians.data(); - const T* out_grad_ptr = out_grad.data(); +void CalcMedianGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& median_index, + const DenseTensor& out_grad, + const IntArray& axes, + DenseTensor* x_grad, + T* x_grad_ptr) { + phi::funcs::SetConstant set_zero; + set_zero(dev_ctx, x_grad, static_cast(0)); + if (!x_grad_ptr) return; + const int64_t* m_ptr = median_index.data(); + const T* out_grad_ptr = out_grad.data(); int64_t numel = x.numel(); auto x_dim = x.dims(); - int64_t x_rank = x_dim.size(); - int64_t stride = x_dim[x_rank - 1]; - auto zero = static_cast(0); - - if (x_grad) { - T* x_grad_ptr = dev_ctx.template Alloc(x_grad); - int64_t i = 0; - for (i = 0; i < numel; i++) { - if (std::isnan(static_cast(x_ptr[i]))) { - x_grad_ptr[i] = zero; - continue; - } + int64_t rank = x_dim.size(); + int64_t stride = x_dim[rank - 1]; - int64_t row = static_cast(i / stride); - int64_t m_row = 2 * row; - if (std::isnan(static_cast(m_ptr[m_row])) || - (x_ptr[i] != m_ptr[m_row] && x_ptr[i] != m_ptr[m_row + 1])) { - x_grad_ptr[i] = zero; - continue; - } - - x_grad_ptr[i] = out_grad_ptr[row]; + int64_t pre_dim = numel / stride; + int64_t i = 0; + int64_t offset = 0; + for (i = 0; i < pre_dim; i++) { + if (m_ptr[2 * i] >= 0) { + x_grad_ptr[offset + m_ptr[2 * i]] = out_grad_ptr[i]; + x_grad_ptr[offset + m_ptr[2 * i + 1]] = out_grad_ptr[i]; } + offset += stride; + } +} + +template +void BaseMedianGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& median_index, + const DenseTensor& out_grad, + const IntArray& axes, + DenseTensor* x_grad) { + auto rank = x.dims().size(); + T* x_grad_ptr = dev_ctx.template Alloc(x_grad); + if (axes.size() && (rank > 1)) { + DenseTensor tmp_x_grad(*x_grad); + CalcMedianGradKernel( + dev_ctx, x, median_index, out_grad, axes, &tmp_x_grad, x_grad_ptr); + PostprocessMedianGradKernel(dev_ctx, &tmp_x_grad, axes, x_grad); + } else { + CalcMedianGradKernel( + dev_ctx, x, median_index, out_grad, axes, x_grad, x_grad_ptr); } } +template +void NanmedianGradKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& median_index, + const DenseTensor& out_grad, + const IntArray& axes, + bool keep_dim, + DenseTensor* x_grad) { + BaseMedianGradKernel( + dev_ctx, input, median_index, out_grad, axes, x_grad); +} + } // namespace phi PD_REGISTER_KERNEL(nanmedian_grad, diff --git a/paddle/phi/kernels/cpu/nanmedian_kernel.cc b/paddle/phi/kernels/cpu/nanmedian_kernel.cc index a750d245c7aae..56f7fb260a7e3 100644 --- a/paddle/phi/kernels/cpu/nanmedian_kernel.cc +++ b/paddle/phi/kernels/cpu/nanmedian_kernel.cc @@ -13,21 +13,102 @@ // limitations under the License. #include "paddle/phi/kernels/nanmedian_kernel.h" - #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/top_k_kernel.h" namespace phi { template -void NanmedianKernel(const Context& dev_ctx, - const DenseTensor& x, - bool ignore_nan, - DenseTensor* out, - DenseTensor* medians) { +void CalcMedianFunc(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& nan_counts, + bool ignore_nan, + int64_t sort_k, + int64_t stride, + int64_t pre_dim, + T* o_ptr, + int64_t* m_ptr) { + bool should_ignore_nan = ignore_nan; + DenseTensor sort_out; + DenseTensor sort_indices; + auto sort_dim = x.dims(); + int64_t rank = sort_dim.size(); + sort_dim[rank - 1] = sort_k; + sort_out.Resize(sort_dim); + sort_indices.Resize(sort_dim); + + dev_ctx.template Alloc(&sort_out); + T* sort_out_ptr = sort_out.data(); + dev_ctx.template Alloc(&sort_indices); + int64_t* sort_indices_ptr = sort_indices.data(); + + TopkKernel( + dev_ctx, x, Scalar(sort_k), -1, false, true, &sort_out, &sort_indices); + + T div_factor = static_cast(2.0); + int64_t pos = 0; + int64_t i = 0; + bool is_ori_odd = stride & 1; + if (should_ignore_nan) { + for (i = 0; i < pre_dim; i++) { + if (nan_counts[i] == stride) { + m_ptr[i * 2] = -1; + m_ptr[i * 2 + 1] = -1; + o_ptr[i] = sort_out_ptr[pos]; + } else { + int64_t nan_k = nan_counts[i] > 0 + ? static_cast(stride - nan_counts[i]) + : sort_k; + int64_t row_pos = static_cast(nan_k >> 1); + int64_t off_set = pos + row_pos; + if (nan_k & 1) { + m_ptr[2 * i] = sort_indices_ptr[off_set]; + m_ptr[2 * i + 1] = sort_indices_ptr[off_set]; + o_ptr[i] = sort_out_ptr[off_set]; + } else { + m_ptr[2 * i] = row_pos > 0 ? sort_indices_ptr[off_set - 1] + : sort_indices_ptr[off_set]; + m_ptr[2 * i + 1] = sort_indices_ptr[off_set]; + T m_val_left = + row_pos > 0 ? sort_out_ptr[off_set - 1] : sort_out_ptr[off_set]; + T m_val_right = sort_out_ptr[off_set]; + o_ptr[i] = (m_val_left + m_val_right) / div_factor; + } + } + pos += sort_k; + } + } else { + pos = -1; + if (is_ori_odd) { + for (i = 0; i < pre_dim; i++) { + pos += sort_k; + o_ptr[i] = sort_out_ptr[pos]; + m_ptr[2 * i] = sort_indices_ptr[pos]; + m_ptr[2 * i + 1] = sort_indices_ptr[pos]; + } + } else { + for (i = 0; i < pre_dim; i++) { + pos += sort_k; + m_ptr[2 * i] = + sort_k > 1 ? sort_indices_ptr[pos - 1] : sort_indices_ptr[pos]; + m_ptr[2 * i + 1] = sort_indices_ptr[pos]; + T m_val_left = sort_k > 1 ? sort_out_ptr[pos - 1] : sort_out_ptr[pos]; + T m_val_right = sort_out_ptr[pos]; + o_ptr[i] = (m_val_left + m_val_right) / div_factor; + } + } + } +} + +template +void ProcessMedianKernel(const Context& dev_ctx, + const DenseTensor& x, + T* o_ptr, + int64_t* m_ptr, + bool ignore_nan) { + bool should_ignore_nan = ignore_nan; const T* x_ptr = x.data(); - T* o_ptr = dev_ctx.template Alloc(out); - T* m_ptr = dev_ctx.template Alloc(medians); int64_t numel = x.numel(); auto x_dim = x.dims(); @@ -36,61 +117,82 @@ void NanmedianKernel(const Context& dev_ctx, int64_t pre_dim = numel / stride; int64_t i = 0; - bool all_nan = true; - for (i = 0; i < numel; i++) { - if (!std::isnan(static_cast(*(x_ptr + i)))) { - all_nan = false; - break; + int64_t max_valid_num = 0; + std::vector nan_counts; + if (should_ignore_nan) { + int64_t total_nan_num = 0; + std::vector col_vec; + col_vec.reserve(stride); + col_vec.resize(stride); + nan_counts.clear(); + nan_counts.reserve(pre_dim); + nan_counts.resize(pre_dim); + for (int64_t i = 0; i < pre_dim; i++) { + col_vec.clear(); + col_vec.insert( + col_vec.begin(), x_ptr + i * stride, x_ptr + (i + 1) * stride); + nan_counts[i] = + std::count_if(col_vec.begin(), col_vec.end(), [&](const T& val) { + return std::isnan(static_cast(val)); + }); + total_nan_num += nan_counts[i]; + if (stride - nan_counts[i] > max_valid_num) + max_valid_num = stride - nan_counts[i]; } - } - - if (all_nan) { - for (i = 0; i < pre_dim; i++) { - o_ptr[i] = x_ptr[0]; - m_ptr[2 * i] = x_ptr[0]; - m_ptr[2 * i + 1] = x_ptr[0]; + // all elems are nan + if (total_nan_num == numel) { + for (i = 0; i < pre_dim; i++) { + o_ptr[i] = x_ptr[0]; + m_ptr[2 * i] = -1; + m_ptr[2 * i + 1] = -1; + } + return; } - return; + should_ignore_nan = total_nan_num > 0; } - std::vector col_vec; - col_vec.reserve(stride); - col_vec.resize(stride); - for (i = 0; i < pre_dim; i++) { - col_vec.clear(); - col_vec.insert( - col_vec.begin(), x_ptr + i * stride, x_ptr + (i + 1) * stride); + int64_t sort_k = should_ignore_nan ? max_valid_num : ((stride >> 1) + 1); + CalcMedianFunc(dev_ctx, + x, + nan_counts, + should_ignore_nan, + sort_k, + stride, + pre_dim, + o_ptr, + m_ptr); +} - int64_t num_nan = - std::count_if(col_vec.begin(), col_vec.end(), [&](const T& val) { - return std::isnan(static_cast(val)); - }); +template +void BaseMedianKernel(const Context& dev_ctx, + const DenseTensor& input, + const IntArray& axes, + DenseTensor* out, + DenseTensor* median_index, + bool ignore_nan) { + DenseTensor x; + auto rank = input.dims().size(); + if (axes.size() == 0) { + x = input; + x.Resize({input.numel()}); + } else if (rank > 1) { + PreprocessMedianKernel(dev_ctx, input, axes, &x); + } - int64_t pos = (stride - num_nan - 1) / 2; - std::nth_element(col_vec.begin(), - col_vec.begin() + pos, - col_vec.end(), - [](const T& l, const T& r) { - return (!std::isnan(static_cast(l)) && - std::isnan(static_cast(r))) || - (l < r); - }); + T* o_ptr = dev_ctx.template Alloc(out); + int64_t* m_ptr = dev_ctx.template Alloc(median_index); + ProcessMedianKernel(dev_ctx, x, o_ptr, m_ptr, ignore_nan); + out->Resize(out->dims()); +} - m_ptr[2 * i] = col_vec[pos]; - m_ptr[2 * i + 1] = col_vec[pos]; - if ((stride - num_nan) % 2 == 0) { - std::nth_element(col_vec.begin(), - col_vec.begin() + pos + 1, - col_vec.end(), - [](const T& l, const T& r) { - return (!std::isnan(static_cast(l)) && - std::isnan(static_cast(r))) || - (l < r); - }); - m_ptr[2 * i + 1] = col_vec[pos + 1]; - } - o_ptr[i] = static_cast((m_ptr[2 * i] + m_ptr[2 * i + 1]) / 2.0); - } +template +void NanmedianKernel(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& axes, + bool keepdim, + DenseTensor* out, + DenseTensor* median_index) { + BaseMedianKernel(dev_ctx, x, axes, out, median_index, true); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/nanmedian_grad_kernel.cu b/paddle/phi/kernels/gpu/nanmedian_grad_kernel.cu index a6e1e6228018c..7764fadc5f79b 100644 --- a/paddle/phi/kernels/gpu/nanmedian_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/nanmedian_grad_kernel.cu @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/kernels/nanmedian_grad_kernel.h" - #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_meta.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/nanmedian_grad_kernel.h" namespace phi { @@ -29,46 +29,76 @@ inline int GET_BLOCKS(const int N) { template __global__ void KernelNanmedianGrad(const T* x_ptr, - const T* medians_ptr, + const int64_t* medians_ptr, const T* out_grad_ptr, T* x_grad_ptr, int64_t stride, - int64_t numel) { - auto zero = static_cast(0); - CUDA_KERNEL_LOOP(index, numel) { - int64_t row = static_cast(index / stride); - int64_t m_row = 2 * row; - if (isnan(static_cast(x_ptr[index])) || - isnan(static_cast(medians_ptr[m_row])) || - (x_ptr[index] != medians_ptr[m_row] && - x_ptr[index] != medians_ptr[m_row + 1])) { - x_grad_ptr[index] = zero; - } else { - x_grad_ptr[index] = out_grad_ptr[row]; + int64_t pre_dim) { + CUDA_KERNEL_LOOP(index, pre_dim) { + int64_t offset = index * stride; + if (medians_ptr[2 * index] >= 0) { + x_grad_ptr[offset + medians_ptr[2 * index]] = out_grad_ptr[index]; + x_grad_ptr[offset + medians_ptr[2 * index + 1]] = out_grad_ptr[index]; } } } template -void NanmedianGradKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& medians, - const DenseTensor& out_grad, - DenseTensor* x_grad) { +void CalcMedianGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& median_index, + const DenseTensor& out_grad, + DenseTensor* x_grad, + T* x_grad_ptr) { + phi::funcs::SetConstant set_zero; + set_zero(dev_ctx, x_grad, static_cast(0)); + auto stream = dev_ctx.stream(); const T* x_ptr = x.data(); - const T* m_ptr = medians.data(); + const int64_t* m_ptr = median_index.data(); const T* out_grad_ptr = out_grad.data(); - T* x_grad_ptr = dev_ctx.template Alloc(x_grad); int64_t numel = x.numel(); auto x_dim = x.dims(); int64_t x_rank = x_dim.size(); int64_t stride = x_dim[x_rank - 1]; + int64_t pre_dim = numel / stride; KernelNanmedianGrad< - T><<>>( - x_ptr, m_ptr, out_grad_ptr, x_grad_ptr, stride, numel); + T><<>>( + x_ptr, m_ptr, out_grad_ptr, x_grad_ptr, stride, pre_dim); +} + +template +void BaseMedianGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& median_index, + const DenseTensor& out_grad, + const IntArray& axes, + DenseTensor* x_grad) { + auto rank = x.dims().size(); + T* x_grad_ptr = dev_ctx.template Alloc(x_grad); + if (axes.size() && (rank > 1)) { + DenseTensor tmp_x_grad(*x_grad); + CalcMedianGradKernel( + dev_ctx, x, median_index, out_grad, &tmp_x_grad, x_grad_ptr); + PostprocessMedianGradKernel(dev_ctx, &tmp_x_grad, axes, x_grad); + } else { + CalcMedianGradKernel( + dev_ctx, x, median_index, out_grad, x_grad, x_grad_ptr); + } +} + +template +void NanmedianGradKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& median_index, + const DenseTensor& out_grad, + const IntArray& axes, + bool keep_dim, + DenseTensor* x_grad) { + BaseMedianGradKernel( + dev_ctx, input, median_index, out_grad, axes, x_grad); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/nanmedian_kernel.cu b/paddle/phi/kernels/gpu/nanmedian_kernel.cu index 5149c6d6baeac..e21593bcc4c2e 100644 --- a/paddle/phi/kernels/gpu/nanmedian_kernel.cu +++ b/paddle/phi/kernels/gpu/nanmedian_kernel.cu @@ -12,15 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/kernels/nanmedian_kernel.h" - #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/nanmedian_kernel.h" #include "paddle/phi/kernels/top_k_kernel.h" namespace phi { @@ -38,8 +36,7 @@ __global__ void KernelNanCounts(const T* input, const int64_t stride, T min_val, int64_t* nan_total, - int64_t* nan_counts, - T* output) { + int64_t* nan_counts) { extern __shared__ int64_t buf[]; for (int i = threadIdx.x; i < pre_dim; i += blockDim.x) { buf[i] = 0; @@ -58,8 +55,6 @@ __global__ void KernelNanCounts(const T* input, if (isnan(static_cast(x))) { auto bin = static_cast(index / stride); paddle::platform::CudaAtomicAdd(&buf[bin], 1); - // NOTE: at this moment paddle.sort does not suppert nan values - output[index] = min_val; } } __syncthreads(); @@ -72,32 +67,36 @@ __global__ void KernelNanCounts(const T* input, } template -__global__ void CalcMedianKernel(const T* sort_out, - T* median_val, +__global__ void CalcMedianKernel(const T* sort_out_ptr, + const int64_t* sort_indices_ptr, + int64_t* median_val, T* output, + T div_factor, const bool is_odd, const int64_t pre_dim, const int64_t stride) { - T div_factor = static_cast(2.0); CUDA_KERNEL_LOOP(index, pre_dim) { int64_t pos = static_cast((index + 1) * stride) - 1; if (is_odd) { - median_val[index * 2] = sort_out[pos]; - median_val[index * 2 + 1] = sort_out[pos]; - output[index] = sort_out[pos]; + median_val[index * 2] = sort_indices_ptr[pos]; + median_val[index * 2 + 1] = sort_indices_ptr[pos]; + output[index] = sort_out_ptr[pos]; } else { - median_val[index * 2] = pos > 1 ? sort_out[pos - 1] : sort_out[pos]; - median_val[index * 2 + 1] = sort_out[pos]; - output[index] = - (median_val[index * 2] + median_val[index * 2 + 1]) / div_factor; + median_val[index * 2] = + pos > 0 ? sort_indices_ptr[pos - 1] : sort_indices_ptr[pos]; + median_val[index * 2 + 1] = sort_indices_ptr[pos]; + T median_val_left = pos > 0 ? sort_out_ptr[pos - 1] : sort_out_ptr[pos]; + T median_val_right = sort_out_ptr[pos]; + output[index] = (median_val_left + median_val_right) / div_factor; } } } template -__global__ void CalcNanmedianKernel(const T* sort_out, +__global__ void CalcNanmedianKernel(const T* sort_out_ptr, + const int64_t* sort_indices_ptr, int64_t* nan_counts, - T* median_val, + int64_t* median_val, T* output, const bool is_odd, const int64_t pre_dim, @@ -109,44 +108,44 @@ __global__ void CalcNanmedianKernel(const T* sort_out, int64_t pos = static_cast(index * max_valid_num); int64_t nan_cnt = nan_counts[index]; if (nan_cnt == stride) { - median_val[index * 2] = nan_val; - median_val[index * 2 + 1] = nan_val; + median_val[index * 2] = -1; + median_val[index * 2 + 1] = -1; output[index] = nan_val; } else { - bool check_odd = is_odd; int64_t nan_k = nan_cnt > 0 ? static_cast(stride - nan_cnt) : max_valid_num; - pos += static_cast(nan_k >> 1); - check_odd = nan_k & 1; + int64_t row_pos = static_cast(nan_k >> 1); + pos += row_pos; - if (check_odd) { - median_val[index * 2] = sort_out[pos]; - median_val[index * 2 + 1] = sort_out[pos]; - output[index] = sort_out[pos]; + if (nan_k & 1) { + median_val[index * 2] = sort_indices_ptr[pos]; + median_val[index * 2 + 1] = sort_indices_ptr[pos]; + output[index] = sort_out_ptr[pos]; } else { - median_val[index * 2] = pos > 0 ? sort_out[pos - 1] : sort_out[pos]; - median_val[index * 2 + 1] = sort_out[pos]; - output[index] = - (median_val[index * 2] + median_val[index * 2 + 1]) / div_factor; + median_val[index * 2] = + pos > 0 ? sort_indices_ptr[pos - 1] : sort_indices_ptr[pos]; + median_val[index * 2 + 1] = sort_indices_ptr[pos]; + T median_val_left = pos > 0 ? sort_out_ptr[pos - 1] : sort_out_ptr[pos]; + T median_val_right = sort_out_ptr[pos]; + output[index] = (median_val_left + median_val_right) / div_factor; } } } } template -void NanmedianKernel(const Context& dev_ctx, - const DenseTensor& x, - bool ignore_nan, - DenseTensor* out, - DenseTensor* medians) { +void ProcessMedianKernel(const Context& dev_ctx, + const DenseTensor& x, + bool ignore_nan, + DenseTensor* out, + int64_t* m_ptr) { + bool should_ignore_nan = ignore_nan; auto stream = dev_ctx.stream(); auto* ctx = reinterpret_cast(&dev_ctx); const T* x_ptr = x.data(); T* o_ptr = dev_ctx.template Alloc(out); - T* m_ptr = dev_ctx.template Alloc(medians); - int64_t numel = x.numel(); auto x_dim = x.dims(); int64_t x_rank = x_dim.size(); @@ -154,35 +153,17 @@ void NanmedianKernel(const Context& dev_ctx, int64_t pre_dim = numel / stride; int64_t i = 0; - int64_t half_stride = (stride >> 1) + 1; - bool is_ori_odd = stride & 1; - - DenseTensor sort_out; - auto sort_dim = x.dims(); - sort_dim[x_rank - 1] = half_stride; - - sort_out.Resize(sort_dim); - dev_ctx.template Alloc(&sort_out); - T* sort_out_ptr = sort_out.data(); - - std::vector out_dim_vec = vectorize(sort_dim); - DenseTensor indices = phi::Empty(dev_ctx, IntArray(out_dim_vec)); - - if (ignore_nan) { - DenseTensor nan_counts, nan_stat, nonnan_x; - + DenseTensor nan_counts, nan_stat; + int64_t* nan_counts_ptr; + int64_t max_valid_num = 0; + if (should_ignore_nan) { nan_counts.Resize(phi::make_ddim({pre_dim})); dev_ctx.template Alloc(&nan_counts); - int64_t* nan_counts_ptr = nan_counts.data(); - + nan_counts_ptr = nan_counts.data(); nan_stat.Resize(phi::make_ddim({2})); int64_t* nan_stat_mem = dev_ctx.template Alloc(&nan_stat); int64_t* nan_stat_ptr = nan_stat.data(); - nonnan_x.Resize(x.dims()); - dev_ctx.template Alloc(&nonnan_x); - T* nonnan_x_ptr = nonnan_x.data(); - KernelNanCounts<<::min(), nan_stat_ptr, - nan_counts_ptr, - nonnan_x_ptr); + nan_counts_ptr); auto nan_stat_mem_cpu = paddle::memory::Alloc(phi::CPUPlace(), sizeof(int64_t) * 2); @@ -206,49 +186,94 @@ void NanmedianKernel(const Context& dev_ctx, sizeof(int64_t) * 2, stream); - // all elements are nan + // all elements are nan values T nan_val = std::numeric_limits::quiet_NaN(); if (nan_stat_cpu_ptr[0] == numel) { FullLikeKernel(dev_ctx, x, nan_val, x.dtype(), out); return; } - if (nan_stat_cpu_ptr[0] > 0) { - int64_t max_valid_num = nan_stat_cpu_ptr[1]; - T div_factor = static_cast(2.0); + should_ignore_nan = nan_stat_cpu_ptr[0] > 0; + max_valid_num = nan_stat_cpu_ptr[1]; + } + + int64_t sort_k = should_ignore_nan ? max_valid_num : ((stride >> 1) + 1); + bool is_ori_odd = stride & 1; - TopkKernel(dev_ctx, - x, - Scalar(max_valid_num), - -1, - true, - true, - &sort_out, - &indices); + DenseTensor sort_out, sort_indices; + auto sort_dim = x.dims(); + int64_t rank = sort_dim.size(); + sort_dim[rank - 1] = sort_k; + sort_out.Resize(sort_dim); + sort_indices.Resize(sort_dim); - CalcNanmedianKernel< - T><<>>( - sort_out_ptr, - nan_counts_ptr, - m_ptr, - o_ptr, - is_ori_odd, - pre_dim, - max_valid_num, - stride, - div_factor, - nan_val); + dev_ctx.template Alloc(&sort_out); + T* sort_out_ptr = sort_out.data(); + dev_ctx.template Alloc(&sort_indices); + int64_t* sort_indices_ptr = sort_indices.data(); - return; - } + TopkKernel( + dev_ctx, x, Scalar(sort_k), -1, false, true, &sort_out, &sort_indices); + + T div_factor = static_cast(2.0); + T nan_val = std::numeric_limits::quiet_NaN(); + if (should_ignore_nan) { + CalcNanmedianKernel< + T><<>>( + sort_out_ptr, + sort_indices_ptr, + nan_counts_ptr, + m_ptr, + o_ptr, + is_ori_odd, + pre_dim, + max_valid_num, + stride, + div_factor, + nan_val); + } else { + CalcMedianKernel< + T><<>>( + sort_out_ptr, + sort_indices_ptr, + m_ptr, + o_ptr, + div_factor, + is_ori_odd, + pre_dim, + sort_k); } +} - TopkKernel( - dev_ctx, x, Scalar(half_stride), -1, true, true, &sort_out, &indices); +template +void BaseMedianKernel(const Context& dev_ctx, + const DenseTensor& input, + const IntArray& axes, + bool ignore_nan, + DenseTensor* out, + DenseTensor* median_index) { + DenseTensor x; + auto rank = input.dims().size(); + if (axes.size() == 0) { + x = input; + x.Resize({input.numel()}); + } else if (rank > 1) { + PreprocessMedianKernel(dev_ctx, input, axes, &x); + } + + int64_t* m_ptr = dev_ctx.template Alloc(median_index); + ProcessMedianKernel(dev_ctx, x, ignore_nan, out, m_ptr); + out->Resize(out->dims()); +} - CalcMedianKernel< - T><<>>( - sort_out_ptr, m_ptr, o_ptr, is_ori_odd, pre_dim, half_stride); +template +void NanmedianKernel(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& axes, + bool keepdim, + DenseTensor* out, + DenseTensor* median_index) { + BaseMedianKernel(dev_ctx, x, axes, true, out, median_index); } } // namespace phi diff --git a/paddle/phi/kernels/nanmedian_grad_kernel.h b/paddle/phi/kernels/nanmedian_grad_kernel.h index 714b2b8192d86..dc7321c1aa751 100644 --- a/paddle/phi/kernels/nanmedian_grad_kernel.h +++ b/paddle/phi/kernels/nanmedian_grad_kernel.h @@ -13,14 +13,61 @@ // limitations under the License. #pragma once +#include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/math_function.h" namespace phi { +template +void PostprocessMedianGradKernel(const Context& dev_ctx, + DenseTensor* input, + const IntArray& raw_axes, + DenseTensor* x) { + auto input_dim = input->dims(); + auto rank = input_dim.size(); + + std::vector axes = raw_axes.GetData(); + int64_t axes_size = static_cast(axes.size()); + for (int64_t i = 0; i < axes_size; i++) { + if (axes[i] < 0) { + axes[i] += rank; + } + } + + std::vector trans_back; + std::vector reshape_back; + trans_back.reserve(rank); + trans_back.resize(rank); + + int offset = 0; + for (int64_t i = 0; i < rank; i++) { + if (std::find(axes.begin(), axes.end(), i) == axes.end()) { + reshape_back.push_back(input_dim[i]); + trans_back[i] = offset; + offset += 1; + } + } + + for (int64_t i = 0; i < rank; i++) { + if (std::find(axes.begin(), axes.end(), i) != axes.end()) { + trans_back[i] = offset; + reshape_back.push_back(input_dim[i]); + offset += 1; + } + } + + input->Resize(make_ddim(reshape_back)); + funcs::TransCompute( + static_cast(trans_back.size()), dev_ctx, *input, x, trans_back); +} + template void NanmedianGradKernel(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& medians, + const DenseTensor& median_index, const DenseTensor& out_grad, + const IntArray& axes, + bool keep_dim, DenseTensor* x_grad); } // namespace phi diff --git a/paddle/phi/kernels/nanmedian_kernel.h b/paddle/phi/kernels/nanmedian_kernel.h index e30472550399c..513bb49022242 100644 --- a/paddle/phi/kernels/nanmedian_kernel.h +++ b/paddle/phi/kernels/nanmedian_kernel.h @@ -13,14 +13,62 @@ // limitations under the License. #pragma once +#include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/math_function.h" namespace phi { +template +void PreprocessMedianKernel(const Context& dev_ctx, + const DenseTensor& input, + const IntArray& raw_axes, + DenseTensor* x) { + auto input_dim = input.dims(); + auto rank = input_dim.size(); + std::vector perm; + std::vector reshape; + + std::vector axes = raw_axes.GetData(); + int64_t axes_size = static_cast(axes.size()); + for (int64_t i = 0; i < axes_size; i++) { + if (axes[i] < 0) { + axes[i] += rank; + } + } + + for (int64_t i = 0; i < rank; i++) { + if (std::find(axes.begin(), axes.end(), i) == axes.end()) { + perm.push_back(i); + reshape.push_back(input_dim[i]); + } + } + + int64_t post_numel = 1; + for (int64_t i = 0; i < rank; i++) { + if (std::find(axes.begin(), axes.end(), i) != axes.end()) { + perm.push_back(i); + post_numel *= input_dim[i]; + } + } + reshape.push_back(post_numel); + + DDim trans_dim(input_dim); + int ndims = perm.size(); + for (int i = 0; i < ndims; i++) { + trans_dim[i] = input_dim[perm[i]]; + } + x->mutable_data(trans_dim, dev_ctx.GetPlace()); + funcs::TransCompute(ndims, dev_ctx, input, x, perm); + + x->Resize(make_ddim(reshape)); +} + template void NanmedianKernel(const Context& dev_ctx, const DenseTensor& x, - bool ignore_nan, + const IntArray& axes, + bool keep_dim, DenseTensor* out, DenseTensor* medians); } // namespace phi diff --git a/paddle/phi/ops/compat/nanmedian_sig.cc b/paddle/phi/ops/compat/nanmedian_sig.cc index 58cb13e344232..5c6d2b38b39dd 100644 --- a/paddle/phi/ops/compat/nanmedian_sig.cc +++ b/paddle/phi/ops/compat/nanmedian_sig.cc @@ -18,13 +18,15 @@ namespace phi { KernelSignature NanmedianOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature( - "nanmedian", {"X"}, {"ignore_nan"}, {"Out", "Medians"}); + "nanmedian", {"X"}, {"axes", "keepdim"}, {"Out", "MedianIndex"}); } KernelSignature NanmedianGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature( - "nanmedian_grad", {"X", "Medians", "Out@GRAD"}, {}, {"X@GRAD"}); + return KernelSignature("nanmedian_grad", + {"X", "MedianIndex", "Out@GRAD"}, + {"axes", "keepdim"}, + {"X@GRAD"}); } } // namespace phi diff --git a/python/paddle/fluid/tests/unittests/test_nanmedian.py b/python/paddle/fluid/tests/unittests/test_nanmedian.py index 6486b71b137b6..bc6b13c86f676 100644 --- a/python/paddle/fluid/tests/unittests/test_nanmedian.py +++ b/python/paddle/fluid/tests/unittests/test_nanmedian.py @@ -61,7 +61,7 @@ def setUp(self): self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ else paddle.CPUPlace() self.axis_candiate_list = [ - None, 0, 2, -1, -2, (1, 2), [0, -1], [0, 1, 3], (1, 2, -3), + None, 0, 2, -1, -2, (1, 2), [0, -1], [0, 1, 3], (1, 2, 3), [0, 2, 1, 3] ] @@ -97,7 +97,7 @@ def clean_axis_numpy(axis, shape_len): axis = set(axis) return axis - def test_data_case(data, ignore_nan=True): + def test_data_case(data): for keep_dim in [False, True]: if np.isnan(data).all() and keep_dim: np_ver = np.version.version.split('.') @@ -109,19 +109,14 @@ def test_data_case(data, ignore_nan=True): np_res = np.nanmedian(data, keepdims=keep_dim) pd_res = paddle.nanmedian( - paddle.to_tensor(data), - ignore_nan=ignore_nan, - keepdim=keep_dim) + paddle.to_tensor(data), keepdim=keep_dim) self.assertTrue( np.allclose( np_res, pd_res.numpy(), equal_nan=True)) - def test_axis_case(data, axis, ignore_nan=True): + def test_axis_case(data, axis): pd_res = paddle.nanmedian( - paddle.to_tensor(data), - axis=axis, - ignore_nan=ignore_nan, - keepdim=False) + paddle.to_tensor(data), axis=axis, keepdim=False) axis = clean_axis_numpy(axis, len(data.shape)) np_res = np.nanmedian(data, axis=axis, keepdims=False) self.assertTrue(np.allclose(np_res, pd_res.numpy(), equal_nan=True)) @@ -129,7 +124,7 @@ def test_axis_case(data, axis, ignore_nan=True): for name, data in self.fake_data.items(): test_data_case(data) if "_normal" in name: - test_data_case(data, ignore_nan=False) + test_data_case(data) for axis in self.axis_candiate_list: test_axis_case(self.fake_data["row_nan_even"], axis) @@ -152,9 +147,13 @@ def test_empty_axis(): def test_axis_not_in_range(): paddle.nanmedian(x, axis=3, keepdim=True) + def test_duplicated_axis(): + paddle.nanmedian(x, axis=[1, -1], keepdim=True) + self.assertRaises(TypeError, test_dtype) self.assertRaises(ValueError, test_empty_axis) self.assertRaises(ValueError, test_axis_not_in_range) + self.assertRaises(ValueError, test_duplicated_axis) def test_dygraph(self): paddle.disable_static(place=self.place) diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index 02ed93d0a9f39..9746a658b576f 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -253,10 +253,13 @@ def numel(x, name=None): return out -def nanmedian(x, axis=None, ignore_nan=True, keepdim=True, name=None): +def nanmedian(x, axis=None, keepdim=True, name=None): """ Compute the median along the specified axis, while ignoring NaNs. + If the valid count of elements is a even number, + the average value of both elements in the middle is calculated as the median. + Args: x (Tensor): The input Tensor, it's data type can be int32, int64, float16, float32, float64. axis (None|int|list|tuple, optional): @@ -264,9 +267,6 @@ def nanmedian(x, axis=None, ignore_nan=True, keepdim=True, name=None): ``axis`` should be in range [-D, D), where D is the dimensions of ``x`` . If ``axis`` is less than 0, it works the same way as :math:`axis + D`. If ``axis`` is None, median is calculated over all elements of ``x``. Default is None. - ignore_nan (bool, optional): Whether to ignore nan values when median was calculated. - If `ignore_nan` is False, the calculation process is the same as `median` operator. - Default is True. keepdim (bool, optional): Whether to reserve the reduced dimension(s) in the output Tensor. If ``keepdim`` is True, the dimensions of the output Tensor is the same as ``x`` except in the reduced @@ -298,71 +298,37 @@ def nanmedian(x, axis=None, ignore_nan=True, keepdim=True, name=None): if not isinstance(x, Variable): raise TypeError("In median, the input x should be a Tensor.") + if isinstance(axis, (list, tuple)) and len(axis) == 0: + raise ValueError("Axis list should not be empty.") + dims = len(x.shape) - out_shape = list(x.shape) if axis is None: - x = paddle.flatten(x) - axis = 0 - out_shape = [1] * dims if keepdim else [1] - else: - if isinstance(axis, tuple): - axis = list(axis) + axis = [] + elif isinstance(axis, tuple): + axis = list(axis) + elif isinstance(axis, int): + axis = [axis] - if isinstance(axis, list): - for k in range(len(axis)): - if axis[k] < 0: - axis[k] += dims - axis = list(set(axis)) - if len(axis) <= 0: - raise ValueError("axis should not be empty") + if not isinstance(axis, list): + raise ValueError( + "Axis should be None, int, or a list, element should in range [-rank(x), rank(x))." + ) - axis_src, axis_dst = [], [] - for axis_single in axis: - if not isinstance(axis_single, int) or not ( - axis_single < dims and axis_single >= -dims): - raise ValueError( - "Axis should be None, int, or a list, element should in range [-rank(x), rank(x))." - ) - axis_src.append(axis_single) - out_shape[axis_single] = 1 if keepdim else 0 + for i in range(len(axis)): + if not isinstance(axis[i], int) or not (axis[i] < dims and + axis[i] >= -dims): + raise ValueError( + "Axis should be None, int, or a list, element should in range [-rank(x), rank(x))." + ) + if axis[i] < 0: + axis[i] += dims - axis_dst = list(range(-len(axis), 0)) - x = paddle.moveaxis(x, axis_src, axis_dst) - x = paddle.flatten(x, axis_dst[0], axis_dst[-1]) - axis = axis_dst[0] - else: - if not isinstance(axis, int) or not (axis < dims and axis >= -dims): - raise ValueError( - "Axis should be None, int, or a list, element should in range [-rank(x), rank(x))." - ) - if axis < 0: - axis += dims - out_shape[axis] = 1 if keepdim else 0 - - trans_axis = [] - last_axis = len(x.shape) - 1 - if len(x.shape) > 1 and axis is not None and axis != last_axis: - for i in range(axis): - trans_axis.append(i) - trans_axis.append(last_axis) - for i in range(axis + 1, last_axis): - trans_axis.append(i) - trans_axis.append(axis) - x = paddle.transpose(x, perm=trans_axis) - - if len(x.shape) == 1 and axis == 0: - axis = None - else: - axis = len(x.shape) - 1 - res_shape = [x for x in out_shape if x > 0] + if len(axis) != len(set(axis)): + raise ValueError("Axis has duplicated elements.") if _in_legacy_dygraph(): - medians, out = _C_ops.nanmedian(x, 'ignore_nan', ignore_nan) - if len(trans_axis) > 1: - out = paddle.transpose(out, perm=trans_axis) - if len(res_shape) > 0: - return paddle.reshape(out, res_shape) - + median_index, out = _C_ops.nanmedian(x, 'axes', axis, 'keepdim', + keepdim) return out check_variable_and_dtype( @@ -370,20 +336,15 @@ def nanmedian(x, axis=None, ignore_nan=True, keepdim=True, name=None): 'nanmedian') helper = LayerHelper('nanmedian', **locals()) - attrs = {'ignore_nan': ignore_nan} + attrs = {'axes': axis, 'keepdim': keepdim} out = helper.create_variable_for_type_inference(x.dtype) medians = helper.create_variable_for_type_inference(x.dtype) helper.append_op( type='nanmedian', inputs={'X': x}, outputs={'Out': out, - 'Medians': medians}, + 'MedianIndex': medians}, attrs=attrs) - - if len(trans_axis) > 1: - out = paddle.transpose(out, perm=trans_axis) - if len(res_shape) > 0: - return paddle.reshape(out, res_shape) return out From bcfb015f5d1e2b7611cc728b317d6f336f9eb7fb Mon Sep 17 00:00:00 2001 From: thunder95 <290844930@qq.com> Date: Wed, 11 May 2022 14:19:11 +0000 Subject: [PATCH 12/18] =?UTF-8?q?=E5=AE=8C=E6=88=90=E4=BF=AE=E6=94=B9?= =?UTF-8?q?=E5=BB=BA=E8=AE=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/fluid/operators/nanmedian_op.cc | 2 +- .../phi/kernels/cpu/nanmedian_grad_kernel.cc | 9 +++-- paddle/phi/kernels/cpu/nanmedian_kernel.cc | 35 ++++++++++--------- .../phi/kernels/gpu/nanmedian_grad_kernel.cu | 16 ++++++--- paddle/phi/kernels/gpu/nanmedian_kernel.cu | 4 +-- paddle/phi/ops/compat/nanmedian_sig.cc | 4 +-- .../fluid/tests/unittests/test_nanmedian.py | 7 ++-- python/paddle/tensor/stat.py | 4 +-- 8 files changed, 47 insertions(+), 34 deletions(-) diff --git a/paddle/fluid/operators/nanmedian_op.cc b/paddle/fluid/operators/nanmedian_op.cc index 01a68b03cac46..8cb2c84321457 100644 --- a/paddle/fluid/operators/nanmedian_op.cc +++ b/paddle/fluid/operators/nanmedian_op.cc @@ -53,7 +53,7 @@ class NanmedianOpMaker : public framework::OpProtoAndCheckerMaker { "(bool, default true) " "If true, retain the reduced axis with length 1.") .SetDefault(true); - AddAttr>("axes", + AddAttr>("axis", "(std::vector). List of integers," " indicating the dimensions to calculate medians") .SetDefault({}); diff --git a/paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc b/paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc index a66e505ad122d..156124c214895 100644 --- a/paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc @@ -41,10 +41,15 @@ void CalcMedianGradKernel(const Context& dev_ctx, int64_t pre_dim = numel / stride; int64_t i = 0; int64_t offset = 0; + T div_factor = static_cast(2.0); for (i = 0; i < pre_dim; i++) { if (m_ptr[2 * i] >= 0) { - x_grad_ptr[offset + m_ptr[2 * i]] = out_grad_ptr[i]; - x_grad_ptr[offset + m_ptr[2 * i + 1]] = out_grad_ptr[i]; + if (m_ptr[2 * i] == m_ptr[2 * i + 1]) { + x_grad_ptr[offset + m_ptr[2 * i]] = out_grad_ptr[i]; + } else { + x_grad_ptr[offset + m_ptr[2 * i]] = out_grad_ptr[i] / div_factor; + x_grad_ptr[offset + m_ptr[2 * i + 1]] = out_grad_ptr[i] / div_factor; + } } offset += stride; } diff --git a/paddle/phi/kernels/cpu/nanmedian_kernel.cc b/paddle/phi/kernels/cpu/nanmedian_kernel.cc index 56f7fb260a7e3..e207366bd04cc 100644 --- a/paddle/phi/kernels/cpu/nanmedian_kernel.cc +++ b/paddle/phi/kernels/cpu/nanmedian_kernel.cc @@ -47,49 +47,50 @@ void CalcMedianFunc(const Context& dev_ctx, dev_ctx, x, Scalar(sort_k), -1, false, true, &sort_out, &sort_indices); T div_factor = static_cast(2.0); - int64_t pos = 0; + int64_t offset = 0; int64_t i = 0; bool is_ori_odd = stride & 1; if (should_ignore_nan) { for (i = 0; i < pre_dim; i++) { + offset = i * pre_dim; if (nan_counts[i] == stride) { m_ptr[i * 2] = -1; m_ptr[i * 2 + 1] = -1; - o_ptr[i] = sort_out_ptr[pos]; + o_ptr[i] = sort_out_ptr[offset]; } else { int64_t nan_k = nan_counts[i] > 0 ? static_cast(stride - nan_counts[i]) : sort_k; int64_t row_pos = static_cast(nan_k >> 1); - int64_t off_set = pos + row_pos; + int64_t pos = offset + row_pos; if (nan_k & 1) { - m_ptr[2 * i] = sort_indices_ptr[off_set]; - m_ptr[2 * i + 1] = sort_indices_ptr[off_set]; - o_ptr[i] = sort_out_ptr[off_set]; + m_ptr[2 * i] = sort_indices_ptr[pos]; + m_ptr[2 * i + 1] = sort_indices_ptr[pos]; + o_ptr[i] = sort_out_ptr[pos]; } else { - m_ptr[2 * i] = row_pos > 0 ? sort_indices_ptr[off_set - 1] - : sort_indices_ptr[off_set]; - m_ptr[2 * i + 1] = sort_indices_ptr[off_set]; + m_ptr[2 * i] = + row_pos > 0 ? sort_indices_ptr[pos - 1] : sort_indices_ptr[pos]; + m_ptr[2 * i + 1] = sort_indices_ptr[pos]; T m_val_left = - row_pos > 0 ? sort_out_ptr[off_set - 1] : sort_out_ptr[off_set]; - T m_val_right = sort_out_ptr[off_set]; + row_pos > 0 ? sort_out_ptr[pos - 1] : sort_out_ptr[pos]; + T m_val_right = sort_out_ptr[pos]; o_ptr[i] = (m_val_left + m_val_right) / div_factor; } } - pos += sort_k; } } else { - pos = -1; if (is_ori_odd) { for (i = 0; i < pre_dim; i++) { - pos += sort_k; + offset = i * sort_k; + int64_t pos = offset + sort_k - 1; o_ptr[i] = sort_out_ptr[pos]; m_ptr[2 * i] = sort_indices_ptr[pos]; m_ptr[2 * i + 1] = sort_indices_ptr[pos]; } } else { for (i = 0; i < pre_dim; i++) { - pos += sort_k; + offset = i * sort_k; + int64_t pos = offset + sort_k - 1; m_ptr[2 * i] = sort_k > 1 ? sort_indices_ptr[pos - 1] : sort_indices_ptr[pos]; m_ptr[2 * i + 1] = sort_indices_ptr[pos]; @@ -172,10 +173,10 @@ void BaseMedianKernel(const Context& dev_ctx, bool ignore_nan) { DenseTensor x; auto rank = input.dims().size(); - if (axes.size() == 0) { + if ((axes.size() == 0) || rank <= 1) { x = input; x.Resize({input.numel()}); - } else if (rank > 1) { + } else { PreprocessMedianKernel(dev_ctx, input, axes, &x); } diff --git a/paddle/phi/kernels/gpu/nanmedian_grad_kernel.cu b/paddle/phi/kernels/gpu/nanmedian_grad_kernel.cu index 7764fadc5f79b..a7cd49c0e53f3 100644 --- a/paddle/phi/kernels/gpu/nanmedian_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/nanmedian_grad_kernel.cu @@ -33,12 +33,19 @@ __global__ void KernelNanmedianGrad(const T* x_ptr, const T* out_grad_ptr, T* x_grad_ptr, int64_t stride, - int64_t pre_dim) { + int64_t pre_dim, + T div_factor) { CUDA_KERNEL_LOOP(index, pre_dim) { int64_t offset = index * stride; if (medians_ptr[2 * index] >= 0) { - x_grad_ptr[offset + medians_ptr[2 * index]] = out_grad_ptr[index]; - x_grad_ptr[offset + medians_ptr[2 * index + 1]] = out_grad_ptr[index]; + if (medians_ptr[2 * index] == medians_ptr[2 * index + 1]) { + x_grad_ptr[offset + medians_ptr[2 * index]] = out_grad_ptr[index]; + } else { + x_grad_ptr[offset + medians_ptr[2 * index]] = + out_grad_ptr[index] / div_factor; + x_grad_ptr[offset + medians_ptr[2 * index + 1]] = + out_grad_ptr[index] / div_factor; + } } } } @@ -64,9 +71,10 @@ void CalcMedianGradKernel(const Context& dev_ctx, int64_t stride = x_dim[x_rank - 1]; int64_t pre_dim = numel / stride; + T div_factor = static_cast(2.0); KernelNanmedianGrad< T><<>>( - x_ptr, m_ptr, out_grad_ptr, x_grad_ptr, stride, pre_dim); + x_ptr, m_ptr, out_grad_ptr, x_grad_ptr, stride, pre_dim, div_factor); } template diff --git a/paddle/phi/kernels/gpu/nanmedian_kernel.cu b/paddle/phi/kernels/gpu/nanmedian_kernel.cu index e21593bcc4c2e..3cd81daa47f1e 100644 --- a/paddle/phi/kernels/gpu/nanmedian_kernel.cu +++ b/paddle/phi/kernels/gpu/nanmedian_kernel.cu @@ -254,10 +254,10 @@ void BaseMedianKernel(const Context& dev_ctx, DenseTensor* median_index) { DenseTensor x; auto rank = input.dims().size(); - if (axes.size() == 0) { + if ((axes.size() == 0) || rank <= 1) { x = input; x.Resize({input.numel()}); - } else if (rank > 1) { + } else { PreprocessMedianKernel(dev_ctx, input, axes, &x); } diff --git a/paddle/phi/ops/compat/nanmedian_sig.cc b/paddle/phi/ops/compat/nanmedian_sig.cc index 5c6d2b38b39dd..5ca0d450e3b41 100644 --- a/paddle/phi/ops/compat/nanmedian_sig.cc +++ b/paddle/phi/ops/compat/nanmedian_sig.cc @@ -18,14 +18,14 @@ namespace phi { KernelSignature NanmedianOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature( - "nanmedian", {"X"}, {"axes", "keepdim"}, {"Out", "MedianIndex"}); + "nanmedian", {"X"}, {"axis", "keepdim"}, {"Out", "MedianIndex"}); } KernelSignature NanmedianGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("nanmedian_grad", {"X", "MedianIndex", "Out@GRAD"}, - {"axes", "keepdim"}, + {"axis", "keepdim"}, {"X@GRAD"}); } diff --git a/python/paddle/fluid/tests/unittests/test_nanmedian.py b/python/paddle/fluid/tests/unittests/test_nanmedian.py index bc6b13c86f676..2e1f13a8c7d9f 100644 --- a/python/paddle/fluid/tests/unittests/test_nanmedian.py +++ b/python/paddle/fluid/tests/unittests/test_nanmedian.py @@ -123,8 +123,6 @@ def test_axis_case(data, axis): for name, data in self.fake_data.items(): test_data_case(data) - if "_normal" in name: - test_data_case(data) for axis in self.axis_candiate_list: test_axis_case(self.fake_data["row_nan_even"], axis) @@ -181,11 +179,12 @@ def test_check_grad(self): mid = int(valid_cnts / 2) targets = [x_np_sorted[i, mid]] - if valid_cnts % 2 == 0 and mid > 0: + is_odd = valid_cnts % 2 + if not is_odd and mid > 0: targets.append(x_np_sorted[i, mid - 1]) for j in range(shape[1]): if x_np[i, j] in targets: - np_grad[i, j] = 1 + np_grad[i, j] = 1 if is_odd else 0.5 x_tensor = paddle.to_tensor(x_np, stop_gradient=False) y = paddle.nanmedian(x_tensor, axis=1, keepdim=True) diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index 9746a658b576f..2dfb9f5b364ad 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -327,7 +327,7 @@ def nanmedian(x, axis=None, keepdim=True, name=None): raise ValueError("Axis has duplicated elements.") if _in_legacy_dygraph(): - median_index, out = _C_ops.nanmedian(x, 'axes', axis, 'keepdim', + median_index, out = _C_ops.nanmedian(x, 'axis', axis, 'keepdim', keepdim) return out @@ -336,7 +336,7 @@ def nanmedian(x, axis=None, keepdim=True, name=None): 'nanmedian') helper = LayerHelper('nanmedian', **locals()) - attrs = {'axes': axis, 'keepdim': keepdim} + attrs = {'axis': axis, 'keepdim': keepdim} out = helper.create_variable_for_type_inference(x.dtype) medians = helper.create_variable_for_type_inference(x.dtype) helper.append_op( From 718fcdbce2f01be3af6887b1e4eb2d415b571a55 Mon Sep 17 00:00:00 2001 From: thunder95 <290844930@qq.com> Date: Thu, 12 May 2022 07:56:33 +0000 Subject: [PATCH 13/18] typo pre_dim --- paddle/phi/kernels/cpu/nanmedian_kernel.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/kernels/cpu/nanmedian_kernel.cc b/paddle/phi/kernels/cpu/nanmedian_kernel.cc index e207366bd04cc..ed38405c9179f 100644 --- a/paddle/phi/kernels/cpu/nanmedian_kernel.cc +++ b/paddle/phi/kernels/cpu/nanmedian_kernel.cc @@ -52,7 +52,7 @@ void CalcMedianFunc(const Context& dev_ctx, bool is_ori_odd = stride & 1; if (should_ignore_nan) { for (i = 0; i < pre_dim; i++) { - offset = i * pre_dim; + offset = i * sort_k; if (nan_counts[i] == stride) { m_ptr[i * 2] = -1; m_ptr[i * 2 + 1] = -1; From 8c158b5079761ac95705b7f521149c36e0fd5f62 Mon Sep 17 00:00:00 2001 From: thunder95 <290844930@qq.com> Date: Fri, 20 May 2022 17:10:21 +0000 Subject: [PATCH 14/18] update en docs, test=document_fix --- python/paddle/tensor/stat.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index 2dfb9f5b364ad..8cf42a4f8f146 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -263,7 +263,7 @@ def nanmedian(x, axis=None, keepdim=True, name=None): Args: x (Tensor): The input Tensor, it's data type can be int32, int64, float16, float32, float64. axis (None|int|list|tuple, optional): - The axis along which to perform median calculations ``axis`` should be int. + The axis along which to perform median calculations ``axis`` should be int or list of int. ``axis`` should be in range [-D, D), where D is the dimensions of ``x`` . If ``axis`` is less than 0, it works the same way as :math:`axis + D`. If ``axis`` is None, median is calculated over all elements of ``x``. Default is None. @@ -271,7 +271,7 @@ def nanmedian(x, axis=None, keepdim=True, name=None): in the output Tensor. If ``keepdim`` is True, the dimensions of the output Tensor is the same as ``x`` except in the reduced dimensions(it is of size 1 in this case). Otherwise, the shape of - the output Tensor is squeezed in ``axis`` . Default is False. + the output Tensor is squeezed in ``axis`` . Default is True. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -287,13 +287,16 @@ def nanmedian(x, axis=None, keepdim=True, name=None): x = paddle.to_tensor([[np.nan, 2. , 3. ], [0. , 1. , 2. ]]) y1 = x.nanmedian() - # y1 is [[[2.]] + # y1 is [[2.]] y2 = x.nanmedian(0) # y2 is [[0., 1.5, 2.5]] y3 = x.nanmedian(0, keepdim=False) # y3 is [0., 1.5, 2.5] + + y4 = x.nanmedian((1, 2)) + # y4 is [[2.]] """ if not isinstance(x, Variable): raise TypeError("In median, the input x should be a Tensor.") From 46dc918c9c99d7d4ec368f2e6019c4e3ad4cff80 Mon Sep 17 00:00:00 2001 From: thunder95 <290844930@qq.com> Date: Mon, 23 May 2022 03:06:47 +0000 Subject: [PATCH 15/18] remove numpy in en doc, test=document_fix --- python/paddle/tensor/stat.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index 9a7fe6f8c32e9..b2475ffa6efcf 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -271,8 +271,7 @@ def nanmedian(x, axis=None, keepdim=True, name=None): :name: nanmedian-example import paddle - import numpy as np - x = paddle.to_tensor([[np.nan, 2. , 3. ], [0. , 1. , 2. ]]) + x = paddle.to_tensor([[float('nan'), 2. , 3. ], [0. , 1. , 2. ]]) y1 = x.nanmedian() # y1 is [[2.]] @@ -283,7 +282,7 @@ def nanmedian(x, axis=None, keepdim=True, name=None): y3 = x.nanmedian(0, keepdim=False) # y3 is [0., 1.5, 2.5] - y4 = x.nanmedian((1, 2)) + y4 = x.nanmedian((0, 1)) # y4 is [[2.]] """ if not isinstance(x, Variable): From 117e1021dea7b27bd926255fbdf21ac4a1045526 Mon Sep 17 00:00:00 2001 From: thunder95 <290844930@qq.com> Date: Wed, 25 May 2022 03:35:26 +0000 Subject: [PATCH 16/18] add r,test=document_fix --- python/paddle/tensor/stat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index 2dfb9f5b364ad..2262dc7dc1ecf 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -254,7 +254,7 @@ def numel(x, name=None): def nanmedian(x, axis=None, keepdim=True, name=None): - """ + r""" Compute the median along the specified axis, while ignoring NaNs. If the valid count of elements is a even number, From a3b23f6bc4cdd47bee6ad88fba3fa959ca4b9dbc Mon Sep 17 00:00:00 2001 From: thunder95 <290844930@qq.com> Date: Thu, 26 May 2022 05:41:28 +0000 Subject: [PATCH 17/18] =?UTF-8?q?=E6=B7=BB=E5=8A=A0api=E5=88=B0all?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/paddle/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 936d5a159aa2e..930918e967eed 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -499,6 +499,7 @@ 'load', 'numel', 'median', + 'nanmedian', 'quantile', 'nanquantile', 'no_grad', From 6744d0a555f8c10e93fd3f4864ef7c9f75d2dddc Mon Sep 17 00:00:00 2001 From: thunder95 <290844930@qq.com> Date: Thu, 26 May 2022 11:21:32 +0000 Subject: [PATCH 18/18] follow advice from chenwhql --- paddle/fluid/operators/nanmedian_op.cc | 6 +++++- paddle/phi/infermeta/backward.cc | 11 +++++++++++ paddle/phi/infermeta/backward.h | 7 +++++++ paddle/phi/infermeta/unary.cc | 11 ----------- paddle/phi/infermeta/unary.h | 7 ------- paddle/phi/kernels/gpu/nanmedian_kernel.cu | 2 -- paddle/phi/kernels/nanmedian_kernel.h | 3 ++- 7 files changed, 25 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/operators/nanmedian_op.cc b/paddle/fluid/operators/nanmedian_op.cc index 8cb2c84321457..23a497bdb1d3d 100644 --- a/paddle/fluid/operators/nanmedian_op.cc +++ b/paddle/fluid/operators/nanmedian_op.cc @@ -1,8 +1,11 @@ -/*Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -14,6 +17,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/backward.h" #include "paddle/phi/infermeta/unary.h" namespace paddle { diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 602942abf4d34..f6b9532166aa3 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -377,6 +377,17 @@ void MultiplexGradInferMeta(const MetaTensor& ids, } } +void NanmedianGradInferMeta(const MetaTensor& x, + const MetaTensor& median_index, + const MetaTensor& out_grad, + const IntArray& axes, + bool keep_dim, + MetaTensor* x_grad) { + auto x_dims = x.dims(); + x_grad->set_dims(x_dims); + x_grad->set_dtype(x.dtype()); +} + void NllLossGradInferMeta(const MetaTensor& x, const MetaTensor& label, paddle::optional weight, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index c35b58d0f56e4..696bb51c2334d 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -168,6 +168,13 @@ void MultiplexGradInferMeta(const MetaTensor& ids, const MetaTensor& out_grad, std::vector ins_grad); +void NanmedianGradInferMeta(const MetaTensor& x, + const MetaTensor& median_index, + const MetaTensor& out_grad, + const IntArray& axes, + bool keep_dim, + MetaTensor* x_grad); + void NllLossGradInferMeta(const MetaTensor& input, const MetaTensor& label, paddle::optional weight, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 307519c029b8d..f736bf50162d8 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1305,17 +1305,6 @@ void NanmedianInferMeta(const MetaTensor& x, out->set_dims(make_ddim(out_dim)); } -void NanmedianGradInferMeta(const MetaTensor& x, - const MetaTensor& median_index, - const MetaTensor& out_grad, - const IntArray& axes, - bool keep_dim, - MetaTensor* x_grad) { - auto x_dims = x.dims(); - x_grad->set_dims(x_dims); - x_grad->set_dtype(x.dtype()); -} - void NormInferMeta(const MetaTensor& x, int axis, float epsilon, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 18524cb1d5d8e..c21ef0e2d1103 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -185,13 +185,6 @@ void NanmedianInferMeta(const MetaTensor& x, MetaTensor* out, MetaTensor* median_index); -void NanmedianGradInferMeta(const MetaTensor& x, - const MetaTensor& median_index, - const MetaTensor& out_grad, - const IntArray& axes, - bool keep_dim, - MetaTensor* x_grad); - void NormInferMeta(const MetaTensor& x, int axis, float epsilon, diff --git a/paddle/phi/kernels/gpu/nanmedian_kernel.cu b/paddle/phi/kernels/gpu/nanmedian_kernel.cu index 3cd81daa47f1e..5975e2748997e 100644 --- a/paddle/phi/kernels/gpu/nanmedian_kernel.cu +++ b/paddle/phi/kernels/gpu/nanmedian_kernel.cu @@ -141,8 +141,6 @@ void ProcessMedianKernel(const Context& dev_ctx, int64_t* m_ptr) { bool should_ignore_nan = ignore_nan; auto stream = dev_ctx.stream(); - auto* ctx = - reinterpret_cast(&dev_ctx); const T* x_ptr = x.data(); T* o_ptr = dev_ctx.template Alloc(out); diff --git a/paddle/phi/kernels/nanmedian_kernel.h b/paddle/phi/kernels/nanmedian_kernel.h index 513bb49022242..374f420381bdc 100644 --- a/paddle/phi/kernels/nanmedian_kernel.h +++ b/paddle/phi/kernels/nanmedian_kernel.h @@ -58,7 +58,8 @@ void PreprocessMedianKernel(const Context& dev_ctx, for (int i = 0; i < ndims; i++) { trans_dim[i] = input_dim[perm[i]]; } - x->mutable_data(trans_dim, dev_ctx.GetPlace()); + x->Resize(trans_dim); + dev_ctx.template Alloc(x); funcs::TransCompute(ndims, dev_ctx, input, x, perm); x->Resize(make_ddim(reshape));