diff --git a/numba_dpex/core/passes/parfor_lowering_pass.py b/numba_dpex/core/passes/parfor_lowering_pass.py index 1afd0b16f8..434ed0bbca 100644 --- a/numba_dpex/core/passes/parfor_lowering_pass.py +++ b/numba_dpex/core/passes/parfor_lowering_pass.py @@ -70,6 +70,8 @@ def _submit_parfor_kernel( if isinstance(arg_type, DpnpNdArray): # FIXME: Remove magic constants num_flattened_args += 5 + (2 * arg_type.ndim) + elif arg_type == types.complex64 or arg_type == types.complex128: + num_flattened_args += 2 else: num_flattened_args += 1 @@ -97,10 +99,33 @@ def _submit_parfor_kernel( # FIXME: Get rid of magic constants kernel_arg_num += 5 + (2 * argtype.ndim) else: - ir_builder.build_arg( - llvm_val, argtype, args_list, args_ty_list, kernel_arg_num - ) - kernel_arg_num += 1 + if argtype == types.complex64: + ir_builder.build_complex_arg( + llvm_val, + types.float32, + args_list, + args_ty_list, + kernel_arg_num, + ) + kernel_arg_num += 2 + elif argtype == types.complex128: + ir_builder.build_complex_arg( + llvm_val, + types.float64, + args_list, + args_ty_list, + kernel_arg_num, + ) + kernel_arg_num += 2 + else: + ir_builder.build_arg( + llvm_val, + argtype, + args_list, + args_ty_list, + kernel_arg_num, + ) + kernel_arg_num += 1 # Create a global range over which to submit the kernel based on the # loop_ranges of the parfor diff --git a/numba_dpex/core/utils/kernel_launcher.py b/numba_dpex/core/utils/kernel_launcher.py index d9652a744a..5bd8851e89 100644 --- a/numba_dpex/core/utils/kernel_launcher.py +++ b/numba_dpex/core/utils/kernel_launcher.py @@ -127,6 +127,29 @@ def build_arg(self, val, ty, arg_list, args_ty_list, arg_num): numba_type_to_dpctl_typenum(self.context, ty), kernel_arg_ty_dst ) + def build_complex_arg(self, val, ty, arg_list, args_ty_list, arg_num): + """Creates a list of LLVM Values for an unpacked complex kernel + argument. + """ + self._build_array_attr_arg( + array_val=val, + array_attr_pos=0, + array_attr_ty=ty, + arg_list=arg_list, + args_ty_list=args_ty_list, + arg_num=arg_num, + ) + arg_num += 1 + self._build_array_attr_arg( + array_val=val, + array_attr_pos=1, + array_attr_ty=ty, + arg_list=arg_list, + args_ty_list=args_ty_list, + arg_num=arg_num, + ) + arg_num += 1 + def build_array_arg( self, array_val, array_rank, arg_list, args_ty_list, arg_num ): diff --git a/numba_dpex/tests/dpjit_tests/test_dpjit_complex_arg_types.py b/numba_dpex/tests/dpjit_tests/test_dpjit_complex_arg_types.py new file mode 100644 index 0000000000..13b9d3fdc1 --- /dev/null +++ b/numba_dpex/tests/dpjit_tests/test_dpjit_complex_arg_types.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +import dpnp +import numba as nb +import numpy +import pytest + +import numba_dpex as dpex + +N = 1024 + + +@dpex.dpjit +def prange_arg(a, b, c): + for i in nb.prange(a.shape[0]): + b[i] = a[i] * c + + +@dpex.dpjit +def prange_array(a, b, c): + for i in nb.prange(a.shape[0]): + b[i] = a[i] * c[i] + + +list_of_dtypes = [ + dpnp.complex64, + dpnp.complex128, +] + +list_of_usm_types = ["shared", "device", "host"] + + +@pytest.fixture(params=list_of_dtypes) +def input_arrays(request): + a = dpnp.ones(N, dtype=request.param) + c = dpnp.zeros(N, dtype=request.param) + b = dpnp.empty_like(a) + return a, b, c + + +def test_dpjit_scalar_arg_types(input_arrays): + """Tests passing float and complex type dpnp arrays to a dpjit prange function. + + Args: + input_arrays (dpnp.ndarray): Array arguments to be passed to a kernel. + """ + s = 2 + a, b, _ = input_arrays + + prange_arg(a, b, s) + + nb = dpnp.asnumpy(b) + nexpected = numpy.full_like(nb, fill_value=2) + + assert numpy.allclose(nb, nexpected) + + +def test_dpjit_arg_complex_scalar(input_arrays): + """Tests passing complex type scalar and dpnp arrays to a dpjit prange function. + + Args: + input_arrays (dpnp.ndarray): Array arguments to be passed to a kernel. + """ + s = 2 + 1j + a, b, _ = input_arrays + + prange_arg(a, b, s) + + nb = dpnp.asnumpy(b) + nexpected = numpy.full_like(nb, fill_value=2 + 1j) + + assert numpy.allclose(nb, nexpected) + + +def test_dpjit_arg_complex_array(input_arrays): + """Tests passing complex type dpnp arrays to a dpjit prange function. + + Args: + input_arrays (dpnp.ndarray): Array arguments to be passed to a kernel. + """ + + a, b, c = input_arrays + + prange_array(a, b, c) + + nb = dpnp.asnumpy(b) + nexpected = numpy.full_like(nb, fill_value=0 + 0j) + + assert numpy.allclose(nb, nexpected)