Skip to content

Commit 356087d

Browse files
authored
为median添加内核 (#74767)
* fix * fix * fix * 新增median kernel 为nanmedian和median反向添加均分 * fix * 分策略均分 * fix * fix * fix * fix * fix * fix * 修改copyright * fix * fix * fix * fix
1 parent e3ccc1e commit 356087d

22 files changed

+1726
-251
lines changed

paddle/phi/infermeta/backward.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,6 +1135,19 @@ void MaxPoolWithIndexGradInferMeta(const MetaTensor& x,
11351135
dx->share_meta(x);
11361136
}
11371137

1138+
void MedianGradInferMeta(const MetaTensor& x,
1139+
const MetaTensor& median_data,
1140+
const MetaTensor& median_index,
1141+
const MetaTensor& out_grad,
1142+
const IntArray& axes,
1143+
bool keep_dim,
1144+
const std::string& mode,
1145+
MetaTensor* x_grad) {
1146+
auto x_dims = x.dims();
1147+
x_grad->set_dims(x_dims);
1148+
x_grad->set_dtype(x.dtype());
1149+
}
1150+
11381151
void MemoryEfficientAttentionGradInferMeta(const MetaTensor& query,
11391152
const MetaTensor& key,
11401153
const MetaTensor& value,
@@ -1417,6 +1430,7 @@ void MultiplexGradInferMeta(const MetaTensor& ids,
14171430
}
14181431

14191432
void NanmedianGradInferMeta(const MetaTensor& x,
1433+
const MetaTensor& median_data,
14201434
const MetaTensor& median_index,
14211435
const MetaTensor& out_grad,
14221436
const IntArray& axes,

paddle/phi/infermeta/backward.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,15 @@ void MaxPoolWithIndexGradInferMeta(const MetaTensor& x,
446446
bool ceil_mode,
447447
MetaTensor* dx);
448448

449+
void MedianGradInferMeta(const MetaTensor& x,
450+
const MetaTensor& median_data,
451+
const MetaTensor& median_index,
452+
const MetaTensor& out_grad,
453+
const IntArray& axes,
454+
bool keep_dim,
455+
const std::string& mode,
456+
MetaTensor* x_grad);
457+
449458
void MeshgridGradInferMeta(const std::vector<const MetaTensor*>& inputs,
450459
const std::vector<const MetaTensor*>& outputs_grad,
451460
std::vector<MetaTensor*> inputs_grad);
@@ -525,6 +534,7 @@ void MultiplexGradInferMeta(const MetaTensor& ids,
525534
std::vector<MetaTensor*> ins_grad);
526535

527536
void NanmedianGradInferMeta(const MetaTensor& x,
537+
const MetaTensor& median_data,
528538
const MetaTensor& median_index,
529539
const MetaTensor& out_grad,
530540
const IntArray& axes,

paddle/phi/infermeta/unary.cc

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2890,6 +2890,80 @@ void MeanAllInferMeta(const MetaTensor& x, MetaTensor* out) {
28902890
out->set_layout(x.layout());
28912891
}
28922892

2893+
void MedianInferMeta(const MetaTensor& x,
2894+
const IntArray& axes,
2895+
bool keep_dim,
2896+
const std::string& mode,
2897+
MetaTensor* out,
2898+
MetaTensor* median_index) {
2899+
std::vector<int64_t> axis_list = axes.GetData();
2900+
auto x_dim = x.dims();
2901+
int64_t x_rank = x_dim.size();
2902+
2903+
std::vector<int64_t> out_dim;
2904+
if (axis_list.empty()) {
2905+
if (keep_dim) {
2906+
for (int64_t i = 0; i < x_rank; i++) {
2907+
out_dim.push_back(1);
2908+
}
2909+
}
2910+
} else {
2911+
std::vector<int64_t> formatted_axis;
2912+
for (auto& axis : axis_list) {
2913+
if (x_rank == 0) {
2914+
PADDLE_ENFORCE_EQ(axis == 0 || axis == -1,
2915+
true,
2916+
common::errors::InvalidArgument(
2917+
"When input 0D Tensor, each element of the axis "
2918+
"can only be -1, 0, None"));
2919+
} else {
2920+
PADDLE_ENFORCE_LT(axis,
2921+
x_rank,
2922+
errors::InvalidArgument(
2923+
"each element of the axis should be in the "
2924+
"range [ -dimension(X), dimension(X) ) "
2925+
"which dimension = %d. But received axis = %d.",
2926+
x_rank,
2927+
axis));
2928+
PADDLE_ENFORCE_GE(axis,
2929+
-x_rank,
2930+
errors::InvalidArgument(
2931+
"each element of the axis should be in the "
2932+
"range [ -dimension(X), dimension(X) ) "
2933+
"which dimension = %d. But received axis = %d.",
2934+
x_rank,
2935+
axis));
2936+
}
2937+
if (axis < 0) axis += x_rank;
2938+
PADDLE_ENFORCE_EQ(
2939+
std::find(formatted_axis.begin(), formatted_axis.end(), axis),
2940+
formatted_axis.end(),
2941+
errors::InvalidArgument("Attr(axes) has duplicated elements: %d.",
2942+
static_cast<int>(axis)));
2943+
2944+
formatted_axis.push_back(axis);
2945+
}
2946+
2947+
for (int64_t i = 0; i < x_rank; i++) {
2948+
if (std::find(formatted_axis.begin(), formatted_axis.end(), i) ==
2949+
formatted_axis.end()) {
2950+
out_dim.push_back(x_dim[i]); // NOLINT
2951+
} else if (keep_dim) {
2952+
out_dim.push_back(1);
2953+
}
2954+
}
2955+
}
2956+
out->set_dtype(x.dtype());
2957+
out->set_dims(make_ddim(out_dim));
2958+
2959+
auto median_dim = out_dim;
2960+
if (mode == "avg") {
2961+
median_dim.push_back(2);
2962+
}
2963+
median_index->set_dtype(DataType::INT64);
2964+
median_index->set_dims(make_ddim(median_dim));
2965+
}
2966+
28932967
void ModeInferMeta(const MetaTensor& x,
28942968
int axis,
28952969
bool keepdim,

paddle/phi/infermeta/unary.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,13 @@ void MaxPoolV2InferMeta(const MetaTensor& x,
468468

469469
void MeanAllInferMeta(const MetaTensor& x, MetaTensor* out);
470470

471+
void MedianInferMeta(const MetaTensor& x,
472+
const IntArray& axes,
473+
bool keep_dim,
474+
const std::string& mode,
475+
MetaTensor* out,
476+
MetaTensor* median_index);
477+
471478
void ModeInferMeta(const MetaTensor& x,
472479
int axis,
473480
bool keepdim,
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/kernels/median_grad_kernel.h"
16+
17+
#include <math.h>
18+
#include "paddle/phi/backends/cpu/cpu_context.h"
19+
#include "paddle/phi/core/kernel_registry.h"
20+
#include "paddle/phi/kernels/funcs/math_function.h"
21+
#include "paddle/phi/kernels/funcs/nanmedian_utils.h"
22+
23+
namespace phi {
24+
25+
template <typename T>
26+
void CalcMedianMinGrad(int64_t pre_dim,
27+
int64_t stride,
28+
const int64_t* m_data,
29+
T* dx_data,
30+
const T* dout_data) {
31+
int64_t i = 0;
32+
int64_t offset = 0;
33+
for (i = 0; i < pre_dim; i++) {
34+
if (m_data[i] >= 0) {
35+
dx_data[offset + m_data[i]] = dout_data[i];
36+
}
37+
offset += stride;
38+
}
39+
}
40+
41+
template <typename T>
42+
void CalcMedianGradEvenly(int64_t pre_dim,
43+
int64_t stride,
44+
const DenseTensor& x,
45+
const T* m_data,
46+
const int64_t* m_index,
47+
T* dx_data,
48+
const T* dout_data) {
49+
int64_t i = 0, j = 0;
50+
int64_t offset = 0;
51+
std::vector<int64_t> data_index;
52+
const T* x_data = x.data<T>();
53+
for (i = 0; i < pre_dim; i++) {
54+
data_index.clear();
55+
for (j = 0; j < stride; j++) {
56+
if ((m_data[i] == x_data[offset + j]) ||
57+
(isnan(static_cast<float>(m_data[i])) &&
58+
isnan(static_cast<float>(x_data[offset + j])))) {
59+
data_index.push_back(offset + j);
60+
}
61+
}
62+
if (data_index.size() == 0) {
63+
if (m_index[2 * i] == m_index[2 * i + 1]) {
64+
dx_data[offset + m_index[2 * i]] = dout_data[i];
65+
} else {
66+
dx_data[offset + m_index[2 * i]] = dout_data[i] / static_cast<T>(2.0);
67+
dx_data[offset + m_index[2 * i + 1]] =
68+
dout_data[i] / static_cast<T>(2.0);
69+
}
70+
} else {
71+
for (j = 0; j < data_index.size(); j++) {
72+
dx_data[data_index[j]] =
73+
dout_data[i] / static_cast<T>(data_index.size());
74+
}
75+
}
76+
77+
offset += stride;
78+
}
79+
}
80+
81+
template <typename T, typename Context>
82+
void CalcMedianGradKernel_CPU(const Context& dev_ctx,
83+
const DenseTensor& x,
84+
const DenseTensor& median_data,
85+
const DenseTensor& median_index,
86+
const DenseTensor& out_grad,
87+
const std::string& mode,
88+
const bool evenly,
89+
DenseTensor* x_grad) {
90+
T* dx_data = dev_ctx.template Alloc<T>(x_grad);
91+
if (!dx_data) return;
92+
93+
phi::funcs::SetConstant<Context, T> set_zero;
94+
set_zero(dev_ctx, x_grad, static_cast<T>(0));
95+
96+
const int64_t* m_index = median_index.data<int64_t>();
97+
const T* m_data = median_data.data<T>();
98+
const T* dout_data = out_grad.data<T>();
99+
int64_t numel = x.numel();
100+
auto x_dim = x.dims();
101+
int64_t rank = x_dim.size();
102+
int64_t stride = x_dim[static_cast<int>(rank - 1)];
103+
int64_t pre_dim = numel / stride;
104+
if (!evenly) {
105+
CalcMedianMinGrad(pre_dim, stride, m_index, dx_data, dout_data);
106+
} else {
107+
CalcMedianGradEvenly(
108+
pre_dim, stride, x, m_data, m_index, dx_data, dout_data);
109+
}
110+
}
111+
112+
template <typename T, typename Context>
113+
void MedianGradKernel(const Context& dev_ctx,
114+
const DenseTensor& x,
115+
const DenseTensor& median_data,
116+
const DenseTensor& median_index,
117+
const DenseTensor& out_grad,
118+
const IntArray& axes,
119+
bool keepdim UNUSED,
120+
const std::string& mode,
121+
DenseTensor* x_grad) {
122+
if (x_grad && x_grad->numel() == 0) {
123+
dev_ctx.template Alloc<T>(x_grad);
124+
return;
125+
}
126+
bool evenly = (axes.size() != 1 || mode == "avg");
127+
DenseTensor tmp_x;
128+
auto rank = x.dims().size();
129+
if ((axes.size() == 0) || rank <= 1) {
130+
tmp_x = x;
131+
tmp_x.Resize({x.numel()});
132+
CalcMedianGradKernel_CPU<T, Context>(dev_ctx,
133+
tmp_x,
134+
median_data,
135+
median_index,
136+
out_grad,
137+
mode,
138+
evenly,
139+
x_grad);
140+
} else {
141+
funcs::PreprocessMedianKernel<T, Context>(dev_ctx, x, axes, &tmp_x);
142+
143+
DenseTensor tmp_x_grad;
144+
tmp_x_grad.Resize(x_grad->dims());
145+
CalcMedianGradKernel_CPU<T, Context>(dev_ctx,
146+
tmp_x,
147+
median_data,
148+
median_index,
149+
out_grad,
150+
mode,
151+
evenly,
152+
&tmp_x_grad);
153+
154+
dev_ctx.template Alloc<T>(x_grad);
155+
funcs::PostprocessMedianGradKernel<T, Context>(
156+
dev_ctx, &tmp_x_grad, axes, x_grad);
157+
}
158+
}
159+
160+
} // namespace phi
161+
162+
PD_REGISTER_KERNEL(median_grad,
163+
CPU,
164+
ALL_LAYOUT,
165+
phi::MedianGradKernel,
166+
float,
167+
double,
168+
int,
169+
int64_t) {}

0 commit comments

Comments
 (0)