Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][GPUToNVVM] support fastMath and other non-supported mathOp #99890

Merged
merged 1 commit into from
Jul 25, 2024

Conversation

runseny
Copy link
Contributor

@runseny runseny commented Jul 22, 2024

Support fastMath and other non-supported mathOp which only require float operands and call libdevice function directly to nvvm.

  1. lowering mathOp with fastMath attribute to correct libdevice intrinsic.
  2. some mathOp in math dialect has been lowered to libdevice now, but it doesn't cover all mathOp. so this mr lowers all the remaining mathOp which only require float operands.

@llvmbot
Copy link
Member

llvmbot commented Jul 22, 2024

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: None (runseny)

Changes

Support fastMath and other non-supported mathOp which only require float operands and call libdevice function directly to nvvm.

  1. lowering mathOp with fastMath attribute to correct libdevice intrinsic.
  2. some mathOp in math dialect has been lowered to libdevice now, but it doesn't cover all mathOp. so this mr lowers all the remaining mathOp which only require float operands.

Patch is 25.48 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/99890.diff

4 Files Affected:

  • (modified) mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h (+12-6)
  • (modified) mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp (+29-17)
  • (modified) mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp (+2-2)
  • (modified) mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir (+221-21)
diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index ebce2d77310ae..7ce17a69d7e4d 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -31,9 +31,9 @@ template <typename SourceOp>
 struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
 public:
   explicit OpToFuncCallLowering(LLVMTypeConverter &lowering, StringRef f32Func,
-                                StringRef f64Func)
+                                StringRef f64Func, StringRef f32FastFunc)
       : ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
-        f64Func(f64Func) {}
+        f64Func(f64Func), f32FastFunc(f32FastFunc) {}
 
   LogicalResult
   matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
@@ -55,7 +55,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
     Type resultType = castedOperands.front().getType();
     Type funcType = getFunctionType(resultType, castedOperands);
     StringRef funcName =
-        getFunctionName(cast<LLVM::LLVMFunctionType>(funcType).getReturnType());
+        getFunctionName(cast<LLVM::LLVMFunctionType>(funcType).getReturnType(), op.getFastmath());
     if (funcName.empty())
       return failure();
 
@@ -90,9 +90,13 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
     return LLVM::LLVMFunctionType::get(resultType, operandTypes);
   }
 
-  StringRef getFunctionName(Type type) const {
-    if (isa<Float32Type>(type))
-      return f32Func;
+  StringRef getFunctionName(Type type, arith::FastMathFlags flag) const {
+    if (isa<Float32Type>(type)) {
+      if (arith::FastMathFlags::fast == flag && !f32FastFunc.empty())
+        return f32FastFunc;
+      else
+        return f32Func;
+    }
     if (isa<Float64Type>(type))
       return f64Func;
     return "";
@@ -113,8 +117,10 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
 
   const std::string f32Func;
   const std::string f64Func;
+  const std::string f32FastFunc;
 };
 
+
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index fea8a0ddc7f06..9cfad02538c98 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -309,10 +309,11 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
   target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
   target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
   target.addIllegalDialect<gpu::GPUDialect>();
-  target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
-                      LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
-                      LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
-                      LLVM::SqrtOp>();
+  target.addIllegalOp<LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
+                      LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FMAOp, LLVM::FRemOp, LLVM::LogOp,
+                      LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::RoundEvenOp,
+                      LLVM::RoundOp, LLVM::SinOp, LLVM::SqrtOp>();
+
 
   // TODO: Remove once we support replacing non-root ops.
   target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
@@ -321,9 +322,9 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
 template <typename OpTy>
 static void populateOpPatterns(LLVMTypeConverter &converter,
                                RewritePatternSet &patterns, StringRef f32Func,
-                               StringRef f64Func) {
+                               StringRef f64Func, StringRef f32FastFunc = "") {
   patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
-  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
+  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, f32FastFunc);
 }
 
 void mlir::populateGpuSubgroupReduceOpLoweringPattern(
@@ -370,42 +371,53 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
       StringAttr::get(&converter.getContext(),
                       NVVM::NVVMDialect::getMaxntidAttrName()));
 
+  populateOpPatterns<arith::RemFOp>(converter, patterns, "__nv_fmodf",
+                                    "__nv_fmod");
   populateOpPatterns<math::AbsFOp>(converter, patterns, "__nv_fabsf",
                                    "__nv_fabs");
+  populateOpPatterns<math::AcosOp>(converter, patterns, "__nv_acosf", "__nv_acos");
+  populateOpPatterns<math::AcoshOp>(converter, patterns, "__nv_acoshf", "__nv_acosh");
+  populateOpPatterns<math::AsinOp>(converter, patterns, "__nv_asinf", "__nv_asin");
+  populateOpPatterns<math::AsinhOp>(converter, patterns, "__nv_asinhf", "__nv_asinh");
   populateOpPatterns<math::AtanOp>(converter, patterns, "__nv_atanf",
                                    "__nv_atan");
   populateOpPatterns<math::Atan2Op>(converter, patterns, "__nv_atan2f",
                                     "__nv_atan2");
+  populateOpPatterns<math::AtanhOp>(converter, patterns, "__nv_atanhf", "__nv_atanh");
   populateOpPatterns<math::CbrtOp>(converter, patterns, "__nv_cbrtf",
                                    "__nv_cbrt");
   populateOpPatterns<math::CeilOp>(converter, patterns, "__nv_ceilf",
                                    "__nv_ceil");
-  populateOpPatterns<math::CosOp>(converter, patterns, "__nv_cosf", "__nv_cos");
+  populateOpPatterns<math::CopySignOp>(converter, patterns, "__nv_copysignf", "__nv_copysign");
+  populateOpPatterns<math::CosOp>(converter, patterns, "__nv_cosf", "__nv_cos", "__nv_fast_cosf");
+  populateOpPatterns<math::CoshOp>(converter, patterns, "__nv_coshf", "__nv_cosh");
   populateOpPatterns<math::ErfOp>(converter, patterns, "__nv_erff", "__nv_erf");
-  populateOpPatterns<math::ExpOp>(converter, patterns, "__nv_expf", "__nv_exp");
+  populateOpPatterns<math::ExpOp>(converter, patterns, "__nv_expf", "__nv_exp", "__nv_fast_expf");
   populateOpPatterns<math::Exp2Op>(converter, patterns, "__nv_exp2f",
                                    "__nv_exp2");
   populateOpPatterns<math::ExpM1Op>(converter, patterns, "__nv_expm1f",
                                     "__nv_expm1");
   populateOpPatterns<math::FloorOp>(converter, patterns, "__nv_floorf",
                                     "__nv_floor");
-  populateOpPatterns<arith::RemFOp>(converter, patterns, "__nv_fmodf",
-                                    "__nv_fmod");
-  populateOpPatterns<math::LogOp>(converter, patterns, "__nv_logf", "__nv_log");
+  populateOpPatterns<math::FmaOp>(converter, patterns, "__nv_fmaf", "__nv_fma");
+  populateOpPatterns<math::LogOp>(converter, patterns, "__nv_logf", "__nv_log", "__nv_fast_logf");
+  populateOpPatterns<math::Log10Op>(converter, patterns, "__nv_log10f",
+                                    "__nv_log10", "__nv_fast_log10f");
   populateOpPatterns<math::Log1pOp>(converter, patterns, "__nv_log1pf",
                                     "__nv_log1p");
-  populateOpPatterns<math::Log10Op>(converter, patterns, "__nv_log10f",
-                                    "__nv_log10");
   populateOpPatterns<math::Log2Op>(converter, patterns, "__nv_log2f",
-                                   "__nv_log2");
+                                   "__nv_log2", "__nv_fast_log2f");
   populateOpPatterns<math::PowFOp>(converter, patterns, "__nv_powf",
-                                   "__nv_pow");
+                                   "__nv_pow", "__nv_fast_powf");
+  populateOpPatterns<math::RoundOp>(converter, patterns, "__nv_roundf", "__nv_round");
+  populateOpPatterns<math::RoundEvenOp>(converter, patterns, "__nv_rintf", "__nv_rint");
   populateOpPatterns<math::RsqrtOp>(converter, patterns, "__nv_rsqrtf",
                                     "__nv_rsqrt");
-  populateOpPatterns<math::SinOp>(converter, patterns, "__nv_sinf", "__nv_sin");
+  populateOpPatterns<math::SinOp>(converter, patterns, "__nv_sinf", "__nv_sin", "__nv_fast_sinf");
+  populateOpPatterns<math::SinhOp>(converter, patterns, "__nv_sinhf", "__nv_sinh");
   populateOpPatterns<math::SqrtOp>(converter, patterns, "__nv_sqrtf",
                                    "__nv_sqrt");
+  populateOpPatterns<math::TanOp>(converter, patterns, "__nv_tanf", "__nv_tan", "__nv_fast_tanf");
   populateOpPatterns<math::TanhOp>(converter, patterns, "__nv_tanhf",
                                    "__nv_tanh");
-  populateOpPatterns<math::TanOp>(converter, patterns, "__nv_tanf", "__nv_tan");
 }
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index 03c7ce5dac0d1..4344fdc142cd2 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -38,9 +38,9 @@ using namespace mlir;
 template <typename OpTy>
 static void populateOpPatterns(LLVMTypeConverter &converter,
                                RewritePatternSet &patterns, StringRef f32Func,
-                               StringRef f64Func) {
+                               StringRef f64Func, StringRef f32FastFunc = "") {
   patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
-  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
+  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, f32FastFunc);
 }
 
 void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index d914790c05fe0..a3b79ae2561e1 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -254,13 +254,16 @@ gpu.module @test_module_9 {
 gpu.module @test_module_10 {
   // CHECK: llvm.func @__nv_cosf(f32) -> f32
   // CHECK: llvm.func @__nv_cos(f64) -> f64
+  // CHECK: llvm.func @__nv_fast_cosf(f32) -> f32
   // CHECK-LABEL: func @gpu_cos
-  func.func @gpu_cos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @gpu_cos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64, f32) {
     %result32 = math.cos %arg_f32 : f32
     // CHECK: llvm.call @__nv_cosf(%{{.*}}) : (f32) -> f32
     %result64 = math.cos %arg_f64 : f64
     // CHECK: llvm.call @__nv_cos(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    %result32Fast = math.cos %arg_f32 fastmath<fast> : f32
+    // CHECK: llvm.call @__nv_fast_cosf(%{{.*}}) : (f32) -> f32
+    func.return %result32, %result64, %result32Fast : f32, f64, f32
   }
 }
 
@@ -268,13 +271,16 @@ gpu.module @test_module_10 {
 gpu.module @test_module_11 {
   // CHECK: llvm.func @__nv_expf(f32) -> f32
   // CHECK: llvm.func @__nv_exp(f64) -> f64
+  // CHECK: llvm.func @__nv_fast_expf(f32) -> f32
   // CHECK-LABEL: func @gpu_exp
-  func.func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64, f32) {
     %result32 = math.exp %arg_f32 : f32
     // CHECK: llvm.call @__nv_expf(%{{.*}}) : (f32) -> f32
     %result64 = math.exp %arg_f64 : f64
     // CHECK: llvm.call @__nv_exp(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    %result32Fast = math.exp %arg_f32 fastmath<fast> : f32
+    // CHECK: llvm.call @__nv_fast_expf(%{{.*}}) : (f32) -> f32
+    func.return %result32, %result64, %result32Fast : f32, f64, f32
   }
 }
 
@@ -297,13 +303,16 @@ gpu.module @test_module_12 {
 gpu.module @test_module_13 {
   // CHECK: llvm.func @__nv_logf(f32) -> f32
   // CHECK: llvm.func @__nv_log(f64) -> f64
+  // CHECK: llvm.func @__nv_fast_logf(f32) -> f32
   // CHECK-LABEL: func @gpu_log
-  func.func @gpu_log(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @gpu_log(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64, f32) {
     %result32 = math.log %arg_f32 : f32
     // CHECK: llvm.call @__nv_logf(%{{.*}}) : (f32) -> f32
     %result64 = math.log %arg_f64 : f64
     // CHECK: llvm.call @__nv_log(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    %result32Fast = math.log %arg_f32 fastmath<fast> : f32
+    // CHECK: llvm.call @__nv_fast_logf(%{{.*}}) : (f32) -> f32
+    func.return %result32, %result64, %result32Fast : f32, f64, f32
   }
 }
 
@@ -312,13 +321,16 @@ gpu.module @test_module_13 {
 gpu.module @test_module_14 {
   // CHECK: llvm.func @__nv_log10f(f32) -> f32
   // CHECK: llvm.func @__nv_log10(f64) -> f64
+  // CHECK: llvm.func @__nv_fast_log10f(f32) -> f32
   // CHECK-LABEL: func @gpu_log10
-  func.func @gpu_log10(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @gpu_log10(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64, f32) {
     %result32 = math.log10 %arg_f32 : f32
     // CHECK: llvm.call @__nv_log10f(%{{.*}}) : (f32) -> f32
     %result64 = math.log10 %arg_f64 : f64
     // CHECK: llvm.call @__nv_log10(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    %result32Fast = math.log10 %arg_f32 fastmath<fast> : f32
+    // CHECK: llvm.call @__nv_fast_log10f(%{{.*}}) : (f32) -> f32
+    func.return %result32, %result64, %result32Fast : f32, f64, f32
   }
 }
 
@@ -342,13 +354,16 @@ gpu.module @test_module_15 {
 gpu.module @test_module_16 {
   // CHECK: llvm.func @__nv_log2f(f32) -> f32
   // CHECK: llvm.func @__nv_log2(f64) -> f64
+  // CHECK: llvm.func @__nv_fast_log2f(f32) -> f32
   // CHECK-LABEL: func @gpu_log2
-  func.func @gpu_log2(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @gpu_log2(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64, f32) {
     %result32 = math.log2 %arg_f32 : f32
     // CHECK: llvm.call @__nv_log2f(%{{.*}}) : (f32) -> f32
     %result64 = math.log2 %arg_f64 : f64
     // CHECK: llvm.call @__nv_log2(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    %result32Fast = math.log2 %arg_f32 fastmath<fast> : f32
+    // CHECK: llvm.call @__nv_fast_log2f(%{{.*}}) : (f32) -> f32
+    func.return %result32, %result64, %result32Fast : f32, f64, f32
   }
 }
 
@@ -357,13 +372,16 @@ gpu.module @test_module_16 {
 gpu.module @test_module_17 {
   // CHECK: llvm.func @__nv_sinf(f32) -> f32
   // CHECK: llvm.func @__nv_sin(f64) -> f64
+  // CHECK: llvm.func @__nv_fast_sinf(f32) -> f32
   // CHECK-LABEL: func @gpu_sin
-  func.func @gpu_sin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @gpu_sin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64, f32) {
     %result32 = math.sin %arg_f32 : f32
     // CHECK: llvm.call @__nv_sinf(%{{.*}}) : (f32) -> f32
     %result64 = math.sin %arg_f64 : f64
     // CHECK: llvm.call @__nv_sin(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    %result32Fast = math.sin %arg_f32 fastmath<fast> : f32
+    // CHECK: llvm.call @__nv_fast_sinf(%{{.*}}) : (f32) -> f32
+    func.return %result32, %result64, %result32Fast : f32, f64, f32
   }
 }
 
@@ -372,8 +390,9 @@ gpu.module @test_module_17 {
 gpu.module @test_module_18 {
   // CHECK: llvm.func @__nv_tanf(f32) -> f32
   // CHECK: llvm.func @__nv_tan(f64) -> f64
+  // CHECK: llvm.func @__nv_fast_tanf(f32) -> f32
   // CHECK-LABEL: func @gpu_tan
-  func.func @gpu_tan(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+  func.func @gpu_tan(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64, f32) {
     %result16 = math.tan %arg_f16 : f16
     // CHECK: llvm.fpext %{{.*}} : f16 to f32
     // CHECK-NEXT: llvm.call @__nv_tanf(%{{.*}}) : (f32) -> f32
@@ -382,7 +401,9 @@ gpu.module @test_module_18 {
     // CHECK: llvm.call @__nv_tanf(%{{.*}}) : (f32) -> f32
     %result64 = math.tan %arg_f64 : f64
     // CHECK: llvm.call @__nv_tan(%{{.*}}) : (f64) -> f64
-    func.return %result16, %result32, %result64 : f16, f32, f64
+    %result32Fast = math.tan %arg_f32 fastmath<fast> : f32
+    // CHECK: llvm.call @__nv_fast_tanf(%{{.*}}) : (f32) -> f32
+    func.return %result16, %result32, %result64, %result32Fast : f16, f32, f64, f32
   }
 }
 
@@ -494,13 +515,16 @@ gpu.module @test_module_24 {
   // CHECK: test.symbol_scope
   // CHECK: llvm.func @__nv_expf(f32) -> f32
   // CHECK: llvm.func @__nv_exp(f64) -> f64
+  // CHECK: llvm.func @__nv_fast_expf(f32) -> f32
   // CHECK-LABEL: func @gpu_exp
-    func.func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    func.func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64, f32) {
       %result32 = math.exp %arg_f32 : f32
       // CHECK: llvm.call @__nv_expf(%{{.*}}) : (f32) -> f32
       %result64 = math.exp %arg_f64 : f64
       // CHECK: llvm.call @__nv_exp(%{{.*}}) : (f64) -> f64
-      func.return %result32, %result64 : f32, f64
+      %result32Fast = math.exp %arg_f32 fastmath<fast> : f32
+      // CHECK: llvm.call @__nv_fast_expf(%{{.*}}) : (f32) -> f32
+      func.return %result32, %result64, %result32Fast : f32, f64, f32
     }
     "test.finish" () : () -> ()
   }) : () -> ()
@@ -526,13 +550,16 @@ gpu.module @test_module_25 {
 gpu.module @test_module_26 {
   // CHECK: llvm.func @__nv_powf(f32, f32) -> f32
   // CHECK: llvm.func @__nv_pow(f64, f64) -> f64
+  // CHECK: llvm.func @__nv_fast_powf(f32, f32) -> f32
   // CHECK-LABEL: func @gpu_pow
-  func.func @gpu_pow(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @gpu_pow(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64, f32) {
     %result32 = math.powf %arg_f32, %arg_f32 : f32
     // CHECK: llvm.call @__nv_powf(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
     %result64 = math.powf %arg_f64, %arg_f64 : f64
     // CHECK: llvm.call @__nv_pow(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    %result32Fast = math.powf %arg_f32, %arg_f32 fastmath<fast> : f32
+    // CHECK: llvm.call @__nv_fast_powf(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
+    func.return %result32, %result64, %result32Fast : f32, f64, f32
   }
 }
 
@@ -701,6 +728,179 @@ gpu.module @test_module_34 {
   }
 }
 
+gpu.module @test_module_35 {
+  // CHECK: llvm.func @__nv_acosf(f32) -> f32
+  // CHECK: llvm.func @__nv_acos(f64) -> f64
+  // CHECK-LABEL: func @gpu_acos
+  func.func @gpu_acos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.acos %arg_f32 : f32
+    // CHECK: llvm.call @__nv_acosf(%{{.*}}) : (f32) -> f32
+    %result64 = math.acos %arg_f64 : f64
+    // CHECK: llvm.call @__nv_acos(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+gpu.module @test_module_36 {
+  // CHECK: llvm.func @__nv_acoshf(f32) -> f32
+  // CHECK: llvm.func @__nv_acosh(f64) -> f64
+  // CHECK-LABEL: func @gpu_acosh
+  func.func @gpu_acosh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.acosh %arg_f32 : f32
+    // CHECK: llvm.call @__nv_acoshf(%{{.*}}) : (f32) -> f32
+    %result64 = math.acosh %arg_f64 : f64
+    // CHECK: llvm.call @__nv_acosh(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+gpu.module @test_module_37 {
+  // CHECK: llvm.func @__nv_asinf(f32) -> f32
+  // CHECK: llvm.func @__nv_asin(f64) -> f64
+  // CHECK-LABEL: func @gpu_asin
+  func.func @gpu_asin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.asin %arg_f32 : f32
+    // CHECK: llvm.call @__nv_asinf(%{{.*}}) : (f32) -> f32
+    %result64 = math.asin %arg_f64 : f64
+    // CHECK: llvm.call @__nv_asin(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+gpu.module @test_module_38 {
+  // CHECK: llvm.func @__nv_asinhf(f32) -> f32
+  // CHECK: llvm.func @__nv_asinh(f64) -> f64
+  // CHECK-LABEL: func @gpu_asinh
+  func.func @gpu_asinh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.asinh %arg_f32 : f32
+    // CHECK: llvm.call @__nv_asinhf(%{{.*}}) : (f32) -> f32
+    %result64 = math.asinh %arg_f64 : f64
+    // CHECK: llvm.call @__nv_asinh(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+gpu.module @test_module_39 {
+  // CHECK: llvm.func @__nv_atanhf(f32) -> f32
+  // CHECK: llvm.func @__nv_atanh(f64) -> f64
+  // CHECK-LABEL: func @gpu_atanh
+  func.func @gpu_atanh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64)
+      -> (f16, f32, f64) {
+    %result16 = math.atanh %arg_f16 : f16
+    // CHECK: llvm.fpext %{{.*}} : f16 to f32
+    // CHECK-NEXT: llvm.call @__nv_atanhf(%{{.*}}) : (f32) -> f3...
[truncated]

Copy link

github-actions bot commented Jul 22, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@runseny runseny force-pushed the gpu2nnvm_more_math branch 2 times, most recently from 4eff225 to 586d3fd Compare July 22, 2024 16:13
: ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
f64Func(f64Func) {}
f64Func(f64Func), f32FastFunc(f32FastFunc) {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the class documentation please?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, updated, the diff please see : link

return f32Func;
StringRef getFunctionName(Type type, arith::FastMathFlags flag) const {
if (isa<Float32Type>(type)) {
if (arith::FastMathFlags::fast == flag && !f32FastFunc.empty())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (arith::FastMathFlags::fast == flag && !f32FastFunc.empty())
if ((arith::FastMathFlags::afn & flag) && !f32FastFunc.empty())

We only need afn I believe (please add such test)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, modified, the code diff please see : link
the related test case please see : link

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still use arith::FastMathFlags::afn == flag instead of arith::FastMathFlags::afn & flag. The reason is that the binary of afn is 1000000 (decimal is 64) , if flag is fast whose binary is 1111111 (decimal is 127) , the result of & is still true.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps i get your point, fast should also imply afn, so & is necessary. Modified to &, please see : link

Related test case please see : link and link

please take a look again :)

@runseny runseny force-pushed the gpu2nnvm_more_math branch 2 times, most recently from a0c136b to 6fcf649 Compare July 23, 2024 14:48

Verified

This commit was signed with the committer’s verified signature.
pentschev Peter Andreas Entschev
@runseny runseny force-pushed the gpu2nnvm_more_math branch from 6fcf649 to d5cb1ee Compare July 23, 2024 15:11
Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG, thanks!

@runseny
Copy link
Contributor Author

runseny commented Jul 25, 2024

Hi, @joker-eph Mehdi, could you help merge this mr? I don't have permission to merge it. Thanks!

@joker-eph joker-eph merged commit f6431f0 into llvm:main Jul 25, 2024
7 checks passed
yuxuanchen1997 pushed a commit that referenced this pull request Jul 25, 2024

Verified

This commit was signed with the committer’s verified signature.
pentschev Peter Andreas Entschev
…9890)

Summary:
Support fastMath and other non-supported mathOp which only require float
operands and call libdevice function directly to nvvm.

1. lowering mathOp with fastMath attribute to correct libdevice
intrinsic.
2. some mathOp in math dialect has been lowered to libdevice now, but it
doesn't cover all mathOp. so this mr lowers all the remaining mathOp
which only require float operands.

Test Plan: 

Reviewers: 

Subscribers: 

Tasks: 

Tags: 


Differential Revision: https://phabricator.intern.facebook.com/D60250617
bjacob added a commit to iree-org/llvm-project that referenced this pull request Jul 28, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
…thOp (llvm#99890)"

This reverts commit f6431f0.
bjacob added a commit to iree-org/iree that referenced this pull request Jul 29, 2024
Advance to llvm/llvm-project@f4be6812

New local patch:
* Reverting llvm/llvm-project@f6431f0c
(llvm/llvm-project#99890). Discussion:
https://discord.com/channels/689900678990135345/1266421307285966859
* This caused new compile failures in our CUDA flow. We suspect our
usage of upstream passes/patterns need some updates in this repo.
Reverting for now to give ourselves some time to investigate.

Existing local patches carried over:
* Still carrying a revert of
llvm/llvm-project@fa06668
* Still carrying a revert of
llvm/llvm-project@bbd4af5

Dropped local patches:
* No longer carrying a local patch of
llvm/llvm-project@4f80508

Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
bjacob added a commit to iree-org/llvm-project that referenced this pull request Jul 29, 2024
bjacob added a commit to iree-org/llvm-project that referenced this pull request Jul 29, 2024
hanhanW pushed a commit to iree-org/llvm-project that referenced this pull request Jul 29, 2024
bjacob added a commit to iree-org/llvm-project that referenced this pull request Jul 30, 2024
hanhanW pushed a commit to iree-org/llvm-project that referenced this pull request Jul 30, 2024
bjacob added a commit to iree-org/llvm-project that referenced this pull request Jul 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants