diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index cd23ea59e9e1..b908b21df478 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1133,6 +1133,10 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str, # get result type shape = input.type.shape + + rank = len(shape) + assert 0 <= axis < rank, f"axis (v={axis}) is out of range, should be within [0, {rank})" + ret_shape = [] for i, s in enumerate(shape): if i != axis: