diff --git a/tests/irdl/test_declarative_assembly_format.py b/tests/irdl/test_declarative_assembly_format.py index 50d9a01ac2..06b37a3557 100644 --- a/tests/irdl/test_declarative_assembly_format.py +++ b/tests/irdl/test_declarative_assembly_format.py @@ -37,12 +37,15 @@ opt_prop_def, opt_region_def, opt_result_def, + opt_successor_def, prop_def, region_def, result_def, + successor_def, var_operand_def, var_region_def, var_result_def, + var_successor_def, ) from xdsl.parser import Parser from xdsl.printer import Printer @@ -1329,6 +1332,182 @@ class OptionalOperandsOp(IRDLOperation): # pyright: ignore[reportUnusedClass] assembly_format = "attr-dict-with-keyword $region1 $region2" +################################################################################ +# Successors # +################################################################################ + + +def test_missing_successor(): + """Test that successors should be parsed.""" + with pytest.raises(PyRDLOpDefinitionError, match="successor 'successor' not found"): + + @irdl_op_definition + class NoSuccessorOp(IRDLOperation): # pyright: ignore[reportUnusedClass] + name = "test.no_successor_op" + successor = successor_def() + + assembly_format = "attr-dict-with-keyword" + + +def test_successors(): + """Test the parsing of successors""" + + program = textwrap.dedent( + """\ + "test.op"() ({ + "test.op"() [^0] : () -> () + ^0: + test.two_successors ^0 ^0 + }) : () -> ()""" + ) + + generic_program = textwrap.dedent( + """\ + "test.op"() ({ + "test.op"() [^0] : () -> () + ^0: + "test.two_successors"() [^0, ^0] : () -> () + }) : () -> ()""" + ) + + @irdl_op_definition + class TwoSuccessorsOp(IRDLOperation): + name = "test.two_successors" + fst = successor_def() + snd = successor_def() + + assembly_format = "$fst $snd attr-dict" + + ctx = MLContext() + ctx.load_op(TwoSuccessorsOp) + ctx.load_dialect(Test) + + check_roundtrip(program, ctx) + check_equivalence(program, generic_program, ctx) + + +@pytest.mark.parametrize( + "program, generic_program", + [ + ( + '"test.op"() ({\n "test.op"() [^0] : () -> ()\n^0:\n test.var_successor \n}) : () -> ()', + textwrap.dedent( + """\ + "test.op"() ({ + "test.op"() [^0] : () -> () + ^0: + "test.var_successor"() : () -> () + }) : () -> ()""" + ), + ), + ( + textwrap.dedent( + """\ + "test.op"() ({ + "test.op"() [^0] : () -> () + ^0: + test.var_successor ^0 + }) : () -> ()""" + ), + textwrap.dedent( + """\ + "test.op"() ({ + "test.op"() [^0] : () -> () + ^0: + "test.var_successor"() [^0] : () -> () + }) : () -> ()""" + ), + ), + ( + textwrap.dedent( + """\ + "test.op"() ({ + "test.op"() [^0] : () -> () + ^0: + test.var_successor ^0 ^0 + }) : () -> ()""" + ), + textwrap.dedent( + """\ + "test.op"() ({ + "test.op"() [^0] : () -> () + ^0: + "test.var_successor"() [^0, ^0] : () -> () + }) : () -> ()""" + ), + ), + ], +) +def test_variadic_successor(program: str, generic_program: str): + """Test the parsing of successors""" + + @irdl_op_definition + class VarSuccessorOp(IRDLOperation): + name = "test.var_successor" + succ = var_successor_def() + + assembly_format = "$succ attr-dict" + + ctx = MLContext() + ctx.load_op(VarSuccessorOp) + ctx.load_dialect(Test) + + check_roundtrip(program, ctx) + check_equivalence(program, generic_program, ctx) + + +@pytest.mark.parametrize( + "program, generic_program", + [ + ( + '"test.op"() ({\n "test.op"() [^0] : () -> ()\n^0:\n test.opt_successor \n}) : () -> ()', + textwrap.dedent( + """\ + "test.op"() ({ + "test.op"() [^0] : () -> () + ^0: + "test.opt_successor"() : () -> () + }) : () -> ()""" + ), + ), + ( + textwrap.dedent( + """\ + "test.op"() ({ + "test.op"() [^0] : () -> () + ^0: + test.opt_successor ^0 + }) : () -> ()""" + ), + textwrap.dedent( + """\ + "test.op"() ({ + "test.op"() [^0] : () -> () + ^0: + "test.opt_successor"() [^0] : () -> () + }) : () -> ()""" + ), + ), + ], +) +def test_optional_successor(program: str, generic_program: str): + """Test the parsing of successors""" + + @irdl_op_definition + class OptSuccessorOp(IRDLOperation): + name = "test.opt_successor" + succ = opt_successor_def() + + assembly_format = "$succ attr-dict" + + ctx = MLContext() + ctx.load_op(OptSuccessorOp) + ctx.load_dialect(Test) + + check_roundtrip(program, ctx) + check_equivalence(program, generic_program, ctx) + + ################################################################################ # Inference # ################################################################################ diff --git a/xdsl/irdl/declarative_assembly_format.py b/xdsl/irdl/declarative_assembly_format.py index 40d3cbc121..241417ad99 100644 --- a/xdsl/irdl/declarative_assembly_format.py +++ b/xdsl/irdl/declarative_assembly_format.py @@ -26,6 +26,7 @@ IRDLOperationInvT, OpDef, OptionalDef, + Successor, VariadicDef, VarIRConstruct, ) @@ -50,19 +51,17 @@ class ParsingState: operand_types: list[Attribute | None | list[Attribute | None]] result_types: list[Attribute | None | list[Attribute | None]] regions: list[Region | None | list[Region]] + successors: list[Successor | None | list[Successor]] attributes: dict[str, Attribute] properties: dict[str, Attribute] constraint_context: ConstraintContext def __init__(self, op_def: OpDef): - if op_def.successors: - raise NotImplementedError( - "Operation definitions with successors are not yet supported" - ) self.operands = [None] * len(op_def.operands) self.operand_types = [None] * len(op_def.operands) self.result_types = [None] * len(op_def.results) self.regions = [None] * len(op_def.regions) + self.successors = [None] * len(op_def.successors) self.attributes = {} self.properties = {} self.constraint_context = ConstraintContext() @@ -167,6 +166,7 @@ def parse( attributes=state.attributes, properties=properties, regions=state.regions, + successors=state.successors, ) def assign_constraint_variables( @@ -805,6 +805,80 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No state.should_emit_space = True +class SuccessorVariable(VariableDirective, OptionallyParsableDirective): + """ + A successor variable, with the following format: + successor-directive ::= dollar-ident + The directive will request a space to be printed after. + """ + + def parse_optional(self, parser: Parser, state: ParsingState) -> bool: + successor = parser.parse_optional_successor() + + state.successors[self.index] = successor + + return successor is not None + + def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: + if state.should_emit_space or not state.last_was_punctuation: + printer.print(" ") + printer.print_block_name(getattr(op, self.name)) + state.last_was_punctuation = False + state.should_emit_space = True + + +class VariadicSuccessorVariable(VariadicVariable, OptionallyParsableDirective): + """ + A variadic successor variable, with the following format: + successor-directive ::= dollar-ident + The directive will request a space to be printed after. + """ + + def parse_optional(self, parser: Parser, state: ParsingState) -> bool: + successors: list[Successor] = [] + current_successor = parser.parse_optional_successor() + while current_successor is not None: + successors.append(current_successor) + current_successor = parser.parse_optional_successor() + + state.successors[self.index] = successors + + return bool(successors) + + def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: + if state.should_emit_space or not state.last_was_punctuation: + printer.print(" ") + successor = getattr(op, self.name) + if successor: + printer.print_list(successor, printer.print_block_name, delimiter=" ") + state.last_was_punctuation = False + state.should_emit_space = True + + +class OptionalSuccessorVariable(OptionalVariable, OptionallyParsableDirective): + """ + An optional successor variable, with the following format: + successor-directive ::= dollar-ident + The directive will request a space to be printed after. + """ + + def parse_optional(self, parser: Parser, state: ParsingState) -> bool: + successor = parser.parse_optional_successor() + if successor is None: + successor = list[Successor]() + state.successors[self.index] = successor + return bool(successor) + + def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: + if state.should_emit_space or not state.last_was_punctuation: + printer.print(" ") + successor = getattr(op, self.name) + if successor: + printer.print_block_name(successor) + state.last_was_punctuation = False + state.should_emit_space = True + + @dataclass(frozen=True) class AttributeVariable(FormatDirective): """ diff --git a/xdsl/irdl/declarative_assembly_format_parser.py b/xdsl/irdl/declarative_assembly_format_parser.py index 33b7c95b10..c5cf492967 100644 --- a/xdsl/irdl/declarative_assembly_format_parser.py +++ b/xdsl/irdl/declarative_assembly_format_parser.py @@ -23,12 +23,14 @@ OptOperandDef, OptRegionDef, OptResultDef, + OptSuccessorDef, ParamAttrConstraint, ParsePropInAttrDict, VariadicDef, VarOperandDef, VarRegionDef, VarResultDef, + VarSuccessorDef, ) from xdsl.irdl.declarative_assembly_format import ( AnchorableDirective, @@ -48,11 +50,13 @@ OptionalRegionVariable, OptionalResultTypeDirective, OptionalResultVariable, + OptionalSuccessorVariable, OptionalUnitAttrVariable, PunctuationDirective, RegionVariable, ResultTypeDirective, ResultVariable, + SuccessorVariable, VariableDirective, VariadicLikeFormatDirective, VariadicLikeTypeDirective, @@ -62,6 +66,7 @@ VariadicRegionVariable, VariadicResultTypeDirective, VariadicResultVariable, + VariadicSuccessorVariable, WhitespaceDirective, ) from xdsl.parser import BaseParser, ParserState @@ -130,6 +135,8 @@ class FormatParser(BaseParser): """The properties that are already parsed.""" seen_regions: list[bool] """The region variables that are already parsed.""" + seen_successors: list[bool] + """The successor variables that are already parsed.""" has_attr_dict: bool = field(default=False) """True if the attribute dictionary has already been parsed.""" context: ParsingContext = field(default=ParsingContext.TopLevel) @@ -149,6 +156,7 @@ def __init__(self, input: str, op_def: OpDef): self.seen_attributes = set[str]() self.seen_properties = set[str]() self.seen_regions = [False] * len(op_def.regions) + self.seen_successors = [False] * len(op_def.successors) self.type_resolutions = {} def parse_format(self) -> FormatProgram: @@ -170,6 +178,7 @@ def parse_format(self) -> FormatProgram: self.verify_operands(seen_variables) self.verify_results(seen_variables) self.verify_regions() + self.verify_successors() return FormatProgram(elements) def verify_directives(self, elements: list[FormatDirective]): @@ -323,6 +332,25 @@ def verify_regions(self): "directive to the custom assembly format." ) + def verify_successors(self): + """ + Check that all successors are present. + """ + for ( + seen_successor, + (successor_name, _), + ) in zip( + self.seen_successors, + self.op_def.successors, + strict=True, + ): + if not seen_successor: + self.raise_error( + f"successor '{successor_name}' " + f"not found, consider adding a '${successor_name}' " + "directive to the custom assembly format." + ) + def parse_optional_variable( self, ) -> VariableDirective | AttributeVariable | None: @@ -389,6 +417,19 @@ def parse_optional_variable( case _: return RegionVariable(variable_name, idx) + # Check if the variable is a successor + for idx, (successor_name, successor_def) in enumerate(self.op_def.successors): + if variable_name != successor_name: + continue + self.seen_successors[idx] = True + match successor_def: + case OptSuccessorDef(): + return OptionalSuccessorVariable(variable_name, idx) + case VarSuccessorDef(): + return VariadicSuccessorVariable(variable_name, idx) + case _: + return SuccessorVariable(variable_name, idx) + attr_or_prop_by_name = { attr_name: attr_or_prop for attr_name, attr_or_prop in self.op_def.accessor_names.values()