Skip to content

Commit

Permalink
core: declarative format variable fix (#2313)
Browse files Browse the repository at this point in the history
  • Loading branch information
PapyChacal authored Mar 7, 2024
1 parent 40b4a68 commit 44af987
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 4 deletions.
39 changes: 38 additions & 1 deletion tests/irdl/test_declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest

from xdsl.dialects.builtin import ModuleOp
from xdsl.dialects.test import Test
from xdsl.dialects.test import Test, TestType
from xdsl.ir import (
Attribute,
MLContext,
Expand Down Expand Up @@ -1318,3 +1318,40 @@ class WrongOptionalGroupOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
mandatory_arg = operand_def()

assembly_format = format


@pytest.mark.parametrize(
"program, generic_program",
[
(
'%0 = "test.op"() : () -> !test.type<"index">\n' "test.mixed %0()",
'%0 = "test.op"() : () -> !test.type<"index">\n'
'"test.mixed"(%0) : (!test.type<"index">) -> ()',
),
(
'%0 = "test.op"() : () -> !test.type<"index">\n' "test.mixed %0(%0)",
'%0 = "test.op"() : () -> !test.type<"index">\n'
'"test.mixed"(%0, %0) : (!test.type<"index">, !test.type<"index">) -> ()',
),
(
'%0 = "test.op"() : () -> !test.type<"index">\n' "test.mixed %0(%0, %0)",
'%0 = "test.op"() : () -> !test.type<"index">\n'
'"test.mixed"(%0, %0, %0) : (!test.type<"index">, !test.type<"index">, !test.type<"index">) -> ()',
),
],
)
def test_variadic_and_single_mixed(program: str, generic_program: str):
@irdl_op_definition
class MixedOp(IRDLOperation):
name = "test.mixed"
var = var_operand_def(TestType("index"))
sin = operand_def(TestType("index"))

assembly_format = "$sin `(` $var `)` attr-dict"

ctx = MLContext()
ctx.load_op(MixedOp)
ctx.load_dialect(Test)

check_roundtrip(program, ctx)
check_equivalence(program, generic_program, ctx)
6 changes: 3 additions & 3 deletions xdsl/irdl/declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def parse(self, parser: Parser, state: ParsingState) -> None:
def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None:
if state.should_emit_space or not state.last_was_punctuation:
printer.print(" ")
printer.print_ssa_value(op.operands[self.index])
printer.print_ssa_value(getattr(op, self.name))
state.last_was_punctuation = False
state.should_emit_space = True

Expand Down Expand Up @@ -502,7 +502,7 @@ def parse(self, parser: Parser, state: ParsingState) -> None:
def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None:
if state.should_emit_space or not state.last_was_punctuation:
printer.print(" ")
printer.print_attribute(op.operands[self.index].type)
printer.print_attribute(getattr(op, self.name).type)
state.last_was_punctuation = False
state.should_emit_space = True

Expand Down Expand Up @@ -636,7 +636,7 @@ def parse(self, parser: Parser, state: ParsingState) -> None:
def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None:
if state.should_emit_space or not state.last_was_punctuation:
printer.print(" ")
printer.print_attribute(op.results[self.index].type)
printer.print_attribute(getattr(op, self.name).type)
state.last_was_punctuation = False
state.should_emit_space = True

Expand Down

0 comments on commit 44af987

Please sign in to comment.