Skip to content

Commit

Permalink
Printer: add support to print types in MLIR style
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiasgrosser committed Jul 12, 2022
1 parent 3818617 commit db13ce5
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 4 deletions.
33 changes: 29 additions & 4 deletions src/xdsl/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,23 @@
from xdsl.diagnostic import *
from typing import TypeVar
from dataclasses import dataclass
from enum import Enum

indentNumSpaces = 2


@dataclass(eq=False, repr=False)
class Printer:

class TypeLocation(Enum):
NONE = 1
INLINE = 2
AFTER = 3

stream: Optional[Any] = field(default=None)
print_generic_format: bool = field(default=False)
print_operand_types: bool = field(default=True)
print_result_types: bool = field(default=True)
print_operand_types: TypeLocation = field(default=TypeLocation.INLINE)
print_result_types: TypeLocation = field(default=TypeLocation.INLINE)
diagnostic: Diagnostic = field(default_factory=Diagnostic)
_indent: int = field(default=0, init=False)
_ssa_values: Dict[SSAValue, str] = field(default_factory=dict, init=False)
Expand Down Expand Up @@ -130,7 +136,7 @@ def _print_result_value(self, op: Operation, idx: int) -> None:
name = self._get_new_valid_name_id()
self._ssa_values[val] = name
self.print("%s" % name)
if self.print_result_types:
if self.print_result_types == Printer.TypeLocation.INLINE:
self.print(" : ")
self.print_attribute(val.typ)

Expand Down Expand Up @@ -168,7 +174,7 @@ def print_ssa_value(self, value: SSAValue) -> None:
def _print_operand(self, operand: SSAValue) -> None:
self.print_ssa_value(operand)

if self.print_operand_types:
if self.print_operand_types == Printer.TypeLocation.INLINE:
self.print(" : ")
self.print_attribute(operand.typ)

Expand Down Expand Up @@ -311,6 +317,25 @@ def print_op_with_default_format(self, op: Operation) -> None:
self._print_op_attributes(op.attributes)
self.print_regions(op.regions)

if self.print_operand_types == Printer.TypeLocation.AFTER and self.print_result_types == Printer.TypeLocation.AFTER:
self.print(" : (")
first = True
for operand in op.operands:
if first:
first = False
else:
self.print(", ")
self.print_attribute(operand.typ)
self.print(") -> (")
first = True
for result in op.results:
if first:
first = False
else:
self.print(", ")
self.print_attribute(result.typ)
self.print(")")

def _print_op(self, op: Operation) -> None:
begin_op_pos = self._current_column
self._print_results(op)
Expand Down
7 changes: 7 additions & 0 deletions src/xdsl/xdsl_opt_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,19 @@ def _output_xdsl(prog: ModuleOp, output: IOBase):
printer = Printer(stream=output)
printer.print_op(prog)

def _output_xdsl_mlir(prog: ModuleOp, output: IOBase):
printer = Printer(stream=output,
print_operand_types=Printer.TypeLocation.AFTER,
print_result_types=Printer.TypeLocation.AFTER)
printer.print_op(prog)

def _output_mlir(prog: ModuleOp, output: IOBase):
converter = MLIRConverter(self.ctx)
mlir_module = converter.convert_module(prog)
print(mlir_module, file=output)

self.available_targets['xdsl'] = _output_xdsl
self.available_targets['xdsl-mlir'] = _output_xdsl_mlir
try:
from xdsl.mlir_converter import MLIRConverter
self.available_targets['mlir'] = _output_mlir
Expand Down

0 comments on commit db13ce5

Please sign in to comment.