Skip to content

Commit

Permalink
replace complex template in cast op
Browse files Browse the repository at this point in the history
  • Loading branch information
MingMingShangTian committed May 20, 2021
1 parent 574ff76 commit 026a782
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 14 deletions.
18 changes: 8 additions & 10 deletions paddle/fluid/operators/cast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,11 @@ REGISTER_OPERATOR(cast, ops::CastOp,
ops::CastOpGradMaker<paddle::framework::OpDesc>,
ops::CastOpGradMaker<paddle::imperative::OpBase>,
ops::CastOpProtoMaker);
REGISTER_OP_CPU_KERNEL(cast, ops::CastOpKernel<CPU, float>,
ops::CastOpKernel<CPU, double>,
ops::CastOpKernel<CPU, int>,
ops::CastOpKernel<CPU, int64_t>,
ops::CastOpKernel<CPU, bool>,
ops::CastOpKernel<CPU, uint8_t>,
ops::CastOpKernel<CPU, paddle::platform::float16>,
ops::CastOpKernel<CPU, paddle::platform::bfloat16>,
ops::CastOpKernel<CPU, paddle::platform::complex64>,
ops::CastOpKernel<CPU, paddle::platform::complex128>);
REGISTER_OP_CPU_KERNEL(
cast, ops::CastOpKernel<CPU, float>, ops::CastOpKernel<CPU, double>,
ops::CastOpKernel<CPU, int>, ops::CastOpKernel<CPU, int64_t>,
ops::CastOpKernel<CPU, bool>, ops::CastOpKernel<CPU, uint8_t>,
ops::CastOpKernel<CPU, paddle::platform::float16>,
ops::CastOpKernel<CPU, paddle::platform::bfloat16>,
ops::CastOpKernel<CPU, paddle::platform::complex<float>>,
ops::CastOpKernel<CPU, paddle::platform::complex<double>>);
8 changes: 4 additions & 4 deletions paddle/fluid/operators/cast_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
#else
REGISTER_OP_CUDA_KERNEL(
cast, ops::CastOpKernel<paddle::platform::CUDADeviceContext, float>,
Expand All @@ -122,7 +122,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::bfloat16>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
#endif
73 changes: 73 additions & 0 deletions python/paddle/fluid/tests/unittests/test_complex_cast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function, division

import unittest
import numpy as np

import paddle


class TestComplexCastOp(unittest.TestCase):
def test_complex_to_real(self):
r = np.random.random(size=[10, 10]) * 10
i = np.random.random(size=[10, 10])

c_t = paddle.to_tensor(r + i * 1J, dtype='complex64')

self.assertEqual(c_t.cast('int64').dtype, paddle.int64)
self.assertEqual(c_t.cast('int32').dtype, paddle.int32)
self.assertEqual(c_t.cast('float32').dtype, paddle.float32)
self.assertEqual(c_t.cast('float64').dtype, paddle.float64)
self.assertEqual(c_t.cast('bool').dtype, paddle.bool)

self.assertTrue(
np.allclose(c_t.cast('int64').numpy(), r.astype('int64')))
self.assertTrue(
np.allclose(c_t.cast('int32').numpy(), r.astype('int32')))
self.assertTrue(
np.allclose(c_t.cast('float32').numpy(), r.astype('float32')))
self.assertTrue(
np.allclose(c_t.cast('float64').numpy(), r.astype('float64')))
self.assertTrue(np.allclose(c_t.cast('bool').numpy(), r.astype('bool')))

def test_real_to_complex(self):
r = np.random.random(size=[10, 10]) * 10
r_t = paddle.to_tensor(r)

self.assertEqual(r_t.cast('complex64').dtype, paddle.complex64)
self.assertEqual(r_t.cast('complex128').dtype, paddle.complex128)

self.assertTrue(np.allclose(r_t.cast('complex64').real().numpy(), r))
self.assertTrue(np.allclose(r_t.cast('complex128').real().numpy(), r))

def test_complex64_complex128(self):
r = np.random.random(size=[10, 10])
i = np.random.random(size=[10, 10])

c = r + i * 1J
c_64 = paddle.to_tensor(c, dtype='complex64')
c_128 = paddle.to_tensor(c, dtype='complex128')

self.assertTrue(c_64.cast('complex128').dtype, paddle.complex128)
self.assertTrue(c_128.cast('complex128').dtype, paddle.complex64)
self.assertTrue(
np.allclose(c_64.cast('complex128').numpy(), c_128.numpy()))
self.assertTrue(
np.allclose(c_128.cast('complex128').numpy(), c_64.numpy()))


if __name__ == '__main__':
unittest.main()

0 comments on commit 026a782

Please sign in to comment.