diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 074bcdd3f533..ac660bfb7461 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -861,6 +861,53 @@ class For : public Stmt { TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode); }; +/*! + * \brief A While loop + * + * \code + * + * while (condition) + * body + * + * \endcode + */ +class WhileNode : public StmtNode { + public: + /*! \brief The termination condition. */ + PrimExpr condition; + /*! \brief The body of the while loop. */ + Stmt body; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("condition", &condition); + v->Visit("body", &body); + v->Visit("span", &span); + } + + bool SEqualReduce(const WhileNode* other, SEqualReducer equal) const { + return equal.DefEqual(condition, other->condition) && equal.DefEqual(body, other->body); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce.DefHash(condition); + hash_reduce.DefHash(body); + } + + static constexpr const char* _type_key = "tir.While"; + TVM_DECLARE_FINAL_OBJECT_INFO(WhileNode, StmtNode); +}; + +/*! + * \brief Managed reference to WhileNode. + * \sa WhileNode + */ +class While : public Stmt { + public: + TVM_DLL While(PrimExpr condition, Stmt body, Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(While, Stmt, WhileNode); +}; + /*! * \brief A prefetch hint for a buffer */ diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index e53b02d73e1d..ceebbbb305ce 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -86,6 +86,7 @@ class StmtFunctor { virtual R VisitStmt_(const AttrStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const IfThenElseNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const ForNode* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const WhileNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; @@ -111,6 +112,7 @@ class StmtFunctor { IR_STMT_FUNCTOR_DISPATCH(AttrStmtNode); IR_STMT_FUNCTOR_DISPATCH(IfThenElseNode); IR_STMT_FUNCTOR_DISPATCH(ForNode); + IR_STMT_FUNCTOR_DISPATCH(WhileNode); IR_STMT_FUNCTOR_DISPATCH(AllocateNode); IR_STMT_FUNCTOR_DISPATCH(StoreNode); IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode); @@ -152,6 +154,7 @@ class TVM_DLL StmtVisitor : protected StmtFunctor { void VisitStmt_(const IfThenElseNode* op) override; void VisitStmt_(const LetStmtNode* op) override; void VisitStmt_(const ForNode* op) override; + void VisitStmt_(const WhileNode* op) override; void VisitStmt_(const AllocateNode* op) override; void VisitStmt_(const StoreNode* op) override; void VisitStmt_(const BufferStoreNode* op) override; @@ -245,6 +248,7 @@ class TVM_DLL StmtMutator : protected StmtFunctor { Stmt VisitStmt_(const IfThenElseNode* op) override; Stmt VisitStmt_(const LetStmtNode* op) override; Stmt VisitStmt_(const ForNode* op) override; + Stmt VisitStmt_(const WhileNode* op) override; Stmt VisitStmt_(const AllocateNode* op) override; Stmt VisitStmt_(const StoreNode* op) override; Stmt VisitStmt_(const BufferStoreNode* op) override; diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 437e8f6610f4..2ecbdeda8371 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -263,6 +263,35 @@ def _exit_cb(): return WithScope(loop_var, _exit_cb) + def while_loop(self, condition): + """Create a while loop scope. + + Parameters + ---------- + condition : Expr + The termination condition. + + Returns + ------- + loop_scope : With.Scope of Var + The while scope. + + Examples + -------- + .. code-block:: python + + ib = tvm.tir.ir_builder.create() + iterations = ib.allocate("int32", (1,), name="iterations", scope="local") + with ib.while_loop(iterations[0] < 10): + iterations[0] += 1 + """ + self._seq_stack.append([]) + + def _exit_cb(): + self.emit(_stmt.While(condition, self._pop_seq())) + + return WithScope(None, _exit_cb) + def if_scope(self, cond): """Create an if scope. diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index e4f1ac924a83..47462066c364 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -159,6 +159,31 @@ def __init__( ) +@tvm._ffi.register_object("tir.While") +class While(Stmt): + """While node. + + Parameters + ---------- + condition : PrimExpr + The termination condition. + + body : Stmt + The body statement. + + span : Optional[Span] + The location of this itervar in the source code. + """ + + def __init__(self, condition, body, span=None): + self.__init_handle_by_constructor__( + _ffi_api.While, + condition, + body, + span, + ) + + @tvm._ffi.register_object("tir.Store") class Store(Stmt): """Store node. diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 152b1bd15987..83b538554ed4 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -521,7 +521,7 @@ def nms_inner_loop(ib, j): offset_j = j * 4 num_iter_per_thread = ceil_div(nkeep - (j + 1), nthread_tx) - with ib.for_range(0, num_iter_per_thread) as _k: + with ib.for_range(0, num_iter_per_thread, name="_k") as _k: k = j + 1 + _k * nthread_tx + tx offset_k = k * 4 @@ -555,16 +555,22 @@ def nms_inner_loop(ib, j): with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): # Apply nms - with ib.for_range(0, nkeep) as j: - # Proceed to the inner loop if the box j is still valid - with ib.if_scope(out_scores[i, j] > -1.0): - with ib.if_scope(max_output_size > 0): - # No need to do more iteration if we have already reached max_output_size - # boxes - # TODO(masahi): Add TIR while loop to realize early exit from the outer loop - with ib.if_scope(num_valid_boxes_local[0] < max_output_size): - nms_inner_loop(ib, j) - with ib.else_scope(): + with ib.if_scope(max_output_size > 0): + # No need to do more iteration if we have already reached max_output_size boxes + box_idx = ib.allocate("int32", (1,), name="box_idx", scope="local") + box_idx[0] = 0 + with ib.while_loop( + tvm.tir.all(box_idx[0] < nkeep, num_valid_boxes_local[0] < max_output_size) + ): + # Proceed to the inner loop if the box with id box_idx is still valid + with ib.if_scope(out_scores[i, box_idx[0]] > -1.0): + nms_inner_loop(ib, box_idx[0]) + box_idx[0] += 1 + + with ib.else_scope(): + with ib.for_range(0, nkeep, name="j") as j: + # Proceed to the inner loop if the box j is still valid + with ib.if_scope(out_scores[i, j] > -1.0): nms_inner_loop(ib, j) with ib.if_scope(tx + 0 == 0): diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index 9a24fe65b4b1..6ec32a9e104c 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -308,6 +308,7 @@ class TIRTextPrinter : public StmtFunctor, Doc VisitStmt_(const SeqStmtNode* op) override; Doc VisitStmt_(const EvaluateNode* op) override; Doc VisitStmt_(const ForNode* op) override; + Doc VisitStmt_(const WhileNode* op) override; Doc VisitStmt_(const PrefetchNode* op) override; Doc VisitStmtDefault_(const Object* op) override; diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index 711af2a8fd08..8d5bba5e5bb0 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -494,6 +494,13 @@ Doc TIRTextPrinter::VisitStmt_(const ForNode* op) { return doc; } +Doc TIRTextPrinter::VisitStmt_(const WhileNode* op) { + Doc doc; + doc << "while (" << Print(op->condition) << ")"; + doc << PrintBody(op->body); + return doc; +} + Doc TIRTextPrinter::VisitStmt_(const PrefetchNode* op) { Doc doc; doc << "prefetch(" << Print(op->buffer) << ", " << Print(op->bounds) << ")"; diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 1dd76f6b9d51..d5140677d45a 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1328,6 +1328,20 @@ void CodeGenLLVM::VisitStmt_(const ForNode* op) { llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1), op->loop_var, op->body); } +void CodeGenLLVM::VisitStmt_(const WhileNode* op) { + using llvm::BasicBlock; + BasicBlock* while_cond = BasicBlock::Create(*ctx_, "while_cond", function_); + BasicBlock* while_body = BasicBlock::Create(*ctx_, "while_body", function_); + BasicBlock* while_merge = BasicBlock::Create(*ctx_, "while_merge", function_); + builder_->CreateBr(while_cond); + builder_->SetInsertPoint(while_cond); + builder_->CreateCondBr(MakeValue(op->condition), while_body, while_merge); + builder_->SetInsertPoint(while_body); + this->VisitStmt(op->body); + builder_->CreateBr(while_cond); + builder_->SetInsertPoint(while_merge); +} + void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) { using llvm::BasicBlock; llvm::Value* cond = MakeValue(op->condition); diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 71583708da2c..e56a6de6d914 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -152,6 +152,7 @@ class CodeGenLLVM : public ExprFunctor, // stmt void VisitStmt_(const StoreNode* op) override; void VisitStmt_(const ForNode* op) override; + void VisitStmt_(const WhileNode* op) override; void VisitStmt_(const IfThenElseNode* op) override; void VisitStmt_(const AllocateNode* op) override; void VisitStmt_(const AttrStmtNode* op) override; diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index af175c7f2208..55db59f8d842 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -728,7 +728,6 @@ void CodeGenC::VisitStmt_(const StoreNode* op) { ICHECK(is_one(op->predicate)) << "Predicated store is not supported"; arith::PVar base; - if (arith::ramp(base, 1, t.lanes()).Match(op->index)) { std::string value = this->PrintExpr(op->value); this->PrintVecStore(op->buffer_var.get(), t, base.Eval(), value); @@ -899,6 +898,16 @@ void CodeGenC::VisitStmt_(const ForNode* op) { stream << "}\n"; } +void CodeGenC::VisitStmt_(const WhileNode* op) { + PrintIndent(); + stream << "while (" << PrintExpr(op->condition) << ") {\n"; + int while_scope = BeginScope(); + PrintStmt(op->body); + this->EndScope(while_scope); + PrintIndent(); + stream << "}\n"; +} + void CodeGenC::VisitStmt_(const IfThenElseNode* op) { std::string cond = PrintExpr(op->condition); PrintIndent(); diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index c1b566c064a4..76e6a9bc7197 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -150,6 +150,7 @@ class CodeGenC : public ExprFunctor, void VisitStmt_(const LetStmtNode* op) override; void VisitStmt_(const StoreNode* op) override; void VisitStmt_(const ForNode* op) override; + void VisitStmt_(const WhileNode* op) override; void VisitStmt_(const IfThenElseNode* op) override; void VisitStmt_(const AllocateNode* op) override; void VisitStmt_(const AttrStmtNode* op) override; diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index e54be4347c8e..2aeaae3eb592 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -197,6 +197,38 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "}\n"; }); +// While +While::While(PrimExpr condition, Stmt body, Span span) { + ICHECK(condition.defined()); + ICHECK(condition.dtype().is_scalar()); + ICHECK(condition.as() == nullptr) << "The condition should not be trivial."; + ICHECK(body.defined()); + + ObjectPtr node = make_object(); + node->condition = std::move(condition); + node->body = std::move(body); + node->span = std::move(span); + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("tir.While").set_body_typed([](PrimExpr condition, Stmt body, Span span) { + return While(condition, body, span); +}); + +TVM_REGISTER_NODE_TYPE(WhileNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "while(" << op->condition << "){\n"; + p->indent += 2; + p->Print(op->body); + p->indent -= 2; + p->PrintIndent(); + p->stream << "}\n"; + }); + // Store Store::Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate, Span span) { ICHECK(value.defined()); diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index f05dc7116494..639d38db0a81 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -45,6 +45,11 @@ void StmtVisitor::VisitStmt_(const ForNode* op) { this->VisitStmt(op->body); } +void StmtVisitor::VisitStmt_(const WhileNode* op) { + this->VisitExpr(op->condition); + this->VisitStmt(op->body); +} + void StmtVisitor::VisitStmt_(const AllocateNode* op) { VisitArray(op->extents, [this](const PrimExpr& e) { this->VisitExpr(e); }); this->VisitStmt(op->body); @@ -283,6 +288,19 @@ Stmt StmtMutator::VisitStmt_(const ForNode* op) { } } +Stmt StmtMutator::VisitStmt_(const WhileNode* op) { + PrimExpr condition = this->VisitExpr(op->condition); + Stmt body = this->VisitStmt(op->body); + if (condition.same_as(op->condition) && body.same_as(op->body)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->condition = std::move(condition); + n->body = std::move(body); + return Stmt(n); + } +} + Stmt StmtMutator::VisitStmt_(const AllocateNode* op) { Array extents = Internal::Mutate(this, op->extents); Stmt body = this->VisitStmt(op->body); diff --git a/src/tir/transforms/coproc_sync.cc b/src/tir/transforms/coproc_sync.cc index f9245442d268..424a1bbb0ae6 100644 --- a/src/tir/transforms/coproc_sync.cc +++ b/src/tir/transforms/coproc_sync.cc @@ -429,6 +429,11 @@ class CoProcInstDepDetector : public StmtVisitor { } } + void VisitStmt_(const WhileNode* op) final { + // TODO(masahi): Do we need a special handling for While nodes? + LOG(FATAL) << "WhileNode not supported in CoProcSync."; + } + // insert before is stored in reverse order // the first element is closest to the node. std::unordered_map > insert_before_; diff --git a/src/tir/transforms/inject_virtual_thread.cc b/src/tir/transforms/inject_virtual_thread.cc index b24a0e95cd53..4ef10f326bb0 100644 --- a/src/tir/transforms/inject_virtual_thread.cc +++ b/src/tir/transforms/inject_virtual_thread.cc @@ -333,6 +333,13 @@ class VTInjector : public StmtExprMutator { } } + // While + Stmt VisitStmt_(const WhileNode* op) final { + // TODO(masahi): What should we do for While nodes? + LOG(FATAL) << "WhileNode in InjectVirtualThread not supported yet"; + return Stmt(); + } + // Seq Stmt VisitStmt_(const SeqStmtNode* op) final { ICHECK_EQ(max_loop_depth_, 0); diff --git a/src/tir/transforms/lift_attr_scope.cc b/src/tir/transforms/lift_attr_scope.cc index 27dd583b8b42..40d152b3b3b6 100644 --- a/src/tir/transforms/lift_attr_scope.cc +++ b/src/tir/transforms/lift_attr_scope.cc @@ -157,6 +157,12 @@ class AttrScopeLifter : public StmtMutator { } } + Stmt VisitStmt_(const WhileNode* op) final { + // TODO(masahi): Do we need a special handling for While nodes? + LOG(FATAL) << "WhileNode not supported in LiftAttrScope."; + return Stmt(); + } + private: // value comparison that also compares content of int constant static bool ValueSame(const PrimExpr& a, const PrimExpr& b) { diff --git a/src/tir/transforms/storage_access.cc b/src/tir/transforms/storage_access.cc index be20724ae207..38143c14b021 100644 --- a/src/tir/transforms/storage_access.cc +++ b/src/tir/transforms/storage_access.cc @@ -180,6 +180,19 @@ void StorageAccessVisitor::VisitStmt_(const IfThenElseNode* op) { --condition_counter_; } +void StorageAccessVisitor::VisitStmt_(const WhileNode* op) { + ++condition_counter_; + this->VisitExpr(op->condition); + scope_.push_back(std::vector()); + this->VisitStmt(op->body); + StmtEntry s; + s.stmt = op; + s.access = Summarize(std::move(scope_.back()), nullptr); + scope_.pop_back(); + scope_.back().emplace_back(std::move(s)); + --condition_counter_; +} + void StorageAccessVisitor::VisitExpr_(const CallNode* op) { if (op->op.same_as(builtin::address_of())) { const LoadNode* l = op->args[0].as(); diff --git a/src/tir/transforms/storage_access.h b/src/tir/transforms/storage_access.h index 80bbff4c1fe4..663c570fd15c 100644 --- a/src/tir/transforms/storage_access.h +++ b/src/tir/transforms/storage_access.h @@ -84,6 +84,7 @@ class StorageAccessVisitor : public StmtExprVisitor { void VisitStmt_(const AttrStmtNode* op) final; void VisitStmt_(const ForNode* op) final; void VisitStmt_(const IfThenElseNode* op) final; + void VisitStmt_(const WhileNode* op) final; void VisitExpr_(const CallNode* op) final; protected: diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 0b1429ca7efa..36eeddb17d89 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -192,6 +192,8 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { void VisitStmt_(const ForNode* op) final { VisitNewScope(op); } + void VisitStmt_(const WhileNode* op) final { VisitNewScope(op); } + void VisitStmt_(const AssertStmtNode* op) final { VisitNewScope(op); } // linearized access sequence. @@ -244,6 +246,8 @@ class InplaceOpVerifier : public StmtExprVisitor { VisitStmt_(static_cast(stmt)); } else if (stmt->IsInstance()) { VisitStmt_(static_cast(stmt)); + } else if (stmt->IsInstance()) { + VisitStmt_(static_cast(stmt)); } else if (stmt->IsInstance()) { VisitStmt_(static_cast(stmt)); } else { @@ -350,16 +354,7 @@ class StoragePlanRewriter : public StmtExprMutator { // start rewrite stmt = operator()(std::move(stmt)); if (attach_map_.count(nullptr)) { - std::vector nest; - for (StorageEntry* e : attach_map_.at(nullptr)) { - // ICHECK_EQ(e->scope.rank, 0); - if (e->new_alloc.defined()) { - nest.emplace_back(AttrStmt(e->alloc_var, attr::storage_scope, - StringImm(e->scope.to_string()), Evaluate(0))); - nest.push_back(e->new_alloc); - } - } - stmt = MergeNest(nest, stmt); + return MakeAttach(attach_map_.at(nullptr), stmt); } return stmt; } @@ -437,6 +432,7 @@ class StoragePlanRewriter : public StmtExprMutator { return StmtExprMutator::VisitStmt_(op); } } + Stmt VisitStmt_(const ForNode* op) final { ICHECK(op->kind != ForKind::kVectorized) << "VectorizeLoop before LiftStorageAlloc"; // remake all the allocation at the attach scope. diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 66f4ae329f69..64956bc8ee54 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -388,6 +388,11 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->value); @@ -441,7 +446,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor 1: + if a % 2 == 1: + a = 3 * a + 1 + else: + a = a >> 1 + i += 1 + return i + + def collatz(ib, n, C): + i = ib.allocate("int32", (1,), name="i", scope="local") + a = ib.allocate("int32", (1,), name="a", scope="local") + i[0] = 0 + a[0] = n + with ib.while_loop(a[0] > 1): + with ib.if_scope(tvm.tir.floormod(a[0], 2) == 1): + a[0] = 3 * a[0] + 1 + with ib.else_scope(): + a[0] = a[0] >> 1 + i[0] += 1 + + C[n] = i[0] + + def collatz_ir_cpu(C): + ib = tvm.tir.ir_builder.create() + n = C.shape[0] + C = ib.buffer_ptr(C) + + with ib.for_range(0, n, name="i", kind="parallel") as i: + collatz(ib, i, C) + + body = ib.get() + + return body + + n = 30 + + def check_target(target, ir): + C = te.extern( + (n,), + [], + lambda ins, outs: ir(outs[0]), + name="collatz", + dtype="int32", + ) + s = te.create_schedule(C.op) + + with tvm.transform.PassContext(opt_level=3): + func = tvm.build(s, [C], target) + + ctx = tvm.context(target, 0) + c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx) + func(c) + ref = np.array([collatz_ref(i) for i in range(n)]) + tvm.testing.assert_allclose(c.asnumpy(), ref) + + check_target("llvm", collatz_ir_cpu) + + +def test_while_mandel(): + n = 160 + shape = (n * 2, n) + t = 300 + + def mandel_ref(): + def complex_sqr(z): + return np.array([z[0] ** 2 - z[1] ** 2, z[1] * z[0] * 2]) + + pixels = np.zeros(shape) + + for i in range(pixels.shape[0]): + for j in range(pixels.shape[1]): + c = np.array([-0.8, np.cos(t) * 0.2]) + z = np.array([i / n - 1, j / n - 0.5]) * 2 + iterations = 0 + + while np.linalg.norm(z) < 20 and iterations < 50: + z = complex_sqr(z) + c + iterations += 1 + + pixels[i, j] = 1 - iterations * 0.02 + + return pixels + + def mandel(ib, i, j, pixels): + z = ib.allocate("float32", (2,), name="z", scope="local") + tmp = ib.allocate("float32", (1,), name="tmp", scope="local") + iterations = ib.allocate("int32", (1,), name="iterations", scope="local") + + z[0] = (i / float(n) - 1) * 2 + z[1] = (j / float(n) - 0.5) * 2 + iterations[0] = 0 + c = [-0.8, float(np.cos(t)) * 0.2] + + def norm(z): + return tvm.tir.sqrt(z[0] * z[0] + z[1] * z[1]) + + with ib.while_loop(tvm.tir.all(norm(z) < 20, iterations[0] < 50)): + tmp[0] = z[0] + z[0] = z[0] * z[0] - z[1] * z[1] + c[0] + z[1] = z[1] * tmp[0] * 2 + c[1] + iterations[0] += 1 + + pixels[i, j] = 1 - iterations[0] * 0.02 + + def mandel_ir_cpu(C): + ib = tvm.tir.ir_builder.create() + ny = C.shape[0] + nx = C.shape[1] + C = ib.buffer_ptr(C) + + with ib.for_range(0, ny, name="i", kind="parallel") as i: + with ib.for_range(0, nx, name="j") as j: + mandel(ib, i, j, C) + + body = ib.get() + + return body + + def mandel_ir_gpu(C): + ib = tvm.tir.ir_builder.create() + ny = C.shape[0] + nx = C.shape[1] + C = ib.buffer_ptr(C) + + bx = te.thread_axis("blockIdx.x") + tx = te.thread_axis("threadIdx.x") + by = te.thread_axis("blockIdx.y") + ty = te.thread_axis("threadIdx.y") + + max_threads = 16 + ib.scope_attr(bx, "thread_extent", tvm.tir.indexdiv(nx + max_threads - 1, max_threads)) + ib.scope_attr(tx, "thread_extent", max_threads) + ib.scope_attr(by, "thread_extent", tvm.tir.indexdiv(ny + max_threads - 1, max_threads)) + ib.scope_attr(ty, "thread_extent", max_threads) + + tidx = bx * max_threads + tx + tidy = by * max_threads + ty + + with ib.if_scope(tvm.tir.all(tidx < nx, tidy < ny)): + mandel(ib, tidy, tidx, C) + + body = ib.get() + + return body + + ref = mandel_ref() + + def check_target(target, ir): + if not tvm.testing.device_enabled(target): + return + + C = te.extern( + shape, + [], + lambda ins, outs: ir(outs[0]), + name="mandel_ir", + dtype="float32", + ) + s = te.create_schedule(C.op) + + with tvm.transform.PassContext(opt_level=3): + func = tvm.build(s, [C], target) + + ctx = tvm.context(target, 0) + c = tvm.nd.array(np.zeros(shape, dtype=C.dtype), ctx) + func(c) + tvm.testing.assert_allclose(c.asnumpy(), ref, rtol=1e-5, atol=1e-5) + + check_target("llvm", mandel_ir_cpu) + check_target("npvtx", mandel_ir_gpu) + check_target("cuda", mandel_ir_gpu) + + +def test_while_binary_search(): + def binary_search(ib, n, i, Aptr, Bptr, Cptr): + lo = ib.allocate("int32", (1,), name="lo", scope="local") + hi = ib.allocate("int32", (1,), name="hi", scope="local") + + lo[0] = 0 + hi[0] = n + v = Bptr[i] + + with ib.while_loop(lo[0] < hi[0]): + mid = lo[0] + (hi[0] - lo[0] >> 1) + with ib.if_scope(Aptr[mid] < v): + lo[0] = mid + 1 + with ib.else_scope(): + hi[0] = mid + + Cptr[i] = lo[0] + + def searchsorted_ir_cpu(A, B, C, n): + ib = tvm.tir.ir_builder.create() + Aptr = ib.buffer_ptr(A) + Bptr = ib.buffer_ptr(B) + Cptr = ib.buffer_ptr(C) + + with ib.for_range(0, n, name="i", kind="parallel") as i: + binary_search(ib, n, i, Aptr, Bptr, Cptr) + + body = ib.get() + + return body + + def searchsorted_ir_gpu(A, B, C, n): + ib = tvm.tir.ir_builder.create() + Aptr = ib.buffer_ptr(A) + Bptr = ib.buffer_ptr(B) + Cptr = ib.buffer_ptr(C) + + bx = te.thread_axis("blockIdx.x") + tx = te.thread_axis("threadIdx.x") + max_threads = 32 + ib.scope_attr(bx, "thread_extent", tvm.tir.indexdiv(n + max_threads - 1, max_threads)) + ib.scope_attr(tx, "thread_extent", max_threads) + tid = bx * max_threads + tx + + with ib.if_scope(tid < n): + binary_search(ib, n, tid, Aptr, Bptr, Cptr) + + body = ib.get() + + return body + + n = 1024 + dtype = "float32" + A = te.placeholder((n,), name="A", dtype=dtype) + B = te.placeholder((n,), name="B", dtype=dtype) + + def check_target(target, ir): + if not tvm.testing.device_enabled(target): + return + + C = te.extern( + A.shape, + [A, B], + lambda ins, outs: ir(ins[0], ins[1], outs[0], n), + name="searchsorted_ir", + dtype="int32", + ) + s = te.create_schedule(C.op) + + with tvm.transform.PassContext(opt_level=3): + func = tvm.build(s, [A, B, C], target) + + ctx = tvm.context(target, 0) + a_np = np.random.uniform(size=n).astype(A.dtype) + b_np = np.random.uniform(size=n).astype(B.dtype) + a_np = np.sort(a_np) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(b_np, ctx) + c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx) + func(a, b, c) + ref = np.searchsorted(a_np, b_np) + tvm.testing.assert_allclose(c.asnumpy(), ref) + + check_target("llvm", searchsorted_ir_cpu) + check_target("cuda", searchsorted_ir_gpu) + check_target("nvptx", searchsorted_ir_gpu) + + if __name__ == "__main__": test_prefetch() test_if() test_for() test_cpu() test_gpu() + test_while_vectorize() + test_while_collatz() + test_while_mandel() + test_while_binary_search() diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py b/tests/python/unittest/test_tir_transform_storage_rewrite.py index 49adcfb568a7..dbe7e04700d9 100644 --- a/tests/python/unittest/test_tir_transform_storage_rewrite.py +++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py @@ -298,6 +298,76 @@ def test_parallel_alloc(): assert isinstance(body.body.body.body.body, tvm.tir.Allocate) +def test_while_alloc(): + def get_mod(kind="serial"): + ib = tvm.tir.ir_builder.create() + n = te.var("n") + with ib.for_range(0, n, name="i", kind=kind) as i: + j = ib.allocate("int32", 1, name="j", scope="global") + j[0] = 0 + with ib.while_loop(j[0] < 10): + A = ib.allocate("float32", n, name="A", scope="global") + A[j[0]] = A[j[0]] + 2 + j[0] += j[0] + 1 + + body = ib.get() + return tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) + + mod = get_mod(kind="parallel") + # parallel (i, 0, n) { + # // attr [j] storage_scope = "global" + # allocate j[int32 * 1] + # j[0] = 0 + # while((j[0] < 10)){ + # // attr [A] storage_scope = "global" + # allocate A[float32 * n] + # A[j[0]] = (A[j[0]] + 2f) + # j[0] = (j[0] + (j[0] + 1)) + # } + # } + body = tvm.tir.transform.StorageRewrite()(mod)["main"].body + # parallel (i, 0, n) { + # // attr [j] storage_scope = "global" + # allocate j[int32 * 1] + # // attr [A] storage_scope = "global" + # allocate A[float32 * n] + # j[0] = 0 + # while((j[0] < 10)){ + # A[j[0]] = (A[j[0]] + 2f) + # j[0] = (j[0] + (j[0] + 1)) + # } + # } + assert isinstance(body.body.body, tvm.tir.Allocate) # j + assert isinstance(body.body.body.body.body, tvm.tir.Allocate) # A + + mod = get_mod(kind="serial") + # for (i, 0, n) { + # // attr [j] storage_scope = "global" + # allocate j[int32 * 1] + # j[0] = 0 + # while((j[0] < 10)){ + # // attr [A] storage_scope = "global" + # allocate A[float32 * n] + # A[j[0]] = (A[j[0]] + 2f) + # j[0] = (j[0] + (j[0] + 1)) + # } + # } + body = tvm.tir.transform.StorageRewrite()(mod)["main"].body + # // attr [j] storage_scope = "global" + # allocate j[int32 * 1] + # // attr [A] storage_scope = "global" + # allocate A[float32 * n] + # for (i, 0, n) { + # j[0] = 0 + # while((j[0] < 10)){ + # A[j[0]] = (A[j[0]] + 2f) + # j[0] = (j[0] + (j[0] + 1)) + # } + # } + assert isinstance(body.body, tvm.tir.Allocate) # j + assert isinstance(body.body.body.body, tvm.tir.Allocate) # A + + def test_inplace_rule2(scope_tb="local_TB2", max_bits=1024 * 1024 * 1024): # Test Buffer register_mem(scope_tb, max_bits) @@ -576,6 +646,7 @@ def verify(n): test_alloc_different_dtypes() test_inplace_rule() test_parallel_alloc() + test_while_alloc() test_storage_combine() test_storage_share_gpu() test_inplace_rule2() diff --git a/tests/python/unittest/test_tir_transform_vectorize.py b/tests/python/unittest/test_tir_transform_vectorize.py index 5ae47e01f681..b1e580957b24 100644 --- a/tests/python/unittest/test_tir_transform_vectorize.py +++ b/tests/python/unittest/test_tir_transform_vectorize.py @@ -158,6 +158,53 @@ def test_vectorize_if_then_else(): assert isinstance(stmt.body.value.args[2], tvm.tir.Broadcast) +def test_vectorize_while_fail(): + """A while loop inside a vectorized loop should fail.""" + + n = 64 + num_iter = 10 + + def test_ir(A, B, C): + ib = tvm.tir.ir_builder.create() + n = C.shape[0] + A = ib.buffer_ptr(A) + B = ib.buffer_ptr(B) + C = ib.buffer_ptr(C) + i = ib.allocate("int32", (1,), name="i", scope="local") + i[0] = 0 + + with ib.for_range(0, n) as j: + C[j] = 0.0 + + with ib.for_range(0, n, kind="vectorize") as j: + with ib.while_loop(i[0] < num_iter): + C[j] += A[j] + B[j] + i[0] += 1 + + return ib.get() + + dtype = "float32" + A = te.placeholder((n,), name="A", dtype=dtype) + B = te.placeholder((n,), name="B", dtype=dtype) + + C = te.extern( + (n,), + [A, B], + lambda ins, outs: test_ir(ins[0], ins[1], outs[0]), + name="while_vectorize", + dtype=dtype, + ) + s = te.create_schedule(C.op) + + try: + tvm.lower(s, [A, B, C], "llvm") + assert False + except tvm.error.TVMError as e: + error_msg = str(e).split("\n")[-1] + expected = "A while loop inside a vectorized loop not supported" + assert expected in error_msg + + if __name__ == "__main__": test_vectorize_vector() test_vectorize_with_if() @@ -166,3 +213,4 @@ def test_vectorize_if_then_else(): test_vectorize_with_le_cond() test_vectorize_with_ge_cond() test_vectorize_let() + test_vectorize_while_fail()