Skip to content

Commit ae8f207

Browse files
zhoujingyalanza
authored andcommitted
[CIR][Lowering] Add MLIR lowering support for CIR cos operations (llvm#565)
llvm#563 This PR add cir.cos lowering to MLIR math dialect, now it only surpport single and double float types, I add an assertation for the long double and other unimplemented types --------- Signed-off-by: zhoujing <jing.zhou@terapines.com>
1 parent 3a9a13d commit ae8f207

File tree

2 files changed

+45
-3
lines changed

2 files changed

+45
-3
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
2525
#include "mlir/Dialect/Func/IR/FuncOps.h"
2626
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
27+
#include "mlir/Dialect/Math/IR/Math.h"
2728
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2829
#include "mlir/Dialect/SCF/IR/SCF.h"
2930
#include "mlir/Dialect/SCF/Transforms/Passes.h"
@@ -68,7 +69,7 @@ struct ConvertCIRToMLIRPass
6869
registry.insert<mlir::BuiltinDialect, mlir::func::FuncDialect,
6970
mlir::affine::AffineDialect, mlir::memref::MemRefDialect,
7071
mlir::arith::ArithDialect, mlir::cf::ControlFlowDialect,
71-
mlir::scf::SCFDialect>();
72+
mlir::scf::SCFDialect, mlir::math::MathDialect>();
7273
}
7374
void runOnOperation() final;
7475

@@ -140,6 +141,18 @@ class CIRStoreOpLowering
140141
}
141142
};
142143

144+
class CIRCosOpLowering : public mlir::OpConversionPattern<mlir::cir::CosOp> {
145+
public:
146+
using OpConversionPattern<mlir::cir::CosOp>::OpConversionPattern;
147+
148+
mlir::LogicalResult
149+
matchAndRewrite(mlir::cir::CosOp op, OpAdaptor adaptor,
150+
mlir::ConversionPatternRewriter &rewriter) const override {
151+
rewriter.replaceOpWithNewOp<mlir::math::CosOp>(op, adaptor.getSrc());
152+
return mlir::LogicalResult::success();
153+
}
154+
};
155+
143156
class CIRConstantOpLowering
144157
: public mlir::OpConversionPattern<mlir::cir::ConstantOp> {
145158
public:
@@ -153,6 +166,11 @@ class CIRConstantOpLowering
153166
if (mlir::isa<mlir::cir::BoolType>(op.getType())) {
154167
auto boolValue = mlir::cast<mlir::cir::BoolAttr>(op.getValue());
155168
value = rewriter.getIntegerAttr(ty, boolValue.getValue());
169+
} else if (op.getType().isa<mlir::cir::CIRFPTypeInterface>()) {
170+
assert(ty.isF32() || ty.isF64() && "NYI");
171+
value = rewriter.getFloatAttr(
172+
typeConverter->convertType(op.getType()),
173+
op.getValue().cast<mlir::cir::FPAttr>().getValue());
156174
} else {
157175
auto cirIntAttr = mlir::dyn_cast<mlir::cir::IntAttr>(op.getValue());
158176
assert(cirIntAttr && "NYI non cir.int attr");
@@ -612,7 +630,8 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
612630
CIRBinOpLowering, CIRLoadOpLowering, CIRConstantOpLowering,
613631
CIRStoreOpLowering, CIRAllocaOpLowering, CIRFuncOpLowering,
614632
CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
615-
CIRYieldOpLowering>(converter, patterns.getContext());
633+
CIRYieldOpLowering, CIRCosOpLowering>(converter,
634+
patterns.getContext());
616635
}
617636

618637
static mlir::TypeConverter prepareTypeConverter() {
@@ -666,7 +685,8 @@ void ConvertCIRToMLIRPass::runOnOperation() {
666685
target.addLegalOp<mlir::ModuleOp>();
667686
target.addLegalDialect<mlir::affine::AffineDialect, mlir::arith::ArithDialect,
668687
mlir::memref::MemRefDialect, mlir::func::FuncDialect,
669-
mlir::scf::SCFDialect, mlir::cf::ControlFlowDialect>();
688+
mlir::scf::SCFDialect, mlir::cf::ControlFlowDialect,
689+
mlir::math::MathDialect>();
670690
target.addIllegalDialect<mlir::cir::CIRDialect>();
671691

672692
if (failed(applyPartialConversion(module, target, std::move(patterns))))
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: cir-opt %s -cir-to-mlir -o %t.mlir
2+
// RUN: FileCheck %s --input-file %t.mlir
3+
4+
module {
5+
cir.func @foo() {
6+
%1 = cir.const(#cir.fp<1.0> : !cir.float) : !cir.float
7+
%2 = cir.const(#cir.fp<1.0> : !cir.double) : !cir.double
8+
%3 = cir.cos %1 : !cir.float
9+
%4 = cir.cos %2 : !cir.double
10+
cir.return
11+
}
12+
}
13+
14+
//CHECK: module {
15+
//CHECK: func.func @foo() {
16+
//CHECK: %cst = arith.constant 1.000000e+00 : f32
17+
//CHECK: %cst_0 = arith.constant 1.000000e+00 : f64
18+
//CHECK: %0 = math.cos %cst : f32
19+
//CHECK: %1 = math.cos %cst_0 : f64
20+
//CHECK: return
21+
//CHECK: }
22+
//CHECK: }

0 commit comments

Comments
 (0)