From e091c368c8a5445d859e9d7571b5b43a93968c16 Mon Sep 17 00:00:00 2001 From: Alexey Romanov Date: Wed, 30 Jan 2019 14:00:15 +0300 Subject: [PATCH 1/3] Fix typos --- src/common/socket.h | 4 +-- src/pass/split_host_device.cc | 2 +- src/relay/ir/base.cc | 2 +- src/relay/pass/fold_scale_axis.cc | 2 +- src/runtime/rpc/rpc_session.cc | 45 +++++++++++++++---------------- src/runtime/stackvm/stackvm.cc | 2 -- src/runtime/stackvm/stackvm.h | 2 +- src/schedule/graph.cc | 2 +- 8 files changed, 29 insertions(+), 32 deletions(-) diff --git a/src/common/socket.h b/src/common/socket.h index 181889b3fd4e..fafff97b2522 100644 --- a/src/common/socket.h +++ b/src/common/socket.h @@ -42,13 +42,13 @@ inline std::string GetHostName() { } /*! - * \brief Common data structure fornetwork address. + * \brief Common data structure for network address. */ struct SockAddr { sockaddr_storage addr; SockAddr() {} /*! - * \brief construc address by url and port + * \brief construct address by url and port * \param url The url of the address * \param port The port of the address. */ diff --git a/src/pass/split_host_device.cc b/src/pass/split_host_device.cc index 4cfbc7c90d8c..5e0a2508c218 100644 --- a/src/pass/split_host_device.cc +++ b/src/pass/split_host_device.cc @@ -34,7 +34,7 @@ class IRUseDefAnalysis : public IRMutator { value = this->Mutate(value); } Stmt body = this->Mutate(op->body); - if (value.same_as(value) && body.same_as(body)) return s; + if (value.same_as(op->value) && body.same_as(op->body)) return s; return AttrStmt::make(op->node, op->attr_key, value, body); } else if (op->attr_key == attr::channel_write_scope || op->attr_key == attr::channel_read_scope) { diff --git a/src/relay/ir/base.cc b/src/relay/ir/base.cc index 8df54883616a..4b67909213e1 100644 --- a/src/relay/ir/base.cc +++ b/src/relay/ir/base.cc @@ -20,8 +20,8 @@ NodePtr GetSourceNameNode(const std::string& name) { auto sn = source_map.find(name); if (sn == source_map.end()) { NodePtr n = make_node(); - n->name = std::move(name); source_map[name] = n; + n->name = std::move(name); return n; } else { return sn->second; diff --git a/src/relay/pass/fold_scale_axis.cc b/src/relay/pass/fold_scale_axis.cc index 0cd46ff330e1..270965886ab9 100644 --- a/src/relay/pass/fold_scale_axis.cc +++ b/src/relay/pass/fold_scale_axis.cc @@ -361,7 +361,7 @@ Expr AddSubForwardRewrite(const Call& ref_call, rnode->scale = slhs->scale; rnode->axes = slhs->axes; } else { - CHECK(slhs != nullptr); + CHECK(srhs != nullptr); CHECK(MatchBroadcastToLeftAxes(trhs, tlhs, srhs->axes)); Expr scale = ExpandBiasToMatchAxis( srhs->scale, trhs->shape.size(), srhs->axes); diff --git a/src/runtime/rpc/rpc_session.cc b/src/runtime/rpc/rpc_session.cc index adc27ab4aa6f..47701fd3f60d 100644 --- a/src/runtime/rpc/rpc_session.cc +++ b/src/runtime/rpc/rpc_session.cc @@ -486,29 +486,28 @@ class RPCSession::EventHandler : public dmlc::Stream { arg_recv_stage_ = 1; this->RequestBytes(len); break; - break; - } - case kArrayHandle: { - temp_array_.reset(new RPCDataArrayBuffer()); - uint64_t handle; - this->Read(&handle); - DLTensor& tensor = temp_array_->tensor; - tensor.data = reinterpret_cast(handle); - this->Read(&(tensor.ctx)); - this->Read(&(tensor.ndim)); - this->Read(&(tensor.dtype)); - temp_array_->shape.resize(tensor.ndim); - tensor.shape = temp_array_->shape.data(); - arg_recv_stage_ = 1; - tensor.strides = nullptr; - tensor.byte_offset = 0; - this->RequestBytes(sizeof(int64_t) * tensor.ndim); - break; - } - default: { - LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode); - break; - } + } + case kArrayHandle: { + temp_array_.reset(new RPCDataArrayBuffer()); + uint64_t handle; + this->Read(&handle); + DLTensor& tensor = temp_array_->tensor; + tensor.data = reinterpret_cast(handle); + this->Read(&(tensor.ctx)); + this->Read(&(tensor.ndim)); + this->Read(&(tensor.dtype)); + temp_array_->shape.resize(tensor.ndim); + tensor.shape = temp_array_->shape.data(); + arg_recv_stage_ = 1; + tensor.strides = nullptr; + tensor.byte_offset = 0; + this->RequestBytes(sizeof(int64_t) * tensor.ndim); + break; + } + default: { + LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode); + break; + } } } else { CHECK_EQ(arg_recv_stage_, 1); diff --git a/src/runtime/stackvm/stackvm.cc b/src/runtime/stackvm/stackvm.cc index f45d83027467..131c60704dc5 100644 --- a/src/runtime/stackvm/stackvm.cc +++ b/src/runtime/stackvm/stackvm.cc @@ -406,7 +406,6 @@ void StackVM::Run(State* s) const { case intrinsic::kArrByteOffset: { stack[sp].v_int64 = static_cast( arr[index].byte_offset); break; - break; } case intrinsic::kArrDeviceId: { stack[sp].v_int64 = arr[index].ctx.device_id; break; @@ -531,7 +530,6 @@ const PackedFunc& StackVM::GetExtern(State* s, int fid) const { if (f == nullptr) { CHECK(s->mod_ctx != nullptr) << "No local context is set in stackvm"; - CHECK(s->mod_ctx != nullptr); const PackedFunc* pf = s->mod_ctx->GetFuncFromEnv(extern_func_name[fid]); CHECK(pf != nullptr); f = *pf; diff --git a/src/runtime/stackvm/stackvm.h b/src/runtime/stackvm/stackvm.h index b2ce975b2c73..87c4582f6bc8 100644 --- a/src/runtime/stackvm/stackvm.h +++ b/src/runtime/stackvm/stackvm.h @@ -331,7 +331,7 @@ class StackVM { case EQ_I64: return EQ_F64; case LT_I64: return LT_F64; case LE_I64: return LE_F64; - case MOD_I64: LOG(FATAL) << "cannot handle mod for float"; + case MOD_I64: LOG(FATAL) << "cannot handle mod for float"; return ADD_F64; default: LOG(FATAL) << "cannot handle op " << code; return ADD_F64; } } diff --git a/src/schedule/graph.cc b/src/schedule/graph.cc index d92e7730b313..5589853c0db8 100644 --- a/src/schedule/graph.cc +++ b/src/schedule/graph.cc @@ -317,7 +317,7 @@ Map ScanFixPointAnalysis(const Operation& scan_op) { const auto& init = op.as()->init; for (size_t i = 0; i < update.size(); ++i) { Tensor t = op.output(i); - for (size_t k = 1; i < update[i]->shape.size(); ++k) { + for (size_t k = 1; k < update[i]->shape.size(); ++k) { f_merge_key(TensorDimKey(t, k), TensorDimKey(update[i], k)); f_merge_key(TensorDimKey(t, k), TensorDimKey(init[i], k)); } From f8e4ccc247fc12903b4c369e9d825c0363bbecf7 Mon Sep 17 00:00:00 2001 From: Alexey Romanov Date: Wed, 30 Jan 2019 14:03:20 +0300 Subject: [PATCH 2/3] Prevent possible overflow --- src/pass/storage_rewrite.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index 54f5010f1461..331b60a865ed 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -718,10 +718,10 @@ class StoragePlanRewriter : public IRMutator { src_entry->attach_scope_ == thread_scope_ && src_entry->elem_type == ae.alloc->type.element_of() && visitor.Check(s.stmt, var, src)) { - uint64_t const_nbits = static_cast( - ae.alloc->constant_allocation_size() * + uint64_t const_nbits = + static_cast(ae.alloc->constant_allocation_size()) * ae.alloc->type.bits() * - ae.alloc->type.lanes()); + ae.alloc->type.lanes(); if (src_entry->const_nbits == const_nbits && !inplace_found) { // successfully inplace dst_entry = src_entry; From f0110ae90714fd8583dc226e2dfd126b064beae8 Mon Sep 17 00:00:00 2001 From: Alexey Romanov Date: Wed, 30 Jan 2019 15:00:56 +0300 Subject: [PATCH 3/3] Extract common ases --- src/codegen/codegen_c.cc | 5 ++-- src/codegen/stackvm/codegen_stackvm.cc | 4 ++-- src/op/hybrid_op.cc | 2 +- src/pass/ir_util.cc | 28 +++++++++++------------ src/pass/loop_partition.cc | 5 ++-- src/pass/verify_gpu_code.cc | 18 ++++++++------- src/relay/backend/interpreter.cc | 2 +- src/relay/op/type_relations.cc | 2 +- src/relay/pass/gradient.cc | 2 +- src/relay/pass/to_anf.cc | 2 +- src/relay/pass/type_infer.cc | 2 +- src/relay/pass/type_solver.cc | 2 +- src/schedule/graph.cc | 24 +++++++++---------- src/schedule/message_passing.cc | 9 +++----- src/schedule/schedule_dataflow_rewrite.cc | 2 +- src/schedule/schedule_lang.cc | 3 +-- src/schedule/schedule_ops.cc | 3 +-- 17 files changed, 56 insertions(+), 59 deletions(-) diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index 3624dc0403aa..9b73e4e77a2a 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -791,10 +791,9 @@ void CodeGenC::VisitStmt_(const AttrStmt* op) { void CodeGenC::VisitStmt_(const AssertStmt* op) { std::string cond = PrintExpr(op->condition); PrintIndent(); - if (op->message.as()) { + if (const auto* str = op->message.as()) { // GLOG style check - stream << "CHECK(" << cond << ") << \"" - << op->message.as()->value << "\";\n"; + stream << "CHECK(" << cond << ") << \"" << str->value << "\";\n"; } else { stream << "assert(" << cond << ");\n"; } diff --git a/src/codegen/stackvm/codegen_stackvm.cc b/src/codegen/stackvm/codegen_stackvm.cc index 0bede2dc0751..1ba8169aa3e7 100644 --- a/src/codegen/stackvm/codegen_stackvm.cc +++ b/src/codegen/stackvm/codegen_stackvm.cc @@ -470,8 +470,8 @@ void CodeGenStackVM::VisitExpr_(const Select *op) { } void CodeGenStackVM::VisitStmt_(const AssertStmt *op) { - if (op->message.as()) { - int sid = this->GetStrID(op->message.as()->value); + if (const auto* str = op->message.as()) { + int sid = this->GetStrID(str->value); this->Push(op->condition); this->PushOp(StackVM::ASSERT, sid); } diff --git a/src/op/hybrid_op.cc b/src/op/hybrid_op.cc index acd7b5737c5f..26daefa76d7f 100644 --- a/src/op/hybrid_op.cc +++ b/src/op/hybrid_op.cc @@ -435,7 +435,7 @@ Stmt ApplySchedule(const Stage &stage, // Gather rebased variables std::unordered_map rebased; for (auto rel : stage->relations) { - if (auto rebase = rel.as()) { + if (const auto* rebase = rel.as()) { rebased[rebase->rebased] = rebase->parent; CHECK(rebase->parent->dom.defined()); CHECK(dom_map.count(rebase->rebased)); diff --git a/src/pass/ir_util.cc b/src/pass/ir_util.cc index 89426f982ba8..2658d5f3c307 100644 --- a/src/pass/ir_util.cc +++ b/src/pass/ir_util.cc @@ -12,39 +12,39 @@ Stmt MergeNest(const std::vector& nest, Stmt body) { // use reverse iteration for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) { Stmt s = *ri; - if (s.as()) { - auto n = make_node(*s.as()); + if (const auto* for_ = s.as()) { + auto n = make_node(*for_); CHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); - } else if (s.as()) { - auto n = make_node(*s.as()); + } else if (const auto* let = s.as()) { + auto n = make_node(*let); CHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); - } else if (s.as()) { - auto n = make_node(*s.as()); + } else if (const auto* attr = s.as()) { + auto n = make_node(*attr); CHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); - } else if (s.as()) { - auto n = make_node(*s.as()); + } else if (const auto* ite = s.as()) { + auto n = make_node(*ite); CHECK(is_no_op(n->then_case)); CHECK(!n->else_case.defined()); n->then_case = body; body = Stmt(n); - } else if (s.as()) { - auto n = make_node(*s.as()); + } else if (const auto* block = s.as()) { + auto n = make_node(*block); CHECK(is_no_op(n->rest)); n->rest = body; body = Stmt(n); - } else if (s.as()) { - auto n = make_node(*s.as()); + } else if (const auto* assert_ = s.as()) { + auto n = make_node(*assert_); CHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); - } else if (s.as()) { - auto n = make_node(*s.as()); + } else if (const auto* alloc = s.as()) { + auto n = make_node(*alloc); CHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index 95ce130785d7..5747ca3c6a40 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -326,7 +326,8 @@ Stmt LoopPartitioner::TryPartition(const Node* node, Expr body_begin; Stmt pre_stmt; - if (true_itrv.as()->i.has_lower_bound()) { + arith::Interval true_itrv_i = true_itrv.as()->i; + if (true_itrv_i.has_lower_bound()) { body_begin = ir::Simplify(true_itrv.min()); if (!can_prove(body_begin == min)) { Expr cond = (body_begin - min >= 0); @@ -347,7 +348,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node, Expr post_doubt_begin; Stmt post_stmt; - if (true_itrv.as()->i.has_upper_bound()) { + if (true_itrv_i.has_upper_bound()) { post_doubt_begin = ir::Simplify(true_itrv.max() + 1); if (!can_prove(true_itrv.max() == max)) { // require the extent to be non-negative diff --git a/src/pass/verify_gpu_code.cc b/src/pass/verify_gpu_code.cc index 70908eb43d6b..adc0df06e6dd 100644 --- a/src/pass/verify_gpu_code.cc +++ b/src/pass/verify_gpu_code.cc @@ -73,9 +73,10 @@ class GPUCodeVerifier : public IRVisitor { void Visit_(const AttrStmt *op) { if (op->attr_key == attr::storage_scope) { - if (op->value.as()->value == "local") { + std::string op_value = op->value.as()->value; + if (op_value == "local") { visited_local_buffers_.insert(op->node.as()); - } else if (op->value.as()->value == "shared") { + } else if (op_value == "shared") { visited_shared_buffers_.insert(op->node.as()); } } else if (op->attr_key == attr::thread_extent) { @@ -159,18 +160,19 @@ bool VerifyGPUCode(Stmt stmt, int64_t max_thread_z = INT64_MAX; for (auto iter : constraints) { + const IntImm* val = iter.second.as(); if (iter.first == "max_local_memory_per_block") - max_local_memory_per_block = (iter.second).as()->value; + max_local_memory_per_block = val->value; else if (iter.first == "max_shared_memory_per_block") - max_shared_memory_per_block = (iter.second).as()->value; + max_shared_memory_per_block = val->value; else if (iter.first == "max_threads_per_block") - max_threads_per_block = (iter.second).as()->value; + max_threads_per_block = val->value; else if (iter.first == "max_thread_x") - max_thread_x = (iter.second).as()->value; + max_thread_x = val->value; else if (iter.first == "max_thread_y") - max_thread_y = (iter.second).as()->value; + max_thread_y = val->value; else if (iter.first == "max_thread_z") - max_thread_z = (iter.second).as()->value; + max_thread_z = val->value; else LOG(FATAL) << "Invalid check item: " << iter.first; } diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 81f1cc6989f3..396ff907951d 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -379,7 +379,7 @@ class Interpreter : // // We have some functions cotaining chunks of operators // which will be loaded into operator map. - if (auto op_node = call->op.as()) { + if (const auto* op_node = call->op.as()) { LOG(FATAL) << "found " << op_node->name << "; operators should be removed by future passes; try " "fusing and lowering"; diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index 2618054a663d..9152f0677616 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -15,7 +15,7 @@ namespace tvm { namespace relay { TensorType ToTensorType(const Type& t) { - if (auto tt_node = t.as()) { + if (const auto* tt_node = t.as()) { return GetRef(tt_node); } else { return TensorType(nullptr); diff --git a/src/relay/pass/gradient.cc b/src/relay/pass/gradient.cc index 251d7153e4e6..780490a45b0a 100644 --- a/src/relay/pass/gradient.cc +++ b/src/relay/pass/gradient.cc @@ -61,7 +61,7 @@ Type WithGradientType(const Type& t) { //! \brief if the expression is a GlobalVar, transform to it's expression. Expr DeGlobal(const Module& mod, const Expr& e) { - if (auto x = e.as()) { + if (const auto* x = e.as()) { return mod->Lookup(GetRef(x))->body; } else { return e; diff --git a/src/relay/pass/to_anf.cc b/src/relay/pass/to_anf.cc index 3880fd16a286..a724d5f2e855 100644 --- a/src/relay/pass/to_anf.cc +++ b/src/relay/pass/to_anf.cc @@ -385,7 +385,7 @@ Expr ToANFAux(const Expr& e, const Module& m, std::set* gv) { } Expr ToANF(const Expr& e, const Module& m, std::set* gv) { - if (auto f = e.as()) { + if (const auto* f = e.as()) { return FunctionNode::make(f->params, ToANFAux(f->body, m, gv), f->ret_type, diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 3135715f7691..b17c1c1f0439 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -386,7 +386,7 @@ class TypeInferencer : private ExprFunctor { } for (auto cs : fn_ty->type_constraints) { - if (auto tr = cs.as()) { + if (const auto* tr = cs.as()) { solver_.AddConstraint( TypeRelationNode::make(tr->func, tr->args, tr->num_inputs, call->attrs), GetRef(call)); diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index dafcaf56015a..617aafdc712c 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -376,7 +376,7 @@ void TypeSolver::ReportError(const Error& err, const NodeRef& location) { // Add type constraint to the solver. void TypeSolver::AddConstraint(const TypeConstraint& constraint, const NodeRef& loc) { - if (auto *op = constraint.as()) { + if (const auto* op = constraint.as()) { // create a new relation node. RelationNode* rnode = arena_.make(); rnode->location = loc; diff --git a/src/schedule/graph.cc b/src/schedule/graph.cc index 5589853c0db8..4adb78b56b46 100644 --- a/src/schedule/graph.cc +++ b/src/schedule/graph.cc @@ -223,9 +223,9 @@ ReachGraph GetReachGraph(const Array& ops) { } for (Operation op : ops) { - if (op.as()) { - const auto& update = op.as()->update; - const auto& init = op.as()->init; + if (const auto* scan_op = op.as()) { + const auto& update = scan_op->update; + const auto& init = scan_op->init; for (size_t i = 0; i < update.size(); ++i) { Tensor t = op.output(i); for (int k = 1; k < static_cast(update[i]->shape.size()); ++k) { @@ -235,9 +235,9 @@ ReachGraph GetReachGraph(const Array& ops) { TensorDimKey(init[i], k)); } } - } else if (op.as()) { + } else if (const auto* compute_op = op.as()) { std::unordered_map vmap; - const auto& axis = op.as()->axis; + const auto& axis = compute_op->axis; Tensor t = op.output(0); for (size_t i = 0; i < axis.size(); ++i) { vmap[axis[i]->var.get()] = TensorDimKey(t, i); @@ -260,7 +260,7 @@ ReachGraph GetReachGraph(const Array& ops) { } } }; - for (auto& e : op.as()->body) { + for (auto& e : compute_op->body) { ir::PostOrderVisit(e, fvisit); } } @@ -312,9 +312,9 @@ Map ScanFixPointAnalysis(const Operation& scan_op) { // prop exact reach back. for (size_t i = 0; i < body.size(); ++i) { const Operation& op = body[i]; - if (op.as()) { - const auto& update = op.as()->update; - const auto& init = op.as()->init; + if (const auto* scan_op = op.as()) { + const auto& update = scan_op->update; + const auto& init = scan_op->init; for (size_t i = 0; i < update.size(); ++i) { Tensor t = op.output(i); for (size_t k = 1; k < update[i]->shape.size(); ++k) { @@ -322,9 +322,9 @@ Map ScanFixPointAnalysis(const Operation& scan_op) { f_merge_key(TensorDimKey(t, k), TensorDimKey(init[i], k)); } } - } else if (op.as()) { + } else if (const auto* compute_op = op.as()) { std::unordered_map > vmap; - const auto& axis = op.as()->axis; + const auto& axis = compute_op->axis; for (size_t i = 0; i < axis.size(); ++i) { std::vector keys; for (int j = 0; j < op->num_outputs(); ++j) { @@ -352,7 +352,7 @@ Map ScanFixPointAnalysis(const Operation& scan_op) { } } }; - for (auto& e : op.as()->body) { + for (auto& e : compute_op->body) { ir::PostOrderVisit(e, fvisit); } } diff --git a/src/schedule/message_passing.cc b/src/schedule/message_passing.cc index dff2895cd42d..44614234bd55 100644 --- a/src/schedule/message_passing.cc +++ b/src/schedule/message_passing.cc @@ -419,8 +419,7 @@ void PassUpBoundCheck(const Stage& s, using HalideIR::Internal::can_prove; for (size_t i = s->relations.size(); i != 0; --i) { IterVarRelation rel = s->relations[i - 1]; - if (rel.as()) { - const SplitNode* s = rel.as(); + if (const SplitNode* s = rel.as()) { bool outer = state.at(s->outer); bool inner = state.at(s->inner); @@ -439,13 +438,11 @@ void PassUpBoundCheck(const Stage& s, } else { state[s->parent] = true; } - } else if (rel.as()) { - const FuseNode* s = rel.as(); + } else if (const FuseNode* s = rel.as()) { bool fused = state.at(s->fused); state[s->outer] = fused; state[s->inner] = fused; - } else if (rel.as()) { - const RebaseNode* s = rel.as(); + } else if (const RebaseNode* s = rel.as()) { state[s->parent] = state.at(s->rebased); } else if (rel.as()) { // nop diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index f1820d2a7fc6..774c623a5df2 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -544,7 +544,7 @@ void InjectInline(ScheduleNode* sch) { const ComputeOpNode* compute = s->op.as(); if (compute) { if (!new_body[j].size()) { - new_body[j] = s->op.as()->body; + new_body[j] = compute->body; } if (new_body[j][0]->is_type()) { // specially handle reduction inline for multiplre reductions. diff --git a/src/schedule/schedule_lang.cc b/src/schedule/schedule_lang.cc index 29265f2e94b8..bd703a211206 100644 --- a/src/schedule/schedule_lang.cc +++ b/src/schedule/schedule_lang.cc @@ -710,8 +710,7 @@ Schedule ScheduleNode::make(Array ops) { n->stages.push_back(stage); n->stage_map.Set(op, stage); // mark scan updates. - if (op.as()) { - const ScanOpNode* scan = op.as(); + if (const ScanOpNode* scan = op.as()) { Array inputs; for (Tensor t : scan->state_placeholder) { inputs.push_back(t); diff --git a/src/schedule/schedule_ops.cc b/src/schedule/schedule_ops.cc index 242423695464..ef76a2c1f28a 100644 --- a/src/schedule/schedule_ops.cc +++ b/src/schedule/schedule_ops.cc @@ -304,8 +304,7 @@ class SchedulePostProc : public IRMutator { } } // Specially add replacements for scan op. - if (s->op.as()) { - const ScanOpNode* scan = s->op.as(); + if (const ScanOpNode* scan = s->op.as()) { for (size_t i = 0; i < scan->update.size(); ++i) { Tensor t = s->origin_op.output(i); AddReplace(scan->init[i], t);