Skip to content

Commit 0395ea1

Browse files
committed
[Lint]
1 parent ddf1a95 commit 0395ea1

File tree

4 files changed

+31
-29
lines changed

4 files changed

+31
-29
lines changed

examples/flash_attention/example_gqa_bwd_tma_reduce.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,8 @@ def maybe_contiguous(x):
443443
dk = torch.empty(shape_k, dtype=torch.float16, device=q.device)
444444
dv = torch.empty(shape_v, dtype=torch.float16, device=q.device)
445445
kernel(q, k, v, do, lse, delta, dq, dk, dv)
446-
dq = mod_post(dq)
446+
dq, _, _ = mod_post(dq, torch.zeros_like(k, dtype=torch.float32),
447+
torch.zeros_like(v, dtype=torch.float32))
447448
dk, dv = dk.sum(0), dv.sum(0)
448449

449450
return dq, dk, dv, None, None, None

src/op/reduce.cc

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -338,25 +338,27 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
338338
dst_indices));
339339
} else if (this->type->isBitAnd()) {
340340
if (!this->clear) {
341-
stmts.push_back(BufferStore(dst_buffer,
342-
bitwise_and(BufferLoad(dst_buffer, dst_indices),
343-
BufferLoad(clear_buffer, dst_indices)),
344-
dst_indices));
341+
stmts.push_back(
342+
BufferStore(dst_buffer,
343+
bitwise_and(BufferLoad(dst_buffer, dst_indices),
344+
BufferLoad(clear_buffer, dst_indices)),
345+
dst_indices));
345346
} else {
346-
stmts.push_back(BufferStore(dst_buffer,
347-
BufferLoad(clear_buffer, dst_indices),
348-
dst_indices));
347+
stmts.push_back(BufferStore(
348+
dst_buffer, BufferLoad(clear_buffer, dst_indices), dst_indices));
349349
}
350350
} else if (this->type->isBitOr()) {
351-
stmts.push_back(BufferStore(dst_buffer,
352-
bitwise_or(BufferLoad(dst_buffer, dst_indices),
353-
BufferLoad(clear_buffer, dst_indices)),
354-
dst_indices));
351+
stmts.push_back(
352+
BufferStore(dst_buffer,
353+
bitwise_or(BufferLoad(dst_buffer, dst_indices),
354+
BufferLoad(clear_buffer, dst_indices)),
355+
dst_indices));
355356
} else if (this->type->isBitXor()) {
356-
stmts.push_back(BufferStore(dst_buffer,
357-
bitwise_xor(BufferLoad(dst_buffer, dst_indices),
358-
BufferLoad(clear_buffer, dst_indices)),
359-
dst_indices));
357+
stmts.push_back(
358+
BufferStore(dst_buffer,
359+
bitwise_xor(BufferLoad(dst_buffer, dst_indices),
360+
BufferLoad(clear_buffer, dst_indices)),
361+
dst_indices));
360362
} else {
361363
ICHECK(false) << "Unsupported reduce type: " << this->type->type;
362364
}

testing/python/math/test_math_bitwise_reduce.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
import torch
44
import tilelang.testing
55

6+
67
@tilelang.jit(
78
out_idx=[-1],
89
pass_configs={
910
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
1011
},
11-
debug_root_path="./testing/python/math/"
1212
)
1313
def bitwise_reduce(
1414
M,
@@ -19,11 +19,12 @@ def bitwise_reduce(
1919
func,
2020
clear=True,
2121
):
22+
2223
@T.prim_func
2324
def reduce_func(
24-
A: T.Tensor((M, N), "int32"),
25-
B: T.Tensor((M), "int32"),
26-
Output: T.Tensor((M), "int32"),
25+
A: T.Tensor((M, N), "int32"),
26+
B: T.Tensor((M), "int32"),
27+
Output: T.Tensor((M), "int32"),
2728
):
2829
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
2930
A_shared = T.alloc_shared((block_M, block_N), "int32")
@@ -51,7 +52,7 @@ def run_single_bitwise_reduce(
5152

5253
# Generate test data that exercises all bit patterns for robust bitwise reduce testing
5354
a = torch.zeros((M, N), device="cuda", dtype=torch.int32)
54-
55+
5556
# Fill with patterns that will produce meaningful results for bitwise operations:
5657
# - Different bit patterns across rows/columns
5758
# - Mix of 0s and 1s in various positions
@@ -61,14 +62,14 @@ def run_single_bitwise_reduce(
6162
# Create varied bit patterns:
6263
# Row-based pattern: alternating bits based on row index
6364
row_pattern = (i & 0xF) << (i % 4) # 4-bit patterns shifted by row
64-
65+
6566
# Column-based pattern: different bit positions set based on column
6667
col_pattern = (1 << (j % 31)) # Single bit set at different positions
67-
68+
6869
# Combine patterns with XOR to create diverse bit distributions
6970
# Add some deterministic "noise" based on position
7071
position_factor = (i * N + j) % 256
71-
72+
7273
# Final value combines all patterns
7374
a[i, j] = (row_pattern ^ col_pattern ^ position_factor) & 0xFFFFFFFF
7475

@@ -79,13 +80,11 @@ def run_single_bitwise_reduce(
7980

8081
if name == "reduce_bitand":
8182
expected = torch.full((M,), -1, device="cuda", dtype=torch.int32)
82-
elif name == "reduce_bitor":
83-
expected = torch.full((M,), 0, device="cuda", dtype=torch.int32)
84-
elif name == "reduce_bitxor":
83+
elif name == "reduce_bitor" or name == "reduce_bitxor":
8584
expected = torch.full((M,), 0, device="cuda", dtype=torch.int32)
8685
else:
8786
raise ValueError("Invalid name: {}".format(name))
88-
87+
8988
output = kernel(a, expected)
9089

9190
for i in range(M):

tilelang/language/reduce.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def reduce_bitand(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: boo
146146
buffer (tir.Buffer): The input buffer
147147
out (tir.Buffer): The output buffer
148148
dim (int): The dimension to perform reduce on
149-
149+
150150
Returns:
151151
tir.Call: Handle to the reduction operation
152152
"""

0 commit comments

Comments
 (0)