Skip to content

Commit

Permalink
add complex support for optest (#53356)
Browse files Browse the repository at this point in the history
* add complex support for  optest

* add complex grad test

* append one

* move some debug info

* move some debug info

* move some debug info

* move some debug info

* add more complex test

* Fix naming ambiguity

* Revert "add more complex test"

This reverts commit dbcb051.

* change backward gradient, add TODO
  • Loading branch information
GGBond8488 authored May 8, 2023
1 parent 70180df commit e522ceb
Show file tree
Hide file tree
Showing 9 changed files with 157 additions and 12 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/pybind/tensor_py.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/pybind/complex.h"
#include "paddle/phi/kernels/funcs/strided_memcpy.h"
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/cuda_device_guard.h"
Expand Down
4 changes: 3 additions & 1 deletion paddle/phi/kernels/cpu/reduce_mean_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,6 @@ PD_REGISTER_KERNEL(mean_grad,
phi::ReduceMeanGradKernel,
bool,
float,
double) {}
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
11 changes: 9 additions & 2 deletions paddle/phi/kernels/cpu/reduce_mean_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,12 @@ void MeanRawKernel(const Context& dev_ctx,

} // namespace phi

PD_REGISTER_KERNEL(
mean_raw, CPU, ALL_LAYOUT, phi::MeanRawKernel, float, double, bool) {}
PD_REGISTER_KERNEL(mean_raw,
CPU,
ALL_LAYOUT,
phi::MeanRawKernel,
float,
double,
bool,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,6 @@ PD_REGISTER_KERNEL(mean_grad,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/kps/reduce_mean_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,7 @@ PD_REGISTER_KERNEL(mean_raw,
phi::dtype::bfloat16,
float16,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#endif
15 changes: 12 additions & 3 deletions paddle/phi/kernels/reduce_mean_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,15 @@ void MeanKernel(const Context& dev_ctx,

} // namespace phi

PD_REGISTER_KERNEL(
mean, CPU, ALL_LAYOUT, phi::MeanKernel, float, double, bool) {}
PD_REGISTER_KERNEL(mean,
CPU,
ALL_LAYOUT,
phi::MeanKernel,
float,
double,
bool,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(mean,
Expand All @@ -45,7 +52,9 @@ PD_REGISTER_KERNEL(mean,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#endif

#if defined(PADDLE_WITH_XPU_KP) && !defined(PADDLE_WITH_XPU)
Expand Down
67 changes: 63 additions & 4 deletions python/paddle/fluid/tests/unittests/eager_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ def __get_elem__(tensor, i):
return tensor._get_float_element(i)
elif tensor_to_check_dtype == np.float64:
return tensor._get_double_element(i)
elif tensor_to_check_dtype == np.complex64:
return tensor._get_complex64_element(i)
elif tensor_to_check_dtype == np.complex128:
return tensor._get_complex128_element(i)
else:
raise TypeError(
"Unsupported test data type %s." % tensor_to_check_dtype
Expand All @@ -224,6 +228,10 @@ def __set_elem__(tensor, i, e):
tensor._set_float_element(i, e)
elif tensor_to_check_dtype == np.float64:
tensor._set_double_element(i, e)
elif tensor_to_check_dtype == np.complex64:
return tensor._set_complex64_element(i, e)
elif tensor_to_check_dtype == np.complex128:
return tensor._set_complex128_element(i, e)
else:
raise TypeError(
"Unsupported test data type %s." % tensor_to_check_dtype
Expand All @@ -242,15 +250,58 @@ def __set_elem__(tensor, i, e):
__set_elem__(tensor_to_check, i, x_pos)
y_pos = get_output()

if tensor_to_check_dtype in [np.complex64, np.complex128]:
if in_place:
set_input(scope, op, inputs, place)
x_pos_j = origin + 1j * delta
__set_elem__(tensor_to_check, i, x_pos_j)
y_pos_j = get_output()

if in_place:
set_input(scope, op, inputs, place)

x_neg = origin - delta
__set_elem__(tensor_to_check, i, x_neg)
y_neg = get_output()

if tensor_to_check_dtype in [np.complex64, np.complex128]:
if in_place:
set_input(scope, op, inputs, place)

x_neg_j = origin - 1j * delta
__set_elem__(tensor_to_check, i, x_neg_j)
y_neg_j = get_output()

__set_elem__(tensor_to_check, i, origin)

if tensor_to_check_dtype in [np.complex64, np.complex128]:
# always assume real output, because this function has
# no input for dl/di, though it should do. so there di will be zero

# TODO: Here is a trick to be consistent with the existing OpTest, it
# need to support variable gradients input
f_ajoint = np.array(1 + 0j)
df_over_dr = (y_pos - y_neg) / delta / 2
df_over_di = (y_pos_j - y_neg_j) / delta / 2

dl_over_du, dl_over_dv = f_ajoint.real, f_ajoint.imag

du_over_dr, dv_over_dr = df_over_dr.real, df_over_dr.imag

du_over_di, dv_over_di = df_over_di.real, df_over_di.imag

dl_over_dr = np.sum(
dl_over_du * du_over_dr + dl_over_dv * dv_over_dr
)
dl_over_di = np.sum(
dl_over_du * du_over_di + dl_over_dv * dv_over_di
)
gradient_flat[i] = dl_over_dr + 1j * dl_over_di
else:
df_over_dr = y_pos - y_neg
gradient_flat[i] = df_over_dr / delta / 2

__set_elem__(tensor_to_check, i, origin)
gradient_flat[i] = (y_pos - y_neg) / delta / 2

return gradient_flat.reshape(tensor_to_check.shape())

Expand Down Expand Up @@ -375,15 +426,24 @@ def is_rocm_op_test():
def is_custom_device_op_test():
return hasattr(cls, "use_custom_device") and cls.use_custom_device

def is_complex_test():
return (
hasattr(cls, "test_complex")
and cls.test_complex
or (cls.dtype in [np.complex64, np.complex128])
)

if not hasattr(cls, "op_type"):
raise AssertionError(
"This test do not have op_type in class attrs, "
"please set self.__class__.op_type=the_real_op_type manually."
)

# case in NO_FP64_CHECK_GRAD_CASES and op in NO_FP64_CHECK_GRAD_OP_LIST should be fixed
if not hasattr(cls, "no_need_check_grad") and not is_empty_grad_op(
cls.op_type
if (
not hasattr(cls, "no_need_check_grad")
and not is_empty_grad_op(cls.op_type)
and not is_complex_test()
):
if cls.dtype is None or (
cls.dtype == np.float16
Expand Down Expand Up @@ -2496,7 +2556,6 @@ def check_grad_with_place(
max_relative_error = (
0.001 if max_relative_error < 0.001 else max_relative_error
)

self._assert_is_close(
numeric_grads,
analytic_grads,
Expand Down
49 changes: 49 additions & 0 deletions python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,55 @@ def if_enable_cinn(self):
pass


class TestComplexElementwiseMulOpWithCheckGrad(ElementwiseMulOp):
def setUp(self):
self.op_type = "elementwise_mul"
self.python_api = paddle.multiply
self.public_python_api = paddle.multiply
self.dtype = np.complex128
self.axis = -1
self.init_dtype()
self.init_input_output()
self.init_kernel_type()
self.init_axis()
self.if_enable_cinn()

self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y),
}
self.outputs = {'Out': self.out}
self.attrs = {'axis': self.axis}

def init_input_output(self):
self.x = np.array([3 + 4j, 1 + 2j]).astype(self.dtype)
self.y = np.array([3 + 4j, 5 + 6j]).astype(self.dtype)
self.out = np.multiply(self.x, self.y)

def if_enable_cinn(self):
self.enable_cinn = False

def test_check_grad_normal(self):
self.check_grad(
['X', 'Y'],
'Out',
)

def test_check_grad_ingore_x(self):
self.check_grad(
['Y'],
'Out',
no_grad_set=set("X"),
)

def test_check_grad_ingore_y(self):
self.check_grad(
['X'],
'Out',
no_grad_set=set('Y'),
)


class TestElementwiseMulOp_ZeroDim1(ElementwiseMulOp):
def init_input_output(self):
self.x = np.random.uniform(0.1, 1, []).astype(self.dtype)
Expand Down
14 changes: 14 additions & 0 deletions python/paddle/fluid/tests/unittests/test_reduce_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,20 @@ def test_check_grad(self):
self.check_grad(['X'], 'Out', check_prim=True)


class TestComplexSumOP(TestSumOp):
def init_dtype(self):
self.dtype = np.complex128

def init_input(self):
self.x = np.random.random((3, 4)).astype(self.dtype)

def init_attrs(self):
self.attrs = {'dim': [0]}

def test_check_grad(self):
self.check_grad(['X'], 'Out', check_prim=False)


class TestSumOp_ZeroDim(TestSumOp):
def init_attrs(self):
self.attrs = {'dim': [], 'reduce_all': True}
Expand Down

0 comments on commit e522ceb

Please sign in to comment.