@@ -772,19 +772,133 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
772772 stride *= s;
773773 }
774774
775+ Array<PrimExpr> global_indices;
776+ for (auto r : global_range) {
777+ global_indices.push_back (r->min );
778+ }
779+ std::vector<PrimExpr> global_strides;
780+ PrimExpr global_stride = 1 ;
781+ for (size_t i = 0 ; i < global_tensor->shape .size (); i++) {
782+ auto s = global_tensor->shape [global_tensor->shape .size () - i - 1 ];
783+ global_strides.insert (global_strides.begin (), global_stride);
784+ global_stride *= s;
785+ }
786+
775787 ICHECK (strides.size () == indices.size ())
776788 << " strides.size() != indices.size()" << strides.size () << " "
777789 << indices.size ();
778790 PrimExpr offset = 0 ;
779791 for (size_t i = 0 ; i < indices.size (); i++) {
780792 offset += indices[i] * strides[i];
781793 }
794+ PrimExpr global_offset = 0 ;
795+ for (size_t i = 0 ; i < global_indices.size (); i++) {
796+ global_offset += global_indices[i] * global_strides[i];
797+ }
798+ auto shared_tensor_before_remap = shared_tensor;
782799 Layout shared_layout;
783800 if (T.layout_map .count (shared_tensor)) {
784801 shared_layout = T.layout_map [shared_tensor];
785802 shared_tensor = T.buffer_remap [shared_tensor];
786803 }
787804
805+ // Add 1D TMA copy when the global and shared memory is contiguous
806+ {
807+ // Check if shared_tensor->name is present in T.buffer_var_gemm
808+ // (Array<PrimExpr>) to avoid use 1D TMA copy for swizzled layout
809+ bool shared_is_contiguous = true ;
810+ for (const auto &v : T.buffer_var_gemm ) {
811+ if (v->name_hint == shared_tensor->name ) {
812+ shared_is_contiguous = false ;
813+ break ;
814+ }
815+ }
816+ bool shared_not_full_dim_encounter = false ;
817+ for (ssize_t i = shared_range.size () - 1 ; i >= 0 ; --i) {
818+ if (!shared_not_full_dim_encounter) {
819+ if (!analyzer->CanProve (shared_range[i]->extent ==
820+ shared_tensor_before_remap->shape [i] &&
821+ shared_range[i]->min == 0 )) {
822+ shared_not_full_dim_encounter = true ;
823+ }
824+ } else {
825+ if (!analyzer->CanProve (shared_range[i]->extent == 1 )) {
826+ shared_is_contiguous = false ;
827+ break ;
828+ }
829+ }
830+ }
831+ // Currently we check the empty stride of global tensor
832+ bool global_is_contiguous = !global_tensor->strides .empty ();
833+ bool global_not_full_dim_encounter = false ;
834+ for (ssize_t i = global_range.size () - 1 ; i >= 0 ; --i) {
835+ if (!global_not_full_dim_encounter) {
836+ if (!analyzer->CanProve (global_range[i]->extent ==
837+ global_tensor->shape [i] &&
838+ global_range[i]->min == 0 )) {
839+ global_not_full_dim_encounter = true ;
840+ }
841+ } else {
842+ if (!analyzer->CanProve (global_range[i]->extent == 1 )) {
843+ global_is_contiguous = false ;
844+ break ;
845+ }
846+ }
847+ }
848+ // Ensure there is element match and no OOB
849+ PrimExpr shared_elements = 1 ;
850+ for (size_t i = 0 ; i < shared_range.size (); i++) {
851+ shared_elements *= shared_range[i]->extent ;
852+ }
853+ PrimExpr global_elements = 1 ;
854+ for (size_t i = 0 ; i < global_range.size (); i++) {
855+ global_elements *= global_range[i]->extent ;
856+ }
857+ bool element_match =
858+ analyzer->CanProveEqual (shared_elements, global_elements);
859+ bool no_oob = true ;
860+ for (size_t i = 0 ; i < shared_range.size (); i++) {
861+ if (!analyzer->CanProve (shared_range[i]->min + shared_range[i]->extent <=
862+ shared_tensor_before_remap->shape [i])) {
863+ no_oob = false ;
864+ break ;
865+ }
866+ }
867+ for (size_t i = 0 ; i < global_range.size (); i++) {
868+ if (!analyzer->CanProve (global_range[i]->min + global_range[i]->extent <=
869+ global_tensor->shape [i])) {
870+ no_oob = false ;
871+ break ;
872+ }
873+ }
874+ // Add 1D TMA copy only for load
875+ if (shared_is_contiguous && global_is_contiguous && element_match &&
876+ no_oob && is_load) {
877+ PrimExpr elements = analyzer->Simplify (shared_elements);
878+ PrimExpr shared_addr = shared_tensor_before_remap.access_ptr (
879+ is_load ? 2 : 1 , DataType::Handle (), 1 , offset, elements);
880+ PrimExpr global_addr = global_tensor.access_ptr (
881+ is_load ? 1 : 2 , DataType::Handle (), 1 , global_offset, elements);
882+ Stmt tma_copy;
883+ if (is_load) {
884+ // the zero is a placeholder for mbarrier id
885+ tma_copy =
886+ Evaluate (Call (DataType::Handle (), tma_load (),
887+ {shared_addr, global_addr, 0 ,
888+ elements * shared_tensor_before_remap->dtype .bytes (),
889+ this ->eviction_policy }));
890+ } else {
891+ tma_copy =
892+ Evaluate (Call (DataType::Handle (), tma_store (),
893+ {global_addr, shared_addr,
894+ elements * shared_tensor_before_remap->dtype .bytes (),
895+ this ->eviction_policy }));
896+ }
897+ tma_copy = IfThenElse (EQ (T.thread_var , T.thread_bounds ->min ), tma_copy);
898+ return tma_copy;
899+ }
900+ }
901+
788902 TMADesc desc;
789903 // Verify copy rank
790904 desc.rank = global_tensor->shape .size ();
@@ -1221,10 +1335,11 @@ Array<PrimExpr> TMAIm2ColDesc::EncodeCallArgs() const {
12211335
12221336// Register the Copy operation with TVM's TIR system
12231337// This makes the copy operation available for use in TVM programs
1224- // - Takes 4 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma
1338+ // - Takes 5 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma,
1339+ // eviction_policy
12251340// - Marked as opaque since it has side effects (memory writes)
12261341TIR_REGISTER_TL_OP (Copy, copy)
1227- .set_num_inputs(4 )
1342+ .set_num_inputs(5 )
12281343 .set_attr<TCallEffectKind>(" TCallEffectKind" ,
12291344 Integer (CallEffectKind::kOpaque ));
12301345
0 commit comments