Skip to content

Commit

Permalink
Extract common ases
Browse files Browse the repository at this point in the history
  • Loading branch information
alexeyr committed Feb 1, 2019
1 parent ef48466 commit 8743399
Show file tree
Hide file tree
Showing 17 changed files with 56 additions and 59 deletions.
5 changes: 2 additions & 3 deletions src/codegen/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -791,10 +791,9 @@ void CodeGenC::VisitStmt_(const AttrStmt* op) {
void CodeGenC::VisitStmt_(const AssertStmt* op) {
std::string cond = PrintExpr(op->condition);
PrintIndent();
if (op->message.as<StringImm>()) {
if (const auto* str = op->message.as<StringImm>()) {
// GLOG style check
stream << "CHECK(" << cond << ") << \""
<< op->message.as<StringImm>()->value << "\";\n";
stream << "CHECK(" << cond << ") << \"" << str->value << "\";\n";
} else {
stream << "assert(" << cond << ");\n";
}
Expand Down
4 changes: 2 additions & 2 deletions src/codegen/stackvm/codegen_stackvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -470,8 +470,8 @@ void CodeGenStackVM::VisitExpr_(const Select *op) {
}

void CodeGenStackVM::VisitStmt_(const AssertStmt *op) {
if (op->message.as<StringImm>()) {
int sid = this->GetStrID(op->message.as<StringImm>()->value);
if (const auto* str = op->message.as<StringImm>()) {
int sid = this->GetStrID(str->value);
this->Push(op->condition);
this->PushOp(StackVM::ASSERT, sid);
}
Expand Down
2 changes: 1 addition & 1 deletion src/op/hybrid_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ Stmt ApplySchedule(const Stage &stage,
// Gather rebased variables
std::unordered_map<IterVar, IterVar> rebased;
for (auto rel : stage->relations) {
if (auto rebase = rel.as<RebaseNode>()) {
if (const auto* rebase = rel.as<RebaseNode>()) {
rebased[rebase->rebased] = rebase->parent;
CHECK(rebase->parent->dom.defined());
CHECK(dom_map.count(rebase->rebased));
Expand Down
28 changes: 14 additions & 14 deletions src/pass/ir_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,39 +12,39 @@ Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) {
// use reverse iteration
for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) {
Stmt s = *ri;
if (s.as<For>()) {
auto n = make_node<For>(*s.as<For>());
if (const auto* for_ = s.as<For>()) {
auto n = make_node<For>(*for_);
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (s.as<LetStmt>()) {
auto n = make_node<LetStmt>(*s.as<LetStmt>());
} else if (const auto* let = s.as<LetStmt>()) {
auto n = make_node<LetStmt>(*let);
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (s.as<AttrStmt>()) {
auto n = make_node<AttrStmt>(*s.as<AttrStmt>());
} else if (const auto* attr = s.as<AttrStmt>()) {
auto n = make_node<AttrStmt>(*attr);
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (s.as<IfThenElse>()) {
auto n = make_node<IfThenElse>(*s.as<IfThenElse>());
} else if (const auto* ite = s.as<IfThenElse>()) {
auto n = make_node<IfThenElse>(*ite);
CHECK(is_no_op(n->then_case));
CHECK(!n->else_case.defined());
n->then_case = body;
body = Stmt(n);
} else if (s.as<Block>()) {
auto n = make_node<Block>(*s.as<Block>());
} else if (const auto* block = s.as<Block>()) {
auto n = make_node<Block>(*block);
CHECK(is_no_op(n->rest));
n->rest = body;
body = Stmt(n);
} else if (s.as<AssertStmt>()) {
auto n = make_node<AssertStmt>(*s.as<AssertStmt>());
} else if (const auto* assert_ = s.as<AssertStmt>()) {
auto n = make_node<AssertStmt>(*assert_);
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (s.as<Allocate>()) {
auto n = make_node<Allocate>(*s.as<Allocate>());
} else if (const auto* alloc = s.as<Allocate>()) {
auto n = make_node<Allocate>(*alloc);
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
Expand Down
5 changes: 3 additions & 2 deletions src/pass/loop_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,8 @@ Stmt LoopPartitioner::TryPartition(const Node* node,

Expr body_begin;
Stmt pre_stmt;
if (true_itrv.as<arith::IntervalSet>()->i.has_lower_bound()) {
arith::Interval true_itrv_i = true_itrv.as<arith::IntervalSet>()->i;
if (true_itrv_i.has_lower_bound()) {
body_begin = ir::Simplify(true_itrv.min());
if (!can_prove(body_begin == min)) {
Expr cond = (body_begin - min >= 0);
Expand All @@ -347,7 +348,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,

Expr post_doubt_begin;
Stmt post_stmt;
if (true_itrv.as<arith::IntervalSet>()->i.has_upper_bound()) {
if (true_itrv_i.has_upper_bound()) {
post_doubt_begin = ir::Simplify(true_itrv.max() + 1);
if (!can_prove(true_itrv.max() == max)) {
// require the extent to be non-negative
Expand Down
18 changes: 10 additions & 8 deletions src/pass/verify_gpu_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,10 @@ class GPUCodeVerifier : public IRVisitor {

void Visit_(const AttrStmt *op) {
if (op->attr_key == attr::storage_scope) {
if (op->value.as<StringImm>()->value == "local") {
std::string opValue = op->value.as<StringImm>()->value;
if (opValue == "local") {
visited_local_buffers_.insert(op->node.as<tvm::Variable>());
} else if (op->value.as<StringImm>()->value == "shared") {
} else if (opValue == "shared") {
visited_shared_buffers_.insert(op->node.as<tvm::Variable>());
}
} else if (op->attr_key == attr::thread_extent) {
Expand Down Expand Up @@ -159,18 +160,19 @@ bool VerifyGPUCode(Stmt stmt,
int64_t max_thread_z = INT64_MAX;

for (auto iter : constraints) {
const IntImm* val = iter.second.as<IntImm>();
if (iter.first == "max_local_memory_per_block")
max_local_memory_per_block = (iter.second).as<IntImm>()->value;
max_local_memory_per_block = val->value;
else if (iter.first == "max_shared_memory_per_block")
max_shared_memory_per_block = (iter.second).as<IntImm>()->value;
max_shared_memory_per_block = val->value;
else if (iter.first == "max_threads_per_block")
max_threads_per_block = (iter.second).as<IntImm>()->value;
max_threads_per_block = val->value;
else if (iter.first == "max_thread_x")
max_thread_x = (iter.second).as<IntImm>()->value;
max_thread_x = val->value;
else if (iter.first == "max_thread_y")
max_thread_y = (iter.second).as<IntImm>()->value;
max_thread_y = val->value;
else if (iter.first == "max_thread_z")
max_thread_z = (iter.second).as<IntImm>()->value;
max_thread_z = val->value;
else
LOG(FATAL) << "Invalid check item: " << iter.first;
}
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ class Interpreter :
//
// We have some functions cotaining chunks of operators
// which will be loaded into operator map.
if (auto op_node = call->op.as<OpNode>()) {
if (const auto* op_node = call->op.as<OpNode>()) {
LOG(FATAL) << "found " << op_node->name
<< "; operators should be removed by future passes; try "
"fusing and lowering";
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/type_relations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace tvm {
namespace relay {

TensorType ToTensorType(const Type& t) {
if (auto tt_node = t.as<TensorTypeNode>()) {
if (const auto* tt_node = t.as<TensorTypeNode>()) {
return GetRef<TensorType>(tt_node);
} else {
return TensorType(nullptr);
Expand Down
2 changes: 1 addition & 1 deletion src/relay/pass/gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ Type WithGradientType(const Type& t) {

//! \brief if the expression is a GlobalVar, transform to it's expression.
Expr DeGlobal(const Module& mod, const Expr& e) {
if (auto x = e.as<GlobalVarNode>()) {
if (const auto* x = e.as<GlobalVarNode>()) {
return mod->Lookup(GetRef<GlobalVar>(x))->body;
} else {
return e;
Expand Down
2 changes: 1 addition & 1 deletion src/relay/pass/to_anf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ Expr ToANFAux(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
}

Expr ToANF(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
if (auto f = e.as<FunctionNode>()) {
if (const auto* f = e.as<FunctionNode>()) {
return FunctionNode::make(f->params,
ToANFAux(f->body, m, gv),
f->ret_type,
Expand Down
2 changes: 1 addition & 1 deletion src/relay/pass/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
}

for (auto cs : fn_ty->type_constraints) {
if (auto tr = cs.as<TypeRelationNode>()) {
if (const auto* tr = cs.as<TypeRelationNode>()) {
solver_.AddConstraint(
TypeRelationNode::make(tr->func, tr->args, tr->num_inputs, call->attrs),
GetRef<Call>(call));
Expand Down
2 changes: 1 addition & 1 deletion src/relay/pass/type_solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ void TypeSolver::ReportError(const Error& err, const NodeRef& location) {

// Add type constraint to the solver.
void TypeSolver::AddConstraint(const TypeConstraint& constraint, const NodeRef& loc) {
if (auto *op = constraint.as<TypeRelationNode>()) {
if (const auto* op = constraint.as<TypeRelationNode>()) {
// create a new relation node.
RelationNode* rnode = arena_.make<RelationNode>();
rnode->location = loc;
Expand Down
24 changes: 12 additions & 12 deletions src/schedule/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,9 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) {
}

for (Operation op : ops) {
if (op.as<ScanOpNode>()) {
const auto& update = op.as<ScanOpNode>()->update;
const auto& init = op.as<ScanOpNode>()->init;
if (const auto* scan_op = op.as<ScanOpNode>()) {
const auto& update = scan_op->update;
const auto& init = scan_op->init;
for (size_t i = 0; i < update.size(); ++i) {
Tensor t = op.output(i);
for (int k = 1; k < static_cast<int>(update[i]->shape.size()); ++k) {
Expand All @@ -235,9 +235,9 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) {
TensorDimKey(init[i], k));
}
}
} else if (op.as<ComputeOpNode>()) {
} else if (const auto* compute_op = op.as<ComputeOpNode>()) {
std::unordered_map<const Node*, TensorDimKey> vmap;
const auto& axis = op.as<ComputeOpNode>()->axis;
const auto& axis = compute_op->axis;
Tensor t = op.output(0);
for (size_t i = 0; i < axis.size(); ++i) {
vmap[axis[i]->var.get()] = TensorDimKey(t, i);
Expand All @@ -260,7 +260,7 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) {
}
}
};
for (auto& e : op.as<ComputeOpNode>()->body) {
for (auto& e : compute_op->body) {
ir::PostOrderVisit(e, fvisit);
}
}
Expand Down Expand Up @@ -312,19 +312,19 @@ Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) {
// prop exact reach back.
for (size_t i = 0; i < body.size(); ++i) {
const Operation& op = body[i];
if (op.as<ScanOpNode>()) {
const auto& update = op.as<ScanOpNode>()->update;
const auto& init = op.as<ScanOpNode>()->init;
if (const auto* scan_op = op.as<ScanOpNode>()) {
const auto& update = scan_op->update;
const auto& init = scan_op->init;
for (size_t i = 0; i < update.size(); ++i) {
Tensor t = op.output(i);
for (size_t k = 1; k < update[i]->shape.size(); ++k) {
f_merge_key(TensorDimKey(t, k), TensorDimKey(update[i], k));
f_merge_key(TensorDimKey(t, k), TensorDimKey(init[i], k));
}
}
} else if (op.as<ComputeOpNode>()) {
} else if (const auto* compute_op = op.as<ComputeOpNode>()) {
std::unordered_map<const Node*, std::vector<TensorDimKey> > vmap;
const auto& axis = op.as<ComputeOpNode>()->axis;
const auto& axis = compute_op->axis;
for (size_t i = 0; i < axis.size(); ++i) {
std::vector<TensorDimKey> keys;
for (int j = 0; j < op->num_outputs(); ++j) {
Expand Down Expand Up @@ -352,7 +352,7 @@ Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) {
}
}
};
for (auto& e : op.as<ComputeOpNode>()->body) {
for (auto& e : compute_op->body) {
ir::PostOrderVisit(e, fvisit);
}
}
Expand Down
9 changes: 3 additions & 6 deletions src/schedule/message_passing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -419,8 +419,7 @@ void PassUpBoundCheck(const Stage& s,
using HalideIR::Internal::can_prove;
for (size_t i = s->relations.size(); i != 0; --i) {
IterVarRelation rel = s->relations[i - 1];
if (rel.as<SplitNode>()) {
const SplitNode* s = rel.as<SplitNode>();
if (const SplitNode* s = rel.as<SplitNode>()) {
bool outer = state.at(s->outer);
bool inner = state.at(s->inner);

Expand All @@ -439,13 +438,11 @@ void PassUpBoundCheck(const Stage& s,
} else {
state[s->parent] = true;
}
} else if (rel.as<FuseNode>()) {
const FuseNode* s = rel.as<FuseNode>();
} else if (const FuseNode* s = rel.as<FuseNode>()) {
bool fused = state.at(s->fused);
state[s->outer] = fused;
state[s->inner] = fused;
} else if (rel.as<RebaseNode>()) {
const RebaseNode* s = rel.as<RebaseNode>();
} else if (const RebaseNode* s = rel.as<RebaseNode>()) {
state[s->parent] = state.at(s->rebased);
} else if (rel.as<SingletonNode>()) {
// nop
Expand Down
2 changes: 1 addition & 1 deletion src/schedule/schedule_dataflow_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ void InjectInline(ScheduleNode* sch) {
const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
if (compute) {
if (!new_body[j].size()) {
new_body[j] = s->op.as<ComputeOpNode>()->body;
new_body[j] = compute->body;
}
if (new_body[j][0]->is_type<ir::Reduce>()) {
// specially handle reduction inline for multiplre reductions.
Expand Down
3 changes: 1 addition & 2 deletions src/schedule/schedule_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -710,8 +710,7 @@ Schedule ScheduleNode::make(Array<Operation> ops) {
n->stages.push_back(stage);
n->stage_map.Set(op, stage);
// mark scan updates.
if (op.as<ScanOpNode>()) {
const ScanOpNode* scan = op.as<ScanOpNode>();
if (const ScanOpNode* scan = op.as<ScanOpNode>()) {
Array<Tensor> inputs;
for (Tensor t : scan->state_placeholder) {
inputs.push_back(t);
Expand Down
3 changes: 1 addition & 2 deletions src/schedule/schedule_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,7 @@ class SchedulePostProc : public IRMutator {
}
}
// Specially add replacements for scan op.
if (s->op.as<ScanOpNode>()) {
const ScanOpNode* scan = s->op.as<ScanOpNode>();
if (const ScanOpNode* scan = s->op.as<ScanOpNode>()) {
for (size_t i = 0; i < scan->update.size(); ++i) {
Tensor t = s->origin_op.output(i);
AddReplace(scan->init[i], t);
Expand Down

0 comments on commit 8743399

Please sign in to comment.