Skip to content

Commit 5529363

Browse files
authored
[Bugfix] Expose alloc_reducer definition to the python side (#802)
- Introduced a new function `alloc_reducer` to allocate a reducer buffer with specified shape, data type, and reduction operation (sum, max, min). - Added detailed documentation for the function, including usage instructions and parameter descriptions. - Ensured that the function supports replication strategies and includes assertions for valid operation types and replication options. This enhancement improves the functionality of buffer management in TileLang, facilitating efficient reduction operations in parallel loops.
1 parent 91a7bb2 commit 5529363

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

tilelang/language/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
alloc_shared, # noqa: F401
4242
alloc_fragment, # noqa: F401
4343
alloc_barrier, # noqa: F401
44+
alloc_reducer, # noqa: F401
4445
)
4546
from .copy import copy, c2d_im2col # noqa: F401
4647
from .gemm import GemmWarpPolicy, gemm, gemm_v2 # noqa: F401

tilelang/language/allocate.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,40 @@ def alloc_barrier(arrive_count: int):
8787
T.Buffer: A TVM buffer object allocated as a barrier
8888
"""
8989
return T.alloc_buffer([arrive_count], "uint64", scope="shared.barrier")
90+
91+
92+
def alloc_reducer(shape, dtype, op="sum", replication=None):
93+
"""
94+
Allocate a reducer buffer.
95+
96+
Modifications needs to conform with `op`,
97+
such as `op="sum"` requires `reducer[...] += ...` and
98+
`op="max"` requires `reducer[...] = T.max(reducer[...], ...)`.
99+
100+
Only after T.fill with proper initializer the reduction may begin;
101+
only after T.finalize_reducer the partial results will be available.
102+
103+
For `op="sum"`, filled value must be 0; for min and max, the filled initializer will become max or min clamper correspondingly.
104+
You may want to use `T.max_value` for min and `T.min_value` for max.
105+
106+
Args:
107+
shape (tuple): The shape of the buffer to allocate
108+
dtype (str): The data type of the buffer (e.g., 'float32', 'int32')
109+
op (str): The reduce operation corresponded with the reducer
110+
replication (str | None): Replication strategy, can be "all" or "none". Defaults to not specified, and the compiler will do whatever it want.
111+
112+
Returns:
113+
T.Buffer: A TVM buffer object allocated in thread-private storage, available to reduce values in T.Parallel loops.
114+
"""
115+
import tilelang.language as TL
116+
117+
assert op in ["sum", "max", "min"]
118+
# TODO: support automatic layout
119+
if replication is None:
120+
replication = "none"
121+
assert replication in ["all", "none"]
122+
123+
reducer = T.alloc_buffer(shape, dtype, scope="local.fragment")
124+
TL.block_attr({"reducer_info": {reducer.data: {"rep": replication, "op": op}}})
125+
126+
return reducer

0 commit comments

Comments
 (0)