-
Notifications
You must be signed in to change notification settings - Fork 332
[Feature] Support Reduce operators for bitwise and/or/xor #1074
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughThis PR adds support for three bitwise reduction operations (BitAnd, BitOr, BitXor) across the compilation pipeline, from type definitions through codegen to Python API exposure. Additionally, it updates the flash attention backward example's postprocessing function signature to accept and return three tensors instead of one. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~22 minutes The changes involve coordinated modifications across multiple interconnected layers (type system, codegen, tests, API). While each individual piece follows established patterns (new reduction types mirror existing ones), the scope spans the full compilation stack requiring attention to consistency across header definitions, template implementations, and codegen logic. The new test module adds complexity but provides good coverage validation. Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/op/reduce.cc (1)
178-187: Add an early dtype guard for bitwise reductions.Preempt CUDA template instantiation on non-integer types. This fails fast at lowering instead of during compilation.
Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { + // Sanity: bitwise reductions only make sense for integer/bool dtypes + if (type->isBitAnd() || type->isBitOr() || type->isBitXor()) { + auto dd = this->dst->dtype; + ICHECK(dd.is_int() || dd.is_uint() || dd.is_bool()) + << "Bitwise reductions require integer/bool dtype, got " << dd; + }
🧹 Nitpick comments (9)
src/op/reduce.cc (1)
237-243: Clarify always-duplicate choice for BitAnd; consider symmetry with Or/Xor.You set need_duplicate:
- BitAnd: always true.
- BitOr/Xor: only when !clear.
That’s fine if BitAnd has an in-place hazard, but it’d help to document why BitAnd cannot do the in-place “clear=True” path that Or/Xor use. If not strictly required, aligning the policy across bitwise ops would simplify reasoning.
Also applies to: 339-361
src/tl_templates/cuda/reduce.h (1)
25-41: Optional: constrain functors to integral/bool T to catch misuse at compile-time.If these ever instantiate with non-integer types, you’ll hit cryptic errors. A minimal static assert inside operator() makes failures clearer. Alternatively, rely on the new IR dtype checks.
struct BitAndOp { - template <typename T> TL_DEVICE T operator()(T const &x, T const &y) { + template <typename T> TL_DEVICE T operator()(T const &x, T const &y) { + static_assert(std::is_integral<T>::value || std::is_same<T,bool>::value, + "BitAndOp requires integral/bool T"); return x & y; } }; struct BitOrOp { template <typename T> TL_DEVICE T operator()(T const &x, T const &y) { + static_assert(std::is_integral<T>::value || std::is_same<T,bool>::value, + "BitOrOp requires integral/bool T"); return x | y; } }; struct BitXorOp { template <typename T> TL_DEVICE T operator()(T const &x, T const &y) { + static_assert(std::is_integral<T>::value || std::is_same<T,bool>::value, + "BitXorOp requires integral/bool T"); return x ^ y; } };tilelang/language/reduce.py (2)
14-26: Update reduce() docstring to include the new reduce types.The “reduce_type” docs still list only max/min/sum/abssum. Please add bitand/bitor/bitxor.
- reduce_type (str): Type of reduction ('max', 'min', 'sum', 'abssum') + reduce_type (str): Type of reduction ('max', 'min', 'sum', 'abssum', 'bitand', 'bitor', 'bitxor')
142-155: Document dtype expectations for bitwise reducers.These ops require integer or boolean buffers; clarify in the docstrings to prevent misuse.
-def reduce_bitand(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True): - """Perform reduce bitwise-and on input buffer, store the result to output buffer. +def reduce_bitand(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True): + """Perform bitwise-AND reduction along `dim`. + Note: buffer/out dtypes must be integer or boolean. @@ -def reduce_bitor(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True): - """Perform reduce bitwise-or on input buffer, store the result to output buffer. +def reduce_bitor(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True): + """Perform bitwise-OR reduction along `dim`. + Note: buffer/out dtypes must be integer or boolean. @@ -def reduce_bitxor(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True): - """Perform reduce bitwise-xor on input buffer, store the result to output buffer. +def reduce_bitxor(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True): + """Perform bitwise-XOR reduction along `dim`. + Note: buffer/out dtypes must be integer or boolean.Also, consider adding a small test covering boolean tensors for each reducer to validate semantics.
Also applies to: 157-170, 172-185
examples/flash_attention/example_gqa_bwd_tma_reduce.py (2)
446-447: Avoid computing K/V in postprocess when using split-path; add a dQ-only post kernel or reuse buffersCalling
mod_post(dq, zeros_like(k), zeros_like(v))allocates and moves two large tensors only to drop them. Prefer:
- Define a
flashattn_bwd_postprocess_dqonlyvariant without_idx=[3]returning justdQ_out, and use it here; or- Reuse small, persistent scratch buffers for K/V to avoid per-call allocs.
This reduces memory traffic and kernel time in the split path without changing numerics.
362-366: Usedk_sharedfor the final copy (match thedvpath)
dvcopies viadv_shared(swizzled), butdkcopies fromdkdirectly. For consistency and potential coalescing benefits, copy fromdk_shared:T.copy(dk, dk_shared) -T.copy(dk, dK[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :]) +T.copy(dk_shared, dK[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :])testing/python/math/test_math_bitwise_reduce.py (3)
13-21: Remove unusednameparameter to satisfy lint and simplify API
nameisn’t used insidebitwise_reduce(Ruff ARG001). Drop it from the signature and call site:- def bitwise_reduce( - M, N, block_M, block_N, name, func, clear=True, - ): + def bitwise_reduce( + M, N, block_M, block_N, func, clear=True, + ):- kernel = bitwise_reduce(M, N, block_M, block_N, name, func, clear) + kernel = bitwise_reduce(M, N, block_M, block_N, func, clear)Also applies to: 45-52
104-111: Strengthenclear=Falsecoverage by pre-filling B with non-neutral valuesCurrently
clear=Falseuses neutral initialB, so behavior equalsclear=True. To verify accumulation semantics, initializeBwith non-neutral values and include them in the CPU expected:
- For bitand: start from a non-all-ones mask (e.g., alternating bits).
- For bitor/bitxor: start from a non-zero pattern.
I can draft a small patch if desired.
66-75: Optionally include bit 31 in patternsUse
j % 32to exercise the sign bit in int32:- col_pattern = (1 << (j % 31)) + col_pattern = (1 << (j % 32))This expands coverage for AND/OR/XOR on the highest bit.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
examples/flash_attention/example_gqa_bwd_tma_reduce.py(1 hunks)src/op/reduce.cc(7 hunks)src/op/reduce.h(3 hunks)src/tl_templates/cuda/reduce.h(1 hunks)testing/python/math/test_math_bitwise_reduce.py(1 hunks)tilelang/language/__init__.py(1 hunks)tilelang/language/reduce.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
src/op/reduce.cc (1)
tilelang/language/tir/op.py (3)
bitwise_and(2403-2422)bitwise_or(2444-2463)bitwise_xor(2466-2485)
tilelang/language/__init__.py (1)
tilelang/language/reduce.py (3)
reduce_bitand(142-154)reduce_bitor(157-169)reduce_bitxor(172-184)
testing/python/math/test_math_bitwise_reduce.py (5)
tilelang/jit/__init__.py (1)
jit(244-317)tilelang/transform/pass_config.py (1)
PassConfigKey(6-104)tilelang/language/allocate.py (2)
alloc_shared(21-36)alloc_fragment(53-64)tilelang/language/copy.py (1)
copy(10-86)tilelang/language/reduce.py (3)
reduce_bitand(142-154)reduce_bitor(157-169)reduce_bitxor(172-184)
🪛 Ruff (0.14.0)
tilelang/language/__init__.py
61-61: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
62-62: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
63-63: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
testing/python/math/test_math_bitwise_reduce.py
18-18: Unused function argument: name
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (1)
src/op/reduce.h (1)
24-27: New bitwise ReduceType additions look correct.Enum values, query helpers, and string mapping are consistent with existing patterns.
Also applies to: 56-59, 77-83
| } else if (type->isBitAnd()) { | ||
| if (is_int) { | ||
| return make_const(dst->dtype, -1); | ||
| } else if (is_uint) { | ||
| return make_const(dst->dtype, (1 << bits) - 1); | ||
| } else { | ||
| // Should not arrive here | ||
| return make_const(dst->dtype, -INFINITY); | ||
| } | ||
| } else if (type->isBitOr()) { | ||
| return make_zero(dst->dtype); | ||
| } else if (type->isBitXor()) { | ||
| return make_zero(dst->dtype); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bitwise init: guard dtype and avoid 64-bit shift UB; add bool handling.
- Current
(1 << bits) - 1overflows/UB when bits=64, and bool falls into the “-INFINITY” path. - Without dtype guards, bitwise reducers can instantiate CUDA functors on float/half and fail compilation.
Fix by asserting integer/bool dtypes and computing all‑ones safely.
} else if (type->isAbsMax()) {
return make_const(dst->dtype, 0);
- } else if (type->isBitAnd()) {
- if (is_int) {
- return make_const(dst->dtype, -1);
- } else if (is_uint) {
- return make_const(dst->dtype, (1 << bits) - 1);
- } else {
- // Should not arrive here
- return make_const(dst->dtype, -INFINITY);
- }
- } else if (type->isBitOr()) {
- return make_zero(dst->dtype);
- } else if (type->isBitXor()) {
- return make_zero(dst->dtype);
+ } else if (type->isBitAnd()) {
+ bool is_bool = dst_dtype.is_bool();
+ ICHECK(is_int || is_uint || is_bool)
+ << "Bitwise AND reduction requires integer or bool dtype, got " << dst_dtype;
+ if (is_int) {
+ return make_const(dst->dtype, -1);
+ } else if (is_uint) {
+ uint64_t mask = (bits >= 64) ? ~uint64_t(0) : ((uint64_t(1) << bits) - 1);
+ return make_const(dst->dtype, mask);
+ } else { // bool
+ return make_const(dst->dtype, true);
+ }
+ } else if (type->isBitOr()) {
+ bool is_bool = dst_dtype.is_bool();
+ ICHECK(is_int || is_uint || is_bool)
+ << "Bitwise OR reduction requires integer or bool dtype, got " << dst_dtype;
+ return make_zero(dst->dtype);
+ } else if (type->isBitXor()) {
+ bool is_bool = dst_dtype.is_bool();
+ ICHECK(is_int || is_uint || is_bool)
+ << "Bitwise XOR reduction requires integer or bool dtype, got " << dst_dtype;
+ return make_zero(dst->dtype);Committable suggestion skipped: line range outside the PR's diff.
| with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): | ||
| A_shared = T.alloc_shared((block_M, block_N), "int32") | ||
| A_fragment = T.alloc_fragment((block_M, block_N), "int32") | ||
| B_shared = T.alloc_shared((block_M,), "int32") | ||
| B_fragment = T.alloc_fragment((block_M), "int32") | ||
| T.copy(A[by * block_M, bx * block_N], A_shared) | ||
| T.copy(A_shared, A_fragment) | ||
| T.copy(B[by * block_M], B_shared) | ||
| T.copy(B_shared, B_fragment) | ||
| func(A_fragment, B_fragment, clear=clear) | ||
| T.copy(B_fragment, Output[by * block_M]) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Block copies should slice regions, not scalar indices
These T.copy calls index scalars as sources/targets where block regions are intended. Slice the regions to make extents explicit and tail-safe:
- T.copy(A[by * block_M, bx * block_N], A_shared)
+ T.copy(A[by * block_M:(by + 1) * block_M,
+ bx * block_N:(bx + 1) * block_N], A_shared)
T.copy(A_shared, A_fragment)
- T.copy(B[by * block_M], B_shared)
+ T.copy(B[by * block_M:(by + 1) * block_M], B_shared)
T.copy(B_shared, B_fragment)
func(A_fragment, B_fragment, clear=clear)
- T.copy(B_fragment, Output[by * block_M])
+ T.copy(B_fragment, Output[by * block_M:(by + 1) * block_M])This mirrors slicing used elsewhere in the codebase and avoids ambiguous pointer-style copies. As per coding guidelines.
🤖 Prompt for AI Agents
testing/python/math/test_math_bitwise_reduce.py around lines 29 to 40: several
T.copy calls are using scalar indices where a blocked region is intended; change
the source and destination arguments to use explicit slices for the block
extents (e.g. row and column ranges of length block_M/block_N) so copies operate
on the full tile and are tail-safe and consistent with the rest of the codebase
— replace scalar-indexed A[by * block_M, bx * block_N] and A_shared, B[by *
block_M] and B_shared, and B_fragment/Output indexing with proper slice ranges
covering block_M and block_N extents respectively.
| reduce_bitand, # noqa: F401 | ||
| reduce_bitor, # noqa: F401 | ||
| reduce_bitxor, # noqa: F401 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unused noqa markers.
These three exports are valid; the trailing “# noqa: F401” is unnecessary and flagged by Ruff (RUF100). Please drop them.
- reduce_bitand, # noqa: F401
- reduce_bitor, # noqa: F401
- reduce_bitxor, # noqa: F401
+ reduce_bitand,
+ reduce_bitor,
+ reduce_bitxor,📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| reduce_bitand, # noqa: F401 | |
| reduce_bitor, # noqa: F401 | |
| reduce_bitxor, # noqa: F401 | |
| reduce_bitand, | |
| reduce_bitor, | |
| reduce_bitxor, |
🧰 Tools
🪛 Ruff (0.14.0)
61-61: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
62-62: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
63-63: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
🤖 Prompt for AI Agents
In tilelang/language/__init__.py around lines 61 to 63, the three exported names
reduce_bitand, reduce_bitor, and reduce_bitxor include unnecessary "# noqa:
F401" markers; remove those trailing noqa comments so the exports remain but the
redundant suppression is gone, leaving the names listed without the "# noqa:
F401" suffix.
* [Feature] Support Reduce operators for bitwise and/or/xor * [Lint]
As title
For issue #1039
Summary by CodeRabbit
New Features
Tests