diff --git a/src/xdsl/printer.py b/src/xdsl/printer.py index b3fdb4b9d2..ac89040b47 100644 --- a/src/xdsl/printer.py +++ b/src/xdsl/printer.py @@ -70,22 +70,22 @@ def _print_message(self, """ Print a message. This is expected to be called on the beginning of a new line, and expect to create a new line at the end. + [begin_pos, end_pos) """ indent = self._indent if indent is None else indent - self.print(" " * indent * indentNumSpaces) indent_size = indent * indentNumSpaces - message_end_pos = max(len(line) for line in message.split("\n")) + 2 + self.print(" " * indent_size) + message_end_pos = max(map(len, message.split("\n"))) + indent_size + 2 first_line = (begin_pos - indent_size) * "-" + ( - end_pos - begin_pos + 1) * "^" + (max(message_end_pos, end_pos) - - end_pos) * "-" + end_pos - begin_pos) * "^" + (max(message_end_pos, end_pos) - + end_pos) * "-" self.print(first_line) self._print_new_line(indent=indent, print_message=False) - message_lines = message.split("\n") - for message_line in message_lines: + for message_line in message.split("\n"): self.print("| ") self.print(message_line) self._print_new_line(indent=indent, print_message=False) - self.print("-" * (max(message_end_pos, end_pos) - indent_size + 1)) + self.print("-" * (max(message_end_pos, end_pos) - indent_size)) self._print_new_line(indent=0, print_message=False) T = TypeVar('T') @@ -103,10 +103,9 @@ def _print_new_line(self, indent=None, print_message=True) -> None: indent = self._indent if indent is None else indent self.print("\n") if print_message: - while len(self._next_line_callback) != 0: - callback = self._next_line_callback[0] - self._next_line_callback = self._next_line_callback[1:] + for callback in self._next_line_callback: callback() + self._next_line_callback = [] self.print(" " * indent * indentNumSpaces) def _get_new_valid_name_id(self) -> str: @@ -156,10 +155,15 @@ def _print_results(self, op: Operation) -> None: self.print(") = ") def print_ssa_value(self, value: SSAValue) -> None: - if (self._ssa_values.get(value) == None): - raise KeyError("SSAValue is not part of the IR, are you sure" - " all operations are added before their uses?") - self.print(f"%{self._ssa_values[value]}") + if ssa_val := self._ssa_values.get(value): + self.print(f"%{ssa_val}") + else: + begin_pos = self._current_column + self.print("%") + end_pos = self._current_column + self._add_message_on_next_line( + "ERROR: SSAValue is not part of the IR, are you sure all operations are added before their uses?", + begin_pos, end_pos) def _print_operand(self, operand: SSAValue) -> None: self.print_ssa_value(operand) @@ -314,7 +318,7 @@ def _print_op(self, op: Operation) -> None: self.print(f'"{op.name}"') else: self.print(op.name) - end_op_pos = self._current_column - 1 + end_op_pos = self._current_column if op in self.diagnostic.op_messages: for message in self.diagnostic.op_messages[op]: self._add_message_on_next_line(message, begin_op_pos, diff --git a/tests/printer_test.py b/tests/printer_test.py index 62ce946597..540a39598a 100644 --- a/tests/printer_test.py +++ b/tests/printer_test.py @@ -4,12 +4,12 @@ from xdsl.printer import Printer from xdsl.parser import Parser -from xdsl.dialects.builtin import Builtin +from xdsl.dialects.builtin import Builtin, ModuleOp from xdsl.dialects.arith import * from xdsl.diagnostic import Diagnostic -def test_forgotten_op(): +def test_simple_forgotten_op(): """Test that the parsing of an undefined operand raises an exception.""" ctx = MLContext() arith = Arith(ctx) @@ -18,13 +18,55 @@ def test_forgotten_op(): add = Addi.get(lit, lit) add.verify() - try: - printer = Printer() - printer.print_op(add) - except KeyError: - return - assert False, "Exception expected" + expected = \ +""" +%0 : !i32 = arith.addi(% : !i32, % : !i32) +-----------------------^^^^^^^^^^---------------------------------------------------------------- +| ERROR: SSAValue is not part of the IR, are you sure all operations are added before their uses? +------------------------------------------------------------------------------------------------- +------------------------------------------^^^^^^^^^^--------------------------------------------- +| ERROR: SSAValue is not part of the IR, are you sure all operations are added before their uses? +------------------------------------------------------------------------------------------------- +""" + + file = StringIO("") + printer = Printer(stream=file) + printer.print_op(add) + + assert file.getvalue().strip() == expected.strip() + + +def test_forgotten_op_non_fail(): + """Test that the parsing of an undefined operand raises an exception.""" + ctx = MLContext() + arith = Arith(ctx) + + lit = Constant.from_int_constant(42, 32) + add = Addi.get(lit, lit) + add2 = Addi.get(add, add) + mod = ModuleOp.from_region_or_ops([add, add2]) + mod.verify() + + expected = \ +""" +module() { + %0 : !i32 = arith.addi(% : !i32, % : !i32) + -----------------------^^^^^^^^^^---------------------------------------------------------------- + | ERROR: SSAValue is not part of the IR, are you sure all operations are added before their uses? + ------------------------------------------------------------------------------------------------- + ------------------------------------------^^^^^^^^^^--------------------------------------------- + | ERROR: SSAValue is not part of the IR, are you sure all operations are added before their uses? + ------------------------------------------------------------------------------------------------- + %1 : !i32 = arith.addi(%0 : !i32, %0 : !i32) +} +""" + + file = StringIO("") + printer = Printer(stream=file) + printer.print_op(mod) + + assert file.getvalue().strip() == expected.strip() # ____ _ _ _ @@ -45,13 +87,15 @@ def test_op_message(): }""" expected = \ -"""module() { +""" +module() { %0 : !i32 = arith.constant() ["value" = 42 : !i32] ^^^^^^^^^^^^^^^^^^^^^^^^^^ | Test message -------------------------- %1 : !i32 = arith.addi(%0 : !i32, %0 : !i32) -}""" +} +""" ctx = MLContext() arith = Arith(ctx) @@ -151,9 +195,9 @@ def test_op_message_with_region(): expected = \ """\ module() { -^^^^^^- +^^^^^^ | Test -------- +------ %0 : !i32 = arith.constant() ["value" = 42 : !i32] %1 : !i32 = arith.addi(%0 : !i32, %0 : !i32) }""" @@ -187,9 +231,9 @@ def test_op_message_with_region_and_overflow(): expected = \ """\ module() { -^^^^^^--------- +^^^^^^-------- | Test message ---------------- +-------------- %0 : !i32 = arith.constant() ["value" = 42 : !i32] %1 : !i32 = arith.addi(%0 : !i32, %0 : !i32) }""" @@ -404,4 +448,4 @@ def test_custom_format(): file = StringIO("") printer = Printer(stream=file, print_generic_format=True) printer.print_op(module) - assert file.getvalue().strip() == expected.strip() \ No newline at end of file + assert file.getvalue().strip() == expected.strip()