Skip to content

Conversation

@clementval
Copy link
Contributor

When doing data transfer with dynamic sized array, we are currently generating a data transfer between two descriptors. If the shape values can be provided, we can keep the data transfer between two references. This patch adds the shape operands to the operation.

This will be exploited in lowering in a follow up patch.

When doing data transfer with dynamic sized array, we are currently
generating a data transfer between two descriptors. If the shape
values can be provided, we can keep the data transfer between two
references. This patch adds the shape operands to the operation.
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Aug 16, 2024
@llvmbot
Copy link
Member

llvmbot commented Aug 16, 2024

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

When doing data transfer with dynamic sized array, we are currently generating a data transfer between two descriptors. If the shape values can be provided, we can keep the data transfer between two references. This patch adds the shape operands to the operation.

This will be exploited in lowering in a follow up patch.


Full diff: https://github.com/llvm/llvm-project/pull/104631.diff

4 Files Affected:

  • (modified) flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td (+2-1)
  • (modified) flang/lib/Lower/Bridge.cpp (+11-8)
  • (modified) flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp (+5)
  • (modified) flang/test/Fir/cuf-invalid.fir (+31)
diff --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
index e95af629ef32f1..3e2d897ff56156 100644
--- a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
+++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
@@ -161,10 +161,11 @@ def cuf_DataTransferOp : cuf_Op<"data_transfer", []> {
 
   let arguments = (ins Arg<AnyType, "", [MemRead]>:$src,
                        Arg<AnyRefOrBoxType, "", [MemWrite]>:$dst,
+                       Variadic<AnyIntegerType>:$shape,
                        cuf_DataTransferKindAttr:$transfer_kind);
 
   let assemblyFormat = [{
-    $src `to` $dst attr-dict `:` type(operands)
+    $src `to` $dst (`,` $shape^ `:` type($shape) )? attr-dict `:` type($src) `,` type($dst)
   }];
 
   let hasVerifier = 1;
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index ccbb481f472d81..3ab24bc163c7af 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -4272,18 +4272,19 @@ class FirConverter : public Fortran::lower::AbstractConverter {
           base = convertOp.getValue();
         // Special case if the rhs is a constant.
         if (matchPattern(base.getDefiningOp(), mlir::m_Constant())) {
-          builder.create<cuf::DataTransferOp>(loc, base, lhsVal,
-                                              transferKindAttr);
+          builder.create<cuf::DataTransferOp>(
+              loc, base, lhsVal, mlir::ValueRange{}, transferKindAttr);
         } else {
           auto associate = hlfir::genAssociateExpr(
               loc, builder, rhs, rhs.getType(), ".cuf_host_tmp");
           builder.create<cuf::DataTransferOp>(loc, associate.getBase(), lhsVal,
+                                              mlir::ValueRange{},
                                               transferKindAttr);
           builder.create<hlfir::EndAssociateOp>(loc, associate);
         }
       } else {
-        builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
-                                            transferKindAttr);
+        builder.create<cuf::DataTransferOp>(
+            loc, rhsVal, lhsVal, mlir::ValueRange{}, transferKindAttr);
       }
       return;
     }
@@ -4293,7 +4294,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       auto transferKindAttr = cuf::DataTransferKindAttr::get(
           builder.getContext(), cuf::DataTransferKind::DeviceHost);
       builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
-                                          transferKindAttr);
+                                          mlir::ValueRange{}, transferKindAttr);
       return;
     }
 
@@ -4303,7 +4304,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       auto transferKindAttr = cuf::DataTransferKindAttr::get(
           builder.getContext(), cuf::DataTransferKind::DeviceDevice);
       builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
-                                          transferKindAttr);
+                                          mlir::ValueRange{}, transferKindAttr);
       return;
     }
     llvm_unreachable("Unhandled CUDA data transfer");
@@ -4346,8 +4347,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
           addSymbol(sym,
                     hlfir::translateToExtendedValue(loc, builder, temp).first,
                     /*forced=*/true);
-          builder.create<cuf::DataTransferOp>(loc, addr, temp,
-                                              transferKindAttr);
+          builder.create<cuf::DataTransferOp>(
+              loc, addr, temp, mlir::ValueRange{}, transferKindAttr);
           ++nbDeviceResidentObject;
         }
       }
@@ -4444,7 +4445,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
         !userDefinedAssignment) {
       Fortran::lower::StatementContext localStmtCtx;
       hlfir::Entity rhs = evaluateRhs(localStmtCtx);
+      llvm::errs() << rhs << "\n";
       hlfir::Entity lhs = evaluateLhs(localStmtCtx);
+      llvm::errs() << lhs << "\n";
       if (isCUDATransfer && !hasCUDAImplicitTransfer)
         genCUDADataTransfer(builder, loc, assign, lhs, rhs);
       else
diff --git a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
index f7b36b208a7deb..d02c5d752dc5a6 100644
--- a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
+++ b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
@@ -99,6 +99,11 @@ llvm::LogicalResult cuf::AllocateOp::verify() {
 llvm::LogicalResult cuf::DataTransferOp::verify() {
   mlir::Type srcTy = getSrc().getType();
   mlir::Type dstTy = getDst().getType();
+  if (!getShape().empty()) {
+    if (!fir::isa_ref_type(srcTy) || fir::isa_ref_type(dstTy))
+      return emitOpError()
+             << "shape can only be specified on data transfer with references";
+  }
   if ((fir::isa_ref_type(srcTy) && fir::isa_ref_type(dstTy)) ||
       (fir::isa_box_type(srcTy) && fir::isa_box_type(dstTy)) ||
       (fir::isa_ref_type(srcTy) && fir::isa_box_type(dstTy)) ||
diff --git a/flang/test/Fir/cuf-invalid.fir b/flang/test/Fir/cuf-invalid.fir
index 06e08d14b2435c..add864b5bea354 100644
--- a/flang/test/Fir/cuf-invalid.fir
+++ b/flang/test/Fir/cuf-invalid.fir
@@ -94,3 +94,34 @@ func.func @_QPsub1() {
   cuf.free %0 : !fir.ref<f32> {data_attr = #cuf.cuda<constant>}
   return
 }
+
+// -----
+
+func.func @_QPsub1(%arg0: !fir.ref<!fir.array<?xf32>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "adev"}, %arg1: !fir.ref<!fir.array<?xf32>> {fir.bindc_name = "ahost"}, %arg2: !fir.ref<i32> {fir.bindc_name = "n"}, %arg3: !fir.ref<i32> {fir.bindc_name = "m"}) {
+  %0 = fir.dummy_scope : !fir.dscope
+  %1:2 = hlfir.declare %arg2 dummy_scope %0 {uniq_name = "_QFsub1En"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %2:2 = hlfir.declare %arg3 dummy_scope %0 {uniq_name = "_QFsub1Em"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %3 = fir.load %1#0 : !fir.ref<i32>
+  %4 = fir.load %2#0 : !fir.ref<i32>
+  %5 = arith.muli %3, %4 : i32
+  %6 = fir.convert %5 : (i32) -> i64
+  %7 = fir.convert %6 : (i64) -> index
+  %c0 = arith.constant 0 : index
+  %8 = arith.cmpi sgt, %7, %c0 : index
+  %9 = arith.select %8, %7, %c0 : index
+  %10 = fir.shape %9 : (index) -> !fir.shape<1>
+  %11:2 = hlfir.declare %arg0(%10) dummy_scope %0 {data_attr = #cuf.cuda<device>, uniq_name = "_QFsub1Eadev"} : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>, !fir.dscope) -> (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>)
+  %12 = fir.load %1#0 : !fir.ref<i32>
+  %13 = fir.load %2#0 : !fir.ref<i32>
+  %14 = arith.muli %12, %13 : i32
+  %15 = fir.convert %14 : (i32) -> i64
+  %16 = fir.convert %15 : (i64) -> index
+  %c0_0 = arith.constant 0 : index
+  %17 = arith.cmpi sgt, %16, %c0_0 : index
+  %18 = arith.select %17, %16, %c0_0 : index
+  %19 = fir.shape %18 : (index) -> !fir.shape<1>
+  %20:2 = hlfir.declare %arg1(%19) dummy_scope %0 {uniq_name = "_QFsub1Eahost"} : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>, !fir.dscope) -> (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>)
+  // expected-error@+1{{'cuf.data_transfer' op shape can only be specified on data transfer with references}}
+  cuf.data_transfer %20#0 to %11#0, %18 : index {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.box<!fir.array<?xf32>>, !fir.box<!fir.array<?xf32>>
+  return
+}

@clementval clementval merged commit 7af61d5 into llvm:main Aug 26, 2024
@clementval clementval deleted the cuf_add_shape branch August 26, 2024 16:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

flang:fir-hlfir flang Flang issues not falling into any other category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants