@@ -307,8 +307,10 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
307307 // (const index frag_a interacts with non-const index frag_b)
308308 // - No propagation needed: shared_a[i] = frag_a[0]
309309 // (const index frag_a with non-fragment buffer)
310+
310311 bool allow_layout_propgate =
311- fragment_buffers.size () > const_index_fragment_buffer.size ();
312+ const_index_fragment_buffer.empty () ||
313+ (fragment_buffers.size () > const_index_fragment_buffer.size ());
312314
313315 // Step 1: try to infer loop's partition from a source fragment
314316 Buffer source_buffer, read_source_buffer;
@@ -387,57 +389,46 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
387389 if (source_buffer.defined () && allow_layout_propgate) {
388390 loop_layout_ = compute_loop_layout_from_buffer (source_buffer);
389391 } else if (level == InferLevel::kFree ) {
392+ // For free layout inference
393+ // If replication exists and buffer has cross-thread shared memory access,
394+ // add predicate
395+ bool has_cross_thread_access = false ;
396+ PostOrderVisit (root_, [&](const ObjectRef &obj) {
397+ if (const auto *store = obj.as <BufferStoreNode>()) {
398+ // check if scope is shared or global
399+ if (store->buffer .scope () == " shared" ||
400+ store->buffer .scope () == " shared.dyn" ||
401+ store->buffer .scope () == " global" ) {
402+ has_cross_thread_access = true ;
403+ }
404+ } else if (const auto *load = obj.as <BufferLoadNode>()) {
405+ // check if scope is shared or global
406+ if (load->buffer .scope () == " shared" ||
407+ load->buffer .scope () == " shared.dyn" ||
408+ load->buffer .scope () == " global" ) {
409+ has_cross_thread_access = true ;
410+ }
411+ }
412+ });
413+
414+ // check if loop body contains a "pure" buffer store (i.e., direct
415+ // assignment, not compound update)
416+ bool has_pure_buffer_store = false ;
417+ PostOrderVisit (root_, [&](const ObjectRef &obj) {
418+ if (const auto *store = obj.as <BufferStoreNode>()) {
419+ // Check if the value is a direct load from another buffer (i.e., b[i]
420+ // = a[i])
421+ if (const auto *load = store->value .as <BufferLoadNode>()) {
422+ has_pure_buffer_store = true ;
423+ }
424+ }
425+ });
426+
390427 if (read_source_buffer.defined () && allow_layout_propgate) {
391428 loop_layout_ = compute_loop_layout_from_buffer (read_source_buffer);
392429 // // Loop don't need to be replicated.
393430 // if (!is_one(loop_layout_->ReplicateExtent()))
394431 // loop_layout_ = loop_layout_->DeReplicate();
395-
396- // For free layout inference
397- // If replication exists and buffer has cross-thread shared memory access,
398- // add predicate
399- bool has_cross_thread_access = false ;
400- PostOrderVisit (root_, [&](const ObjectRef &obj) {
401- if (const auto *store = obj.as <BufferStoreNode>()) {
402- // check if scope is shared or global
403- if (store->buffer .scope () == " shared" ||
404- store->buffer .scope () == " shared.dyn" ||
405- store->buffer .scope () == " global" ) {
406- has_cross_thread_access = true ;
407- }
408- } else if (const auto *load = obj.as <BufferLoadNode>()) {
409- // check if scope is shared or global
410- if (load->buffer .scope () == " shared" ||
411- load->buffer .scope () == " shared.dyn" ||
412- load->buffer .scope () == " global" ) {
413- has_cross_thread_access = true ;
414- }
415- }
416- });
417-
418- // check if loop body contains a "pure" buffer store (i.e., direct
419- // assignment, not compound update)
420- bool has_pure_buffer_store = false ;
421- PostOrderVisit (root_, [&](const ObjectRef &obj) {
422- if (const auto *store = obj.as <BufferStoreNode>()) {
423- // Check if the value is a direct load from another buffer (i.e., b[i]
424- // = a[i])
425- if (const auto *load = store->value .as <BufferLoadNode>()) {
426- has_pure_buffer_store = true ;
427- }
428- }
429- });
430-
431- if (!is_one (loop_layout_->ReplicateExtent ()) && has_cross_thread_access &&
432- !has_pure_buffer_store) {
433- auto inv = loop_layout_->Inverse ();
434- Array<PrimExpr> fwd;
435- for (size_t i = 0 ; i < loop_layout_->OutputDim (); i++)
436- fwd.push_back (0 );
437- fwd.push_back (InputPlaceholder (0 ) - T.thread_bounds ->min );
438- auto rep = inv->Forward (fwd).back ();
439- AddPredicate (EQ (rep, 0 ));
440- }
441432 }
442433
443434 if (!loop_layout_.defined ()) {
@@ -486,6 +477,16 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
486477 DLOG (INFO) << " [PlanLoopPartition] loop_layout_ = "
487478 << loop_layout_->DebugOutput () << ' \n ' ;
488479 }
480+ if (!is_one (loop_layout_->ReplicateExtent ()) && has_cross_thread_access &&
481+ !has_pure_buffer_store) {
482+ auto inv = loop_layout_->Inverse ();
483+ Array<PrimExpr> fwd;
484+ for (size_t i = 0 ; i < loop_layout_->OutputDim (); i++)
485+ fwd.push_back (0 );
486+ fwd.push_back (InputPlaceholder (0 ) - T.thread_bounds ->min );
487+ auto rep = inv->Forward (fwd).back ();
488+ AddPredicate (EQ (rep, 0 ));
489+ }
489490 } else {
490491 return {};
491492 }
0 commit comments