Skip to content

Commit

Permalink
Printer: print <UNKNOWN> instead of abort (#119)
Browse files Browse the repository at this point in the history
  • Loading branch information
webmiche authored Jun 23, 2022
1 parent 50ab2fc commit 88a8241
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 30 deletions.
34 changes: 19 additions & 15 deletions src/xdsl/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,22 +70,22 @@ def _print_message(self,
"""
Print a message.
This is expected to be called on the beginning of a new line, and expect to create a new line at the end.
[begin_pos, end_pos)
"""
indent = self._indent if indent is None else indent
self.print(" " * indent * indentNumSpaces)
indent_size = indent * indentNumSpaces
message_end_pos = max(len(line) for line in message.split("\n")) + 2
self.print(" " * indent_size)
message_end_pos = max(map(len, message.split("\n"))) + indent_size + 2
first_line = (begin_pos - indent_size) * "-" + (
end_pos - begin_pos + 1) * "^" + (max(message_end_pos, end_pos) -
end_pos) * "-"
end_pos - begin_pos) * "^" + (max(message_end_pos, end_pos) -
end_pos) * "-"
self.print(first_line)
self._print_new_line(indent=indent, print_message=False)
message_lines = message.split("\n")
for message_line in message_lines:
for message_line in message.split("\n"):
self.print("| ")
self.print(message_line)
self._print_new_line(indent=indent, print_message=False)
self.print("-" * (max(message_end_pos, end_pos) - indent_size + 1))
self.print("-" * (max(message_end_pos, end_pos) - indent_size))
self._print_new_line(indent=0, print_message=False)

T = TypeVar('T')
Expand All @@ -103,10 +103,9 @@ def _print_new_line(self, indent=None, print_message=True) -> None:
indent = self._indent if indent is None else indent
self.print("\n")
if print_message:
while len(self._next_line_callback) != 0:
callback = self._next_line_callback[0]
self._next_line_callback = self._next_line_callback[1:]
for callback in self._next_line_callback:
callback()
self._next_line_callback = []
self.print(" " * indent * indentNumSpaces)

def _get_new_valid_name_id(self) -> str:
Expand Down Expand Up @@ -156,10 +155,15 @@ def _print_results(self, op: Operation) -> None:
self.print(") = ")

def print_ssa_value(self, value: SSAValue) -> None:
if (self._ssa_values.get(value) == None):
raise KeyError("SSAValue is not part of the IR, are you sure"
" all operations are added before their uses?")
self.print(f"%{self._ssa_values[value]}")
if ssa_val := self._ssa_values.get(value):
self.print(f"%{ssa_val}")
else:
begin_pos = self._current_column
self.print("%<UNKNOWN>")
end_pos = self._current_column
self._add_message_on_next_line(
"ERROR: SSAValue is not part of the IR, are you sure all operations are added before their uses?",
begin_pos, end_pos)

def _print_operand(self, operand: SSAValue) -> None:
self.print_ssa_value(operand)
Expand Down Expand Up @@ -314,7 +318,7 @@ def _print_op(self, op: Operation) -> None:
self.print(f'"{op.name}"')
else:
self.print(op.name)
end_op_pos = self._current_column - 1
end_op_pos = self._current_column
if op in self.diagnostic.op_messages:
for message in self.diagnostic.op_messages[op]:
self._add_message_on_next_line(message, begin_op_pos,
Expand Down
74 changes: 59 additions & 15 deletions tests/printer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

from xdsl.printer import Printer
from xdsl.parser import Parser
from xdsl.dialects.builtin import Builtin
from xdsl.dialects.builtin import Builtin, ModuleOp
from xdsl.dialects.arith import *
from xdsl.diagnostic import Diagnostic


def test_forgotten_op():
def test_simple_forgotten_op():
"""Test that the parsing of an undefined operand raises an exception."""
ctx = MLContext()
arith = Arith(ctx)
Expand All @@ -18,13 +18,55 @@ def test_forgotten_op():
add = Addi.get(lit, lit)

add.verify()
try:
printer = Printer()
printer.print_op(add)
except KeyError:
return

assert False, "Exception expected"
expected = \
"""
%0 : !i32 = arith.addi(%<UNKNOWN> : !i32, %<UNKNOWN> : !i32)
-----------------------^^^^^^^^^^----------------------------------------------------------------
| ERROR: SSAValue is not part of the IR, are you sure all operations are added before their uses?
-------------------------------------------------------------------------------------------------
------------------------------------------^^^^^^^^^^---------------------------------------------
| ERROR: SSAValue is not part of the IR, are you sure all operations are added before their uses?
-------------------------------------------------------------------------------------------------
"""

file = StringIO("")
printer = Printer(stream=file)
printer.print_op(add)

assert file.getvalue().strip() == expected.strip()


def test_forgotten_op_non_fail():
"""Test that the parsing of an undefined operand raises an exception."""
ctx = MLContext()
arith = Arith(ctx)

lit = Constant.from_int_constant(42, 32)
add = Addi.get(lit, lit)
add2 = Addi.get(add, add)
mod = ModuleOp.from_region_or_ops([add, add2])
mod.verify()

expected = \
"""
module() {
%0 : !i32 = arith.addi(%<UNKNOWN> : !i32, %<UNKNOWN> : !i32)
-----------------------^^^^^^^^^^----------------------------------------------------------------
| ERROR: SSAValue is not part of the IR, are you sure all operations are added before their uses?
-------------------------------------------------------------------------------------------------
------------------------------------------^^^^^^^^^^---------------------------------------------
| ERROR: SSAValue is not part of the IR, are you sure all operations are added before their uses?
-------------------------------------------------------------------------------------------------
%1 : !i32 = arith.addi(%0 : !i32, %0 : !i32)
}
"""

file = StringIO("")
printer = Printer(stream=file)
printer.print_op(mod)

assert file.getvalue().strip() == expected.strip()


# ____ _ _ _
Expand All @@ -45,13 +87,15 @@ def test_op_message():
}"""

expected = \
"""module() {
"""
module() {
%0 : !i32 = arith.constant() ["value" = 42 : !i32]
^^^^^^^^^^^^^^^^^^^^^^^^^^
| Test message
--------------------------
%1 : !i32 = arith.addi(%0 : !i32, %0 : !i32)
}"""
}
"""

ctx = MLContext()
arith = Arith(ctx)
Expand Down Expand Up @@ -151,9 +195,9 @@ def test_op_message_with_region():
expected = \
"""\
module() {
^^^^^^-
^^^^^^
| Test
-------
------
%0 : !i32 = arith.constant() ["value" = 42 : !i32]
%1 : !i32 = arith.addi(%0 : !i32, %0 : !i32)
}"""
Expand Down Expand Up @@ -187,9 +231,9 @@ def test_op_message_with_region_and_overflow():
expected = \
"""\
module() {
^^^^^^---------
^^^^^^--------
| Test message
---------------
--------------
%0 : !i32 = arith.constant() ["value" = 42 : !i32]
%1 : !i32 = arith.addi(%0 : !i32, %0 : !i32)
}"""
Expand Down Expand Up @@ -404,4 +448,4 @@ def test_custom_format():
file = StringIO("")
printer = Printer(stream=file, print_generic_format=True)
printer.print_op(module)
assert file.getvalue().strip() == expected.strip()
assert file.getvalue().strip() == expected.strip()

0 comments on commit 88a8241

Please sign in to comment.