diff --git a/paddle/phi/kernels/funcs/maxouting.cu b/paddle/phi/kernels/funcs/maxouting.cu index 89450dbd5c60bd..146bb1aca4c1bd 100644 --- a/paddle/phi/kernels/funcs/maxouting.cu +++ b/paddle/phi/kernels/funcs/maxouting.cu @@ -175,9 +175,11 @@ void MaxOutGradFunctor::operator()( } template class MaxOutGradFunctor; +template class MaxOutGradFunctor; template class MaxOutGradFunctor; template class MaxOutFunctor; +template class MaxOutFunctor; template class MaxOutFunctor; } // namespace funcs diff --git a/paddle/phi/kernels/gpu/maxout_grad_kernel.cu b/paddle/phi/kernels/gpu/maxout_grad_kernel.cu index a405f38523a75a..7d59436019c715 100644 --- a/paddle/phi/kernels/gpu/maxout_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/maxout_grad_kernel.cu @@ -15,5 +15,10 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/maxout_grad_kernel_impl.h" -PD_REGISTER_KERNEL( - maxout_grad, GPU, ALL_LAYOUT, phi::MaxOutGradKernel, float, double) {} +PD_REGISTER_KERNEL(maxout_grad, + GPU, + ALL_LAYOUT, + phi::MaxOutGradKernel, + float, + phi::dtype::float16, + double) {} diff --git a/paddle/phi/kernels/gpu/maxout_kernel.cu b/paddle/phi/kernels/gpu/maxout_kernel.cu index e5407a4925c840..4871046450264c 100644 --- a/paddle/phi/kernels/gpu/maxout_kernel.cu +++ b/paddle/phi/kernels/gpu/maxout_kernel.cu @@ -15,4 +15,10 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/maxout_kernel_impl.h" -PD_REGISTER_KERNEL(maxout, GPU, ALL_LAYOUT, phi::MaxOutKernel, float, double) {} +PD_REGISTER_KERNEL(maxout, + GPU, + ALL_LAYOUT, + phi::MaxOutKernel, + float, + phi::dtype::float16, + double) {} diff --git a/python/paddle/fluid/tests/unittests/test_maxout_op.py b/python/paddle/fluid/tests/unittests/test_maxout_op.py index 678dd55fe92c1d..b6d339c3aab283 100644 --- a/python/paddle/fluid/tests/unittests/test_maxout_op.py +++ b/python/paddle/fluid/tests/unittests/test_maxout_op.py @@ -136,5 +136,40 @@ def test_errors(self): self.assertRaises(ValueError, F.maxout, x_float32, 2, 2) +class TestMaxOutOpFP16(TestMaxOutOp): + def set_attrs(self): + self.dtype = 'float16' + + +class TestMaxoutFP16Case1(TestMaxOutOpFP16): + def set_attrs(self): + self.axis = -1 + + +class TestMaxoutFP16Case2(TestMaxOutOpFP16): + def set_attrs(self): + self.axis = 3 + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestMaxoutStaticAPIFP16(unittest.TestCase): + def setUp(self): + self.x_np = np.random.uniform(-1, 1, [2, 6, 5, 4]).astype(np.float16) + self.groups = 2 + self.axis = 1 + self.place = paddle.CUDAPlace(0) + + def test_static_api(self): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype) + out = F.maxout(x, self.groups, self.axis) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out]) + out_ref = maxout_forward_naive(self.x_np, self.groups, self.axis) + np.testing.assert_allclose(out_ref, res[0], rtol=1e-05) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 88a01799165684..8984800b9062ed 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -784,7 +784,7 @@ def maxout(x, groups, axis=1, name=None): Parameters: x (Tensor): The input is 4-D Tensor with shape [N, C, H, W] or [N, H, W, C], the data type - of input is float32 or float64. + of input is float16, float32 or float64. groups (int): The groups number of maxout. `groups` specifies the index of channel dimension where maxout will be performed. This must be a factor of number of features. @@ -819,7 +819,9 @@ def maxout(x, groups, axis=1, name=None): if in_dygraph_mode(): return _C_ops.maxout(x, groups, axis) else: - check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'maxout') + check_variable_and_dtype( + x, 'x', ['float16', 'float32', 'float64'], 'maxout' + ) if axis not in [1, -1, 3]: raise ValueError( "Attr(axis) should be 1 when data format is NCHW, -1 or 3 when data format is NHWC. Received "