Skip to content

Commit

Permalink
Revert "[MLIR][GPUToNVVM] support fastMath and other non-supported ma…
Browse files Browse the repository at this point in the history
…thOp (llvm#99890)"

This reverts commit f6431f0.
  • Loading branch information
bjacob authored and hanhanW committed Jul 29, 2024
1 parent 97c52c0 commit f9f4932
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 312 deletions.
34 changes: 9 additions & 25 deletions mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

namespace mlir {

/// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func` or
/// `f32ApproxFunc` depending on the element type and the fastMathFlag of that
/// Op. The function declaration is added in case it was not added before.
/// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func`
/// depending on the element type that Op operates upon. The function
/// declaration is added in case it was not added before.
///
/// If the input values are of f16 type, the value is first casted to f32, the
/// function called and then the result casted back.
Expand All @@ -28,22 +28,13 @@ namespace mlir {
///
/// will be transformed into
/// llvm.call @__nv_expf(%arg_f32) : (f32) -> f32
///
/// If the fastMathFlag attribute of SourceOp is `afn` or `fast`, this Op lowers
/// to the approximate calculation function.
///
/// Also example with NVVM:
/// %exp_f32 = math.exp %arg_f32 fastmath<afn> : f32
///
/// will be transformed into
/// llvm.call @__nv_fast_expf(%arg_f32) : (f32) -> f32
template <typename SourceOp>
struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
public:
explicit OpToFuncCallLowering(LLVMTypeConverter &lowering, StringRef f32Func,
StringRef f64Func, StringRef f32ApproxFunc)
StringRef f64Func)
: ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
f64Func(f64Func), f32ApproxFunc(f32ApproxFunc) {}
f64Func(f64Func) {}

LogicalResult
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
Expand All @@ -65,8 +56,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
Type resultType = castedOperands.front().getType();
Type funcType = getFunctionType(resultType, castedOperands);
StringRef funcName =
getFunctionName(cast<LLVM::LLVMFunctionType>(funcType).getReturnType(),
op.getFastmath());
getFunctionName(cast<LLVM::LLVMFunctionType>(funcType).getReturnType());
if (funcName.empty())
return failure();

Expand Down Expand Up @@ -101,14 +91,9 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
return LLVM::LLVMFunctionType::get(resultType, operandTypes);
}

StringRef getFunctionName(Type type, arith::FastMathFlags flag) const {
if (isa<Float32Type>(type)) {
if (((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag) &&
!f32ApproxFunc.empty())
return f32ApproxFunc;
else
return f32Func;
}
StringRef getFunctionName(Type type) const {
if (isa<Float32Type>(type))
return f32Func;
if (isa<Float64Type>(type))
return f64Func;
return "";
Expand All @@ -129,7 +114,6 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {

const std::string f32Func;
const std::string f64Func;
const std::string f32ApproxFunc;
};

} // namespace mlir
Expand Down
65 changes: 18 additions & 47 deletions mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,11 +309,10 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
target.addIllegalDialect<gpu::GPUDialect>();
target.addIllegalOp<LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op,
LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FMAOp,
LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op,
LLVM::PowOp, LLVM::RoundEvenOp, LLVM::RoundOp,
LLVM::SinOp, LLVM::SqrtOp>();
target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
LLVM::SqrtOp>();

// TODO: Remove once we support replacing non-root ops.
target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
Expand All @@ -322,11 +321,9 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
template <typename OpTy>
static void populateOpPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns, StringRef f32Func,
StringRef f64Func,
StringRef f32ApproxFunc = "") {
StringRef f64Func) {
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
f32ApproxFunc);
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
}

void mlir::populateGpuSubgroupReduceOpLoweringPattern(
Expand Down Expand Up @@ -373,68 +370,42 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
StringAttr::get(&converter.getContext(),
NVVM::NVVMDialect::getMaxntidAttrName()));

populateOpPatterns<arith::RemFOp>(converter, patterns, "__nv_fmodf",
"__nv_fmod");
populateOpPatterns<math::AbsFOp>(converter, patterns, "__nv_fabsf",
"__nv_fabs");
populateOpPatterns<math::AcosOp>(converter, patterns, "__nv_acosf",
"__nv_acos");
populateOpPatterns<math::AcoshOp>(converter, patterns, "__nv_acoshf",
"__nv_acosh");
populateOpPatterns<math::AsinOp>(converter, patterns, "__nv_asinf",
"__nv_asin");
populateOpPatterns<math::AsinhOp>(converter, patterns, "__nv_asinhf",
"__nv_asinh");
populateOpPatterns<math::AtanOp>(converter, patterns, "__nv_atanf",
"__nv_atan");
populateOpPatterns<math::Atan2Op>(converter, patterns, "__nv_atan2f",
"__nv_atan2");
populateOpPatterns<math::AtanhOp>(converter, patterns, "__nv_atanhf",
"__nv_atanh");
populateOpPatterns<math::CbrtOp>(converter, patterns, "__nv_cbrtf",
"__nv_cbrt");
populateOpPatterns<math::CeilOp>(converter, patterns, "__nv_ceilf",
"__nv_ceil");
populateOpPatterns<math::CopySignOp>(converter, patterns, "__nv_copysignf",
"__nv_copysign");
populateOpPatterns<math::CosOp>(converter, patterns, "__nv_cosf", "__nv_cos",
"__nv_fast_cosf");
populateOpPatterns<math::CoshOp>(converter, patterns, "__nv_coshf",
"__nv_cosh");
populateOpPatterns<math::CosOp>(converter, patterns, "__nv_cosf", "__nv_cos");
populateOpPatterns<math::ErfOp>(converter, patterns, "__nv_erff", "__nv_erf");
populateOpPatterns<math::ExpOp>(converter, patterns, "__nv_expf", "__nv_exp",
"__nv_fast_expf");
populateOpPatterns<math::ExpOp>(converter, patterns, "__nv_expf", "__nv_exp");
populateOpPatterns<math::Exp2Op>(converter, patterns, "__nv_exp2f",
"__nv_exp2");
populateOpPatterns<math::ExpM1Op>(converter, patterns, "__nv_expm1f",
"__nv_expm1");
populateOpPatterns<math::FloorOp>(converter, patterns, "__nv_floorf",
"__nv_floor");
populateOpPatterns<math::FmaOp>(converter, patterns, "__nv_fmaf", "__nv_fma");
populateOpPatterns<math::LogOp>(converter, patterns, "__nv_logf", "__nv_log",
"__nv_fast_logf");
populateOpPatterns<math::Log10Op>(converter, patterns, "__nv_log10f",
"__nv_log10", "__nv_fast_log10f");
populateOpPatterns<arith::RemFOp>(converter, patterns, "__nv_fmodf",
"__nv_fmod");
populateOpPatterns<math::LogOp>(converter, patterns, "__nv_logf", "__nv_log");
populateOpPatterns<math::Log1pOp>(converter, patterns, "__nv_log1pf",
"__nv_log1p");
populateOpPatterns<math::Log10Op>(converter, patterns, "__nv_log10f",
"__nv_log10");
populateOpPatterns<math::Log2Op>(converter, patterns, "__nv_log2f",
"__nv_log2", "__nv_fast_log2f");
populateOpPatterns<math::PowFOp>(converter, patterns, "__nv_powf", "__nv_pow",
"__nv_fast_powf");
populateOpPatterns<math::RoundOp>(converter, patterns, "__nv_roundf",
"__nv_round");
populateOpPatterns<math::RoundEvenOp>(converter, patterns, "__nv_rintf",
"__nv_rint");
"__nv_log2");
populateOpPatterns<math::PowFOp>(converter, patterns, "__nv_powf",
"__nv_pow");
populateOpPatterns<math::RsqrtOp>(converter, patterns, "__nv_rsqrtf",
"__nv_rsqrt");
populateOpPatterns<math::SinOp>(converter, patterns, "__nv_sinf", "__nv_sin",
"__nv_fast_sinf");
populateOpPatterns<math::SinhOp>(converter, patterns, "__nv_sinhf",
"__nv_sinh");
populateOpPatterns<math::SinOp>(converter, patterns, "__nv_sinf", "__nv_sin");
populateOpPatterns<math::SqrtOp>(converter, patterns, "__nv_sqrtf",
"__nv_sqrt");
populateOpPatterns<math::TanOp>(converter, patterns, "__nv_tanf", "__nv_tan",
"__nv_fast_tanf");
populateOpPatterns<math::TanhOp>(converter, patterns, "__nv_tanhf",
"__nv_tanh");
populateOpPatterns<math::TanOp>(converter, patterns, "__nv_tanf", "__nv_tan");
}
6 changes: 2 additions & 4 deletions mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,9 @@ using namespace mlir;
template <typename OpTy>
static void populateOpPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns, StringRef f32Func,
StringRef f64Func,
StringRef f32ApproxFunc = "") {
StringRef f64Func) {
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
f32ApproxFunc);
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
}

void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
Expand Down
Loading

0 comments on commit f9f4932

Please sign in to comment.