Skip to content

Commit 0e7f98f

Browse files
author
Diptorup Deb
authored
Merge pull request #1033 from IntelPython/feature/complex_number
Enabled complex type for kernel.
2 parents 20d4340 + 91c4932 commit 0e7f98f

File tree

3 files changed

+79
-4
lines changed

3 files changed

+79
-4
lines changed

numba_dpex/core/kernel_interface/arg_pack_unpacker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,9 +193,9 @@ def _unpack_argument(self, ty, val, access_specifier):
193193
elif ty == types.boolean:
194194
return ctypes.c_uint8(int(val))
195195
elif ty == types.complex64:
196-
raise UnsupportedKernelArgumentError(ty, val, self._pyfunc_name)
196+
return [ctypes.c_float(val.real), ctypes.c_float(val.imag)]
197197
elif ty == types.complex128:
198-
raise UnsupportedKernelArgumentError(ty, val, self._pyfunc_name)
198+
return [ctypes.c_double(val.real), ctypes.c_double(val.imag)]
199199
else:
200200
raise UnsupportedKernelArgumentError(ty, val, self._pyfunc_name)
201201

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import dpnp
6+
import numpy
7+
import pytest
8+
9+
import numba_dpex as dpex
10+
11+
N = 1024
12+
13+
14+
@dpex.kernel
15+
def kernel_scalar(a, b, c):
16+
i = dpex.get_global_id(0)
17+
b[i] = a[i] * c
18+
19+
20+
@dpex.kernel
21+
def kernel_array(a, b, c):
22+
i = dpex.get_global_id(0)
23+
b[i] = a[i] * c[i]
24+
25+
26+
list_of_dtypes = [
27+
dpnp.complex64,
28+
dpnp.complex128,
29+
]
30+
31+
list_of_usm_types = ["shared", "device", "host"]
32+
33+
34+
@pytest.fixture(params=list_of_dtypes)
35+
def input_arrays(request):
36+
a = dpnp.ones(N, dtype=request.param)
37+
c = dpnp.zeros(N, dtype=request.param)
38+
b = dpnp.empty_like(a)
39+
return a, b, c
40+
41+
42+
def test_numeric_kernel_arg_complex_scalar(input_arrays):
43+
"""Tests passing complex type scalar and dpnp arrays to a kernel function.
44+
45+
Args:
46+
input_arrays (dpnp.ndarray): Array arguments to be passed to a kernel.
47+
"""
48+
s = 2 + 1j
49+
a, b, _ = input_arrays
50+
51+
kernel_scalar[dpex.Range(N)](a, b, s)
52+
53+
nb = dpnp.asnumpy(b)
54+
nexpected = numpy.full_like(nb, fill_value=2 + 1j)
55+
56+
assert numpy.allclose(nb, nexpected)
57+
58+
59+
def test_numeric_kernel_arg_complex_array(input_arrays):
60+
"""Tests passing complex type dpnp arrays to a kernel function.
61+
62+
Args:
63+
input_arrays (dpnp.ndarray): Array arguments to be passed to a kernel.
64+
"""
65+
66+
a, b, c = input_arrays
67+
68+
kernel_array[dpex.Range(N)](a, b, c)
69+
70+
nb = dpnp.asnumpy(b)
71+
nexpected = numpy.full_like(nb, fill_value=0 + 0j)
72+
73+
assert numpy.allclose(nb, nexpected)

numba_dpex/tests/kernel_tests/test_scalar_arg_types.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ def kernel_with_bool_arg(a, b, test):
3131
dpnp.int64,
3232
dpnp.float32,
3333
dpnp.float64,
34+
dpnp.complex64,
35+
dpnp.complex128,
3436
]
3537

3638
list_of_usm_types = ["shared", "device", "host"]
@@ -43,8 +45,8 @@ def input_arrays(request):
4345
return a, b
4446

4547

46-
def test_numeric_kernel_arg_types(input_arrays):
47-
"""Tests passing float and int type scalar arguments to a kernel function.
48+
def test_numeric_kernel_arg_types1(input_arrays):
49+
"""Tests passing float, int and complex type dpnp arrays to a kernel function.
4850
4951
Args:
5052
input_arrays (dpnp.ndarray): Array arguments to be passed to a kernel.

0 commit comments

Comments
 (0)