Skip to content

Commit

Permalink
dialects: (llvm) added overflow flags to trunc (#3422)
Browse files Browse the repository at this point in the history
  • Loading branch information
lfrenot authored Nov 11, 2024
1 parent e03689e commit 5df40db
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 14 deletions.
3 changes: 3 additions & 0 deletions tests/filecheck/dialects/llvm/arithmetic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@
%trunc = llvm.trunc %arg0 : i32 to i16
// CHECK: %trunc = llvm.trunc %arg0 : i32 to i16

%trunc_overflow = llvm.trunc %arg0 overflow<nsw> : i32 to i16
// CHECK: %trunc_overflow = llvm.trunc %arg0 overflow<nsw> : i32 to i16

%sext = llvm.sext %arg0 : i32 to i64
// CHECK: %sext = llvm.sext %arg0 : i32 to i64

Expand Down
72 changes: 58 additions & 14 deletions xdsl/dialects/llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,17 @@ class OverflowAttrBase(BitEnumAttribute[OverflowFlag]):
class OverflowAttr(OverflowAttrBase):
name = "llvm.overflow"

@classmethod
def parse(cls, parser: Parser) -> OverflowAttr:
if parser.parse_optional_keyword("overflow") is not None:
return OverflowAttr(OverflowAttr.parse_parameter(parser))
return OverflowAttr("none")

def print(self, printer: Printer):
if self.flags:
printer.print(" overflow")
self.print_parameter(printer)


class ArithmeticBinOpOverflow(IRDLOperation, ABC):
"""Class for arithmetic binary operations that use overflow flags."""
Expand Down Expand Up @@ -443,23 +454,12 @@ def __init__(
},
)

@classmethod
def parse_overflow(cls, parser: Parser) -> OverflowAttr:
if parser.parse_optional_keyword("overflow") is not None:
return OverflowAttr(OverflowAttr.parse_parameter(parser))
return OverflowAttr("none")

def print_overflow(self, printer: Printer) -> None:
if self.overflowFlags and self.overflowFlags.flags:
printer.print(" overflow")
self.overflowFlags.print_parameter(printer)

@classmethod
def parse(cls, parser: Parser):
lhs = parser.parse_unresolved_operand()
parser.parse_characters(",")
rhs = parser.parse_unresolved_operand()
overflowFlags = cls.parse_overflow(parser)
overflowFlags = OverflowAttr.parse(parser)
attributes = parser.parse_optional_attr_dict()
parser.parse_characters(":")
type = parser.parse_type()
Expand All @@ -468,7 +468,8 @@ def parse(cls, parser: Parser):

def print(self, printer: Printer) -> None:
printer.print(" ", self.lhs, ", ", self.rhs)
self.print_overflow(printer)
if self.overflowFlags:
self.overflowFlags.print(printer)
printer.print_op_attributes(self.attributes)
printer.print(" : ")
printer.print(self.lhs.type)
Expand Down Expand Up @@ -589,6 +590,49 @@ def __init__(
)


class IntegerConversionOpOverflow(IRDLOperation, ABC):
arg = operand_def(IntegerType)
res = result_def(IntegerType)
overflowFlags = opt_prop_def(OverflowAttr)
traits = traits_def(NoMemoryEffect())

def __init__(
self,
arg: SSAValue,
res_type: Attribute,
attributes: dict[str, Attribute] = {},
overflow: OverflowAttr = OverflowAttr(None),
):
super().__init__(
operands=(arg,),
attributes=attributes,
result_types=(res_type,),
properties={
"overflowFlags": overflow,
},
)

@classmethod
def parse(cls, parser: Parser):
arg = parser.parse_unresolved_operand()
overflowFlags = OverflowAttr.parse(parser)
attributes = parser.parse_optional_attr_dict()
parser.parse_characters(":")
arg_type = parser.parse_type()
parser.parse_characters("to")
res_type = parser.parse_type()
operands = parser.resolve_operands([arg], [arg_type], parser.pos)
return cls(operands[0], res_type, attributes, overflowFlags)

def print(self, printer: Printer):
printer.print(" ", self.arg)
if self.overflowFlags:
self.overflowFlags.print(printer)
printer.print_op_attributes(self.attributes)
printer.print(" : ")
printer.print(self.arg.type, " to ", self.res.type)


@irdl_op_definition
class AddOp(ArithmeticBinOpOverflow):
name = "llvm.add"
Expand Down Expand Up @@ -655,7 +699,7 @@ class AShrOp(ArithmeticBinOpExact):


@irdl_op_definition
class TruncOp(IntegerConversionOp):
class TruncOp(IntegerConversionOpOverflow):
name = "llvm.trunc"

def verify(self, verify_nested_ops: bool = True):
Expand Down

0 comments on commit 5df40db

Please sign in to comment.