diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 8cef462b0257..92a5af43461e 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -49,6 +49,14 @@ TVM_DLL const Op& ret(); * \brief Return from a GPU thread. */ TVM_DLL const Op& thread_return(); +/*! + * \brief Loop continue. + */ +TVM_DLL const Op& continue_loop(); +/*! + * \brief Loop break. + */ +TVM_DLL const Op& break_loop(); /*! * \brief Reinterpret the value using the target type. */ diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index e1be6834fe2b..6a0f427b807d 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -99,6 +99,20 @@ TVM_DLL PrimExpr ret(PrimExpr value, Span span = Span()); */ TVM_DLL PrimExpr thread_return(Span span = Span()); +/*! + * \brief Continue current loop. + * \param span The location of this operation in the source. + * \return The continue loop expression. + */ +TVM_DLL PrimExpr continue_loop(Span span = Span()); + +/*! + * \brief Break current loop. + * \param span The location of this operation in the source. + * \return The break loop expression. + */ +TVM_DLL PrimExpr break_loop(Span span = Span()); + /*! * Query the maximum possible value of dtype. * \param dtype The data type. diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index aa827d96bd15..1b8041e36cc1 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1310,6 +1310,9 @@ constexpr const char* explicit_read_region = "explicit_read_region"; */ constexpr const char* explicit_write_region = "explicit_write_region"; +/*! \brief ,ark a ForNode represent an irregular loop of non-structural control flow edges. */ +constexpr const char* irregular_loop_mark = "irregular_loop_mark"; + /*! * \brief Check if attr_key is a pragma key extension * \param attr_key The attr key to be compared diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index ed41ac9bfb56..6d746d73b1be 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1917,6 +1917,8 @@ def wrapped(*args, **kwargs): q_multiply_shift = _op_wrapper(_tir_op.q_multiply_shift) q_multiply_shift_per_axis = _op_wrapper(_tir_op.q_multiply_shift_per_axis) ret = _op_wrapper(_tir_op.ret) +continue_loop = _op_wrapper(_tir_op.continue_loop) +break_loop = _op_wrapper(_tir_op.break_loop) round = _op_wrapper(_tir_op.round) # pylint: disable=redefined-builtin rsqrt = _op_wrapper(_tir_op.rsqrt) shift_left = _op_wrapper(_tir_op.shift_left) @@ -2195,6 +2197,8 @@ def wrapped(*args, **kwargs): "q_multiply_shift", "q_multiply_shift_per_axis", "ret", + "continue_loop", + "break_loop", "reinterpret", "round", "rsqrt", diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py index 80d272899345..e81ff0657f8b 100644 --- a/python/tvm/script/parser/core/parser.py +++ b/python/tvm/script/parser/core/parser.py @@ -872,6 +872,36 @@ def visit_Return(self, node: doc.Return) -> Any: # pylint: disable=invalid-name """ return _dispatch(self, "Return")(self, node) + def visit_Continue(self, node: doc.Continue) -> Any: # pylint: disable=invalid-name + """The general continue visiting method. + + Parameters + ---------- + node : doc.Continue + The doc AST continue node. + + Returns + ------- + res : Any + The visiting result. + """ + return _dispatch(self, "Continue")(self, node) + + def visit_Break(self, node: doc.Break) -> Any: # pylint: disable=invalid-name + """The general break visiting method. + + Parameters + ---------- + node : doc.Break + The doc AST break node. + + Returns + ------- + res : Any + The visiting result. + """ + return _dispatch(self, "Break")(self, node) + def visit_Nonlocal(self, node: doc.Nonlocal) -> Any: # pylint: disable=invalid-name """The general nonlocal visiting method. diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index f6141404fa40..85ab1982f384 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -353,7 +353,8 @@ def visit_with(self: Parser, node: doc.With) -> None: frame = self.eval_expr(item.context_expr) if not isinstance(frame, Frame): self.report_error( - item.context_expr, "Invalid context expression in the with-statement." + item.context_expr, + "Invalid context expression in the with-statement.", ) rhs = stack.enter_context(frame) if item.optional_vars is not None: @@ -498,7 +499,8 @@ def visit_if(self: Parser, node: doc.If) -> None: self.visit_body(node.orelse) else: self.report_error( - node.test, f"If condition must be a boolean expression, but got {predicate}" + node.test, + f"If condition must be a boolean expression, but got {predicate}", ) @@ -539,6 +541,36 @@ def visit_return(self: Parser, node: doc.Return) -> None: T.evaluate(tvm.tir.ret(value)) +@dispatch.register(token="tir", type_name="Continue") +def visit_continue(self: Parser, node: doc.Continue) -> None: # pylint:disable=unused-argument + """The continue visiting method for tir. + + Parameters + ---------- + self : Parser + The visiting parser. + + node : doc.Continue + The doc AST continue node. + """ + T.evaluate(tvm.tir.continue_loop()) + + +@dispatch.register(token="tir", type_name="Break") +def visit_break(self: Parser, node: doc.Break) -> None: # pylint:disable=unused-argument + """The continue visiting method for tir. + + Parameters + ---------- + self : Parser + The visiting parser. + + node : doc.Break + The doc AST break node. + """ + T.evaluate(tvm.tir.break_loop()) + + @dispatch.register(token="tir", type_name="tvm_declare_function") def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar: """The function declaration step for tir diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 120d652dd817..0a598e5e9bb9 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -50,7 +50,13 @@ from .op import tvm_stack_alloca, tvm_stack_make_shape, tvm_stack_make_array from .op import tvm_tuple, handle_add_byte_offset, tvm_struct_get, tvm_struct_set from .op import address_of, lookup_param, assume, undef -from .op import tvm_thread_allreduce, type_annotation, tvm_access_ptr, tvm_throw_last_error +from .op import continue_loop, break_loop +from .op import ( + tvm_thread_allreduce, + type_annotation, + tvm_access_ptr, + tvm_throw_last_error, +) from .op import ( tvm_load_matrix_sync, tvm_store_matrix_sync, @@ -86,7 +92,18 @@ from .op import tan, tanh, atan, atan2, atanh from .op import bitwise_and, bitwise_not, bitwise_or, bitwise_xor from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot -from .op import trunc, abs, round, nextafter, nearbyint, power, pow, popcount, fmod, if_then_else +from .op import ( + trunc, + abs, + round, + nextafter, + nearbyint, + power, + pow, + popcount, + fmod, + if_then_else, +) from .op import likely, isnan, isnullptr, isfinite, isinf, copysign from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod, ceildiv, logaddexp from .op import comm_reducer, min, max, sum diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index fcbc47961625..9a912bbb6b63 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -1884,8 +1884,7 @@ def ret(val, span=None): def thread_return(span=None): - """Return from a GPU thread. - + """Return from a GPU thread Parameters ---------- span : Optional[Span] @@ -1900,6 +1899,40 @@ def thread_return(span=None): return _ffi_api.thread_return(span) +def continue_loop(span=None): + """Create a tir intrinsic call to represent continue expression + + Parameters + ---------- + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + ret : PrimExpr + The continue expression + """ + + return _ffi_api.continue_loop(span) + + +def break_loop(span=None): + """Create a tir intrinsic call to represent break expression + + Parameters + ---------- + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + ret : PrimExpr + The break expression + """ + + return _ffi_api.break_loop(span) + + def any(*args, span=None): """Create a new experssion of the union of all conditions in the arguments diff --git a/python/tvm/tir/pipeline.py b/python/tvm/tir/pipeline.py index ae78b0573822..22cec3033497 100644 --- a/python/tvm/tir/pipeline.py +++ b/python/tvm/tir/pipeline.py @@ -43,6 +43,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I tir.transform.LowerMatchBuffer(), tir.transform.Simplify(), tir.transform.InjectPermutedLayout(), + tir.transform.AnnotateIrregularLoop(), tir.transform.InjectSoftwarePipeline(), tir.transform.TransformMmaBufferLayout(), tir.transform.LowerOpaqueBlock(), diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index bf02529194e3..de11d30fbc6e 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -430,6 +430,19 @@ def AnnotateDeviceRegions(): return _ffi_api.AnnotateDeviceRegions() # type: ignore +def AnnotateIrregularLoop(): + """Annotate irregular loop mark. Loop transformations like + peeling, partition, unroll, etc is not allowed on irregular + loop with internal loop continuation and breaks. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.AnnotateIrregularLoop() # type: ignore + + def SplitHostDevice(): """Split the function into a host function and device functions. diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 895cdae23107..d9ee9723216c 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -511,6 +511,7 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { std::swap(analyzer_, parent_->analyzer_); std::swap(var_map_, parent_->var_map_); std::swap(di_subprogram_, parent_->di_subprogram_); + std::swap(loop_frame_jump_tgts_, parent_->loop_frame_jump_tgts_); } void ExitWithScope() { @@ -518,11 +519,13 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { std::swap(analyzer_, parent_->analyzer_); std::swap(var_map_, parent_->var_map_); std::swap(di_subprogram_, parent_->di_subprogram_); + std::swap(loop_frame_jump_tgts_, parent_->loop_frame_jump_tgts_); } llvm::Function* function_{nullptr}; llvm::DISubprogram* di_subprogram_{nullptr}; std::unordered_map var_map_; + std::vector> loop_frame_jump_tgts_; std::unique_ptr analyzer_{std::make_unique()}; CodeGenCPU* parent_; }; diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 48d576f12efa..bdb0c6b7389f 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -775,6 +775,12 @@ std::unique_ptr CodeGenLLVM::CreateDebugInfo(llvm::Modul return debug_info; } +void CodeGenLLVM::PushLoopFrame(llvm::BasicBlock* backedge_tgt, llvm::BasicBlock* exit_tgt) { + loop_frame_jump_tgts_.emplace_back(backedge_tgt, exit_tgt); +} + +void CodeGenLLVM::PopLoopFrame() { loop_frame_jump_tgts_.pop_back(); } + llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) { int num_elems = GetVectorNumElements(vec); if (extent == num_elems && begin == 0) return vec; @@ -878,6 +884,7 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Va auto* for_begin = llvm::BasicBlock::Create(*ctx, "for_begin_" + loop_var_name, function_); auto* for_body = llvm::BasicBlock::Create(*ctx, "for_body_" + loop_var_name, function_); auto* for_end = llvm::BasicBlock::Create(*ctx, "for_end_" + loop_var_name, function_); + auto* for_next = llvm::BasicBlock::Create(*ctx, "for_next_" + loop_var_name, function_); builder_->CreateBr(for_begin); builder_->SetInsertPoint(for_begin); @@ -892,8 +899,13 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Va builder_->SetInsertPoint(for_body); EmitDebugLocation(body->span); + PushLoopFrame(for_next, for_end); this->VisitStmt(body); + PopLoopFrame(); var_map_.erase(loop_var.get()); + + builder_->CreateBr(for_next); + builder_->SetInsertPoint(for_next); llvm::Value* loop_next = CreateAdd(loop_var.dtype(), loop_value, stride); loop_value->addIncoming(loop_next, builder_->GetInsertBlock()); builder_->CreateBr(for_begin); @@ -1466,6 +1478,26 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { llvm::BasicBlock::Create(*llvm_target_->GetContext(), "ret_dummy", function_); builder_->SetInsertPoint(ret_dummy); return ret_dummy; + } else if (op->op.same_as(builtin::continue_loop())) { + ICHECK(!loop_frame_jump_tgts_.empty()) + << "the tir.continue_loop should be inserted under at least one For or While stmts."; + builder_->CreateBr(loop_frame_jump_tgts_.back().first); + // LLVM allows exactly one terminator in a single basic block + // append a new dummy basic block to avoid error. + llvm::BasicBlock* post_dummy = + llvm::BasicBlock::Create(*llvm_target_->GetContext(), "post_cont_dummy", function_); + builder_->SetInsertPoint(post_dummy); + return post_dummy; + } else if (op->op.same_as(builtin::break_loop())) { + ICHECK(!loop_frame_jump_tgts_.empty()) + << "the tir.break_loop should be inserted under at least one For or While stmts."; + builder_->CreateBr(loop_frame_jump_tgts_.back().second); + // LLVM allows exactly one terminator in a single basic block + // append a new dummy basic block to avoid error. + llvm::BasicBlock* post_dummy = + llvm::BasicBlock::Create(*llvm_target_->GetContext(), "post_break_dummy", function_); + builder_->SetInsertPoint(post_dummy); + return post_dummy; } else if (op->op.same_as(builtin::reinterpret())) { llvm::Type* target = DTypeToLLVMType(op->dtype); return builder_->CreateBitCast(MakeValue(op->args[0]), target); @@ -2010,7 +2042,9 @@ void CodeGenLLVM::VisitStmt_(const WhileNode* op) { builder_->SetInsertPoint(while_cond); builder_->CreateCondBr(MakeValue(op->condition), while_body, while_merge); builder_->SetInsertPoint(while_body); + PushLoopFrame(while_cond, while_merge); this->VisitStmt(op->body); + PopLoopFrame(); builder_->CreateBr(while_cond); builder_->SetInsertPoint(while_merge); } diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index cdaac859e430..5cf053cf7103 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -617,6 +617,13 @@ class CodeGenLLVM : public ExprFunctor, * initializes file and compilation_unit_ to TVM defaults. */ static std::unique_ptr CreateDebugInfo(llvm::Module* module); + + void PushLoopFrame(llvm::BasicBlock* backedge_tgt, llvm::BasicBlock* exit_tgt); + void PopLoopFrame(); + + // loop frame's jump target for continue and break generation + // store basic block pair (blk to backedge, blk to exit) for each frame. + std::vector> loop_frame_jump_tgts_; }; inline int CodeGenLLVM::GetVectorNumElements(llvm::Value* vec) { diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index ddd904c555a2..8ebd41645aa2 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -612,6 +612,10 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) if (op->op.same_as(builtin::ret())) { os << "return "; PrintExpr(op->args[0], os); + } else if (op->op.same_as(builtin::continue_loop())) { + os << "continue;"; + } else if (op->op.same_as(builtin::break_loop())) { + os << "break;"; } else if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) { ICHECK_GE(op->args.size(), 1U); auto func = Downcast(op->args[0]); diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 81aeffe46a9d..d33a01340b96 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -216,7 +216,6 @@ std::ostream& operator<<(std::ostream& out, ForKind type) { // NOLINT(*) While::While(PrimExpr condition, Stmt body, Span span) { ICHECK(condition.defined()); ICHECK(condition.dtype().is_scalar()); - ICHECK(condition.as() == nullptr) << "The condition should not be trivial."; ICHECK(body.defined()); ObjectPtr node = ffi::make_object(); diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index fe095dbaa593..f04842f40e53 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -52,6 +52,14 @@ TIR_DEFINE_BUILTIN_FUNC(thread_return) .set_attr("TCallEffectKind", Integer(CallEffectKind::kControlJump)) .set_num_inputs(0); +TIR_DEFINE_BUILTIN_FUNC(continue_loop) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kControlJump)) + .set_num_inputs(0); + +TIR_DEFINE_BUILTIN_FUNC(break_loop) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kControlJump)) + .set_num_inputs(0); + TIR_DEFINE_BUILTIN_FUNC(likely) .set_num_inputs(1) .set_attr("TCallEffectKind", Integer(CallEffectKind::kExprAnnotation)) diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 700bc5f0e486..935f9928a508 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -246,19 +246,26 @@ PrimExpr ret(PrimExpr value, Span span) { return tir::Call(value.dtype(), tir::builtin::ret(), {value}, span); } -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.ret", ret); -} - PrimExpr thread_return(Span span) { return tir::Call(DataType::Void(), tir::builtin::thread_return(), {}, span); } +PrimExpr continue_loop(Span span) { + return tir::Call(DataType::Void(), tir::builtin::continue_loop(), {}, span); +} + +PrimExpr break_loop(Span span) { + return tir::Call(DataType::Void(), tir::builtin::break_loop(), {}, span); +} + TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.thread_return", thread_return); -} + refl::GlobalDef() + .def("tir.ret", ret) + .def("tir.thread_return", thread_return) + .def("tir.continue_loop", continue_loop) + .def("tir.break_loop", break_loop); +}; // maximum and min limits PrimExpr max_value(const DataType& dtype, Span span) { diff --git a/src/tir/transforms/annotate_irregular_loop.cc b/src/tir/transforms/annotate_irregular_loop.cc new file mode 100644 index 000000000000..c715922d60b3 --- /dev/null +++ b/src/tir/transforms/annotate_irregular_loop.cc @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace tir { + +class IrregularLoopAnnotator : public StmtMutator { + public: + static Stmt Annotate(const Stmt& body) { return IrregularLoopAnnotator().VisitStmt(body); } + + private: + IrregularLoopAnnotator() = default; + + Stmt VisitStmt_(const ForNode* op) final { + bool cur_has_jump = has_jump_; + has_jump_ = false; + For res = Downcast(StmtMutator::VisitStmt_(op)); + if (has_jump_) { + CHECK(op->kind == ForKind::kSerial) + << "Loop kind " << op->kind << " is invalid for irregular loop " << op->loop_var; + for (const char* key : {attr::pragma_auto_unroll_max_step, attr::pragma_unroll_explicit, + attr::pragma_loop_partition_hint, attr::software_pipeline_stage}) { + CHECK(!res->annotations.count(key)) + << "Annotation `" << key << "` is invalid for irregular loop " << op->loop_var; + } + res.CopyOnWrite()->annotations.Set(attr::irregular_loop_mark, 1); + } + std::swap(cur_has_jump, has_jump_); + return res; + } + + Stmt VisitStmt_(const WhileNode* op) final { + bool cur_has_jump = has_jump_; + has_jump_ = false; + Stmt res = StmtMutator::VisitStmt_(op); + std::swap(cur_has_jump, has_jump_); + return res; + } + + Stmt VisitStmt_(const EvaluateNode* op) final { + if (const CallNode* call = op->value.as()) { + if (call->op.same_as(builtin::continue_loop()) || call->op.same_as(builtin::break_loop())) { + has_jump_ = true; + } + } + return ffi::GetRef(op); + } + + bool has_jump_{false}; +}; + +namespace transform { + +Pass AnnotateIrregularLoop() { + auto pass_func = [](PrimFunc func, IRModule mod, PassContext ctx) -> PrimFunc { + func.CopyOnWrite()->body = IrregularLoopAnnotator::Annotate(func->body); + return func; + }; + + return CreatePrimFuncPass(pass_func, 0, "tir.AnnotateIrregularLoop", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.AnnotateIrregularLoop", AnnotateIrregularLoop); +} + +} // namespace transform +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/lower_opaque_block.cc b/src/tir/transforms/lower_opaque_block.cc index bbe550fe35e4..2e53e89667cc 100644 --- a/src/tir/transforms/lower_opaque_block.cc +++ b/src/tir/transforms/lower_opaque_block.cc @@ -90,8 +90,10 @@ class OpaqueBlockLower : public StmtExprMutator { // handling unit loop unit_loop_vars_[op->loop_var] = min; } + // Step 2. Visit recursively Stmt body = this->VisitStmt(op->body); + // Step 3. Handle annotations std::vector> pragma_attrs; ffi::Map new_annotations = @@ -102,7 +104,8 @@ class OpaqueBlockLower : public StmtExprMutator { ICHECK(op->thread_binding.defined()); ffi::String thread_tag = op->thread_binding.value()->thread_tag; body = MakeLaunchThread(min, extent, op->loop_var, thread_tag, body); - } else if (is_one(extent) && op->annotations.empty()) { + } else if (is_one(extent) && op->annotations.empty() && + !op->annotations.count(attr::irregular_loop_mark)) { // Case 2. Unit loop return body; } else { diff --git a/tests/python/tir-base/test_tir_base.py b/tests/python/tir-base/test_tir_base.py index d204ebfb6084..b23c600b15b8 100644 --- a/tests/python/tir-base/test_tir_base.py +++ b/tests/python/tir-base/test_tir_base.py @@ -14,11 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import numpy as np import tvm import pytest from tvm import tir from tvm.base import TVMError from tvm.ir.transform import PassContext +from tvm.script import tir as T import itertools import pytest @@ -113,6 +115,61 @@ def test_control_flow_jump(): assert out == 1.0 +def test_break_loop(): + @T.prim_func + def func(In: T.Buffer[(2,), "int32"], Out: T.Buffer[(2,), "int32"]): + Out[0] = 0 + Out[1] = 1 + for i in range(10): + for j in range(10): + if i * 10 + j == In[0]: + Out[0] = i + j + break + if Out[0] > 0: + break + while Out[1] > 0: + Out[1] = Out[1] + 1 + if Out[1] > In[1]: + break + + func = build_tir_func(func) + a = np.asarray([49, 8], "int32") + b = np.zeros([2], "int32") + if not hasattr(b, "__dlpack__"): + return + func(a, b) + assert b[0] == 13 + assert b[1] == 9 + + +def test_continue_loop(): + @T.prim_func + def func(Out: T.Buffer[(2,), "int32"]): + T.func_attr({"global_symbol": "main"}) + Out[0] = 0 + Out[1] = 0 + for i in range(10): + for j in range(10): + if (i * 10 + j) % 3 != 0: + continue + Out[0] = Out[0] + 1 + k = T.decl_buffer([], "int32") + k[()] = 0 + while k[()] < Out[0]: + k[()] = k[()] + 1 + if k[()] % 6 == 0: + Out[1] = Out[1] + 1 + continue + + func = build_tir_func(func) + b = np.zeros([2], "int32") + if not hasattr(b, "__dlpack__"): + return + func(b) + assert b[0] == 34 + assert b[1] == 5 # 6, 12, 18, 24, 30 + + def test_exception(): with pytest.raises(TypeError): x = tir.Var(name=1, dtype="int") diff --git a/tests/python/tir-transform/test_tir_transform_annotate_irregular_loop.py b/tests/python/tir-transform/test_tir_transform_annotate_irregular_loop.py new file mode 100644 index 000000000000..fa46ef36403c --- /dev/null +++ b/tests/python/tir-transform/test_tir_transform_annotate_irregular_loop.py @@ -0,0 +1,203 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tests for AnnotateIrregularLoop""" + +import tvm +import tvm.testing +from tvm import tir +from tvm.script import tir as T + + +def test_handle_irrgular_unit_loop(): + """Dedicated testcase to check the unitloop with loop jump not simplified""" + + @T.prim_func + def before(A: T.Buffer((10,), "int32")): + for i in T.serial(1): + if A[i] > 5: + break + A[i] = A[i] + 1 + for j in T.serial(1): + if A[j] > 5: + continue + A[j] = A[j] + 1 + for k in T.serial(1): + A[k] = A[k] + 1 + + @T.prim_func + def expected(A: T.Buffer((10,), "int32")): + for i in T.serial(1, annotations={"irregular_loop_mark": 1}): + if A[i] > 5: + break + A[i] = A[i] + 1 + for j in T.serial(1, annotations={"irregular_loop_mark": 1}): + if A[j] > 5: + continue + A[j] = A[j] + 1 + A[0] = A[0] + 1 + + mod = tvm.IRModule.from_expr(before) + mod = tvm.tir.transform.AnnotateIrregularLoop()(mod) + mod = tvm.tir.transform.LowerOpaqueBlock()(mod) + tvm.ir.assert_structural_equal(mod["before"].with_attr("global_symbol", "expected"), expected) + + +class BaseCompare(tvm.testing.CompareBeforeAfter): + transform = tir.transform.AnnotateIrregularLoop() + + +class TestAnnotateLoopWithBreak(BaseCompare): + """Test that loops containing break statements are annotated as irregular.""" + + def before(A: T.Buffer((10,), "int32")): + for i in T.serial(10): + if A[i] > 5: + break + A[i] = A[i] + 1 + + def expected(A: T.Buffer((10,), "int32")): + for i in T.serial(10, annotations={"irregular_loop_mark": 1}): + if A[i] > 5: + break + A[i] = A[i] + 1 + + +class TestAnnotateLoopWithContinue(BaseCompare): + """Test that loops containing continue statements are annotated as irregular.""" + + def before(A: T.Buffer((10,), "int32")): + for i in T.serial(10): + if A[i] < 0: + continue + A[i] = A[i] * 2 + + def expected(A: T.Buffer((10,), "int32")): + for i in T.serial(10, annotations={"irregular_loop_mark": 1}): + if A[i] < 0: + continue + A[i] = A[i] * 2 + + +class TestNestedIrregularBothLoops(BaseCompare): + """Test nested loops where both loops have break/continue.""" + + def before(A: T.Buffer((10, 10), "int32")): + for i in T.serial(10): + if i > 7: + break + for j in T.serial(10): + if A[i, j] < 0: + continue + A[i, j] = A[i, j] + 1 + + def expected(A: T.Buffer((10, 10), "int32")): + for i in T.serial(10, annotations={"irregular_loop_mark": 1}): + if i > 7: + break + for j in T.serial(10, annotations={"irregular_loop_mark": 1}): + if A[i, j] < 0: + continue + A[i, j] = A[i, j] + 1 + + +class TestWhileLoopWithBreak(BaseCompare): + """Test that while loops with break/continue are not annotated (while loops don't have annotations).""" + + def before(A: T.Buffer((10,), "int32")): + i = T.int32(0) + while i < 10: + if A[i] > 5: + break + A[i] = A[i] + 1 + i = i + 1 + + def expected(A: T.Buffer((10,), "int32")): + i = T.int32(0) + while i < 10: + if A[i] > 5: + break + A[i] = A[i] + 1 + i = i + 1 + + +class TestBreakInNestedConditional(BaseCompare): + """Test break statement deeply nested in conditional blocks.""" + + def before(A: T.Buffer((10,), "int32"), flag1: T.int32, flag2: T.int32): + for i in T.serial(10): + if flag1 > 0: + if flag2 > 0: + if A[i] > 5: + break + A[i] = A[i] + 1 + + def expected(A: T.Buffer((10,), "int32"), flag1: T.int32, flag2: T.int32): + for i in T.serial(10, annotations={"irregular_loop_mark": 1}): + if flag1 > 0: + if flag2 > 0: + if A[i] > 5: + break + A[i] = A[i] + 1 + + +class TestWhileLoopWithBreakStandalone(BaseCompare): + """Test that while loops with break/continue are not annotated (while loops don't have annotations).""" + + def before(A: T.Buffer((10,), "int32")): + i = T.int32(0) + while i < 10: + if A[i] > 5: + break + A[i] = A[i] + 1 + i = i + 1 + + def expected(A: T.Buffer((10,), "int32")): + i = T.int32(0) + while i < 10: + if A[i] > 5: + break + A[i] = A[i] + 1 + i = i + 1 + + +class TestNestedIrregularLoopStandalone(BaseCompare): + """Test deeply nested loops with irregular control flow only in innermost loop.""" + + def before(A: T.Buffer((5, 5, 5), "int32")): + for i in T.serial(5): + for j in T.serial(5): + for k in T.serial(5): + if A[i, j, k] > 10: + break + if A[i, j, k] < 0: + continue + A[i, j, k] = A[i, j, k] + 1 + + def expected(A: T.Buffer((5, 5, 5), "int32")): + for i in T.serial(5): + for j in T.serial(5): + for k in T.serial(5, annotations={"irregular_loop_mark": 1}): + if A[i, j, k] > 10: + break + if A[i, j, k] < 0: + continue + A[i, j, k] = A[i, j, k] + 1 + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py b/tests/python/tvmscript/test_tvmscript_printer_tir.py index be8b03357dde..fc7deacd980d 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_tir.py +++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py @@ -1046,5 +1046,34 @@ def main(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float32")): _assert_print(main, expected_output) +def test_func_with_loop_jumps(): + from tvm.script import tir as T + + @T.prim_func + def main(a: T.handle, b: T.handle): + A = T.match_buffer(a, (4,), "float32") + B = T.match_buffer(b, (4,), "float32") + for i in range(1000): + if i % 13 == 0: + A[1] = A[1] + 1 + continue + if A[0] >= B[0]: + break + + expected_output = """ +# from tvm.script import tir as T + +@T.prim_func +def main(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float32")): + for i in range(1000): + if i % 13 == 0: + A[1] = A[1] + T.float32(1.0) + T.continue_loop() + if A[0] >= B[0]: + T.break_loop() + """ + _assert_print(main, expected_output) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index 2be2e2e98d81..1954ca773f14 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -4002,6 +4002,22 @@ def func( return func +def func_with_loop_jumps(): + @T.prim_func + def func(In: T.Buffer((1,), "int32"), Out: T.Buffer((2,), "int32")): + Out[0] = 0 + Out[1] = 0 + for i in range(1000): + if i % 13 == 0: + Out[1] = Out[1] + 1 + continue + Out[0] = Out[0] + 1 + if Out[0] >= In[0]: + break + + return func + + def op_of_literal(): op_list = [ (T.exp, 0), @@ -4220,6 +4236,7 @@ def func(A: R.Tensor(["N"], "float16"), _: R.Prim(value="threshold")): return_zero_private, return_zero_private_with_attr, func_attr_with_list, + func_with_loop_jumps, *op_of_literal(), *relax_match_cast_struct_info_proxy(), relax_symbolic_size_var, diff --git a/tests/python/tvmscript/test_tvmscript_syntax_sugar.py b/tests/python/tvmscript/test_tvmscript_syntax_sugar.py index 33880539eb5f..df8675704b67 100644 --- a/tests/python/tvmscript/test_tvmscript_syntax_sugar.py +++ b/tests/python/tvmscript/test_tvmscript_syntax_sugar.py @@ -506,5 +506,27 @@ def implicit(): assert_structural_equal_ignore_global_symbol(implicit, explicit) +def test_loop_jump_statement(): + """`break` and `continue` evaluates to TIR intrinsics""" + + @T.prim_func + def explicit(): + for i in range(16): + if i % 2 == 0: + T.evaluate(T.continue_loop()) + if i < 15: + T.evaluate(T.break_loop()) + + @T.prim_func + def implicit(): + for i in range(16): + if i % 2 == 0: + continue + if i < 15: + break + + assert_structural_equal_ignore_global_symbol(implicit, explicit) + + if __name__ == "__main__": tvm.testing.main()