Skip to content

Commit

Permalink
Remove duplicate as Checks and CHECK value (apache#2531)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexeyr authored and merrymercy committed Feb 17, 2019
1 parent 11ec873 commit 2dc545a
Show file tree
Hide file tree
Showing 25 changed files with 88 additions and 94 deletions.
5 changes: 2 additions & 3 deletions src/codegen/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -791,10 +791,9 @@ void CodeGenC::VisitStmt_(const AttrStmt* op) {
void CodeGenC::VisitStmt_(const AssertStmt* op) {
std::string cond = PrintExpr(op->condition);
PrintIndent();
if (op->message.as<StringImm>()) {
if (const auto* str = op->message.as<StringImm>()) {
// GLOG style check
stream << "CHECK(" << cond << ") << \""
<< op->message.as<StringImm>()->value << "\";\n";
stream << "CHECK(" << cond << ") << \"" << str->value << "\";\n";
} else {
stream << "assert(" << cond << ");\n";
}
Expand Down
4 changes: 2 additions & 2 deletions src/codegen/stackvm/codegen_stackvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -470,8 +470,8 @@ void CodeGenStackVM::VisitExpr_(const Select *op) {
}

void CodeGenStackVM::VisitStmt_(const AssertStmt *op) {
if (op->message.as<StringImm>()) {
int sid = this->GetStrID(op->message.as<StringImm>()->value);
if (const auto* str = op->message.as<StringImm>()) {
int sid = this->GetStrID(str->value);
this->Push(op->condition);
this->PushOp(StackVM::ASSERT, sid);
}
Expand Down
4 changes: 2 additions & 2 deletions src/common/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ inline std::string GetHostName() {
}

/*!
* \brief Common data structure fornetwork address.
* \brief Common data structure for network address.
*/
struct SockAddr {
sockaddr_storage addr;
SockAddr() {}
/*!
* \brief construc address by url and port
* \brief construct address by url and port
* \param url The url of the address
* \param port The port of the address.
*/
Expand Down
2 changes: 1 addition & 1 deletion src/op/hybrid_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ Stmt ApplySchedule(const Stage &stage,
// Gather rebased variables
std::unordered_map<IterVar, IterVar> rebased;
for (auto rel : stage->relations) {
if (auto rebase = rel.as<RebaseNode>()) {
if (const auto* rebase = rel.as<RebaseNode>()) {
rebased[rebase->rebased] = rebase->parent;
CHECK(rebase->parent->dom.defined());
CHECK(dom_map.count(rebase->rebased));
Expand Down
28 changes: 14 additions & 14 deletions src/pass/ir_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,39 +12,39 @@ Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) {
// use reverse iteration
for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) {
Stmt s = *ri;
if (s.as<For>()) {
auto n = make_node<For>(*s.as<For>());
if (const auto* for_ = s.as<For>()) {
auto n = make_node<For>(*for_);
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (s.as<LetStmt>()) {
auto n = make_node<LetStmt>(*s.as<LetStmt>());
} else if (const auto* let = s.as<LetStmt>()) {
auto n = make_node<LetStmt>(*let);
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (s.as<AttrStmt>()) {
auto n = make_node<AttrStmt>(*s.as<AttrStmt>());
} else if (const auto* attr = s.as<AttrStmt>()) {
auto n = make_node<AttrStmt>(*attr);
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (s.as<IfThenElse>()) {
auto n = make_node<IfThenElse>(*s.as<IfThenElse>());
} else if (const auto* ite = s.as<IfThenElse>()) {
auto n = make_node<IfThenElse>(*ite);
CHECK(is_no_op(n->then_case));
CHECK(!n->else_case.defined());
n->then_case = body;
body = Stmt(n);
} else if (s.as<Block>()) {
auto n = make_node<Block>(*s.as<Block>());
} else if (const auto* block = s.as<Block>()) {
auto n = make_node<Block>(*block);
CHECK(is_no_op(n->rest));
n->rest = body;
body = Stmt(n);
} else if (s.as<AssertStmt>()) {
auto n = make_node<AssertStmt>(*s.as<AssertStmt>());
} else if (const auto* assert_ = s.as<AssertStmt>()) {
auto n = make_node<AssertStmt>(*assert_);
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (s.as<Allocate>()) {
auto n = make_node<Allocate>(*s.as<Allocate>());
} else if (const auto* alloc = s.as<Allocate>()) {
auto n = make_node<Allocate>(*alloc);
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
Expand Down
5 changes: 3 additions & 2 deletions src/pass/loop_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,8 @@ Stmt LoopPartitioner::TryPartition(const Node* node,

Expr body_begin;
Stmt pre_stmt;
if (true_itrv.as<arith::IntervalSet>()->i.has_lower_bound()) {
arith::Interval true_itrv_i = true_itrv.as<arith::IntervalSet>()->i;
if (true_itrv_i.has_lower_bound()) {
body_begin = ir::Simplify(true_itrv.min());
if (!can_prove(body_begin == min)) {
Expr cond = (body_begin - min >= 0);
Expand All @@ -347,7 +348,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,

Expr post_doubt_begin;
Stmt post_stmt;
if (true_itrv.as<arith::IntervalSet>()->i.has_upper_bound()) {
if (true_itrv_i.has_upper_bound()) {
post_doubt_begin = ir::Simplify(true_itrv.max() + 1);
if (!can_prove(true_itrv.max() == max)) {
// require the extent to be non-negative
Expand Down
2 changes: 1 addition & 1 deletion src/pass/split_host_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class IRUseDefAnalysis : public IRMutator {
value = this->Mutate(value);
}
Stmt body = this->Mutate(op->body);
if (value.same_as(value) && body.same_as(body)) return s;
if (value.same_as(op->value) && body.same_as(op->body)) return s;
return AttrStmt::make(op->node, op->attr_key, value, body);
} else if (op->attr_key == attr::channel_write_scope ||
op->attr_key == attr::channel_read_scope) {
Expand Down
6 changes: 3 additions & 3 deletions src/pass/storage_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -718,10 +718,10 @@ class StoragePlanRewriter : public IRMutator {
src_entry->attach_scope_ == thread_scope_ &&
src_entry->elem_type == ae.alloc->type.element_of() &&
visitor.Check(s.stmt, var, src)) {
uint64_t const_nbits = static_cast<uint64_t>(
ae.alloc->constant_allocation_size() *
uint64_t const_nbits =
static_cast<uint64_t>(ae.alloc->constant_allocation_size()) *
ae.alloc->type.bits() *
ae.alloc->type.lanes());
ae.alloc->type.lanes();
if (src_entry->const_nbits == const_nbits && !inplace_found) {
// successfully inplace
dst_entry = src_entry;
Expand Down
18 changes: 10 additions & 8 deletions src/pass/verify_gpu_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,10 @@ class GPUCodeVerifier : public IRVisitor {

void Visit_(const AttrStmt *op) {
if (op->attr_key == attr::storage_scope) {
if (op->value.as<StringImm>()->value == "local") {
std::string op_value = op->value.as<StringImm>()->value;
if (op_value == "local") {
visited_local_buffers_.insert(op->node.as<tvm::Variable>());
} else if (op->value.as<StringImm>()->value == "shared") {
} else if (op_value == "shared") {
visited_shared_buffers_.insert(op->node.as<tvm::Variable>());
}
} else if (op->attr_key == attr::thread_extent) {
Expand Down Expand Up @@ -159,18 +160,19 @@ bool VerifyGPUCode(Stmt stmt,
int64_t max_thread_z = INT64_MAX;

for (auto iter : constraints) {
const IntImm* val = iter.second.as<IntImm>();
if (iter.first == "max_local_memory_per_block")
max_local_memory_per_block = (iter.second).as<IntImm>()->value;
max_local_memory_per_block = val->value;
else if (iter.first == "max_shared_memory_per_block")
max_shared_memory_per_block = (iter.second).as<IntImm>()->value;
max_shared_memory_per_block = val->value;
else if (iter.first == "max_threads_per_block")
max_threads_per_block = (iter.second).as<IntImm>()->value;
max_threads_per_block = val->value;
else if (iter.first == "max_thread_x")
max_thread_x = (iter.second).as<IntImm>()->value;
max_thread_x = val->value;
else if (iter.first == "max_thread_y")
max_thread_y = (iter.second).as<IntImm>()->value;
max_thread_y = val->value;
else if (iter.first == "max_thread_z")
max_thread_z = (iter.second).as<IntImm>()->value;
max_thread_z = val->value;
else
LOG(FATAL) << "Invalid check item: " << iter.first;
}
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ class Interpreter :
//
// We have some functions cotaining chunks of operators
// which will be loaded into operator map.
if (auto op_node = call->op.as<OpNode>()) {
if (const auto* op_node = call->op.as<OpNode>()) {
LOG(FATAL) << "found " << op_node->name
<< "; operators should be removed by future passes; try "
"fusing and lowering";
Expand Down
2 changes: 1 addition & 1 deletion src/relay/ir/base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ NodePtr<SourceNameNode> GetSourceNameNode(const std::string& name) {
auto sn = source_map.find(name);
if (sn == source_map.end()) {
NodePtr<SourceNameNode> n = make_node<SourceNameNode>();
n->name = std::move(name);
source_map[name] = n;
n->name = std::move(name);
return n;
} else {
return sn->second;
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/type_relations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace tvm {
namespace relay {

TensorType ToTensorType(const Type& t) {
if (auto tt_node = t.as<TensorTypeNode>()) {
if (const auto* tt_node = t.as<TensorTypeNode>()) {
return GetRef<TensorType>(tt_node);
} else {
return TensorType(nullptr);
Expand Down
2 changes: 1 addition & 1 deletion src/relay/pass/fold_scale_axis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ Expr AddSubForwardRewrite(const Call& ref_call,
rnode->scale = slhs->scale;
rnode->axes = slhs->axes;
} else {
CHECK(slhs != nullptr);
CHECK(srhs != nullptr);
CHECK(MatchBroadcastToLeftAxes(trhs, tlhs, srhs->axes));
Expr scale = ExpandBiasToMatchAxis(
srhs->scale, trhs->shape.size(), srhs->axes);
Expand Down
2 changes: 1 addition & 1 deletion src/relay/pass/gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ Type WithGradientType(const Type& t) {

//! \brief if the expression is a GlobalVar, transform to it's expression.
Expr DeGlobal(const Module& mod, const Expr& e) {
if (auto x = e.as<GlobalVarNode>()) {
if (const auto* x = e.as<GlobalVarNode>()) {
return mod->Lookup(GetRef<GlobalVar>(x))->body;
} else {
return e;
Expand Down
2 changes: 1 addition & 1 deletion src/relay/pass/to_anf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ Expr ToANFAux(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
}

Expr ToANF(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
if (auto f = e.as<FunctionNode>()) {
if (const auto* f = e.as<FunctionNode>()) {
return FunctionNode::make(f->params,
ToANFAux(f->body, m, gv),
f->ret_type,
Expand Down
2 changes: 1 addition & 1 deletion src/relay/pass/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
}

for (auto cs : fn_ty->type_constraints) {
if (auto tr = cs.as<TypeRelationNode>()) {
if (const auto* tr = cs.as<TypeRelationNode>()) {
solver_.AddConstraint(
TypeRelationNode::make(tr->func, tr->args, tr->num_inputs, call->attrs),
GetRef<Call>(call));
Expand Down
2 changes: 1 addition & 1 deletion src/relay/pass/type_solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ void TypeSolver::ReportError(const Error& err, const NodeRef& location) {

// Add type constraint to the solver.
void TypeSolver::AddConstraint(const TypeConstraint& constraint, const NodeRef& loc) {
if (auto *op = constraint.as<TypeRelationNode>()) {
if (const auto* op = constraint.as<TypeRelationNode>()) {
// create a new relation node.
RelationNode* rnode = arena_.make<RelationNode>();
rnode->location = loc;
Expand Down
45 changes: 22 additions & 23 deletions src/runtime/rpc/rpc_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -486,29 +486,28 @@ class RPCSession::EventHandler : public dmlc::Stream {
arg_recv_stage_ = 1;
this->RequestBytes(len);
break;
break;
}
case kArrayHandle: {
temp_array_.reset(new RPCDataArrayBuffer());
uint64_t handle;
this->Read(&handle);
DLTensor& tensor = temp_array_->tensor;
tensor.data = reinterpret_cast<void*>(handle);
this->Read(&(tensor.ctx));
this->Read(&(tensor.ndim));
this->Read(&(tensor.dtype));
temp_array_->shape.resize(tensor.ndim);
tensor.shape = temp_array_->shape.data();
arg_recv_stage_ = 1;
tensor.strides = nullptr;
tensor.byte_offset = 0;
this->RequestBytes(sizeof(int64_t) * tensor.ndim);
break;
}
default: {
LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode);
break;
}
}
case kArrayHandle: {
temp_array_.reset(new RPCDataArrayBuffer());
uint64_t handle;
this->Read(&handle);
DLTensor& tensor = temp_array_->tensor;
tensor.data = reinterpret_cast<void*>(handle);
this->Read(&(tensor.ctx));
this->Read(&(tensor.ndim));
this->Read(&(tensor.dtype));
temp_array_->shape.resize(tensor.ndim);
tensor.shape = temp_array_->shape.data();
arg_recv_stage_ = 1;
tensor.strides = nullptr;
tensor.byte_offset = 0;
this->RequestBytes(sizeof(int64_t) * tensor.ndim);
break;
}
default: {
LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode);
break;
}
}
} else {
CHECK_EQ(arg_recv_stage_, 1);
Expand Down
2 changes: 0 additions & 2 deletions src/runtime/stackvm/stackvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,6 @@ void StackVM::Run(State* s) const {
case intrinsic::kArrByteOffset: {
stack[sp].v_int64 = static_cast<int64_t>(
arr[index].byte_offset); break;
break;
}
case intrinsic::kArrDeviceId: {
stack[sp].v_int64 = arr[index].ctx.device_id; break;
Expand Down Expand Up @@ -531,7 +530,6 @@ const PackedFunc& StackVM::GetExtern(State* s, int fid) const {
if (f == nullptr) {
CHECK(s->mod_ctx != nullptr)
<< "No local context is set in stackvm";
CHECK(s->mod_ctx != nullptr);
const PackedFunc* pf = s->mod_ctx->GetFuncFromEnv(extern_func_name[fid]);
CHECK(pf != nullptr);
f = *pf;
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/stackvm/stackvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ class StackVM {
case EQ_I64: return EQ_F64;
case LT_I64: return LT_F64;
case LE_I64: return LE_F64;
case MOD_I64: LOG(FATAL) << "cannot handle mod for float";
case MOD_I64: LOG(FATAL) << "cannot handle mod for float"; return ADD_F64;
default: LOG(FATAL) << "cannot handle op " << code; return ADD_F64;
}
}
Expand Down
Loading

0 comments on commit 2dc545a

Please sign in to comment.