-
Notifications
You must be signed in to change notification settings - Fork 333
[Feature] Support Reduce operators for bitwise and/or/xor #1074
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,115 @@ | ||
| import tilelang | ||
| import tilelang.language as T | ||
| import torch | ||
| import tilelang.testing | ||
|
|
||
|
|
||
| @tilelang.jit( | ||
| out_idx=[-1], | ||
| pass_configs={ | ||
| tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, | ||
| }, | ||
| ) | ||
| def bitwise_reduce( | ||
| M, | ||
| N, | ||
| block_M, | ||
| block_N, | ||
| name, | ||
| func, | ||
| clear=True, | ||
| ): | ||
|
|
||
| @T.prim_func | ||
| def reduce_func( | ||
| A: T.Tensor((M, N), "int32"), | ||
| B: T.Tensor((M), "int32"), | ||
| Output: T.Tensor((M), "int32"), | ||
| ): | ||
| with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): | ||
| A_shared = T.alloc_shared((block_M, block_N), "int32") | ||
| A_fragment = T.alloc_fragment((block_M, block_N), "int32") | ||
| B_shared = T.alloc_shared((block_M,), "int32") | ||
| B_fragment = T.alloc_fragment((block_M), "int32") | ||
| T.copy(A[by * block_M, bx * block_N], A_shared) | ||
| T.copy(A_shared, A_fragment) | ||
| T.copy(B[by * block_M], B_shared) | ||
| T.copy(B_shared, B_fragment) | ||
| func(A_fragment, B_fragment, clear=clear) | ||
| T.copy(B_fragment, Output[by * block_M]) | ||
|
|
||
|
Comment on lines
+29
to
+40
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Block copies should slice regions, not scalar indices These - T.copy(A[by * block_M, bx * block_N], A_shared)
+ T.copy(A[by * block_M:(by + 1) * block_M,
+ bx * block_N:(bx + 1) * block_N], A_shared)
T.copy(A_shared, A_fragment)
- T.copy(B[by * block_M], B_shared)
+ T.copy(B[by * block_M:(by + 1) * block_M], B_shared)
T.copy(B_shared, B_fragment)
func(A_fragment, B_fragment, clear=clear)
- T.copy(B_fragment, Output[by * block_M])
+ T.copy(B_fragment, Output[by * block_M:(by + 1) * block_M])This mirrors slicing used elsewhere in the codebase and avoids ambiguous pointer-style copies. As per coding guidelines. 🤖 Prompt for AI Agents |
||
| return reduce_func | ||
|
|
||
|
|
||
| def run_single_bitwise_reduce( | ||
| name, | ||
| func, | ||
| clear=True, | ||
| ): | ||
| M, N = 32, 32 | ||
| block_M, block_N = 32, 32 | ||
| kernel = bitwise_reduce(M, N, block_M, block_N, name, func, clear) | ||
|
|
||
| # Generate test data that exercises all bit patterns for robust bitwise reduce testing | ||
| a = torch.zeros((M, N), device="cuda", dtype=torch.int32) | ||
|
|
||
| # Fill with patterns that will produce meaningful results for bitwise operations: | ||
| # - Different bit patterns across rows/columns | ||
| # - Mix of 0s and 1s in various positions | ||
| # - Some all-1s and all-0s patterns for edge cases | ||
| for i in range(M): | ||
| for j in range(N): | ||
| # Create varied bit patterns: | ||
| # Row-based pattern: alternating bits based on row index | ||
| row_pattern = (i & 0xF) << (i % 4) # 4-bit patterns shifted by row | ||
|
|
||
| # Column-based pattern: different bit positions set based on column | ||
| col_pattern = (1 << (j % 31)) # Single bit set at different positions | ||
|
|
||
| # Combine patterns with XOR to create diverse bit distributions | ||
| # Add some deterministic "noise" based on position | ||
| position_factor = (i * N + j) % 256 | ||
|
|
||
| # Final value combines all patterns | ||
| a[i, j] = (row_pattern ^ col_pattern ^ position_factor) & 0xFFFFFFFF | ||
|
|
||
| if i % 4 == 0: | ||
| a[i, j] &= ~(0x1 << (i // 4)) | ||
| elif i % 2 == 0: | ||
| a[i, j] |= (0x1 << (i // 2)) | ||
|
|
||
| if name == "reduce_bitand": | ||
| expected = torch.full((M,), -1, device="cuda", dtype=torch.int32) | ||
| elif name == "reduce_bitor" or name == "reduce_bitxor": | ||
| expected = torch.full((M,), 0, device="cuda", dtype=torch.int32) | ||
| else: | ||
| raise ValueError("Invalid name: {}".format(name)) | ||
|
|
||
| output = kernel(a, expected) | ||
|
|
||
| for i in range(M): | ||
| for j in range(N): | ||
| if name == "reduce_bitand": | ||
| expected[i] = expected[i] & a[i, j] | ||
| elif name == "reduce_bitor": | ||
| expected[i] = expected[i] | a[i, j] | ||
| elif name == "reduce_bitxor": | ||
| expected[i] = expected[i] ^ a[i, j] | ||
| else: | ||
| raise ValueError("Invalid name: {}".format(name)) | ||
| assert torch.all(output == expected) | ||
| print("✓ {} with clear={} test passed".format(name, clear)) | ||
|
|
||
|
|
||
| @tilelang.testing.requires_cuda | ||
| def test_bitwise_reduce_ops(): | ||
| run_single_bitwise_reduce("reduce_bitand", T.reduce_bitand, clear=True) | ||
| run_single_bitwise_reduce("reduce_bitor", T.reduce_bitor, clear=True) | ||
| run_single_bitwise_reduce("reduce_bitxor", T.reduce_bitxor, clear=True) | ||
| run_single_bitwise_reduce("reduce_bitand", T.reduce_bitand, clear=False) | ||
| run_single_bitwise_reduce("reduce_bitor", T.reduce_bitor, clear=False) | ||
| run_single_bitwise_reduce("reduce_bitxor", T.reduce_bitxor, clear=False) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| tilelang.testing.main() | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -58,6 +58,9 @@ | |||||||||||||
| reduce_sum, # noqa: F401 | ||||||||||||||
| reduce_abssum, # noqa: F401 | ||||||||||||||
| reduce_absmax, # noqa: F401 | ||||||||||||||
| reduce_bitand, # noqa: F401 | ||||||||||||||
| reduce_bitor, # noqa: F401 | ||||||||||||||
| reduce_bitxor, # noqa: F401 | ||||||||||||||
|
Comment on lines
+61
to
+63
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove unused noqa markers. These three exports are valid; the trailing “# noqa: F401” is unnecessary and flagged by Ruff (RUF100). Please drop them. - reduce_bitand, # noqa: F401
- reduce_bitor, # noqa: F401
- reduce_bitxor, # noqa: F401
+ reduce_bitand,
+ reduce_bitor,
+ reduce_bitxor,📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff (0.14.0)61-61: Unused Remove unused (RUF100) 62-62: Unused Remove unused (RUF100) 63-63: Unused Remove unused (RUF100) 🤖 Prompt for AI Agents |
||||||||||||||
| cumsum, # noqa: F401 | ||||||||||||||
| finalize_reducer, # noqa: F401 | ||||||||||||||
| ) | ||||||||||||||
|
|
||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bitwise init: guard dtype and avoid 64-bit shift UB; add bool handling.
(1 << bits) - 1overflows/UB when bits=64, and bool falls into the “-INFINITY” path.Fix by asserting integer/bool dtypes and computing all‑ones safely.
} else if (type->isAbsMax()) { return make_const(dst->dtype, 0); - } else if (type->isBitAnd()) { - if (is_int) { - return make_const(dst->dtype, -1); - } else if (is_uint) { - return make_const(dst->dtype, (1 << bits) - 1); - } else { - // Should not arrive here - return make_const(dst->dtype, -INFINITY); - } - } else if (type->isBitOr()) { - return make_zero(dst->dtype); - } else if (type->isBitXor()) { - return make_zero(dst->dtype); + } else if (type->isBitAnd()) { + bool is_bool = dst_dtype.is_bool(); + ICHECK(is_int || is_uint || is_bool) + << "Bitwise AND reduction requires integer or bool dtype, got " << dst_dtype; + if (is_int) { + return make_const(dst->dtype, -1); + } else if (is_uint) { + uint64_t mask = (bits >= 64) ? ~uint64_t(0) : ((uint64_t(1) << bits) - 1); + return make_const(dst->dtype, mask); + } else { // bool + return make_const(dst->dtype, true); + } + } else if (type->isBitOr()) { + bool is_bool = dst_dtype.is_bool(); + ICHECK(is_int || is_uint || is_bool) + << "Bitwise OR reduction requires integer or bool dtype, got " << dst_dtype; + return make_zero(dst->dtype); + } else if (type->isBitXor()) { + bool is_bool = dst_dtype.is_bool(); + ICHECK(is_int || is_uint || is_bool) + << "Bitwise XOR reduction requires integer or bool dtype, got " << dst_dtype; + return make_zero(dst->dtype);