diff --git a/tilelang/language/reduce.py b/tilelang/language/reduce.py index fcc01b5a0..e229a7952 100644 --- a/tilelang/language/reduce.py +++ b/tilelang/language/reduce.py @@ -24,6 +24,16 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea Returns: tir.Call: Handle to the reduction operation """ + # input shape: [X, d, Y], expected output shape: [X, Y] or [X, 1, Y] + expected_shapes = [ + buffer.shape[:dim] + buffer.shape[dim + 1:], + buffer.shape[:dim] + [1] + buffer.shape[dim + 1:] + ] + if list(out.shape) not in expected_shapes: + expected_shapes_str = ' or '.join(map(str, expected_shapes)) + raise ValueError( + f"Invalid reduce output shape, buffer shape is {buffer.shape}, dim is {dim}, " + f"output shape is {out.shape}, expected shapes are {expected_shapes_str}") buffer = buffer.access_ptr("r") out = out.access_ptr("w") return tir.call_intrin(