3030from ..descriptor import dpex_kernel_target
3131from ..passes import parfor
3232from ..types .dpnp_ndarray_type import DpnpNdArray
33+ from .kernel_templates import RangeKernelTemplate
3334
3435
35- class GufuncKernel :
36+ class ParforKernel :
3637 def __init__ (
3738 self ,
3839 name ,
@@ -172,131 +173,6 @@ def _replace_var_with_array(vars, loop_body, typemap, calltypes):
172173 typemap [v ] = types .npytypes .Array (el_typ , 1 , "C" )
173174
174175
175- def _dbgprint_after_each_array_assignments (lowerer , loop_body , typemap ):
176- for label , block in loop_body .items ():
177- new_block = block .copy ()
178- new_block .clear ()
179- loc = block .loc
180- scope = block .scope
181- for inst in block .body :
182- new_block .append (inst )
183- # Append print after assignment
184- if isinstance (inst , ir .Assign ):
185- # Only apply to numbers
186- if typemap [inst .target .name ] not in types .number_domain :
187- continue
188-
189- # Make constant string
190- strval = "{} =" .format (inst .target .name )
191- strconsttyp = types .StringLiteral (strval )
192-
193- lhs = ir .Var (scope , mk_unique_var ("str_const" ), loc )
194- assign_lhs = ir .Assign (
195- value = ir .Const (value = strval , loc = loc ), target = lhs , loc = loc
196- )
197- typemap [lhs .name ] = strconsttyp
198- new_block .append (assign_lhs )
199-
200- # Make print node
201- print_node = ir .Print (
202- args = [lhs , inst .target ], vararg = None , loc = loc
203- )
204- new_block .append (print_node )
205- sig = signature (
206- types .none , typemap [lhs .name ], typemap [inst .target .name ]
207- )
208- lowerer .fndesc .calltypes [print_node ] = sig
209- loop_body [label ] = new_block
210-
211-
212- def _generate_kernel_stub_as_string (
213- kernel_name ,
214- parfor_params ,
215- parfor_dim ,
216- legal_loop_indices ,
217- loop_ranges ,
218- param_dict ,
219- has_reduction ,
220- redvars ,
221- typemap ,
222- redvars_dict ,
223- sentinel_name ,
224- ):
225- """Generates a stub dpex kernel for the parfor.
226-
227- Returns:
228- str: A string representing a stub kernel function for the parfor.
229- """
230- kernel_txt = ""
231-
232- # Create the dpex kernel function.
233- kernel_txt += "def " + kernel_name
234- kernel_txt += "(" + (", " .join (parfor_params )) + "):\n "
235- global_id_dim = 0
236- for_loop_dim = parfor_dim
237-
238- if parfor_dim > 3 :
239- raise NotImplementedError
240- global_id_dim = 3
241- else :
242- global_id_dim = parfor_dim
243-
244- for eachdim in range (global_id_dim ):
245- kernel_txt += (
246- " "
247- + legal_loop_indices [eachdim ]
248- + " = "
249- + "dpex.get_global_id("
250- + str (eachdim )
251- + ")\n "
252- )
253-
254- for eachdim in range (global_id_dim , for_loop_dim ):
255- for indent in range (1 + (eachdim - global_id_dim )):
256- kernel_txt += " "
257-
258- start , stop , step = loop_ranges [eachdim ]
259- start = param_dict .get (str (start ), start )
260- stop = param_dict .get (str (stop ), stop )
261- kernel_txt += (
262- "for "
263- + legal_loop_indices [eachdim ]
264- + " in range("
265- + str (start )
266- + ", "
267- + str (stop )
268- + " + 1):\n "
269- )
270-
271- for eachdim in range (global_id_dim , for_loop_dim ):
272- for indent in range (1 + (eachdim - global_id_dim )):
273- kernel_txt += " "
274-
275- # Add the sentinel assignment so that we can find the loop body position
276- # in the IR.
277- kernel_txt += " "
278- kernel_txt += sentinel_name + " = 0\n "
279-
280- # A kernel function does not return anything
281- kernel_txt += " return None\n "
282-
283- return kernel_txt
284-
285-
286- def _wrap_loop_body (loop_body ):
287- blocks = loop_body .copy () # shallow copy is enough
288- first_label = min (blocks .keys ())
289- last_label = max (blocks .keys ())
290- loc = blocks [last_label ].loc
291- blocks [last_label ].body .append (ir .Jump (first_label , loc ))
292- return blocks
293-
294-
295- def _unwrap_loop_body (loop_body ):
296- last_label = max (loop_body .keys ())
297- loop_body [last_label ].body = loop_body [last_label ].body [:- 1 ]
298-
299-
300176def _find_setitems_block (setitems , block , typemap ):
301177 for inst in block .body :
302178 if isinstance (inst , ir .StaticSetItem ) or isinstance (inst , ir .SetItem ):
@@ -403,18 +279,8 @@ def create_kernel_for_parfor(
403279
404280 # Get all the parfor params.
405281 parfor_params = parfor_node .params
406-
407- # Get all parfor reduction vars, and operators.
408282 typemap = lowerer .fndesc .typemap
409283
410- parfor_redvars , parfor_reddict = parfor .get_parfor_reductions (
411- lowerer .func_ir , parfor_node , parfor_params , lowerer .fndesc .calltypes
412- )
413- has_reduction = False if len (parfor_redvars ) == 0 else True
414-
415- if has_reduction :
416- raise NotImplementedError
417-
418284 # Compute just the parfor inputs as a set difference.
419285 parfor_inputs = sorted (list (set (parfor_params ) - set (parfor_outputs )))
420286
@@ -435,7 +301,6 @@ def create_kernel_for_parfor(
435301 # dict of potentially illegal param name to guaranteed legal name.
436302 param_dict = _legalize_names_with_typemap (parfor_params , typemap )
437303 ind_dict = _legalize_names_with_typemap (loop_indices , typemap )
438- redvars_dict = legalize_names (parfor_redvars )
439304
440305 # Compute a new list of legal loop index names.
441306 legal_loop_indices = [ind_dict [v ] for v in loop_indices ]
@@ -477,35 +342,16 @@ def create_kernel_for_parfor(
477342 # Determine the unique names of the kernel functions.
478343 kernel_name = "__numba_parfor_kernel_%s" % (parfor_node .id )
479344
480- kernel_fn_txt = _generate_kernel_stub_as_string (
481- kernel_name ,
482- parfor_params ,
483- parfor_dim ,
484- legal_loop_indices ,
485- loop_ranges ,
486- param_dict ,
487- has_reduction ,
488- parfor_redvars ,
489- typemap ,
490- redvars_dict ,
491- sentinel_name ,
345+ kernel_template = RangeKernelTemplate (
346+ kernel_name = kernel_name ,
347+ kernel_params = parfor_params ,
348+ kernel_rank = parfor_dim ,
349+ ivar_names = legal_loop_indices ,
350+ sentinel_name = sentinel_name ,
351+ loop_ranges = loop_ranges ,
352+ param_dict = param_dict ,
492353 )
493-
494- if config .DEBUG_ARRAY_OPT :
495- print ("kernel_fn_txt = " , type (kernel_fn_txt ), "\n " , kernel_fn_txt )
496- sys .stdout .flush ()
497-
498- # Exec the kernel_fn_txt string into existence.
499- globls = {"dpnp" : dpnp , "numba" : numba , "dpex" : dpex }
500- locls = {}
501- exec (kernel_fn_txt , globls , locls )
502-
503- kernel_fn = locls [kernel_name ]
504-
505- if config .DEBUG_ARRAY_OPT :
506- print ("kernel_fn = " , type (kernel_fn ), "\n " , kernel_fn )
507- # Get the IR for the kernel_fn dpex kernel
508- kernel_ir = compiler .run_frontend (kernel_fn )
354+ kernel_ir = kernel_template .kernel_ir
509355
510356 if config .DEBUG_ARRAY_OPT :
511357 print ("kernel_ir dump " , type (kernel_ir ))
@@ -527,25 +373,21 @@ def create_kernel_for_parfor(
527373 print ("kernel_ir dump after renaming " )
528374 kernel_ir .dump ()
529375
530- gufunc_param_types = param_types
376+ kernel_param_types = param_types
531377
532378 if config .DEBUG_ARRAY_OPT :
533379 print (
534- "gufunc_param_types = " ,
535- type (gufunc_param_types ),
380+ "kernel_param_types = " ,
381+ type (kernel_param_types ),
536382 "\n " ,
537- gufunc_param_types ,
383+ kernel_param_types ,
538384 )
539385
540- gufunc_stub_last_label = max (kernel_ir .blocks .keys ()) + 1
386+ kernel_stub_last_label = max (kernel_ir .blocks .keys ()) + 1
541387
542- # Add gufunc stub last label to each parfor.loop_body label to prevent
388+ # Add kernel stub last label to each parfor.loop_body label to prevent
543389 # label conflicts.
544- loop_body = add_offset_to_labels (loop_body , gufunc_stub_last_label )
545-
546- # If enabled, add a print statement after every assignment.
547- if config .DEBUG_ARRAY_OPT_RUNTIME :
548- _dbgprint_after_each_array_assignments (lowerer , loop_body , typemap )
390+ loop_body = add_offset_to_labels (loop_body , kernel_stub_last_label )
549391
550392 _replace_sentinel_with_parfor_body (kernel_ir , sentinel_name , loop_body )
551393
@@ -565,10 +407,10 @@ def create_kernel_for_parfor(
565407 remove_dead (kernel_ir .blocks , kernel_ir .arg_names , kernel_ir , typemap )
566408
567409 if config .DEBUG_ARRAY_OPT :
568- print ("gufunc_ir after remove dead" )
410+ print ("kernel_ir after remove dead" )
569411 kernel_ir .dump ()
570412
571- kernel_sig = signature (types .none , * gufunc_param_types )
413+ kernel_sig = signature (types .none , * kernel_param_types )
572414
573415 if config .DEBUG_ARRAY_OPT :
574416 sys .stdout .flush ()
@@ -597,7 +439,7 @@ def create_kernel_for_parfor(
597439 exec_queue ,
598440 kernel_name ,
599441 kernel_ir ,
600- gufunc_param_types ,
442+ kernel_param_types ,
601443 debug = flags .debuginfo ,
602444 )
603445
@@ -606,7 +448,7 @@ def create_kernel_for_parfor(
606448 if config .DEBUG_ARRAY_OPT :
607449 print ("kernel_sig = " , kernel_sig )
608450
609- return GufuncKernel (
451+ return ParforKernel (
610452 name = kernel_name ,
611453 kernel = sycl_kernel ,
612454 signature = kernel_sig ,
0 commit comments