Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[TIR] More hygenic TVM_SREF macros (apache#12607)
Browse files Browse the repository at this point in the history
Previously, the `TVM_SREF_TO_BLOCK`, `TVM_SREF_TO_FOR`, and
`TVM_TYPE_AS` macros required both the input and output variables.
The input variable name is useful for improving the error message
returned, but the output variable name isn't necessary for this
functionality, and prevents the macro from being used as part of an
expression.

* Generate an immediately-invoked lambda expression to allow for an
  independently-scoped `result` variable.

* Use parentheses around the input argument, in case the sref is
  the result of an expression.

* Update all call sites to remove the macro argument providing the
  first argument.
  • Loading branch information
Lunderberg authored and xinetzone committed Nov 25, 2022
1 parent f43baf4 commit 0dfbff3
Show file tree
Hide file tree
Showing 30 changed files with 133 additions and 120 deletions.
4 changes: 2 additions & 2 deletions src/meta_schedule/mutator/mutate_parallel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ const BlockRVNode* GetInstGetBlockOutput(const Instruction& inst) {
return nullptr;
}
ICHECK_EQ(inst->outputs.size(), 1);
const BlockRVNode* block = TVM_TYPE_AS(block, inst->outputs[0], BlockRVNode);
const BlockRVNode* block = TVM_TYPE_AS(inst->outputs[0], BlockRVNode);
return block;
}

Expand All @@ -82,7 +82,7 @@ std::vector<std::vector<int64_t>> AnalyzeParallel(const ScheduleState& self,
Array<StmtSRef> block_srefs =
tir::GetBlocks(self, block_name, self->mod->GetGlobalVar(func_name));
ICHECK_EQ(block_srefs.size(), 1);
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_srefs[0]);
const BlockNode* block = TVM_SREF_TO_BLOCK(block_srefs[0]);
ScopeBlockLoopInfo info = GetScopeBlockLoopInfo(GetRef<Block>(block));
std::vector<std::vector<int64_t>> results;
results.reserve(info.realizes.size());
Expand Down
8 changes: 4 additions & 4 deletions src/meta_schedule/mutator/mutate_thread_binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,25 +109,25 @@ std::vector<MutateThreadBindingNode::Candidate> MutateThreadBindingNode::FindCan
for (const Instruction& inst : trace->insts) {
if (inst->kind.same_as(inst_sample_categorical)) {
ICHECK_EQ(inst->outputs.size(), 1);
const PrimExprNode* var_rv = TVM_TYPE_AS(var_rv, inst->outputs[0], PrimExprNode);
const PrimExprNode* var_rv = TVM_TYPE_AS(inst->outputs[0], PrimExprNode);
sample_insts[var_rv] = inst.get();
} else if (is_split_by_sample(inst)) {
CHECK_EQ(inst->outputs.size(), 2);
// Only consider the inner loop, which can be bound to threadIdx.x
const tir::LoopRVNode* var_rv = TVM_TYPE_AS(var_rv, inst->outputs[1], tir::LoopRVNode);
const tir::LoopRVNode* var_rv = TVM_TYPE_AS(inst->outputs[1], tir::LoopRVNode);
sampled_split_insts[var_rv] = inst.get();
} else if (is_thread_binding_by_sample(inst)) {
bind_insts.push_back(inst.get());
}
}

for (const InstructionNode* bind_inst : bind_insts) {
const auto* loop_rv = TVM_TYPE_AS(loop_rv, bind_inst->inputs[0], tir::LoopRVNode);
const auto* loop_rv = TVM_TYPE_AS(bind_inst->inputs[0], tir::LoopRVNode);
auto split_it = sampled_split_insts.find(loop_rv);
ICHECK(split_it != sampled_split_insts.end());
const InstructionNode* split_inst = split_it->second;

const auto* expr_rv = TVM_TYPE_AS(expr_rv, split_inst->inputs[2], PrimExprNode);
const auto* expr_rv = TVM_TYPE_AS(split_inst->inputs[2], PrimExprNode);
auto sample_it = sample_insts.find(expr_rv);
ICHECK(sample_it != sample_insts.end());
const InstructionNode* sample_inst = sample_it->second;
Expand Down
4 changes: 2 additions & 2 deletions src/meta_schedule/mutator/mutate_tile_size.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ using tir::Trace;
* \return The result of downcast
*/
std::vector<int64_t> DowncastTilingDecision(const ObjectRef& decision) {
const auto* arr = TVM_TYPE_AS(arr, decision, runtime::ArrayNode);
const auto* arr = TVM_TYPE_AS(decision, runtime::ArrayNode);
return support::AsVector<ObjectRef, int64_t>(GetRef<Array<ObjectRef>>(arr));
}

Expand Down Expand Up @@ -123,7 +123,7 @@ void FindSampleVectorize(const Trace& trace, std::vector<Instruction>* inst,
if (inst->kind.same_as(inst_sample_categorical)) {
ICHECK_EQ(inst->outputs.size(), 1);
if (annotated.count(inst->outputs[0].get())) {
const auto* d = TVM_TYPE_AS(d, decision, IntImmNode);
const auto* d = TVM_TYPE_AS(decision, IntImmNode);
instructions.push_back(inst);
decisions.push_back(d->value);
}
Expand Down
4 changes: 2 additions & 2 deletions src/meta_schedule/mutator/mutate_unroll.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ bool FindUnrollDecision(const Trace& trace, TRandState* rand_state,
for (const Instruction& inst : trace->insts) {
if (inst->kind.same_as(inst_sample_categorical)) {
ICHECK_EQ(inst->outputs.size(), 1);
const PrimExprNode* var_rv = TVM_TYPE_AS(var_rv, inst->outputs[0], PrimExprNode);
const PrimExprNode* var_rv = TVM_TYPE_AS(inst->outputs[0], PrimExprNode);
sample_insts[var_rv] = inst.get();
} else if (IsAnnotateWithUnroll(inst)) {
ann_insts.push_back(inst.get());
Expand All @@ -103,7 +103,7 @@ bool FindUnrollDecision(const Trace& trace, TRandState* rand_state,
}
const InstructionNode* ann_inst = ann_insts[tir::SampleInt(rand_state, 0, n_ann_insts)];
ICHECK_EQ(ann_inst->inputs.size(), 2);
const auto* var_rv = TVM_TYPE_AS(var_rv, ann_inst->inputs[1], PrimExprNode);
const auto* var_rv = TVM_TYPE_AS(ann_inst->inputs[1], PrimExprNode);
ICHECK(sample_insts.count(var_rv));
const InstructionNode* sample_inst = sample_insts.at(var_rv);
ICHECK_EQ(sample_inst->attrs.size(), 2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv,
int64_t prod_extent = 1;
for (int i = 0; i < n_loops && loop_types[i] == IterVarType::kDataPar; ++i) {
const StmtSRef& loop_sref = loop_srefs[i];
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
if (HasAnnOrBinding(loop)) {
break;
}
Expand Down Expand Up @@ -262,7 +262,7 @@ void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv,
for (int i = n_loops - 1;
i >= 0 && loop_types[i] == IterVarType::kDataPar && num_fusible < max_fusible; --i) {
const StmtSRef& loop_sref = loop_srefs[i];
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
if (HasAnnOrBinding(loop)) {
break;
}
Expand Down
2 changes: 1 addition & 1 deletion src/meta_schedule/schedule_rule/auto_bind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void BindBlockThreadIdx(const tir::Schedule& sch, const tir::BlockRV& block_rv,
int i_spatial_loop = -1;
for (int i = 0; i < n; ++i) {
const StmtSRef& loop_sref = loops[i];
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
runtime::ThreadScope thread_scope = GetThreadScope(loop);
if (IsBlockIdx(thread_scope)) {
if (i_block_idx == -1) {
Expand Down
2 changes: 1 addition & 1 deletion src/meta_schedule/schedule_rule/auto_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch,
StmtSRef block_sref = sch->GetSRef(block_rv);
bool is_pure_sptial = IsInSpatialPrimFunc(sch, block_sref);
ScheduleState state = sch->state();
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
BlockRealize realize = GetBlockRealize(state, block_sref);
// Cond 1. The block has only one write buffer
if (block->writes.size() != 1) {
Expand Down
2 changes: 1 addition & 1 deletion src/meta_schedule/schedule_rule/multi_level_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ namespace tir {
* of multi-level tiling, so it's intentionally kept inside this file not in the analysis header
*/
std::vector<int> GetReadBufferNDims(const StmtSRef& block_sref) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
const BufferNode* write_buffer = block->writes[0]->buffer.get();
int n = block->reads.size();
std::vector<int> results(n, -1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ Optional<LoopRV> MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin(
tir::StmtSRef block_sref = state->sch->GetSRef(state->block_rv);

// Add reindex stages
const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
// Hold the reference of the block before reindex
const tir::Block block_before_reindex = GetRef<tir::Block>(block);
if (block->reads.size() != 2 || block->writes.size() != 1) {
Expand Down Expand Up @@ -488,7 +488,7 @@ Optional<LoopRV> MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin(
}
visited_buffers.insert(lhs_buffer);
// Refresh block pointer (block sref is not invalidated)
block = TVM_SREF_TO_BLOCK(block, block_sref);
block = TVM_SREF_TO_BLOCK(block_sref);
const tir::BufferRegion& reindexed_buffer_region = tir::GetNthAccessBufferRegion(
state->sch->state(), GetRef<tir::Block>(block), buffer_index, index_type);
auto sub_index_map = f_get_sub_index_map(lhs_buffer, reindexed_buffer_region->region);
Expand Down
2 changes: 1 addition & 1 deletion src/meta_schedule/schedule_rule/random_compute_location.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class RandomComputeLocationNode : public ScheduleRuleNode {
private:
bool CheckConditions(const tir::Schedule sch, const tir::BlockRV& block_rv) const {
tir::StmtSRef block_sref = sch->GetSRef(block_rv);
const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
TVM_SREF_TO_BLOCK(block_sref);

// Cond 1. The block is not the root block.
if (block_sref->parent == nullptr) {
Expand Down
2 changes: 1 addition & 1 deletion src/meta_schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ inline std::string Concat(const Array<String>& strs, const std::string& delim) {
*/
inline tir::BlockRV GetRVFromSRef(const tir::Schedule& sch, const tir::StmtSRef& block_sref,
const String& global_var_name) {
const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
return sch->GetBlock(block->name_hint, global_var_name);
}

Expand Down
48 changes: 24 additions & 24 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ Definition of a scope that is a stage pipeline:
if (require_stage_pipeline) {
bool stage_pipeline = self->GetBlockInfo(scope_root_sref).scope->stage_pipeline;
if (stage_pipeline == false) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, scope_root_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root_sref);
throw NotStagePipelineError(self->mod, GetRef<Block>(block));
}
}
Expand Down Expand Up @@ -229,7 +229,7 @@ bool IsDominantBlock(const ScheduleState& self, const StmtSRef& scope_root_sref,
}
}
// Check whether the input block is the only writer of its outputs
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
for (const BufferRegion& write_region : block->writes) {
if (buffer_writers.count(write_region->buffer)) {
if (buffer_writers.at(write_region->buffer).size() != 1) {
Expand All @@ -252,7 +252,7 @@ bool IsDominantBlock(const ScheduleState& self, const StmtSRef& scope_root_sref,
int CheckCompleteBlockErrorCode(const ScheduleState& self, const StmtSRef& block_sref,
const StmtSRef& scope_root_sref) {
// Cond 1. All block vars are data parallel
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
for (const IterVar& iter_var : block->iter_vars) {
if (iter_var->iter_type != kDataPar) {
return 1;
Expand Down Expand Up @@ -328,7 +328,7 @@ void CheckCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref,

int error_code = CheckCompleteBlockErrorCode(self, block_sref, scope_root_sref);
if (error_code != 0) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
throw IncompleteBlockError(self->mod, GetRef<Block>(block), error_code);
}
}
Expand All @@ -344,7 +344,7 @@ void CheckCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref,
*/
int CheckReductionBlockErrorCode(const ScheduleState& self, const StmtSRef& block_sref,
const StmtSRef& scope_root_sref) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
// Cond 1. The block has the `init` statement.
if (!block->init.defined()) {
return 1;
Expand Down Expand Up @@ -394,7 +394,7 @@ void CheckReductionBlock(const ScheduleState& self, const StmtSRef& block_sref,

int error_code = CheckReductionBlockErrorCode(self, block_sref, scope_root_sref);
if (error_code != 0) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
throw NotReductionBlockError(self->mod, GetRef<Block>(block), error_code);
}
}
Expand Down Expand Up @@ -441,7 +441,7 @@ void CheckCompleteOrReductionBlock(const ScheduleState& self, const StmtSRef& bl
if (reduction_block_error_code == 0) {
return;
}
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
throw NotCompleteOrReductionBlockError(self->mod, GetRef<Block>(block), complete_block_error_code,
reduction_block_error_code);
}
Expand Down Expand Up @@ -491,7 +491,7 @@ void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subt
int local_complete_block_code = CheckCompleteBlockErrorCode(self, block_sref, subtree_root),
local_reduction_block_code = CheckReductionBlockErrorCode(self, block_sref, subtree_root);
if (local_complete_block_code != 0 && local_reduction_block_code != 0) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
throw NotCompactDataFlowError(self->mod, GetRef<Stmt>(subtree_root->stmt),
GetRef<Block>(block), local_complete_block_code,
local_reduction_block_code);
Expand All @@ -501,8 +501,8 @@ void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subt

bool IsOutputBlock(const ScheduleState& self, const StmtSRef& block_sref,
const StmtSRef& scope_root_sref) {
const BlockNode* scope_root = TVM_SREF_TO_BLOCK(scope_root, scope_root_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
const BlockNode* scope_root = TVM_SREF_TO_BLOCK(scope_root_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
std::unordered_set<const BufferNode*> scope_allocated;
scope_allocated.reserve(scope_root->alloc_buffers.size());
for (const Buffer& buffer : scope_root->alloc_buffers) {
Expand Down Expand Up @@ -532,7 +532,7 @@ void CheckNotOutputBlock(const ScheduleState& self, const StmtSRef& block_sref,
Block block_;
};
if (IsOutputBlock(self, block_sref, scope_root_sref)) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
throw OutputBlockError(self->mod, GetRef<Block>(block));
}
}
Expand All @@ -547,12 +547,12 @@ std::vector<IterVarType> GetBlockVarTypes(const BlockNode* block) {
}

std::vector<IterVarType> GetBlockVarTypes(const StmtSRef& block_sref) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
return GetBlockVarTypes(block);
}

bool IsWriteCache(const StmtSRef& block_sref) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
if (block->writes.size() != 1) {
return false;
}
Expand Down Expand Up @@ -751,7 +751,7 @@ void CheckLoopStartsWithZero(const ScheduleState& self, const StmtSRef& loop_sre
IRModule mod_;
For loop_;
};
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
if (!analyzer->CanProve(loop->min == 0)) {
throw LoopNotStartWithZeroError(self->mod, GetRef<For>(loop));
}
Expand Down Expand Up @@ -856,7 +856,7 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr
const BlockRealizeNode* result;
};

const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
if (block_sref->parent == nullptr) {
const PrimFuncNode* func = GetRootPrimFunc(self->mod, block, nullptr);
return Downcast<BlockRealize>(func->body);
Expand All @@ -870,7 +870,7 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr
}

IterVarType GetLoopIterType(const StmtSRef& loop_sref) {
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
const Var& loop_var = loop->loop_var;
int n_spatial = 0;
int n_reduce = 0;
Expand Down Expand Up @@ -1924,7 +1924,7 @@ void CheckStorageScope(const ScheduleState& self, String storage_scope) {
}

bool IsSpatial(const StmtSRef& block_sref) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
for (const IterVar& iter_var : block->iter_vars) {
if (iter_var->iter_type != IterVarType::kDataPar) {
return false;
Expand All @@ -1934,14 +1934,14 @@ bool IsSpatial(const StmtSRef& block_sref) {
}

bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
TVM_SREF_TO_BLOCK(block_sref);
Array<StmtSRef> loops = GetLoops(block_sref);
Array<PrimExpr> binds = GetBlockRealize(self, block_sref)->iter_values;
if (loops.size() != binds.size()) {
return false;
}
for (int i = 0, n = loops.size(); i < n; ++i) {
const ForNode* loop = TVM_SREF_TO_FOR(loop, loops[i]);
const ForNode* loop = TVM_SREF_TO_FOR(loops[i]);
if (binds[i].get() != loop->loop_var.get()) {
return false;
}
Expand All @@ -1953,7 +1953,7 @@ bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& block_sref
if (HasBeenMultiLevelTiled(block_sref)) {
return false;
}
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
if (block->writes.size() != 1 || block->reads.empty() || IsSpatial(block_sref) ||
!IsTrivialBinding(self, block_sref)) {
return false;
Expand Down Expand Up @@ -2065,7 +2065,7 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, //
const tir::StmtSRef& block_sref, //
int64_t max_parallel_extent, //
int64_t max_parallel_basic) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
Array<tir::StmtSRef> loops = tir::GetLoops(block_sref);

// Cond 1. The block has only one write buffer
Expand Down Expand Up @@ -2100,9 +2100,9 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, //
}

// Cond 5.
const ForNode* loop_i = TVM_SREF_TO_FOR(loop_i, loops[i]);
const ForNode* loop_i = TVM_SREF_TO_FOR(loops[i]);
if (i < loops.size() - 1) {
const ForNode* loop_i1 = TVM_SREF_TO_FOR(loop_i1, loops[i + 1]);
const ForNode* loop_i1 = TVM_SREF_TO_FOR(loops[i + 1]);
if (loop_i->body.get() != loop_i1) {
return false;
}
Expand Down Expand Up @@ -2194,7 +2194,7 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
TensorIntrinDescInfo desc_info = ExtractTensorIntrinDescInfo(&analyzer, desc_func);
// Step 2. Collect loops from block_sref
const tir::StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false);
const tir::BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref);
TVM_SREF_TO_BLOCK(scope_sref);
std::vector<const tir::ForNode*> block_loops;
std::unordered_set<const tir::VarNode*> block_loop_vars;
{
Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/block_scope.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ BlockScope::BlockScope(const Array<StmtSRef>& child_block_srefs) {
SMap<Buffer, Array<StmtSRef>> buffer_readers;
SMap<Buffer, Array<StmtSRef>>& buffer_writers = n->buffer_writers;
for (const StmtSRef& child_block_sref : child_block_srefs) {
const BlockNode* child_block = TVM_SREF_TO_BLOCK(child_block, child_block_sref);
const BlockNode* child_block = TVM_SREF_TO_BLOCK(child_block_sref);
// Step 1. Update `buffer_readers` and `buffer_writers` for each buffer
for (const BufferRegion& region : child_block->reads) {
buffer_readers[region->buffer].push_back(child_block_sref);
Expand Down
4 changes: 2 additions & 2 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ BlockRV ConcreteScheduleNode::GetBlock(const String& name, const Optional<String
: name_(name), mod_(mod), blocks_{} {
blocks_.reserve(blocks.size());
for (const StmtSRef& block_sref : blocks) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
blocks_.push_back(GetRef<Block>(block));
}
}
Expand Down Expand Up @@ -432,7 +432,7 @@ Array<LoopRV> ConcreteScheduleNode::Split(const LoopRV& loop_rv,

// Prepare for the splitting
StmtSRef loop_sref = this->GetSRef(loop_rv);
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
Array<PrimExpr> factors;
factors.reserve(factor_rvs.size());
int infer_index = -1;
Expand Down
Loading

0 comments on commit 0dfbff3

Please sign in to comment.