Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ TVM_DLL const Op& ret();
* \brief Return from a GPU thread.
*/
TVM_DLL const Op& thread_return();
/*!
* \brief Loop continue.
*/
TVM_DLL const Op& continue_loop();
/*!
* \brief Loop break.
*/
TVM_DLL const Op& break_loop();
/*!
* \brief Reinterpret the value using the target type.
*/
Expand Down
14 changes: 14 additions & 0 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,20 @@ TVM_DLL PrimExpr ret(PrimExpr value, Span span = Span());
*/
TVM_DLL PrimExpr thread_return(Span span = Span());

/*!
* \brief Continue current loop.
* \param span The location of this operation in the source.
* \return The continue loop expression.
*/
TVM_DLL PrimExpr continue_loop(Span span = Span());

/*!
* \brief Break current loop.
* \param span The location of this operation in the source.
* \return The break loop expression.
*/
TVM_DLL PrimExpr break_loop(Span span = Span());

/*!
* Query the maximum possible value of dtype.
* \param dtype The data type.
Expand Down
3 changes: 3 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1310,6 +1310,9 @@ constexpr const char* explicit_read_region = "explicit_read_region";
*/
constexpr const char* explicit_write_region = "explicit_write_region";

/*! \brief ,ark a ForNode represent an irregular loop of non-structural control flow edges. */
constexpr const char* irregular_loop_mark = "irregular_loop_mark";

/*!
* \brief Check if attr_key is a pragma key extension
* \param attr_key The attr key to be compared
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1917,6 +1917,8 @@ def wrapped(*args, **kwargs):
q_multiply_shift = _op_wrapper(_tir_op.q_multiply_shift)
q_multiply_shift_per_axis = _op_wrapper(_tir_op.q_multiply_shift_per_axis)
ret = _op_wrapper(_tir_op.ret)
continue_loop = _op_wrapper(_tir_op.continue_loop)
break_loop = _op_wrapper(_tir_op.break_loop)
round = _op_wrapper(_tir_op.round) # pylint: disable=redefined-builtin
rsqrt = _op_wrapper(_tir_op.rsqrt)
shift_left = _op_wrapper(_tir_op.shift_left)
Expand Down Expand Up @@ -2195,6 +2197,8 @@ def wrapped(*args, **kwargs):
"q_multiply_shift",
"q_multiply_shift_per_axis",
"ret",
"continue_loop",
"break_loop",
"reinterpret",
"round",
"rsqrt",
Expand Down
30 changes: 30 additions & 0 deletions python/tvm/script/parser/core/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,36 @@ def visit_Return(self, node: doc.Return) -> Any: # pylint: disable=invalid-name
"""
return _dispatch(self, "Return")(self, node)

def visit_Continue(self, node: doc.Continue) -> Any: # pylint: disable=invalid-name
"""The general continue visiting method.

Parameters
----------
node : doc.Continue
The doc AST continue node.

Returns
-------
res : Any
The visiting result.
"""
return _dispatch(self, "Continue")(self, node)

def visit_Break(self, node: doc.Break) -> Any: # pylint: disable=invalid-name
"""The general break visiting method.

Parameters
----------
node : doc.Break
The doc AST break node.

Returns
-------
res : Any
The visiting result.
"""
return _dispatch(self, "Break")(self, node)

def visit_Nonlocal(self, node: doc.Nonlocal) -> Any: # pylint: disable=invalid-name
"""The general nonlocal visiting method.

Expand Down
36 changes: 34 additions & 2 deletions python/tvm/script/parser/tir/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,8 @@ def visit_with(self: Parser, node: doc.With) -> None:
frame = self.eval_expr(item.context_expr)
if not isinstance(frame, Frame):
self.report_error(
item.context_expr, "Invalid context expression in the with-statement."
item.context_expr,
"Invalid context expression in the with-statement.",
)
rhs = stack.enter_context(frame)
if item.optional_vars is not None:
Expand Down Expand Up @@ -498,7 +499,8 @@ def visit_if(self: Parser, node: doc.If) -> None:
self.visit_body(node.orelse)
else:
self.report_error(
node.test, f"If condition must be a boolean expression, but got {predicate}"
node.test,
f"If condition must be a boolean expression, but got {predicate}",
)


Expand Down Expand Up @@ -539,6 +541,36 @@ def visit_return(self: Parser, node: doc.Return) -> None:
T.evaluate(tvm.tir.ret(value))


@dispatch.register(token="tir", type_name="Continue")
def visit_continue(self: Parser, node: doc.Continue) -> None: # pylint:disable=unused-argument
"""The continue visiting method for tir.

Parameters
----------
self : Parser
The visiting parser.

node : doc.Continue
The doc AST continue node.
"""
T.evaluate(tvm.tir.continue_loop())


@dispatch.register(token="tir", type_name="Break")
def visit_break(self: Parser, node: doc.Break) -> None: # pylint:disable=unused-argument
"""The continue visiting method for tir.

Parameters
----------
self : Parser
The visiting parser.

node : doc.Break
The doc AST break node.
"""
T.evaluate(tvm.tir.break_loop())


@dispatch.register(token="tir", type_name="tvm_declare_function")
def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar:
"""The function declaration step for tir
Expand Down
21 changes: 19 additions & 2 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,13 @@
from .op import tvm_stack_alloca, tvm_stack_make_shape, tvm_stack_make_array
from .op import tvm_tuple, handle_add_byte_offset, tvm_struct_get, tvm_struct_set
from .op import address_of, lookup_param, assume, undef
from .op import tvm_thread_allreduce, type_annotation, tvm_access_ptr, tvm_throw_last_error
from .op import continue_loop, break_loop
from .op import (
tvm_thread_allreduce,
type_annotation,
tvm_access_ptr,
tvm_throw_last_error,
)
from .op import (
tvm_load_matrix_sync,
tvm_store_matrix_sync,
Expand Down Expand Up @@ -86,7 +92,18 @@
from .op import tan, tanh, atan, atan2, atanh
from .op import bitwise_and, bitwise_not, bitwise_or, bitwise_xor
from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot
from .op import trunc, abs, round, nextafter, nearbyint, power, pow, popcount, fmod, if_then_else
from .op import (
trunc,
abs,
round,
nextafter,
nearbyint,
power,
pow,
popcount,
fmod,
if_then_else,
)
from .op import likely, isnan, isnullptr, isfinite, isinf, copysign
from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod, ceildiv, logaddexp
from .op import comm_reducer, min, max, sum
Expand Down
37 changes: 35 additions & 2 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1884,8 +1884,7 @@ def ret(val, span=None):


def thread_return(span=None):
"""Return from a GPU thread.

"""Return from a GPU thread
Parameters
----------
span : Optional[Span]
Expand All @@ -1900,6 +1899,40 @@ def thread_return(span=None):
return _ffi_api.thread_return(span)


def continue_loop(span=None):
"""Create a tir intrinsic call to represent continue expression

Parameters
----------
span : Optional[Span]
The location of this operator in the source code.

Returns
-------
ret : PrimExpr
The continue expression
"""

return _ffi_api.continue_loop(span)


def break_loop(span=None):
"""Create a tir intrinsic call to represent break expression

Parameters
----------
span : Optional[Span]
The location of this operator in the source code.

Returns
-------
ret : PrimExpr
The break expression
"""

return _ffi_api.break_loop(span)


def any(*args, span=None):
"""Create a new experssion of the union of all conditions in the arguments

Expand Down
1 change: 1 addition & 0 deletions python/tvm/tir/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
tir.transform.LowerMatchBuffer(),
tir.transform.Simplify(),
tir.transform.InjectPermutedLayout(),
tir.transform.AnnotateIrregularLoop(),
tir.transform.InjectSoftwarePipeline(),
tir.transform.TransformMmaBufferLayout(),
tir.transform.LowerOpaqueBlock(),
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,19 @@ def AnnotateDeviceRegions():
return _ffi_api.AnnotateDeviceRegions() # type: ignore


def AnnotateIrregularLoop():
"""Annotate irregular loop mark. Loop transformations like
peeling, partition, unroll, etc is not allowed on irregular
loop with internal loop continuation and breaks.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.AnnotateIrregularLoop() # type: ignore


def SplitHostDevice():
"""Split the function into a host function and device functions.
Expand Down
3 changes: 3 additions & 0 deletions src/target/llvm/codegen_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -511,18 +511,21 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) {
std::swap(analyzer_, parent_->analyzer_);
std::swap(var_map_, parent_->var_map_);
std::swap(di_subprogram_, parent_->di_subprogram_);
std::swap(loop_frame_jump_tgts_, parent_->loop_frame_jump_tgts_);
}

void ExitWithScope() {
std::swap(function_, parent_->function_);
std::swap(analyzer_, parent_->analyzer_);
std::swap(var_map_, parent_->var_map_);
std::swap(di_subprogram_, parent_->di_subprogram_);
std::swap(loop_frame_jump_tgts_, parent_->loop_frame_jump_tgts_);
}

llvm::Function* function_{nullptr};
llvm::DISubprogram* di_subprogram_{nullptr};
std::unordered_map<const VarNode*, llvm::Value*> var_map_;
std::vector<std::pair<llvm::BasicBlock*, llvm::BasicBlock*>> loop_frame_jump_tgts_;
std::unique_ptr<arith::Analyzer> analyzer_{std::make_unique<arith::Analyzer>()};
CodeGenCPU* parent_;
};
Expand Down
34 changes: 34 additions & 0 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,12 @@ std::unique_ptr<CodeGenLLVM::DebugInfo> CodeGenLLVM::CreateDebugInfo(llvm::Modul
return debug_info;
}

void CodeGenLLVM::PushLoopFrame(llvm::BasicBlock* backedge_tgt, llvm::BasicBlock* exit_tgt) {
loop_frame_jump_tgts_.emplace_back(backedge_tgt, exit_tgt);
}

void CodeGenLLVM::PopLoopFrame() { loop_frame_jump_tgts_.pop_back(); }

llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) {
int num_elems = GetVectorNumElements(vec);
if (extent == num_elems && begin == 0) return vec;
Expand Down Expand Up @@ -878,6 +884,7 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Va
auto* for_begin = llvm::BasicBlock::Create(*ctx, "for_begin_" + loop_var_name, function_);
auto* for_body = llvm::BasicBlock::Create(*ctx, "for_body_" + loop_var_name, function_);
auto* for_end = llvm::BasicBlock::Create(*ctx, "for_end_" + loop_var_name, function_);
auto* for_next = llvm::BasicBlock::Create(*ctx, "for_next_" + loop_var_name, function_);
builder_->CreateBr(for_begin);
builder_->SetInsertPoint(for_begin);

Expand All @@ -892,8 +899,13 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Va
builder_->SetInsertPoint(for_body);
EmitDebugLocation(body->span);

PushLoopFrame(for_next, for_end);
this->VisitStmt(body);
PopLoopFrame();
var_map_.erase(loop_var.get());

builder_->CreateBr(for_next);
builder_->SetInsertPoint(for_next);
llvm::Value* loop_next = CreateAdd(loop_var.dtype(), loop_value, stride);
loop_value->addIncoming(loop_next, builder_->GetInsertBlock());
builder_->CreateBr(for_begin);
Expand Down Expand Up @@ -1466,6 +1478,26 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
llvm::BasicBlock::Create(*llvm_target_->GetContext(), "ret_dummy", function_);
builder_->SetInsertPoint(ret_dummy);
return ret_dummy;
} else if (op->op.same_as(builtin::continue_loop())) {
ICHECK(!loop_frame_jump_tgts_.empty())
<< "the tir.continue_loop should be inserted under at least one For or While stmts.";
builder_->CreateBr(loop_frame_jump_tgts_.back().first);
// LLVM allows exactly one terminator in a single basic block
// append a new dummy basic block to avoid error.
llvm::BasicBlock* post_dummy =
llvm::BasicBlock::Create(*llvm_target_->GetContext(), "post_cont_dummy", function_);
builder_->SetInsertPoint(post_dummy);
return post_dummy;
} else if (op->op.same_as(builtin::break_loop())) {
ICHECK(!loop_frame_jump_tgts_.empty())
<< "the tir.break_loop should be inserted under at least one For or While stmts.";
builder_->CreateBr(loop_frame_jump_tgts_.back().second);
// LLVM allows exactly one terminator in a single basic block
// append a new dummy basic block to avoid error.
llvm::BasicBlock* post_dummy =
llvm::BasicBlock::Create(*llvm_target_->GetContext(), "post_break_dummy", function_);
builder_->SetInsertPoint(post_dummy);
return post_dummy;
} else if (op->op.same_as(builtin::reinterpret())) {
llvm::Type* target = DTypeToLLVMType(op->dtype);
return builder_->CreateBitCast(MakeValue(op->args[0]), target);
Expand Down Expand Up @@ -2010,7 +2042,9 @@ void CodeGenLLVM::VisitStmt_(const WhileNode* op) {
builder_->SetInsertPoint(while_cond);
builder_->CreateCondBr(MakeValue(op->condition), while_body, while_merge);
builder_->SetInsertPoint(while_body);
PushLoopFrame(while_cond, while_merge);
this->VisitStmt(op->body);
PopLoopFrame();
builder_->CreateBr(while_cond);
builder_->SetInsertPoint(while_merge);
}
Expand Down
7 changes: 7 additions & 0 deletions src/target/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,13 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
* initializes file and compilation_unit_ to TVM defaults.
*/
static std::unique_ptr<DebugInfo> CreateDebugInfo(llvm::Module* module);

void PushLoopFrame(llvm::BasicBlock* backedge_tgt, llvm::BasicBlock* exit_tgt);
void PopLoopFrame();

// loop frame's jump target for continue and break generation
// store basic block pair (blk to backedge, blk to exit) for each frame.
std::vector<std::pair<llvm::BasicBlock*, llvm::BasicBlock*>> loop_frame_jump_tgts_;
};

inline int CodeGenLLVM::GetVectorNumElements(llvm::Value* vec) {
Expand Down
4 changes: 4 additions & 0 deletions src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,10 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
if (op->op.same_as(builtin::ret())) {
os << "return ";
PrintExpr(op->args[0], os);
} else if (op->op.same_as(builtin::continue_loop())) {
os << "continue;";
} else if (op->op.same_as(builtin::break_loop())) {
os << "break;";
} else if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) {
ICHECK_GE(op->args.size(), 1U);
auto func = Downcast<StringImm>(op->args[0]);
Expand Down
1 change: 0 additions & 1 deletion src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@ std::ostream& operator<<(std::ostream& out, ForKind type) { // NOLINT(*)
While::While(PrimExpr condition, Stmt body, Span span) {
ICHECK(condition.defined());
ICHECK(condition.dtype().is_scalar());
ICHECK(condition.as<tir::IntImmNode>() == nullptr) << "The condition should not be trivial.";
ICHECK(body.defined());

ObjectPtr<WhileNode> node = ffi::make_object<WhileNode>();
Expand Down
Loading
Loading