Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 22 additions & 180 deletions numba_dpex/core/utils/kernel_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@
from ..descriptor import dpex_kernel_target
from ..passes import parfor
from ..types.dpnp_ndarray_type import DpnpNdArray
from .kernel_templates import RangeKernelTemplate


class GufuncKernel:
class ParforKernel:
def __init__(
self,
name,
Expand Down Expand Up @@ -172,131 +173,6 @@ def _replace_var_with_array(vars, loop_body, typemap, calltypes):
typemap[v] = types.npytypes.Array(el_typ, 1, "C")


def _dbgprint_after_each_array_assignments(lowerer, loop_body, typemap):
for label, block in loop_body.items():
new_block = block.copy()
new_block.clear()
loc = block.loc
scope = block.scope
for inst in block.body:
new_block.append(inst)
# Append print after assignment
if isinstance(inst, ir.Assign):
# Only apply to numbers
if typemap[inst.target.name] not in types.number_domain:
continue

# Make constant string
strval = "{} =".format(inst.target.name)
strconsttyp = types.StringLiteral(strval)

lhs = ir.Var(scope, mk_unique_var("str_const"), loc)
assign_lhs = ir.Assign(
value=ir.Const(value=strval, loc=loc), target=lhs, loc=loc
)
typemap[lhs.name] = strconsttyp
new_block.append(assign_lhs)

# Make print node
print_node = ir.Print(
args=[lhs, inst.target], vararg=None, loc=loc
)
new_block.append(print_node)
sig = signature(
types.none, typemap[lhs.name], typemap[inst.target.name]
)
lowerer.fndesc.calltypes[print_node] = sig
loop_body[label] = new_block


def _generate_kernel_stub_as_string(
kernel_name,
parfor_params,
parfor_dim,
legal_loop_indices,
loop_ranges,
param_dict,
has_reduction,
redvars,
typemap,
redvars_dict,
sentinel_name,
):
"""Generates a stub dpex kernel for the parfor.

Returns:
str: A string representing a stub kernel function for the parfor.
"""
kernel_txt = ""

# Create the dpex kernel function.
kernel_txt += "def " + kernel_name
kernel_txt += "(" + (", ".join(parfor_params)) + "):\n"
global_id_dim = 0
for_loop_dim = parfor_dim

if parfor_dim > 3:
raise NotImplementedError
global_id_dim = 3
else:
global_id_dim = parfor_dim

for eachdim in range(global_id_dim):
kernel_txt += (
" "
+ legal_loop_indices[eachdim]
+ " = "
+ "dpex.get_global_id("
+ str(eachdim)
+ ")\n"
)

for eachdim in range(global_id_dim, for_loop_dim):
for indent in range(1 + (eachdim - global_id_dim)):
kernel_txt += " "

start, stop, step = loop_ranges[eachdim]
start = param_dict.get(str(start), start)
stop = param_dict.get(str(stop), stop)
kernel_txt += (
"for "
+ legal_loop_indices[eachdim]
+ " in range("
+ str(start)
+ ", "
+ str(stop)
+ " + 1):\n"
)

for eachdim in range(global_id_dim, for_loop_dim):
for indent in range(1 + (eachdim - global_id_dim)):
kernel_txt += " "

# Add the sentinel assignment so that we can find the loop body position
# in the IR.
kernel_txt += " "
kernel_txt += sentinel_name + " = 0\n"

# A kernel function does not return anything
kernel_txt += " return None\n"

return kernel_txt


def _wrap_loop_body(loop_body):
blocks = loop_body.copy() # shallow copy is enough
first_label = min(blocks.keys())
last_label = max(blocks.keys())
loc = blocks[last_label].loc
blocks[last_label].body.append(ir.Jump(first_label, loc))
return blocks


def _unwrap_loop_body(loop_body):
last_label = max(loop_body.keys())
loop_body[last_label].body = loop_body[last_label].body[:-1]


def _find_setitems_block(setitems, block, typemap):
for inst in block.body:
if isinstance(inst, ir.StaticSetItem) or isinstance(inst, ir.SetItem):
Expand Down Expand Up @@ -403,18 +279,8 @@ def create_kernel_for_parfor(

# Get all the parfor params.
parfor_params = parfor_node.params

# Get all parfor reduction vars, and operators.
typemap = lowerer.fndesc.typemap

parfor_redvars, parfor_reddict = parfor.get_parfor_reductions(
lowerer.func_ir, parfor_node, parfor_params, lowerer.fndesc.calltypes
)
has_reduction = False if len(parfor_redvars) == 0 else True

if has_reduction:
raise NotImplementedError

# Compute just the parfor inputs as a set difference.
parfor_inputs = sorted(list(set(parfor_params) - set(parfor_outputs)))

Expand All @@ -435,7 +301,6 @@ def create_kernel_for_parfor(
# dict of potentially illegal param name to guaranteed legal name.
param_dict = _legalize_names_with_typemap(parfor_params, typemap)
ind_dict = _legalize_names_with_typemap(loop_indices, typemap)
redvars_dict = legalize_names(parfor_redvars)

# Compute a new list of legal loop index names.
legal_loop_indices = [ind_dict[v] for v in loop_indices]
Expand Down Expand Up @@ -477,35 +342,16 @@ def create_kernel_for_parfor(
# Determine the unique names of the kernel functions.
kernel_name = "__numba_parfor_kernel_%s" % (parfor_node.id)

kernel_fn_txt = _generate_kernel_stub_as_string(
kernel_name,
parfor_params,
parfor_dim,
legal_loop_indices,
loop_ranges,
param_dict,
has_reduction,
parfor_redvars,
typemap,
redvars_dict,
sentinel_name,
kernel_template = RangeKernelTemplate(
kernel_name=kernel_name,
kernel_params=parfor_params,
kernel_rank=parfor_dim,
ivar_names=legal_loop_indices,
sentinel_name=sentinel_name,
loop_ranges=loop_ranges,
param_dict=param_dict,
)

if config.DEBUG_ARRAY_OPT:
print("kernel_fn_txt = ", type(kernel_fn_txt), "\n", kernel_fn_txt)
sys.stdout.flush()

# Exec the kernel_fn_txt string into existence.
globls = {"dpnp": dpnp, "numba": numba, "dpex": dpex}
locls = {}
exec(kernel_fn_txt, globls, locls)

kernel_fn = locls[kernel_name]

if config.DEBUG_ARRAY_OPT:
print("kernel_fn = ", type(kernel_fn), "\n", kernel_fn)
# Get the IR for the kernel_fn dpex kernel
kernel_ir = compiler.run_frontend(kernel_fn)
kernel_ir = kernel_template.kernel_ir

if config.DEBUG_ARRAY_OPT:
print("kernel_ir dump ", type(kernel_ir))
Expand All @@ -527,25 +373,21 @@ def create_kernel_for_parfor(
print("kernel_ir dump after renaming ")
kernel_ir.dump()

gufunc_param_types = param_types
kernel_param_types = param_types

if config.DEBUG_ARRAY_OPT:
print(
"gufunc_param_types = ",
type(gufunc_param_types),
"kernel_param_types = ",
type(kernel_param_types),
"\n",
gufunc_param_types,
kernel_param_types,
)

gufunc_stub_last_label = max(kernel_ir.blocks.keys()) + 1
kernel_stub_last_label = max(kernel_ir.blocks.keys()) + 1

# Add gufunc stub last label to each parfor.loop_body label to prevent
# Add kernel stub last label to each parfor.loop_body label to prevent
# label conflicts.
loop_body = add_offset_to_labels(loop_body, gufunc_stub_last_label)

# If enabled, add a print statement after every assignment.
if config.DEBUG_ARRAY_OPT_RUNTIME:
_dbgprint_after_each_array_assignments(lowerer, loop_body, typemap)
loop_body = add_offset_to_labels(loop_body, kernel_stub_last_label)

_replace_sentinel_with_parfor_body(kernel_ir, sentinel_name, loop_body)

Expand All @@ -565,10 +407,10 @@ def create_kernel_for_parfor(
remove_dead(kernel_ir.blocks, kernel_ir.arg_names, kernel_ir, typemap)

if config.DEBUG_ARRAY_OPT:
print("gufunc_ir after remove dead")
print("kernel_ir after remove dead")
kernel_ir.dump()

kernel_sig = signature(types.none, *gufunc_param_types)
kernel_sig = signature(types.none, *kernel_param_types)

if config.DEBUG_ARRAY_OPT:
sys.stdout.flush()
Expand Down Expand Up @@ -597,7 +439,7 @@ def create_kernel_for_parfor(
exec_queue,
kernel_name,
kernel_ir,
gufunc_param_types,
kernel_param_types,
debug=flags.debuginfo,
)

Expand All @@ -606,7 +448,7 @@ def create_kernel_for_parfor(
if config.DEBUG_ARRAY_OPT:
print("kernel_sig = ", kernel_sig)

return GufuncKernel(
return ParforKernel(
name=kernel_name,
kernel=sycl_kernel,
signature=kernel_sig,
Expand Down
9 changes: 9 additions & 0 deletions numba_dpex/core/utils/kernel_templates/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# SPDX-FileCopyrightText: 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

from numba_dpex.core.utils.kernel_templates.range_kernel_template import (
RangeKernelTemplate,
)

__all__ = ["RangeKernelTemplate"]
Loading