Skip to content

Commit

Permalink
[TIR][Schedule] Fix region_cover checking for cache related primitives
Browse files Browse the repository at this point in the history
  • Loading branch information
Min Chen committed Nov 10, 2022
1 parent f8cc0b1 commit ed2157c
Showing 1 changed file with 15 additions and 9 deletions.
24 changes: 15 additions & 9 deletions src/tir/schedule/primitive/cache_read_write.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1078,7 +1078,7 @@ class ReIndexRewriter : public StmtExprMutator {
Region region_;
};

void CheckRegionCover(const ScheduleState& self, StmtSRef scope_root) {
void CheckRegionCover(const ScheduleState& self, StmtSRef scope_root, Buffer read_buffer) {
class NotRegionCoverError : public ScheduleError {
public:
explicit NotRegionCoverError(IRModule mod, Block block) : mod_(mod), block_(block) {}
Expand All @@ -1095,12 +1095,16 @@ The region cover property require to hold for every of its child blocks
IRModule mod_;
Block block_;
};
BlockScope scope = self->GetBlockScope(scope_root);
for (const auto& kv : scope->dst2deps) {
const StmtSRef& consumer_block_sref = kv.first;
if (!self->block_info.at(consumer_block_sref).region_cover) {
const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root);
throw NotRegionCoverError(self->mod, GetRef<Block>(block));

for (const auto& child_block_sref : tir::GetChildBlocks(self, scope_root)) {
const BlockNode* child_block = TVM_SREF_TO_BLOCK(child_block_sref);
for (const BufferRegion& region : child_block->reads) {
if (region->buffer.same_as(read_buffer)) {
if (!self->block_info.at(child_block_sref).region_cover) {
const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root);
throw NotRegionCoverError(self->mod, GetRef<Block>(block));
}
}
}
}
}
Expand Down Expand Up @@ -1129,7 +1133,7 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff
GetNthAccessBuffer(self, GetRef<Block>(block), read_buffer_index, BufferIndexType::kRead);
StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);
// Check required region cover for cache_read
CheckRegionCover(self, scope_sref);
CheckRegionCover(self, scope_sref, read_buffer);
const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref);

// Step 2. Create CacheStageInfo
Expand Down Expand Up @@ -1281,7 +1285,7 @@ Array<StmtSRef> CacheInplace(ScheduleState self, const StmtSRef& block_sref, int
StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);

// Check 3. Check required region cover for cache_read
CheckRegionCover(self, scope_sref);
CheckRegionCover(self, scope_sref, buffer);

// Check 4. Check if target block both read & write target buffer.
const BlockNode* rw_block = TVM_SREF_TO_BLOCK(block_sref);
Expand Down Expand Up @@ -1318,6 +1322,8 @@ Array<StmtSRef> CacheInplace(ScheduleState self, const StmtSRef& block_sref, int
StmtSRef result_block_sref = self->stmt2ref.at(cache_read_stage.get());
BlockInfo& block_info_read = self->block_info[result_block_sref];
block_info_read.affine_binding = CalculateAffineFlag(self, result_block_sref);
block_info_read.region_cover = true;
block_info_read.scope->stage_pipeline = false;
results_block_sref.push_back(result_block_sref);

// Do cache write
Expand Down

0 comments on commit ed2157c

Please sign in to comment.