Skip to content

Commit

Permalink
dialects: (scf) add assembly format to scf.condition
Browse files Browse the repository at this point in the history
  • Loading branch information
superlopuh committed Nov 11, 2024
1 parent 97cfc1f commit e55400b
Showing 1 changed file with 7 additions and 44 deletions.
51 changes: 7 additions & 44 deletions xdsl/dialects/scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,56 +639,19 @@ def __init__(self, result: SSAValue | Operation):
@irdl_op_definition
class Condition(IRDLOperation):
name = "scf.condition"
cond = operand_def(IntegerType(1))
arguments = var_operand_def()
condition = operand_def(IntegerType(1))
args = var_operand_def()

traits = traits_def(HasParent(While), IsTerminator(), Pure())

assembly_format = "`(` $condition `)` attr-dict ($args^ `:` type($args))?"

def __init__(
self,
cond: SSAValue | Operation,
*output_ops: SSAValue | Operation,
condition: SSAValue | Operation,
*args: SSAValue | Operation,
):
super().__init__(operands=[cond, [output for output in output_ops]])

def print(self, printer: Printer):
printer.print("(", self.cond, ")")
if self.attributes:
printer.print_op_attributes(self.attributes)
if self.arguments:
printer.print(" ")
printer.print_list(self.arguments, printer.print_ssa_value)
printer.print_string(" : ")
printer.print_list(
self.arguments, lambda val: printer.print_attribute(val.type)
)

@classmethod
def parse(cls, parser: Parser) -> Self:
parser.parse_punctuation("(")
unresolved_cond = parser.parse_unresolved_operand("cond expected")
parser.parse_punctuation(")")
cond = parser.resolve_operand(unresolved_cond, IntegerType(1))
attrs = parser.parse_optional_attr_dict()

# scf.condition is a terminator, so the list of arguments cannot be confused with
# the results of a hypothetical operation on the next line.
pos = parser.pos
unresolved_arguments = parser.parse_optional_undelimited_comma_separated_list(
parser.parse_optional_unresolved_operand, parser.parse_unresolved_operand
)
if unresolved_arguments is not None:
parser.parse_punctuation(":")
types = parser.parse_comma_separated_list(
parser.Delimiter.NONE, parser.parse_type
)
arguments = parser.resolve_operands(unresolved_arguments, types, pos)
else:
arguments: Sequence[SSAValue] = ()

op = cls(cond, *arguments)
op.attributes = attrs
return op
super().__init__(operands=(condition, args))


@irdl_op_definition
Expand Down

0 comments on commit e55400b

Please sign in to comment.