diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp index 29926719129dc..fc3e1fc4f9d0c 100644 --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -26,6 +26,7 @@ #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/MathToROCDL/MathToROCDL.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" @@ -290,6 +291,7 @@ struct LowerGpuOpsToROCDLOpsPass populateAMDGPUToROCDLConversionPatterns(converter, llvmPatterns, *maybeChipset); populateVectorToLLVMConversionPatterns(converter, llvmPatterns); + populateMathToLLVMConversionPatterns(converter, llvmPatterns); cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns); populateFuncToLLVMConversionPatterns(converter, llvmPatterns); populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns); @@ -332,7 +334,11 @@ void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) { target.addIllegalOp(); - + // These ops are legal for f16 and f32 type. + target.addDynamicallyLegalOp([](Operation *op) { + return any_of(op->getOperandTypes(), + llvm::IsaPred); + }); // TODO: Remove once we support replacing non-root ops. target.addLegalOp(); } diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir index b6fb08522ae1f..c0b62b46dcf2c 100644 --- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir +++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir @@ -131,6 +131,68 @@ gpu.module @test_module { // ----- +gpu.module @test_module { + // CHECK-LABEL: func @gpu_sqrt + func.func @gpu_sqrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.sqrt %arg_f16 : f16 + // CHECK: llvm.intr.sqrt(%{{.*}}) : (f16) -> f16 + %result32 = math.sqrt %arg_f32 : f32 + // CHECK: llvm.intr.sqrt(%{{.*}}) : (f32) -> f32 + %result64 = math.sqrt %arg_f64 : f64 + // CHECK: llvm.intr.sqrt(%{{.*}}) : (f64) -> f64 + func.return %result16, %result32, %result64 : f16, f32, f64 + } +} + +// ----- + +gpu.module @test_module { + // CHECK-LABEL: func @gpu_fabs + func.func @gpu_fabs(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.absf %arg_f16 : f16 + // CHECK: llvm.intr.fabs(%{{.*}}) : (f16) -> f16 + %result32 = math.absf %arg_f32 : f32 + // CHECK: llvm.intr.fabs(%{{.*}}) : (f32) -> f32 + %result64 = math.absf %arg_f64 : f64 + // CHECK: llvm.intr.fabs(%{{.*}}) : (f64) -> f64 + func.return %result16, %result32, %result64 : f16, f32, f64 + } +} + +// ----- + +gpu.module @test_module { + // CHECK: llvm.func @__ocml_exp_f64(f64) -> f64 + // CHECK-LABEL: func @gpu_exp + func.func @gpu_exp(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.exp %arg_f16 : f16 + // CHECK: llvm.intr.exp(%{{.*}}) : (f16) -> f16 + %result32 = math.exp %arg_f32 : f32 + // CHECK: llvm.intr.exp(%{{.*}}) : (f32) -> f32 + %result64 = math.exp %arg_f64 : f64 + // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64 + func.return %result16, %result32, %result64 : f16, f32, f64 + } +} + +// ----- + +gpu.module @test_module { + // CHECK: llvm.func @__ocml_log_f64(f64) -> f64 + // CHECK-LABEL: func @gpu_log + func.func @gpu_log(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.log %arg_f16 : f16 + // CHECK: llvm.intr.log(%{{.*}}) : (f16) -> f16 + %result32 = math.log %arg_f32 : f32 + // CHECK: llvm.intr.log(%{{.*}}) : (f32) -> f32 + %result64 = math.log %arg_f64 : f64 + // CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64 + func.return %result16, %result32, %result64 : f16, f32, f64 + } +} + +// ----- + gpu.module @test_module { // CHECK: llvm.func @__ocml_cbrt_f32(f32) -> f32 // CHECK: llvm.func @__ocml_cbrt_f64(f64) -> f64 diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index c931898ed98e3..4124897722d23 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6004,6 +6004,7 @@ cc_library( ":LLVMCommonConversion", ":LLVMDialect", ":MathDialect", + ":MathToLLVM", ":MathToROCDL", ":MemRefDialect", ":MemRefToLLVM",