@@ -464,6 +464,8 @@ class MaxPool2dGradFunctor<CPUContext, T> {
464464template class MaxPool2dGradFunctor <CPUContext, float >;
465465template class MaxPool2dGradFunctor <CPUContext, double >;
466466
467+ template class MaxPool2dGradFunctor <CPUContext, dtype::float16>;
468+
467469template class Pool2dFunctor <CPUContext, MaxPool<float >, float >;
468470template class Pool2dFunctor <CPUContext, AvgPool<float >, float >;
469471template class Pool2dFunctor <CPUContext, LPPool<float >, float >;
@@ -477,6 +479,24 @@ template class Pool2dGradFunctor<CPUContext, MaxPoolGrad<double>, double>;
477479template class Pool2dGradFunctor <CPUContext, AvgPoolGrad<double >, double >;
478480template class Pool2dGradFunctor <CPUContext, LPPoolGrad<double >, double >;
479481
482+ template class Pool2dFunctor <phi::CPUContext,
483+ MaxPool<dtype::float16>,
484+ dtype::float16>;
485+ template class Pool2dFunctor <phi::CPUContext,
486+ AvgPool<dtype::float16>,
487+ dtype::float16>;
488+ template class Pool2dFunctor <phi::CPUContext,
489+ LPPool<dtype::float16>,
490+ dtype::float16>;
491+ template class Pool2dGradFunctor <phi::CPUContext,
492+ MaxPoolGrad<dtype::float16>,
493+ dtype::float16>;
494+ template class Pool2dGradFunctor <phi::CPUContext,
495+ AvgPoolGrad<dtype::float16>,
496+ dtype::float16>;
497+ template class Pool2dGradFunctor <phi::CPUContext,
498+ LPPoolGrad<dtype::float16>,
499+ dtype::float16>;
480500/*
481501 * Tensors are in NCDHW or NDHWC format.
482502 * Ksize, strides, paddings are three elements. These three elements represent
@@ -1057,6 +1077,7 @@ class MaxPool3dGradFunctor<CPUContext, T> {
10571077};
10581078template class MaxPool3dGradFunctor <CPUContext, float >;
10591079template class MaxPool3dGradFunctor <CPUContext, double >;
1080+ template class MaxPool3dGradFunctor <CPUContext, dtype::float16>;
10601081
10611082template class Pool3dFunctor <CPUContext, MaxPool<float >, float >;
10621083template class Pool3dFunctor <CPUContext, AvgPool<float >, float >;
@@ -1067,6 +1088,21 @@ template class Pool3dFunctor<CPUContext, AvgPool<double>, double>;
10671088template class Pool3dGradFunctor <CPUContext, MaxPoolGrad<double >, double >;
10681089template class Pool3dGradFunctor <CPUContext, AvgPoolGrad<double >, double >;
10691090
1091+ template class Pool3dFunctor <phi::CPUContext,
1092+ MaxPool<dtype::float16>,
1093+ dtype::float16>;
1094+ template class Pool3dFunctor <phi::CPUContext,
1095+ AvgPool<dtype::float16>,
1096+ dtype::float16>;
1097+ template class Pool3dFunctor <phi::CPUContext,
1098+ LPPool<dtype::float16>,
1099+ dtype::float16>;
1100+ template class Pool3dGradFunctor <phi::CPUContext,
1101+ MaxPoolGrad<dtype::float16>,
1102+ dtype::float16>;
1103+ template class Pool3dGradFunctor <phi::CPUContext,
1104+ AvgPoolGrad<dtype::float16>,
1105+ dtype::float16>;
10701106/*
10711107 * All tensors are in NCHW format.
10721108 * Ksize, strides, paddings are two elements. These two elements represent
0 commit comments