Skip to content

Commit 5a2238b

Browse files
author
Diptorup Deb
authored
Merge pull request #957 from IntelPython/feature/dpnp_parfor_v2
Parfor lowering as kernels and dpnp ufunc compilation to kernels
2 parents bb8b497 + 0a343be commit 5a2238b

31 files changed

+3492
-348
lines changed

numba_dpex/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def load_dpctl_sycl_interface():
7474
f"dpctl={dpctl_version} may cause unexpected behavior"
7575
)
7676

77+
from numba import prange # noqa E402
7778

7879
import numba_dpex.core.dpjit_dispatcher # noqa E402
7980
import numba_dpex.core.offload_dispatcher # noqa E402
@@ -92,6 +93,7 @@ def load_dpctl_sycl_interface():
9293

9394
# Re-export all type names
9495
from numba_dpex.core.types import * # noqa E402
96+
from numba_dpex.dpnp_iface import dpnpimpl # noqa E402
9597
from numba_dpex.retarget import offload_to_sycl_device # noqa E402
9698

9799
if config.HAS_NON_HOST_DEVICE:

numba_dpex/_patches.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,13 +258,11 @@ def _empty_nd_impl(context, builder, arrtype, shapes):
258258
)
259259
from numba_dpex.decorators import dpjit
260260

261-
numba_config.DISABLE_PERFORMANCE_WARNINGS = 0
262261
op = dpjit(_call_usm_allocator)
263262
fnop = context.typing_context.resolve_value_type(op)
264263
# The _call_usm_allocator function will be compiled and added to registry
265264
# when the get_call_type function is invoked.
266265
fnop.get_call_type(context.typing_context, sig.args, {})
267-
numba_config.DISABLE_PERFORMANCE_WARNINGS = 1
268266
eqfn = context.get_function(fnop, sig)
269267
meminfo = eqfn(builder, args)
270268
else:
@@ -309,11 +307,17 @@ def impl(cls, allocsize, usm_type, device):
309307
return impl
310308

311309

310+
numba_config.DISABLE_PERFORMANCE_WARNINGS = 0
311+
312+
312313
def _call_usm_allocator(arrtype, size, usm_type, device):
313314
"""Trampoline to call the intrinsic used for allocation"""
314315
return arrtype._usm_allocate(size, usm_type, device)
315316

316317

318+
numba_config.DISABLE_PERFORMANCE_WARNINGS = 1
319+
320+
317321
@intrinsic
318322
def intrin_usm_alloc(typingctx, allocsize, usm_type, device):
319323
"""Intrinsic to call into the allocator for Array"""

numba_dpex/core/exceptions.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def __init__(
221221
f"Arguments {ndarray_args} are non-usm arrays, "
222222
f"and arguments {usmarray_args} are usm arrays."
223223
)
224-
elif usmarray_argnum_list:
224+
elif usmarray_argnum_list is not None:
225225
usmarray_args = ",".join([str(i) for i in usmarray_argnum_list])
226226
self.message = (
227227
f'Execution queue for kernel "{kernel_name}" could '
@@ -433,3 +433,13 @@ def __init__(self, kernel_name, argtypes) -> None:
433433
)
434434

435435
super().__init__(self.message)
436+
437+
438+
class UnsupportedParforError(Exception):
439+
"""Exception raised when a parfor node could not be lowered by Numba-dpex"""
440+
441+
def __init__(self, extra_msg=None) -> None:
442+
self.message = "Expression cannot be offloaded"
443+
if extra_msg:
444+
self.message += " due to " + extra_msg
445+
super().__init__(self.message)

numba_dpex/core/passes/__init__.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,27 @@
11
# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation
22
#
33
# SPDX-License-Identifier: Apache-2.0
4+
5+
from .parfor_legalize_cfd_pass import ParforLegalizeCFDPass
6+
from .parfor_lowering_pass import ParforLoweringPass
7+
from .passes import (
8+
DumpParforDiagnostics,
9+
NoPythonBackend,
10+
ParforFusionPass,
11+
ParforPass,
12+
ParforPreLoweringPass,
13+
PreParforPass,
14+
SplitParforPass,
15+
)
16+
17+
__all__ = [
18+
"DumpParforDiagnostics",
19+
"ParforLoweringPass",
20+
"ParforLegalizeCFDPass",
21+
"ParforFusionPass",
22+
"ParforPreLoweringPass",
23+
"ParforPass",
24+
"PreParforPass",
25+
"SplitParforPass",
26+
"NoPythonBackend",
27+
]

numba_dpex/core/passes/parfor.py

Lines changed: 66 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2421,7 +2421,6 @@ def _arrayexpr_to_parfor(self, equiv_set, lhs, arrayexpr, avail_vars):
24212421
expr = arrayexpr.expr
24222422
arr_typ = pass_states.typemap[lhs.name]
24232423
el_typ = arr_typ.dtype
2424-
24252424
# generate loopnests and size variables from lhs correlations
24262425
size_vars = equiv_set.get_shape(lhs)
24272426
index_vars, loopnests = _mk_parfor_loops(
@@ -3788,6 +3787,46 @@ def _get_call_arg_types(expr, typemap):
37883787
return tuple(new_arg_typs), new_kw_types
37893788

37903789

3790+
def _ufunc_to_parfor_instr(
3791+
typemap,
3792+
op,
3793+
avail_vars,
3794+
loc,
3795+
scope,
3796+
func_ir,
3797+
out_ir,
3798+
arg_vars,
3799+
typingctx,
3800+
calltypes,
3801+
expr_out_var,
3802+
):
3803+
func_var_name = _find_func_var(typemap, op, avail_vars, loc=loc)
3804+
func_var = ir.Var(scope, mk_unique_var(func_var_name), loc)
3805+
typemap[func_var.name] = typemap[func_var_name]
3806+
func_var_def = copy.deepcopy(func_ir.get_definition(func_var_name))
3807+
if (
3808+
isinstance(func_var_def, ir.Expr)
3809+
and func_var_def.op == "getattr"
3810+
and func_var_def.attr == "sqrt"
3811+
):
3812+
g_math_var = ir.Var(scope, mk_unique_var("$math_g_var"), loc)
3813+
typemap[g_math_var.name] = types.misc.Module(math)
3814+
g_math = ir.Global("math", math, loc)
3815+
g_math_assign = ir.Assign(g_math, g_math_var, loc)
3816+
func_var_def = ir.Expr.getattr(g_math_var, "sqrt", loc)
3817+
out_ir.append(g_math_assign)
3818+
# out_ir.append(func_var_def)
3819+
ir_expr = ir.Expr.call(func_var, arg_vars, (), loc)
3820+
call_typ = typemap[func_var.name].get_call_type(
3821+
typingctx, tuple(typemap[a.name] for a in arg_vars), {}
3822+
)
3823+
calltypes[ir_expr] = call_typ
3824+
el_typ = call_typ.return_type
3825+
# signature(el_typ, el_typ)
3826+
out_ir.append(ir.Assign(func_var_def, func_var, loc))
3827+
out_ir.append(ir.Assign(ir_expr, expr_out_var, loc))
3828+
3829+
37913830
def _arrayexpr_tree_to_ir(
37923831
func_ir,
37933832
typingctx,
@@ -3852,35 +3891,33 @@ def _arrayexpr_tree_to_ir(
38523891
# elif isinstance(op, (np.ufunc, DUFunc)):
38533892
# function calls are stored in variables which are not removed
38543893
# op is typing_key to the variables type
3855-
func_var_name = _find_func_var(typemap, op, avail_vars, loc=loc)
3856-
func_var = ir.Var(scope, mk_unique_var(func_var_name), loc)
3857-
typemap[func_var.name] = typemap[func_var_name]
3858-
func_var_def = copy.deepcopy(
3859-
func_ir.get_definition(func_var_name)
3860-
)
3861-
if (
3862-
isinstance(func_var_def, ir.Expr)
3863-
and func_var_def.op == "getattr"
3864-
and func_var_def.attr == "sqrt"
3865-
):
3866-
g_math_var = ir.Var(
3867-
scope, mk_unique_var("$math_g_var"), loc
3868-
)
3869-
typemap[g_math_var.name] = types.misc.Module(math)
3870-
g_math = ir.Global("math", math, loc)
3871-
g_math_assign = ir.Assign(g_math, g_math_var, loc)
3872-
func_var_def = ir.Expr.getattr(g_math_var, "sqrt", loc)
3873-
out_ir.append(g_math_assign)
3874-
# out_ir.append(func_var_def)
3875-
ir_expr = ir.Expr.call(func_var, arg_vars, (), loc)
3876-
call_typ = typemap[func_var.name].get_call_type(
3877-
typingctx, tuple(typemap[a.name] for a in arg_vars), {}
3894+
_ufunc_to_parfor_instr(
3895+
typemap,
3896+
op,
3897+
avail_vars,
3898+
loc,
3899+
scope,
3900+
func_ir,
3901+
out_ir,
3902+
arg_vars,
3903+
typingctx,
3904+
calltypes,
3905+
expr_out_var,
38783906
)
3879-
calltypes[ir_expr] = call_typ
3880-
el_typ = call_typ.return_type
3881-
# signature(el_typ, el_typ)
3882-
out_ir.append(ir.Assign(func_var_def, func_var, loc))
3883-
out_ir.append(ir.Assign(ir_expr, expr_out_var, loc))
3907+
if hasattr(op, "is_dpnp_ufunc"):
3908+
_ufunc_to_parfor_instr(
3909+
typemap,
3910+
op,
3911+
avail_vars,
3912+
loc,
3913+
scope,
3914+
func_ir,
3915+
out_ir,
3916+
arg_vars,
3917+
typingctx,
3918+
calltypes,
3919+
expr_out_var,
3920+
)
38843921
elif isinstance(expr, ir.Var):
38853922
var_typ = typemap[expr.name]
38863923
if isinstance(var_typ, types.Array):

0 commit comments

Comments
 (0)