Skip to content

Commit

Permalink
[TIR][ANALYSIS] Refine side effect analysis. (#5954)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored Jun 29, 2020
1 parent 78d7992 commit ef804b7
Show file tree
Hide file tree
Showing 10 changed files with 57 additions and 38 deletions.
8 changes: 5 additions & 3 deletions include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <tvm/ir/transform.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op_attr_types.h>
#include <tvm/tir/stmt.h>

#include <string>
Expand Down Expand Up @@ -64,11 +65,12 @@ struct ExprDeepEqual {
TVM_DLL Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);

/*!
* \brief Whether the expression have side effect.
* \brief Analyze the side effect
* \param expr The expression to be checked.
* \return whether expression have side effect
*
* \return CallEffectKind, can be kPure, kReadState or kUpdateState
*/
TVM_DLL bool HasSideEffect(const PrimExpr& expr);
TVM_DLL CallEffectKind SideEffect(const PrimExpr& expr);

/*!
* \brief Whether e expression used any var in variable set..
Expand Down
5 changes: 3 additions & 2 deletions src/arith/canonical_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1018,8 +1018,9 @@ PrimExpr CanonicalSimplifier::Impl::SimplifyReduceCombiner(const ReduceNode* op)

// components which have side effects should also be preserved
for (size_t i = 0; i < used.size(); ++i) {
if (HasSideEffect(op->source[i]) || HasSideEffect(op->combiner->identity_element[i]) ||
HasSideEffect(op->combiner->result[i])) {
if (SideEffect(op->source[i]) > CallEffectKind::kReadState ||
SideEffect(op->combiner->identity_element[i]) > CallEffectKind::kReadState ||
SideEffect(op->combiner->result[i]) > CallEffectKind::kReadState) {
mark_used(i);
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/arith/ir_mutator_with_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const ForNode* op) {

Stmt IRMutatorWithAnalyzer::VisitStmt_(const LetStmtNode* op) {
PrimExpr value = this->VisitExpr(op->value);
if (!tir::HasSideEffect(value)) {
if (SideEffect(value) <= CallEffectKind::kPure) {
analyzer_->Bind(op->var, value);
}
// We keep the let-binding here
Expand Down Expand Up @@ -154,7 +154,7 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* op) {

PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const LetNode* op) {
PrimExpr value = this->VisitExpr(op->value);
if (!tir::HasSideEffect(value)) {
if (SideEffect(value) <= CallEffectKind::kPure) {
analyzer_->Bind(op->var, value);
}
// We keep the let-binding here
Expand Down
2 changes: 1 addition & 1 deletion src/te/schedule/operation_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class OperationInliner final : public StmtExprMutator {

bool has_side_effect = false;
for (size_t i = 0; i < op->indices.size(); ++i) {
if (HasSideEffect(op->indices[i])) has_side_effect = true;
if (SideEffect(op->indices[i]) > CallEffectKind::kReadState) has_side_effect = true;
}
if (has_side_effect) {
for (size_t i = 0; i < args_.size(); ++i) {
Expand Down
2 changes: 1 addition & 1 deletion src/te/schedule/schedule_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class InjectScanStep : public StmtMutator {
class SchedulePostProc : public StmtExprMutator {
public:
Stmt VisitStmt_(const LetStmtNode* op) final {
if (!HasSideEffect(op->value)) {
if (SideEffect(op->value) <= CallEffectKind::kPure) {
var_value_[op->var.get()] = this->VisitExpr(op->value);
return this->VisitStmt(op->body);
} else {
Expand Down
43 changes: 28 additions & 15 deletions src/tir/analysis/side_effect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,34 +33,47 @@ namespace tir {
class ExprSideEffect : public ExprVisitor {
public:
void VisitExpr(const PrimExpr& e) final {
if (has_side_effect_) return;
if (kind_ == CallEffectKind::kUpdateState) return;
ExprVisitor::VisitExpr(e);
}

void VisitExpr_(const LoadNode* op) final {
this->UpdateEffect(CallEffectKind::kReadState);
ExprVisitor::VisitExpr_(op);
}

void VisitExpr_(const BufferLoadNode* op) final {
this->UpdateEffect(CallEffectKind::kReadState);
ExprVisitor::VisitExpr_(op);
}

void VisitExpr_(const CallNode* op) final {
static auto op_call_effect = Op::GetAttrMap<TCallEffectKind>("TCallEffectKind");

if (auto* ptr_op = op->op.as<OpNode>()) {
auto effect_kind = op_call_effect[GetRef<Op>(ptr_op)];
if (effect_kind != CallEffectKind::kPure && effect_kind != CallEffectKind::kExprAnnotation) {
has_side_effect_ = true;
return;
} else {
ExprVisitor::VisitExpr_(op);
}
this->UpdateEffect(static_cast<CallEffectKind>(op_call_effect[GetRef<Op>(ptr_op)]->value));
} else {
has_side_effect_ = true;
return;
this->UpdateEffect(CallEffectKind::kOpaque);
}
ExprVisitor::VisitExpr_(op);
}

void UpdateEffect(CallEffectKind effect_kind) {
if (effect_kind > CallEffectKind::kUpdateState) {
effect_kind = CallEffectKind::kUpdateState;
}
if (effect_kind > kind_) {
kind_ = effect_kind;
}
}

bool has_side_effect_{false};
CallEffectKind kind_{CallEffectKind::kPure};
};

bool HasSideEffect(const PrimExpr& e) {
ExprSideEffect v;
v(e);
return v.has_side_effect_;
CallEffectKind SideEffect(const PrimExpr& e) {
ExprSideEffect visitor;
visitor(e);
return visitor.kind_;
}

} // namespace tir
Expand Down
6 changes: 3 additions & 3 deletions src/tir/transforms/remove_no_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class NoOpRemover : public StmtMutator {
return is_no_op(op->body) ? op->body : stmt;
}
Stmt VisitStmt_(const EvaluateNode* op) final {
if (HasSideEffect(op->value)) return GetRef<Stmt>(op);
if (SideEffect(op->value) > CallEffectKind::kReadState) return GetRef<Stmt>(op);
return Evaluate(0);
}

Expand Down Expand Up @@ -127,7 +127,7 @@ class NoOpRemover : public StmtMutator {

private:
Stmt MakeEvaluate(PrimExpr value) {
if (HasSideEffect(value)) {
if (SideEffect(value) > CallEffectKind::kReadState) {
return Evaluate(value);
} else {
return Evaluate(0);
Expand All @@ -136,7 +136,7 @@ class NoOpRemover : public StmtMutator {
Stmt MakeEvaluate(const Array<PrimExpr>& values) {
Stmt stmt;
for (PrimExpr e : values) {
if (HasSideEffect(e)) {
if (SideEffect(e) > CallEffectKind::kReadState) {
if (stmt.defined()) {
stmt = SeqStmt({stmt, Evaluate(e)});
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
// Won't face the deep expression explosion problem as in Let expression.
// attempt to inline as much as possible if the value integer type(can be index).
if (!op->value.dtype().is_int()) return false;
return !tir::HasSideEffect(op->value);
return SideEffect(op->value) <= CallEffectKind::kPure;
}

Stmt VisitStmt_(const LetStmtNode* op) {
Expand Down
6 changes: 4 additions & 2 deletions src/tir/transforms/split_host_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ class VarUseDefAnalysis : public StmtExprMutator {
this->HandleDef(op->var.get());
Stmt body = this->VisitStmt(op->body);
// eliminate unreferenced let
if (use_count_.at(op->var.get()) == 0 && !HasSideEffect(op->value) && simplify_let_) {
if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState &&
simplify_let_) {
return body;
} else {
PrimExpr value = this->VisitExpr(op->value);
Expand Down Expand Up @@ -101,7 +102,8 @@ class VarUseDefAnalysis : public StmtExprMutator {
this->HandleDef(op->var.get());
PrimExpr body = this->VisitExpr(op->body);
// eliminate unreferenced let
if (use_count_.at(op->var.get()) == 0 && !HasSideEffect(op->value) && simplify_let_) {
if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState &&
simplify_let_) {
return body;
} else {
PrimExpr value = this->VisitExpr(op->value);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,17 @@
#include <gtest/gtest.h>
#include <tvm/te/operation.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>

TEST(SimplePasses, HasSideEffect) {
TEST(SimplePasses, SideEffect) {
using namespace tvm;
auto n = te::var("n");
Array<PrimExpr> shape;
shape.push_back(n);

auto A = te::placeholder(shape, DataType::Float(32), "A");

CHECK(!tvm::tir::HasSideEffect(A[0]));
auto A = tir::Var("A", DataType::Handle());
auto i = tir::Var("i", DataType::Int(32));
CHECK(tir::SideEffect(tir::Load(DataType::Float(32), A, i, tir::const_true(1))) ==
tir::CallEffectKind::kReadState);
CHECK(tir::SideEffect(exp(tir::Cast(DataType::Float(32), i + 1))) == tir::CallEffectKind::kPure);
CHECK(tir::SideEffect(tir::Call(DataType::Handle(), tir::builtin::tvm_storage_sync(), {})) ==
tir::CallEffectKind::kUpdateState);
}

int main(int argc, char** argv) {
Expand Down

0 comments on commit ef804b7

Please sign in to comment.