diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index 950bc88efd345..b3e0e8f1274e4 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -1078,7 +1078,7 @@ class ReIndexRewriter : public StmtExprMutator { Region region_; }; -void CheckRegionCover(const ScheduleState& self, StmtSRef scope_root) { +void CheckRegionCover(const ScheduleState& self, StmtSRef scope_root, Buffer read_buffer) { class NotRegionCoverError : public ScheduleError { public: explicit NotRegionCoverError(IRModule mod, Block block) : mod_(mod), block_(block) {} @@ -1095,12 +1095,16 @@ The region cover property require to hold for every of its child blocks IRModule mod_; Block block_; }; - BlockScope scope = self->GetBlockScope(scope_root); - for (const auto& kv : scope->dst2deps) { - const StmtSRef& consumer_block_sref = kv.first; - if (!self->block_info.at(consumer_block_sref).region_cover) { - const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root); - throw NotRegionCoverError(self->mod, GetRef(block)); + + for (const auto& child_block_sref : tir::GetChildBlocks(self, scope_root)) { + const BlockNode* child_block = TVM_SREF_TO_BLOCK(child_block_sref); + for (const BufferRegion& region : child_block->reads) { + if (region->buffer.same_as(read_buffer)) { + if (!self->block_info.at(child_block_sref).region_cover) { + const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root); + throw NotRegionCoverError(self->mod, GetRef(block)); + } + } } } } @@ -1129,7 +1133,7 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff GetNthAccessBuffer(self, GetRef(block), read_buffer_index, BufferIndexType::kRead); StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); // Check required region cover for cache_read - CheckRegionCover(self, scope_sref); + CheckRegionCover(self, scope_sref, read_buffer); const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref); // Step 2. Create CacheStageInfo @@ -1281,7 +1285,7 @@ Array CacheInplace(ScheduleState self, const StmtSRef& block_sref, int StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); // Check 3. Check required region cover for cache_read - CheckRegionCover(self, scope_sref); + CheckRegionCover(self, scope_sref, buffer); // Check 4. Check if target block both read & write target buffer. const BlockNode* rw_block = TVM_SREF_TO_BLOCK(block_sref); @@ -1318,6 +1322,8 @@ Array CacheInplace(ScheduleState self, const StmtSRef& block_sref, int StmtSRef result_block_sref = self->stmt2ref.at(cache_read_stage.get()); BlockInfo& block_info_read = self->block_info[result_block_sref]; block_info_read.affine_binding = CalculateAffineFlag(self, result_block_sref); + block_info_read.region_cover = true; + block_info_read.scope->stage_pipeline = false; results_block_sref.push_back(result_block_sref); // Do cache write