From 3f8a0bd6f3a6e6c26558f7ec4bb256b59a728663 Mon Sep 17 00:00:00 2001 From: Valentin Clement Date: Wed, 13 Nov 2024 17:53:05 -0800 Subject: [PATCH] [flang][cuda] Add conversion after CUFGetDeviceAddress to avoid problem when emboxing --- .../Optimizer/Transforms/CUFOpConversion.cpp | 5 ++-- flang/test/Fir/CUDA/cuda-data-transfer.fir | 23 +++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp index 58a3cdc905d36..bca0a09c5bff6 100644 --- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp +++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp @@ -140,8 +140,9 @@ mlir::Value getDeviceAddress(mlir::PatternRewriter &rewriter, llvm::SmallVector args{fir::runtime::createArguments( builder, loc, fTy, inputArg, sourceFile, sourceLine)}; auto call = rewriter.create(loc, callee, args); - - return call->getResult(0); + mlir::Value cast = createConvertOp( + rewriter, loc, declareOp.getMemref().getType(), call->getResult(0)); + return cast; } template diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir index 491d417271ce7..9c6d9e0c10012 100644 --- a/flang/test/Fir/CUDA/cuda-data-transfer.fir +++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir @@ -202,7 +202,9 @@ func.func @_QPsub8() attributes {fir.bindc_name = "t"} { // CHECK: %[[DECL:.*]] = fir.declare %[[GBL]] // CHECK: %[[HOST:.*]] = fir.convert %[[DECL]] : (!fir.ref>) -> !fir.llvm_ptr // CHECK: %[[SRC:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[HOST]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr, !fir.ref, i32) -> !fir.llvm_ptr +// CHECK: %[[SRC_CONV:.*]] = fir.convert %[[SRC]] : (!fir.llvm_ptr) -> !fir.ref> // CHECK: %[[DST:.*]] = fir.convert %[[LOCAL]] : (!fir.ref>) -> !fir.llvm_ptr +// CHECK: %[[SRC:.*]] = fir.convert %[[SRC_CONV]] : (!fir.ref>) -> !fir.llvm_ptr // CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr, !fir.llvm_ptr, i64, i32, !fir.ref, i32) -> none @@ -224,6 +226,8 @@ func.func @_QPsub9() { // CHECK: %[[DECL:.*]] = fir.declare %[[GBL]] // CHECK: %[[HOST:.*]] = fir.convert %[[DECL]] : (!fir.ref>) -> !fir.llvm_ptr // CHECK: %[[DST:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[HOST]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr, !fir.ref, i32) -> !fir.llvm_ptr +// CHECK: %[[DST_CONV:.*]] = fir.convert %[[DST]] : (!fir.llvm_ptr) -> !fir.ref> +// CHECK: %[[DST:.*]] = fir.convert %[[DST_CONV]] : (!fir.ref>) -> !fir.llvm_ptr // CHECK: %[[SRC:.*]] = fir.convert %[[LOCAL]] : (!fir.ref>) -> !fir.llvm_ptr // CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr, !fir.llvm_ptr, i64, i32, !fir.ref, i32) -> none @@ -361,5 +365,24 @@ func.func @_QPshape_shift2() { // CHECK: %[[BYTES:.*]] = arith.muli %[[C10]], %c4{{.*}} : i64 // CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%{{.*}}, %{{.*}}, %[[BYTES]], %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr, !fir.llvm_ptr, i64, i32, !fir.ref, i32) -> none +fir.global @_QMmod1Ea_dev {data_attr = #cuf.cuda} : !fir.array<4xf32> { + %0 = fir.zero_bits !fir.array<4xf32> + fir.has_value %0 : !fir.array<4xf32> +} +func.func @_QPdevice_addr_conv() { + %cst = arith.constant 4.200000e+01 : f32 + %c4 = arith.constant 4 : index + %0 = fir.address_of(@_QMmod1Ea_dev) : !fir.ref> + %1 = fir.shape %c4 : (index) -> !fir.shape<1> + %2 = fir.declare %0(%1) {data_attr = #cuf.cuda, uniq_name = "_QMmod1Ea_dev"} : (!fir.ref>, !fir.shape<1>) -> !fir.ref> + cuf.data_transfer %cst to %2 {transfer_kind = #cuf.cuda_transfer} : f32, !fir.ref> + return +} + +// CHECK-LABEL: func.func @_QPdevice_addr_conv() +// CHECK: %[[DEV_ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr, !fir.ref, i32) -> !fir.llvm_ptr +// CHECK: %[[DEV_ADDR_CONV:.*]] = fir.convert %[[DEV_ADDR]] : (!fir.llvm_ptr) -> !fir.ref> +// CHECK: fir.embox %[[DEV_ADDR_CONV]](%{{.*}}) : (!fir.ref>, !fir.shape<1>) -> !fir.box> +// CHECK: fir.call @_FortranACUFDataTransferDescDescNoRealloc } // end of module