Skip to content

Conversation

@tzj-fxz
Copy link
Contributor

@tzj-fxz tzj-fxz commented Oct 20, 2025

As title
For issue #1039

Summary by CodeRabbit

  • New Features

    • Added bitwise reduction operations: AND, OR, and XOR are now available for GPU-accelerated computations.
  • Tests

    • Added comprehensive test coverage for bitwise reduction operations across multiple configurations.

@tzj-fxz tzj-fxz requested a review from LeiWang1999 October 20, 2025 07:53
@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 20, 2025

Walkthrough

This PR adds support for three bitwise reduction operations (BitAnd, BitOr, BitXor) across the compilation pipeline, from type definitions through codegen to Python API exposure. Additionally, it updates the flash attention backward example's postprocessing function signature to accept and return three tensors instead of one.

Changes

Cohort / File(s) Summary
Bitwise Reduction Type Definitions
src/op/reduce.h
Added three new enum values (kBitAnd, kBitOr, kBitXor) to ReduceTypeEnum and corresponding query methods (isBitAnd(), isBitOr(), isBitXor()) to ReduceTypeNode. Extended ReduceType string-based constructor to recognize "bitand", "bitor", "bitxor" keys.
Bitwise Reduction Functors
src/tl_templates/cuda/reduce.h
Added three new templated reduction functor structs: BitAndOp, BitOrOp, and BitXorOp, each implementing the corresponding bitwise operation pattern.
Bitwise Reduction Codegen
src/op/reduce.cc
Integrated bitwise reduction support throughout codegen pipeline: init value calculation, per-step reduction logic, code generator reducer selection, lowering transformations, and final write-back to destination buffer, including handling for clear vs. non-clear paths.
Python API Exposure
tilelang/language/reduce.py, tilelang/language/__init__.py
Added three new public reduction helper functions (reduce_bitand, reduce_bitor, reduce_bitxor) that delegate to the core reduce function with corresponding operation names; exported them via language package __init__.py.
Bitwise Reduction Tests
testing/python/math/test_math_bitwise_reduce.py
New test module introducing bitwise_reduce() kernel builder, run_single_bitwise_reduce() test runner, and test_bitwise_reduce_ops() test function covering all three bitwise operations under both clear=True and clear=False configurations.
Flash Attention Example Update
examples/flash_attention/example_gqa_bwd_tma_reduce.py
Updated postprocessing function call signature from single-argument/single-output dq = mod_post(dq) to three-argument/three-output dq, _, _ = mod_post(dq, torch.zeros_like(k, dtype=torch.float32), torch.zeros_like(v, dtype=torch.float32)).

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~22 minutes

The changes involve coordinated modifications across multiple interconnected layers (type system, codegen, tests, API). While each individual piece follows established patterns (new reduction types mirror existing ones), the scope spans the full compilation stack requiring attention to consistency across header definitions, template implementations, and codegen logic. The new test module adds complexity but provides good coverage validation.

Possibly related PRs

  • tile-ai/tilelang#969: Updates flash attention backward example's postprocessing function to accept and return three tensors, aligning with the related change in this PR.

Suggested reviewers

  • LeiWang1999

Poem

🐰 Three bitwise friends now hop along,
And-Or-Xor, a cryptic song!
From reduce.h to kernels bright,
Post-processing tensors take their flight. ✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 40.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title "[Feature] Support Reduce operators for bitwise and/or/xor" is clear, specific, and directly related to the primary objective of the changeset. The PR introduces support for three new bitwise reduction operations (AND, OR, XOR) across the entire codebase—from core compiler infrastructure in src/op/reduce.h and src/op/reduce.cc, to CUDA implementations in src/tl_templates/cuda/reduce.h, to public API in tilelang/language/reduce.py, tests, and examples. The title concisely captures this scope without unnecessary detail, and a teammate scanning history would immediately understand that this PR adds new reduction operation types for bitwise logic.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/op/reduce.cc (1)

178-187: Add an early dtype guard for bitwise reductions.

Preempt CUDA template instantiation on non-integer types. This fails fast at lowering instead of during compilation.

 Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
+  // Sanity: bitwise reductions only make sense for integer/bool dtypes
+  if (type->isBitAnd() || type->isBitOr() || type->isBitXor()) {
+    auto dd = this->dst->dtype;
+    ICHECK(dd.is_int() || dd.is_uint() || dd.is_bool())
+        << "Bitwise reductions require integer/bool dtype, got " << dd;
+  }
🧹 Nitpick comments (9)
src/op/reduce.cc (1)

237-243: Clarify always-duplicate choice for BitAnd; consider symmetry with Or/Xor.

You set need_duplicate:

  • BitAnd: always true.
  • BitOr/Xor: only when !clear.

That’s fine if BitAnd has an in-place hazard, but it’d help to document why BitAnd cannot do the in-place “clear=True” path that Or/Xor use. If not strictly required, aligning the policy across bitwise ops would simplify reasoning.

Also applies to: 339-361

src/tl_templates/cuda/reduce.h (1)

25-41: Optional: constrain functors to integral/bool T to catch misuse at compile-time.

If these ever instantiate with non-integer types, you’ll hit cryptic errors. A minimal static assert inside operator() makes failures clearer. Alternatively, rely on the new IR dtype checks.

 struct BitAndOp {
-  template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
+  template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
+    static_assert(std::is_integral<T>::value || std::is_same<T,bool>::value,
+                  "BitAndOp requires integral/bool T");
     return x & y;
   }
 };
 
 struct BitOrOp {
   template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
+    static_assert(std::is_integral<T>::value || std::is_same<T,bool>::value,
+                  "BitOrOp requires integral/bool T");
     return x | y;
   }
 };
 
 struct BitXorOp {
   template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
+    static_assert(std::is_integral<T>::value || std::is_same<T,bool>::value,
+                  "BitXorOp requires integral/bool T");
     return x ^ y;
   }
 };
tilelang/language/reduce.py (2)

14-26: Update reduce() docstring to include the new reduce types.

The “reduce_type” docs still list only max/min/sum/abssum. Please add bitand/bitor/bitxor.

-        reduce_type (str): Type of reduction ('max', 'min', 'sum', 'abssum')
+        reduce_type (str): Type of reduction ('max', 'min', 'sum', 'abssum', 'bitand', 'bitor', 'bitxor')

142-155: Document dtype expectations for bitwise reducers.

These ops require integer or boolean buffers; clarify in the docstrings to prevent misuse.

-def reduce_bitand(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True):
-    """Perform reduce bitwise-and on input buffer, store the result to output buffer.
+def reduce_bitand(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True):
+    """Perform bitwise-AND reduction along `dim`.
+    Note: buffer/out dtypes must be integer or boolean.
@@
-def reduce_bitor(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True):
-    """Perform reduce bitwise-or on input buffer, store the result to output buffer.
+def reduce_bitor(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True):
+    """Perform bitwise-OR reduction along `dim`.
+    Note: buffer/out dtypes must be integer or boolean.
@@
-def reduce_bitxor(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True):
-    """Perform reduce bitwise-xor on input buffer, store the result to output buffer.
+def reduce_bitxor(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True):
+    """Perform bitwise-XOR reduction along `dim`.
+    Note: buffer/out dtypes must be integer or boolean.

Also, consider adding a small test covering boolean tensors for each reducer to validate semantics.

Also applies to: 157-170, 172-185

examples/flash_attention/example_gqa_bwd_tma_reduce.py (2)

446-447: Avoid computing K/V in postprocess when using split-path; add a dQ-only post kernel or reuse buffers

Calling mod_post(dq, zeros_like(k), zeros_like(v)) allocates and moves two large tensors only to drop them. Prefer:

  • Define a flashattn_bwd_postprocess_dqonly variant with out_idx=[3] returning just dQ_out, and use it here; or
  • Reuse small, persistent scratch buffers for K/V to avoid per-call allocs.

This reduces memory traffic and kernel time in the split path without changing numerics.


362-366: Use dk_shared for the final copy (match the dv path)

dv copies via dv_shared (swizzled), but dk copies from dk directly. For consistency and potential coalescing benefits, copy from dk_shared:

 T.copy(dk, dk_shared)
-T.copy(dk, dK[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :])
+T.copy(dk_shared, dK[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :])
testing/python/math/test_math_bitwise_reduce.py (3)

13-21: Remove unused name parameter to satisfy lint and simplify API

name isn’t used inside bitwise_reduce (Ruff ARG001). Drop it from the signature and call site:

- def bitwise_reduce(
-     M, N, block_M, block_N, name, func, clear=True,
- ):
+ def bitwise_reduce(
+     M, N, block_M, block_N, func, clear=True,
+ ):
-    kernel = bitwise_reduce(M, N, block_M, block_N, name, func, clear)
+    kernel = bitwise_reduce(M, N, block_M, block_N, func, clear)

Also applies to: 45-52


104-111: Strengthen clear=False coverage by pre-filling B with non-neutral values

Currently clear=False uses neutral initial B, so behavior equals clear=True. To verify accumulation semantics, initialize B with non-neutral values and include them in the CPU expected:

  • For bitand: start from a non-all-ones mask (e.g., alternating bits).
  • For bitor/bitxor: start from a non-zero pattern.

I can draft a small patch if desired.


66-75: Optionally include bit 31 in patterns

Use j % 32 to exercise the sign bit in int32:

-            col_pattern = (1 << (j % 31))
+            col_pattern = (1 << (j % 32))

This expands coverage for AND/OR/XOR on the highest bit.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 27701c3 and 0395ea1.

📒 Files selected for processing (7)
  • examples/flash_attention/example_gqa_bwd_tma_reduce.py (1 hunks)
  • src/op/reduce.cc (7 hunks)
  • src/op/reduce.h (3 hunks)
  • src/tl_templates/cuda/reduce.h (1 hunks)
  • testing/python/math/test_math_bitwise_reduce.py (1 hunks)
  • tilelang/language/__init__.py (1 hunks)
  • tilelang/language/reduce.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
src/op/reduce.cc (1)
tilelang/language/tir/op.py (3)
  • bitwise_and (2403-2422)
  • bitwise_or (2444-2463)
  • bitwise_xor (2466-2485)
tilelang/language/__init__.py (1)
tilelang/language/reduce.py (3)
  • reduce_bitand (142-154)
  • reduce_bitor (157-169)
  • reduce_bitxor (172-184)
testing/python/math/test_math_bitwise_reduce.py (5)
tilelang/jit/__init__.py (1)
  • jit (244-317)
tilelang/transform/pass_config.py (1)
  • PassConfigKey (6-104)
tilelang/language/allocate.py (2)
  • alloc_shared (21-36)
  • alloc_fragment (53-64)
tilelang/language/copy.py (1)
  • copy (10-86)
tilelang/language/reduce.py (3)
  • reduce_bitand (142-154)
  • reduce_bitor (157-169)
  • reduce_bitxor (172-184)
🪛 Ruff (0.14.0)
tilelang/language/__init__.py

61-61: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


62-62: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


63-63: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

testing/python/math/test_math_bitwise_reduce.py

18-18: Unused function argument: name

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (1)
src/op/reduce.h (1)

24-27: New bitwise ReduceType additions look correct.

Enum values, query helpers, and string mapping are consistent with existing patterns.

Also applies to: 56-59, 77-83

Comment on lines +73 to +85
} else if (type->isBitAnd()) {
if (is_int) {
return make_const(dst->dtype, -1);
} else if (is_uint) {
return make_const(dst->dtype, (1 << bits) - 1);
} else {
// Should not arrive here
return make_const(dst->dtype, -INFINITY);
}
} else if (type->isBitOr()) {
return make_zero(dst->dtype);
} else if (type->isBitXor()) {
return make_zero(dst->dtype);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Bitwise init: guard dtype and avoid 64-bit shift UB; add bool handling.

  • Current (1 << bits) - 1 overflows/UB when bits=64, and bool falls into the “-INFINITY” path.
  • Without dtype guards, bitwise reducers can instantiate CUDA functors on float/half and fail compilation.

Fix by asserting integer/bool dtypes and computing all‑ones safely.

   } else if (type->isAbsMax()) {
     return make_const(dst->dtype, 0);
-  } else if (type->isBitAnd()) {
-    if (is_int) {
-      return make_const(dst->dtype, -1);
-    } else if (is_uint) {
-      return make_const(dst->dtype, (1 << bits) - 1);
-    } else {
-      // Should not arrive here
-      return make_const(dst->dtype, -INFINITY);
-    }
-  } else if (type->isBitOr()) {
-    return make_zero(dst->dtype);
-  } else if (type->isBitXor()) {
-    return make_zero(dst->dtype);
+  } else if (type->isBitAnd()) {
+    bool is_bool = dst_dtype.is_bool();
+    ICHECK(is_int || is_uint || is_bool)
+        << "Bitwise AND reduction requires integer or bool dtype, got " << dst_dtype;
+    if (is_int) {
+      return make_const(dst->dtype, -1);
+    } else if (is_uint) {
+      uint64_t mask = (bits >= 64) ? ~uint64_t(0) : ((uint64_t(1) << bits) - 1);
+      return make_const(dst->dtype, mask);
+    } else { // bool
+      return make_const(dst->dtype, true);
+    }
+  } else if (type->isBitOr()) {
+    bool is_bool = dst_dtype.is_bool();
+    ICHECK(is_int || is_uint || is_bool)
+        << "Bitwise OR reduction requires integer or bool dtype, got " << dst_dtype;
+    return make_zero(dst->dtype);
+  } else if (type->isBitXor()) {
+    bool is_bool = dst_dtype.is_bool();
+    ICHECK(is_int || is_uint || is_bool)
+        << "Bitwise XOR reduction requires integer or bool dtype, got " << dst_dtype;
+    return make_zero(dst->dtype);

Committable suggestion skipped: line range outside the PR's diff.

Comment on lines +29 to +40
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), "int32")
A_fragment = T.alloc_fragment((block_M, block_N), "int32")
B_shared = T.alloc_shared((block_M,), "int32")
B_fragment = T.alloc_fragment((block_M), "int32")
T.copy(A[by * block_M, bx * block_N], A_shared)
T.copy(A_shared, A_fragment)
T.copy(B[by * block_M], B_shared)
T.copy(B_shared, B_fragment)
func(A_fragment, B_fragment, clear=clear)
T.copy(B_fragment, Output[by * block_M])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Block copies should slice regions, not scalar indices

These T.copy calls index scalars as sources/targets where block regions are intended. Slice the regions to make extents explicit and tail-safe:

-            T.copy(A[by * block_M, bx * block_N], A_shared)
+            T.copy(A[by * block_M:(by + 1) * block_M,
+                     bx * block_N:(bx + 1) * block_N], A_shared)
             T.copy(A_shared, A_fragment)
-            T.copy(B[by * block_M], B_shared)
+            T.copy(B[by * block_M:(by + 1) * block_M], B_shared)
             T.copy(B_shared, B_fragment)
             func(A_fragment, B_fragment, clear=clear)
-            T.copy(B_fragment, Output[by * block_M])
+            T.copy(B_fragment, Output[by * block_M:(by + 1) * block_M])

This mirrors slicing used elsewhere in the codebase and avoids ambiguous pointer-style copies. As per coding guidelines.

🤖 Prompt for AI Agents
testing/python/math/test_math_bitwise_reduce.py around lines 29 to 40: several
T.copy calls are using scalar indices where a blocked region is intended; change
the source and destination arguments to use explicit slices for the block
extents (e.g. row and column ranges of length block_M/block_N) so copies operate
on the full tile and are tail-safe and consistent with the rest of the codebase
— replace scalar-indexed A[by * block_M, bx * block_N] and A_shared, B[by *
block_M] and B_shared, and B_fragment/Output indexing with proper slice ranges
covering block_M and block_N extents respectively.

Comment on lines +61 to +63
reduce_bitand, # noqa: F401
reduce_bitor, # noqa: F401
reduce_bitxor, # noqa: F401
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove unused noqa markers.

These three exports are valid; the trailing “# noqa: F401” is unnecessary and flagged by Ruff (RUF100). Please drop them.

-    reduce_bitand,  # noqa: F401
-    reduce_bitor,  # noqa: F401
-    reduce_bitxor,  # noqa: F401
+    reduce_bitand,
+    reduce_bitor,
+    reduce_bitxor,
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
reduce_bitand, # noqa: F401
reduce_bitor, # noqa: F401
reduce_bitxor, # noqa: F401
reduce_bitand,
reduce_bitor,
reduce_bitxor,
🧰 Tools
🪛 Ruff (0.14.0)

61-61: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


62-62: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


63-63: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

🤖 Prompt for AI Agents
In tilelang/language/__init__.py around lines 61 to 63, the three exported names
reduce_bitand, reduce_bitor, and reduce_bitxor include unnecessary "# noqa:
F401" markers; remove those trailing noqa comments so the exports remain but the
redundant suppression is gone, leaving the names listed without the "# noqa:
F401" suffix.

@LeiWang1999 LeiWang1999 merged commit ba410ae into tile-ai:main Oct 20, 2025
6 checks passed
RubiaCx pushed a commit to RubiaCx/tilelang that referenced this pull request Nov 24, 2025
* [Feature] Support Reduce operators for bitwise and/or/xor

* [Lint]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants