Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang][cuda] Support malloc and free conversion in gpu module #116112

Merged
merged 1 commit into from
Nov 14, 2024

Conversation

clementval
Copy link
Contributor

@clementval clementval commented Nov 13, 2024

Update getMalloc and getFree to work with the enclosing module (ModuleOp or GPUModuleOp) so we can convert fir.allocmem and fir.freemem in device code.

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir flang:codegen labels Nov 13, 2024
@llvmbot
Copy link

llvmbot commented Nov 13, 2024

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

Update getMalloc and getFree to work with the enclosing module (ModuleOp or GPUModuleOp) so we can convert fir.allocmem and fir.freemem in device code.


Full diff: https://github.com/llvm/llvm-project/pull/116112.diff

3 Files Affected:

  • (modified) flang/lib/Optimizer/CodeGen/CMakeLists.txt (+1)
  • (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+36-13)
  • (modified) flang/test/Fir/convert-to-llvm.fir (+16)
diff --git a/flang/lib/Optimizer/CodeGen/CMakeLists.txt b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
index 646621cb01c157..f47d11875f04db 100644
--- a/flang/lib/Optimizer/CodeGen/CMakeLists.txt
+++ b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
@@ -23,6 +23,7 @@ add_flang_library(FIRCodeGen
   FIRSupport
   MLIRComplexToLLVM
   MLIRComplexToStandard
+  MLIRGPUDialect
   MLIRMathToFuncs
   MLIRMathToLLVM
   MLIRMathToLibm
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index d038efcb2eb42c..3452a662f7a194 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -41,6 +41,7 @@
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/Transforms/AddComdats.h"
@@ -920,17 +921,19 @@ struct EmboxCharOpConversion : public fir::FIROpConversion<fir::EmboxCharOp> {
 };
 } // namespace
 
-/// Return the LLVMFuncOp corresponding to the standard malloc call.
+template <typename ModuleOp>
 static mlir::SymbolRefAttr
-getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
+getMallocInModule(ModuleOp mod, fir::AllocMemOp op,
+                  mlir::ConversionPatternRewriter &rewriter) {
   static constexpr char mallocName[] = "malloc";
-  auto module = op->getParentOfType<mlir::ModuleOp>();
-  if (auto mallocFunc = module.lookupSymbol<mlir::LLVM::LLVMFuncOp>(mallocName))
+  if (auto mallocFunc =
+          mod.template lookupSymbol<mlir::LLVM::LLVMFuncOp>(mallocName))
     return mlir::SymbolRefAttr::get(mallocFunc);
-  if (auto userMalloc = module.lookupSymbol<mlir::func::FuncOp>(mallocName))
+  if (auto userMalloc =
+          mod.template lookupSymbol<mlir::func::FuncOp>(mallocName))
     return mlir::SymbolRefAttr::get(userMalloc);
-  mlir::OpBuilder moduleBuilder(
-      op->getParentOfType<mlir::ModuleOp>().getBodyRegion());
+
+  mlir::OpBuilder moduleBuilder(mod.getBodyRegion());
   auto indexType = mlir::IntegerType::get(op.getContext(), 64);
   auto mallocDecl = moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
       op.getLoc(), mallocName,
@@ -940,6 +943,15 @@ getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
   return mlir::SymbolRefAttr::get(mallocDecl);
 }
 
+/// Return the LLVMFuncOp corresponding to the standard malloc call.
+static mlir::SymbolRefAttr
+getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
+  if (auto mod = op->getParentOfType<mlir::gpu::GPUModuleOp>())
+    return getMallocInModule(mod, op, rewriter);
+  auto mod = op->getParentOfType<mlir::ModuleOp>();
+  return getMallocInModule(mod, op, rewriter);
+}
+
 /// Helper function for generating the LLVM IR that computes the distance
 /// in bytes between adjacent elements pointed to by a pointer
 /// of type \p ptrTy. The result is returned as a value of \p idxTy integer
@@ -1016,18 +1028,20 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> {
 } // namespace
 
 /// Return the LLVMFuncOp corresponding to the standard free call.
-static mlir::SymbolRefAttr getFree(fir::FreeMemOp op,
-                                   mlir::ConversionPatternRewriter &rewriter) {
+template <typename ModuleOp>
+static mlir::SymbolRefAttr
+getFreeInModule(ModuleOp mod, fir::FreeMemOp op,
+                mlir::ConversionPatternRewriter &rewriter) {
   static constexpr char freeName[] = "free";
-  auto module = op->getParentOfType<mlir::ModuleOp>();
   // Check if free already defined in the module.
-  if (auto freeFunc = module.lookupSymbol<mlir::LLVM::LLVMFuncOp>(freeName))
+  if (auto freeFunc =
+          mod.template lookupSymbol<mlir::LLVM::LLVMFuncOp>(freeName))
     return mlir::SymbolRefAttr::get(freeFunc);
   if (auto freeDefinedByUser =
-          module.lookupSymbol<mlir::func::FuncOp>(freeName))
+          mod.template lookupSymbol<mlir::func::FuncOp>(freeName))
     return mlir::SymbolRefAttr::get(freeDefinedByUser);
   // Create llvm declaration for free.
-  mlir::OpBuilder moduleBuilder(module.getBodyRegion());
+  mlir::OpBuilder moduleBuilder(mod.getBodyRegion());
   auto voidType = mlir::LLVM::LLVMVoidType::get(op.getContext());
   auto freeDecl = moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
       rewriter.getUnknownLoc(), freeName,
@@ -1037,6 +1051,14 @@ static mlir::SymbolRefAttr getFree(fir::FreeMemOp op,
   return mlir::SymbolRefAttr::get(freeDecl);
 }
 
+static mlir::SymbolRefAttr getFree(fir::FreeMemOp op,
+                                   mlir::ConversionPatternRewriter &rewriter) {
+  if (auto mod = op->getParentOfType<mlir::gpu::GPUModuleOp>())
+    return getFreeInModule(mod, op, rewriter);
+  auto mod = op->getParentOfType<mlir::ModuleOp>();
+  return getFreeInModule(mod, op, rewriter);
+}
+
 static unsigned getDimension(mlir::LLVM::LLVMArrayType ty) {
   unsigned result = 1;
   for (auto eleTy =
@@ -3730,6 +3752,7 @@ class FIRToLLVMLowering
     mlir::configureOpenMPToLLVMConversionLegality(target, typeConverter);
     target.addLegalDialect<mlir::omp::OpenMPDialect>();
     target.addLegalDialect<mlir::acc::OpenACCDialect>();
+    target.addLegalDialect<mlir::gpu::GPUDialect>();
 
     // required NOPs for applying a full conversion
     target.addLegalOp<mlir::ModuleOp>();
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index fa391fa6cc7a7d..4c9f965e1241a0 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -2776,3 +2776,19 @@ func.func @coordinate_array_unknown_size_1d(%arg0: !fir.ptr<!fir.array<? x i32>>
 fir.global common @c_(dense<0> : vector<4294967296xi8>) : !fir.array<4294967296xi8>
 
 // CHECK: llvm.mlir.global common @c_(dense<0> : vector<4294967296xi8>) {addr_space = 0 : i32} : !llvm.array<4294967296 x i8>
+
+// -----
+
+gpu.module @cuda_device_mod {
+  gpu.func @test_alloc_and_freemem_one() {
+    %z0 = fir.allocmem i32
+    fir.freemem %z0 : !fir.heap<i32>
+    gpu.return
+  }
+}
+
+// CHECK: gpu.module @cuda_device_mod {
+// CHECK: llvm.func @free(!llvm.ptr)
+// CHECK: llvm.func @malloc(i64) -> !llvm.ptr
+// CHECK: llvm.call @malloc
+// CHECK: lvm.call @free

@llvmbot
Copy link

llvmbot commented Nov 13, 2024

@llvm/pr-subscribers-flang-codegen

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

Update getMalloc and getFree to work with the enclosing module (ModuleOp or GPUModuleOp) so we can convert fir.allocmem and fir.freemem in device code.


Full diff: https://github.com/llvm/llvm-project/pull/116112.diff

3 Files Affected:

  • (modified) flang/lib/Optimizer/CodeGen/CMakeLists.txt (+1)
  • (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+36-13)
  • (modified) flang/test/Fir/convert-to-llvm.fir (+16)
diff --git a/flang/lib/Optimizer/CodeGen/CMakeLists.txt b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
index 646621cb01c157..f47d11875f04db 100644
--- a/flang/lib/Optimizer/CodeGen/CMakeLists.txt
+++ b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
@@ -23,6 +23,7 @@ add_flang_library(FIRCodeGen
   FIRSupport
   MLIRComplexToLLVM
   MLIRComplexToStandard
+  MLIRGPUDialect
   MLIRMathToFuncs
   MLIRMathToLLVM
   MLIRMathToLibm
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index d038efcb2eb42c..3452a662f7a194 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -41,6 +41,7 @@
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/Transforms/AddComdats.h"
@@ -920,17 +921,19 @@ struct EmboxCharOpConversion : public fir::FIROpConversion<fir::EmboxCharOp> {
 };
 } // namespace
 
-/// Return the LLVMFuncOp corresponding to the standard malloc call.
+template <typename ModuleOp>
 static mlir::SymbolRefAttr
-getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
+getMallocInModule(ModuleOp mod, fir::AllocMemOp op,
+                  mlir::ConversionPatternRewriter &rewriter) {
   static constexpr char mallocName[] = "malloc";
-  auto module = op->getParentOfType<mlir::ModuleOp>();
-  if (auto mallocFunc = module.lookupSymbol<mlir::LLVM::LLVMFuncOp>(mallocName))
+  if (auto mallocFunc =
+          mod.template lookupSymbol<mlir::LLVM::LLVMFuncOp>(mallocName))
     return mlir::SymbolRefAttr::get(mallocFunc);
-  if (auto userMalloc = module.lookupSymbol<mlir::func::FuncOp>(mallocName))
+  if (auto userMalloc =
+          mod.template lookupSymbol<mlir::func::FuncOp>(mallocName))
     return mlir::SymbolRefAttr::get(userMalloc);
-  mlir::OpBuilder moduleBuilder(
-      op->getParentOfType<mlir::ModuleOp>().getBodyRegion());
+
+  mlir::OpBuilder moduleBuilder(mod.getBodyRegion());
   auto indexType = mlir::IntegerType::get(op.getContext(), 64);
   auto mallocDecl = moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
       op.getLoc(), mallocName,
@@ -940,6 +943,15 @@ getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
   return mlir::SymbolRefAttr::get(mallocDecl);
 }
 
+/// Return the LLVMFuncOp corresponding to the standard malloc call.
+static mlir::SymbolRefAttr
+getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
+  if (auto mod = op->getParentOfType<mlir::gpu::GPUModuleOp>())
+    return getMallocInModule(mod, op, rewriter);
+  auto mod = op->getParentOfType<mlir::ModuleOp>();
+  return getMallocInModule(mod, op, rewriter);
+}
+
 /// Helper function for generating the LLVM IR that computes the distance
 /// in bytes between adjacent elements pointed to by a pointer
 /// of type \p ptrTy. The result is returned as a value of \p idxTy integer
@@ -1016,18 +1028,20 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> {
 } // namespace
 
 /// Return the LLVMFuncOp corresponding to the standard free call.
-static mlir::SymbolRefAttr getFree(fir::FreeMemOp op,
-                                   mlir::ConversionPatternRewriter &rewriter) {
+template <typename ModuleOp>
+static mlir::SymbolRefAttr
+getFreeInModule(ModuleOp mod, fir::FreeMemOp op,
+                mlir::ConversionPatternRewriter &rewriter) {
   static constexpr char freeName[] = "free";
-  auto module = op->getParentOfType<mlir::ModuleOp>();
   // Check if free already defined in the module.
-  if (auto freeFunc = module.lookupSymbol<mlir::LLVM::LLVMFuncOp>(freeName))
+  if (auto freeFunc =
+          mod.template lookupSymbol<mlir::LLVM::LLVMFuncOp>(freeName))
     return mlir::SymbolRefAttr::get(freeFunc);
   if (auto freeDefinedByUser =
-          module.lookupSymbol<mlir::func::FuncOp>(freeName))
+          mod.template lookupSymbol<mlir::func::FuncOp>(freeName))
     return mlir::SymbolRefAttr::get(freeDefinedByUser);
   // Create llvm declaration for free.
-  mlir::OpBuilder moduleBuilder(module.getBodyRegion());
+  mlir::OpBuilder moduleBuilder(mod.getBodyRegion());
   auto voidType = mlir::LLVM::LLVMVoidType::get(op.getContext());
   auto freeDecl = moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
       rewriter.getUnknownLoc(), freeName,
@@ -1037,6 +1051,14 @@ static mlir::SymbolRefAttr getFree(fir::FreeMemOp op,
   return mlir::SymbolRefAttr::get(freeDecl);
 }
 
+static mlir::SymbolRefAttr getFree(fir::FreeMemOp op,
+                                   mlir::ConversionPatternRewriter &rewriter) {
+  if (auto mod = op->getParentOfType<mlir::gpu::GPUModuleOp>())
+    return getFreeInModule(mod, op, rewriter);
+  auto mod = op->getParentOfType<mlir::ModuleOp>();
+  return getFreeInModule(mod, op, rewriter);
+}
+
 static unsigned getDimension(mlir::LLVM::LLVMArrayType ty) {
   unsigned result = 1;
   for (auto eleTy =
@@ -3730,6 +3752,7 @@ class FIRToLLVMLowering
     mlir::configureOpenMPToLLVMConversionLegality(target, typeConverter);
     target.addLegalDialect<mlir::omp::OpenMPDialect>();
     target.addLegalDialect<mlir::acc::OpenACCDialect>();
+    target.addLegalDialect<mlir::gpu::GPUDialect>();
 
     // required NOPs for applying a full conversion
     target.addLegalOp<mlir::ModuleOp>();
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index fa391fa6cc7a7d..4c9f965e1241a0 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -2776,3 +2776,19 @@ func.func @coordinate_array_unknown_size_1d(%arg0: !fir.ptr<!fir.array<? x i32>>
 fir.global common @c_(dense<0> : vector<4294967296xi8>) : !fir.array<4294967296xi8>
 
 // CHECK: llvm.mlir.global common @c_(dense<0> : vector<4294967296xi8>) {addr_space = 0 : i32} : !llvm.array<4294967296 x i8>
+
+// -----
+
+gpu.module @cuda_device_mod {
+  gpu.func @test_alloc_and_freemem_one() {
+    %z0 = fir.allocmem i32
+    fir.freemem %z0 : !fir.heap<i32>
+    gpu.return
+  }
+}
+
+// CHECK: gpu.module @cuda_device_mod {
+// CHECK: llvm.func @free(!llvm.ptr)
+// CHECK: llvm.func @malloc(i64) -> !llvm.ptr
+// CHECK: llvm.call @malloc
+// CHECK: lvm.call @free

@clementval clementval merged commit e5092c3 into llvm:main Nov 14, 2024
12 checks passed
@clementval clementval deleted the cuf_alloc_free branch November 14, 2024 01:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:codegen flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants