24
24
#include " mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
25
25
#include " mlir/Dialect/Func/IR/FuncOps.h"
26
26
#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
27
+ #include " mlir/Dialect/Math/IR/Math.h"
27
28
#include " mlir/Dialect/MemRef/IR/MemRef.h"
28
29
#include " mlir/Dialect/SCF/IR/SCF.h"
29
30
#include " mlir/Dialect/SCF/Transforms/Passes.h"
@@ -68,7 +69,7 @@ struct ConvertCIRToMLIRPass
68
69
registry.insert <mlir::BuiltinDialect, mlir::func::FuncDialect,
69
70
mlir::affine::AffineDialect, mlir::memref::MemRefDialect,
70
71
mlir::arith::ArithDialect, mlir::cf::ControlFlowDialect,
71
- mlir::scf::SCFDialect>();
72
+ mlir::scf::SCFDialect, mlir::math::MathDialect >();
72
73
}
73
74
void runOnOperation () final ;
74
75
@@ -140,6 +141,18 @@ class CIRStoreOpLowering
140
141
}
141
142
};
142
143
144
+ class CIRCosOpLowering : public mlir ::OpConversionPattern<mlir::cir::CosOp> {
145
+ public:
146
+ using OpConversionPattern<mlir::cir::CosOp>::OpConversionPattern;
147
+
148
+ mlir::LogicalResult
149
+ matchAndRewrite (mlir::cir::CosOp op, OpAdaptor adaptor,
150
+ mlir::ConversionPatternRewriter &rewriter) const override {
151
+ rewriter.replaceOpWithNewOp <mlir::math::CosOp>(op, adaptor.getSrc ());
152
+ return mlir::LogicalResult::success ();
153
+ }
154
+ };
155
+
143
156
class CIRConstantOpLowering
144
157
: public mlir::OpConversionPattern<mlir::cir::ConstantOp> {
145
158
public:
@@ -153,6 +166,11 @@ class CIRConstantOpLowering
153
166
if (mlir::isa<mlir::cir::BoolType>(op.getType ())) {
154
167
auto boolValue = mlir::cast<mlir::cir::BoolAttr>(op.getValue ());
155
168
value = rewriter.getIntegerAttr (ty, boolValue.getValue ());
169
+ } else if (op.getType ().isa <mlir::cir::CIRFPTypeInterface>()) {
170
+ assert (ty.isF32 () || ty.isF64 () && " NYI" );
171
+ value = rewriter.getFloatAttr (
172
+ typeConverter->convertType (op.getType ()),
173
+ op.getValue ().cast <mlir::cir::FPAttr>().getValue ());
156
174
} else {
157
175
auto cirIntAttr = mlir::dyn_cast<mlir::cir::IntAttr>(op.getValue ());
158
176
assert (cirIntAttr && " NYI non cir.int attr" );
@@ -612,7 +630,8 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
612
630
CIRBinOpLowering, CIRLoadOpLowering, CIRConstantOpLowering,
613
631
CIRStoreOpLowering, CIRAllocaOpLowering, CIRFuncOpLowering,
614
632
CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
615
- CIRYieldOpLowering>(converter, patterns.getContext ());
633
+ CIRYieldOpLowering, CIRCosOpLowering>(converter,
634
+ patterns.getContext ());
616
635
}
617
636
618
637
static mlir::TypeConverter prepareTypeConverter () {
@@ -666,7 +685,8 @@ void ConvertCIRToMLIRPass::runOnOperation() {
666
685
target.addLegalOp <mlir::ModuleOp>();
667
686
target.addLegalDialect <mlir::affine::AffineDialect, mlir::arith::ArithDialect,
668
687
mlir::memref::MemRefDialect, mlir::func::FuncDialect,
669
- mlir::scf::SCFDialect, mlir::cf::ControlFlowDialect>();
688
+ mlir::scf::SCFDialect, mlir::cf::ControlFlowDialect,
689
+ mlir::math::MathDialect>();
670
690
target.addIllegalDialect <mlir::cir::CIRDialect>();
671
691
672
692
if (failed (applyPartialConversion (module , target, std::move (patterns))))
0 commit comments