Skip to content

Commit d7e7496

Browse files
add avg/max poll for float16 (#74313)
1 parent 53b6276 commit d7e7496

File tree

5 files changed

+66
-8
lines changed

5 files changed

+66
-8
lines changed

paddle/phi/kernels/cpu/pool_grad_kernel.cc

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,13 @@
1717
#include "paddle/phi/core/kernel_registry.h"
1818
#include "paddle/phi/kernels/impl/pool_grad_kernel_impl.h"
1919

20-
PD_REGISTER_KERNEL(
21-
pool2d_grad, CPU, ALL_LAYOUT, phi::Pool2dGradKernel, float, double) {}
20+
PD_REGISTER_KERNEL(pool2d_grad,
21+
CPU,
22+
ALL_LAYOUT,
23+
phi::Pool2dGradKernel,
24+
float,
25+
double,
26+
phi::dtype::float16) {}
2227
PD_REGISTER_KERNEL(
2328
lp_pool2d_grad, CPU, ALL_LAYOUT, phi::LPPool2dGradKernel, float, double) {}
2429
PD_REGISTER_KERNEL(pool2d_double_grad,
@@ -36,8 +41,13 @@ PD_REGISTER_KERNEL(max_pool2d_with_index_grad,
3641
kernel->InputAt(1).SetDataType(phi::CppTypeToDataType<int>::Type());
3742
}
3843

39-
PD_REGISTER_KERNEL(
40-
pool3d_grad, CPU, ALL_LAYOUT, phi::Pool3dGradKernel, float, double) {}
44+
PD_REGISTER_KERNEL(pool3d_grad,
45+
CPU,
46+
ALL_LAYOUT,
47+
phi::Pool3dGradKernel,
48+
float,
49+
double,
50+
phi::dtype::float16) {}
4151
PD_REGISTER_KERNEL(max_pool3d_with_index_grad,
4252
CPU,
4353
ALL_LAYOUT,

paddle/phi/kernels/cpu/pool_kernel.cc

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,13 @@
1717
#include "paddle/phi/core/kernel_registry.h"
1818
#include "paddle/phi/kernels/impl/pool_kernel_impl.h"
1919

20-
PD_REGISTER_KERNEL(pool2d, CPU, ALL_LAYOUT, phi::Pool2dKernel, float, double) {}
20+
PD_REGISTER_KERNEL(pool2d,
21+
CPU,
22+
ALL_LAYOUT,
23+
phi::Pool2dKernel,
24+
float,
25+
double,
26+
phi::dtype::float16) {}
2127
PD_REGISTER_KERNEL(
2228
lp_pool2d, CPU, ALL_LAYOUT, phi::LPPool2dKernel, float, double) {}
2329
PD_REGISTER_KERNEL(max_pool2d_with_index,
@@ -29,7 +35,13 @@ PD_REGISTER_KERNEL(max_pool2d_with_index,
2935
kernel->OutputAt(1).SetDataType(phi::CppTypeToDataType<int>::Type());
3036
}
3137

32-
PD_REGISTER_KERNEL(pool3d, CPU, ALL_LAYOUT, phi::Pool3dKernel, float, double) {}
38+
PD_REGISTER_KERNEL(pool3d,
39+
CPU,
40+
ALL_LAYOUT,
41+
phi::Pool3dKernel,
42+
float,
43+
double,
44+
phi::dtype::float16) {}
3345
PD_REGISTER_KERNEL(max_pool3d_with_index,
3446
CPU,
3547
ALL_LAYOUT,

paddle/phi/kernels/funcs/pooling.cc

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,8 @@ class MaxPool2dGradFunctor<CPUContext, T> {
464464
template class MaxPool2dGradFunctor<CPUContext, float>;
465465
template class MaxPool2dGradFunctor<CPUContext, double>;
466466

467+
template class MaxPool2dGradFunctor<CPUContext, dtype::float16>;
468+
467469
template class Pool2dFunctor<CPUContext, MaxPool<float>, float>;
468470
template class Pool2dFunctor<CPUContext, AvgPool<float>, float>;
469471
template class Pool2dFunctor<CPUContext, LPPool<float>, float>;
@@ -477,6 +479,24 @@ template class Pool2dGradFunctor<CPUContext, MaxPoolGrad<double>, double>;
477479
template class Pool2dGradFunctor<CPUContext, AvgPoolGrad<double>, double>;
478480
template class Pool2dGradFunctor<CPUContext, LPPoolGrad<double>, double>;
479481

482+
template class Pool2dFunctor<phi::CPUContext,
483+
MaxPool<dtype::float16>,
484+
dtype::float16>;
485+
template class Pool2dFunctor<phi::CPUContext,
486+
AvgPool<dtype::float16>,
487+
dtype::float16>;
488+
template class Pool2dFunctor<phi::CPUContext,
489+
LPPool<dtype::float16>,
490+
dtype::float16>;
491+
template class Pool2dGradFunctor<phi::CPUContext,
492+
MaxPoolGrad<dtype::float16>,
493+
dtype::float16>;
494+
template class Pool2dGradFunctor<phi::CPUContext,
495+
AvgPoolGrad<dtype::float16>,
496+
dtype::float16>;
497+
template class Pool2dGradFunctor<phi::CPUContext,
498+
LPPoolGrad<dtype::float16>,
499+
dtype::float16>;
480500
/*
481501
* Tensors are in NCDHW or NDHWC format.
482502
* Ksize, strides, paddings are three elements. These three elements represent
@@ -1057,6 +1077,7 @@ class MaxPool3dGradFunctor<CPUContext, T> {
10571077
};
10581078
template class MaxPool3dGradFunctor<CPUContext, float>;
10591079
template class MaxPool3dGradFunctor<CPUContext, double>;
1080+
template class MaxPool3dGradFunctor<CPUContext, dtype::float16>;
10601081

10611082
template class Pool3dFunctor<CPUContext, MaxPool<float>, float>;
10621083
template class Pool3dFunctor<CPUContext, AvgPool<float>, float>;
@@ -1067,6 +1088,21 @@ template class Pool3dFunctor<CPUContext, AvgPool<double>, double>;
10671088
template class Pool3dGradFunctor<CPUContext, MaxPoolGrad<double>, double>;
10681089
template class Pool3dGradFunctor<CPUContext, AvgPoolGrad<double>, double>;
10691090

1091+
template class Pool3dFunctor<phi::CPUContext,
1092+
MaxPool<dtype::float16>,
1093+
dtype::float16>;
1094+
template class Pool3dFunctor<phi::CPUContext,
1095+
AvgPool<dtype::float16>,
1096+
dtype::float16>;
1097+
template class Pool3dFunctor<phi::CPUContext,
1098+
LPPool<dtype::float16>,
1099+
dtype::float16>;
1100+
template class Pool3dGradFunctor<phi::CPUContext,
1101+
MaxPoolGrad<dtype::float16>,
1102+
dtype::float16>;
1103+
template class Pool3dGradFunctor<phi::CPUContext,
1104+
AvgPoolGrad<dtype::float16>,
1105+
dtype::float16>;
10701106
/*
10711107
* All tensors are in NCHW format.
10721108
* Ksize, strides, paddings are two elements. These two elements represent

paddle/phi/kernels/gpudnn/pool_kernel.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ void Pool3dGPUDNNKernel(const Context& dev_ctx,
305305
const std::string& padding_algorithm,
306306
DenseTensor* out) {
307307
if (x.numel() == 0) {
308-
if (pooling_type == "max" || pooling_type == "avg") {
308+
if (pooling_type == "max") {
309309
phi::Full<T, Context>(
310310
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out);
311311
} else {

paddle/phi/kernels/impl/pool_kernel_impl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ void Pool3dKernel(const Context& dev_ctx,
389389
const std::string& padding_algorithm,
390390
DenseTensor* out) {
391391
if (x.numel() == 0) {
392-
if (pooling_type == "max" || pooling_type == "avg") {
392+
if (pooling_type == "max") {
393393
phi::Full<T, Context>(
394394
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out);
395395
} else {

0 commit comments

Comments
 (0)