diff --git a/src/xdsl/printer.py b/src/xdsl/printer.py index ac89040b47..43c1755a1f 100644 --- a/src/xdsl/printer.py +++ b/src/xdsl/printer.py @@ -3,6 +3,7 @@ from xdsl.diagnostic import * from typing import TypeVar from dataclasses import dataclass +from enum import Enum indentNumSpaces = 2 @@ -10,10 +11,15 @@ @dataclass(eq=False, repr=False) class Printer: + class TypeLocation(Enum): + NONE = 1 + INLINE = 2 + AFTER = 3 + stream: Optional[Any] = field(default=None) print_generic_format: bool = field(default=False) - print_operand_types: bool = field(default=True) - print_result_types: bool = field(default=True) + print_operand_types: TypeLocation = field(default=TypeLocation.INLINE) + print_result_types: TypeLocation = field(default=TypeLocation.INLINE) diagnostic: Diagnostic = field(default_factory=Diagnostic) _indent: int = field(default=0, init=False) _ssa_values: Dict[SSAValue, str] = field(default_factory=dict, init=False) @@ -130,7 +136,7 @@ def _print_result_value(self, op: Operation, idx: int) -> None: name = self._get_new_valid_name_id() self._ssa_values[val] = name self.print("%s" % name) - if self.print_result_types: + if self.print_result_types == Printer.TypeLocation.INLINE: self.print(" : ") self.print_attribute(val.typ) @@ -168,7 +174,7 @@ def print_ssa_value(self, value: SSAValue) -> None: def _print_operand(self, operand: SSAValue) -> None: self.print_ssa_value(operand) - if self.print_operand_types: + if self.print_operand_types == Printer.TypeLocation.INLINE: self.print(" : ") self.print_attribute(operand.typ) @@ -311,6 +317,25 @@ def print_op_with_default_format(self, op: Operation) -> None: self._print_op_attributes(op.attributes) self.print_regions(op.regions) + if self.print_operand_types == Printer.TypeLocation.AFTER and self.print_result_types == Printer.TypeLocation.AFTER: + self.print(" : (") + first = True + for operand in op.operands: + if first: + first = False + else: + self.print(", ") + self.print_attribute(operand.typ) + self.print(") -> (") + first = True + for result in op.results: + if first: + first = False + else: + self.print(", ") + self.print_attribute(result.typ) + self.print(")") + def _print_op(self, op: Operation) -> None: begin_op_pos = self._current_column self._print_results(op) diff --git a/src/xdsl/xdsl_opt_main.py b/src/xdsl/xdsl_opt_main.py index 2cd1a6597c..df8847103b 100644 --- a/src/xdsl/xdsl_opt_main.py +++ b/src/xdsl/xdsl_opt_main.py @@ -187,12 +187,19 @@ def _output_xdsl(prog: ModuleOp, output: IOBase): printer = Printer(stream=output) printer.print_op(prog) + def _output_xdsl_mlir(prog: ModuleOp, output: IOBase): + printer = Printer(stream=output, + print_operand_types=Printer.TypeLocation.AFTER, + print_result_types=Printer.TypeLocation.AFTER) + printer.print_op(prog) + def _output_mlir(prog: ModuleOp, output: IOBase): converter = MLIRConverter(self.ctx) mlir_module = converter.convert_module(prog) print(mlir_module, file=output) self.available_targets['xdsl'] = _output_xdsl + self.available_targets['xdsl-mlir'] = _output_xdsl_mlir try: from xdsl.mlir_converter import MLIRConverter self.available_targets['mlir'] = _output_mlir