Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/flash_attention/example_gqa_bwd_tma_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,8 @@ def maybe_contiguous(x):
dk = torch.empty(shape_k, dtype=torch.float16, device=q.device)
dv = torch.empty(shape_v, dtype=torch.float16, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq = mod_post(dq)
dq, _, _ = mod_post(dq, torch.zeros_like(k, dtype=torch.float32),
torch.zeros_like(v, dtype=torch.float32))
dk, dv = dk.sum(0), dv.sum(0)

return dq, dk, dv, None, None, None
Expand Down
63 changes: 62 additions & 1 deletion src/op/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,19 @@ PrimExpr ReduceOpNode::MakeInitValue() const {
}
} else if (type->isAbsMax()) {
return make_const(dst->dtype, 0);
} else if (type->isBitAnd()) {
if (is_int) {
return make_const(dst->dtype, -1);
} else if (is_uint) {
return make_const(dst->dtype, (1 << bits) - 1);
} else {
// Should not arrive here
return make_const(dst->dtype, -INFINITY);
}
} else if (type->isBitOr()) {
return make_zero(dst->dtype);
} else if (type->isBitXor()) {
return make_zero(dst->dtype);
Comment on lines +73 to +85
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Bitwise init: guard dtype and avoid 64-bit shift UB; add bool handling.

  • Current (1 << bits) - 1 overflows/UB when bits=64, and bool falls into the “-INFINITY” path.
  • Without dtype guards, bitwise reducers can instantiate CUDA functors on float/half and fail compilation.

Fix by asserting integer/bool dtypes and computing all‑ones safely.

   } else if (type->isAbsMax()) {
     return make_const(dst->dtype, 0);
-  } else if (type->isBitAnd()) {
-    if (is_int) {
-      return make_const(dst->dtype, -1);
-    } else if (is_uint) {
-      return make_const(dst->dtype, (1 << bits) - 1);
-    } else {
-      // Should not arrive here
-      return make_const(dst->dtype, -INFINITY);
-    }
-  } else if (type->isBitOr()) {
-    return make_zero(dst->dtype);
-  } else if (type->isBitXor()) {
-    return make_zero(dst->dtype);
+  } else if (type->isBitAnd()) {
+    bool is_bool = dst_dtype.is_bool();
+    ICHECK(is_int || is_uint || is_bool)
+        << "Bitwise AND reduction requires integer or bool dtype, got " << dst_dtype;
+    if (is_int) {
+      return make_const(dst->dtype, -1);
+    } else if (is_uint) {
+      uint64_t mask = (bits >= 64) ? ~uint64_t(0) : ((uint64_t(1) << bits) - 1);
+      return make_const(dst->dtype, mask);
+    } else { // bool
+      return make_const(dst->dtype, true);
+    }
+  } else if (type->isBitOr()) {
+    bool is_bool = dst_dtype.is_bool();
+    ICHECK(is_int || is_uint || is_bool)
+        << "Bitwise OR reduction requires integer or bool dtype, got " << dst_dtype;
+    return make_zero(dst->dtype);
+  } else if (type->isBitXor()) {
+    bool is_bool = dst_dtype.is_bool();
+    ICHECK(is_int || is_uint || is_bool)
+        << "Bitwise XOR reduction requires integer or bool dtype, got " << dst_dtype;
+    return make_zero(dst->dtype);

Committable suggestion skipped: line range outside the PR's diff.

} else {
LOG(FATAL) << "Unsupported reduce type: " << type->type;
}
Expand All @@ -91,6 +104,12 @@ PrimExpr ReduceOpNode::MakeReduce(const PrimExpr &lhs,
return Min(lhs, rhs);
} else if (type->isAbsMax()) {
return Max(Max(lhs, rhs), -Min(lhs, rhs));
} else if (type->isBitAnd()) {
return lhs & rhs;
} else if (type->isBitOr()) {
return lhs | rhs;
} else if (type->isBitXor()) {
return lhs ^ rhs;
} else {
LOG(FATAL) << "Unsupported reduce type: " << type->type;
}
Expand All @@ -107,6 +126,12 @@ std::string ReduceOpNode::MakeCodegenReducer() const {
return "tl::MinOp";
} else if (type->isAbsMax()) {
return "tl::MaxOp";
} else if (type->isBitAnd()) {
return "tl::BitAndOp";
} else if (type->isBitOr()) {
return "tl::BitOrOp";
} else if (type->isBitXor()) {
return "tl::BitXorOp";
} else {
LOG(FATAL) << "Unsupported reduce type: " << type->type;
return "";
Expand Down Expand Up @@ -195,6 +220,12 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
require_init = true;
} else if (this->type->isAbsSum()) {
require_init = true;
} else if (this->type->isBitAnd()) {
require_init = true;
} else if (this->type->isBitOr()) {
require_init = true;
} else if (this->type->isBitXor()) {
require_init = true;
}

Buffer clear_buffer = dst_buffer;
Expand All @@ -203,6 +234,12 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
need_duplicate = true;
} else if (this->type->isAbsSum() && !this->clear) {
need_duplicate = true;
} else if (this->type->isBitAnd()) {
need_duplicate = true;
} else if (this->type->isBitOr() && !this->clear) {
need_duplicate = true;
} else if (this->type->isBitXor() && !this->clear) {
need_duplicate = true;
}

if (need_duplicate) {
Expand All @@ -213,9 +250,10 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
}

// make reduce-init stmt
if (require_init)
if (require_init) {
stmts.push_back(
BufferStore(clear_buffer, this->MakeInitValue(), dst_indices));
}

// make thread-local reduce
Array<PrimExpr> src_indice_compressed;
Expand Down Expand Up @@ -298,6 +336,29 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Add(BufferLoad(dst_buffer, dst_indices),
BufferLoad(clear_buffer, dst_indices)),
dst_indices));
} else if (this->type->isBitAnd()) {
if (!this->clear) {
stmts.push_back(
BufferStore(dst_buffer,
bitwise_and(BufferLoad(dst_buffer, dst_indices),
BufferLoad(clear_buffer, dst_indices)),
dst_indices));
} else {
stmts.push_back(BufferStore(
dst_buffer, BufferLoad(clear_buffer, dst_indices), dst_indices));
}
} else if (this->type->isBitOr()) {
stmts.push_back(
BufferStore(dst_buffer,
bitwise_or(BufferLoad(dst_buffer, dst_indices),
BufferLoad(clear_buffer, dst_indices)),
dst_indices));
} else if (this->type->isBitXor()) {
stmts.push_back(
BufferStore(dst_buffer,
bitwise_xor(BufferLoad(dst_buffer, dst_indices),
BufferLoad(clear_buffer, dst_indices)),
dst_indices));
} else {
ICHECK(false) << "Unsupported reduce type: " << this->type->type;
}
Expand Down
12 changes: 12 additions & 0 deletions src/op/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ enum class ReduceTypeEnum : uint8_t {
kMax, ///< Maximum value reduction
kMin, ///< Minimum value reduction
kAbsMax, ///< Maximum absolute value reduction
kBitAnd, ///< Bitwise and reduction
kBitOr, ///< Bitwise or reduction
kBitXor, ///< Bitwise xor reduction
};

/// Node class representing a reduction type
Expand Down Expand Up @@ -50,6 +53,9 @@ class ReduceTypeNode : public Object {
bool isMax() const { return type == int(ReduceTypeEnum::kMax); }
bool isMin() const { return type == int(ReduceTypeEnum::kMin); }
bool isAbsMax() const { return type == int(ReduceTypeEnum::kAbsMax); }
bool isBitAnd() const { return type == int(ReduceTypeEnum::kBitAnd); }
bool isBitOr() const { return type == int(ReduceTypeEnum::kBitOr); }
bool isBitXor() const { return type == int(ReduceTypeEnum::kBitXor); }
};

/// Wrapper class for reduction type with string-based construction
Expand All @@ -68,6 +74,12 @@ class ReduceType : public ObjectRef {
node->type = int(ReduceTypeEnum::kAbsMax);
} else if (type == "min") {
node->type = int(ReduceTypeEnum::kMin);
} else if (type == "bitand") {
node->type = int(ReduceTypeEnum::kBitAnd);
} else if (type == "bitor") {
node->type = int(ReduceTypeEnum::kBitOr);
} else if (type == "bitxor") {
node->type = int(ReduceTypeEnum::kBitXor);
} else {
LOG(FATAL) << "Invalid reduce type: " << type;
}
Expand Down
18 changes: 18 additions & 0 deletions src/tl_templates/cuda/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,24 @@ struct MinOp {
}
};

struct BitAndOp {
template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
return x & y;
}
};

struct BitOrOp {
template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
return x | y;
}
};

struct BitXorOp {
template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
return x ^ y;
}
};

template <class Reducer, int threads, int scale, int thread_offset = 0,
int all_threads = threads>
struct AllReduce {
Expand Down
115 changes: 115 additions & 0 deletions testing/python/math/test_math_bitwise_reduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import tilelang
import tilelang.language as T
import torch
import tilelang.testing


@tilelang.jit(
out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
},
)
def bitwise_reduce(
M,
N,
block_M,
block_N,
name,
func,
clear=True,
):

@T.prim_func
def reduce_func(
A: T.Tensor((M, N), "int32"),
B: T.Tensor((M), "int32"),
Output: T.Tensor((M), "int32"),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), "int32")
A_fragment = T.alloc_fragment((block_M, block_N), "int32")
B_shared = T.alloc_shared((block_M,), "int32")
B_fragment = T.alloc_fragment((block_M), "int32")
T.copy(A[by * block_M, bx * block_N], A_shared)
T.copy(A_shared, A_fragment)
T.copy(B[by * block_M], B_shared)
T.copy(B_shared, B_fragment)
func(A_fragment, B_fragment, clear=clear)
T.copy(B_fragment, Output[by * block_M])

Comment on lines +29 to +40
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Block copies should slice regions, not scalar indices

These T.copy calls index scalars as sources/targets where block regions are intended. Slice the regions to make extents explicit and tail-safe:

-            T.copy(A[by * block_M, bx * block_N], A_shared)
+            T.copy(A[by * block_M:(by + 1) * block_M,
+                     bx * block_N:(bx + 1) * block_N], A_shared)
             T.copy(A_shared, A_fragment)
-            T.copy(B[by * block_M], B_shared)
+            T.copy(B[by * block_M:(by + 1) * block_M], B_shared)
             T.copy(B_shared, B_fragment)
             func(A_fragment, B_fragment, clear=clear)
-            T.copy(B_fragment, Output[by * block_M])
+            T.copy(B_fragment, Output[by * block_M:(by + 1) * block_M])

This mirrors slicing used elsewhere in the codebase and avoids ambiguous pointer-style copies. As per coding guidelines.

🤖 Prompt for AI Agents
testing/python/math/test_math_bitwise_reduce.py around lines 29 to 40: several
T.copy calls are using scalar indices where a blocked region is intended; change
the source and destination arguments to use explicit slices for the block
extents (e.g. row and column ranges of length block_M/block_N) so copies operate
on the full tile and are tail-safe and consistent with the rest of the codebase
— replace scalar-indexed A[by * block_M, bx * block_N] and A_shared, B[by *
block_M] and B_shared, and B_fragment/Output indexing with proper slice ranges
covering block_M and block_N extents respectively.

return reduce_func


def run_single_bitwise_reduce(
name,
func,
clear=True,
):
M, N = 32, 32
block_M, block_N = 32, 32
kernel = bitwise_reduce(M, N, block_M, block_N, name, func, clear)

# Generate test data that exercises all bit patterns for robust bitwise reduce testing
a = torch.zeros((M, N), device="cuda", dtype=torch.int32)

# Fill with patterns that will produce meaningful results for bitwise operations:
# - Different bit patterns across rows/columns
# - Mix of 0s and 1s in various positions
# - Some all-1s and all-0s patterns for edge cases
for i in range(M):
for j in range(N):
# Create varied bit patterns:
# Row-based pattern: alternating bits based on row index
row_pattern = (i & 0xF) << (i % 4) # 4-bit patterns shifted by row

# Column-based pattern: different bit positions set based on column
col_pattern = (1 << (j % 31)) # Single bit set at different positions

# Combine patterns with XOR to create diverse bit distributions
# Add some deterministic "noise" based on position
position_factor = (i * N + j) % 256

# Final value combines all patterns
a[i, j] = (row_pattern ^ col_pattern ^ position_factor) & 0xFFFFFFFF

if i % 4 == 0:
a[i, j] &= ~(0x1 << (i // 4))
elif i % 2 == 0:
a[i, j] |= (0x1 << (i // 2))

if name == "reduce_bitand":
expected = torch.full((M,), -1, device="cuda", dtype=torch.int32)
elif name == "reduce_bitor" or name == "reduce_bitxor":
expected = torch.full((M,), 0, device="cuda", dtype=torch.int32)
else:
raise ValueError("Invalid name: {}".format(name))

output = kernel(a, expected)

for i in range(M):
for j in range(N):
if name == "reduce_bitand":
expected[i] = expected[i] & a[i, j]
elif name == "reduce_bitor":
expected[i] = expected[i] | a[i, j]
elif name == "reduce_bitxor":
expected[i] = expected[i] ^ a[i, j]
else:
raise ValueError("Invalid name: {}".format(name))
assert torch.all(output == expected)
print("✓ {} with clear={} test passed".format(name, clear))


@tilelang.testing.requires_cuda
def test_bitwise_reduce_ops():
run_single_bitwise_reduce("reduce_bitand", T.reduce_bitand, clear=True)
run_single_bitwise_reduce("reduce_bitor", T.reduce_bitor, clear=True)
run_single_bitwise_reduce("reduce_bitxor", T.reduce_bitxor, clear=True)
run_single_bitwise_reduce("reduce_bitand", T.reduce_bitand, clear=False)
run_single_bitwise_reduce("reduce_bitor", T.reduce_bitor, clear=False)
run_single_bitwise_reduce("reduce_bitxor", T.reduce_bitxor, clear=False)


if __name__ == "__main__":
tilelang.testing.main()
3 changes: 3 additions & 0 deletions tilelang/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@
reduce_sum, # noqa: F401
reduce_abssum, # noqa: F401
reduce_absmax, # noqa: F401
reduce_bitand, # noqa: F401
reduce_bitor, # noqa: F401
reduce_bitxor, # noqa: F401
Comment on lines +61 to +63
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove unused noqa markers.

These three exports are valid; the trailing “# noqa: F401” is unnecessary and flagged by Ruff (RUF100). Please drop them.

-    reduce_bitand,  # noqa: F401
-    reduce_bitor,  # noqa: F401
-    reduce_bitxor,  # noqa: F401
+    reduce_bitand,
+    reduce_bitor,
+    reduce_bitxor,
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
reduce_bitand, # noqa: F401
reduce_bitor, # noqa: F401
reduce_bitxor, # noqa: F401
reduce_bitand,
reduce_bitor,
reduce_bitxor,
🧰 Tools
🪛 Ruff (0.14.0)

61-61: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


62-62: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


63-63: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

🤖 Prompt for AI Agents
In tilelang/language/__init__.py around lines 61 to 63, the three exported names
reduce_bitand, reduce_bitor, and reduce_bitxor include unnecessary "# noqa:
F401" markers; remove those trailing noqa comments so the exports remain but the
redundant suppression is gone, leaving the names listed without the "# noqa:
F401" suffix.

cumsum, # noqa: F401
finalize_reducer, # noqa: F401
)
Expand Down
45 changes: 45 additions & 0 deletions tilelang/language/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,51 @@ def reduce_absmax(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: boo
return reduce(buffer, out, "absmax", dim, clear)


def reduce_bitand(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True):
"""Perform reduce bitwise-and on input buffer, store the result to output buffer.

Args:
buffer (tir.Buffer): The input buffer
out (tir.Buffer): The output buffer
dim (int): The dimension to perform reduce on

Returns:
tir.Call: Handle to the reduction operation
"""
dim = _legalize_dim(buffer, dim)
return reduce(buffer, out, "bitand", dim, clear)


def reduce_bitor(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True):
"""Perform reduce bitwise-or on input buffer, store the result to output buffer.

Args:
buffer (tir.Buffer): The input buffer
out (tir.Buffer): The output buffer
dim (int): The dimension to perform reduce on

Returns:
tir.Call: Handle to the reduction operation
"""
dim = _legalize_dim(buffer, dim)
return reduce(buffer, out, "bitor", dim, clear)


def reduce_bitxor(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True):
"""Perform reduce bitwise-xor on input buffer, store the result to output buffer.

Args:
buffer (tir.Buffer): The input buffer
out (tir.Buffer): The output buffer
dim (int): The dimension to perform reduce on

Returns:
tir.Call: Handle to the reduction operation
"""
dim = _legalize_dim(buffer, dim)
return reduce(buffer, out, "bitxor", dim, clear)


@macro
def cumsum_fragment(src: tir.Buffer, dst: tir.Buffer, dim: int, reverse: bool) -> tir.PrimExpr:
cumsum_smem = alloc_shared(src.shape, src.dtype, "shared.dyn")
Expand Down