Skip to content

Commit 8a5e849

Browse files
committed
Enabled complex number in dpjit.
1 parent 0e7f98f commit 8a5e849

File tree

3 files changed

+175
-4
lines changed

3 files changed

+175
-4
lines changed

numba_dpex/core/passes/parfor_lowering_pass.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ def _submit_parfor_kernel(
7070
if isinstance(arg_type, DpnpNdArray):
7171
# FIXME: Remove magic constants
7272
num_flattened_args += 5 + (2 * arg_type.ndim)
73+
elif arg_type == types.complex64 or arg_type == types.complex128:
74+
num_flattened_args += 2
7375
else:
7476
num_flattened_args += 1
7577

@@ -97,10 +99,33 @@ def _submit_parfor_kernel(
9799
# FIXME: Get rid of magic constants
98100
kernel_arg_num += 5 + (2 * argtype.ndim)
99101
else:
100-
ir_builder.build_arg(
101-
llvm_val, argtype, args_list, args_ty_list, kernel_arg_num
102-
)
103-
kernel_arg_num += 1
102+
if argtype == types.complex64:
103+
ir_builder.build_complex_arg(
104+
llvm_val,
105+
types.float32,
106+
args_list,
107+
args_ty_list,
108+
kernel_arg_num,
109+
)
110+
kernel_arg_num += 2
111+
elif argtype == types.complex128:
112+
ir_builder.build_complex_arg(
113+
llvm_val,
114+
types.float64,
115+
args_list,
116+
args_ty_list,
117+
kernel_arg_num,
118+
)
119+
kernel_arg_num += 2
120+
else:
121+
ir_builder.build_arg(
122+
llvm_val,
123+
argtype,
124+
args_list,
125+
args_ty_list,
126+
kernel_arg_num,
127+
)
128+
kernel_arg_num += 1
104129

105130
# Create a global range over which to submit the kernel based on the
106131
# loop_ranges of the parfor

numba_dpex/core/utils/kernel_launcher.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,29 @@ def build_arg(self, val, ty, arg_list, args_ty_list, arg_num):
127127
numba_type_to_dpctl_typenum(self.context, ty), kernel_arg_ty_dst
128128
)
129129

130+
def build_complex_arg(self, val, ty, arg_list, args_ty_list, arg_num):
131+
"""Creates a list of LLVM Values for an unpacked complex kernel
132+
argument.
133+
"""
134+
self._build_array_attr_arg(
135+
array_val=val,
136+
array_attr_pos=0,
137+
array_attr_ty=ty,
138+
arg_list=arg_list,
139+
args_ty_list=args_ty_list,
140+
arg_num=arg_num,
141+
)
142+
arg_num += 1
143+
self._build_array_attr_arg(
144+
array_val=val,
145+
array_attr_pos=1,
146+
array_attr_ty=ty,
147+
arg_list=arg_list,
148+
args_ty_list=args_ty_list,
149+
arg_num=arg_num,
150+
)
151+
arg_num += 1
152+
130153
def build_array_arg(
131154
self, array_val, array_rank, arg_list, args_ty_list, arg_num
132155
):
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import dpnp
6+
import numba as nb
7+
import numpy
8+
import pytest
9+
10+
import numba_dpex as dpex
11+
12+
N = 1024
13+
14+
15+
@dpex.dpjit
16+
def prange_arg(a, b, c):
17+
for i in nb.prange(a.shape[0]):
18+
b[i] = a[i] * c
19+
20+
21+
@dpex.dpjit
22+
def prange_with_bool_arg(a, b, test):
23+
for i in nb.prange(a.shape[0]):
24+
if test:
25+
b[i] = a[i] + a[i]
26+
else:
27+
b[i] = a[i] - a[i]
28+
29+
30+
@dpex.dpjit
31+
def prange_array(a, b, c):
32+
for i in nb.prange(a.shape[0]):
33+
b[i] = a[i] * c[i]
34+
35+
36+
list_of_dtypes = [
37+
dpnp.complex64,
38+
dpnp.complex128,
39+
]
40+
41+
list_of_usm_types = ["shared", "device", "host"]
42+
43+
44+
@pytest.fixture(params=list_of_dtypes)
45+
def input_arrays(request):
46+
a = dpnp.ones(N, dtype=request.param)
47+
c = dpnp.zeros(N, dtype=request.param)
48+
b = dpnp.empty_like(a)
49+
return a, b, c
50+
51+
52+
def test_dpjit_scalar_arg_types(input_arrays):
53+
"""Tests passing float and complex type dpnp arrays to a dpjit prange function.
54+
55+
Args:
56+
input_arrays (dpnp.ndarray): Array arguments to be passed to a kernel.
57+
"""
58+
s = 2
59+
a, b, _ = input_arrays
60+
61+
prange_arg(a, b, s)
62+
63+
nb = dpnp.asnumpy(b)
64+
nexpected = numpy.full_like(nb, fill_value=2)
65+
66+
assert numpy.allclose(nb, nexpected)
67+
68+
69+
def test_bool_kernel_arg_type(input_arrays):
70+
"""Tests passing boolean arguments to a dpjit prange function.
71+
72+
Args:
73+
input_arrays (dpnp.ndarray): Array arguments to be passed to a kernel.
74+
"""
75+
a, b, _ = input_arrays
76+
77+
prange_with_bool_arg(a, b, True)
78+
79+
nb = dpnp.asnumpy(b)
80+
nexpected_true = numpy.full_like(nb, fill_value=2)
81+
82+
assert numpy.allclose(nb, nexpected_true)
83+
84+
prange_with_bool_arg(a, b, False)
85+
86+
nb = dpnp.asnumpy(b)
87+
nexpected_false = numpy.zeros_like(nb)
88+
89+
assert numpy.allclose(nb, nexpected_false)
90+
91+
92+
def test_dpjit_arg_complex_scalar(input_arrays):
93+
"""Tests passing complex type scalar and dpnp arrays to a dpjit prange function.
94+
95+
Args:
96+
input_arrays (dpnp.ndarray): Array arguments to be passed to a kernel.
97+
"""
98+
s = 2 + 1j
99+
a, b, _ = input_arrays
100+
101+
prange_arg(a, b, s)
102+
103+
nb = dpnp.asnumpy(b)
104+
nexpected = numpy.full_like(nb, fill_value=2 + 1j)
105+
106+
assert numpy.allclose(nb, nexpected)
107+
108+
109+
def test_dpjit_arg_complex_array(input_arrays):
110+
"""Tests passing complex type dpnp arrays to a dpjit prange function.
111+
112+
Args:
113+
input_arrays (dpnp.ndarray): Array arguments to be passed to a kernel.
114+
"""
115+
116+
a, b, c = input_arrays
117+
118+
prange_array(a, b, c)
119+
120+
nb = dpnp.asnumpy(b)
121+
nexpected = numpy.full_like(nb, fill_value=0 + 0j)
122+
123+
assert numpy.allclose(nb, nexpected)

0 commit comments

Comments
 (0)