diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index e88bb05d0b0b9a..4a923fac76c884 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -367,7 +367,9 @@ static void emitCpAsyncOpZfillAsm(Location loc, Value dstPtr, Value srcPtr, ConversionPatternRewriter &rewriter) { auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(), LLVM::AsmDialect::AD_ATT); - const char *asmStr = "cp.async.cg.shared.global [$0], [$1], $2, $3;\n"; + + const char *cpAsyncCgStr = "cp.async.cg.shared.global [$0], [$1], $2, $3;\n"; + const char *cpAsyncCaStr = "cp.async.ca.shared.global [$0], [$1], $2, $3;\n"; const char *asmConstraints = "r,l,n,r"; Value c3I32 = rewriter.create( @@ -382,6 +384,19 @@ static void emitCpAsyncOpZfillAsm(Location loc, Value dstPtr, Value srcPtr, SmallVector asmVals{dstPtr, srcPtr, dstBytes, srcBytes}; + // Pick the right asm string based on the dstBytes which is a compile-time + // constant. + auto dstByteConstOp = + dyn_cast(dstBytes.getDefiningOp()); + auto dstByteAttr = dstByteConstOp.getValue().dyn_cast(); + int64_t dstByteVal = dstByteAttr.getValue().getSExtValue(); + + assert((dstByteVal == 4 || dstByteVal == 8 || dstByteVal == 16) && + "cp.async byte copy size must be 4, 8 or 16"); + // Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than + // 16 dst bytes. + const char *asmStr = (dstByteVal == 16) ? cpAsyncCgStr : cpAsyncCaStr; + rewriter.create( loc, LLVM::LLVMVoidType::get(rewriter.getContext()), /*operands=*/asmVals, diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir index dfbbd15393b046..54b71389d8ee56 100644 --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -295,12 +295,12 @@ func.func @async_cp_i4( // ----- -// CHECK-LABEL: @async_cp_zfill( +// CHECK-LABEL: @async_cp_zfill_f32_align4( // CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index, %[[SRCELEMENTS:[a-zA-Z0-9_]+]]: index) -func.func @async_cp_zfill( +func.func @async_cp_zfill_f32_align4( %src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index, %srcElements : index) { - - // CHECK-DAG: lvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" %[[DSTPTR:.*]], %[[SRCPTR:.*]], %[[DSTBYTES:.*]], %[[SRCBYTES:.*]] : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> !llvm.void + // CHECK-DAG: %[[DSTBYTES:.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK-DAG: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" %[[DSTPTR:.*]], %[[SRCPTR:.*]], %[[DSTBYTES]], %[[SRCBYTES:.*]] : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> !llvm.void %0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 4, %srcElements {bypassL1}: memref<128x128xf32> to memref<3x16x128xf32, 3> // CHECK: nvvm.cp.async.commit.group %1 = nvgpu.device_async_create_group %0 @@ -312,6 +312,24 @@ func.func @async_cp_zfill( // ----- +// CHECK-LABEL: @async_cp_zfill_f32_align1( +// CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index, %[[SRCELEMENTS:[a-zA-Z0-9_]+]]: index) +func.func @async_cp_zfill_f32_align1( + %src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index, %srcElements : index) { + // CHECK-DAG: %[[DSTBYTES:.*]] = llvm.mlir.constant(4 : i32) : i32 + // CHECK-DAG: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.ca.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" %[[DSTPTR:.*]], %[[SRCPTR:.*]], %[[DSTBYTES]], %[[SRCBYTES:.*]] : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> !llvm.void + %0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 1, %srcElements {bypassL1}: memref<128x128xf32> to memref<3x16x128xf32, 3> + // CHECK: nvvm.cp.async.commit.group + %1 = nvgpu.device_async_create_group %0 + // CHECK: nvvm.cp.async.wait.group 1 + nvgpu.device_async_wait %1 { numGroups = 1 : i32 } + + return +} + +// ----- + + // CHECK-LABEL: func @mma_sp_sync_f16_16832( func.func @mma_sp_sync_f16_16832(%arg0: vector<4x2xf16>, %arg1: vector<4x2xf16>,