Skip to content

Commit

Permalink
[CIR][Lowering] Add MLIR lowering support for CIR cos operations (#565)
Browse files Browse the repository at this point in the history
#563 This PR add cir.cos lowering to MLIR math dialect, now it only
surpport single and double float types, I add an assertation for the
long double and other unimplemented types

---------

Signed-off-by: zhoujing <jing.zhou@terapines.com>
  • Loading branch information
zhoujingya authored and lanza committed Apr 29, 2024
1 parent bdd4aa4 commit 123cc2f
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 3 deletions.
26 changes: 23 additions & 3 deletions clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Passes.h"
Expand Down Expand Up @@ -68,7 +69,7 @@ struct ConvertCIRToMLIRPass
registry.insert<mlir::BuiltinDialect, mlir::func::FuncDialect,
mlir::affine::AffineDialect, mlir::memref::MemRefDialect,
mlir::arith::ArithDialect, mlir::cf::ControlFlowDialect,
mlir::scf::SCFDialect>();
mlir::scf::SCFDialect, mlir::math::MathDialect>();
}
void runOnOperation() final;

Expand Down Expand Up @@ -140,6 +141,18 @@ class CIRStoreOpLowering
}
};

class CIRCosOpLowering : public mlir::OpConversionPattern<mlir::cir::CosOp> {
public:
using OpConversionPattern<mlir::cir::CosOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(mlir::cir::CosOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<mlir::math::CosOp>(op, adaptor.getSrc());
return mlir::LogicalResult::success();
}
};

class CIRConstantOpLowering
: public mlir::OpConversionPattern<mlir::cir::ConstantOp> {
public:
Expand All @@ -153,6 +166,11 @@ class CIRConstantOpLowering
if (mlir::isa<mlir::cir::BoolType>(op.getType())) {
auto boolValue = mlir::cast<mlir::cir::BoolAttr>(op.getValue());
value = rewriter.getIntegerAttr(ty, boolValue.getValue());
} else if (op.getType().isa<mlir::cir::CIRFPTypeInterface>()) {
assert(ty.isF32() || ty.isF64() && "NYI");
value = rewriter.getFloatAttr(
typeConverter->convertType(op.getType()),
op.getValue().cast<mlir::cir::FPAttr>().getValue());
} else {
auto cirIntAttr = mlir::dyn_cast<mlir::cir::IntAttr>(op.getValue());
assert(cirIntAttr && "NYI non cir.int attr");
Expand Down Expand Up @@ -612,7 +630,8 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
CIRBinOpLowering, CIRLoadOpLowering, CIRConstantOpLowering,
CIRStoreOpLowering, CIRAllocaOpLowering, CIRFuncOpLowering,
CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
CIRYieldOpLowering>(converter, patterns.getContext());
CIRYieldOpLowering, CIRCosOpLowering>(converter,
patterns.getContext());
}

static mlir::TypeConverter prepareTypeConverter() {
Expand Down Expand Up @@ -666,7 +685,8 @@ void ConvertCIRToMLIRPass::runOnOperation() {
target.addLegalOp<mlir::ModuleOp>();
target.addLegalDialect<mlir::affine::AffineDialect, mlir::arith::ArithDialect,
mlir::memref::MemRefDialect, mlir::func::FuncDialect,
mlir::scf::SCFDialect, mlir::cf::ControlFlowDialect>();
mlir::scf::SCFDialect, mlir::cf::ControlFlowDialect,
mlir::math::MathDialect>();
target.addIllegalDialect<mlir::cir::CIRDialect>();

if (failed(applyPartialConversion(module, target, std::move(patterns))))
Expand Down
22 changes: 22 additions & 0 deletions clang/test/CIR/Lowering/ThroughMLIR/cos.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// RUN: cir-opt %s -cir-to-mlir -o %t.mlir
// RUN: FileCheck %s --input-file %t.mlir

module {
cir.func @foo() {
%1 = cir.const(#cir.fp<1.0> : !cir.float) : !cir.float
%2 = cir.const(#cir.fp<1.0> : !cir.double) : !cir.double
%3 = cir.cos %1 : !cir.float
%4 = cir.cos %2 : !cir.double
cir.return
}
}

//CHECK: module {
//CHECK: func.func @foo() {
//CHECK: %cst = arith.constant 1.000000e+00 : f32
//CHECK: %cst_0 = arith.constant 1.000000e+00 : f64
//CHECK: %0 = math.cos %cst : f32
//CHECK: %1 = math.cos %cst_0 : f64
//CHECK: return
//CHECK: }
//CHECK: }

0 comments on commit 123cc2f

Please sign in to comment.