@@ -484,29 +484,89 @@ def _get_transpose(self):
484484 tensor .T = property (_get_transpose )
485485
486486
487- def _patch_lang_core (lang , builder ):
488- for name , member in inspect .getmembers (lang ):
489- if tl .core .is_builtin (member ):
490- _patch_attr (lang , name , member , builder )
491- # reduce is better off with a separate patch due to how
492- # the builder currently interfaces with custom functions
487+ def _patch_reduce_scan (lang ):
493488
494- def _new_reduce ( input , axis , combine_fn , ** kwargs ):
495- if axis is not None and axis >= len (input .shape ):
496- raise ValueError (f"axis { axis } out of bounds for shape { input .shape } " )
489+ def _check_axis ( tensor , axis ):
490+ if axis is not None and axis >= len (tensor .shape ):
491+ raise ValueError (f"axis { axis } out of bounds for shape { tensor .shape } " )
497492
498- def _to_tensor (ret , dtype ):
499- if ret .shape :
500- ret_type = tl .block_type (dtype , ret .shape )
493+ def _to_tensor (ret , dtype ):
494+ if hasattr (ret , "shape" ) and ret .shape :
495+ ret_type = tl .block_type (dtype , ret .shape )
496+ else :
497+ ret = np .array ([ret ], dtype = _get_np_dtype (dtype ))
498+ ret_type = dtype
499+ return tl .core .tensor (TensorHandle (ret , dtype ), ret_type )
500+
501+ def _generic_reduce (input , axis , combine_fn , keep_dims ):
502+
503+ def _check_axis_and_unravel (input , axis ):
504+ ret = []
505+ if not isinstance (input , tuple ):
506+ input = (input , )
507+ for data in input :
508+ if isinstance (data , tl .core .tensor ):
509+ if axis is not None :
510+ _check_axis (data , axis )
511+ ret .append (data )
512+ else :
513+ axis = 0
514+ ret .append (_to_tensor (data .handle .data .flatten (), data .dtype ))
515+ return tuple (ret ), axis
516+
517+ original_input = input
518+ original_axis = axis
519+ input , axis = _check_axis_and_unravel (input , axis )
520+ input_data = []
521+ output_data = []
522+ input_shape = None
523+ output_shape = None
524+ for arg in input :
525+ if isinstance (arg , tl .core .tensor ):
526+ input_shape = arg .handle .data .shape
527+ input_data .append (arg .handle .data )
528+ output_shape = input_shape [0 :axis ] + input_shape [axis + 1 :]
529+ output_data .append (np .zeros (output_shape , dtype = arg .handle .data .dtype ))
530+ if not input_shape :
531+ raise ValueError ("no tensors found in input" )
532+ # Reduce on axis
533+ for i in range (len (input_data [0 ])):
534+ # Recover input_index from i using input_shape
535+ input_index = np .unravel_index (i , input_shape )
536+ output_index = input_index [0 :axis ] + input_index [axis + 1 :]
537+ input_tuple = tuple (d [input_index ] for d in input_data )
538+ if input_index [axis ] == 0 :
539+ # First element
540+ for j in range (len (output_data )):
541+ output_data [j ][output_index ] = input_tuple [j ]
501542 else :
502- ret = np .array ([ret ], dtype = _get_np_dtype (dtype ))
503- ret_type = dtype
504- return tl .core .tensor (TensorHandle (ret , dtype ), ret_type )
505-
506- def _min_max (input , val_reduce_op , idx_reduce_op = None , axis = None , return_indices_tie_break_left = True ,
507- keepdims = False ):
508- if return_indices_tie_break_left is False :
509- raise NotImplementedError ("return_indices_tie_break_left=False not supported in interpreter mode" )
543+ acc_tuple = tuple (o [output_index ] for o in output_data )
544+ acc_tuple = combine_fn .fn (* acc_tuple , * input_tuple )
545+ for j in range (len (output_data )):
546+ output_data [j ][output_index ] = acc_tuple [j ]
547+ # Pack output
548+ ret = []
549+ for data in output_data :
550+ if keep_dims :
551+ if original_axis is not None :
552+ data = np .expand_dims (data , axis )
553+ else :
554+ input_shape = original_input [0 ].handle .data .shape
555+ for _ in range (len (input_shape )):
556+ data = np .expand_dims (data , 0 )
557+
558+ elif original_axis is None :
559+ # Take a scalar
560+ data = data .item ()
561+ ret .append (_to_tensor (data , input [0 ].dtype ))
562+ return ret [0 ] if len (ret ) == 1 else tuple (ret )
563+
564+ def _new_reduce (input , axis , combine_fn , keep_dims = False , ** kwargs ):
565+
566+ def _min_max (input , val_reduce_op , idx_reduce_op = None , axis = None , keepdims = False ):
567+ # If input is a tuple, it must be (val, index), and we only take val
568+ input = input [0 ] if isinstance (input , tuple ) else input
569+ _check_axis (input , axis )
510570 val = None
511571 idx = None
512572 if val_reduce_op :
@@ -523,53 +583,42 @@ def _min_max(input, val_reduce_op, idx_reduce_op=None, axis=None, return_indices
523583 raise ValueError ("val_reduce_op and idx_reduce_op are both None" )
524584
525585 def _sum (input , axis = None , keepdims = False ):
586+ _check_axis (input , axis )
526587 return _to_tensor (np .sum (input .handle .data , axis = axis , keepdims = keepdims ), input .dtype )
527588
528- keep_dims = kwargs .get ("keep_dims" , False )
529- return_indices = kwargs .get ("return_indices" , False )
530- return_indices_tile_break_left = kwargs .get ("return_indices_tile_break_left" , True )
531- fn = combine_fn .fn .__name__
532589 mapping = {
533- "_elementwise_min" : #
534- functools .partial (_min_max , val_reduce_op = np .min , idx_reduce_op = np .argmin if return_indices else None ,
535- return_indices_tie_break_left = return_indices_tile_break_left ), #
536- "_elementwise_max" : #
537- functools .partial (_min_max , val_reduce_op = np .max , idx_reduce_op = np .argmax if return_indices else None ,
538- return_indices_tie_break_left = return_indices_tile_break_left ), #
539- "_argmin_combine" : #
540- functools .partial (_min_max , val_reduce_op = None , idx_reduce_op = np .argmin ,
541- return_indices_tie_break_left = return_indices_tile_break_left ), #
542- "_argmax_combine" : #
543- functools .partial (_min_max , val_reduce_op = None , idx_reduce_op = np .argmax ,
544- return_indices_tie_break_left = return_indices_tile_break_left ), #
545- "_sum_combine" : _sum
590+ tl .standard ._argmin_combine_tie_break_left : #
591+ functools .partial (_min_max , val_reduce_op = np .min , idx_reduce_op = np .argmin ), #
592+ tl .standard ._argmax_combine_tie_break_left : #
593+ functools .partial (_min_max , val_reduce_op = np .max , idx_reduce_op = np .argmax ), #
594+ tl .standard ._elementwise_max : functools .partial (_min_max , val_reduce_op = np .max , idx_reduce_op = None ), #
595+ tl .standard ._elementwise_min : functools .partial (_min_max , val_reduce_op = np .min , idx_reduce_op = None ), #
596+ tl .standard ._sum_combine : _sum , #
546597 }
547- if fn not in mapping :
548- raise ValueError (f"fn { fn } not supported" )
549- return mapping [fn ](input , axis = axis , keepdims = keep_dims )
598+ if combine_fn not in mapping :
599+ # Fall back to the slow mode
600+ return _generic_reduce (input , axis , combine_fn , keep_dims )
601+ return mapping [combine_fn ](input , axis = axis , keepdims = keep_dims )
550602
551603 def _new_scan (input , axis , combine_fn , ** kwargs ):
552- fn = combine_fn .fn .__name__
553604 mapping = {
554- " _sum_combine" : np .cumsum ,
605+ tl . standard . _sum_combine : np .cumsum ,
555606 }
556- ret = mapping [fn ](input .handle .data , axis = axis )
607+ ret = mapping [combine_fn ](input .handle .data , axis = axis )
557608 ret_type = tl .block_type (input .dtype , ret .shape )
558609 return tl .core .tensor (TensorHandle (ret , input .dtype ), ret_type )
559610
560- def _new_reduce_scan_wrapper (mode , input , axis = None , ** kwargs ):
561- impl_fn = _new_scan if mode .startswith ("cum" ) else _new_reduce
562- mode = mode [3 :] if mode .startswith ("cum" ) else mode
563- combine_fn = {
564- "min" : tl .standard ._elementwise_min ,
565- "max" : tl .standard ._elementwise_max ,
566- "sum" : tl .standard ._sum_combine ,
567- "argmin" : tl .standard ._argmin_combine ,
568- "argmax" : tl .standard ._argmax_combine ,
569- }
570- if mode not in combine_fn :
571- raise ValueError (f"mode { mode } not supported" )
572- return impl_fn (input , axis , combine_fn [mode ], ** kwargs )
611+ tl .reduce = _new_reduce
612+ tl .associative_scan = _new_scan
613+ # FIXME(Keren): This is a workaround because some core functions use core.reduce but not tl.reduce
614+ tl .core .reduce = _new_reduce
615+ tl .core .associative_scan = _new_scan
616+
617+
618+ def _patch_lang_core (lang , builder ):
619+ for name , member in inspect .getmembers (lang ):
620+ if tl .core .is_builtin (member ):
621+ _patch_attr (lang , name , member , builder )
573622
574623 def _new_to_ir (self , builder ):
575624 # We need to specify signedness for integer types in the numpy mode
@@ -611,14 +660,6 @@ def _new_to_ir(self, builder):
611660 return builder .get_double_ty ()
612661 raise ValueError (f'fail to convert { self } to ir type' )
613662
614- lang .reduce = _new_reduce
615- lang .min = functools .partial (_new_reduce_scan_wrapper , "min" )
616- lang .max = functools .partial (_new_reduce_scan_wrapper , "max" )
617- lang .sum = functools .partial (_new_reduce_scan_wrapper , "sum" )
618- lang .argmin = functools .partial (_new_reduce_scan_wrapper , "argmin" )
619- lang .argmax = functools .partial (_new_reduce_scan_wrapper , "argmax" )
620- lang .cumsum = functools .partial (_new_reduce_scan_wrapper , "cumsum" )
621-
622663 # can't just map lang.static_range to `range`, because `tl.static_range`
623664 # can get `step` passed by keyword
624665 def _new_range (arg1 , arg2 = None , step = None , ** kwargs ):
@@ -638,6 +679,8 @@ def _new_static_assert(cond, msg=""):
638679 lang .static_assert = _new_static_assert
639680 lang .dtype .to_ir = _new_to_ir
640681
682+ _patch_reduce_scan (lang )
683+
641684
642685def _patch_lang_math (lang ):
643686 mapping = {
0 commit comments