Skip to content

Commit 5ee38fe

Browse files
authored
[INTERPRETER] Support generic reduce (#3412)
This is a simple and general version that matches the underlying triton implementation when reducing within a single thread. `numpy.func` is not used because its reduce op only supports binary inputs, but triton's `combine_fn` can have arbitrary number of inputs.
1 parent bd2f99e commit 5ee38fe

File tree

3 files changed

+110
-65
lines changed

3 files changed

+110
-65
lines changed

python/test/unit/language/test_core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1635,6 +1635,7 @@ def kernel(X, Z1, Z2, N: tl.constexpr):
16351635
np.testing.assert_equal(to_numpy(z2_ref), to_numpy(z2))
16361636

16371637

1638+
@pytest.mark.interpreter
16381639
def test_split_to_scalar(device):
16391640

16401641
@triton.jit
@@ -2602,6 +2603,7 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis):
26022603
np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
26032604

26042605

2606+
@pytest.mark.interpreter
26052607
def test_generic_reduction(device):
26062608

26072609
@triton.jit

python/triton/language/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1968,7 +1968,7 @@ def _reduce_with_indices(input, axis, combine_fn, keep_dims=False, _builder=None
19681968
# -----------------------
19691969

19701970

1971-
def _add_scan_docstr(name: str, return_indices_arg: str = None, tie_break_arg: str = None) -> Callable[[T], T]:
1971+
def _add_scan_docstr(name: str) -> Callable[[T], T]:
19721972

19731973
def _decorator(func: T) -> T:
19741974
docstr = """

python/triton/runtime/interpreter.py

Lines changed: 107 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -484,29 +484,89 @@ def _get_transpose(self):
484484
tensor.T = property(_get_transpose)
485485

486486

487-
def _patch_lang_core(lang, builder):
488-
for name, member in inspect.getmembers(lang):
489-
if tl.core.is_builtin(member):
490-
_patch_attr(lang, name, member, builder)
491-
# reduce is better off with a separate patch due to how
492-
# the builder currently interfaces with custom functions
487+
def _patch_reduce_scan(lang):
493488

494-
def _new_reduce(input, axis, combine_fn, **kwargs):
495-
if axis is not None and axis >= len(input.shape):
496-
raise ValueError(f"axis {axis} out of bounds for shape {input.shape}")
489+
def _check_axis(tensor, axis):
490+
if axis is not None and axis >= len(tensor.shape):
491+
raise ValueError(f"axis {axis} out of bounds for shape {tensor.shape}")
497492

498-
def _to_tensor(ret, dtype):
499-
if ret.shape:
500-
ret_type = tl.block_type(dtype, ret.shape)
493+
def _to_tensor(ret, dtype):
494+
if hasattr(ret, "shape") and ret.shape:
495+
ret_type = tl.block_type(dtype, ret.shape)
496+
else:
497+
ret = np.array([ret], dtype=_get_np_dtype(dtype))
498+
ret_type = dtype
499+
return tl.core.tensor(TensorHandle(ret, dtype), ret_type)
500+
501+
def _generic_reduce(input, axis, combine_fn, keep_dims):
502+
503+
def _check_axis_and_unravel(input, axis):
504+
ret = []
505+
if not isinstance(input, tuple):
506+
input = (input, )
507+
for data in input:
508+
if isinstance(data, tl.core.tensor):
509+
if axis is not None:
510+
_check_axis(data, axis)
511+
ret.append(data)
512+
else:
513+
axis = 0
514+
ret.append(_to_tensor(data.handle.data.flatten(), data.dtype))
515+
return tuple(ret), axis
516+
517+
original_input = input
518+
original_axis = axis
519+
input, axis = _check_axis_and_unravel(input, axis)
520+
input_data = []
521+
output_data = []
522+
input_shape = None
523+
output_shape = None
524+
for arg in input:
525+
if isinstance(arg, tl.core.tensor):
526+
input_shape = arg.handle.data.shape
527+
input_data.append(arg.handle.data)
528+
output_shape = input_shape[0:axis] + input_shape[axis + 1:]
529+
output_data.append(np.zeros(output_shape, dtype=arg.handle.data.dtype))
530+
if not input_shape:
531+
raise ValueError("no tensors found in input")
532+
# Reduce on axis
533+
for i in range(len(input_data[0])):
534+
# Recover input_index from i using input_shape
535+
input_index = np.unravel_index(i, input_shape)
536+
output_index = input_index[0:axis] + input_index[axis + 1:]
537+
input_tuple = tuple(d[input_index] for d in input_data)
538+
if input_index[axis] == 0:
539+
# First element
540+
for j in range(len(output_data)):
541+
output_data[j][output_index] = input_tuple[j]
501542
else:
502-
ret = np.array([ret], dtype=_get_np_dtype(dtype))
503-
ret_type = dtype
504-
return tl.core.tensor(TensorHandle(ret, dtype), ret_type)
505-
506-
def _min_max(input, val_reduce_op, idx_reduce_op=None, axis=None, return_indices_tie_break_left=True,
507-
keepdims=False):
508-
if return_indices_tie_break_left is False:
509-
raise NotImplementedError("return_indices_tie_break_left=False not supported in interpreter mode")
543+
acc_tuple = tuple(o[output_index] for o in output_data)
544+
acc_tuple = combine_fn.fn(*acc_tuple, *input_tuple)
545+
for j in range(len(output_data)):
546+
output_data[j][output_index] = acc_tuple[j]
547+
# Pack output
548+
ret = []
549+
for data in output_data:
550+
if keep_dims:
551+
if original_axis is not None:
552+
data = np.expand_dims(data, axis)
553+
else:
554+
input_shape = original_input[0].handle.data.shape
555+
for _ in range(len(input_shape)):
556+
data = np.expand_dims(data, 0)
557+
558+
elif original_axis is None:
559+
# Take a scalar
560+
data = data.item()
561+
ret.append(_to_tensor(data, input[0].dtype))
562+
return ret[0] if len(ret) == 1 else tuple(ret)
563+
564+
def _new_reduce(input, axis, combine_fn, keep_dims=False, **kwargs):
565+
566+
def _min_max(input, val_reduce_op, idx_reduce_op=None, axis=None, keepdims=False):
567+
# If input is a tuple, it must be (val, index), and we only take val
568+
input = input[0] if isinstance(input, tuple) else input
569+
_check_axis(input, axis)
510570
val = None
511571
idx = None
512572
if val_reduce_op:
@@ -523,53 +583,42 @@ def _min_max(input, val_reduce_op, idx_reduce_op=None, axis=None, return_indices
523583
raise ValueError("val_reduce_op and idx_reduce_op are both None")
524584

525585
def _sum(input, axis=None, keepdims=False):
586+
_check_axis(input, axis)
526587
return _to_tensor(np.sum(input.handle.data, axis=axis, keepdims=keepdims), input.dtype)
527588

528-
keep_dims = kwargs.get("keep_dims", False)
529-
return_indices = kwargs.get("return_indices", False)
530-
return_indices_tile_break_left = kwargs.get("return_indices_tile_break_left", True)
531-
fn = combine_fn.fn.__name__
532589
mapping = {
533-
"_elementwise_min": #
534-
functools.partial(_min_max, val_reduce_op=np.min, idx_reduce_op=np.argmin if return_indices else None,
535-
return_indices_tie_break_left=return_indices_tile_break_left), #
536-
"_elementwise_max": #
537-
functools.partial(_min_max, val_reduce_op=np.max, idx_reduce_op=np.argmax if return_indices else None,
538-
return_indices_tie_break_left=return_indices_tile_break_left), #
539-
"_argmin_combine": #
540-
functools.partial(_min_max, val_reduce_op=None, idx_reduce_op=np.argmin,
541-
return_indices_tie_break_left=return_indices_tile_break_left), #
542-
"_argmax_combine": #
543-
functools.partial(_min_max, val_reduce_op=None, idx_reduce_op=np.argmax,
544-
return_indices_tie_break_left=return_indices_tile_break_left), #
545-
"_sum_combine": _sum
590+
tl.standard._argmin_combine_tie_break_left: #
591+
functools.partial(_min_max, val_reduce_op=np.min, idx_reduce_op=np.argmin), #
592+
tl.standard._argmax_combine_tie_break_left: #
593+
functools.partial(_min_max, val_reduce_op=np.max, idx_reduce_op=np.argmax), #
594+
tl.standard._elementwise_max: functools.partial(_min_max, val_reduce_op=np.max, idx_reduce_op=None), #
595+
tl.standard._elementwise_min: functools.partial(_min_max, val_reduce_op=np.min, idx_reduce_op=None), #
596+
tl.standard._sum_combine: _sum, #
546597
}
547-
if fn not in mapping:
548-
raise ValueError(f"fn {fn} not supported")
549-
return mapping[fn](input, axis=axis, keepdims=keep_dims)
598+
if combine_fn not in mapping:
599+
# Fall back to the slow mode
600+
return _generic_reduce(input, axis, combine_fn, keep_dims)
601+
return mapping[combine_fn](input, axis=axis, keepdims=keep_dims)
550602

551603
def _new_scan(input, axis, combine_fn, **kwargs):
552-
fn = combine_fn.fn.__name__
553604
mapping = {
554-
"_sum_combine": np.cumsum,
605+
tl.standard._sum_combine: np.cumsum,
555606
}
556-
ret = mapping[fn](input.handle.data, axis=axis)
607+
ret = mapping[combine_fn](input.handle.data, axis=axis)
557608
ret_type = tl.block_type(input.dtype, ret.shape)
558609
return tl.core.tensor(TensorHandle(ret, input.dtype), ret_type)
559610

560-
def _new_reduce_scan_wrapper(mode, input, axis=None, **kwargs):
561-
impl_fn = _new_scan if mode.startswith("cum") else _new_reduce
562-
mode = mode[3:] if mode.startswith("cum") else mode
563-
combine_fn = {
564-
"min": tl.standard._elementwise_min,
565-
"max": tl.standard._elementwise_max,
566-
"sum": tl.standard._sum_combine,
567-
"argmin": tl.standard._argmin_combine,
568-
"argmax": tl.standard._argmax_combine,
569-
}
570-
if mode not in combine_fn:
571-
raise ValueError(f"mode {mode} not supported")
572-
return impl_fn(input, axis, combine_fn[mode], **kwargs)
611+
tl.reduce = _new_reduce
612+
tl.associative_scan = _new_scan
613+
# FIXME(Keren): This is a workaround because some core functions use core.reduce but not tl.reduce
614+
tl.core.reduce = _new_reduce
615+
tl.core.associative_scan = _new_scan
616+
617+
618+
def _patch_lang_core(lang, builder):
619+
for name, member in inspect.getmembers(lang):
620+
if tl.core.is_builtin(member):
621+
_patch_attr(lang, name, member, builder)
573622

574623
def _new_to_ir(self, builder):
575624
# We need to specify signedness for integer types in the numpy mode
@@ -611,14 +660,6 @@ def _new_to_ir(self, builder):
611660
return builder.get_double_ty()
612661
raise ValueError(f'fail to convert {self} to ir type')
613662

614-
lang.reduce = _new_reduce
615-
lang.min = functools.partial(_new_reduce_scan_wrapper, "min")
616-
lang.max = functools.partial(_new_reduce_scan_wrapper, "max")
617-
lang.sum = functools.partial(_new_reduce_scan_wrapper, "sum")
618-
lang.argmin = functools.partial(_new_reduce_scan_wrapper, "argmin")
619-
lang.argmax = functools.partial(_new_reduce_scan_wrapper, "argmax")
620-
lang.cumsum = functools.partial(_new_reduce_scan_wrapper, "cumsum")
621-
622663
# can't just map lang.static_range to `range`, because `tl.static_range`
623664
# can get `step` passed by keyword
624665
def _new_range(arg1, arg2=None, step=None, **kwargs):
@@ -638,6 +679,8 @@ def _new_static_assert(cond, msg=""):
638679
lang.static_assert = _new_static_assert
639680
lang.dtype.to_ir = _new_to_ir
640681

682+
_patch_reduce_scan(lang)
683+
641684

642685
def _patch_lang_math(lang):
643686
mapping = {

0 commit comments

Comments
 (0)