From aca885b1edaa21849bc63045e2bac243f3d24e78 Mon Sep 17 00:00:00 2001 From: zhoujingya <104264072+zhoujingya@users.noreply.github.com> Date: Fri, 26 Apr 2024 06:12:32 +0800 Subject: [PATCH] [CIR][Lowering] Add MLIR lowering support for CIR cos operations (#565) #563 This PR add cir.cos lowering to MLIR math dialect, now it only surpport single and double float types, I add an assertation for the long double and other unimplemented types --------- Signed-off-by: zhoujing --- .../Lowering/ThroughMLIR/LowerCIRToMLIR.cpp | 26 ++++++++++++++++--- clang/test/CIR/Lowering/ThroughMLIR/cos.cir | 22 ++++++++++++++++ 2 files changed, 45 insertions(+), 3 deletions(-) create mode 100644 clang/test/CIR/Lowering/ThroughMLIR/cos.cir diff --git a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp index d413307ce7ba..005f11e9b7a7 100644 --- a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp +++ b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp @@ -24,6 +24,7 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Passes.h" @@ -68,7 +69,7 @@ struct ConvertCIRToMLIRPass registry.insert(); + mlir::scf::SCFDialect, mlir::math::MathDialect>(); } void runOnOperation() final; @@ -140,6 +141,18 @@ class CIRStoreOpLowering } }; +class CIRCosOpLowering : public mlir::OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(mlir::cir::CosOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, adaptor.getSrc()); + return mlir::LogicalResult::success(); + } +}; + class CIRConstantOpLowering : public mlir::OpConversionPattern { public: @@ -153,6 +166,11 @@ class CIRConstantOpLowering if (mlir::isa(op.getType())) { auto boolValue = mlir::cast(op.getValue()); value = rewriter.getIntegerAttr(ty, boolValue.getValue()); + } else if (op.getType().isa()) { + assert(ty.isF32() || ty.isF64() && "NYI"); + value = rewriter.getFloatAttr( + typeConverter->convertType(op.getType()), + op.getValue().cast().getValue()); } else { auto cirIntAttr = mlir::dyn_cast(op.getValue()); assert(cirIntAttr && "NYI non cir.int attr"); @@ -612,7 +630,8 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns, CIRBinOpLowering, CIRLoadOpLowering, CIRConstantOpLowering, CIRStoreOpLowering, CIRAllocaOpLowering, CIRFuncOpLowering, CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering, - CIRYieldOpLowering>(converter, patterns.getContext()); + CIRYieldOpLowering, CIRCosOpLowering>(converter, + patterns.getContext()); } static mlir::TypeConverter prepareTypeConverter() { @@ -666,7 +685,8 @@ void ConvertCIRToMLIRPass::runOnOperation() { target.addLegalOp(); target.addLegalDialect(); + mlir::scf::SCFDialect, mlir::cf::ControlFlowDialect, + mlir::math::MathDialect>(); target.addIllegalDialect(); if (failed(applyPartialConversion(module, target, std::move(patterns)))) diff --git a/clang/test/CIR/Lowering/ThroughMLIR/cos.cir b/clang/test/CIR/Lowering/ThroughMLIR/cos.cir new file mode 100644 index 000000000000..05e913866202 --- /dev/null +++ b/clang/test/CIR/Lowering/ThroughMLIR/cos.cir @@ -0,0 +1,22 @@ +// RUN: cir-opt %s -cir-to-mlir -o %t.mlir +// RUN: FileCheck %s --input-file %t.mlir + +module { + cir.func @foo() { + %1 = cir.const(#cir.fp<1.0> : !cir.float) : !cir.float + %2 = cir.const(#cir.fp<1.0> : !cir.double) : !cir.double + %3 = cir.cos %1 : !cir.float + %4 = cir.cos %2 : !cir.double + cir.return + } +} + +//CHECK: module { +//CHECK: func.func @foo() { +//CHECK: %cst = arith.constant 1.000000e+00 : f32 +//CHECK: %cst_0 = arith.constant 1.000000e+00 : f64 +//CHECK: %0 = math.cos %cst : f32 +//CHECK: %1 = math.cos %cst_0 : f64 +//CHECK: return +//CHECK: } +//CHECK: }