Skip to content

Commit

Permalink
[MooreToCore] Support pows and powu op
Browse files Browse the repository at this point in the history
  • Loading branch information
Max-astro authored and hailongSun2000 committed Dec 6, 2024
1 parent 4a73177 commit 70baf49
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 3 deletions.
87 changes: 84 additions & 3 deletions lib/Conversion/MooreToCore/MooreToCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1188,6 +1188,84 @@ struct ShrOpConversion : public OpConversionPattern<ShrOp> {
}
};

struct PowUOpConversion : public OpConversionPattern<PowUOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(PowUOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type resultType = typeConverter->convertType(op.getResult().getType());

Location loc = op.getLoc();
auto intType = cast<IntType>(op.getRhs().getType());

// transform a ** b into scf.for 0 to b step 1 { init *= a }, init = 1
Type integerType = rewriter.getIntegerType(intType.getWidth());
Value lowerBound = rewriter.create<hw::ConstantOp>(loc, integerType, 0);
Value upperBound =
rewriter.create<ConversionOp>(loc, integerType, op.getRhs());
Value step = rewriter.create<hw::ConstantOp>(loc, integerType, 1);

Value initVal = rewriter.create<hw::ConstantOp>(loc, resultType, 1);
Value lhsVal = rewriter.create<ConversionOp>(loc, resultType, op.getLhs());

auto forOp = rewriter.create<scf::ForOp>(
loc, lowerBound, upperBound, step, ValueRange(initVal),
[&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) {
Value loopVar = iterArgs.front();
Value mul = rewriter.create<comb::MulOp>(loc, lhsVal, loopVar);
rewriter.create<scf::YieldOp>(loc, ValueRange(mul));
});

rewriter.replaceOp(op, forOp.getResult(0));

return success();
}
};

struct PowSOpConversion : public OpConversionPattern<PowSOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(PowSOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type resultType = typeConverter->convertType(op.getResult().getType());

Location loc = op.getLoc();
auto intType = cast<IntType>(op.getRhs().getType());
// transform a ** b into scf.for 0 to b step 1 { init *= a }, init = 1
Type integerType = rewriter.getIntegerType(intType.getWidth());
Value lhsVal = rewriter.create<ConversionOp>(loc, resultType, op.getLhs());
Value rhsVal = rewriter.create<ConversionOp>(loc, integerType, op.getRhs());
Value constZero = rewriter.create<hw::ConstantOp>(loc, integerType, 0);
Value constZeroResult = rewriter.create<hw::ConstantOp>(loc, resultType, 0);
Value isNegative = rewriter.create<comb::ICmpOp>(loc, ICmpPredicate::slt,
rhsVal, constZero);

// if the exponent is negative, return 0
lhsVal =
rewriter.create<comb::MuxOp>(loc, isNegative, constZeroResult, lhsVal);
Value upperBound =
rewriter.create<comb::MuxOp>(loc, isNegative, constZero, rhsVal);

Value lowerBound = constZero;
Value step = rewriter.create<hw::ConstantOp>(loc, integerType, 1);
Value initVal = rewriter.create<hw::ConstantOp>(loc, resultType, 1);

auto forOp = rewriter.create<scf::ForOp>(
loc, lowerBound, upperBound, step, ValueRange(initVal),
[&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) {
auto loopVar = iterArgs.front();
auto mul = rewriter.create<comb::MulOp>(loc, lhsVal, loopVar);
rewriter.create<scf::YieldOp>(loc, ValueRange(mul));
});

rewriter.replaceOp(op, forOp.getResult(0));

return success();
}
};

struct AShrOpConversion : public OpConversionPattern<AShrOp> {
using OpConversionPattern::OpConversionPattern;

Expand Down Expand Up @@ -1430,9 +1508,9 @@ static void populateLegality(ConversionTarget &target,
target.addLegalOp<debug::ScopeOp>();

target.addDynamicallyLegalOp<
cf::CondBranchOp, cf::BranchOp, scf::IfOp, scf::YieldOp, func::CallOp,
func::ReturnOp, UnrealizedConversionCastOp, hw::OutputOp, hw::InstanceOp,
debug::ArrayOp, debug::StructOp, debug::VariableOp>(
cf::CondBranchOp, cf::BranchOp, scf::IfOp, scf::ForOp, scf::YieldOp,
func::CallOp, func::ReturnOp, UnrealizedConversionCastOp, hw::OutputOp,
hw::InstanceOp, debug::ArrayOp, debug::StructOp, debug::VariableOp>(
[&](Operation *op) { return converter.isLegal(op); });

target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
Expand Down Expand Up @@ -1590,6 +1668,9 @@ static void populateOpConversion(RewritePatternSet &patterns,
BinaryOpConversion<OrOp, comb::OrOp>,
BinaryOpConversion<XorOp, comb::XorOp>,

// Patterns of power operations.
PowUOpConversion, PowSOpConversion,

// Patterns of relational operations.
ICmpOpConversion<UltOp, ICmpPredicate::ult>,
ICmpOpConversion<SltOp, ICmpPredicate::slt>,
Expand Down
22 changes: 22 additions & 0 deletions test/Conversion/MooreToCore/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -985,3 +985,25 @@ func.func @Conversions(%arg0: !moore.i16, %arg1: !moore.l16) {

return
}

// CHECK-LABEL: func.func @PowUOp
func.func @PowUOp(%arg0: !moore.l32, %arg1: !moore.l32) {
// CHECK: %{{.*}} = scf.for %{{.*}} = %{{.*}} to %arg1 step %{{.*}} iter_args([[VAR:%.+]] = %{{.*}}) -> (i32) : i32 {
// CHECK: [[MUL:%.+]] = comb.mul %arg0, [[VAR]] : i32
// CHECK: scf.yield [[MUL]] : i32
%0 = moore.powu %arg0, %arg1 : l32
return
}

// CHECK-LABEL: func.func @PowSOp
func.func @PowSOp(%arg0: !moore.i32, %arg1: !moore.i32) {
// CHECK: [[COND:%.+]] = comb.icmp slt %arg1, %{{.*}} : i32
// CHECK: [[BASE:%.+]] = comb.mux [[COND]], %{{.*}}, %arg0 : i32
// CHECK: [[EXP:%.+]] = comb.mux [[COND]], %{{.*}}, %arg1 : i32

// CHECK: %{{.*}} = scf.for %{{.*}} = %{{.*}} to [[EXP]] step %{{.*}} iter_args([[VAR:%.+]] = %{{.*}}) -> (i32) : i32 {
// CHECK: [[MUL:%.+]] = comb.mul [[BASE]], [[VAR]] : i32
// CHECK: scf.yield [[MUL]] : i32
%0 = moore.pows %arg0, %arg1 : i32
return
}

1 comment on commit 70baf49

@hailongSun2000
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The related PR: [MooreToCore] Support pows and powu op #7899.

Please sign in to comment.