-
Notifications
You must be signed in to change notification settings - Fork 12k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[flang][cuda] Support malloc and free conversion in gpu module #116112
Conversation
@llvm/pr-subscribers-flang-fir-hlfir Author: Valentin Clement (バレンタイン クレメン) (clementval) ChangesUpdate Full diff: https://github.com/llvm/llvm-project/pull/116112.diff 3 Files Affected:
diff --git a/flang/lib/Optimizer/CodeGen/CMakeLists.txt b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
index 646621cb01c157..f47d11875f04db 100644
--- a/flang/lib/Optimizer/CodeGen/CMakeLists.txt
+++ b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
@@ -23,6 +23,7 @@ add_flang_library(FIRCodeGen
FIRSupport
MLIRComplexToLLVM
MLIRComplexToStandard
+ MLIRGPUDialect
MLIRMathToFuncs
MLIRMathToLLVM
MLIRMathToLibm
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index d038efcb2eb42c..3452a662f7a194 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -41,6 +41,7 @@
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/Transforms/AddComdats.h"
@@ -920,17 +921,19 @@ struct EmboxCharOpConversion : public fir::FIROpConversion<fir::EmboxCharOp> {
};
} // namespace
-/// Return the LLVMFuncOp corresponding to the standard malloc call.
+template <typename ModuleOp>
static mlir::SymbolRefAttr
-getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
+getMallocInModule(ModuleOp mod, fir::AllocMemOp op,
+ mlir::ConversionPatternRewriter &rewriter) {
static constexpr char mallocName[] = "malloc";
- auto module = op->getParentOfType<mlir::ModuleOp>();
- if (auto mallocFunc = module.lookupSymbol<mlir::LLVM::LLVMFuncOp>(mallocName))
+ if (auto mallocFunc =
+ mod.template lookupSymbol<mlir::LLVM::LLVMFuncOp>(mallocName))
return mlir::SymbolRefAttr::get(mallocFunc);
- if (auto userMalloc = module.lookupSymbol<mlir::func::FuncOp>(mallocName))
+ if (auto userMalloc =
+ mod.template lookupSymbol<mlir::func::FuncOp>(mallocName))
return mlir::SymbolRefAttr::get(userMalloc);
- mlir::OpBuilder moduleBuilder(
- op->getParentOfType<mlir::ModuleOp>().getBodyRegion());
+
+ mlir::OpBuilder moduleBuilder(mod.getBodyRegion());
auto indexType = mlir::IntegerType::get(op.getContext(), 64);
auto mallocDecl = moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
op.getLoc(), mallocName,
@@ -940,6 +943,15 @@ getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
return mlir::SymbolRefAttr::get(mallocDecl);
}
+/// Return the LLVMFuncOp corresponding to the standard malloc call.
+static mlir::SymbolRefAttr
+getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
+ if (auto mod = op->getParentOfType<mlir::gpu::GPUModuleOp>())
+ return getMallocInModule(mod, op, rewriter);
+ auto mod = op->getParentOfType<mlir::ModuleOp>();
+ return getMallocInModule(mod, op, rewriter);
+}
+
/// Helper function for generating the LLVM IR that computes the distance
/// in bytes between adjacent elements pointed to by a pointer
/// of type \p ptrTy. The result is returned as a value of \p idxTy integer
@@ -1016,18 +1028,20 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> {
} // namespace
/// Return the LLVMFuncOp corresponding to the standard free call.
-static mlir::SymbolRefAttr getFree(fir::FreeMemOp op,
- mlir::ConversionPatternRewriter &rewriter) {
+template <typename ModuleOp>
+static mlir::SymbolRefAttr
+getFreeInModule(ModuleOp mod, fir::FreeMemOp op,
+ mlir::ConversionPatternRewriter &rewriter) {
static constexpr char freeName[] = "free";
- auto module = op->getParentOfType<mlir::ModuleOp>();
// Check if free already defined in the module.
- if (auto freeFunc = module.lookupSymbol<mlir::LLVM::LLVMFuncOp>(freeName))
+ if (auto freeFunc =
+ mod.template lookupSymbol<mlir::LLVM::LLVMFuncOp>(freeName))
return mlir::SymbolRefAttr::get(freeFunc);
if (auto freeDefinedByUser =
- module.lookupSymbol<mlir::func::FuncOp>(freeName))
+ mod.template lookupSymbol<mlir::func::FuncOp>(freeName))
return mlir::SymbolRefAttr::get(freeDefinedByUser);
// Create llvm declaration for free.
- mlir::OpBuilder moduleBuilder(module.getBodyRegion());
+ mlir::OpBuilder moduleBuilder(mod.getBodyRegion());
auto voidType = mlir::LLVM::LLVMVoidType::get(op.getContext());
auto freeDecl = moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
rewriter.getUnknownLoc(), freeName,
@@ -1037,6 +1051,14 @@ static mlir::SymbolRefAttr getFree(fir::FreeMemOp op,
return mlir::SymbolRefAttr::get(freeDecl);
}
+static mlir::SymbolRefAttr getFree(fir::FreeMemOp op,
+ mlir::ConversionPatternRewriter &rewriter) {
+ if (auto mod = op->getParentOfType<mlir::gpu::GPUModuleOp>())
+ return getFreeInModule(mod, op, rewriter);
+ auto mod = op->getParentOfType<mlir::ModuleOp>();
+ return getFreeInModule(mod, op, rewriter);
+}
+
static unsigned getDimension(mlir::LLVM::LLVMArrayType ty) {
unsigned result = 1;
for (auto eleTy =
@@ -3730,6 +3752,7 @@ class FIRToLLVMLowering
mlir::configureOpenMPToLLVMConversionLegality(target, typeConverter);
target.addLegalDialect<mlir::omp::OpenMPDialect>();
target.addLegalDialect<mlir::acc::OpenACCDialect>();
+ target.addLegalDialect<mlir::gpu::GPUDialect>();
// required NOPs for applying a full conversion
target.addLegalOp<mlir::ModuleOp>();
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index fa391fa6cc7a7d..4c9f965e1241a0 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -2776,3 +2776,19 @@ func.func @coordinate_array_unknown_size_1d(%arg0: !fir.ptr<!fir.array<? x i32>>
fir.global common @c_(dense<0> : vector<4294967296xi8>) : !fir.array<4294967296xi8>
// CHECK: llvm.mlir.global common @c_(dense<0> : vector<4294967296xi8>) {addr_space = 0 : i32} : !llvm.array<4294967296 x i8>
+
+// -----
+
+gpu.module @cuda_device_mod {
+ gpu.func @test_alloc_and_freemem_one() {
+ %z0 = fir.allocmem i32
+ fir.freemem %z0 : !fir.heap<i32>
+ gpu.return
+ }
+}
+
+// CHECK: gpu.module @cuda_device_mod {
+// CHECK: llvm.func @free(!llvm.ptr)
+// CHECK: llvm.func @malloc(i64) -> !llvm.ptr
+// CHECK: llvm.call @malloc
+// CHECK: lvm.call @free
|
@llvm/pr-subscribers-flang-codegen Author: Valentin Clement (バレンタイン クレメン) (clementval) ChangesUpdate Full diff: https://github.com/llvm/llvm-project/pull/116112.diff 3 Files Affected:
diff --git a/flang/lib/Optimizer/CodeGen/CMakeLists.txt b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
index 646621cb01c157..f47d11875f04db 100644
--- a/flang/lib/Optimizer/CodeGen/CMakeLists.txt
+++ b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
@@ -23,6 +23,7 @@ add_flang_library(FIRCodeGen
FIRSupport
MLIRComplexToLLVM
MLIRComplexToStandard
+ MLIRGPUDialect
MLIRMathToFuncs
MLIRMathToLLVM
MLIRMathToLibm
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index d038efcb2eb42c..3452a662f7a194 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -41,6 +41,7 @@
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/Transforms/AddComdats.h"
@@ -920,17 +921,19 @@ struct EmboxCharOpConversion : public fir::FIROpConversion<fir::EmboxCharOp> {
};
} // namespace
-/// Return the LLVMFuncOp corresponding to the standard malloc call.
+template <typename ModuleOp>
static mlir::SymbolRefAttr
-getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
+getMallocInModule(ModuleOp mod, fir::AllocMemOp op,
+ mlir::ConversionPatternRewriter &rewriter) {
static constexpr char mallocName[] = "malloc";
- auto module = op->getParentOfType<mlir::ModuleOp>();
- if (auto mallocFunc = module.lookupSymbol<mlir::LLVM::LLVMFuncOp>(mallocName))
+ if (auto mallocFunc =
+ mod.template lookupSymbol<mlir::LLVM::LLVMFuncOp>(mallocName))
return mlir::SymbolRefAttr::get(mallocFunc);
- if (auto userMalloc = module.lookupSymbol<mlir::func::FuncOp>(mallocName))
+ if (auto userMalloc =
+ mod.template lookupSymbol<mlir::func::FuncOp>(mallocName))
return mlir::SymbolRefAttr::get(userMalloc);
- mlir::OpBuilder moduleBuilder(
- op->getParentOfType<mlir::ModuleOp>().getBodyRegion());
+
+ mlir::OpBuilder moduleBuilder(mod.getBodyRegion());
auto indexType = mlir::IntegerType::get(op.getContext(), 64);
auto mallocDecl = moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
op.getLoc(), mallocName,
@@ -940,6 +943,15 @@ getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
return mlir::SymbolRefAttr::get(mallocDecl);
}
+/// Return the LLVMFuncOp corresponding to the standard malloc call.
+static mlir::SymbolRefAttr
+getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
+ if (auto mod = op->getParentOfType<mlir::gpu::GPUModuleOp>())
+ return getMallocInModule(mod, op, rewriter);
+ auto mod = op->getParentOfType<mlir::ModuleOp>();
+ return getMallocInModule(mod, op, rewriter);
+}
+
/// Helper function for generating the LLVM IR that computes the distance
/// in bytes between adjacent elements pointed to by a pointer
/// of type \p ptrTy. The result is returned as a value of \p idxTy integer
@@ -1016,18 +1028,20 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> {
} // namespace
/// Return the LLVMFuncOp corresponding to the standard free call.
-static mlir::SymbolRefAttr getFree(fir::FreeMemOp op,
- mlir::ConversionPatternRewriter &rewriter) {
+template <typename ModuleOp>
+static mlir::SymbolRefAttr
+getFreeInModule(ModuleOp mod, fir::FreeMemOp op,
+ mlir::ConversionPatternRewriter &rewriter) {
static constexpr char freeName[] = "free";
- auto module = op->getParentOfType<mlir::ModuleOp>();
// Check if free already defined in the module.
- if (auto freeFunc = module.lookupSymbol<mlir::LLVM::LLVMFuncOp>(freeName))
+ if (auto freeFunc =
+ mod.template lookupSymbol<mlir::LLVM::LLVMFuncOp>(freeName))
return mlir::SymbolRefAttr::get(freeFunc);
if (auto freeDefinedByUser =
- module.lookupSymbol<mlir::func::FuncOp>(freeName))
+ mod.template lookupSymbol<mlir::func::FuncOp>(freeName))
return mlir::SymbolRefAttr::get(freeDefinedByUser);
// Create llvm declaration for free.
- mlir::OpBuilder moduleBuilder(module.getBodyRegion());
+ mlir::OpBuilder moduleBuilder(mod.getBodyRegion());
auto voidType = mlir::LLVM::LLVMVoidType::get(op.getContext());
auto freeDecl = moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
rewriter.getUnknownLoc(), freeName,
@@ -1037,6 +1051,14 @@ static mlir::SymbolRefAttr getFree(fir::FreeMemOp op,
return mlir::SymbolRefAttr::get(freeDecl);
}
+static mlir::SymbolRefAttr getFree(fir::FreeMemOp op,
+ mlir::ConversionPatternRewriter &rewriter) {
+ if (auto mod = op->getParentOfType<mlir::gpu::GPUModuleOp>())
+ return getFreeInModule(mod, op, rewriter);
+ auto mod = op->getParentOfType<mlir::ModuleOp>();
+ return getFreeInModule(mod, op, rewriter);
+}
+
static unsigned getDimension(mlir::LLVM::LLVMArrayType ty) {
unsigned result = 1;
for (auto eleTy =
@@ -3730,6 +3752,7 @@ class FIRToLLVMLowering
mlir::configureOpenMPToLLVMConversionLegality(target, typeConverter);
target.addLegalDialect<mlir::omp::OpenMPDialect>();
target.addLegalDialect<mlir::acc::OpenACCDialect>();
+ target.addLegalDialect<mlir::gpu::GPUDialect>();
// required NOPs for applying a full conversion
target.addLegalOp<mlir::ModuleOp>();
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index fa391fa6cc7a7d..4c9f965e1241a0 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -2776,3 +2776,19 @@ func.func @coordinate_array_unknown_size_1d(%arg0: !fir.ptr<!fir.array<? x i32>>
fir.global common @c_(dense<0> : vector<4294967296xi8>) : !fir.array<4294967296xi8>
// CHECK: llvm.mlir.global common @c_(dense<0> : vector<4294967296xi8>) {addr_space = 0 : i32} : !llvm.array<4294967296 x i8>
+
+// -----
+
+gpu.module @cuda_device_mod {
+ gpu.func @test_alloc_and_freemem_one() {
+ %z0 = fir.allocmem i32
+ fir.freemem %z0 : !fir.heap<i32>
+ gpu.return
+ }
+}
+
+// CHECK: gpu.module @cuda_device_mod {
+// CHECK: llvm.func @free(!llvm.ptr)
+// CHECK: llvm.func @malloc(i64) -> !llvm.ptr
+// CHECK: llvm.call @malloc
+// CHECK: lvm.call @free
|
Update
getMalloc
andgetFree
to work with the enclosing module (ModuleOp or GPUModuleOp) so we can convertfir.allocmem
andfir.freemem
in device code.