From 8e86da54987b3b0d9b0a44ff7b2a3c01a495bdda Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 25 Apr 2025 16:58:55 +0800 Subject: [PATCH 1/4] [Enhancement] Update reduce operations to support clear option in sum and abs sum (#436) * Modified reduce_sum and reduce_absmax functions to include a clear parameter, allowing for accumulation on existing values. * Updated ReduceOp::Lower method to handle initialization and buffer duplication based on the clear flag for sum and abs sum operations. * Added new tests for reduce_sum and reduce_max with clear functionality to ensure correctness in various scenarios. * Enhanced documentation for reduce functions to clarify the behavior of the clear parameter. --- src/op/reduce.cc | 59 +++++++++-- .../test_tilelang_language_reduce_max.py | 37 +++++++ .../test_tilelang_language_reduce_sum.py | 100 ++++++++++++++++++ tilelang/language/reduce.py | 18 +++- 4 files changed, 201 insertions(+), 13 deletions(-) create mode 100644 testing/python/language/test_tilelang_language_reduce_sum.py diff --git a/src/op/reduce.cc b/src/op/reduce.cc index 9d3dd8357..dc21b6a7f 100644 --- a/src/op/reduce.cc +++ b/src/op/reduce.cc @@ -13,9 +13,11 @@ #include #include +#include "tir/transforms/ir_utils.h" #include "../layout/utils.h" #include "../transform/loop_partition.h" + namespace tvm { namespace tl { @@ -124,10 +126,34 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { Array stmts; + bool require_init = this->clear; + // sum op must be cleared + if (this->type == ReduceType::kSum) { + require_init = true; + } else if (this->type == ReduceType::kAbsSum) { + require_init = true; + } + + Buffer clear_buffer = dst_buffer; + bool need_duplicate = false; + if (this->type == ReduceType::kSum && !this->clear) { + need_duplicate = true; + } else if (this->type == ReduceType::kAbsSum && !this->clear) { + need_duplicate = true; + } + + if (need_duplicate) { + // Create a new buffer with same shape and dtype as dst_buffer + clear_buffer = decl_buffer(dst_buffer->shape, + dst_buffer->dtype, + dst_buffer->name + "_clear", + GetPtrStorageScope(dst_buffer->data)); + } + // make reduce-init stmt - if (this->clear) + if (require_init) stmts.push_back( - BufferStore(dst_buffer, this->MakeInitValue(), dst_indices)); + BufferStore(clear_buffer, this->MakeInitValue(), dst_indices)); // make thread-local reduce Array src_indice_compressed; @@ -141,8 +167,8 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { src_var_compressed.push_back(var); } Stmt reduce_local = BufferStore( - dst_buffer, - this->MakeReduce(BufferLoad(dst_buffer, dst_indices), + clear_buffer, + this->MakeReduce(BufferLoad(clear_buffer, dst_indices), BufferLoad(src_buffer, src_indice_compressed)), dst_indices); for (int i = src_layout->OutputDim() - 1; i >= 0; i--) { @@ -179,20 +205,31 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { << reducing_threads << ", " << (*scale) << ">::run"; } Array thread_reduce_args = { - StringImm(ss.str()), BufferLoad(dst_buffer, dst_indices)}; + StringImm(ss.str()), BufferLoad(clear_buffer, dst_indices)}; if (reducing_threads >= 32) { PrimExpr workspace = T.AddWorkspace( - *as_const_int(T.thread_bounds->extent), dst_buffer->dtype); + *as_const_int(T.thread_bounds->extent), clear_buffer->dtype); thread_reduce_args.push_back(workspace); } auto call = - Call(dst_buffer->dtype, builtin::call_extern(), thread_reduce_args); - stmts.push_back(BufferStore(dst_buffer, call, dst_indices)); + Call(clear_buffer->dtype, builtin::call_extern(), thread_reduce_args); + stmts.push_back(BufferStore(clear_buffer, call, dst_indices)); } } Stmt reduce_interthread = - BufferStore(dst_buffer, BufferLoad(dst_buffer, dst_indices), dst_indices); + BufferStore(clear_buffer, BufferLoad(clear_buffer, dst_indices), dst_indices); + // copy clear_buffer to dst_buffer + if (need_duplicate) { + // if is reduce sum, we should add a copy from clear_buffer to dst_buffer + if (this->type == ReduceType::kSum) { + stmts.push_back(BufferStore(dst_buffer, Add(BufferLoad(dst_buffer, dst_indices), BufferLoad(clear_buffer, dst_indices)), dst_indices)); + } else if (this->type == ReduceType::kAbsSum) { + stmts.push_back(BufferStore(dst_buffer, Add(BufferLoad(dst_buffer, dst_indices), BufferLoad(clear_buffer, dst_indices)), dst_indices)); + } else { + ICHECK(false) << "Unsupported reduce type: " << (int)this->type; + } + } // make the outer spatial loop Stmt body = stmts.size() > 1 ? SeqStmt(stmts) : stmts[0]; for (int i = dst_layout->InputDim() - 1; i >= 0; i--) { @@ -201,6 +238,10 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { } body = PartitionLoop(Downcast(body), T.thread_var, analyzer, dst_layout); + if (need_duplicate) { + body = Allocate(clear_buffer->data, clear_buffer->dtype, clear_buffer->shape, + const_true(), body); + } return body; } diff --git a/testing/python/language/test_tilelang_language_reduce_max.py b/testing/python/language/test_tilelang_language_reduce_max.py index 6a9156d5f..f12fd2d8a 100644 --- a/testing/python/language/test_tilelang_language_reduce_max.py +++ b/testing/python/language/test_tilelang_language_reduce_max.py @@ -49,6 +49,43 @@ def test_reduce_max(): run_reduce_max(256, 256, "float32") run_reduce_max(256, 256, "float16") +def reduce_max_test_clear(M, N, dtype="float16"): + import tilelang.language as T + + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M,), dtype), + ): + with T.Kernel(1, threads=32) as _: + A_local = T.alloc_fragment((M, N), dtype) + B_local = T.alloc_fragment((M,), dtype) + + T.copy(A, A_local) + T.fill(B_local, -T.infinity(dtype)) + T.reduce_max(A_local, B_local, dim=1, clear=False) + T.copy(B_local, B) + + return main + + +def run_reduce_max_clear(M, N, dtype="float16"): + program = reduce_max_test_clear(M, N, dtype) + jit_kernel = tl.compile(program, out_idx=-1) + print(jit_kernel.get_kernel_source()) + def ref_program(A): + return A.max(dim=1).values + + import torch + dummp_A = torch.randn((M, N), dtype=getattr(torch, dtype)).cuda() + ref_out = ref_program(dummp_A) + tl_out = jit_kernel(dummp_A) + print(tl_out) + print(ref_out) + torch.testing.assert_close(tl_out, ref_out, atol=1e-2, rtol=1e-2) + +def test_reduce_max_clear(): + run_reduce_max_clear(256, 256, "float16") if __name__ == "__main__": tilelang.testing.main() diff --git a/testing/python/language/test_tilelang_language_reduce_sum.py b/testing/python/language/test_tilelang_language_reduce_sum.py new file mode 100644 index 000000000..bbbde0682 --- /dev/null +++ b/testing/python/language/test_tilelang_language_reduce_sum.py @@ -0,0 +1,100 @@ +# Copyright (c) Tile-AI Corporation. +# Licensed under the MIT License. + +from tilelang import tvm as tvm +import tilelang.testing +import tilelang as tl + +tilelang.testing.set_random_seed() +tilelang.disable_cache() + +def reduce_sum_test(M, N, dtype="float16"): + import tilelang.language as T + + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M,), dtype), + ): + with T.Kernel(1) as _: + A_local = T.alloc_fragment((M, N), dtype) + B_local = T.alloc_fragment((M,), dtype) + + # Copy input to local + T.copy(A, A_local) + # Perform reduce_sum operation + T.reduce_sum(A_local, B_local, dim=1) + # Copy result back + T.copy(B_local, B) + + return main + + +def run_reduce_sum(M, N, dtype="float16"): + program = reduce_sum_test(M, N, dtype) + jit_kernel = tl.compile(program, out_idx=-1) + profiler = jit_kernel.get_profiler() + + def ref_program(A): + return A.sum(dim=1) + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + +def test_reduce_sum(): + # Test different sizes + run_reduce_sum(256, 256) + run_reduce_sum(512, 128) + run_reduce_sum(128, 512) + + # Test different dtypes + run_reduce_sum(256, 256, "float32") + run_reduce_sum(256, 256, "float16") + +def reduce_sum_test_clear(M, N, dtype="float16"): + import tilelang.language as T + + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M,), dtype), + ): + with T.Kernel(1, threads=32) as _: + A_local = T.alloc_fragment((M, N), dtype) + B_local = T.alloc_fragment((M,), dtype) + + T.copy(A, A_local) + T.fill(B_local, 1) + T.reduce_sum(A_local, B_local, dim=1, clear=False) + T.copy(B_local, B) + + return main + + +def run_reduce_sum_clear(M, N, dtype="float16"): + program = reduce_sum_test_clear(M, N, dtype) + jit_kernel = tl.compile(program, out_idx=-1, pass_configs={ + "tl.disable_tma_lower": True, + "tl.disable_warp_specialized": True, + }) + print(jit_kernel.get_kernel_source()) + + def ref_program(A): + return A.sum(dim=1) + 1 + + import torch + dummp_A = torch.randn((M, N), dtype=getattr(torch, dtype)).cuda() + ref_out = ref_program(dummp_A) + tl_out = jit_kernel(dummp_A) + print(tl_out) + print(ref_out) + torch.testing.assert_close(tl_out, ref_out, atol=1e-2, rtol=1e-2) + + +def test_reduce_sum_clear(): + run_reduce_sum_clear(256, 256, "float16") + run_reduce_sum_clear(512, 128, "float16") + run_reduce_sum_clear(128, 512, "float16") + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/language/reduce.py b/tilelang/language/reduce.py index 83601786d..b634b77e6 100644 --- a/tilelang/language/reduce.py +++ b/tilelang/language/reduce.py @@ -68,18 +68,28 @@ def reduce_min(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True return reduce(buffer, out, "min", dim, clear) -def reduce_sum(buffer: tir.Buffer, out: tir.Buffer, dim: int): +def reduce_sum(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True): """Perform reduce sum on input buffer, store the result to output buffer. Args: buffer (tir.Buffer): The input buffer out (tir.Buffer): The output buffer dim (int): The dimension to perform reduce on + clear (bool, optional): If True, output buffer will be cleared before reduction. + If False, results will be accumulated on existing values. + Defaults to True. + Note: When clear=True, reduce_sum will not compute directly on the output buffer. This is because + during warp reduction, the same value would be accumulated multiple times (number of threads + in the warp). Therefore, the implementation with clear=True follows these steps: + 1. create a temp buffer with same shape and dtype as out + 2. copy out to temp buffer + 3. call reduce_sum with temp buffer and out + 4. Add temp buffer to out Returns: tir.Call: Handle to the reduction operation """ - return reduce(buffer, out, "sum", dim, True) + return reduce(buffer, out, "sum", dim, clear) def reduce_abssum(buffer: tir.Buffer, out: tir.Buffer, dim: int): @@ -96,7 +106,7 @@ def reduce_abssum(buffer: tir.Buffer, out: tir.Buffer, dim: int): return reduce(buffer, out, "abssum", dim, True) -def reduce_absmax(buffer: tir.Buffer, out: tir.Buffer, dim: int): +def reduce_absmax(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True): """Perform reduce absolute max on input buffer, store the result to output buffer. Args: @@ -107,7 +117,7 @@ def reduce_absmax(buffer: tir.Buffer, out: tir.Buffer, dim: int): Returns: tir.Call: Handle to the reduction operation """ - return reduce(buffer, out, "absmax", dim, True) + return reduce(buffer, out, "absmax", dim, clear) @macro From 57904934e1b645fe7b01b300e90310a8a2527950 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 25 Apr 2025 16:59:07 +0800 Subject: [PATCH 2/4] lint fix --- src/op/reduce.cc | 30 +++++++++++-------- .../test_tilelang_language_reduce_max.py | 6 +++- .../test_tilelang_language_reduce_sum.py | 16 ++++++---- 3 files changed, 33 insertions(+), 19 deletions(-) diff --git a/src/op/reduce.cc b/src/op/reduce.cc index dc21b6a7f..6bca83d8b 100644 --- a/src/op/reduce.cc +++ b/src/op/reduce.cc @@ -13,10 +13,9 @@ #include #include -#include "tir/transforms/ir_utils.h" #include "../layout/utils.h" #include "../transform/loop_partition.h" - +#include "tir/transforms/ir_utils.h" namespace tvm { namespace tl { @@ -133,7 +132,7 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { } else if (this->type == ReduceType::kAbsSum) { require_init = true; } - + Buffer clear_buffer = dst_buffer; bool need_duplicate = false; if (this->type == ReduceType::kSum && !this->clear) { @@ -144,10 +143,9 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { if (need_duplicate) { // Create a new buffer with same shape and dtype as dst_buffer - clear_buffer = decl_buffer(dst_buffer->shape, - dst_buffer->dtype, - dst_buffer->name + "_clear", - GetPtrStorageScope(dst_buffer->data)); + clear_buffer = decl_buffer(dst_buffer->shape, dst_buffer->dtype, + dst_buffer->name + "_clear", + GetPtrStorageScope(dst_buffer->data)); } // make reduce-init stmt @@ -216,16 +214,22 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { stmts.push_back(BufferStore(clear_buffer, call, dst_indices)); } } - Stmt reduce_interthread = - BufferStore(clear_buffer, BufferLoad(clear_buffer, dst_indices), dst_indices); + Stmt reduce_interthread = BufferStore( + clear_buffer, BufferLoad(clear_buffer, dst_indices), dst_indices); // copy clear_buffer to dst_buffer if (need_duplicate) { // if is reduce sum, we should add a copy from clear_buffer to dst_buffer if (this->type == ReduceType::kSum) { - stmts.push_back(BufferStore(dst_buffer, Add(BufferLoad(dst_buffer, dst_indices), BufferLoad(clear_buffer, dst_indices)), dst_indices)); + stmts.push_back(BufferStore(dst_buffer, + Add(BufferLoad(dst_buffer, dst_indices), + BufferLoad(clear_buffer, dst_indices)), + dst_indices)); } else if (this->type == ReduceType::kAbsSum) { - stmts.push_back(BufferStore(dst_buffer, Add(BufferLoad(dst_buffer, dst_indices), BufferLoad(clear_buffer, dst_indices)), dst_indices)); + stmts.push_back(BufferStore(dst_buffer, + Add(BufferLoad(dst_buffer, dst_indices), + BufferLoad(clear_buffer, dst_indices)), + dst_indices)); } else { ICHECK(false) << "Unsupported reduce type: " << (int)this->type; } @@ -239,8 +243,8 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { body = PartitionLoop(Downcast(body), T.thread_var, analyzer, dst_layout); if (need_duplicate) { - body = Allocate(clear_buffer->data, clear_buffer->dtype, clear_buffer->shape, - const_true(), body); + body = Allocate(clear_buffer->data, clear_buffer->dtype, + clear_buffer->shape, const_true(), body); } return body; } diff --git a/testing/python/language/test_tilelang_language_reduce_max.py b/testing/python/language/test_tilelang_language_reduce_max.py index f12fd2d8a..a004f816f 100644 --- a/testing/python/language/test_tilelang_language_reduce_max.py +++ b/testing/python/language/test_tilelang_language_reduce_max.py @@ -49,6 +49,7 @@ def test_reduce_max(): run_reduce_max(256, 256, "float32") run_reduce_max(256, 256, "float16") + def reduce_max_test_clear(M, N, dtype="float16"): import tilelang.language as T @@ -60,7 +61,7 @@ def main( with T.Kernel(1, threads=32) as _: A_local = T.alloc_fragment((M, N), dtype) B_local = T.alloc_fragment((M,), dtype) - + T.copy(A, A_local) T.fill(B_local, -T.infinity(dtype)) T.reduce_max(A_local, B_local, dim=1, clear=False) @@ -73,6 +74,7 @@ def run_reduce_max_clear(M, N, dtype="float16"): program = reduce_max_test_clear(M, N, dtype) jit_kernel = tl.compile(program, out_idx=-1) print(jit_kernel.get_kernel_source()) + def ref_program(A): return A.max(dim=1).values @@ -84,8 +86,10 @@ def ref_program(A): print(ref_out) torch.testing.assert_close(tl_out, ref_out, atol=1e-2, rtol=1e-2) + def test_reduce_max_clear(): run_reduce_max_clear(256, 256, "float16") + if __name__ == "__main__": tilelang.testing.main() diff --git a/testing/python/language/test_tilelang_language_reduce_sum.py b/testing/python/language/test_tilelang_language_reduce_sum.py index bbbde0682..a8cd68ff7 100644 --- a/testing/python/language/test_tilelang_language_reduce_sum.py +++ b/testing/python/language/test_tilelang_language_reduce_sum.py @@ -8,6 +8,7 @@ tilelang.testing.set_random_seed() tilelang.disable_cache() + def reduce_sum_test(M, N, dtype="float16"): import tilelang.language as T @@ -51,6 +52,7 @@ def test_reduce_sum(): run_reduce_sum(256, 256, "float32") run_reduce_sum(256, 256, "float16") + def reduce_sum_test_clear(M, N, dtype="float16"): import tilelang.language as T @@ -62,7 +64,7 @@ def main( with T.Kernel(1, threads=32) as _: A_local = T.alloc_fragment((M, N), dtype) B_local = T.alloc_fragment((M,), dtype) - + T.copy(A, A_local) T.fill(B_local, 1) T.reduce_sum(A_local, B_local, dim=1, clear=False) @@ -73,10 +75,13 @@ def main( def run_reduce_sum_clear(M, N, dtype="float16"): program = reduce_sum_test_clear(M, N, dtype) - jit_kernel = tl.compile(program, out_idx=-1, pass_configs={ - "tl.disable_tma_lower": True, - "tl.disable_warp_specialized": True, - }) + jit_kernel = tl.compile( + program, + out_idx=-1, + pass_configs={ + "tl.disable_tma_lower": True, + "tl.disable_warp_specialized": True, + }) print(jit_kernel.get_kernel_source()) def ref_program(A): @@ -96,5 +101,6 @@ def test_reduce_sum_clear(): run_reduce_sum_clear(512, 128, "float16") run_reduce_sum_clear(128, 512, "float16") + if __name__ == "__main__": tilelang.testing.main() From bdcc58116bf916e41ae85d42833286cd078ac07c Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 26 Apr 2025 11:58:33 +0000 Subject: [PATCH 3/4] Update tensor type annotations in test_tilelang_transform_annotate_device_regions.py from Buffer to Tensor --- src/transform/pipeline_planning.cc | 5 ++--- .../test_tilelang_transform_annotate_device_regions.py | 8 ++++---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/transform/pipeline_planning.cc b/src/transform/pipeline_planning.cc index 19b70a719..8e12c96db 100644 --- a/src/transform/pipeline_planning.cc +++ b/src/transform/pipeline_planning.cc @@ -389,7 +389,6 @@ class PipelinePlanner : public StmtExprMutator { // Handle trailing unassigned copy stages: // These are typically final copy operations needing post-main-stage // insertion - auto &head_pinfo = pipeline_stage_infos.at(0); int unassigned_order_elem = -1; @@ -422,7 +421,7 @@ class PipelinePlanner : public StmtExprMutator { int copy_order_min = pipeline_stage_infos.size(); int non_copy_order_max = 0; for (auto &pinfo : pipeline_stage_infos) { - if (pinfo.copy_stage) { + if (pinfo.copy_stage || pinfo.prepare_for_condition) { copy_stage_cnt++; copy_order_min = std::min(copy_order_min, pinfo.order); } else { @@ -437,7 +436,7 @@ class PipelinePlanner : public StmtExprMutator { for (auto &pinfo : pipeline_stage_infos) { // move copy to the beginning pinfo.order = (pinfo.order + copy_stage_at_end) % pipeline_stage_infos.size(); - if (!pinfo.copy_stage) + if (!pinfo.copy_stage && !pinfo.prepare_for_condition) pinfo.stage--; } } diff --git a/testing/python/transform/test_tilelang_transform_annotate_device_regions.py b/testing/python/transform/test_tilelang_transform_annotate_device_regions.py index 2732c4239..c8cbd9e23 100644 --- a/testing/python/transform/test_tilelang_transform_annotate_device_regions.py +++ b/testing/python/transform/test_tilelang_transform_annotate_device_regions.py @@ -13,12 +13,12 @@ class BaseCompare(tilelang.testing.CompareBeforeAfter): class TestAnnotateThreadExtent(BaseCompare): """Annotation inserted at the "thread_extent" attribute""" - def before(A: T.Buffer(16, "float32")): + def before(A: T.Tensor(16, "float32")): T.func_attr({"target": T.target("cuda", host="llvm")}) i = T.launch_thread("threadIdx.x", 16) A[i] = 0.0 - def expected(A: T.Buffer(16, "float32")): + def expected(A: T.Tensor(16, "float32")): T.func_attr({"target": T.target("cuda", host="llvm")}) T.attr(T.target("cuda"), "target", 0) i = T.launch_thread("threadIdx.x", 16) @@ -28,12 +28,12 @@ def expected(A: T.Buffer(16, "float32")): class TestAnnotateDeviceScope(BaseCompare): """Annotation inserted at the "device_scope" attribute""" - def before(A: T.Buffer(1, "float32")): + def before(A: T.Tensor(1, "float32")): T.func_attr({"target": T.target("cuda", host="llvm")}) T.attr(0, "device_scope", 0) A[0] = 0.0 - def expected(A: T.Buffer(1, "float32")): + def expected(A: T.Tensor(1, "float32")): T.func_attr({"target": T.target("cuda", host="llvm")}) T.attr(T.target("cuda"), "target", 0) T.attr(0, "device_scope", 0) From 22cef554778da30e16c5722061a101ca60802b83 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 26 Apr 2025 12:52:56 +0000 Subject: [PATCH 4/4] Update tensor type in reduce sum tests from float16 to float32 for improved precision --- .../python/language/test_tilelang_language_reduce_sum.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/testing/python/language/test_tilelang_language_reduce_sum.py b/testing/python/language/test_tilelang_language_reduce_sum.py index a8cd68ff7..4958aab8d 100644 --- a/testing/python/language/test_tilelang_language_reduce_sum.py +++ b/testing/python/language/test_tilelang_language_reduce_sum.py @@ -6,7 +6,6 @@ import tilelang as tl tilelang.testing.set_random_seed() -tilelang.disable_cache() def reduce_sum_test(M, N, dtype="float16"): @@ -97,9 +96,9 @@ def ref_program(A): def test_reduce_sum_clear(): - run_reduce_sum_clear(256, 256, "float16") - run_reduce_sum_clear(512, 128, "float16") - run_reduce_sum_clear(128, 512, "float16") + run_reduce_sum_clear(256, 256, "float32") + run_reduce_sum_clear(512, 128, "float32") + run_reduce_sum_clear(128, 512, "float32") if __name__ == "__main__":