From 9d20a9622786df440adf0f7f1c294f9d8ffdc0dc Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 23 Oct 2025 12:16:37 +0800 Subject: [PATCH 1/3] [Refactor] Improve scalar handling in CopyNode and update loop partition dtype logic * Refactored CopyNode::MakeSIMTLoop to handle scalar cases more efficiently by moving the scalar check to the end of the function. * Updated loop_partition.cc to set a default DataType for thread and vector extents, ensuring compatibility when loop_vars_ is empty. --- src/op/copy.cc | 10 +++++----- src/transform/loop_partition.cc | 6 ++++-- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/op/copy.cc b/src/op/copy.cc index a16d09dad..4da6548d2 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -299,10 +299,6 @@ PrimExpr CopyNode::MakePredicate(arith::Analyzer *analyzer, For CopyNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { Array loop_vars = MakeIterVars(); bool is_scalar = loop_vars.empty(); - if (is_scalar) { - return For(Var("i"), 0, 1, ForKind::kSerial, - BufferStore(dst, BufferLoad(src, {0}), {0})); - } for (const auto &iv : loop_vars) analyzer->Bind(iv->var, iv->dom); @@ -332,6 +328,9 @@ For CopyNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { Stmt body = BufferStore(dst, value, dst_indices); if (dst_predicate.defined()) body = IfThenElse(dst_predicate, body); + if (is_scalar) { + return For(Var("i"), 0, 1, ForKind::kSerial, body); + } for (int i = loop_vars.size() - 1; i >= 0; i--) { Map annotations = {}; if (coalesced_width.defined()) { @@ -844,6 +843,7 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T, arith::Analyzer *analyzer) const { bool is_cpu_target = T.target->GetTargetDeviceType() == kDLCPU; auto simt_loop = MakeSIMTLoop(analyzer); + LOG(INFO) << "[LowerNormalCopy] simt_loop " << simt_loop; auto fused_loop = Downcast(ParallelLoopFuser::Fuse(simt_loop)); auto transformed_loop = @@ -1979,4 +1979,4 @@ TVM_FFI_STATIC_INIT_BLOCK({ Conv2DIm2ColOpNode::RegisterReflection(); }); } // namespace tl -} // namespace tvm \ No newline at end of file +} // namespace tvm diff --git a/src/transform/loop_partition.cc b/src/transform/loop_partition.cc index 24168677e..81c32b0b8 100644 --- a/src/transform/loop_partition.cc +++ b/src/transform/loop_partition.cc @@ -189,8 +189,10 @@ class LoopPartitioner : public StmtExprVisitor { Fragment Partition(const For &op, int num_thread, int vectorize_size) { this->VisitStmt(op); - ICHECK(!loop_vars_.empty()); - DataType dtype = loop_vars_[0]->var.dtype(); + DataType dtype = DataType::Int(32); + if (!loop_vars_.empty()){ + dtype = loop_vars_.back()->var.dtype(); + } PrimExpr flattened = make_const(dtype, 0); PrimExpr vector_extent = make_const(dtype, vectorize_size); PrimExpr thread_extent_const = make_const(dtype, num_thread); From d3aa225f828c27df89213e844dc3c2ae2972e2c8 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 23 Oct 2025 12:17:04 +0800 Subject: [PATCH 2/3] lint fix --- src/transform/loop_partition.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transform/loop_partition.cc b/src/transform/loop_partition.cc index 81c32b0b8..e9930310a 100644 --- a/src/transform/loop_partition.cc +++ b/src/transform/loop_partition.cc @@ -190,7 +190,7 @@ class LoopPartitioner : public StmtExprVisitor { Fragment Partition(const For &op, int num_thread, int vectorize_size) { this->VisitStmt(op); DataType dtype = DataType::Int(32); - if (!loop_vars_.empty()){ + if (!loop_vars_.empty()) { dtype = loop_vars_.back()->var.dtype(); } PrimExpr flattened = make_const(dtype, 0); From 6d5b3e1a1a3e50bbaa3ffd0e990e2acb0684128f Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 23 Oct 2025 12:19:09 +0800 Subject: [PATCH 3/3] remove debug print --- src/op/copy.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/op/copy.cc b/src/op/copy.cc index 4da6548d2..754dd7336 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -843,7 +843,6 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T, arith::Analyzer *analyzer) const { bool is_cpu_target = T.target->GetTargetDeviceType() == kDLCPU; auto simt_loop = MakeSIMTLoop(analyzer); - LOG(INFO) << "[LowerNormalCopy] simt_loop " << simt_loop; auto fused_loop = Downcast(ParallelLoopFuser::Fuse(simt_loop)); auto transformed_loop =