diff --git a/tests/dialects/test_builtin.py b/tests/dialects/test_builtin.py index c946bd3162..bccbcff5cc 100644 --- a/tests/dialects/test_builtin.py +++ b/tests/dialects/test_builtin.py @@ -119,6 +119,26 @@ def test_IntegerType_normalized(): assert ui8.normalized_value(255) == 255 +def test_IntegerType_truncated(): + si8 = IntegerType(8, Signedness.SIGNED) + ui8 = IntegerType(8, Signedness.UNSIGNED) + + assert i8.normalized_value(-1, truncate_bits=True) == -1 + assert i8.normalized_value(1, truncate_bits=True) == 1 + assert i8.normalized_value(255, truncate_bits=True) == -1 + assert i8.normalized_value(256, truncate_bits=True) == 0 + + assert si8.normalized_value(-1, truncate_bits=True) == -1 + assert si8.normalized_value(1, truncate_bits=True) == 1 + assert si8.normalized_value(255, truncate_bits=True) == -1 + assert si8.normalized_value(256, truncate_bits=True) == 0 + + assert ui8.normalized_value(-1, truncate_bits=True) == 255 + assert ui8.normalized_value(1, truncate_bits=True) == 1 + assert ui8.normalized_value(255, truncate_bits=True) == 255 + assert ui8.normalized_value(256, truncate_bits=True) == 0 + + def test_IntegerAttr_normalize(): """ Test that the value within the accepted signless range is normalized to signed diff --git a/tests/filecheck/transforms/individual_rewrite/add-same.mlir b/tests/filecheck/transforms/individual_rewrite/add-same.mlir new file mode 100644 index 0000000000..5c4813755e --- /dev/null +++ b/tests/filecheck/transforms/individual_rewrite/add-same.mlir @@ -0,0 +1,18 @@ +// RUN:xdsl-opt %s --split-input-file -p 'apply-individual-rewrite{matched_operation_index=2 operation_name="arith.addi" pattern_name="AdditionOfSameVariablesToMultiplyByTwo"}'| filecheck %s + + +// CHECK: %v = "test.op"() : () -> i32 +// CHECK-NEXT: %[[#two:]] = arith.constant 2 : i32 +// CHECK-NEXT: %{{.*}} = arith.muli %v, %[[#two]] : i32 + +%v = "test.op"() : () -> (i32) +%1 = arith.addi %v, %v : i32 + +// ----- + +// CHECK: %v = "test.op"() : () -> i1 +// CHECK-NEXT: %[[#zero:]] = arith.constant false +// CHECK-NEXT: %{{.*}} = arith.muli %v, %[[#zero]] : i1 + +%v = "test.op"() : () -> (i1) +%1 = arith.addi %v, %v : i1 diff --git a/tests/filecheck/transforms/individual_rewrite.mlir b/tests/filecheck/transforms/individual_rewrite/riscv.mlir similarity index 100% rename from tests/filecheck/transforms/individual_rewrite.mlir rename to tests/filecheck/transforms/individual_rewrite/riscv.mlir diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index 54315b7f14..918859c9f7 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -489,7 +489,9 @@ def verify_value(self, value: int): f"values in the range [{min_value}, {max_value})" ) - def normalized_value(self, value: int) -> int | None: + def normalized_value( + self, value: int, *, truncate_bits: bool = False + ) -> int | None: """ Signless values can represent integers from both the signed and unsigned ranges for a given bitwidth. @@ -497,13 +499,16 @@ def normalized_value(self, value: int) -> int | None: to the signed version (meaning ambiguous values will always be negative). For example, the bitpattern of all ones will always be represented as `-1` at runtime. - If the input value is outside of the valid range, return `None`. + If the input value is outside of the valid range, return `None` if `truncate_bits` + is false, otherwise returns a value in range by truncating the bits of the input. """ min_value, max_value = self.value_range() if not (min_value <= value < max_value): - return None + if not truncate_bits: + return None + value = value % (2**self.bitwidth) - if self.signedness.data == Signedness.SIGNLESS: + if self.signedness.data != Signedness.UNSIGNED: signed_ub = signed_upper_bound(self.bitwidth) unsigned_ub = unsigned_upper_bound(self.bitwidth) if signed_ub <= value: @@ -620,22 +625,34 @@ def __init__( self, value: int | IntAttr, value_type: _IntegerAttrType, + *, + truncate_bits: bool = False, ) -> None: ... @overload def __init__( - self: IntegerAttr[IntegerType], value: int | IntAttr, value_type: int + self: IntegerAttr[IntegerType], + value: int | IntAttr, + value_type: int, + *, + truncate_bits: bool = False, ) -> None: ... def __init__( - self, value: int | IntAttr, value_type: int | IntegerType | IndexType + self, + value: int | IntAttr, + value_type: int | IntegerType | IndexType, + *, + truncate_bits: bool = False, ) -> None: if isinstance(value_type, int): value_type = IntegerType(value_type) if isinstance(value, IntAttr): value = value.data if not isinstance(value_type, IndexType): - normalized_value = value_type.normalized_value(value) + normalized_value = value_type.normalized_value( + value, truncate_bits=truncate_bits + ) if normalized_value is not None: value = normalized_value super().__init__([IntAttr(value), value_type]) diff --git a/xdsl/transforms/individual_rewrite.py b/xdsl/transforms/individual_rewrite.py index 5b461c379b..aa122c5555 100644 --- a/xdsl/transforms/individual_rewrite.py +++ b/xdsl/transforms/individual_rewrite.py @@ -17,10 +17,10 @@ class AdditionOfSameVariablesToMultiplyByTwo(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: arith.AddiOp, rewriter: PatternRewriter) -> None: if op.lhs == op.rhs: - assert isinstance(op.lhs.type, IntegerType | IndexType) + assert isinstance(type := op.lhs.type, IntegerType | IndexType) rewriter.replace_matched_op( [ - li_op := arith.ConstantOp(IntegerAttr(2, op.lhs.type)), + li_op := arith.ConstantOp(IntegerAttr(2, type, truncate_bits=True)), arith.MuliOp(op.lhs, li_op), ] )