Skip to content

Commit eacca8a

Browse files
committed
wip
1 parent 2b173f8 commit eacca8a

File tree

10 files changed

+229
-413
lines changed

10 files changed

+229
-413
lines changed

numba_dpex/core/descriptor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class DpexTargetOptions(CPUTargetOptions):
4848
no_compile = _option_mapping("no_compile")
4949
inline_threshold = _option_mapping("inline_threshold")
5050
_compilation_mode = _option_mapping("_compilation_mode")
51-
_reduction_kernel_variables = _option_mapping("_reduction_kernel_variables")
51+
_parfor_args = _option_mapping("_parfor_args")
5252

5353
def finalize(self, flags, options):
5454
super().finalize(flags, options)
@@ -64,7 +64,7 @@ def finalize(self, flags, options):
6464
_inherit_if_not_set(
6565
flags, options, "_compilation_mode", CompilationMode.KERNEL
6666
)
67-
_inherit_if_not_set(flags, options, "_reduction_kernel_variables", None)
67+
_inherit_if_not_set(flags, options, "_parfor_args", None)
6868

6969

7070
class DpexKernelTarget(TargetDescriptor):

numba_dpex/core/parfors/kernel_builder.py

Lines changed: 5 additions & 228 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
from numba_dpex.core import config
2929
from numba_dpex.core.decorators import kernel
30+
from numba_dpex.core.parfors.kernel_parfor_pass import ParforArguments
3031
from numba_dpex.core.types.kernel_api.index_space_ids import ItemType
3132
from numba_dpex.core.utils.call_kernel_builder import SPIRVKernelModule
3233
from numba_dpex.kernel_api_impl.spirv import spirv_generator
@@ -44,83 +45,21 @@
4445
class ParforKernel:
4546
def __init__(
4647
self,
47-
name,
48-
kernel,
4948
signature,
5049
kernel_args,
5150
kernel_arg_types,
52-
queue: dpctl.SyclQueue,
5351
local_accessors=None,
5452
work_group_size=None,
5553
kernel_module=None,
5654
):
57-
self.name = name
58-
self.kernel = kernel
5955
self.signature = signature
6056
self.kernel_args = kernel_args
6157
self.kernel_arg_types = kernel_arg_types
62-
self.queue = queue
6358
self.local_accessors = local_accessors
6459
self.work_group_size = work_group_size
6560
self.kernel_module = kernel_module
6661

6762

68-
def _print_block(block):
69-
for i, inst in enumerate(block.body):
70-
print(" ", i, inst)
71-
72-
73-
def _print_body(body_dict):
74-
"""Pretty-print a set of IR blocks."""
75-
for label, block in body_dict.items():
76-
print("label: ", label)
77-
_print_block(block)
78-
79-
80-
def _compile_kernel_parfor(
81-
sycl_queue, kernel_name, func_ir, argtypes, debug=False
82-
):
83-
with target_override(dpex_kernel_target.target_context.target_name):
84-
cres = compile_numba_ir_with_dpex(
85-
pyfunc=func_ir,
86-
pyfunc_name=kernel_name,
87-
args=argtypes,
88-
return_type=None,
89-
debug=debug,
90-
is_kernel=True,
91-
typing_context=dpex_kernel_target.typing_context,
92-
target_context=dpex_kernel_target.target_context,
93-
extra_compile_flags=None,
94-
)
95-
cres.library.inline_threshold = config.INLINE_THRESHOLD
96-
cres.library._optimize_final_module()
97-
func = cres.library.get_function(cres.fndesc.llvm_func_name)
98-
kernel = dpex_kernel_target.target_context.prepare_spir_kernel(
99-
func, cres.signature.args
100-
)
101-
spirv_module = spirv_generator.llvm_to_spirv(
102-
dpex_kernel_target.target_context,
103-
kernel.module.__str__(),
104-
kernel.module.as_bitcode(),
105-
)
106-
107-
dpctl_create_program_from_spirv_flags = []
108-
if debug or config.DPEX_OPT == 0:
109-
# if debug is ON we need to pass additional flags to igc.
110-
dpctl_create_program_from_spirv_flags = ["-g", "-cl-opt-disable"]
111-
112-
# create a sycl::kernel_bundle
113-
kernel_bundle = dpctl_prog.create_program_from_spirv(
114-
sycl_queue,
115-
spirv_module,
116-
" ".join(dpctl_create_program_from_spirv_flags),
117-
)
118-
# create a sycl::kernel
119-
sycl_kernel = kernel_bundle.get_sycl_kernel(kernel.name)
120-
121-
return sycl_kernel
122-
123-
12463
def _legalize_names_with_typemap(names, typemap):
12564
"""Replace illegal characters in Numba IR var names.
12665
@@ -197,69 +136,6 @@ def _replace_var_with_array(vars, loop_body, typemap, calltypes):
197136
typemap[v] = types.npytypes.Array(el_typ, 1, "C")
198137

199138

200-
def _find_setitems_block(setitems, block, typemap):
201-
for inst in block.body:
202-
if isinstance(inst, ir.StaticSetItem) or isinstance(inst, ir.SetItem):
203-
setitems.add(inst.target.name)
204-
elif isinstance(inst, parfor.Parfor):
205-
_find_setitems_block(setitems, inst.init_block, typemap)
206-
_find_setitems_body(setitems, inst.loop_body, typemap)
207-
208-
209-
def _find_setitems_body(setitems, loop_body, typemap):
210-
"""
211-
Find the arrays that are written into (goes into setitems)
212-
"""
213-
for label, block in loop_body.items():
214-
_find_setitems_block(setitems, block, typemap)
215-
216-
217-
def _replace_sentinel_with_parfor_body(kernel_ir, sentinel_name, loop_body):
218-
# new label for splitting sentinel block
219-
new_label = max(loop_body.keys()) + 1
220-
221-
# Search all the block in the kernel function for the sentinel assignment.
222-
for label, block in kernel_ir.blocks.items():
223-
for i, inst in enumerate(block.body):
224-
if (
225-
isinstance(inst, ir.Assign)
226-
and inst.target.name == sentinel_name
227-
):
228-
# We found the sentinel assignment.
229-
loc = inst.loc
230-
scope = block.scope
231-
# split block across __sentinel__
232-
# A new block is allocated for the statements prior to the
233-
# sentinel but the new block maintains the current block label.
234-
prev_block = ir.Block(scope, loc)
235-
prev_block.body = block.body[:i]
236-
237-
# The current block is used for statements after the sentinel.
238-
block.body = block.body[i + 1 :] # noqa: E203
239-
# But the current block gets a new label.
240-
body_first_label = min(loop_body.keys())
241-
242-
# The previous block jumps to the minimum labelled block of the
243-
# parfor body.
244-
prev_block.append(ir.Jump(body_first_label, loc))
245-
# Add all the parfor loop body blocks to the kernel function's
246-
# IR.
247-
for loop, b in loop_body.items():
248-
kernel_ir.blocks[loop] = b
249-
body_last_label = max(loop_body.keys())
250-
kernel_ir.blocks[new_label] = block
251-
kernel_ir.blocks[label] = prev_block
252-
# Add a jump from the last parfor body block to the block
253-
# containing statements after the sentinel.
254-
kernel_ir.blocks[body_last_label].append(
255-
ir.Jump(new_label, loc)
256-
)
257-
break
258-
else:
259-
continue
260-
break
261-
262-
263139
def create_kernel_for_parfor(
264140
lowerer,
265141
parfor_node,
@@ -375,127 +251,28 @@ def create_kernel_for_parfor(
375251
loop_ranges=loop_ranges,
376252
param_dict=param_dict,
377253
)
378-
kernel_ir = kernel_template.kernel_ir
379254

380-
kernel_dispatcher: SPIRVKernelDispatcher = kernel(kernel_template._py_func)
381-
382-
if config.DEBUG_ARRAY_OPT:
383-
print("kernel_ir dump ", type(kernel_ir))
384-
kernel_ir.dump()
385-
print("loop_body dump ", type(loop_body))
386-
_print_body(loop_body)
387-
388-
# rename all variables in kernel_ir afresh
389-
var_table = get_name_var_table(kernel_ir.blocks)
390-
new_var_dict = {}
391-
reserved_names = (
392-
[sentinel_name] + list(param_dict.values()) + legal_loop_indices
255+
kernel_dispatcher: SPIRVKernelDispatcher = kernel(
256+
kernel_template._py_func,
257+
_parfor_args=ParforArguments(loop_body=loop_body),
393258
)
394-
for name, var in var_table.items():
395-
if not (name in reserved_names):
396-
new_var_dict[name] = mk_unique_var(name)
397-
replace_var_names(kernel_ir.blocks, new_var_dict)
398-
if config.DEBUG_ARRAY_OPT:
399-
print("kernel_ir dump after renaming ")
400-
kernel_ir.dump()
401259

402-
kernel_param_types = param_types
403-
404-
if config.DEBUG_ARRAY_OPT:
405-
print(
406-
"kernel_param_types = ",
407-
type(kernel_param_types),
408-
"\n",
409-
kernel_param_types,
410-
)
411-
412-
kernel_stub_last_label = max(kernel_ir.blocks.keys()) + 1
413-
414-
# Add kernel stub last label to each parfor.loop_body label to prevent
415-
# label conflicts.
416-
loop_body = add_offset_to_labels(loop_body, kernel_stub_last_label)
417-
418-
_replace_sentinel_with_parfor_body(kernel_ir, sentinel_name, loop_body)
419-
420-
if config.DEBUG_ARRAY_OPT:
421-
print("kernel_ir last dump before renaming")
422-
kernel_ir.dump()
423-
424-
kernel_ir.blocks = rename_labels(kernel_ir.blocks)
425-
remove_dels(kernel_ir.blocks)
426-
427-
old_alias = flags.noalias
428-
if not has_aliases:
429-
if config.DEBUG_ARRAY_OPT:
430-
print("No aliases found so adding noalias flag.")
431-
flags.noalias = True
432-
433-
remove_dead(kernel_ir.blocks, kernel_ir.arg_names, kernel_ir, typemap)
434-
435-
if config.DEBUG_ARRAY_OPT:
436-
print("kernel_ir after remove dead")
437-
kernel_ir.dump()
438-
439-
# The first argument to a range kernel is a kernel_api.Item object. The
440-
# ``Item`` object is used by the kernel_api.spirv backend to generate the
441-
# correct SPIR-V indexing instructions. Since, the argument is not something
442-
# available originally in the kernel_param_types, we add it at this point to
443-
# make sure the kernel signature matches the actual generated code.
444260
ty_item = ItemType(parfor_dim)
445-
kernel_param_types = (ty_item, *kernel_param_types)
261+
kernel_param_types = (ty_item, *param_types)
446262
kernel_sig = signature(types.none, *kernel_param_types)
447263

448-
if config.DEBUG_ARRAY_OPT:
449-
sys.stdout.flush()
450-
451-
if config.DEBUG_ARRAY_OPT:
452-
print("after DUFunc inline".center(80, "-"))
453-
kernel_ir.dump()
454-
455-
# The ParforLegalizeCFD pass has already ensured that the LHS and RHS
456-
# arrays are on same device. We can take the queue from the first input
457-
# array and use that to compile the kernel.
458-
459-
exec_queue: dpctl.SyclQueue = None
460-
461-
for arg in parfor_args:
462-
obj = typemap[arg]
463-
if isinstance(obj, DpnpNdArray):
464-
filter_string = obj.queue.sycl_device
465-
# FIXME: A better design is required so that we do not have to
466-
# create a queue every time.
467-
exec_queue = dpctl.get_device_cached_queue(filter_string)
468-
469-
if not exec_queue:
470-
raise AssertionError(
471-
"No execution found for parfor. No way to compile the kernel!"
472-
)
473-
474-
sycl_kernel = _compile_kernel_parfor(
475-
exec_queue,
476-
kernel_name,
477-
kernel_ir,
478-
kernel_param_types,
479-
debug=flags.debuginfo,
480-
)
481-
482264
kcres: _SPIRVKernelCompileResult = kernel_dispatcher.get_compile_result(
483265
types.void(*kernel_param_types) # kernel signature
484266
)
485267
kernel_module: SPIRVKernelModule = kcres.kernel_device_ir_module
486268

487-
flags.noalias = old_alias
488-
489269
if config.DEBUG_ARRAY_OPT:
490270
print("kernel_sig = ", kernel_sig)
491271

492272
return ParforKernel(
493-
name=kernel_name,
494-
kernel=sycl_kernel,
495273
signature=kernel_sig,
496274
kernel_args=parfor_args,
497275
kernel_arg_types=func_arg_types,
498-
queue=exec_queue,
499276
kernel_module=kernel_module,
500277
)
501278

0 commit comments

Comments
 (0)