Skip to content

Commit 33b3af4

Browse files
author
Diptorup Deb
authored
Merge pull request #967 from IntelPython/refactor/kernel_builder
Refactor/kernel builder
2 parents a282a42 + 6829e9b commit 33b3af4

File tree

4 files changed

+174
-181
lines changed

4 files changed

+174
-181
lines changed

numba_dpex/core/utils/kernel_builder.py

Lines changed: 22 additions & 180 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,10 @@
3030
from ..descriptor import dpex_kernel_target
3131
from ..passes import parfor
3232
from ..types.dpnp_ndarray_type import DpnpNdArray
33+
from .kernel_templates import RangeKernelTemplate
3334

3435

35-
class GufuncKernel:
36+
class ParforKernel:
3637
def __init__(
3738
self,
3839
name,
@@ -172,131 +173,6 @@ def _replace_var_with_array(vars, loop_body, typemap, calltypes):
172173
typemap[v] = types.npytypes.Array(el_typ, 1, "C")
173174

174175

175-
def _dbgprint_after_each_array_assignments(lowerer, loop_body, typemap):
176-
for label, block in loop_body.items():
177-
new_block = block.copy()
178-
new_block.clear()
179-
loc = block.loc
180-
scope = block.scope
181-
for inst in block.body:
182-
new_block.append(inst)
183-
# Append print after assignment
184-
if isinstance(inst, ir.Assign):
185-
# Only apply to numbers
186-
if typemap[inst.target.name] not in types.number_domain:
187-
continue
188-
189-
# Make constant string
190-
strval = "{} =".format(inst.target.name)
191-
strconsttyp = types.StringLiteral(strval)
192-
193-
lhs = ir.Var(scope, mk_unique_var("str_const"), loc)
194-
assign_lhs = ir.Assign(
195-
value=ir.Const(value=strval, loc=loc), target=lhs, loc=loc
196-
)
197-
typemap[lhs.name] = strconsttyp
198-
new_block.append(assign_lhs)
199-
200-
# Make print node
201-
print_node = ir.Print(
202-
args=[lhs, inst.target], vararg=None, loc=loc
203-
)
204-
new_block.append(print_node)
205-
sig = signature(
206-
types.none, typemap[lhs.name], typemap[inst.target.name]
207-
)
208-
lowerer.fndesc.calltypes[print_node] = sig
209-
loop_body[label] = new_block
210-
211-
212-
def _generate_kernel_stub_as_string(
213-
kernel_name,
214-
parfor_params,
215-
parfor_dim,
216-
legal_loop_indices,
217-
loop_ranges,
218-
param_dict,
219-
has_reduction,
220-
redvars,
221-
typemap,
222-
redvars_dict,
223-
sentinel_name,
224-
):
225-
"""Generates a stub dpex kernel for the parfor.
226-
227-
Returns:
228-
str: A string representing a stub kernel function for the parfor.
229-
"""
230-
kernel_txt = ""
231-
232-
# Create the dpex kernel function.
233-
kernel_txt += "def " + kernel_name
234-
kernel_txt += "(" + (", ".join(parfor_params)) + "):\n"
235-
global_id_dim = 0
236-
for_loop_dim = parfor_dim
237-
238-
if parfor_dim > 3:
239-
raise NotImplementedError
240-
global_id_dim = 3
241-
else:
242-
global_id_dim = parfor_dim
243-
244-
for eachdim in range(global_id_dim):
245-
kernel_txt += (
246-
" "
247-
+ legal_loop_indices[eachdim]
248-
+ " = "
249-
+ "dpex.get_global_id("
250-
+ str(eachdim)
251-
+ ")\n"
252-
)
253-
254-
for eachdim in range(global_id_dim, for_loop_dim):
255-
for indent in range(1 + (eachdim - global_id_dim)):
256-
kernel_txt += " "
257-
258-
start, stop, step = loop_ranges[eachdim]
259-
start = param_dict.get(str(start), start)
260-
stop = param_dict.get(str(stop), stop)
261-
kernel_txt += (
262-
"for "
263-
+ legal_loop_indices[eachdim]
264-
+ " in range("
265-
+ str(start)
266-
+ ", "
267-
+ str(stop)
268-
+ " + 1):\n"
269-
)
270-
271-
for eachdim in range(global_id_dim, for_loop_dim):
272-
for indent in range(1 + (eachdim - global_id_dim)):
273-
kernel_txt += " "
274-
275-
# Add the sentinel assignment so that we can find the loop body position
276-
# in the IR.
277-
kernel_txt += " "
278-
kernel_txt += sentinel_name + " = 0\n"
279-
280-
# A kernel function does not return anything
281-
kernel_txt += " return None\n"
282-
283-
return kernel_txt
284-
285-
286-
def _wrap_loop_body(loop_body):
287-
blocks = loop_body.copy() # shallow copy is enough
288-
first_label = min(blocks.keys())
289-
last_label = max(blocks.keys())
290-
loc = blocks[last_label].loc
291-
blocks[last_label].body.append(ir.Jump(first_label, loc))
292-
return blocks
293-
294-
295-
def _unwrap_loop_body(loop_body):
296-
last_label = max(loop_body.keys())
297-
loop_body[last_label].body = loop_body[last_label].body[:-1]
298-
299-
300176
def _find_setitems_block(setitems, block, typemap):
301177
for inst in block.body:
302178
if isinstance(inst, ir.StaticSetItem) or isinstance(inst, ir.SetItem):
@@ -403,18 +279,8 @@ def create_kernel_for_parfor(
403279

404280
# Get all the parfor params.
405281
parfor_params = parfor_node.params
406-
407-
# Get all parfor reduction vars, and operators.
408282
typemap = lowerer.fndesc.typemap
409283

410-
parfor_redvars, parfor_reddict = parfor.get_parfor_reductions(
411-
lowerer.func_ir, parfor_node, parfor_params, lowerer.fndesc.calltypes
412-
)
413-
has_reduction = False if len(parfor_redvars) == 0 else True
414-
415-
if has_reduction:
416-
raise NotImplementedError
417-
418284
# Compute just the parfor inputs as a set difference.
419285
parfor_inputs = sorted(list(set(parfor_params) - set(parfor_outputs)))
420286

@@ -435,7 +301,6 @@ def create_kernel_for_parfor(
435301
# dict of potentially illegal param name to guaranteed legal name.
436302
param_dict = _legalize_names_with_typemap(parfor_params, typemap)
437303
ind_dict = _legalize_names_with_typemap(loop_indices, typemap)
438-
redvars_dict = legalize_names(parfor_redvars)
439304

440305
# Compute a new list of legal loop index names.
441306
legal_loop_indices = [ind_dict[v] for v in loop_indices]
@@ -477,35 +342,16 @@ def create_kernel_for_parfor(
477342
# Determine the unique names of the kernel functions.
478343
kernel_name = "__numba_parfor_kernel_%s" % (parfor_node.id)
479344

480-
kernel_fn_txt = _generate_kernel_stub_as_string(
481-
kernel_name,
482-
parfor_params,
483-
parfor_dim,
484-
legal_loop_indices,
485-
loop_ranges,
486-
param_dict,
487-
has_reduction,
488-
parfor_redvars,
489-
typemap,
490-
redvars_dict,
491-
sentinel_name,
345+
kernel_template = RangeKernelTemplate(
346+
kernel_name=kernel_name,
347+
kernel_params=parfor_params,
348+
kernel_rank=parfor_dim,
349+
ivar_names=legal_loop_indices,
350+
sentinel_name=sentinel_name,
351+
loop_ranges=loop_ranges,
352+
param_dict=param_dict,
492353
)
493-
494-
if config.DEBUG_ARRAY_OPT:
495-
print("kernel_fn_txt = ", type(kernel_fn_txt), "\n", kernel_fn_txt)
496-
sys.stdout.flush()
497-
498-
# Exec the kernel_fn_txt string into existence.
499-
globls = {"dpnp": dpnp, "numba": numba, "dpex": dpex}
500-
locls = {}
501-
exec(kernel_fn_txt, globls, locls)
502-
503-
kernel_fn = locls[kernel_name]
504-
505-
if config.DEBUG_ARRAY_OPT:
506-
print("kernel_fn = ", type(kernel_fn), "\n", kernel_fn)
507-
# Get the IR for the kernel_fn dpex kernel
508-
kernel_ir = compiler.run_frontend(kernel_fn)
354+
kernel_ir = kernel_template.kernel_ir
509355

510356
if config.DEBUG_ARRAY_OPT:
511357
print("kernel_ir dump ", type(kernel_ir))
@@ -527,25 +373,21 @@ def create_kernel_for_parfor(
527373
print("kernel_ir dump after renaming ")
528374
kernel_ir.dump()
529375

530-
gufunc_param_types = param_types
376+
kernel_param_types = param_types
531377

532378
if config.DEBUG_ARRAY_OPT:
533379
print(
534-
"gufunc_param_types = ",
535-
type(gufunc_param_types),
380+
"kernel_param_types = ",
381+
type(kernel_param_types),
536382
"\n",
537-
gufunc_param_types,
383+
kernel_param_types,
538384
)
539385

540-
gufunc_stub_last_label = max(kernel_ir.blocks.keys()) + 1
386+
kernel_stub_last_label = max(kernel_ir.blocks.keys()) + 1
541387

542-
# Add gufunc stub last label to each parfor.loop_body label to prevent
388+
# Add kernel stub last label to each parfor.loop_body label to prevent
543389
# label conflicts.
544-
loop_body = add_offset_to_labels(loop_body, gufunc_stub_last_label)
545-
546-
# If enabled, add a print statement after every assignment.
547-
if config.DEBUG_ARRAY_OPT_RUNTIME:
548-
_dbgprint_after_each_array_assignments(lowerer, loop_body, typemap)
390+
loop_body = add_offset_to_labels(loop_body, kernel_stub_last_label)
549391

550392
_replace_sentinel_with_parfor_body(kernel_ir, sentinel_name, loop_body)
551393

@@ -565,10 +407,10 @@ def create_kernel_for_parfor(
565407
remove_dead(kernel_ir.blocks, kernel_ir.arg_names, kernel_ir, typemap)
566408

567409
if config.DEBUG_ARRAY_OPT:
568-
print("gufunc_ir after remove dead")
410+
print("kernel_ir after remove dead")
569411
kernel_ir.dump()
570412

571-
kernel_sig = signature(types.none, *gufunc_param_types)
413+
kernel_sig = signature(types.none, *kernel_param_types)
572414

573415
if config.DEBUG_ARRAY_OPT:
574416
sys.stdout.flush()
@@ -597,7 +439,7 @@ def create_kernel_for_parfor(
597439
exec_queue,
598440
kernel_name,
599441
kernel_ir,
600-
gufunc_param_types,
442+
kernel_param_types,
601443
debug=flags.debuginfo,
602444
)
603445

@@ -606,7 +448,7 @@ def create_kernel_for_parfor(
606448
if config.DEBUG_ARRAY_OPT:
607449
print("kernel_sig = ", kernel_sig)
608450

609-
return GufuncKernel(
451+
return ParforKernel(
610452
name=kernel_name,
611453
kernel=sycl_kernel,
612454
signature=kernel_sig,
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# SPDX-FileCopyrightText: 2023 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from numba_dpex.core.utils.kernel_templates.range_kernel_template import (
6+
RangeKernelTemplate,
7+
)
8+
9+
__all__ = ["RangeKernelTemplate"]

0 commit comments

Comments
 (0)