Skip to content

Commit 88163d9

Browse files
author
Diptorup Deb
authored
Merge pull request #1035 from IntelPython/feature/complex_number_dpjit
Enabled complex number in dpjit.
2 parents 0e7f98f + 28d7234 commit 88163d9

File tree

3 files changed

+143
-4
lines changed

3 files changed

+143
-4
lines changed

numba_dpex/core/passes/parfor_lowering_pass.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ def _submit_parfor_kernel(
7070
if isinstance(arg_type, DpnpNdArray):
7171
# FIXME: Remove magic constants
7272
num_flattened_args += 5 + (2 * arg_type.ndim)
73+
elif arg_type == types.complex64 or arg_type == types.complex128:
74+
num_flattened_args += 2
7375
else:
7476
num_flattened_args += 1
7577

@@ -97,10 +99,33 @@ def _submit_parfor_kernel(
9799
# FIXME: Get rid of magic constants
98100
kernel_arg_num += 5 + (2 * argtype.ndim)
99101
else:
100-
ir_builder.build_arg(
101-
llvm_val, argtype, args_list, args_ty_list, kernel_arg_num
102-
)
103-
kernel_arg_num += 1
102+
if argtype == types.complex64:
103+
ir_builder.build_complex_arg(
104+
llvm_val,
105+
types.float32,
106+
args_list,
107+
args_ty_list,
108+
kernel_arg_num,
109+
)
110+
kernel_arg_num += 2
111+
elif argtype == types.complex128:
112+
ir_builder.build_complex_arg(
113+
llvm_val,
114+
types.float64,
115+
args_list,
116+
args_ty_list,
117+
kernel_arg_num,
118+
)
119+
kernel_arg_num += 2
120+
else:
121+
ir_builder.build_arg(
122+
llvm_val,
123+
argtype,
124+
args_list,
125+
args_ty_list,
126+
kernel_arg_num,
127+
)
128+
kernel_arg_num += 1
104129

105130
# Create a global range over which to submit the kernel based on the
106131
# loop_ranges of the parfor

numba_dpex/core/utils/kernel_launcher.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,29 @@ def build_arg(self, val, ty, arg_list, args_ty_list, arg_num):
127127
numba_type_to_dpctl_typenum(self.context, ty), kernel_arg_ty_dst
128128
)
129129

130+
def build_complex_arg(self, val, ty, arg_list, args_ty_list, arg_num):
131+
"""Creates a list of LLVM Values for an unpacked complex kernel
132+
argument.
133+
"""
134+
self._build_array_attr_arg(
135+
array_val=val,
136+
array_attr_pos=0,
137+
array_attr_ty=ty,
138+
arg_list=arg_list,
139+
args_ty_list=args_ty_list,
140+
arg_num=arg_num,
141+
)
142+
arg_num += 1
143+
self._build_array_attr_arg(
144+
array_val=val,
145+
array_attr_pos=1,
146+
array_attr_ty=ty,
147+
arg_list=arg_list,
148+
args_ty_list=args_ty_list,
149+
arg_num=arg_num,
150+
)
151+
arg_num += 1
152+
130153
def build_array_arg(
131154
self, array_val, array_rank, arg_list, args_ty_list, arg_num
132155
):
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import dpnp
6+
import numba as nb
7+
import numpy
8+
import pytest
9+
10+
import numba_dpex as dpex
11+
12+
N = 1024
13+
14+
15+
@dpex.dpjit
16+
def prange_arg(a, b, c):
17+
for i in nb.prange(a.shape[0]):
18+
b[i] = a[i] * c
19+
20+
21+
@dpex.dpjit
22+
def prange_array(a, b, c):
23+
for i in nb.prange(a.shape[0]):
24+
b[i] = a[i] * c[i]
25+
26+
27+
list_of_dtypes = [
28+
dpnp.complex64,
29+
dpnp.complex128,
30+
]
31+
32+
list_of_usm_types = ["shared", "device", "host"]
33+
34+
35+
@pytest.fixture(params=list_of_dtypes)
36+
def input_arrays(request):
37+
a = dpnp.ones(N, dtype=request.param)
38+
c = dpnp.zeros(N, dtype=request.param)
39+
b = dpnp.empty_like(a)
40+
return a, b, c
41+
42+
43+
def test_dpjit_scalar_arg_types(input_arrays):
44+
"""Tests passing float and complex type dpnp arrays to a dpjit prange function.
45+
46+
Args:
47+
input_arrays (dpnp.ndarray): Array arguments to be passed to a kernel.
48+
"""
49+
s = 2
50+
a, b, _ = input_arrays
51+
52+
prange_arg(a, b, s)
53+
54+
nb = dpnp.asnumpy(b)
55+
nexpected = numpy.full_like(nb, fill_value=2)
56+
57+
assert numpy.allclose(nb, nexpected)
58+
59+
60+
def test_dpjit_arg_complex_scalar(input_arrays):
61+
"""Tests passing complex type scalar and dpnp arrays to a dpjit prange function.
62+
63+
Args:
64+
input_arrays (dpnp.ndarray): Array arguments to be passed to a kernel.
65+
"""
66+
s = 2 + 1j
67+
a, b, _ = input_arrays
68+
69+
prange_arg(a, b, s)
70+
71+
nb = dpnp.asnumpy(b)
72+
nexpected = numpy.full_like(nb, fill_value=2 + 1j)
73+
74+
assert numpy.allclose(nb, nexpected)
75+
76+
77+
def test_dpjit_arg_complex_array(input_arrays):
78+
"""Tests passing complex type dpnp arrays to a dpjit prange function.
79+
80+
Args:
81+
input_arrays (dpnp.ndarray): Array arguments to be passed to a kernel.
82+
"""
83+
84+
a, b, c = input_arrays
85+
86+
prange_array(a, b, c)
87+
88+
nb = dpnp.asnumpy(b)
89+
nexpected = numpy.full_like(nb, fill_value=0 + 0j)
90+
91+
assert numpy.allclose(nb, nexpected)

0 commit comments

Comments
 (0)