diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index ae8a6ef350ce1..3892e8fa0a32f 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -39,26 +39,26 @@ using namespace mlir::arith; static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, - function_ref binFn) { - return builder.getIntegerAttr(res.getType(), - binFn(llvm::cast(lhs).getInt(), - llvm::cast(rhs).getInt())); + function_ref binFn) { + APInt lhsVal = llvm::cast(lhs).getValue(); + APInt rhsVal = llvm::cast(rhs).getValue(); + APInt value = binFn(lhsVal, rhsVal); + return IntegerAttr::get(res.getType(), value); } static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs) { - return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus()); + return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus()); } static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs) { - return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus()); + return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus()); } static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs) { - return applyToIntegerAttrs(builder, res, lhs, rhs, - std::multiplies()); + return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies()); } /// Invert an integer comparison predicate. diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index f697f3d01458e..5e4476a21df04 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -909,6 +909,18 @@ func.func @tripleMulIMulII32(%arg0: i32) -> i32 { return %mul2 : i32 } +// CHECK-LABEL: @tripleMulLargeInt +// CHECK: %[[cres:.+]] = arith.constant 3618502788666131213697322783095070105623107215331596699973092056135872020482 : i256 +// CHECK: %[[addi:.+]] = arith.addi %arg0, %[[cres]] : i256 +// CHECK: return %[[addi]] +func.func @tripleMulLargeInt(%arg0: i256) -> i256 { + %0 = arith.constant 3618502788666131213697322783095070105623107215331596699973092056135872020481 : i256 + %1 = arith.constant 1 : i256 + %2 = arith.addi %arg0, %0 : i256 + %3 = arith.addi %2, %1 : i256 + return %3 : i256 +} + // CHECK-LABEL: @addiMuliToSubiRhsI32 // CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32) // CHECK: %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : i32