Skip to content

Commit 73ad4d3

Browse files
committed
lint fix
1 parent 9aeba45 commit 73ad4d3

File tree

1 file changed

+48
-47
lines changed

1 file changed

+48
-47
lines changed

src/op/parallel.cc

Lines changed: 48 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -307,8 +307,10 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
307307
// (const index frag_a interacts with non-const index frag_b)
308308
// - No propagation needed: shared_a[i] = frag_a[0]
309309
// (const index frag_a with non-fragment buffer)
310+
310311
bool allow_layout_propgate =
311-
fragment_buffers.size() > const_index_fragment_buffer.size();
312+
const_index_fragment_buffer.empty() ||
313+
(fragment_buffers.size() > const_index_fragment_buffer.size());
312314

313315
// Step 1: try to infer loop's partition from a source fragment
314316
Buffer source_buffer, read_source_buffer;
@@ -387,57 +389,46 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
387389
if (source_buffer.defined() && allow_layout_propgate) {
388390
loop_layout_ = compute_loop_layout_from_buffer(source_buffer);
389391
} else if (level == InferLevel::kFree) {
392+
// For free layout inference
393+
// If replication exists and buffer has cross-thread shared memory access,
394+
// add predicate
395+
bool has_cross_thread_access = false;
396+
PostOrderVisit(root_, [&](const ObjectRef &obj) {
397+
if (const auto *store = obj.as<BufferStoreNode>()) {
398+
// check if scope is shared or global
399+
if (store->buffer.scope() == "shared" ||
400+
store->buffer.scope() == "shared.dyn" ||
401+
store->buffer.scope() == "global") {
402+
has_cross_thread_access = true;
403+
}
404+
} else if (const auto *load = obj.as<BufferLoadNode>()) {
405+
// check if scope is shared or global
406+
if (load->buffer.scope() == "shared" ||
407+
load->buffer.scope() == "shared.dyn" ||
408+
load->buffer.scope() == "global") {
409+
has_cross_thread_access = true;
410+
}
411+
}
412+
});
413+
414+
// check if loop body contains a "pure" buffer store (i.e., direct
415+
// assignment, not compound update)
416+
bool has_pure_buffer_store = false;
417+
PostOrderVisit(root_, [&](const ObjectRef &obj) {
418+
if (const auto *store = obj.as<BufferStoreNode>()) {
419+
// Check if the value is a direct load from another buffer (i.e., b[i]
420+
// = a[i])
421+
if (const auto *load = store->value.as<BufferLoadNode>()) {
422+
has_pure_buffer_store = true;
423+
}
424+
}
425+
});
426+
390427
if (read_source_buffer.defined() && allow_layout_propgate) {
391428
loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer);
392429
// // Loop don't need to be replicated.
393430
// if (!is_one(loop_layout_->ReplicateExtent()))
394431
// loop_layout_ = loop_layout_->DeReplicate();
395-
396-
// For free layout inference
397-
// If replication exists and buffer has cross-thread shared memory access,
398-
// add predicate
399-
bool has_cross_thread_access = false;
400-
PostOrderVisit(root_, [&](const ObjectRef &obj) {
401-
if (const auto *store = obj.as<BufferStoreNode>()) {
402-
// check if scope is shared or global
403-
if (store->buffer.scope() == "shared" ||
404-
store->buffer.scope() == "shared.dyn" ||
405-
store->buffer.scope() == "global") {
406-
has_cross_thread_access = true;
407-
}
408-
} else if (const auto *load = obj.as<BufferLoadNode>()) {
409-
// check if scope is shared or global
410-
if (load->buffer.scope() == "shared" ||
411-
load->buffer.scope() == "shared.dyn" ||
412-
load->buffer.scope() == "global") {
413-
has_cross_thread_access = true;
414-
}
415-
}
416-
});
417-
418-
// check if loop body contains a "pure" buffer store (i.e., direct
419-
// assignment, not compound update)
420-
bool has_pure_buffer_store = false;
421-
PostOrderVisit(root_, [&](const ObjectRef &obj) {
422-
if (const auto *store = obj.as<BufferStoreNode>()) {
423-
// Check if the value is a direct load from another buffer (i.e., b[i]
424-
// = a[i])
425-
if (const auto *load = store->value.as<BufferLoadNode>()) {
426-
has_pure_buffer_store = true;
427-
}
428-
}
429-
});
430-
431-
if (!is_one(loop_layout_->ReplicateExtent()) && has_cross_thread_access &&
432-
!has_pure_buffer_store) {
433-
auto inv = loop_layout_->Inverse();
434-
Array<PrimExpr> fwd;
435-
for (size_t i = 0; i < loop_layout_->OutputDim(); i++)
436-
fwd.push_back(0);
437-
fwd.push_back(InputPlaceholder(0) - T.thread_bounds->min);
438-
auto rep = inv->Forward(fwd).back();
439-
AddPredicate(EQ(rep, 0));
440-
}
441432
}
442433

443434
if (!loop_layout_.defined()) {
@@ -486,6 +477,16 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
486477
DLOG(INFO) << "[PlanLoopPartition] loop_layout_ = "
487478
<< loop_layout_->DebugOutput() << '\n';
488479
}
480+
if (!is_one(loop_layout_->ReplicateExtent()) && has_cross_thread_access &&
481+
!has_pure_buffer_store) {
482+
auto inv = loop_layout_->Inverse();
483+
Array<PrimExpr> fwd;
484+
for (size_t i = 0; i < loop_layout_->OutputDim(); i++)
485+
fwd.push_back(0);
486+
fwd.push_back(InputPlaceholder(0) - T.thread_bounds->min);
487+
auto rep = inv->Forward(fwd).back();
488+
AddPredicate(EQ(rep, 0));
489+
}
489490
} else {
490491
return {};
491492
}

0 commit comments

Comments
 (0)