|
27 | 27 |
|
28 | 28 | from numba_dpex.core import config |
29 | 29 | from numba_dpex.core.decorators import kernel |
| 30 | +from numba_dpex.core.parfors.kernel_parfor_pass import ParforArguments |
30 | 31 | from numba_dpex.core.types.kernel_api.index_space_ids import ItemType |
31 | 32 | from numba_dpex.core.utils.call_kernel_builder import SPIRVKernelModule |
32 | 33 | from numba_dpex.kernel_api_impl.spirv import spirv_generator |
|
44 | 45 | class ParforKernel: |
45 | 46 | def __init__( |
46 | 47 | self, |
47 | | - name, |
48 | | - kernel, |
49 | 48 | signature, |
50 | 49 | kernel_args, |
51 | 50 | kernel_arg_types, |
52 | | - queue: dpctl.SyclQueue, |
53 | 51 | local_accessors=None, |
54 | 52 | work_group_size=None, |
55 | 53 | kernel_module=None, |
56 | 54 | ): |
57 | | - self.name = name |
58 | | - self.kernel = kernel |
59 | 55 | self.signature = signature |
60 | 56 | self.kernel_args = kernel_args |
61 | 57 | self.kernel_arg_types = kernel_arg_types |
62 | | - self.queue = queue |
63 | 58 | self.local_accessors = local_accessors |
64 | 59 | self.work_group_size = work_group_size |
65 | 60 | self.kernel_module = kernel_module |
66 | 61 |
|
67 | 62 |
|
68 | | -def _print_block(block): |
69 | | - for i, inst in enumerate(block.body): |
70 | | - print(" ", i, inst) |
71 | | - |
72 | | - |
73 | | -def _print_body(body_dict): |
74 | | - """Pretty-print a set of IR blocks.""" |
75 | | - for label, block in body_dict.items(): |
76 | | - print("label: ", label) |
77 | | - _print_block(block) |
78 | | - |
79 | | - |
80 | | -def _compile_kernel_parfor( |
81 | | - sycl_queue, kernel_name, func_ir, argtypes, debug=False |
82 | | -): |
83 | | - with target_override(dpex_kernel_target.target_context.target_name): |
84 | | - cres = compile_numba_ir_with_dpex( |
85 | | - pyfunc=func_ir, |
86 | | - pyfunc_name=kernel_name, |
87 | | - args=argtypes, |
88 | | - return_type=None, |
89 | | - debug=debug, |
90 | | - is_kernel=True, |
91 | | - typing_context=dpex_kernel_target.typing_context, |
92 | | - target_context=dpex_kernel_target.target_context, |
93 | | - extra_compile_flags=None, |
94 | | - ) |
95 | | - cres.library.inline_threshold = config.INLINE_THRESHOLD |
96 | | - cres.library._optimize_final_module() |
97 | | - func = cres.library.get_function(cres.fndesc.llvm_func_name) |
98 | | - kernel = dpex_kernel_target.target_context.prepare_spir_kernel( |
99 | | - func, cres.signature.args |
100 | | - ) |
101 | | - spirv_module = spirv_generator.llvm_to_spirv( |
102 | | - dpex_kernel_target.target_context, |
103 | | - kernel.module.__str__(), |
104 | | - kernel.module.as_bitcode(), |
105 | | - ) |
106 | | - |
107 | | - dpctl_create_program_from_spirv_flags = [] |
108 | | - if debug or config.DPEX_OPT == 0: |
109 | | - # if debug is ON we need to pass additional flags to igc. |
110 | | - dpctl_create_program_from_spirv_flags = ["-g", "-cl-opt-disable"] |
111 | | - |
112 | | - # create a sycl::kernel_bundle |
113 | | - kernel_bundle = dpctl_prog.create_program_from_spirv( |
114 | | - sycl_queue, |
115 | | - spirv_module, |
116 | | - " ".join(dpctl_create_program_from_spirv_flags), |
117 | | - ) |
118 | | - # create a sycl::kernel |
119 | | - sycl_kernel = kernel_bundle.get_sycl_kernel(kernel.name) |
120 | | - |
121 | | - return sycl_kernel |
122 | | - |
123 | | - |
124 | 63 | def _legalize_names_with_typemap(names, typemap): |
125 | 64 | """Replace illegal characters in Numba IR var names. |
126 | 65 |
|
@@ -197,69 +136,6 @@ def _replace_var_with_array(vars, loop_body, typemap, calltypes): |
197 | 136 | typemap[v] = types.npytypes.Array(el_typ, 1, "C") |
198 | 137 |
|
199 | 138 |
|
200 | | -def _find_setitems_block(setitems, block, typemap): |
201 | | - for inst in block.body: |
202 | | - if isinstance(inst, ir.StaticSetItem) or isinstance(inst, ir.SetItem): |
203 | | - setitems.add(inst.target.name) |
204 | | - elif isinstance(inst, parfor.Parfor): |
205 | | - _find_setitems_block(setitems, inst.init_block, typemap) |
206 | | - _find_setitems_body(setitems, inst.loop_body, typemap) |
207 | | - |
208 | | - |
209 | | -def _find_setitems_body(setitems, loop_body, typemap): |
210 | | - """ |
211 | | - Find the arrays that are written into (goes into setitems) |
212 | | - """ |
213 | | - for label, block in loop_body.items(): |
214 | | - _find_setitems_block(setitems, block, typemap) |
215 | | - |
216 | | - |
217 | | -def _replace_sentinel_with_parfor_body(kernel_ir, sentinel_name, loop_body): |
218 | | - # new label for splitting sentinel block |
219 | | - new_label = max(loop_body.keys()) + 1 |
220 | | - |
221 | | - # Search all the block in the kernel function for the sentinel assignment. |
222 | | - for label, block in kernel_ir.blocks.items(): |
223 | | - for i, inst in enumerate(block.body): |
224 | | - if ( |
225 | | - isinstance(inst, ir.Assign) |
226 | | - and inst.target.name == sentinel_name |
227 | | - ): |
228 | | - # We found the sentinel assignment. |
229 | | - loc = inst.loc |
230 | | - scope = block.scope |
231 | | - # split block across __sentinel__ |
232 | | - # A new block is allocated for the statements prior to the |
233 | | - # sentinel but the new block maintains the current block label. |
234 | | - prev_block = ir.Block(scope, loc) |
235 | | - prev_block.body = block.body[:i] |
236 | | - |
237 | | - # The current block is used for statements after the sentinel. |
238 | | - block.body = block.body[i + 1 :] # noqa: E203 |
239 | | - # But the current block gets a new label. |
240 | | - body_first_label = min(loop_body.keys()) |
241 | | - |
242 | | - # The previous block jumps to the minimum labelled block of the |
243 | | - # parfor body. |
244 | | - prev_block.append(ir.Jump(body_first_label, loc)) |
245 | | - # Add all the parfor loop body blocks to the kernel function's |
246 | | - # IR. |
247 | | - for loop, b in loop_body.items(): |
248 | | - kernel_ir.blocks[loop] = b |
249 | | - body_last_label = max(loop_body.keys()) |
250 | | - kernel_ir.blocks[new_label] = block |
251 | | - kernel_ir.blocks[label] = prev_block |
252 | | - # Add a jump from the last parfor body block to the block |
253 | | - # containing statements after the sentinel. |
254 | | - kernel_ir.blocks[body_last_label].append( |
255 | | - ir.Jump(new_label, loc) |
256 | | - ) |
257 | | - break |
258 | | - else: |
259 | | - continue |
260 | | - break |
261 | | - |
262 | | - |
263 | 139 | def create_kernel_for_parfor( |
264 | 140 | lowerer, |
265 | 141 | parfor_node, |
@@ -375,127 +251,28 @@ def create_kernel_for_parfor( |
375 | 251 | loop_ranges=loop_ranges, |
376 | 252 | param_dict=param_dict, |
377 | 253 | ) |
378 | | - kernel_ir = kernel_template.kernel_ir |
379 | 254 |
|
380 | | - kernel_dispatcher: SPIRVKernelDispatcher = kernel(kernel_template._py_func) |
381 | | - |
382 | | - if config.DEBUG_ARRAY_OPT: |
383 | | - print("kernel_ir dump ", type(kernel_ir)) |
384 | | - kernel_ir.dump() |
385 | | - print("loop_body dump ", type(loop_body)) |
386 | | - _print_body(loop_body) |
387 | | - |
388 | | - # rename all variables in kernel_ir afresh |
389 | | - var_table = get_name_var_table(kernel_ir.blocks) |
390 | | - new_var_dict = {} |
391 | | - reserved_names = ( |
392 | | - [sentinel_name] + list(param_dict.values()) + legal_loop_indices |
| 255 | + kernel_dispatcher: SPIRVKernelDispatcher = kernel( |
| 256 | + kernel_template._py_func, |
| 257 | + _parfor_args=ParforArguments(loop_body=loop_body), |
393 | 258 | ) |
394 | | - for name, var in var_table.items(): |
395 | | - if not (name in reserved_names): |
396 | | - new_var_dict[name] = mk_unique_var(name) |
397 | | - replace_var_names(kernel_ir.blocks, new_var_dict) |
398 | | - if config.DEBUG_ARRAY_OPT: |
399 | | - print("kernel_ir dump after renaming ") |
400 | | - kernel_ir.dump() |
401 | 259 |
|
402 | | - kernel_param_types = param_types |
403 | | - |
404 | | - if config.DEBUG_ARRAY_OPT: |
405 | | - print( |
406 | | - "kernel_param_types = ", |
407 | | - type(kernel_param_types), |
408 | | - "\n", |
409 | | - kernel_param_types, |
410 | | - ) |
411 | | - |
412 | | - kernel_stub_last_label = max(kernel_ir.blocks.keys()) + 1 |
413 | | - |
414 | | - # Add kernel stub last label to each parfor.loop_body label to prevent |
415 | | - # label conflicts. |
416 | | - loop_body = add_offset_to_labels(loop_body, kernel_stub_last_label) |
417 | | - |
418 | | - _replace_sentinel_with_parfor_body(kernel_ir, sentinel_name, loop_body) |
419 | | - |
420 | | - if config.DEBUG_ARRAY_OPT: |
421 | | - print("kernel_ir last dump before renaming") |
422 | | - kernel_ir.dump() |
423 | | - |
424 | | - kernel_ir.blocks = rename_labels(kernel_ir.blocks) |
425 | | - remove_dels(kernel_ir.blocks) |
426 | | - |
427 | | - old_alias = flags.noalias |
428 | | - if not has_aliases: |
429 | | - if config.DEBUG_ARRAY_OPT: |
430 | | - print("No aliases found so adding noalias flag.") |
431 | | - flags.noalias = True |
432 | | - |
433 | | - remove_dead(kernel_ir.blocks, kernel_ir.arg_names, kernel_ir, typemap) |
434 | | - |
435 | | - if config.DEBUG_ARRAY_OPT: |
436 | | - print("kernel_ir after remove dead") |
437 | | - kernel_ir.dump() |
438 | | - |
439 | | - # The first argument to a range kernel is a kernel_api.Item object. The |
440 | | - # ``Item`` object is used by the kernel_api.spirv backend to generate the |
441 | | - # correct SPIR-V indexing instructions. Since, the argument is not something |
442 | | - # available originally in the kernel_param_types, we add it at this point to |
443 | | - # make sure the kernel signature matches the actual generated code. |
444 | 260 | ty_item = ItemType(parfor_dim) |
445 | | - kernel_param_types = (ty_item, *kernel_param_types) |
| 261 | + kernel_param_types = (ty_item, *param_types) |
446 | 262 | kernel_sig = signature(types.none, *kernel_param_types) |
447 | 263 |
|
448 | | - if config.DEBUG_ARRAY_OPT: |
449 | | - sys.stdout.flush() |
450 | | - |
451 | | - if config.DEBUG_ARRAY_OPT: |
452 | | - print("after DUFunc inline".center(80, "-")) |
453 | | - kernel_ir.dump() |
454 | | - |
455 | | - # The ParforLegalizeCFD pass has already ensured that the LHS and RHS |
456 | | - # arrays are on same device. We can take the queue from the first input |
457 | | - # array and use that to compile the kernel. |
458 | | - |
459 | | - exec_queue: dpctl.SyclQueue = None |
460 | | - |
461 | | - for arg in parfor_args: |
462 | | - obj = typemap[arg] |
463 | | - if isinstance(obj, DpnpNdArray): |
464 | | - filter_string = obj.queue.sycl_device |
465 | | - # FIXME: A better design is required so that we do not have to |
466 | | - # create a queue every time. |
467 | | - exec_queue = dpctl.get_device_cached_queue(filter_string) |
468 | | - |
469 | | - if not exec_queue: |
470 | | - raise AssertionError( |
471 | | - "No execution found for parfor. No way to compile the kernel!" |
472 | | - ) |
473 | | - |
474 | | - sycl_kernel = _compile_kernel_parfor( |
475 | | - exec_queue, |
476 | | - kernel_name, |
477 | | - kernel_ir, |
478 | | - kernel_param_types, |
479 | | - debug=flags.debuginfo, |
480 | | - ) |
481 | | - |
482 | 264 | kcres: _SPIRVKernelCompileResult = kernel_dispatcher.get_compile_result( |
483 | 265 | types.void(*kernel_param_types) # kernel signature |
484 | 266 | ) |
485 | 267 | kernel_module: SPIRVKernelModule = kcres.kernel_device_ir_module |
486 | 268 |
|
487 | | - flags.noalias = old_alias |
488 | | - |
489 | 269 | if config.DEBUG_ARRAY_OPT: |
490 | 270 | print("kernel_sig = ", kernel_sig) |
491 | 271 |
|
492 | 272 | return ParforKernel( |
493 | | - name=kernel_name, |
494 | | - kernel=sycl_kernel, |
495 | 273 | signature=kernel_sig, |
496 | 274 | kernel_args=parfor_args, |
497 | 275 | kernel_arg_types=func_arg_types, |
498 | | - queue=exec_queue, |
499 | 276 | kernel_module=kernel_module, |
500 | 277 | ) |
501 | 278 |
|
|
0 commit comments