-
Notifications
You must be signed in to change notification settings - Fork 12.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][GPUToNVVM] support fastMath and other non-supported mathOp #99890
Conversation
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: None (runseny) ChangesSupport fastMath and other non-supported mathOp which only require float operands and call libdevice function directly to nvvm.
Patch is 25.48 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/99890.diff 4 Files Affected:
diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index ebce2d77310ae..7ce17a69d7e4d 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -31,9 +31,9 @@ template <typename SourceOp>
struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
public:
explicit OpToFuncCallLowering(LLVMTypeConverter &lowering, StringRef f32Func,
- StringRef f64Func)
+ StringRef f64Func, StringRef f32FastFunc)
: ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
- f64Func(f64Func) {}
+ f64Func(f64Func), f32FastFunc(f32FastFunc) {}
LogicalResult
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
@@ -55,7 +55,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
Type resultType = castedOperands.front().getType();
Type funcType = getFunctionType(resultType, castedOperands);
StringRef funcName =
- getFunctionName(cast<LLVM::LLVMFunctionType>(funcType).getReturnType());
+ getFunctionName(cast<LLVM::LLVMFunctionType>(funcType).getReturnType(), op.getFastmath());
if (funcName.empty())
return failure();
@@ -90,9 +90,13 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
return LLVM::LLVMFunctionType::get(resultType, operandTypes);
}
- StringRef getFunctionName(Type type) const {
- if (isa<Float32Type>(type))
- return f32Func;
+ StringRef getFunctionName(Type type, arith::FastMathFlags flag) const {
+ if (isa<Float32Type>(type)) {
+ if (arith::FastMathFlags::fast == flag && !f32FastFunc.empty())
+ return f32FastFunc;
+ else
+ return f32Func;
+ }
if (isa<Float64Type>(type))
return f64Func;
return "";
@@ -113,8 +117,10 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
const std::string f32Func;
const std::string f64Func;
+ const std::string f32FastFunc;
};
+
} // namespace mlir
#endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index fea8a0ddc7f06..9cfad02538c98 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -309,10 +309,11 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
target.addIllegalDialect<gpu::GPUDialect>();
- target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
- LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
- LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
- LLVM::SqrtOp>();
+ target.addIllegalOp<LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
+ LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FMAOp, LLVM::FRemOp, LLVM::LogOp,
+ LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::RoundEvenOp,
+ LLVM::RoundOp, LLVM::SinOp, LLVM::SqrtOp>();
+
// TODO: Remove once we support replacing non-root ops.
target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
@@ -321,9 +322,9 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
template <typename OpTy>
static void populateOpPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns, StringRef f32Func,
- StringRef f64Func) {
+ StringRef f64Func, StringRef f32FastFunc = "") {
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
- patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
+ patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, f32FastFunc);
}
void mlir::populateGpuSubgroupReduceOpLoweringPattern(
@@ -370,42 +371,53 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
StringAttr::get(&converter.getContext(),
NVVM::NVVMDialect::getMaxntidAttrName()));
+ populateOpPatterns<arith::RemFOp>(converter, patterns, "__nv_fmodf",
+ "__nv_fmod");
populateOpPatterns<math::AbsFOp>(converter, patterns, "__nv_fabsf",
"__nv_fabs");
+ populateOpPatterns<math::AcosOp>(converter, patterns, "__nv_acosf", "__nv_acos");
+ populateOpPatterns<math::AcoshOp>(converter, patterns, "__nv_acoshf", "__nv_acosh");
+ populateOpPatterns<math::AsinOp>(converter, patterns, "__nv_asinf", "__nv_asin");
+ populateOpPatterns<math::AsinhOp>(converter, patterns, "__nv_asinhf", "__nv_asinh");
populateOpPatterns<math::AtanOp>(converter, patterns, "__nv_atanf",
"__nv_atan");
populateOpPatterns<math::Atan2Op>(converter, patterns, "__nv_atan2f",
"__nv_atan2");
+ populateOpPatterns<math::AtanhOp>(converter, patterns, "__nv_atanhf", "__nv_atanh");
populateOpPatterns<math::CbrtOp>(converter, patterns, "__nv_cbrtf",
"__nv_cbrt");
populateOpPatterns<math::CeilOp>(converter, patterns, "__nv_ceilf",
"__nv_ceil");
- populateOpPatterns<math::CosOp>(converter, patterns, "__nv_cosf", "__nv_cos");
+ populateOpPatterns<math::CopySignOp>(converter, patterns, "__nv_copysignf", "__nv_copysign");
+ populateOpPatterns<math::CosOp>(converter, patterns, "__nv_cosf", "__nv_cos", "__nv_fast_cosf");
+ populateOpPatterns<math::CoshOp>(converter, patterns, "__nv_coshf", "__nv_cosh");
populateOpPatterns<math::ErfOp>(converter, patterns, "__nv_erff", "__nv_erf");
- populateOpPatterns<math::ExpOp>(converter, patterns, "__nv_expf", "__nv_exp");
+ populateOpPatterns<math::ExpOp>(converter, patterns, "__nv_expf", "__nv_exp", "__nv_fast_expf");
populateOpPatterns<math::Exp2Op>(converter, patterns, "__nv_exp2f",
"__nv_exp2");
populateOpPatterns<math::ExpM1Op>(converter, patterns, "__nv_expm1f",
"__nv_expm1");
populateOpPatterns<math::FloorOp>(converter, patterns, "__nv_floorf",
"__nv_floor");
- populateOpPatterns<arith::RemFOp>(converter, patterns, "__nv_fmodf",
- "__nv_fmod");
- populateOpPatterns<math::LogOp>(converter, patterns, "__nv_logf", "__nv_log");
+ populateOpPatterns<math::FmaOp>(converter, patterns, "__nv_fmaf", "__nv_fma");
+ populateOpPatterns<math::LogOp>(converter, patterns, "__nv_logf", "__nv_log", "__nv_fast_logf");
+ populateOpPatterns<math::Log10Op>(converter, patterns, "__nv_log10f",
+ "__nv_log10", "__nv_fast_log10f");
populateOpPatterns<math::Log1pOp>(converter, patterns, "__nv_log1pf",
"__nv_log1p");
- populateOpPatterns<math::Log10Op>(converter, patterns, "__nv_log10f",
- "__nv_log10");
populateOpPatterns<math::Log2Op>(converter, patterns, "__nv_log2f",
- "__nv_log2");
+ "__nv_log2", "__nv_fast_log2f");
populateOpPatterns<math::PowFOp>(converter, patterns, "__nv_powf",
- "__nv_pow");
+ "__nv_pow", "__nv_fast_powf");
+ populateOpPatterns<math::RoundOp>(converter, patterns, "__nv_roundf", "__nv_round");
+ populateOpPatterns<math::RoundEvenOp>(converter, patterns, "__nv_rintf", "__nv_rint");
populateOpPatterns<math::RsqrtOp>(converter, patterns, "__nv_rsqrtf",
"__nv_rsqrt");
- populateOpPatterns<math::SinOp>(converter, patterns, "__nv_sinf", "__nv_sin");
+ populateOpPatterns<math::SinOp>(converter, patterns, "__nv_sinf", "__nv_sin", "__nv_fast_sinf");
+ populateOpPatterns<math::SinhOp>(converter, patterns, "__nv_sinhf", "__nv_sinh");
populateOpPatterns<math::SqrtOp>(converter, patterns, "__nv_sqrtf",
"__nv_sqrt");
+ populateOpPatterns<math::TanOp>(converter, patterns, "__nv_tanf", "__nv_tan", "__nv_fast_tanf");
populateOpPatterns<math::TanhOp>(converter, patterns, "__nv_tanhf",
"__nv_tanh");
- populateOpPatterns<math::TanOp>(converter, patterns, "__nv_tanf", "__nv_tan");
}
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index 03c7ce5dac0d1..4344fdc142cd2 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -38,9 +38,9 @@ using namespace mlir;
template <typename OpTy>
static void populateOpPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns, StringRef f32Func,
- StringRef f64Func) {
+ StringRef f64Func, StringRef f32FastFunc = "") {
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
- patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
+ patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, f32FastFunc);
}
void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index d914790c05fe0..a3b79ae2561e1 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -254,13 +254,16 @@ gpu.module @test_module_9 {
gpu.module @test_module_10 {
// CHECK: llvm.func @__nv_cosf(f32) -> f32
// CHECK: llvm.func @__nv_cos(f64) -> f64
+ // CHECK: llvm.func @__nv_fast_cosf(f32) -> f32
// CHECK-LABEL: func @gpu_cos
- func.func @gpu_cos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_cos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64, f32) {
%result32 = math.cos %arg_f32 : f32
// CHECK: llvm.call @__nv_cosf(%{{.*}}) : (f32) -> f32
%result64 = math.cos %arg_f64 : f64
// CHECK: llvm.call @__nv_cos(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ %result32Fast = math.cos %arg_f32 fastmath<fast> : f32
+ // CHECK: llvm.call @__nv_fast_cosf(%{{.*}}) : (f32) -> f32
+ func.return %result32, %result64, %result32Fast : f32, f64, f32
}
}
@@ -268,13 +271,16 @@ gpu.module @test_module_10 {
gpu.module @test_module_11 {
// CHECK: llvm.func @__nv_expf(f32) -> f32
// CHECK: llvm.func @__nv_exp(f64) -> f64
+ // CHECK: llvm.func @__nv_fast_expf(f32) -> f32
// CHECK-LABEL: func @gpu_exp
- func.func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64, f32) {
%result32 = math.exp %arg_f32 : f32
// CHECK: llvm.call @__nv_expf(%{{.*}}) : (f32) -> f32
%result64 = math.exp %arg_f64 : f64
// CHECK: llvm.call @__nv_exp(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ %result32Fast = math.exp %arg_f32 fastmath<fast> : f32
+ // CHECK: llvm.call @__nv_fast_expf(%{{.*}}) : (f32) -> f32
+ func.return %result32, %result64, %result32Fast : f32, f64, f32
}
}
@@ -297,13 +303,16 @@ gpu.module @test_module_12 {
gpu.module @test_module_13 {
// CHECK: llvm.func @__nv_logf(f32) -> f32
// CHECK: llvm.func @__nv_log(f64) -> f64
+ // CHECK: llvm.func @__nv_fast_logf(f32) -> f32
// CHECK-LABEL: func @gpu_log
- func.func @gpu_log(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_log(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64, f32) {
%result32 = math.log %arg_f32 : f32
// CHECK: llvm.call @__nv_logf(%{{.*}}) : (f32) -> f32
%result64 = math.log %arg_f64 : f64
// CHECK: llvm.call @__nv_log(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ %result32Fast = math.log %arg_f32 fastmath<fast> : f32
+ // CHECK: llvm.call @__nv_fast_logf(%{{.*}}) : (f32) -> f32
+ func.return %result32, %result64, %result32Fast : f32, f64, f32
}
}
@@ -312,13 +321,16 @@ gpu.module @test_module_13 {
gpu.module @test_module_14 {
// CHECK: llvm.func @__nv_log10f(f32) -> f32
// CHECK: llvm.func @__nv_log10(f64) -> f64
+ // CHECK: llvm.func @__nv_fast_log10f(f32) -> f32
// CHECK-LABEL: func @gpu_log10
- func.func @gpu_log10(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_log10(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64, f32) {
%result32 = math.log10 %arg_f32 : f32
// CHECK: llvm.call @__nv_log10f(%{{.*}}) : (f32) -> f32
%result64 = math.log10 %arg_f64 : f64
// CHECK: llvm.call @__nv_log10(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ %result32Fast = math.log10 %arg_f32 fastmath<fast> : f32
+ // CHECK: llvm.call @__nv_fast_log10f(%{{.*}}) : (f32) -> f32
+ func.return %result32, %result64, %result32Fast : f32, f64, f32
}
}
@@ -342,13 +354,16 @@ gpu.module @test_module_15 {
gpu.module @test_module_16 {
// CHECK: llvm.func @__nv_log2f(f32) -> f32
// CHECK: llvm.func @__nv_log2(f64) -> f64
+ // CHECK: llvm.func @__nv_fast_log2f(f32) -> f32
// CHECK-LABEL: func @gpu_log2
- func.func @gpu_log2(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_log2(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64, f32) {
%result32 = math.log2 %arg_f32 : f32
// CHECK: llvm.call @__nv_log2f(%{{.*}}) : (f32) -> f32
%result64 = math.log2 %arg_f64 : f64
// CHECK: llvm.call @__nv_log2(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ %result32Fast = math.log2 %arg_f32 fastmath<fast> : f32
+ // CHECK: llvm.call @__nv_fast_log2f(%{{.*}}) : (f32) -> f32
+ func.return %result32, %result64, %result32Fast : f32, f64, f32
}
}
@@ -357,13 +372,16 @@ gpu.module @test_module_16 {
gpu.module @test_module_17 {
// CHECK: llvm.func @__nv_sinf(f32) -> f32
// CHECK: llvm.func @__nv_sin(f64) -> f64
+ // CHECK: llvm.func @__nv_fast_sinf(f32) -> f32
// CHECK-LABEL: func @gpu_sin
- func.func @gpu_sin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_sin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64, f32) {
%result32 = math.sin %arg_f32 : f32
// CHECK: llvm.call @__nv_sinf(%{{.*}}) : (f32) -> f32
%result64 = math.sin %arg_f64 : f64
// CHECK: llvm.call @__nv_sin(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ %result32Fast = math.sin %arg_f32 fastmath<fast> : f32
+ // CHECK: llvm.call @__nv_fast_sinf(%{{.*}}) : (f32) -> f32
+ func.return %result32, %result64, %result32Fast : f32, f64, f32
}
}
@@ -372,8 +390,9 @@ gpu.module @test_module_17 {
gpu.module @test_module_18 {
// CHECK: llvm.func @__nv_tanf(f32) -> f32
// CHECK: llvm.func @__nv_tan(f64) -> f64
+ // CHECK: llvm.func @__nv_fast_tanf(f32) -> f32
// CHECK-LABEL: func @gpu_tan
- func.func @gpu_tan(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ func.func @gpu_tan(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64, f32) {
%result16 = math.tan %arg_f16 : f16
// CHECK: llvm.fpext %{{.*}} : f16 to f32
// CHECK-NEXT: llvm.call @__nv_tanf(%{{.*}}) : (f32) -> f32
@@ -382,7 +401,9 @@ gpu.module @test_module_18 {
// CHECK: llvm.call @__nv_tanf(%{{.*}}) : (f32) -> f32
%result64 = math.tan %arg_f64 : f64
// CHECK: llvm.call @__nv_tan(%{{.*}}) : (f64) -> f64
- func.return %result16, %result32, %result64 : f16, f32, f64
+ %result32Fast = math.tan %arg_f32 fastmath<fast> : f32
+ // CHECK: llvm.call @__nv_fast_tanf(%{{.*}}) : (f32) -> f32
+ func.return %result16, %result32, %result64, %result32Fast : f16, f32, f64, f32
}
}
@@ -494,13 +515,16 @@ gpu.module @test_module_24 {
// CHECK: test.symbol_scope
// CHECK: llvm.func @__nv_expf(f32) -> f32
// CHECK: llvm.func @__nv_exp(f64) -> f64
+ // CHECK: llvm.func @__nv_fast_expf(f32) -> f32
// CHECK-LABEL: func @gpu_exp
- func.func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64, f32) {
%result32 = math.exp %arg_f32 : f32
// CHECK: llvm.call @__nv_expf(%{{.*}}) : (f32) -> f32
%result64 = math.exp %arg_f64 : f64
// CHECK: llvm.call @__nv_exp(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ %result32Fast = math.exp %arg_f32 fastmath<fast> : f32
+ // CHECK: llvm.call @__nv_fast_expf(%{{.*}}) : (f32) -> f32
+ func.return %result32, %result64, %result32Fast : f32, f64, f32
}
"test.finish" () : () -> ()
}) : () -> ()
@@ -526,13 +550,16 @@ gpu.module @test_module_25 {
gpu.module @test_module_26 {
// CHECK: llvm.func @__nv_powf(f32, f32) -> f32
// CHECK: llvm.func @__nv_pow(f64, f64) -> f64
+ // CHECK: llvm.func @__nv_fast_powf(f32, f32) -> f32
// CHECK-LABEL: func @gpu_pow
- func.func @gpu_pow(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_pow(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64, f32) {
%result32 = math.powf %arg_f32, %arg_f32 : f32
// CHECK: llvm.call @__nv_powf(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
%result64 = math.powf %arg_f64, %arg_f64 : f64
// CHECK: llvm.call @__nv_pow(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
- func.return %result32, %result64 : f32, f64
+ %result32Fast = math.powf %arg_f32, %arg_f32 fastmath<fast> : f32
+ // CHECK: llvm.call @__nv_fast_powf(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
+ func.return %result32, %result64, %result32Fast : f32, f64, f32
}
}
@@ -701,6 +728,179 @@ gpu.module @test_module_34 {
}
}
+gpu.module @test_module_35 {
+ // CHECK: llvm.func @__nv_acosf(f32) -> f32
+ // CHECK: llvm.func @__nv_acos(f64) -> f64
+ // CHECK-LABEL: func @gpu_acos
+ func.func @gpu_acos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.acos %arg_f32 : f32
+ // CHECK: llvm.call @__nv_acosf(%{{.*}}) : (f32) -> f32
+ %result64 = math.acos %arg_f64 : f64
+ // CHECK: llvm.call @__nv_acos(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+gpu.module @test_module_36 {
+ // CHECK: llvm.func @__nv_acoshf(f32) -> f32
+ // CHECK: llvm.func @__nv_acosh(f64) -> f64
+ // CHECK-LABEL: func @gpu_acosh
+ func.func @gpu_acosh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.acosh %arg_f32 : f32
+ // CHECK: llvm.call @__nv_acoshf(%{{.*}}) : (f32) -> f32
+ %result64 = math.acosh %arg_f64 : f64
+ // CHECK: llvm.call @__nv_acosh(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+gpu.module @test_module_37 {
+ // CHECK: llvm.func @__nv_asinf(f32) -> f32
+ // CHECK: llvm.func @__nv_asin(f64) -> f64
+ // CHECK-LABEL: func @gpu_asin
+ func.func @gpu_asin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.asin %arg_f32 : f32
+ // CHECK: llvm.call @__nv_asinf(%{{.*}}) : (f32) -> f32
+ %result64 = math.asin %arg_f64 : f64
+ // CHECK: llvm.call @__nv_asin(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+gpu.module @test_module_38 {
+ // CHECK: llvm.func @__nv_asinhf(f32) -> f32
+ // CHECK: llvm.func @__nv_asinh(f64) -> f64
+ // CHECK-LABEL: func @gpu_asinh
+ func.func @gpu_asinh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.asinh %arg_f32 : f32
+ // CHECK: llvm.call @__nv_asinhf(%{{.*}}) : (f32) -> f32
+ %result64 = math.asinh %arg_f64 : f64
+ // CHECK: llvm.call @__nv_asinh(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+gpu.module @test_module_39 {
+ // CHECK: llvm.func @__nv_atanhf(f32) -> f32
+ // CHECK: llvm.func @__nv_atanh(f64) -> f64
+ // CHECK-LABEL: func @gpu_atanh
+ func.func @gpu_atanh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64)
+ -> (f16, f32, f64) {
+ %result16 = math.atanh %arg_f16 : f16
+ // CHECK: llvm.fpext %{{.*}} : f16 to f32
+ // CHECK-NEXT: llvm.call @__nv_atanhf(%{{.*}}) : (f32) -> f3...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
4eff225
to
586d3fd
Compare
: ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func), | ||
f64Func(f64Func) {} | ||
f64Func(f64Func), f32FastFunc(f32FastFunc) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you update the class documentation please?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, updated, the diff please see : link
return f32Func; | ||
StringRef getFunctionName(Type type, arith::FastMathFlags flag) const { | ||
if (isa<Float32Type>(type)) { | ||
if (arith::FastMathFlags::fast == flag && !f32FastFunc.empty()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (arith::FastMathFlags::fast == flag && !f32FastFunc.empty()) | |
if ((arith::FastMathFlags::afn & flag) && !f32FastFunc.empty()) |
We only need afn I believe (please add such test)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still use arith::FastMathFlags::afn == flag
instead of arith::FastMathFlags::afn & flag
. The reason is that the binary of afn
is 1000000
(decimal is 64) , if flag is fast
whose binary is 1111111
(decimal is 127) , the result of &
is still true.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a0c136b
to
6fcf649
Compare
6fcf649
to
d5cb1ee
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG, thanks!
Hi, @joker-eph Mehdi, could you help merge this mr? I don't have permission to merge it. Thanks! |
…9890) Summary: Support fastMath and other non-supported mathOp which only require float operands and call libdevice function directly to nvvm. 1. lowering mathOp with fastMath attribute to correct libdevice intrinsic. 2. some mathOp in math dialect has been lowered to libdevice now, but it doesn't cover all mathOp. so this mr lowers all the remaining mathOp which only require float operands. Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: https://phabricator.intern.facebook.com/D60250617
…thOp (llvm#99890)" This reverts commit f6431f0.
Advance to llvm/llvm-project@f4be6812 New local patch: * Reverting llvm/llvm-project@f6431f0c (llvm/llvm-project#99890). Discussion: https://discord.com/channels/689900678990135345/1266421307285966859 * This caused new compile failures in our CUDA flow. We suspect our usage of upstream passes/patterns need some updates in this repo. Reverting for now to give ourselves some time to investigate. Existing local patches carried over: * Still carrying a revert of llvm/llvm-project@fa06668 * Still carrying a revert of llvm/llvm-project@bbd4af5 Dropped local patches: * No longer carrying a local patch of llvm/llvm-project@4f80508 Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
…thOp (llvm#99890)" This reverts commit f6431f0.
…thOp (llvm#99890)" This reverts commit f6431f0.
…thOp (llvm#99890)" This reverts commit f6431f0.
…thOp (llvm#99890)" This reverts commit f6431f0.
…thOp (llvm#99890)" This reverts commit f6431f0.
…thOp (llvm#99890)" This reverts commit f6431f0.
Support fastMath and other non-supported mathOp which only require float operands and call libdevice function directly to nvvm.