Skip to content

Commit

Permalink
Merge pull request #1048 from IntelPython/fix/mul_reduction
Browse files Browse the repository at this point in the history
Fixed bug in reduction mul operation for dpjit.
  • Loading branch information
Diptorup Deb authored May 18, 2023
2 parents ee08e9b + 512cabb commit f489858
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 15 deletions.
18 changes: 14 additions & 4 deletions numba_dpex/core/utils/kernel_templates/reduction_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ def _generate_kernel_stub_as_string(self):
for redvar in self._redvars:
legal_redvar = self._redvars_dict[redvar]
gufunc_txt += " "
gufunc_txt += legal_redvar + " = 0\n"
gufunc_txt += legal_redvar + " = "
gufunc_txt += f"{self._parfor_reddict[redvar].init_val} \n"

gufunc_txt += " "
gufunc_txt += self._sentinel_name + " = 0\n"

Expand Down Expand Up @@ -265,8 +267,15 @@ def _generate_kernel_stub_as_string(self):
)

for i, redvar in enumerate(self._redvars):
gufunc_txt += f" {self._final_sum_var_name[i]}[0] += \
{self._partial_sum_var_name[i]}[j]\n"
redop = self._parfor_reddict[redvar].redop
if redop == operator.iadd:
gufunc_txt += f" {self._final_sum_var_name[i]}[0] += \
{self._partial_sum_var_name[i]}[j]\n"
elif redop == operator.imul:
gufunc_txt += f" {self._final_sum_var_name[i]}[0] *= \
{self._partial_sum_var_name[i]}[j]\n"
else:
raise NotImplementedError

gufunc_txt += (
f" for j in range ({self._global_size_mod_var_name[0]}) :\n"
Expand All @@ -275,7 +284,8 @@ def _generate_kernel_stub_as_string(self):
for redvar in self._redvars:
legal_redvar = self._redvars_dict[redvar]
gufunc_txt += " "
gufunc_txt += legal_redvar + " = 0\n"
gufunc_txt += legal_redvar + " = "
gufunc_txt += f"{self._parfor_reddict[redvar].init_val}\n"

gufunc_txt += (
" "
Expand Down
42 changes: 31 additions & 11 deletions numba_dpex/tests/dpjit_tests/test_dpjit_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@

import numba_dpex as dpex

N = 100
N = 10


@dpex.dpjit
def vecadd_prange(a, b):
def vecadd_prange1(a, b):
s = 0
t = 0
for i in nb.prange(a.shape[0]):
Expand All @@ -24,13 +24,21 @@ def vecadd_prange(a, b):


@dpex.dpjit
def vecmul_prange(a, b):
def vecadd_prange2(a, b):
t = 0
for i in nb.prange(a.shape[0]):
t += a[i] * b[i]
return t


@dpex.dpjit
def vecmul_prange(a, b):
t = 1
for i in nb.prange(a.shape[0]):
t *= a[i] + b[i]
return t


@dpex.dpjit
def vecadd_prange_float(a, b):
s = numpy.float32(0)
Expand All @@ -57,30 +65,42 @@ def input_arrays(request):
return a, b


def test_dpjit_array_arg_types(input_arrays):
def test_dpjit_array_arg_types_add1(input_arrays):
"""Tests passing float and int type dpnp arrays to a dpjit
prange function.
Args:
input_arrays (dpnp.ndarray): Array arguments to be passed to a kernel.
"""
s = 200

s = 20
a, b = input_arrays

c = vecadd_prange(a, b)
c = vecadd_prange1(a, b)

assert s == c


def test_dpjit_array_arg_types_add2(input_arrays):
"""Tests passing float and int type dpnp arrays to a dpjit
prange function.
Args:
input_arrays (dpnp.ndarray): Array arguments to be passed to a kernel.
"""
t = 45
a, b = input_arrays
d = vecadd_prange2(a, b)

assert t == d


def test_dpjit_array_arg_types_mul(input_arrays):
"""Tests passing float and int type dpnp arrays to a dpjit
prange function.
Args:
input_arrays (dpnp.ndarray): Array arguments to be passed to a kernel.
"""
s = 4950
s = 3628800

a, b = input_arrays

Expand All @@ -97,8 +117,8 @@ def test_dpjit_array_arg_float32_types(input_arrays):
input_arrays (dpnp.ndarray): Array arguments to be passed to a kernel.
"""
s = 9900
a = dpnp.arange(N, dtype=dpnp.float32)
b = dpnp.arange(N, dtype=dpnp.float32)
a = dpnp.arange(100, dtype=dpnp.float32)
b = dpnp.arange(100, dtype=dpnp.float32)

c = vecadd_prange_float(a, b)

Expand Down

0 comments on commit f489858

Please sign in to comment.