From b09e94ec178b82029c7e7542805fdb41fc7cd1a5 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Mon, 23 Jan 2023 18:28:48 +0000 Subject: [PATCH] xDSL: Rewrite Parser (#262) --- .../parser-printer/float_parsing.xdsl | 8 +- tests/test_attribute_builder.py | 12 +- tests/test_attribute_definition.py | 27 +- tests/test_ir.py | 10 +- tests/test_irdl.py | 6 +- tests/test_mlir_converter.py | 4 +- tests/test_mlir_printer.py | 19 +- tests/test_parser.py | 58 +- tests/test_parser_error.py | 41 +- tests/test_pattern_rewriter.py | 2 + tests/test_printer.py | 75 +- tests/test_ssa_value.py | 32 + tests/xdsl_opt/test_xdsl_opt.py | 15 +- xdsl/dialects/builtin.py | 71 +- xdsl/dialects/irdl.py | 4 +- xdsl/dialects/llvm.py | 13 +- xdsl/ir.py | 29 +- xdsl/parser.py | 2919 ++++++++++------- xdsl/printer.py | 3 +- xdsl/tools/xdsl-opt | 21 +- xdsl/utils/exceptions.py | 67 +- xdsl/xdsl_opt_main.py | 37 +- 22 files changed, 2061 insertions(+), 1412 deletions(-) create mode 100644 tests/test_ssa_value.py diff --git a/tests/filecheck/parser-printer/float_parsing.xdsl b/tests/filecheck/parser-printer/float_parsing.xdsl index 1810fffb41..cf08385de1 100644 --- a/tests/filecheck/parser-printer/float_parsing.xdsl +++ b/tests/filecheck/parser-printer/float_parsing.xdsl @@ -8,16 +8,16 @@ builtin.module() { %1 : !f32 = arith.constant() ["value" = -42.0 : !f32] // CHECK-NEXT: %{{.*}} : !f32 = arith.constant() ["value" = -42.0 : !f32] - %2 : !f32 = arith.constant() ["value" = 34e0 : !f32] + %2 : !f32 = arith.constant() ["value" = 34.e0 : !f32] // CHECK-NEXT: %{{.*}} : !f32 = arith.constant() ["value" = 34.0 : !f32] - %3 : !f32 = arith.constant() ["value" = 34e-23 : !f32] + %3 : !f32 = arith.constant() ["value" = 34.e-23 : !f32] // CHECK-NEXT: %{{.*}} : !f32 = arith.constant() ["value" = 3.4e-22 : !f32] - %4 : !f32 = arith.constant() ["value" = 34e12 : !f32] + %4 : !f32 = arith.constant() ["value" = 34.e12 : !f32] // CHECK-NEXT: %{{.*}} : !f32 = arith.constant() ["value" = 34000000000000.0 : !f32] - %5 : !f32 = arith.constant() ["value" = -34e-12 : !f32] + %5 : !f32 = arith.constant() ["value" = -34.e-12 : !f32] // CHECK-NEXT: %{{.*}} : !f32 = arith.constant() ["value" = -3.4e-11 : !f32] func.return() diff --git a/tests/test_attribute_builder.py b/tests/test_attribute_builder.py index f89794fffe..cda2add81b 100644 --- a/tests/test_attribute_builder.py +++ b/tests/test_attribute_builder.py @@ -4,7 +4,7 @@ from xdsl.ir import ParametrizedAttribute, Data from xdsl.irdl import irdl_attr_definition, builder -from xdsl.parser import Parser +from xdsl.parser import BaseParser from xdsl.printer import Printer from xdsl.utils.exceptions import BuilderNotFoundException @@ -34,7 +34,7 @@ def from_int(data: int) -> OneBuilderAttr: return OneBuilderAttr(str(data)) @staticmethod - def parse_parameter(parser: Parser) -> str: + def parse_parameter(parser: BaseParser) -> str: raise NotImplementedError() @staticmethod @@ -72,7 +72,7 @@ def from_int(data1: int, data2: str) -> OneBuilderAttr: return OneBuilderAttr(str(data1) + data2) @staticmethod - def parse_parameter(parser: Parser) -> str: + def parse_parameter(parser: BaseParser) -> str: raise NotImplementedError() @staticmethod @@ -105,7 +105,7 @@ def from_str(s: str) -> TwoBuildersAttr: return TwoBuildersAttr(s) @staticmethod - def parse_parameter(parser: Parser) -> str: + def parse_parameter(parser: BaseParser) -> str: raise NotImplementedError() @staticmethod @@ -145,7 +145,7 @@ def from_int(data1: int, return BuilderDefaultArgAttr(f"{data1}, {data2}, {data3}") @staticmethod - def parse_parameter(parser: Parser) -> str: + def parse_parameter(parser: BaseParser) -> str: raise NotImplementedError() @staticmethod @@ -188,7 +188,7 @@ def from_int(data: str | int) -> BuilderUnionArgAttr: return BuilderUnionArgAttr(str(data)) @staticmethod - def parse_parameter(parser: Parser) -> str: + def parse_parameter(parser: BaseParser) -> str: raise NotImplementedError() @staticmethod diff --git a/tests/test_attribute_definition.py b/tests/test_attribute_definition.py index cb78352cb9..f7436c75ed 100644 --- a/tests/test_attribute_definition.py +++ b/tests/test_attribute_definition.py @@ -13,7 +13,7 @@ from xdsl.irdl import (AttrConstraint, GenericData, ParameterDef, irdl_attr_definition, builder, irdl_to_attr_constraint, AnyAttr, BaseAttr, ParamAttrDef) -from xdsl.parser import Parser +from xdsl.parser import BaseParser from xdsl.printer import Printer from xdsl.utils.exceptions import VerifyException @@ -31,14 +31,13 @@ class BoolData(Data[bool]): name = "bool" @staticmethod - def parse_parameter(parser: Parser) -> bool: - val = parser.parse_optional_ident() - if val == "True": + def parse_parameter(parser: BaseParser) -> bool: + val = parser.tokenizer.next_token_of_pattern('(True|False)') + if val is None or val.text not in ('True', 'False'): + parser.raise_error("Expected True or False literal") + if val.text == "True": return True - elif val == "False": - return False - else: - raise Exception("Wrong argument passed to BoolAttr.") + return False @staticmethod def print_parameter(data: bool, printer: Printer): @@ -51,7 +50,7 @@ class IntData(Data[int]): name = "int" @staticmethod - def parse_parameter(parser: Parser) -> int: + def parse_parameter(parser: BaseParser) -> int: return parser.parse_int_literal() @staticmethod @@ -65,7 +64,7 @@ class StringData(Data[str]): name = "str" @staticmethod - def parse_parameter(parser: Parser) -> str: + def parse_parameter(parser: BaseParser) -> str: return parser.parse_str_literal() @staticmethod @@ -102,7 +101,7 @@ class IntListMissingVerifierData(Data[list[int]]): name = "missing_verifier_data" @staticmethod - def parse_parameter(parser: Parser) -> list[int]: + def parse_parameter(parser: BaseParser) -> list[int]: raise NotImplementedError() @staticmethod @@ -134,7 +133,7 @@ class IntListData(Data[list[int]]): name = "int_list" @staticmethod - def parse_parameter(parser: Parser) -> list[int]: + def parse_parameter(parser: BaseParser) -> list[int]: raise NotImplementedError() @staticmethod @@ -431,7 +430,7 @@ class MissingGenericDataData(Data[_MissingGenericDataData]): name = "missing_genericdata" @staticmethod - def parse_parameter(parser: Parser) -> _MissingGenericDataData: + def parse_parameter(parser: BaseParser) -> _MissingGenericDataData: raise NotImplementedError() @staticmethod @@ -484,7 +483,7 @@ class ListData(GenericData[list[A]]): name = "list" @staticmethod - def parse_parameter(parser: Parser) -> list[A]: + def parse_parameter(parser: BaseParser) -> list[A]: raise NotImplementedError() @staticmethod diff --git a/tests/test_ir.py b/tests/test_ir.py index 1b690e4ecf..47fbe88eaa 100644 --- a/tests/test_ir.py +++ b/tests/test_ir.py @@ -4,7 +4,7 @@ from xdsl.dialects.arith import Addi, Subi, Constant from xdsl.dialects.builtin import i32, IntegerAttr, ModuleOp from xdsl.dialects.scf import If -from xdsl.parser import Parser +from xdsl.parser import XDSLParser from xdsl.dialects.builtin import Builtin from xdsl.dialects.func import Func from xdsl.dialects.arith import Arith @@ -203,10 +203,10 @@ def test_is_structurally_equivalent(args: list[str], expected_result: bool): ctx.register_dialect(Arith) ctx.register_dialect(Cf) - parser = Parser(ctx, args[0]) + parser = XDSLParser(ctx, args[0]) lhs: Operation = parser.parse_op() - parser = Parser(ctx, args[1]) + parser = XDSLParser(ctx, args[1]) rhs: Operation = parser.parse_op() assert lhs.is_structurally_equivalent(rhs) == expected_result @@ -231,8 +231,8 @@ def test_is_structurally_equivalent_incompatible_ir_nodes(): ctx.register_dialect(Arith) ctx.register_dialect(Cf) - parser = Parser(ctx, program_func) - program: ModuleOp = parser.parse_op() + parser = XDSLParser(ctx, program_func) + program: ModuleOp = parser.parse_operation() assert program.is_structurally_equivalent(program.regions[0]) == False assert program.is_structurally_equivalent( diff --git a/tests/test_irdl.py b/tests/test_irdl.py index c28c4b2b1a..9bd1248a2f 100644 --- a/tests/test_irdl.py +++ b/tests/test_irdl.py @@ -5,7 +5,7 @@ from xdsl.ir import Attribute, Data, ParametrizedAttribute from xdsl.irdl import AllOf, AnyAttr, AnyOf, AttrConstraint, BaseAttr, EqAttrConstraint, ParamAttrConstraint, ParameterDef, irdl_attr_definition -from xdsl.parser import Parser +from xdsl.parser import BaseParser from xdsl.printer import Printer from xdsl.utils.exceptions import VerifyException @@ -16,7 +16,7 @@ class BoolData(Data[bool]): name = "bool" @staticmethod - def parse_parameter(parser: Parser) -> bool: + def parse_parameter(parser: BaseParser) -> bool: raise NotImplementedError() @staticmethod @@ -30,7 +30,7 @@ class IntData(Data[int]): name = "int" @staticmethod - def parse_parameter(parser: Parser) -> int: + def parse_parameter(parser: BaseParser) -> int: return parser.parse_int_literal() @staticmethod diff --git a/tests/test_mlir_converter.py b/tests/test_mlir_converter.py index c68f609c59..465dfb51f6 100644 --- a/tests/test_mlir_converter.py +++ b/tests/test_mlir_converter.py @@ -9,7 +9,7 @@ from xdsl.dialects.affine import Affine from xdsl.dialects.arith import Arith -from xdsl.parser import Parser +from xdsl.parser import XDSLParser from xdsl.ir import MLContext from xdsl.dialects.builtin import Builtin @@ -23,7 +23,7 @@ def convert_and_verify(test_prog: str): ctx.register_dialect(Scf) ctx.register_dialect(MemRef) - parser = Parser(ctx, test_prog) + parser = XDSLParser(ctx, test_prog) module = parser.parse_op() module.verify() diff --git a/tests/test_mlir_printer.py b/tests/test_mlir_printer.py index 38aed4e260..0fd69d769a 100644 --- a/tests/test_mlir_printer.py +++ b/tests/test_mlir_printer.py @@ -1,14 +1,11 @@ +import re from io import StringIO from typing import Annotated -import re -from xdsl.dialects.builtin import Builtin -from xdsl.dialects.memref import MemRef -from xdsl.dialects.func import Func from xdsl.ir import Attribute, Data, MLContext, MLIRType, Operation, ParametrizedAttribute -from xdsl.irdl import (AnyAttr, ParameterDef, RegionDef, VarOpResult, - VarOperand, irdl_attr_definition, irdl_op_definition) -from xdsl.parser import Parser +from xdsl.irdl import (AnyAttr, ParameterDef, RegionDef, irdl_attr_definition, + irdl_op_definition, VarOperand, VarOpResult) +from xdsl.parser import BaseParser, XDSLParser from xdsl.printer import Printer @@ -33,7 +30,7 @@ class DataAttr(Data[int]): name = "data_attr" @staticmethod - def parse_parameter(parser: Parser) -> int: + def parse_parameter(parser: BaseParser) -> int: return parser.parse_int_literal() @staticmethod @@ -47,7 +44,7 @@ class DataType(Data[int], MLIRType): name = "data_type" @staticmethod - def parse_parameter(parser: Parser) -> int: + def parse_parameter(parser: BaseParser) -> int: return parser.parse_int_literal() @staticmethod @@ -92,8 +89,8 @@ def print_as_mlir_and_compare(test_prog: str, expected: str): ctx.register_attr(ParamAttrWithParam) ctx.register_attr(ParamAttrWithCustomFormat) - parser = Parser(ctx, test_prog) - module = parser.parse_op() + parser = XDSLParser(ctx, test_prog) + module = parser.parse_operation() res = StringIO() printer = Printer(target=Printer.Target.MLIR, stream=res) diff --git a/tests/test_parser.py b/tests/test_parser.py index 8e3c90707d..7d3d186c5c 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -1,7 +1,11 @@ +from io import StringIO + import pytest -from xdsl.ir import MLContext -from xdsl.parser import Parser +from xdsl.printer import Printer +from xdsl.ir import MLContext, Attribute +from xdsl.parser import XDSLParser +from xdsl.dialects.builtin import IntAttr, DictionaryAttr, StringAttr, ArrayAttr, Builtin @pytest.mark.parametrize("input,expected", [("0, 1, 1", [0, 1, 1]), @@ -9,29 +13,31 @@ ("1, 1, 0", [1, 1, 0])]) def test_int_list_parser(input: str, expected: list[int]): ctx = MLContext() - parser = Parser(ctx, input) - - int_list = parser.parse_list(parser.parse_int_literal) - assert int_list == expected - - -@pytest.mark.parametrize("input,expected", [('{"A"=0, "B"=1, "C"=2}', { - "A": 0, - "B": 1, - "C": 2 -}), ('{"MA"=10, "BR"=7, "Z"=3}', { - "MA": 10, - "BR": 7, - "Z": 3 -}), ('{"Q"=77, "VV"=12, "AA"=-8}', { - "Q": 77, - "VV": 12, - "AA": -8 -})]) -def test_int_dictionary_parser(input: str, expected: dict[str, int]): + parser = XDSLParser(ctx, input) + + int_list = parser.parse_list_of(parser.try_parse_integer_literal, '') + assert [int(span.text) for span in int_list] == expected + + +@pytest.mark.parametrize('data', [ + dict(a=IntAttr.from_int(1), b=IntAttr.from_int(2), c=IntAttr.from_int(3)), + dict(a=StringAttr.from_str('hello'), + b=IntAttr.from_int(2), + c=ArrayAttr.from_list( + [IntAttr.from_int(2), + StringAttr.from_str('world')])), + dict(), +]) +def test_dictionary_attr(data: dict[str, Attribute]): + attr = DictionaryAttr.from_dict(data) + + with StringIO() as io: + Printer(io).print(attr) + text = io.getvalue() + ctx = MLContext() - parser = Parser(ctx, input) + ctx.register_dialect(Builtin) + + attr = XDSLParser(ctx, text).parse_attribute() - int_dict = parser.parse_dictionary(parser.parse_str_literal, - parser.parse_int_literal) - assert int_dict == expected + assert attr.data == data diff --git a/tests/test_parser_error.py b/tests/test_parser_error.py index 46dab27b1e..de67c51b8f 100644 --- a/tests/test_parser_error.py +++ b/tests/test_parser_error.py @@ -1,11 +1,12 @@ -from __future__ import annotations from typing import Annotated -from xdsl.ir import MLContext -from xdsl.irdl import AnyAttr, VarOpResult, VarOperand, irdl_op_definition, Operation -from xdsl.parser import Parser, ParserError from pytest import raises +from xdsl.ir import MLContext +from xdsl.irdl import AnyAttr, irdl_op_definition, Operation, VarOperand, VarOpResult +from xdsl.parser import XDSLParser +from xdsl.utils.exceptions import ParseError + @irdl_op_definition class UnkownOp(Operation): @@ -18,14 +19,19 @@ def check_error(prog: str, line: int, column: int, message: str): ctx = MLContext() ctx.register_op(UnkownOp) - parser = Parser(ctx, prog) - with raises(ParserError) as e: - parser.parse_op() + parser = XDSLParser(ctx, prog) + with raises(ParseError) as e: + parser.parse_operation() + + assert e.value.span - assert e.value.pos - assert e.value.pos.line is line - assert e.value.pos.column is column - assert e.value.message == message + for err in e.value.history.iterate(): + if message in err.error.msg: + assert err.error.span.get_line_col() == (line, column) + break + else: + assert False, "'{}' not found in an error message {}!".format( + message, e.value.args) def test_parser_missing_equal(): @@ -39,7 +45,8 @@ def test_parser_missing_equal(): %0 : !i32 unknown() } """ - check_error(prog, 3, 13, "'=' expected, got 'u'") + check_error(prog, 3, 12, + "Operation definitions expect an `=` after op-result-list!") def test_parser_redefined_value(): @@ -54,7 +61,7 @@ def test_parser_redefined_value(): %val : !i32 = unknown() } """ - check_error(prog, 4, 3, "SSA value val is already defined") + check_error(prog, 4, 2, "SSA value %val is already defined") def test_parser_missing_operation_name(): @@ -68,11 +75,11 @@ def test_parser_missing_operation_name(): %val : !i32 = } """ - check_error(prog, 4, 1, "operation name expected") + check_error(prog, 4, 0, "Expected an operation name here") -def test_parser_missing_attribute(): - """Test a missing attribute error.""" +def test_parser_malformed_type(): + """Test a missing type error.""" ctx = MLContext() ctx.register_op(UnkownOp) @@ -82,4 +89,4 @@ def test_parser_missing_attribute(): %val : i32 = unknown() } """ - check_error(prog, 3, 10, "attribute expected") + check_error(prog, 3, 9, "Expected type of value-id here!") diff --git a/tests/test_pattern_rewriter.py b/tests/test_pattern_rewriter.py index 510e0929e9..d34ba36726 100644 --- a/tests/test_pattern_rewriter.py +++ b/tests/test_pattern_rewriter.py @@ -22,6 +22,8 @@ def rewrite_and_compare(prog: str, expected_prog: str, parser = Parser(ctx, prog) module = parser.parse_op() + assert isinstance(module, ModuleOp) + walker.rewrite_module(module) file = StringIO("") printer = Printer(stream=file) diff --git a/tests/test_printer.py b/tests/test_printer.py index 93f2ec7e5a..3e2b9b5d6a 100644 --- a/tests/test_printer.py +++ b/tests/test_printer.py @@ -1,17 +1,17 @@ from __future__ import annotations +import re from io import StringIO from typing import List, Annotated -from xdsl.dialects.func import Func, FuncOp -from xdsl.dialects.builtin import Builtin, IntAttr, ModuleOp, IntegerType, UnitAttr from xdsl.dialects.arith import Arith, Addi, Constant - -from xdsl.ir import Attribute, MLContext, OpResult, ParametrizedAttribute +from xdsl.dialects.builtin import Builtin, IntAttr, ModuleOp, IntegerType, UnitAttr +from xdsl.dialects.func import Func +from xdsl.ir import Attribute, MLContext, OpResult, ParametrizedAttribute, SSAValue from xdsl.irdl import (ParameterDef, irdl_attr_definition, irdl_op_definition, Operation, Operand, OptAttributeDef) +from xdsl.parser import Parser, BaseParser, Span, XDSLParser from xdsl.printer import Printer -from xdsl.parser import Parser from xdsl.utils.diagnostic import Diagnostic @@ -149,7 +149,7 @@ def test_op_message(): ctx.register_dialect(Arith) ctx.register_dialect(Builtin) - parser = Parser(ctx, prog) + parser = XDSLParser(ctx, prog) module = parser.parse_op() file = StringIO("") @@ -184,8 +184,8 @@ def test_two_different_op_messages(): ctx.register_dialect(Arith) ctx.register_dialect(Builtin) - parser = Parser(ctx, prog) - module = parser.parse_op() + parser = XDSLParser(ctx, prog) + module = parser.parse_module() file = StringIO("") diagnostic = Diagnostic() @@ -220,7 +220,7 @@ def test_two_same_op_messages(): ctx.register_dialect(Arith) ctx.register_dialect(Builtin) - parser = Parser(ctx, prog) + parser = XDSLParser(ctx, prog) module = parser.parse_op() file = StringIO("") @@ -254,7 +254,7 @@ def test_op_message_with_region(): ctx.register_dialect(Arith) ctx.register_dialect(Builtin) - parser = Parser(ctx, prog) + parser = XDSLParser(ctx, prog) module = parser.parse_op() file = StringIO("") @@ -290,7 +290,7 @@ def test_op_message_with_region_and_overflow(): ctx.register_dialect(Arith) ctx.register_dialect(Builtin) - parser = Parser(ctx, prog) + parser = XDSLParser(ctx, prog) module = parser.parse_op() file = StringIO("") @@ -316,7 +316,7 @@ def test_diagnostic(): ctx.register_dialect(Arith) ctx.register_dialect(Builtin) - parser = Parser(ctx, prog) + parser = XDSLParser(ctx, prog) module = parser.parse_op() diag = Diagnostic() @@ -356,7 +356,7 @@ def test_print_custom_name(): ctx.register_dialect(Arith) ctx.register_dialect(Builtin) - parser = Parser(ctx, prog) + parser = XDSLParser(ctx, prog) module = parser.parse_op() file = StringIO("") @@ -382,13 +382,19 @@ class PlusCustomFormatOp(Operation): @classmethod def parse(cls, result_types: List[Attribute], - parser: Parser) -> PlusCustomFormatOp: - lhs = parser.parse_ssa_value() - parser.skip_white_space() - parser.parse_char("+") - rhs = parser.parse_ssa_value() - return PlusCustomFormatOp.create(operands=[lhs, rhs], - result_types=result_types) + parser: BaseParser) -> PlusCustomFormatOp: + + lhs = parser.expect(parser.try_parse_value_id, + 'Expected SSA Value name here!') + parser.parse_characters("+", + "Malformed operation format, expected `+`!") + rhs = parser.expect(parser.try_parse_value_id, + 'Expected SSA Value name here!') + + return PlusCustomFormatOp.create( + operands=[parser.get_ssa_val(lhs), + parser.get_ssa_val(rhs)], + result_types=result_types) def print(self, printer: Printer): printer.print(" ", self.lhs, " + ", self.rhs) @@ -416,7 +422,7 @@ def test_generic_format(): ctx.register_dialect(Builtin) ctx.register_op(PlusCustomFormatOp) - parser = Parser(ctx, prog) + parser = XDSLParser(ctx, prog) module = parser.parse_op() file = StringIO("") @@ -447,7 +453,7 @@ def test_custom_format(): ctx.register_dialect(Builtin) ctx.register_op(PlusCustomFormatOp) - parser = Parser(ctx, prog) + parser = XDSLParser(ctx, prog) module = parser.parse_op() file = StringIO("") @@ -478,7 +484,7 @@ def test_custom_format_II(): ctx.register_dialect(Builtin) ctx.register_op(PlusCustomFormatOp) - parser = Parser(ctx, prog) + parser = XDSLParser(ctx, prog) module = parser.parse_op() file = StringIO("") @@ -494,13 +500,14 @@ class CustomFormatAttr(ParametrizedAttribute): attr: ParameterDef[IntAttr] @staticmethod - def parse_parameters(parser: Parser) -> list[Attribute]: + def parse_parameters(parser: BaseParser) -> list[Attribute]: parser.parse_char("<") - value = parser.parse_alpha_num(skip_white_space=False) - if value == "zero": + value = parser.tokenizer.next_token_of_pattern( + re.compile('(zero|one)')) + if value and value.text == "zero": parser.parse_char(">") return [IntAttr.from_int(0)] - if value == "one": + if value and value.text == "one": parser.parse_char(">") return [IntAttr.from_int(1)] assert False @@ -535,7 +542,7 @@ def test_custom_format_attr(): ctx.register_op(AnyOp) ctx.register_attr(CustomFormatAttr) - parser = Parser(ctx, prog) + parser = XDSLParser(ctx, prog) module = parser.parse_op() file = StringIO("") @@ -550,7 +557,7 @@ def test_parse_generic_format_attr(): """ prog = \ """builtin.module() { - any() ["attr" = !"custom">] + any() ["attr" = #"custom"<#int<0>>] }""" expected = \ @@ -564,7 +571,7 @@ def test_parse_generic_format_attr(): ctx.register_op(AnyOp) ctx.register_attr(CustomFormatAttr) - parser = Parser(ctx, prog) + parser = XDSLParser(ctx, prog) module = parser.parse_op() file = StringIO("") @@ -593,7 +600,7 @@ def test_parse_generic_format_attr_II(): ctx.register_op(AnyOp) ctx.register_attr(CustomFormatAttr) - parser = Parser(ctx, prog) + parser = XDSLParser(ctx, prog) module = parser.parse_op() file = StringIO("") @@ -648,7 +655,7 @@ def test_parse_dense_xdsl(): ctx.register_dialect(Builtin) ctx.register_dialect(Arith) - parser = Parser(ctx, prog) + parser = XDSLParser(ctx, prog) module = parser.parse_op() file = StringIO("") @@ -696,7 +703,7 @@ def test_foo_string(): ctx.register_op(AnyOp) ctx.register_attr(CustomFormatAttr) - parser = Parser(ctx, prog) + parser = XDSLParser(ctx, prog) try: parser.parse_op() assert False @@ -715,7 +722,7 @@ def test_dictionary_attr(): ctx.register_dialect(Builtin) ctx.register_dialect(Func) - parser = Parser(ctx, prog) + parser = XDSLParser(ctx, prog) parsed = parser.parse_op() file = StringIO("") diff --git a/tests/test_ssa_value.py b/tests/test_ssa_value.py new file mode 100644 index 0000000000..2847a49233 --- /dev/null +++ b/tests/test_ssa_value.py @@ -0,0 +1,32 @@ +from io import StringIO +from typing import Callable + +import pytest + +from xdsl.dialects.arith import Arith, Constant, Addi +from xdsl.dialects.builtin import ModuleOp, Builtin, i32 +from xdsl.dialects.scf import Scf, Yield +from xdsl.dialects.func import Func +from xdsl.ir import MLContext, Block, SSAValue, OpResult, BlockArgument +from xdsl.parser import Parser +from xdsl.printer import Printer +from xdsl.rewriter import Rewriter + + +@pytest.mark.parametrize("name,result", [ + ('a', 'a'), + ('test', 'test'), + ('test1', None), + ('1', None), +]) +def test_ssa_value_name_hints(name, result): + """ + The rewriter assumes, that ssa value name hints (their .name field) does not end in a numeric value. If it does, + it will generate broken rewrites that potentially assign twice to an SSA value. + + Therefore, the SSAValue class prevents the setting of names ending in a number. + """ + val = BlockArgument(i32, Block(), 0) + + val.name = name + assert val.name == result diff --git a/tests/xdsl_opt/test_xdsl_opt.py b/tests/xdsl_opt/test_xdsl_opt.py index 55dcf6896b..4f71ccc983 100644 --- a/tests/xdsl_opt/test_xdsl_opt.py +++ b/tests/xdsl_opt/test_xdsl_opt.py @@ -27,20 +27,19 @@ def test_empty_program(): assert f.getvalue().strip() == expected.strip() -@pytest.mark.parametrize("args, expected_error", - [(['tests/xdsl_opt/not_module.xdsl'], - "Expected module or program as toplevel operation"), - (['tests/xdsl_opt/not_module.mlir'], - "Expected module or program as toplevel operation"), - (['tests/xdsl_opt/empty_program.wrong' - ], "Unrecognized file extension 'wrong'")]) +@pytest.mark.parametrize( + "args, expected_error", + [(['tests/xdsl_opt/not_module.xdsl'], "Expected ModuleOp at top level!"), + (['tests/xdsl_opt/not_module.mlir'], "Expected ModuleOp at top level!"), + (['tests/xdsl_opt/empty_program.wrong' + ], "Unrecognized file extension 'wrong'")]) def test_error_on_run(args, expected_error): opt = xDSLOptMain(args=args) with pytest.raises(Exception) as e: opt.run() - assert e.value.args[0] == expected_error + assert expected_error in e.value.args[0] @pytest.mark.parametrize( diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index 7038d540fa..a48dd903f0 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -6,7 +6,7 @@ TYPE_CHECKING, Any, TypeVar) from xdsl.ir import (Data, MLIRType, ParametrizedAttribute, Operation, - SSAValue, Region, Attribute, Dialect) + SSAValue, Region, Attribute, Dialect, MLContext) from xdsl.irdl import (AttributeDef, VarOpResult, VarOperand, VarRegionDef, irdl_attr_definition, attr_constr_coercion, irdl_data_definition, irdl_to_attr_constraint, @@ -16,7 +16,8 @@ from xdsl.utils.exceptions import VerifyException if TYPE_CHECKING: - from xdsl.parser import Parser, ParserError + from xdsl.parser import BaseParser + from utils.exceptions import ParseError from xdsl.printer import Printer @@ -25,7 +26,7 @@ class StringAttr(Data[str]): name = "string" @staticmethod - def parse_parameter(parser: Parser) -> str: + def parse_parameter(parser: BaseParser) -> str: data = parser.parse_str_literal() return data @@ -81,7 +82,7 @@ class IntAttr(Data[int]): name = "int" @staticmethod - def parse_parameter(parser: Parser) -> int: + def parse_parameter(parser: BaseParser) -> int: data = parser.parse_int_literal() return data @@ -110,14 +111,14 @@ class SignednessAttr(Data[Signedness]): name = "signedness" @staticmethod - def parse_parameter(parser: Parser) -> Signedness: + def parse_parameter(parser: BaseParser) -> Signedness: if parser.parse_optional_string("signless") is not None: return Signedness.SIGNLESS elif parser.parse_optional_string("signed") is not None: return Signedness.SIGNED elif parser.parse_optional_string("unsigned") is not None: return Signedness.UNSIGNED - raise ParserError(parser.get_pos(), "Expected signedness") + raise ParseError(parser.get_pos(), "Expected signedness") @staticmethod def print_parameter(data: Signedness, printer: Printer) -> None: @@ -226,7 +227,7 @@ class FloatData(Data[float]): name = "float_data" @staticmethod - def parse_parameter(parser: Parser) -> float: + def parse_parameter(parser: BaseParser) -> float: return parser.parse_float_literal() @staticmethod @@ -299,7 +300,7 @@ class ArrayAttr(GenericData[List[_ArrayAttrT]]): name = "array" @staticmethod - def parse_parameter(parser: Parser) -> List[_ArrayAttrT]: + def parse_parameter(parser: BaseParser) -> List[_ArrayAttrT]: parser.parse_char("[") data = parser.parse_list(parser.parse_optional_attribute) parser.parse_char("]") @@ -348,10 +349,10 @@ class DictionaryAttr(GenericData[dict[str, Attribute]]): name = "dictionary" @staticmethod - def parse_parameter(parser: Parser) -> dict[str, Attribute]: - data = parser.parse_dictionary(parser.parse_str_literal, - parser.parse_attribute) - return data + def parse_parameter(parser: BaseParser) -> dict[str, Attribute]: + # force MLIR style parsing of attribute + from xdsl.parser import MLIRParser + return MLIRParser.parse_optional_attr_dict(parser) @staticmethod def print_parameter(data: dict[str, Attribute], printer: Printer) -> None: @@ -385,16 +386,15 @@ def verify(self) -> None: def from_dict(data: dict[str | StringAttr, Attribute]) -> DictionaryAttr: to_add_data: dict[str, Attribute] = {} for k, v in data.items(): + # try to coerce keys into StringAttr + if isinstance(k, StringAttr): + k = k.data + # if coercion fails, raise KeyError! if not isinstance(k, str): - if isinstance(k, StringAttr): - to_add_data[k.data] = v - else: - raise TypeError( - f"Attribute DictionaryAttr expects keys to" - f" be of type StringAttr or str, but {type(k)} provided" - ) - else: - to_add_data[k] = v + raise TypeError( + f"DictionaryAttr.from_dict expects keys to" + f" be of type str or StringAttr, but {type(k)} provided") + to_add_data[k] = v return DictionaryAttr(to_add_data) @@ -695,20 +695,21 @@ class UnregisteredOp(Operation): def op_name(self) -> StringAttr: return self.op_name__ # type: ignore - @staticmethod - def from_name(name: str | StringAttr, - args: list[SSAValue | Operation] = [], - res: list[Attribute] = [], - regs: list[Region] = [], - attrs: dict[str, Attribute] = {}) -> UnregisteredOp: - if "op_name__" in attrs: - raise Exception( - "Cannot create an unregistered op with an __op_name attribute") - attrs["op_name__"] = StringAttr.build(name) - return UnregisteredOp.build(operands=args, - result_types=res, - regions=regs, - attributes=attrs) + @classmethod + def with_name(cls, name: str, ctx: MLContext) -> type[Operation]: + if name in ctx.registered_unregistered_ops: + return ctx.registered_unregistered_ops[name] # type: ignore + + class UnregisteredOpWithName(UnregisteredOp): + + @classmethod + def create(cls, **kwargs): + op = super().create(**kwargs) + op.attributes['op_name__'] = StringAttr.build(name) + return op + + ctx.registered_unregistered_ops[name] = UnregisteredOpWithName + return UnregisteredOpWithName @irdl_op_definition diff --git a/xdsl/dialects/irdl.py b/xdsl/dialects/irdl.py index d37462a98a..672da4b145 100644 --- a/xdsl/dialects/irdl.py +++ b/xdsl/dialects/irdl.py @@ -6,7 +6,7 @@ from xdsl.irdl import (ParameterDef, AnyAttr, AttributeDef, SingleBlockRegionDef, irdl_op_definition, irdl_attr_definition) -from xdsl.parser import Parser +from xdsl.parser import BaseParser from xdsl.printer import Printer @@ -61,7 +61,7 @@ class NamedTypeConstraintAttr(ParametrizedAttribute): params_constraints: ParameterDef[Attribute] @staticmethod - def parse_parameters(parser: Parser) -> list[Attribute]: + def parse_parameters(parser: BaseParser) -> list[Attribute]: parser.parse_char("<") type_name = parser.parse_str_literal() parser.parse_char(":") diff --git a/xdsl/dialects/llvm.py b/xdsl/dialects/llvm.py index cd154ffcc4..749787ca2c 100644 --- a/xdsl/dialects/llvm.py +++ b/xdsl/dialects/llvm.py @@ -9,7 +9,7 @@ from xdsl.dialects.builtin import StringAttr, ArrayOfConstraint, ArrayAttr if TYPE_CHECKING: - from xdsl.parser import Parser + from xdsl.parser import BaseParser from xdsl.printer import Printer @@ -38,10 +38,13 @@ def print_parameters(self, printer: Printer) -> None: printer.print(")>") @staticmethod - def parse_parameters(parser: Parser) -> list[Attribute]: - parser.parse_string("<(") - params = parser.parse_list(parser.parse_optional_attribute) - parser.parse_string(")>") + def parse_parameters(parser: BaseParser) -> list[Attribute]: + parser.parse_characters("<(", "LLVM Struct must start with `<(`") + params = parser.parse_list_of( + parser.try_parse_type, + "Malformed LLVM struct, expected attribute definition here!") + parser.parse_characters( + ")>", "Unexpected input, expected end of LLVM struct!") return [StringAttr.from_str(""), ArrayAttr.from_list(params)] diff --git a/xdsl/ir.py b/xdsl/ir.py index e6672ca46c..6fcfc8d3c8 100644 --- a/xdsl/ir.py +++ b/xdsl/ir.py @@ -1,16 +1,17 @@ from __future__ import annotations +import re from abc import ABC, abstractmethod from dataclasses import dataclass, field from frozenlist import FrozenList from io import StringIO from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, Protocol, - Sequence, TypeVar, cast, Iterator, Union) + Sequence, TypeVar, cast, Iterator, Union, ClassVar) import sys # Used for cyclic dependencies in type hints if TYPE_CHECKING: - from xdsl.parser import Parser + from xdsl.parser import Parser, BaseParser from xdsl.printer import Printer from xdsl.irdl import OpDef, ParamAttrDef @@ -49,6 +50,8 @@ class MLContext: """Contains structures for operations/attributes registration.""" _registeredOps: dict[str, type[Operation]] = field(default_factory=dict) _registeredAttrs: dict[str, type[Attribute]] = field(default_factory=dict) + registered_unregistered_ops: dict[str, type[Operation]] = field( + default_factory=dict) def register_dialect(self, dialect: Dialect): """Register a dialect. Operation and Attribute names should be unique""" @@ -118,7 +121,19 @@ class SSAValue(ABC): uses: set[Use] = field(init=False, default_factory=set, repr=False) """All uses of the value.""" - name: str | None = field(init=False, default=None) + _name: str | None = field(init=False, default=None) + + _name_regex: ClassVar[re.Pattern] = re.compile( + r'[A-Za-z0-9._$-]*[A-Za-z._$-]') + + @property + def name(self) -> str | None: + return self._name + + @name.setter + def name(self, name: str): + if self._name_regex.fullmatch(name): + self._name = name @staticmethod def get(arg: SSAValue | Operation) -> SSAValue: @@ -284,7 +299,7 @@ class Data(Generic[DataElement], Attribute, ABC): @staticmethod @abstractmethod - def parse_parameter(parser: Parser) -> DataElement: + def parse_parameter(parser: BaseParser) -> DataElement: """Parse the attribute parameter.""" @staticmethod @@ -299,7 +314,7 @@ class ParametrizedAttribute(Attribute): parameters: list[Attribute] = field(default_factory=list) @staticmethod - def parse_parameters(parser: Parser) -> list[Attribute]: + def parse_parameters(parser: BaseParser) -> list[Attribute]: """Parse the attribute parameters.""" return parser.parse_paramattr_parameters() @@ -497,7 +512,7 @@ def verify_(self) -> None: @classmethod def parse(cls: type[_OperationType], result_types: list[Attribute], - parser: Parser) -> _OperationType: + parser: BaseParser) -> _OperationType: return parser.parse_op_with_default_format(cls, result_types) def print(self, printer: Printer): @@ -623,6 +638,8 @@ def irdl_definition(cls) -> OpDef: class Block(IRNode): """A sequence of operations""" + declared_at: 'Span' | None = None + _args: FrozenList[BlockArgument] = field(default_factory=FrozenList, init=False) """The basic block arguments.""" diff --git a/xdsl/parser.py b/xdsl/parser.py index ad01f0ebb4..85e32fc41b 100644 --- a/xdsl/parser.py +++ b/xdsl/parser.py @@ -1,1336 +1,1881 @@ from __future__ import annotations +import ast +import contextlib +import functools +import itertools +import re +import sys +import traceback +from abc import ABC, abstractmethod +from collections import defaultdict from dataclasses import dataclass, field from enum import Enum -from typing import Any, TypeVar +from io import StringIO +from typing import TypeVar, Iterable +from xdsl.utils.exceptions import ParseError, MultipleSpansParseError from xdsl.dialects.memref import MemRefType, UnrankedMemrefType +from xdsl.dialects.builtin import ( + AnyTensorType, AnyVectorType, Float16Type, Float32Type, Float64Type, + FloatAttr, FunctionType, IndexType, IntegerType, Signedness, StringAttr, + IntegerAttr, ArrayAttr, TensorType, UnrankedTensorType, VectorType, + FlatSymbolRefAttr, DenseIntOrFPElementsAttr, UnregisteredOp, OpaqueAttr, + NoneAttr, ModuleOp, UnitAttr, i64) from xdsl.ir import (SSAValue, Block, Callable, Attribute, Operation, Region, - BlockArgument, MLContext, ParametrizedAttribute) + BlockArgument, MLContext, ParametrizedAttribute, Data) -from xdsl.dialects.builtin import ( - AnyFloat, AnyTensorType, AnyUnrankedTensorType, AnyVectorType, - DenseIntOrFPElementsAttr, Float16Type, Float32Type, Float64Type, FloatAttr, - FunctionType, IndexType, IntegerType, OpaqueAttr, Signedness, StringAttr, - FlatSymbolRefAttr, IntegerAttr, ArrayAttr, TensorType, UnitAttr, - UnrankedTensorType, UnregisteredOp, VectorType, DictionaryAttr) -from xdsl.irdl import Data -indentNumSpaces = 2 +@dataclass +class BacktrackingHistory: + """ + This class holds on to past errors encountered during parsing. + + Given the following error message: + :2:12 + %0 : !invalid = arith.constant() ["value" = 1 : !i32] + ^^^^^^^ + 'invalid' is not a known attribute + + :2:7 + %0 : !invalid = arith.constant() ["value" = 1 : !i32] + ^ + Expected type of value-id here! + + The BacktrackingHistory will contain the outermost error (expected type of value-id here) + It's parent will be the next error message (not a known attribute). + Some errors happen in named regions (e.g. "parsing of operation") + """ + error: ParseError + parent: BacktrackingHistory | None + region_name: str | None + pos: int + + def print_unroll(self, file=sys.stderr): + if self.parent: + if self.parent.get_farthest_point() > self.pos: + self.parent.print_unroll(file) + self.print(file) + else: + self.print(file) + self.parent.print_unroll(file) + + def print(self, file=sys.stderr): + print("Parsing of {} failed:".format(self.region_name or ""), + file=file) + self.error.print_pretty(file=file) + + @functools.cache + def get_farthest_point(self) -> int: + """ + Find the farthest this history managed to parse + """ + if self.parent: + return max(self.pos, self.parent.get_farthest_point()) + return self.pos + + def iterate(self) -> Iterable[BacktrackingHistory]: + yield self + if self.parent: + yield from self.parent.iterate() + def __hash__(self): + return id(self) -@dataclass(frozen=True) -class Position: - """A position in a file""" - file: str +@dataclass(frozen=True) +class Span: """ - A handle to the file contents. The position is relative to this file. + Parts of the input are always passed around as spans, so we know where they originated. """ - idx: int = field(default=0) + start: int + """ + Start of tokens location in source file, global byte offset in file + """ + end: int """ - The character index in the entire file. - A line break is consider to be a character here. + End of tokens location in source file, global byte offset in file """ + input: Input + """ + The input being operated on + """ + + def __len__(self): + return self.len - line: int = field(default=1) - """The line index.""" + @property + def len(self): + return self.end - self.start + + @property + def text(self): + return self.input.content[self.start:self.end] + + def get_line_col(self) -> tuple[int, int]: + info = self.input.get_lines_containing(self) + if info is None: + return -1, -1 + lines, offset_of_first_line, line_no = info + return line_no, self.start - offset_of_first_line + + def print_with_context(self, msg: str | None = None) -> str: + """ + returns a string containing lines relevant to the span. The Span's contents + are highlighted by up-carets beneath them (`^`). The message msg is printed + along these. + """ + info = self.input.get_lines_containing(self) + if info is None: + return "Unknown location of span {}. Error: ".format(self, msg) + lines, offset_of_first_line, line_no = info + # Offset relative to the first line: + offset = self.start - offset_of_first_line + remaining_len = max(self.len, 1) + capture = StringIO() + print("{}:{}:{}".format(self.input.name, line_no, offset), + file=capture) + for line in lines: + print(line, file=capture) + if remaining_len < 0: + continue + len_on_this_line = min(remaining_len, len(line) - offset) + remaining_len -= len_on_this_line + print("{}{}".format(" " * offset, "^" * max(len_on_this_line, 1)), + file=capture) + if msg is not None: + print("{}{}".format(" " * offset, msg), file=capture) + msg = None + offset = 0 + if msg is not None: + print(msg, file=capture) + return capture.getvalue() + + def __repr__(self): + return "{}[{}:{}](text='{}')".format(self.__class__.__name__, + self.start, self.end, self.text) + + +@dataclass(frozen=True, repr=False) +class StringLiteral(Span): - column: int = field(default=1) - """The character index in the current line.""" + def __post_init__(self): + if len(self) < 2 or self.text[0] != '"' or self.text[-1] != '"': + raise ParseError(self, "Invalid string literal!") - def __str__(self): - return f"{self.line}:{self.column}" + @classmethod + def from_span(cls, span: Span | None) -> StringLiteral | None: + """ + Convert a normal span into a StringLiteral, to facilitate parsing. - def next_char_pos(self, n: int = 1) -> Position | None: - """Return the position of the next character in the string.""" - if self.idx >= len(self.file) - n: + If argument is None, returns None. + """ + if span is None: return None - new_idx = self.idx - new_line = self.line - new_column = self.column - while n > 0: - if self.file[new_idx] == '\n': - new_line += 1 - new_column = 1 - else: - new_column += 1 - new_idx += 1 - n -= 1 - assert new_idx < len(self.file) - return Position(self.file, new_idx, new_line, new_column) - - def get_char(self) -> str: - """Return the character at the current position.""" - assert self.idx < len(self.file) - return self.file[self.idx] - - def get_current_line(self) -> str: - """Return the current line.""" - assert self.idx < len(self.file) - start_idx = self.idx - self.column + 1 - end_idx = self.idx - while self.file[end_idx] != '\n': - end_idx += 1 - return self.file[start_idx:end_idx] + return cls(span.start, span.end, span.input) + @property + def string_contents(self): + # TODO: is this a hack-job? + return ast.literal_eval(self.text) -@dataclass -class ParserError(Exception): - """An error triggered during parsing.""" - pos: Position | None - message: str +@dataclass(frozen=True) +class Input: + """ + This is a very simple class that is used to keep track of the input. + """ + content: str = field(repr=False) + name: str + + @property + def len(self): + return len(self.content) + + def __len__(self): + return self.len + + def get_lines_containing(self, + span: Span) -> tuple[list[str], int, int] | None: + # A pointer to the start of the first line + start = 0 + line_no = 0 + source = self.content + while True: + next_start = source.find('\n', start) + line_no += 1 + # Handle eof + if next_start == -1: + if span.start > len(source): + return None + return [source[start:]], start, line_no + # As long as the next newline comes before the spans start we can continue + if next_start < span.start: + start = next_start + 1 + continue + # If the whole span is on one line, we are good as well + if next_start >= span.end: + return [source[start:next_start]], start, line_no + while next_start < span.end: + next_start = source.find('\n', next_start + 1) + return source[start:next_start].split('\n'), start, line_no + + def at(self, i: int): + if i >= self.len: + raise EOFError() + return self.content[i] + - def __str__(self): - if self.pos is None: - return f"Parsing error at end of file :{self.message}\n" - message = f"Parsing error at {self.pos}:\n" - message += self.pos.get_current_line() + '\n' - message += " " * (self.pos.column - 1) + "^\n" - message += self.message + '\n' - return message +save_t = tuple[int, tuple[str, ...]] @dataclass -class Parser: +class Tokenizer: + """ + This class is used to tokenize an Input. - class Source(Enum): - XDSL = 1 - MLIR = 2 + It provides an interface for backtracking, so you can use: - ctx: MLContext - """xDSL context.""" + with tokenizer.backtracking(): + # Try stuff + raise ParseError(...) - str: str - """The current file/input to parse.""" + and not worry about manually resetting the input position. Backtracking will also + record errors that happen during backtracking to provide a richer error reporting + experience. + + It also provides the following methods to inspect the input: + + - next_token(peek) is used to get the next token + (which just breaks the input as per the rules defined in break_on) + peek=True doesn't advance the position in the file. + - next_token_of_pattern(pattern, peek) can be used to get a next token if it + conforms to a specific pattern. If a literal string is given, it'll check + if the next characters match. If a regex is given, it will check + the regex. + - starts_with(pattern) checks if the input starts with a literal string or + regex pattern + """ - source: Source = field(default=Source.XDSL, kw_only=True) - """The source language to parse.""" + input: Input - allow_unregistered_ops: bool = field(default=False, kw_only=True) - """Allow the parsing of unregistered ops.""" + pos: int = field(init=False, default=0) + """ + The position in the input. Points to the first unconsumed character. + """ - _pos: Position | None = field(init=False) - """Position in the file. None represent the end of the file.""" + break_on: tuple[str, ...] = ('.', '%', ' ', '(', ')', '[', ']', '{', '}', + '<', '>', ':', '=', '@', '?', '|', '->', '-', + '//', '\n', '\t', '#', '"', "'", ',', '!') + """ + characters the tokenizer should break on + """ - _ssaValues: dict[str, SSAValue] = field(init=False, default_factory=dict) - """Associate SSA values with their names.""" + history: BacktrackingHistory | None = field(init=False, + default=None, + repr=False) - _blocks: dict[str, Block] = field(init=False, default_factory=dict) - """Associate blocks with their names.""" + last_token: Span | None = field(init=False, default=None, repr=False) def __post_init__(self): - if len(self.str) == 0: - self._pos = None - else: - self._pos = Position(self.str) - - def get_pos(self) -> Position | None: - """Return the current position.""" - return self._pos - - def get_char(self, - n: int = 1, - skip_white_space: bool = True) -> str | None: - """Get the next n characters (including the current one)""" - assert n >= 0 - if skip_white_space: - self.skip_white_space() - if self._pos is None: - return None - if self._pos.idx + n > len(self.str): - return None - return self.str[self._pos.idx:self._pos.idx + n] + self.last_token = self.next_token(peek=True) + + def save(self) -> save_t: + """ + Create a checkpoint in the parsing process, useful for backtracking + """ + return self.pos, self.break_on + + def resume_from(self, save: save_t): + """ + Resume from a previously saved position. - _T = TypeVar("_T") + Restores the state of the tokenizer to the exact previous position + """ + self.pos, self.break_on = save - def try_parse(self, - parse_fn: Callable[[], _T | None], - skip_white_space: bool = True) -> _T | None: + @contextlib.contextmanager + def backtracking(self, region_name: str | None = None): """ - Wrap a parsing function. If the parsing fails, then return without - any change to the current position. + This context manager can be used to mark backtracking regions. + + When an error is thrown during backtracking, it is recorded and stored together + with some meta information in the history attribute. + + The backtracker accepts the following exceptions: + - ParseError: signifies that the region could not be parsed because of (unexpected) syntax errors + - AssertionError: this error should probably be phased out in favour of the two above + - EOFError: signals that EOF was reached unexpectedly + + Any other error will be printed to stderr, but backtracking will continue as normal. """ - if skip_white_space: - self.skip_white_space() - start_pos = self._pos + save = self.save() + starting_position = self.pos try: - return parse_fn() - except ParserError: - pass - self._pos = start_pos - return None + yield + # Clear error history when something doesn't fail + # This is because we are only interested in the last "cascade" of failures. + # If a backtracking() completes without failure, something has been parsed (we assume) + if self.pos > starting_position and self.history is not None: + self.history = None + except Exception as ex: + how_far_we_got = self.pos + + # If we have no error history, start recording! + if not self.history: + self.history = self._history_entry_from_exception( + ex, region_name, how_far_we_got) + + # If we got further than on previous attempts + elif how_far_we_got > self.history.get_farthest_point(): + # Throw away history + self.history = None + # Generate new history entry, + self.history = self._history_entry_from_exception( + ex, region_name, how_far_we_got) + + # Otherwise, add to exception, if we are in a named region + elif region_name is not None and how_far_we_got - starting_position > 0: + self.history = self._history_entry_from_exception( + ex, region_name, how_far_we_got) + + self.resume_from(save) + + def _history_entry_from_exception(self, ex: Exception, region: str, + pos: int) -> BacktrackingHistory: + """ + Given an exception generated inside a backtracking attempt, + generate a BacktrackingHistory object with the relevant information in it. - def skip_white_space(self) -> None: - while pos := self._pos: - char = pos.get_char() - if char.isspace(): - self._pos = pos.next_char_pos() - elif self.get_char(2, skip_white_space=False) == "//": - self.parse_while(lambda x: x != '\n', False) - else: - return - - def parse_while(self, - cond: Callable[[str], bool], - skip_white_space: bool = True) -> str: - if skip_white_space: - self.skip_white_space() - start_pos = self._pos - if start_pos is None: - return "" - while self._pos: - char = self._pos.get_char() - if not cond(char): - return self.str[start_pos.idx:self._pos.idx] - self._pos = self._pos.next_char_pos() - return self.str[start_pos.idx:] - - # TODO why two different functions, no nums in ident? - def parse_optional_ident(self, - skip_white_space: bool = True) -> str | None: - res = self.parse_while(lambda x: x.isalpha() or x == "_" or x == ".", - skip_white_space=skip_white_space) - if len(res) == 0: - return None - return res + If an unexpected exception type is encountered, print a traceback to stderr + """ + if isinstance(ex, ParseError): + return BacktrackingHistory(ex, self.history, region, pos) + elif isinstance(ex, AssertionError): + reason = [ + "Generic assertion failure", + *(reason for reason in ex.args if isinstance(reason, str)), + ] + # We assume that assertions fail because of the last read-in token + if len(reason) == 1: + tb = StringIO() + traceback.print_exc(file=tb) + reason[0] += "\n" + tb.getvalue() + + return BacktrackingHistory( + ParseError(self.last_token, reason[-1], self.history), + self.history, + region, + pos, + ) + elif isinstance(ex, EOFError): + return BacktrackingHistory( + ParseError(self.last_token, "Encountered EOF", self.history), + self.history, + region, + pos, + ) + + print("Warning: Unexpected error in backtracking:", file=sys.stderr) + traceback.print_exception(ex, file=sys.stderr) + + return BacktrackingHistory( + ParseError(self.last_token, "Unexpected exception: {}".format(ex), + self.history), + self.history, + region, + pos, + ) + + def next_token(self, peek: bool = False) -> Span: + """ + Return a Span of the next token, according to the self.break_on rules. - def parse_ident(self, skip_white_space: bool = True) -> str: - res = self.parse_optional_ident(skip_white_space=skip_white_space) - if res is None: - raise ParserError(self._pos, "ident expected") - return res + Can be modified using: + - peek: don't advance the position, only "peek" at the input - def parse_optional_alpha_num(self, - skip_white_space: bool = True) -> str | None: - res = self.parse_while(lambda x: x.isalnum() or x == "_" or x == ".", - skip_white_space=skip_white_space) - if len(res) == 0: + This will skip over line comments. Meaning it will skip the entire line if it encounters '//' + """ + i = self.next_pos() + # Construct the span: + span = Span(i, self._find_token_end(i), self.input) + # Advance pointer if not peeking + if not peek: + self.pos = span.end + + # Save last token + self.last_token = span + return span + + def next_token_of_pattern(self, + pattern: re.Pattern | str, + peek: bool = False) -> Span | None: + """ + Return a span that matched the pattern, or nothing. You can choose not to consume the span. + """ + try: + start = self.next_pos() + except EOFError: return None - return res - def parse_alpha_num(self, skip_white_space: bool = True) -> str: - res = self.parse_optional_alpha_num(skip_white_space=skip_white_space) - if res is None: - raise ParserError(self._pos, "alphanum expected") - return res + # Handle search for string literal + if isinstance(pattern, str): + if self.starts_with(pattern): + if not peek: + self.pos = start + len(pattern) + return Span(start, start + len(pattern), self.input) + return None - def parse_optional_str_literal(self, - skip_white_space: bool = True - ) -> str | None: - parsed = self.parse_optional_char('"', - skip_white_space=skip_white_space) - if parsed is None: + # Handle regex logic + match = pattern.match(self.input.content, start) + if match is None: return None - start_pos = self._pos - if start_pos is None: - raise ParserError(None, "Unexpected end of file") - while self._pos: - pos = self._pos - char = pos.get_char() - if char == '\\': - if next_pos := pos.next_char_pos(): - escaped = next_pos.get_char() - if escaped in ['\\', 'n', 't', 'r', '"']: - self._pos = next_pos.next_char_pos() - continue - else: - raise ParserError( - next_pos, - f"Unrecognized escaped character: \\{escaped}") - else: - raise ParserError(None, "Unexpected end of file") - elif char == '"': - break - self._pos = pos.next_char_pos() - if self._pos is None: - res = self.str[start_pos.idx:] - else: - res = self.str[start_pos.idx:self._pos.idx] - self.parse_char('"') - return res - def parse_str_literal(self, skip_white_space: bool = True) -> str: - res = self.parse_optional_str_literal( - skip_white_space=skip_white_space) - if res is None: - raise ParserError(self._pos, "string literal expected") - return res + if not peek: + self.pos = match.end() - def parse_optional_int_literal(self, - skip_white_space: bool = True - ) -> int | None: - is_negative = self.parse_optional_char( - "-", skip_white_space=skip_white_space) - res = self.parse_while(lambda char: char.isnumeric(), - skip_white_space=False) - if len(res) == 0: - if is_negative: - raise ParserError(self._pos, "int literal expected") - return None - return int(res) if is_negative is None else -int(res) + # Save last token + self.last_token = Span(start, match.end(), self.input) + return self.last_token - def parse_int_literal(self, skip_white_space: bool = True) -> int: - res = self.parse_optional_int_literal( - skip_white_space=skip_white_space) - if res is None: - raise ParserError(self._pos, "int literal expected") - return res + def consume_peeked(self, peeked_span: Span): + if peeked_span.start != self.next_pos(): + raise ParseError(peeked_span, "This is not the peeked span!") + self.pos = peeked_span.end - def parse_optional_float_literal(self, - skip_white_space: bool = True - ) -> float | None: - return self.try_parse(self.parse_float_literal, - skip_white_space=skip_white_space) - - def parse_float_literal(self, skip_white_space: bool = True) -> float: - # Parse the optional sign - value = "" - if self.parse_optional_char("+", skip_white_space=skip_white_space): - value += "+" - elif self.parse_optional_char("-", skip_white_space=False): - value += "-" - - # Parse the significant digits - digits = self.parse_while(lambda x: x.isdigit(), - skip_white_space=False) - if digits == "": - raise ParserError(self._pos, "float literal expected") - value += digits - - # Check that we are parsing a float, and not an integer - is_float = False - - # Parse the optional decimal point - if self.parse_optional_char(".", skip_white_space=False): - # Parse the fractional digits - value += "." - value += self.parse_while(lambda x: x.isdigit(), - skip_white_space=False) - is_float = True - - # Parse the optional exponent - if self.parse_optional_char( - "e", skip_white_space=False) or self.parse_optional_char( - "E", skip_white_space=False): - value += "e" - # Parse the optional exponent sign - if self.parse_optional_char("+", skip_white_space=False): - value += "+" - elif self.parse_optional_char("-", skip_white_space=False): - value += "-" - # Parse the exponent digits - value += self.parse_while(lambda x: x.isdigit(), - skip_white_space=False) - is_float = True - - if not is_float: - raise ParserError( - self._pos, - "float literal expected, but got an integer literal instead") - - return float(value) - - def peek_char(self, - char: str, - skip_white_space: bool = True) -> bool | None: - if skip_white_space: - self.skip_white_space() - if self.get_char() == char: - return True - return None + def _find_token_end(self, start: int | None = None) -> int: + """ + Find the point (optionally starting from start) where the token ends + """ + i = self.next_pos() if start is None else start + # Search for literal breaks + for part in self.break_on: + if self.input.content.startswith(part, i): + return i + len(part) + # Otherwise return the start of the next break + return min( + filter( + lambda x: x >= 0, + (self.input.content.find(part, i) for part in self.break_on), + )) + + def next_pos(self, i: int | None = None) -> int: + """ + Find the next starting position (optionally starting from i) - def parse_optional_char(self, - char: str, - skip_white_space: bool = True) -> bool | None: - assert len(char) == 1 - if skip_white_space: - self.skip_white_space() - if self._pos is None: - return None - if self._pos.get_char() == char: - self._pos = self._pos.next_char_pos() - return True - return None + This will skip line comments! + """ + i = self.pos if i is None else i + # Skip whitespaces + while self.input.at(i).isspace(): + i += 1 - def parse_char(self, char: str, skip_white_space: bool = True) -> bool: - assert (len(char) == 1) - current_char = self.get_char() - res = self.parse_optional_char(char, skip_white_space=skip_white_space) - if res is None: - raise ParserError(self._pos, - f"'{char}' expected, got '{current_char}'") - return True - - def parse_string(self, - contents: str, - skip_white_space: bool = True) -> bool: - if skip_white_space: - self.skip_white_space() - chars = self.get_char(len(contents)) - if chars == contents: - assert self._pos - self._pos = self._pos.next_char_pos(len(contents)) - return True - raise ParserError(self._pos, f"'{contents}' expected") - - def parse_optional_string(self, - contents: str, - skip_white_space: bool = True) -> bool | None: - if skip_white_space: - self.skip_white_space() - chars = self.get_char(len(contents)) - if chars == contents: - assert self._pos is not None - self._pos = self._pos.next_char_pos(len(contents)) + # Skip comments as well + if self.input.content.startswith("//", i): + i = self.input.content.find("\n", i) + 1 + return self.next_pos(i) + + return i + + def is_eof(self): + """ + Check if the end of the input was reached. + """ + try: + self.next_pos() + return False + except EOFError: return True - return None - T = TypeVar('T') + @contextlib.contextmanager + def configured(self, break_on: tuple[str, ...]): + """ + This is a helper class to allow expressing a temporary change in config, allowing you to write: - def parse_optional_nested_list( - self, - parse_optional_one: Callable[[], T | None], - delimiter: str = ",", - brackets: str = "[]", - skip_white_space: bool = True) -> list[T] | None: - ''' - Parse and flatten a list of lists. The result is a list of elements, no matter the - rank of the input. - Delimiter must be length one, for example ",". - Brackets must be length two, for example "[]". - ''' - - assert len(delimiter) == 1 - assert len(brackets) == 2 - - open_bracket, close_bracket = brackets - if not self.parse_optional_char(open_bracket, - skip_white_space=skip_white_space): - # This is not a list that opens with the opening bracket - return None + # Parsing double-quoted string now + string_content = "" + with tokenizer.configured(break_on=('"', '\\'),): + # Use tokenizer - indices = [0] + # Now old config is restored automatically - res = list[Any]() # Pyright does not let us use `T` here + """ + save = self.save() - while len(indices) > 0: - if self.parse_optional_char(close_bracket, - skip_white_space=skip_white_space): - # This is the end of a list - indices.pop() - if len(indices) > 0: - indices[-1] += 1 - continue + if break_on is not None: + self.break_on = break_on - if indices[-1]: - # If we're not at the end of the list, then it's a delimiter followed by - # the next eleement, which might be a nested list. - self.parse_char(delimiter, skip_white_space=skip_white_space) + try: + yield self + finally: + self.break_on = save[1] - if self.parse_optional_char(open_bracket, - skip_white_space=skip_white_space): - # A new nested list, reset the index - indices.append(0) - else: - # This must be a list element - one = parse_optional_one() - if one is None: - raise ParserError(self._pos, 'Expected list element') - res.append(one) - indices[-1] += 1 + def starts_with(self, text: str | re.Pattern) -> bool: + try: + start = self.next_pos() + if isinstance(text, re.Pattern): + return text.match(self.input.content, start) is None + return self.input.content.startswith(text, start) + except EOFError: + return False - return res - def parse_list(self, - parse_optional_one: Callable[[], T | None], - delimiter: str = ",", - skip_white_space: bool = True) -> list[T]: - if skip_white_space: - self.skip_white_space() - assert (len(delimiter) <= 1) - res = list[Any]() # Pyright do not let us use `T` here - one = parse_optional_one() - if one is not None: - res.append(one) - while self.parse_optional_char(delimiter) if len( - delimiter) == 1 else True: - one = parse_optional_one() - if one is None: - return res - res.append(one) - return res +class ParserCommons: + """ + Collection of common things used in parsing MLIR/IRDL - K = TypeVar('K') - V = TypeVar('V') - - def parse_dictionary(self, - parse_key: Callable[[], K], - parse_value: Callable[[], V], - delimiter: str = ",", - skip_white_space: bool = True) -> dict[K, V]: - if skip_white_space: - self.skip_white_space() - assert (len(delimiter) <= 1) - if len(delimiter): - parse_delimiter = lambda: self.parse_char(delimiter) - else: - parse_delimiter = lambda: True + """ - self.parse_char("{") - if self.peek_char("}"): - return {} + integer_literal = re.compile(r"[+-]?([0-9]+|0x[0-9A-Fa-f]+)") + decimal_literal = re.compile(r"[+-]?([1-9][0-9]*)") + string_literal = re.compile(r'"(\\[nfvtr"\\]|[^\n\f\v\r"\\])*"') + float_literal = re.compile(r"[-+]?[0-9]+\.[0-9]*([eE][-+]?[0-9]+)?") + bare_id = re.compile(r"[A-Za-z_][\w$.]+") + value_id = re.compile(r"%([0-9]+|([A-Za-z_$.-][\w$.-]*))") + suffix_id = re.compile(r"([0-9]+|([A-Za-z_$.-][\w$.-]*))") + """ + suffix-id ::= (digit+ | ((letter|id-punct) (letter|id-punct|digit)*)) + id-punct ::= [$._-] + """ + block_id = re.compile(r"\^([0-9]+|([A-Za-z_$.-][\w$.-]*))") + type_alias = re.compile(r"![A-Za-z_][\w$.]+") + attribute_alias = re.compile(r"#[A-Za-z_][\w$.]+") + boolean_literal = re.compile(r"(true|false)") + # A list of names that are builtin types + _builtin_type_names = ( + r"[su]?i\d+", r"f\d+", "tensor", "vector", "memref", "complex", + "opaque", "tuple", "index", "dense" + # TODO: add all the Float8E4M3FNType, Float8E5M2Type, and BFloat16Type + ) + builtin_attr_names = ('dense', 'opaque', 'affine_map', 'array', + 'dense_resource', 'sparse') + builtin_type = re.compile("(({}))".format(")|(".join(_builtin_type_names))) + builtin_type_xdsl = re.compile("!(({}))".format( + ")|(".join(_builtin_type_names))) + double_colon = re.compile("::") + comma = re.compile(",") + + +class BaseParser(ABC): + """ + Basic recursive descent parser. - key, value = self.parse_dict_entry(parse_key, parse_value) - res = {key: value} - while not self.peek_char("}"): - parse_delimiter() - key, value = self.parse_dict_entry(parse_key, parse_value) - res[key] = value + methods marked try_... will attempt to parse, and return None if they failed. If they return None + they must make sure to restore all state. - self.parse_char("}") + methods marked parse_... will do "greedy" parsing, meaning they consume as much as they can. They will + also throw an error if the think they should still be parsing. e.g. when parsing a list of numbers + separated by '::', the following input will trigger an exception: + 1::2:: + Due to the '::' present after the last element. This is useful for parsing lists, as a trailing + separator is usually considered a syntax error there. - return res + must_ type parsers are preferred because they are explicit about their failure modes. + """ - def parse_dict_entry( - self, - parse_key: Callable[[], K], - parse_value: Callable[[], V], - ) -> tuple[K, V]: - key = parse_key() - self.parse_char("=") - value = parse_value() - return key, value - - def parse_optional_block_argument( - self, - skip_white_space: bool = True) -> tuple[str, Attribute] | None: - name = self.parse_optional_ssa_name(skip_white_space=skip_white_space) - if name is None: - return None - self.parse_char(":") - typ = self.parse_attribute() - # TODO how to get the id? - return name, typ - - def parse_optional_named_block(self, - skip_white_space: bool = True - ) -> Block | None: - if self.parse_optional_char("^", - skip_white_space=skip_white_space) is None: - return None - block_name = self.parse_alpha_num(skip_white_space=False) - if block_name in self._blocks: - block = self._blocks[block_name] + ctx: MLContext + """xDSL context.""" + + ssaValues: dict[str, SSAValue] + blocks: dict[str, Block] + forward_block_references: dict[str, list[Span]] + """ + Blocks we encountered references to before the definition (must be empty after parsing of region completes) + """ + + T_ = TypeVar("T_") + """ + Type var used for handling function that return single or multiple Spans. Basically the output type + of all try_parse functions is T_ | None + """ + + allow_unregistered_ops: bool + + def __init__(self, + ctx: MLContext, + input: str, + name: str = '', + allow_unregistered_ops=False): + self.tokenizer = Tokenizer(Input(input, name)) + self.ctx = ctx + self.ssaValues = dict() + self.blocks = dict() + self.forward_block_references = dict() + self.allow_unregistered_ops = allow_unregistered_ops + + def parse_module(self) -> ModuleOp: + op = self.try_parse_operation() + + if op is None: + self.raise_error("Could not parse entire input!") + + if isinstance(op, ModuleOp): + return op else: - block = Block() - self._blocks[block_name] = block - - if self.parse_optional_char("("): - tuple_list = self.parse_list(self.parse_optional_block_argument) - # Register the BlockArguments as ssa values and add them to - # the block - for (idx, (arg_name, arg_type)) in enumerate(tuple_list): - if arg_name in self._ssaValues: - raise ParserError( - self._pos, f"SSA value {arg_name} is already defined") - arg = BlockArgument(arg_type, block, idx) - self._ssaValues[arg_name] = arg - block.args.append(arg) - - self.parse_char(")") - self.parse_char(":") - for op in self.parse_list(self.parse_optional_op, delimiter=""): - block.add_op(op) - return block + self.tokenizer.pos = 0 + self.raise_error("Expected ModuleOp at top level!", + self.tokenizer.next_token()) - def parse_optional_region(self, - skip_white_space: bool = True) -> Region | None: - if not self.parse_optional_char("{", - skip_white_space=skip_white_space): - return None - region = Region() - oldSSAVals = self._ssaValues.copy() - oldBBNames = self._blocks.copy() - self._blocks = dict[str, Block]() + def get_ssa_val(self, name: Span) -> SSAValue: + if name.text not in self.ssaValues: + self.raise_error('SSA Value used before assignment', name) + return self.ssaValues[name.text] - if self.peek_char('^'): - for block in self.parse_list(self.parse_optional_named_block, - delimiter=""): - region.add_block(block) + def _get_block_from_name(self, block_name: Span) -> Block: + """ + This function takes a span containing a block id (like `^42`) and returns a block. + + If the block definition was not seen yet, we create a forward declaration. + """ + name = block_name.text + if name not in self.blocks: + self.forward_block_references[name].append(block_name) + self.blocks[name] = Block() + return self.blocks[name] + + def parse_block(self) -> Block: + block_id, args = self._parse_optional_block_label() + + if block_id is None: + block = Block(self.tokenizer.last_token) + elif self.forward_block_references.pop(block_id.text, + None) is not None: + block = self.blocks[block_id.text] + block.declared_at = block_id else: - region.add_block(Block()) - for op in self.parse_list(self.parse_optional_op, delimiter=""): - region.blocks[0].add_op(op) - self.parse_char("}") - - self._ssaValues = oldSSAVals - self._blocks = oldBBNames - return region - - def parse_optional_ssa_name(self, - skip_white_space: bool = True) -> str | None: - if self.parse_optional_char("%", - skip_white_space=skip_white_space) is None: - return None - name = self.parse_alpha_num() - return name - - def parse_optional_ssa_value(self, - skip_white_space: bool = True - ) -> SSAValue | None: - if skip_white_space: - self.skip_white_space() - start_pos = self._pos - name = self.parse_optional_ssa_name() - if name is None: - return None - if name not in self._ssaValues: - raise ParserError(start_pos, - f"name {name} does not refer to a SSA value") - return self._ssaValues[name] + if block_id.text in self.blocks: + raise MultipleSpansParseError( + block_id, + "Re-declaration of block {}".format(block_id.text), + "Originally declared here:", + [(self.blocks[block_id.text].declared_at, None)], + self.tokenizer.history, + ) + block = Block(block_id) + self.blocks[block_id.text] = block + + for i, (name, type) in enumerate(args): + arg = BlockArgument(type, block, i) + self.ssaValues[name.text] = arg + block.args.append(arg) + + while (next_op := self.try_parse_operation()) is not None: + block.add_op(next_op) - def parse_ssa_value(self, skip_white_space: bool = True) -> SSAValue: - res = self.parse_optional_ssa_value(skip_white_space=skip_white_space) - if res is None: - raise ParserError(self._pos, "SSA value expected") - return res + return block - def parse_optional_results(self, - skip_white_space: bool = True - ) -> list[str] | None: - res = self.parse_list(self.parse_optional_ssa_name, - skip_white_space=skip_white_space) - if len(res) == 0: - return None - self.parse_char("=") - return res + def _parse_optional_block_label( + self) -> tuple[Span | None, list[tuple[Span, Attribute]]]: + """ + A block label consists of block-id ( `(` block-arg `,` ... `)` )? + """ + block_id = self.try_parse_block_id() + arg_list = list() - def parse_optional_typed_result( - self, - skip_white_space: bool = True) -> tuple[str, Attribute] | None: - name = self.parse_optional_ssa_name(skip_white_space=skip_white_space) - if name is None: + if block_id is not None: + if self.tokenizer.starts_with('('): + arg_list = self._parse_block_arg_list() + + self.parse_characters(':', 'Block label must end in a `:`!') + + return block_id, arg_list + + def _parse_block_arg_list(self) -> list[tuple[Span, Attribute]]: + self.parse_characters('(', 'Block arguments must start with `(`') + + args = self.parse_list_of(self.try_parse_value_id_and_type, + "Expected value-id and type here!") + + self.parse_characters(')', 'Expected closing of block arguments!') + + return args + + def try_parse_single_reference(self) -> Span | None: + with self.tokenizer.backtracking('part of a reference'): + self.parse_characters('@', "references must start with `@`") + if (reference := self.try_parse_string_literal()) is not None: + return reference + if (reference := self.try_parse_suffix_id()) is not None: + return reference + self.raise_error( + "References must conform to `@` (string-literal | suffix-id)") + + def parse_reference(self) -> list[Span]: + return self.parse_list_of( + self.try_parse_single_reference, + 'Expected reference here in the format of `@` (suffix-id | string-literal)', + ParserCommons.double_colon, + allow_empty=False) + + def parse_list_of(self, + try_parse: Callable[[], T_ | None], + error_msg: str, + separator_pattern: re.Pattern = ParserCommons.comma, + allow_empty: bool = True) -> list[T_]: + """ + This is a greedy list-parser. It accepts input only in these cases: + + - If the separator isn't encountered, which signals the end of the list + - If an empty list is allowed, it accepts when the first try_parse fails + - If an empty separator is given, it instead sees a failed try_parse as the end of the list. + + This means, that the setup will not accept the input and instead raise an error: + try_parse = parse_integer_literal + separator = 'x' + input = 3x4x4xi32 + as it will read [3,4,4], then see another separator, and expects the next try_parse call to succeed + (which won't as i32 is not a valid integer literal) + """ + items = list() + first_item = try_parse() + if first_item is None: + if allow_empty: + return items + self.raise_error(error_msg) + + items.append(first_item) + + while (match := self.tokenizer.next_token_of_pattern(separator_pattern) + ) is not None: + next_item = try_parse() + if next_item is None: + # If the separator is emtpy, we are good here + if separator_pattern.pattern == '': + return items + self.raise_error(error_msg + + ' because was able to match next separator {}' + .format(match.text)) + items.append(next_item) + + return items + + def try_parse_integer_literal(self) -> Span | None: + return self.tokenizer.next_token_of_pattern( + ParserCommons.integer_literal) + + def try_parse_decimal_literal(self) -> Span | None: + return self.tokenizer.next_token_of_pattern( + ParserCommons.decimal_literal) + + def try_parse_string_literal(self) -> StringLiteral | None: + return StringLiteral.from_span( + self.tokenizer.next_token_of_pattern(ParserCommons.string_literal)) + + def try_parse_float_literal(self) -> Span | None: + return self.tokenizer.next_token_of_pattern( + ParserCommons.float_literal) + + def try_parse_bare_id(self) -> Span | None: + return self.tokenizer.next_token_of_pattern(ParserCommons.bare_id) + + def try_parse_value_id(self) -> Span | None: + return self.tokenizer.next_token_of_pattern(ParserCommons.value_id) + + def try_parse_suffix_id(self) -> Span | None: + return self.tokenizer.next_token_of_pattern(ParserCommons.suffix_id) + + def try_parse_block_id(self) -> Span | None: + return self.tokenizer.next_token_of_pattern(ParserCommons.block_id) + + def try_parse_boolean_literal(self) -> Span | None: + return self.tokenizer.next_token_of_pattern( + ParserCommons.boolean_literal) + + def try_parse_value_id_and_type(self) -> tuple[Span, Attribute] | None: + with self.tokenizer.backtracking("value id and type"): + value_id = self.try_parse_value_id() + + if value_id is None: + self.raise_error("Invalid value-id format!") + + self.parse_characters(':', + 'Expected expression (value-id `:` type)') + + type = self.try_parse_type() + + if type is None: + self.raise_error("Expected type of value-id here!") + return value_id, type + + def try_parse_type(self) -> Attribute | None: + if (builtin_type := self.try_parse_builtin_type()) is not None: + return builtin_type + if (dialect_type := self.try_parse_dialect_type()) is not None: + return dialect_type + return None + + def try_parse_dialect_type_or_attribute(self) -> Attribute | None: + """ + Parse a type or an attribute. + """ + kind = self.tokenizer.next_token_of_pattern(re.compile('[!#]'), + peek=True) + + if kind is None: return None - self.parse_char(":") - typ = self.parse_attribute() - return name, typ - def parse_optional_typed_results( - self, - skip_white_space: bool = True - ) -> list[tuple[str, Attribute]] | None: - res = self.parse_list(lambda: self.parse_optional_typed_result( - skip_white_space=skip_white_space)) - if len(res) == 0: + with self.tokenizer.backtracking("dialect attribute or type"): + self.tokenizer.consume_peeked(kind) + if kind.text == '!': + return self._parse_dialect_type_or_attribute_inner('type') + else: + return self._parse_dialect_type_or_attribute_inner('attribute') + + def try_parse_dialect_type(self): + """ + Parse a dialect type (something prefixed by `!`, defined by a dialect) + """ + if not self.tokenizer.starts_with('!'): return None - elif len(res) == 1 and res[0] is None: + with self.tokenizer.backtracking("dialect type"): + self.parse_characters('!', "Dialect type must start with a `!`") + return self._parse_dialect_type_or_attribute_inner('type') + + def try_parse_dialect_attr(self): + """ + Parse a dialect attribute (something prefixed by `#`, defined by a dialect) + """ + if not self.tokenizer.starts_with('#'): return None + with self.tokenizer.backtracking("dialect attribute"): + self.parse_characters('#', + "Dialect attribute must start with a `#`") + return self._parse_dialect_type_or_attribute_inner('attribute') + + def _parse_dialect_type_or_attribute_inner(self, kind: str): + type_name = self.tokenizer.next_token_of_pattern(ParserCommons.bare_id) + + if type_name is None: + self.raise_error("Expected dialect {} name here!".format(kind)) + + type_def = self.ctx.get_optional_attr(type_name.text) + if type_def is None: + self.raise_error( + "'{}' is not a know attribute!".format(type_name.text), + type_name) + + # Pass the task of parsing parameters on to the attribute/type definition + if issubclass(type_def, ParametrizedAttribute): + param_list = type_def.parse_parameters(self) + elif issubclass(type_def, Data): + self.parse_characters("<", "This attribute must be parametrized!") + param_list = type_def.parse_parameter(self) + self.parse_characters( + ">", "Invalid attribute parametrization, expected `>`!") else: - self.parse_char("=") - return res - - def parse_optional_operand(self, - skip_white_space: bool = True - ) -> SSAValue | None: - value = self.parse_optional_ssa_value( - skip_white_space=skip_white_space) - if value is None: - return None - if self.source == self.Source.XDSL: - self.parse_char(":") - typ = self.parse_attribute() - if value.typ != typ: - raise ParserError( - self._pos, f"type mismatch between {typ} and {value.typ}") - return value + assert False, "Mathieu said this cannot be." + return type_def(param_list) - def parse_operands(self, skip_white_space: bool = True) -> list[SSAValue]: - self.parse_char("(", skip_white_space=skip_white_space) - res = self.parse_list(lambda: self.parse_optional_operand()) - self.parse_char(")") - return res + @abstractmethod + def try_parse_builtin_type(self) -> Attribute | None: + """ + parse a builtin-type like i32, index, vector etc. + """ + raise NotImplementedError("Subclasses must implement this method!") - def parse_paramattr_parameters( - self, - expect_brackets: bool = False, - skip_white_space: bool = True) -> list[Attribute]: - if expect_brackets: - self.parse_char("<", skip_white_space=skip_white_space) - elif self.parse_optional_char( - "<", skip_white_space=skip_white_space) is None: - return [] + def _parse_builtin_parametrized_type(self, + name: Span) -> ParametrizedAttribute: + """ + This function is called after we parse the name of a parameterized type such as vector. + """ + + def unimplemented() -> ParametrizedAttribute: + raise ParseError(name, + "Builtin {} not supported yet!".format(name.text)) + + builtin_parsers: dict[str, Callable[[], ParametrizedAttribute]] = { + "vector": self.parse_vector_attrs, + "memref": self.parse_memref_attrs, + "tensor": self.parse_tensor_attrs, + "complex": self.parse_complex_attrs, + "tuple": unimplemented, + } + + self.parse_characters("<", "Expected parameter list here!") + # Get the parser for the type, falling back to the unimplemented warning + res = builtin_parsers.get(name.text, unimplemented)() + self.parse_characters(">", "Expected end of parameter list here!") - res = self.parse_list(self.parse_optional_attribute) - self.parse_char(">") return res - def parse_optional_boolean_attribute( - self, - skip_white_space: bool = True) -> IntegerAttr[IntegerType] | None: - if self.parse_optional_string( - "true", skip_white_space=skip_white_space) is not None: - return IntegerAttr.from_int_and_width(1, 1) - if self.parse_optional_string( - "false", skip_white_space=skip_white_space) is not None: - return IntegerAttr.from_int_and_width(0, 1) - - def parse_optional_xdsl_builtin_attribute(self, - skip_white_space: bool = True - ) -> Attribute | None: - # Shorthand for StringAttr - string_lit = self.parse_optional_str_literal( - skip_white_space=skip_white_space) - if string_lit is not None: - return StringAttr.from_str(string_lit) - - # Shorthand for FloatAttr - float_lit = self.parse_optional_float_literal() - if float_lit is not None: - if self.parse_optional_char(":"): - typ = self.parse_attribute() - else: - typ = Float32Type() - return FloatAttr.from_value(float_lit, typ) - - # Shorthand for boolean literals (IntegerAttr of width 1) - if (bool_attr := self.parse_optional_boolean_attribute( - skip_white_space=skip_white_space)): - return bool_attr - - # Shorthand for IntegerAttr - integer_lit = self.parse_optional_int_literal() - if integer_lit is not None: - if self.parse_optional_char(":"): - typ = self.parse_attribute() + def parse_complex_attrs(self): + self.raise_error("ComplexType is unimplemented!") + + def parse_memref_attrs(self) -> MemRefType | UnrankedMemrefType: + dims = self._parse_tensor_or_memref_dims() + type = self.try_parse_type() + if dims is None: + return UnrankedMemrefType.from_type(type) + return MemRefType.from_element_type_and_shape(type, dims) + + def try_parse_numerical_dims(self, + accept_closing_bracket: bool = False, + lower_bound: int = 1) -> Iterable[int]: + while (shape_arg := + self._try_parse_shape_element(lower_bound)) is not None: + yield shape_arg + # Look out for the closing bracket for scalable vector dims + if accept_closing_bracket and self.tokenizer.starts_with("]"): + break + self.parse_characters("x", + "Unexpected end of dimension parameters!") + + def parse_vector_attrs(self) -> AnyVectorType: + # Also break on 'x' characters as they are separators in dimension parameters + with self.tokenizer.configured(break_on=self.tokenizer.break_on + + ("x", )): + shape = list[int](self.try_parse_numerical_dims()) + scaling_shape: list[int] | None = None + + if self.tokenizer.next_token_of_pattern("[") is not None: + # We now need to parse the scalable dimensions + scaling_shape = list(self.try_parse_numerical_dims()) + self.parse_characters( + "]", "Expected end of scalable vector dimensions here!") + self.parse_characters( + "x", "Expected end of scalable vector dimensions here!") + + if scaling_shape is not None: + # TODO: handle scaling vectors! + self.raise_error("Warning: scaling vectors not supported!") + pass + + type = self.try_parse_type() + if type is None: + self.raise_error( + "Expected a type at the end of the vector parameters!") + + return VectorType.from_element_type_and_shape(type, shape) + + def _parse_tensor_or_memref_dims(self) -> list[int] | None: + with self.tokenizer.configured(break_on=self.tokenizer.break_on + + ('x', )): + # Check for unranked-ness + if self.tokenizer.next_token_of_pattern('*') is not None: + # Consume `x` + self.parse_characters( + 'x', + 'Unranked tensors must follow format (`<*x` type `>`)') else: - typ = IntegerType.from_width(64) - return IntegerAttr.from_params(integer_lit, typ) - - # Shorthand for ArrayAttr - parse_bracket = self.parse_optional_char("[") - if parse_bracket: - array = self.parse_list(self.parse_optional_attribute) - self.parse_char("]") - return ArrayAttr.from_list(array) - - # Shorthand for DictionaryAttr - if self.peek_char("{"): - dictionary = self.parse_dictionary(self.parse_str_literal, - self.parse_attribute) - return DictionaryAttr.from_dict(dictionary) - - # Shorthand for FlatSymbolRefAttr - parse_at = self.parse_optional_char("@") - if parse_at: - symbol_name = self.parse_alpha_num(skip_white_space=False) - return FlatSymbolRefAttr.from_str(symbol_name) - - def parse_integer_type(): - self.parse_char("!", skip_white_space=skip_white_space) - return self.parse_mlir_integer_type( - skip_white_space=skip_white_space) - - if int_type := self.try_parse(parse_integer_type): - return int_type + # Parse rank: + return list(self.try_parse_numerical_dims(lower_bound=0)) + + def parse_tensor_attrs(self) -> AnyTensorType: + shape = self._parse_tensor_or_memref_dims() + type = self.try_parse_type() + + if type is None: + self.raise_error("Expected tensor type here!") + + if self.tokenizer.starts_with(','): + # TODO: add tensor encoding! + raise self.raise_error("Parsing tensor encoding is not supported!") + if shape is None and self.tokenizer.starts_with(','): + raise self.raise_error("Unranked tensors don't have an encoding!") + + if shape is not None: + return TensorType.from_type_and_list(type, shape) + + return UnrankedTensorType.from_type(type) + + def _try_parse_shape_element(self, lower_bound: int = 1) -> int | None: + """ + Parse a shape element, either a decimal integer immediate or a `?`, which evaluates to -1 + + immediate cannot be smaller than lower_bound (defaults to 1) (is 0 for tensors and memrefs) + """ + int_lit = self.try_parse_decimal_literal() + + if int_lit is not None: + value = int(int_lit.text) + if value < lower_bound: + # TODO: this is ugly, it's a raise inside a try_ type function, which should instead just give up + raise ParseError( + int_lit, + "Shape element literal cannot be negative or zero!") + return value + + if self.tokenizer.next_token_of_pattern('?') is not None: + return -1 return None - def parse_optional_attribute(self, - skip_white_space: bool = True - ) -> Attribute | None: - # If we are parsing an MLIR file, we first try to parse builtin - # attributes, which have a different format. - if self.source == self.Source.MLIR: - if attr := self.parse_optional_mlir_attribute( - skip_white_space=skip_white_space): - return attr - - # If we are parsing an xDSL file, we first try to parse builtin - # attributes, which have a different format. - if self.source == self.Source.XDSL: - if attr := self.parse_optional_xdsl_builtin_attribute( - skip_white_space=skip_white_space): - return attr - - # Then, we parse attributes/types with the generic format. - - if self.parse_optional_char("!") is None: - if self.source == self.Source.MLIR: - if self.parse_optional_char("#") is None: - return None - else: - return None - - parse_with_default_format = False - # Attribute with default format - if self.parse_optional_char('"'): - attr_def_name = self.parse_alpha_num(skip_white_space=False) - self.parse_char('"') - parse_with_default_format = True + def _parse_type_params(self) -> list[Attribute]: + # Consume opening bracket + self.parse_characters('<', 'Type must be parameterized!') + + params = self.parse_list_of(self.try_parse_type, + 'Expected a type here!') + + self.parse_characters('>', + 'Expected end of type parameterization here!') + + return params + + def expect(self, try_parse: Callable[[], T_ | None], + error_message: str) -> T_: + """ + Used to force completion of a try_parse function. Will throw a parse error if it can't + """ + res = try_parse() + if res is None: + self.raise_error(error_message) + return res + + def raise_error(self, msg: str, at_position: Span | None = None): + """ + Helper for raising exceptions, provides as much context as possible to them. + + This will, for example, include backtracking errors, if any occurred previously + """ + if at_position is None: + at_position = self.tokenizer.next_token(peek=True) + + raise ParseError(at_position, msg, self.tokenizer.history) + + def parse_characters(self, text: str, msg: str) -> Span: + if (match := self.tokenizer.next_token_of_pattern(text)) is None: + self.raise_error(msg) + return match + + @abstractmethod + def _parse_op_result_list( + self) -> tuple[list[Span], list[Attribute] | None]: + raise NotImplementedError() + + def try_parse_operation(self) -> Operation | None: + with self.tokenizer.backtracking("operation"): + return self.parse_operation() + + def parse_operation(self) -> Operation: + result_list, ret_types = self._parse_op_result_list() + if len(result_list) > 0: + self.parse_characters( + '=', + 'Operation definitions expect an `=` after op-result-list!') + + # Check for custom op format + op_name = self.try_parse_bare_id() + if op_name is not None: + op_type = self._get_op_by_name(op_name) + op = op_type.parse(ret_types, self) else: - attr_def_name = self.parse_alpha_num(skip_white_space=True) + # Check for basic op format + op_name = self.try_parse_string_literal() + if op_name is None: + self.raise_error( + "Expected an operation name here, either a bare-id, or a string literal!" + ) + + args, successors, attrs, regions, func_type = self._parse_operation_details( + ) + + if ret_types is None: + assert func_type is not None + ret_types = func_type.outputs.data + + op_type = self._get_op_by_name(op_name) + + op = op_type.create( + operands=[self.ssaValues[span.text] for span in args], + result_types=ret_types, + attributes=attrs, + successors=[ + self.blocks[block_name.text] for block_name in successors + ], + regions=regions) + + # Register the result SSA value names in the parser + for idx, res in enumerate(result_list): + ssa_val_name = res.text + if ssa_val_name in self.ssaValues: + self.raise_error( + f"SSA value {ssa_val_name} is already defined", res) + self.ssaValues[ssa_val_name] = op.results[idx] + self.ssaValues[ssa_val_name].name = ssa_val_name.lstrip('%') + + return op + + def _get_op_by_name(self, span: Span) -> type[Operation]: + if isinstance(span, StringLiteral): + op_name = span.string_contents + else: + op_name = span.text + + op_type = self.ctx.get_optional_op(op_name) + + if op_type is not None: + return op_type + + if self.allow_unregistered_ops: + return UnregisteredOp.with_name(op_name, self.ctx) + + self.raise_error(f'Unknown operation {op_name}!', span) + + def parse_region(self) -> Region: + oldSSAVals = self.ssaValues.copy() + oldBBNames = self.blocks + oldForwardRefs = self.forward_block_references + self.blocks = dict() + self.forward_block_references = defaultdict(list) + + region = Region() - if (self.source == self.Source.MLIR) and parse_with_default_format: - raise ParserError(self._pos, "cannot parse generic MLIR attribute") + try: + self.parse_characters("{", "Regions begin with `{`") + if self.tokenizer.starts_with("}"): + region.add_block(Block()) + else: + # Parse first block + block = self.parse_block() + region.add_block(block) + + while self.tokenizer.starts_with("^"): + region.add_block(self.parse_block()) + + end = self.parse_characters( + "}", "Reached end of region, expected `}`!") + + if len(self.forward_block_references) > 0: + raise MultipleSpansParseError( + end, + "Region ends with missing block declarations for block(s) {}!" + .format(', '.join(self.forward_block_references.keys())), + 'The following block references are dangling:', + [(span, "Reference to block \"{}\" without implementation!" + .format(span.text)) for span in itertools.chain( + *self.forward_block_references.values())], + self.tokenizer.history) + + return region + finally: + self.ssaValues = oldSSAVals + self.blocks = oldBBNames + self.forward_block_references = oldForwardRefs + + def _try_parse_op_name(self) -> Span | None: + if (str_lit := self.try_parse_string_literal()) is not None: + return str_lit + return self.try_parse_bare_id() + + def _parse_attribute_entry(self) -> tuple[Span, Attribute]: + """ + Parse entry in attribute dict. Of format: + + attribute_entry := (bare-id | string-literal) `=` attribute + attribute := dialect-attribute | builtin-attribute + """ + if (name := self.try_parse_bare_id()) is None: + name = self.try_parse_string_literal() - attr_def = self.ctx.get_attr(attr_def_name) + if name is None: + self.raise_error( + "Expected bare-id or string-literal here as part of attribute entry!" + ) - # Attribute with default format - if parse_with_default_format: - if not issubclass(attr_def, ParametrizedAttribute): - raise ParserError( - self._pos, - f"{attr_def_name} is not a parameterized attribute, and " - "thus cannot be parsed with a generic format.") - params = self.parse_paramattr_parameters() - return attr_def(params) # type: ignore + if not self.tokenizer.starts_with('='): + return name, UnitAttr() - if issubclass(attr_def, Data): - self.parse_char("<") - attr: Any = attr_def.parse_parameter(self) - self.parse_char(">") - return attr_def(attr) # type: ignore + self.parse_characters( + "=", "Attribute entries must be of format name `=` attribute!") - assert issubclass(attr_def, ParametrizedAttribute) - param_list = attr_def.parse_parameters(self) - return attr_def(param_list) # type: ignore + return name, self.parse_attribute() - def parse_optional_dim(self, skip_white_space: bool = True) -> int | None: + @abstractmethod + def parse_attribute(self) -> Attribute: """ - Parse an optional dimension. - The dimension is either a non-negative integer, or -1 for dynamic dimensions. + Parse attribute (either builtin or dialect) + + This is different in xDSL and MLIR, so the actuall implementation is provided by the subclass """ - if self.parse_optional_char("?", skip_white_space=skip_white_space): - return -1 - if (dim := self.parse_optional_int_literal()) is not None: - return dim - return None + raise NotImplementedError() + + def try_parse_attribute(self) -> Attribute | None: + with self.tokenizer.backtracking("attribute"): + return self.parse_attribute() - def parse_dim(self, skip_white_space: bool = True) -> int: + def _parse_attribute_type(self) -> Attribute: """ - Parse a dimension. - The dimension is either a non-negative integer, - or -1 for dynamic dimensions, represented by `?`. + Parses `:` type and returns the type """ - dim = self.parse_optional_dim(skip_white_space=skip_white_space) - if dim is not None: - return dim - raise ParserError(self._pos, "dimension expected") + self.parse_characters( + ":", "Expected attribute type definition here ( `:` type )") + return self.expect( + self.try_parse_type, + "Expected attribute type definition here ( `:` type )") - def parse_optional_shape( - self, - skip_white_space: bool = True - ) -> tuple[list[int], Attribute] | None: + def try_parse_builtin_attr(self) -> Attribute | None: + """ + Tries to parse a builtin attribute, e.g. a string literal, int, array, etc.. + """ + next_token = self.tokenizer.next_token(peek=True) + if next_token.text == '"': + return self.try_parse_builtin_str_attr() + elif next_token.text == "[": + return self.try_parse_builtin_arr_attr() + elif next_token.text == "@": + return self.try_parse_ref_attr() + elif next_token.text == '{': + return self.try_parse_builtin_dict_attr() + elif next_token.text == '(': + return self.try_parse_function_type() + elif next_token.text in ParserCommons.builtin_attr_names: + return self.try_parse_builtin_named_attr() + # Order here is important! + attrs = (self.try_parse_builtin_float_attr, + self.try_parse_builtin_int_attr, self.try_parse_builtin_type) + + for attr_parser in attrs: + if (val := attr_parser()) is not None: + return val + + def try_parse_builtin_named_attr(self) -> Attribute | None: + name = self.tokenizer.next_token(peek=True) + with self.tokenizer.backtracking("Builtin attribute {}".format( + name.text)): + self.tokenizer.consume_peeked(name) + parsers = { + 'dense': self._parse_builtin_dense_attr, + 'opaque': self._parse_builtin_opaque_attr, + } + + def not_implemented(): + raise NotImplementedError() + + return parsers.get(name.text, not_implemented)() + + def _parse_builtin_dense_attr(self) -> Attribute | None: + err_msg = "Malformed dense attribute, format must be (`dense<` array-attr `>:` type)" + self.parse_characters("<", err_msg) + info = list(self._parse_builtin_dense_attr_args()) + self.parse_characters(">", err_msg) + self.parse_characters(":", err_msg) + type = self.expect(self.try_parse_type, + "Dense attribute must be typed!") + return DenseIntOrFPElementsAttr.from_list(type, info) + + def _parse_builtin_opaque_attr(self): + self.parse_characters("<", "Opaque attribute must be parametrized") + str_lit_list = self.parse_list_of(self.try_parse_string_literal, + 'Expected opaque attr here!') + + if len(str_lit_list) != 2: + self.raise_error('Opaque expects 2 string literal parameters!') + + self.parse_characters( + ">", "Unexpected parameters for opaque attr, expected `>`!") + + type = NoneAttr() + if self.tokenizer.starts_with(':'): + self.parse_characters(":", "opaque attribute must be typed!") + type = self.expect(self.try_parse_type, + "opaque attribute must be typed!") + + return OpaqueAttr.from_strings(*(span.string_contents + for span in str_lit_list), + type=type) + + def _parse_builtin_dense_attr_args(self) -> Iterable[int | float]: """ - Parse a shape, with the format `dim0 x dim1 x ... x dimN x type`. + dense attribute params must be: + + dense-attr-params := float-literal | int-literal | list-of-dense-attrs-params + list-of-dense-attrs-params := `[` dense-attr-params (`,` dense-attr-params)* `]` """ - dims = list[int]() - if skip_white_space: - self.skip_white_space() + def try_parse_int_or_float(): + if (literal := self.try_parse_float_literal()) is not None: + return float(literal.text) + if (literal := self.try_parse_integer_literal()) is not None: + return int(literal.text) + self.raise_error('Expected int or float literal here!') + + if not self.tokenizer.starts_with('['): + yield try_parse_int_or_float() + return + + self.parse_characters('[', '') + while not self.tokenizer.starts_with(']'): + yield from self._parse_builtin_dense_attr_args() + if self.tokenizer.next_token_of_pattern(',') is None: + break + self.parse_characters(']', '') - def parse_optional_dim_and_x(): - if (dim := self.parse_optional_dim( - skip_white_space=False)) is not None: - self.parse_char("x", skip_white_space=False) - return dim + def try_parse_ref_attr(self) -> FlatSymbolRefAttr | None: + if not self.tokenizer.starts_with("@"): return None - dims = self.parse_list(parse_optional_dim_and_x, delimiter="") - typ = self.parse_attribute() + ref = self.parse_reference() - return dims, typ + if len(ref) > 1: + self.raise_error("Nested refs are not supported yet!", ref[1]) - def parse_shape( - self, - skip_white_space: bool = True) -> tuple[list[int], Attribute]: + return FlatSymbolRefAttr.from_str(ref[0].text) + + def try_parse_builtin_int_attr(self) -> IntegerAttr | None: + bool = self.try_parse_builtin_boolean_attr() + if bool is not None: + return bool + + with self.tokenizer.backtracking("built in int attribute"): + value = self.expect( + self.try_parse_integer_literal, + 'Integer attribute must start with an integer literal!') + if self.tokenizer.next_token(peek=True).text != ':': + return IntegerAttr.from_params(int(value.text), i64) + type = self._parse_attribute_type() + return IntegerAttr.from_params(int(value.text), type) + + def try_parse_builtin_float_attr(self) -> FloatAttr | None: + with self.tokenizer.backtracking("float literal"): + value = self.expect( + self.try_parse_float_literal, + "Float attribute must start with a float literal!", + ) + # If we don't see a ':' indicating a type signature + if not self.tokenizer.starts_with(":"): + return FloatAttr.from_value(float(value.text)) + + type = self._parse_attribute_type() + return FloatAttr.from_value(float(value.text), type) + + def try_parse_builtin_boolean_attr(self) -> IntegerAttr | None: + span = self.try_parse_boolean_literal() + + if span is None: + return None + + int_val = ["false", "true"].index(span.text) + return IntegerAttr.from_params(int_val, IntegerType.from_width(1)) + + def try_parse_builtin_str_attr(self): + if not self.tokenizer.starts_with('"'): + return None + + with self.tokenizer.backtracking("string literal"): + literal = self.try_parse_string_literal() + if literal is None: + self.raise_error("Invalid string literal") + return StringAttr.from_str(literal.string_contents) + + def try_parse_builtin_arr_attr(self) -> ArrayAttr | None: + if not self.tokenizer.starts_with("["): + return None + with self.tokenizer.backtracking("array literal"): + self.parse_characters("[", "Array literals must start with `[`") + attrs = self.parse_list_of(self.try_parse_attribute, + "Expected array entry!") + self.parse_characters( + "]", "Malformed array contents (expected end of array here!") + return ArrayAttr.from_list(attrs) + + @abstractmethod + def parse_optional_attr_dict(self) -> dict[str, Attribute]: + raise NotImplementedError() + + def _attr_dict_from_tuple_list( + self, tuple_list: list[tuple[Span, + Attribute]]) -> dict[str, Attribute]: """ - Parse a shape, with the format `dim0 x dim1 x ... x dimN x type`. + Convert a list of tuples (Span, Attribute) to a dictionary. + + This function converts the span to a string, trimming quotes from string literals """ - shape = self.parse_optional_shape(skip_white_space=skip_white_space) - if shape is not None: - return shape - raise ParserError(self._pos, "shape expected") - def parse_optional_mlir_tensor( + def span_to_str(span: Span) -> str: + if isinstance(span, StringLiteral): + return span.string_contents + return span.text + + return dict((span_to_str(span), attr) for span, attr in tuple_list) + + def parse_function_type(self) -> FunctionType: + """ + Parses function-type: + + viable function types are: + (i32) -> () + () -> (i32, i32) + (i32, i32) -> () + () -> i32 + Non-viable types are: + i32 -> i32 + i32 -> () + + Uses type-or-type-list-parens internally + """ + self.parse_characters( + "(", "First group of function args must start with a `(`") + + args: list[Attribute] = self.parse_list_of(self.try_parse_type, + "Expected type here!") + + self.parse_characters( + ")", + "Malformed function type, expected closing brackets of argument types!" + ) + + self.parse_characters("->", "Malformed function type, expected `->`!") + + return FunctionType.from_lists(args, + self._parse_type_or_type_list_parens()) + + def _parse_type_or_type_list_parens(self) -> list[Attribute]: + """ + Parses type-or-type-list-parens, which is used in function-type. + + type-or-type-list-parens ::= type | type-list-parens + type-list-parens ::= `(` `)` | `(` type-list-no-parens `)` + type-list-no-parens ::= type (`,` type)* + """ + if self.tokenizer.next_token_of_pattern("(") is not None: + args: list[Attribute] = self.parse_list_of(self.try_parse_type, + "Expected type here!") + self.parse_characters(")", "Unclosed function type argument list!") + else: + args = [self.try_parse_type()] + if args[0] is None: + self.raise_error( + "Function type must either be single type or list of types in" + " parenthesis!") + return args + + def try_parse_function_type(self) -> FunctionType | None: + if not self.tokenizer.starts_with("("): + return None + with self.tokenizer.backtracking("function type"): + return self.parse_function_type() + + def parse_region_list(self) -> list[Region]: + """ + Parses a sequence of regions for as long as there is a `{` in the input. + """ + regions = [] + while not self.tokenizer.is_eof() and self.tokenizer.starts_with("{"): + regions.append(self.parse_region()) + return regions + + def _parse_builtin_type_with_name(self, name: Span): + """ + Parses one of the builtin types like i42, vector, etc... + """ + if name.text == "index": + return IndexType() + if (re_match := re.match(r"^[su]?i(\d+)$", name.text)) is not None: + signedness = { + "s": Signedness.SIGNED, + "u": Signedness.UNSIGNED, + "i": Signedness.SIGNLESS, + } + return IntegerType.from_width(int(re_match.group(1)), + signedness[name.text[0]]) + + if (re_match := re.match(r"^f(\d+)$", name.text)) is not None: + width = int(re_match.group(1)) + type = { + 16: Float16Type, + 32: Float32Type, + 64: Float64Type + }.get(width, None) + if type is None: + self.raise_error( + "Unsupported floating point width: {}".format(width)) + return type() + + return self._parse_builtin_parametrized_type(name) + + @abstractmethod + def _parse_operation_details( self, - skip_white_space: bool = True - ) -> AnyTensorType | AnyUnrankedTensorType | None: - if self.parse_optional_string("tensor", - skip_white_space=skip_white_space): - self.parse_char("<") - # Unranked tensor case - if self.parse_optional_char("*"): - self.parse_char("x") - typ = self.parse_attribute() - self.parse_char(">") - return UnrankedTensorType.from_type(typ) - dims, typ = self.parse_shape() - self.parse_char(">") - return TensorType.from_type_and_list(typ, dims) - return None + ) -> tuple[list[Span], list[Span], dict[str, Attribute], list[Region], + FunctionType | None]: + """ + Must return a tuple consisting of: + - a list of arguments to the operation + - a list of successor names + - the attributes attached to the OP + - the regions of the op + - An optional function type. If not supplied, parse_op_result_list must return a second value + containing the types of the returned SSAValues - def parse_optional_mlir_vector(self, - skip_white_space: bool = True - ) -> AnyVectorType | None: - if self.parse_optional_string("vector", - skip_white_space=skip_white_space): - self.parse_optional_char("<") - dims, typ = self.parse_shape() - self.parse_char(">") - return VectorType.from_element_type_and_shape(typ, dims) - return None + """ + raise NotImplementedError() - def parse_optional_mlir_memref( + @abstractmethod + def _parse_op_args_list(self) -> list[Span]: + raise NotImplementedError() + + # HERE STARTS A SOMEWHAT CURSED COMPATIBILITY LAYER: + # Since we don't want to rewrite all dialects currently, the new parser needs to expose the same + # Interface to the dialect definitions (to some extent). Here we implement that interface. + + _OperationType = TypeVar("_OperationType", bound=Operation) + + def parse_op_with_default_format( self, - skip_white_space: bool = True - ) -> MemRefType[Any] | UnrankedMemrefType[Any] | None: - if self.parse_optional_string("memref", - skip_white_space=skip_white_space): - self.parse_char("<") - # Unranked memref case - if self.parse_optional_char("*"): - self.parse_char("x") - typ = self.parse_attribute() - self.parse_char(">") - return UnrankedMemrefType.from_type(typ) - dims, typ = self.parse_shape() - self.parse_char(">") - return MemRefType.from_element_type_and_shape(typ, dims) - return None + op_type: type[_OperationType], + result_types: list[Attribute], + ) -> _OperationType: + """ + Compatibility wrapper so the new parser can be passed instead of the old one. Parses everything after the + operation name. - def parse_optional_mlir_index_type(self, - skip_white_space: bool = True - ) -> IndexType | None: - if self.parse_optional_string("index", - skip_white_space=skip_white_space): - return IndexType() - return None + This implicitly assumes XDSL format, and will fail on MLIR style operations + """ + # TODO: remove this function and restructure custom op / irdl parsing + assert isinstance(self, XDSLParser) + args, successors, attributes, regions, _ = self._parse_operation_details( + ) + + for x in args: + if x.text not in self.ssaValues: + self.raise_error( + "Unknown SSAValue name, known SSA Values are: {}".format( + ", ".join(self.ssaValues.keys())), x) + + return op_type.create( + operands=[self.ssaValues[span.text] for span in args], + result_types=result_types, + attributes=attributes, + successors=[ + self._get_block_from_name(span) for span in successors + ], + regions=regions) - def parse_mlir_index_type(self, - skip_white_space: bool = True) -> IndexType: - typ = self.parse_optional_mlir_index_type( - skip_white_space=skip_white_space) - if typ is not None: - return typ - raise ParserError(self._pos, "index type expected") - - def parse_mlir_integer_type(self, - skip_white_space: bool = True) -> IntegerType: - # Parse the optional signedness semantics - if self.parse_optional_string("si", skip_white_space=skip_white_space): - signedness = Signedness.SIGNED - elif self.parse_optional_string("ui", - skip_white_space=skip_white_space): - signedness = Signedness.UNSIGNED - elif self.parse_optional_string("i", - skip_white_space=skip_white_space): - signedness = Signedness.SIGNLESS - else: - raise ParserError(self._pos, "integer type expected") - - val = self.parse_int_literal(skip_white_space=False) - return IntegerType.from_width(val, signedness) - - def parse_optional_mlir_integer_type(self, - skip_white_space: bool = True - ) -> IntegerType | None: - return self.try_parse(self.parse_mlir_integer_type, - skip_white_space=skip_white_space) - - def parse_optional_mlir_float_type(self, - skip_white_space: bool = True - ) -> AnyFloat | None: - if self.parse_optional_string("f16") is not None: - return Float16Type() - if self.parse_optional_string("f32") is not None: - return Float32Type() - if self.parse_optional_string("f64") is not None: - return Float64Type() - return None + def parse_paramattr_parameters( + self, + expect_brackets: bool = False, + skip_white_space: bool = True) -> list[Attribute]: + opening_brackets = self.tokenizer.next_token_of_pattern('<') + if expect_brackets and opening_brackets is None: + self.raise_error("Expected start attribute parameters here (`<`)!") - def parse_mlir_float_type(self, skip_white_space: bool = True) -> AnyFloat: - typ = self.parse_optional_mlir_float_type( - skip_white_space=skip_white_space) - if typ is not None: - return typ - raise ParserError(self._pos, "float type expected") - - def parse_optional_mlir_attribute(self, - skip_white_space: bool = True - ) -> Attribute | None: - if skip_white_space: - self.skip_white_space() - - # index type - if (index_type := self.parse_optional_mlir_index_type()) is not None: - return index_type - - # integer type - if (int_type := self.parse_optional_mlir_integer_type()) is not None: - return int_type - - # float type - if (float_type := self.parse_optional_mlir_float_type()) is not None: - return float_type - - # float attribute - if (lit := self.parse_optional_float_literal()) is not None: - if self.parse_optional_char(":"): - if (typ := self.parse_optional_mlir_float_type()) is not None: - return FloatAttr.from_value(lit, typ) - raise ParserError(self._pos, "float type expected") - return FloatAttr.from_value(lit, Float64Type()) - - # Shorthand for boolean attributes (integer attributes of width 1) - if (bool_attr := self.parse_optional_boolean_attribute()) is not None: - return bool_attr - - # integer attribute - if (lit := self.parse_optional_int_literal()) is not None: - if self.parse_optional_char(":"): - if (typ := - self.parse_optional_mlir_integer_type()) is not None: - return IntegerAttr.from_params(lit, typ) - if (typ := self.parse_optional_mlir_index_type()) is not None: - return IntegerAttr.from_params(lit, typ) - raise ParserError(self._pos, "integer or index type expected") - return IntegerAttr.from_params(lit, IntegerType.from_width(64)) - - # string literal - str_literal = self.parse_optional_str_literal() - if str_literal is not None: - return StringAttr.from_str(str_literal) - - # Array attribute - if self.parse_optional_char("["): - contents = self.parse_list(self.parse_optional_attribute) - self.parse_char("]") - return ArrayAttr.from_list(contents) - - # Shorthand for DictionaryAttr - if self.peek_char("{"): - contents = self.parse_dictionary(self.parse_str_literal, - self.parse_attribute) - return DictionaryAttr.from_dict(contents) - - # FlatSymbolRefAttr - if self.parse_optional_char("@"): - symbol_name = self.parse_alpha_num(skip_white_space=False) - return FlatSymbolRefAttr.from_str(symbol_name) - - # tensor type - if (tensor := self.parse_optional_mlir_tensor()) is not None: - return tensor - - # vector type - if (vector := self.parse_optional_mlir_vector()) is not None: - return vector - - # dense attribute - if self.parse_optional_string("dense"): - self.parse_char("<") - - def parse_num() -> int | float | None: - if (f := self.parse_optional_float_literal()) is not None: - return f - if (i := self.parse_optional_int_literal()) is not None: - return i - return None - - value = self.parse_optional_nested_list(parse_num) - self.parse_char(">") - self.parse_char(":") - - # Parse the dense attribute type. It is either a tensor or a vector. - loc = self._pos - type_attr: AnyVectorType | AnyTensorType - if (vec := self.parse_optional_mlir_vector()) is not None: - type_attr = vec - elif (tensor := self.parse_optional_mlir_tensor()) is not None: - type_attr = tensor - else: - raise ParserError(loc, "expected a tensor or a vector type") - - return DenseIntOrFPElementsAttr.from_list(type_attr, value) - - # opaque attribute - if self.parse_optional_string("opaque") is not None: - self.parse_char("<") - name = self.parse_str_literal() - self.parse_char(",") - val: str = self.parse_str_literal() - self.parse_char(">") - if self.parse_optional_char(":") is not None: - typ = self.parse_attribute() - return OpaqueAttr.from_strings(name, val, typ) - return OpaqueAttr.from_strings(name, val) - - # function attribute - if self.parse_optional_char("(") is not None: - inputs = self.parse_list(self.parse_optional_attribute) - self.parse_char(")") - self.parse_string("->") - if self.parse_optional_char("("): - outputs = self.parse_list(self.parse_optional_attribute) - self.parse_char(")") - return FunctionType.from_lists(inputs, outputs) - output = self.parse_attribute() - return FunctionType.from_lists(inputs, [output]) - - # memref type - if (memref := self.parse_optional_mlir_memref()) is not None: - return memref + res = self.parse_list_of(self.try_parse_attribute, + 'Expected another attribute here!') - return None + if opening_brackets is not None and self.tokenizer.next_token_of_pattern( + '>') is None: + self.raise_error( + "Malformed parameter list, expected either another parameter or `>`!" + ) - def parse_attribute(self, skip_white_space: bool = True) -> Attribute: - res = self.parse_optional_attribute(skip_white_space=skip_white_space) - if res is None: - raise ParserError(self._pos, "attribute expected") return res - def parse_optional_named_attribute( - self, - skip_white_space: bool = True) -> tuple[str, Attribute] | None: - # The attribute name is either a string literal, or an identifier. - attr_name = self.parse_optional_str_literal( - skip_white_space=skip_white_space) - if attr_name is None: - attr_name = self.parse_optional_alpha_num( - skip_white_space=skip_white_space) - - if attr_name is None: - return None - if not self.peek_char("="): - return attr_name, UnitAttr([]) - self.parse_char("=") - attr = self.parse_attribute() - return attr_name, attr - - def parse_op_attributes(self, - skip_white_space: bool = True - ) -> dict[str, Attribute]: - if not self.parse_optional_char( - "[" if self.source == self.Source.XDSL else "{", - skip_white_space=skip_white_space): + def parse_char(self, text: str): + self.parse_characters(text, "Expected '{}' here!".format(text)) + + def parse_str_literal(self) -> str: + return self.expect(self.try_parse_string_literal, + 'Malformed string literal!').string_contents + + def parse_attribute(self) -> Attribute: + raise NotImplementedError() + + def parse_op(self) -> Operation: + return self.parse_operation() + + def parse_int_literal(self) -> int: + return int( + self.expect(self.try_parse_integer_literal, + 'Expected integer literal here').text) + + def try_parse_builtin_dict_attr(self): + attr_def = self.ctx.get_optional_attr('dictionary') + if attr_def is None: + self.raise_error( + "An attribute named `dictionary` must be available in the " + "context in order to parse dictionary attributes! Please make " + "sure the builtin dialect is available, or provide your own " + "replacement!") + param = attr_def.parse_parameter(self) + return attr_def(param) + + +class MLIRParser(BaseParser): + + def try_parse_builtin_type(self) -> Attribute | None: + """ + parse a builtin-type like i32, index, vector etc. + """ + with self.tokenizer.backtracking("builtin type"): + name = self.tokenizer.next_token_of_pattern( + ParserCommons.builtin_type) + if name is None: + raise self.raise_error("Expected builtin name!") + + return self._parse_builtin_type_with_name(name) + + def parse_attribute(self) -> Attribute: + """ + Parse attribute (either builtin or dialect) + """ + # All dialect attrs must start with '#', so we check for that first (as it's easier) + if self.tokenizer.starts_with("#"): + value = self.try_parse_dialect_attr() + + # No value => error + if value is None: + self.raise_error( + "`#` must be followed by a valid dialect attribute or type!" + ) + + return value + + # If it isn't a dialect attr, parse builtin + builtin_val = self.try_parse_builtin_attr() + + if builtin_val is None: + self.raise_error( + "Unknown attribute (neither builtin nor dialect could be parsed)!" + ) + + return builtin_val + + def _parse_op_result_list( + self) -> tuple[list[Span], list[Attribute] | None]: + return ( + self.parse_list_of(self.try_parse_value_id, + "Expected op-result here!", + allow_empty=True), + None, + ) + + def parse_optional_attr_dict(self) -> dict[str, Attribute]: + if not self.tokenizer.starts_with("{"): return dict() - attrs_with_names = self.parse_list(self.parse_optional_named_attribute) - self.parse_char("]" if self.source == self.Source.XDSL else "}") - return {name: attr for (name, attr) in attrs_with_names} - - def parse_optional_successor(self, - skip_white_space: bool = True - ) -> Block | None: - parsed = self.parse_optional_char("^", - skip_white_space=skip_white_space) - if parsed is None: - return None - bb_name = self.parse_alpha_num(skip_white_space=False) - if bb_name in self._blocks: - block = self._blocks[bb_name] - pass - else: - block = Block() - self._blocks[bb_name] = block - return block - def parse_successors(self, skip_white_space: bool = True) -> list[Block]: - parsed = self.parse_optional_char( - "(" if self.source == self.Source.XDSL else "[", - skip_white_space=skip_white_space) - if parsed is None: + self.parse_characters( + "{", + "MLIR Attribute dictionary must be enclosed in curly brackets") + + attrs = [] + if not self.tokenizer.starts_with('}'): + attrs = self.parse_list_of(self._parse_attribute_entry, + "Expected attribute entry") + + self.parse_characters( + "}", + "MLIR Attribute dictionary must be enclosed in curly brackets") + + return self._attr_dict_from_tuple_list(attrs) + + def _parse_operation_details( + self, + ) -> tuple[list[Span], list[Span], dict[str, Attribute], list[Region], + FunctionType | None]: + args = self._parse_op_args_list() + succ = self._parse_optional_successor_list() + + regions = [] + if self.tokenizer.starts_with("("): + self.parse_characters("(", "Expected brackets enclosing regions!") + regions = self.parse_region_list() + self.parse_characters(")", "Expected brackets enclosing regions!") + + attrs = self.parse_optional_attr_dict() + + self.parse_characters( + ":", + "MLIR Operation definitions must end in a function type signature!" + ) + func_type = self.parse_function_type() + + return args, succ, attrs, regions, func_type + + def _parse_optional_successor_list(self) -> list[Span]: + if not self.tokenizer.starts_with("["): return [] - res = self.parse_list(self.parse_optional_successor, delimiter=',') - self.parse_char(")" if self.source == self.Source.XDSL else "]") - return res + self.parse_characters("[", + "Successor list is enclosed in square brackets") + successors = self.parse_list_of(self.try_parse_block_id, + "Expected a block-id", + allow_empty=False) + self.parse_characters("]", + "Successor list is enclosed in square brackets") + return successors + + def _parse_op_args_list(self) -> list[Span]: + self.parse_characters( + "(", "Operation args list must be enclosed by brackets!") + args = self.parse_list_of(self.try_parse_value_id, + "Expected another bare-id here") + self.parse_characters( + ")", "Operation args list must be closed by a closing bracket") + # TODO: check if type is correct here! + return args + + +class XDSLParser(BaseParser): + + def try_parse_builtin_type(self) -> Attribute | None: + """ + parse a builtin-type like i32, index, vector etc. + """ + with self.tokenizer.backtracking("builtin type"): + name = self.tokenizer.next_token_of_pattern( + ParserCommons.builtin_type_xdsl) + if name is None: + self.raise_error("Expected builtin name!") + # xDSL builtin types have a '!' prefix, we strip that out here + name = Span(start=name.start + 1, end=name.end, input=name.input) - def is_valid_name(self, name: str) -> bool: - return not name[-1].isnumeric() + return self._parse_builtin_type_with_name(name) - _OperationType = TypeVar('_OperationType', bound='Operation') + def parse_attribute(self) -> Attribute: + """ + Parse attribute (either builtin or dialect) - def parse_op_with_default_format( - self, - op_type: type[_OperationType], - result_types: list[Attribute], - skip_white_space: bool = True) -> _OperationType: - operands = self.parse_operands(skip_white_space=skip_white_space) - successors = self.parse_successors() - attributes = self.parse_op_attributes() - regions = self.parse_list(self.parse_optional_region, delimiter="") - - return op_type.create(operands=operands, - result_types=result_types, - attributes=attributes, - successors=successors, - regions=regions) - - def _parse_optional_op_name(self, - skip_white_space: bool = True - ) -> tuple[str, bool] | None: - op_name = self.parse_optional_alpha_num( - skip_white_space=skip_white_space) - if op_name: - return op_name, False - op_name = self.parse_optional_str_literal() - if op_name: - return op_name, True - return None + xDSL allows types in places of attributes! That's why we parse types here as well + """ + value = self.try_parse_builtin_attr() - def _parse_op_name(self, - skip_white_space: bool = True) -> tuple[str, bool]: - op_name = self._parse_optional_op_name( - skip_white_space=skip_white_space) - if op_name is None: - raise ParserError(self._pos, "operation name expected") - return op_name - - def parse_optional_op(self, - skip_white_space: bool = True) -> Operation | None: - if self.source == self.Source.MLIR: - return self.parse_optional_mlir_op( - skip_white_space=skip_white_space) - - start_pos = self._pos - results = self.parse_optional_typed_results( - skip_white_space=skip_white_space) - if results is None: - op_name_and_generic = self._parse_optional_op_name() - if op_name_and_generic is None: - return None - op_name, is_generic_format = op_name_and_generic - results = [] - else: - op_name, is_generic_format = self._parse_op_name() + # xDSL: Allow both # and ! prefixes, as we allow both types and attrs + # TODO: phase out use of next_token(peek=True) in favour of starts_with + if value is None and self.tokenizer.next_token(peek=True).text in "#!": + # In MLIR # and ! are prefixes for dialect attrs/types, but in xDSL ! is also used for builtin types + value = self.try_parse_dialect_type_or_attribute() - result_types = [typ for (_, typ) in results] - op_type = self.ctx.get_optional_op(op_name) + if value is None: + self.raise_error( + "Unknown attribute (neither builtin nor dialect could be parsed)!" + ) - # If the operation is not registered, we create an UnregisteredOp instead, - # or fail. - if op_type is None: - if not self.allow_unregistered_ops: - raise ParserError(start_pos, f"unknown operation '{op_name}'") - if not is_generic_format: - raise ParserError( - start_pos, f"unknown operation '{op_name}' can " - "only be parsed using the generic format") - - op = self.parse_op_with_default_format(UnregisteredOp, - result_types) - op.attributes["op_name__"] = StringAttr.from_str(op_name) - else: - if not is_generic_format: - op = op_type.parse(result_types, self) - else: - op = self.parse_op_with_default_format(op_type, result_types) + return value - # Register the SSA value names in the parser - for (idx, res) in enumerate(results): - if res[0] in self._ssaValues: - raise ParserError(start_pos, - f"SSA value {res[0]} is already defined") - self._ssaValues[res[0]] = op.results[idx] - if self.is_valid_name(res[0]): - self._ssaValues[res[0]].name = res[0] + def _parse_op_result_list( + self) -> tuple[list[Span], list[Attribute] | None]: + if not self.tokenizer.starts_with("%"): + return list(), list() + results = self.parse_list_of( + self.try_parse_value_id_and_type, + "Expected (value-id `:` type) here!", + allow_empty=False, + ) + # TODO: this is hideous, make it cleaner + # zip(*results) works, but is barely readable :/ + return [name for name, _ in results], [type for _, type in results] + + def try_parse_builtin_attr(self) -> Attribute: + """ + Tries to parse a builtin attribute, e.g. a string literal, int, array, etc.. - return op + If the mode is xDSL, it also allows parsing of builtin types + """ + # In xdsl, two things are different here: + # 1. types are considered valid attributes + # 2. all types, builtins included, are prefixed with ! + if self.tokenizer.starts_with("!"): + return self.try_parse_builtin_type() - def parse_op_type( - self, - skip_white_space: bool = True - ) -> tuple[list[Attribute], list[Attribute]]: - self.parse_char("(", skip_white_space=skip_white_space) - inputs = self.parse_list(self.parse_optional_attribute) - self.parse_char(")") - self.parse_string("->") - - # No or multiple result types - if self.parse_optional_char("("): - outputs = self.parse_list(self.parse_optional_attribute) - self.parse_char(")") - else: - outputs = [self.parse_attribute()] + return super().try_parse_builtin_attr() - return inputs, outputs + def parse_optional_attr_dict(self) -> dict[str, Attribute]: + if not self.tokenizer.starts_with("["): + return dict() - def parse_mlir_op_with_default_format( - self, - op_type: type[_OperationType], - num_results: int, - skip_white_space: bool = True) -> _OperationType: - operands = self.parse_operands(skip_white_space=skip_white_space) + self.parse_characters( + "[", + "xDSL Attribute dictionary must be enclosed in square brackets") - regions = [] - if self.parse_optional_char("(") is not None: - regions = self.parse_list(self.parse_optional_region) - self.parse_char(")") - - attributes = self.parse_op_attributes() - - self.parse_char(":") - operand_types, result_types = self.parse_op_type() - - if len(operand_types) != len(operands): - raise Exception( - "Operand types are not matching the number of operands.") - if len(result_types) != num_results: - raise Exception( - "Result types are not matching the number of results.") - for operand, operand_type in zip(operands, operand_types): - if operand.typ != operand_type: - raise Exception("Operation operand types are not matching " - "the types of its operands. Got operand with " - f"type {operand.typ}, but operation expect " - f"operand to be of type {operand_type}") - - return op_type.create(operands=operands, - result_types=result_types, - attributes=attributes, - regions=regions) - - def parse_optional_mlir_op(self, - skip_white_space: bool = True - ) -> Operation | None: - start_pos = self._pos - results = self.parse_optional_results( - skip_white_space=skip_white_space) - if results is None: - results = [] - op_name = self.parse_optional_str_literal() - if op_name is None: - return None - else: - op_name = self.parse_str_literal() + attrs = self.parse_list_of(self._parse_attribute_entry, + "Expected attribute entry") - op_type = self.ctx.get_optional_op(op_name) - if op_type is None: - if not self.allow_unregistered_ops: - raise ParserError(start_pos, f"unknown operation '{op_name}'") + self.parse_characters( + "]", + "xDSL Attribute dictionary must be enclosed in square brackets") - op_type = UnregisteredOp - op = self.parse_mlir_op_with_default_format(op_type, len(results)) - op.attributes["op_name__"] = StringAttr.from_str(op_name) - else: - op = self.parse_mlir_op_with_default_format(op_type, len(results)) + return self._attr_dict_from_tuple_list(attrs) - # Register the SSA value names in the parser - for (idx, res) in enumerate(results): - if res in self._ssaValues: - raise ParserError(start_pos, - f"SSA value {res} is already defined") - self._ssaValues[res] = op.results[idx] - if self.is_valid_name(res): - self._ssaValues[res].name = res + def _parse_operation_details( + self, + ) -> tuple[list[Span], list[Span], dict[str, Attribute], list[Region], + FunctionType | None]: + """ + Must return a tuple consisting of: + - a list of arguments to the operation + - a list of successor names + - the attributes attached to the OP + - the regions of the op + - An optional function type. If not supplied, parse_op_result_list must return a second value + containing the types of the returned SSAValues - return op + """ + args = self._parse_op_args_list() + succ = self._parse_optional_successor_list() + attrs = self.parse_optional_attr_dict() + regions = self.parse_region_list() - def parse_op(self, skip_white_space: bool = True) -> Operation: - res = self.parse_optional_op(skip_white_space=skip_white_space) - if res is None: - raise ParserError(self._pos, "operation expected") - return res + return args, succ, attrs, regions, None + + def _parse_optional_successor_list(self) -> list[Span]: + if not self.tokenizer.starts_with("("): + return [] + self.parse_characters("(", + "Successor list is enclosed in round brackets") + successors = self.parse_list_of(self.try_parse_block_id, + "Expected a block-id", + allow_empty=False) + self.parse_characters(")", + "Successor list is enclosed in round brackets") + return successors + + def _parse_dialect_type_or_attribute_inner(self, kind: str): + if self.tokenizer.starts_with('"'): + name = self.try_parse_string_literal() + if name is None: + self.raise_error( + "Expected string literal for an attribute in generic format here!" + ) + return self._parse_generic_attribute_args(name) + return super()._parse_dialect_type_or_attribute_inner(kind) + + def _parse_generic_attribute_args(self, name: StringLiteral): + attr = self.ctx.get_optional_attr(name.string_contents) + if attr is None: + self.raise_error("Unknown attribute name!", name) + if not issubclass(attr, ParametrizedAttribute): + self.raise_error("Expected ParametrizedAttribute name here!", name) + self.parse_characters('<', + 'Expected generic attribute arguments here!') + args = self.parse_list_of(self.try_parse_attribute, + 'Unexpected end of attribute list!') + self.parse_characters( + '>', 'Malformed attribute arguments, reached end of args list!') + return attr(args) + + def _parse_op_args_list(self) -> list[Span]: + self.parse_characters( + "(", "Operation args list must be enclosed by brackets!") + args = self.parse_list_of(self.try_parse_value_id_and_type, + "Expected another bare-id here") + self.parse_characters( + ")", "Operation args list must be closed by a closing bracket") + # TODO: check if type is correct here! + return [name for name, _ in args] + + def try_parse_type(self) -> Attribute | None: + return self.try_parse_attribute() + + +# COMPAT layer so parser_ng is a drop-in replacement for parser: + + +class Source(Enum): + XDSL = 1 + MLIR = 2 + + +def Parser(ctx: MLContext, + prog: str, + source: Source = Source.XDSL, + filename: str = '', + allow_unregistered_ops=False) -> BaseParser: + selected_parser = { + Source.XDSL: XDSLParser, + Source.MLIR: MLIRParser + }[source] + return selected_parser(ctx, prog, filename) + + +setattr(Parser, 'Source', Source) diff --git a/xdsl/printer.py b/xdsl/printer.py index 85e2ed7c67..9b275d8935 100644 --- a/xdsl/printer.py +++ b/xdsl/printer.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json from dataclasses import dataclass, field from enum import Enum from frozenlist import FrozenList @@ -296,7 +297,7 @@ def print_paramattr_parameters( self.print(">") def print_string_literal(self, string: str): - self.print(f'"{string}"') + self.print(json.dumps(string)) def print_attribute(self, attribute: Attribute) -> None: if isinstance(attribute, UnitAttr): diff --git a/xdsl/tools/xdsl-opt b/xdsl/tools/xdsl-opt index 4923ccfe33..38206cfb69 100755 --- a/xdsl/tools/xdsl-opt +++ b/xdsl/tools/xdsl-opt @@ -1,25 +1,6 @@ #!/usr/bin/env python3 -import argparse from xdsl.xdsl_opt_main import xDSLOptMain - -class OptMain(xDSLOptMain): - - def register_all_dialects(self): - super().register_all_dialects() - - def register_all_passes(self): - super().register_all_passes() - - def register_all_arguments(self, arg_parser: argparse.ArgumentParser): - super().register_all_arguments(arg_parser) - - -def __main__(): - xdsl_main = OptMain() - xdsl_main.run() - - if __name__ == "__main__": - __main__() + xDSLOptMain().run() diff --git a/xdsl/utils/exceptions.py b/xdsl/utils/exceptions.py index c167fd472d..6bbb53358a 100644 --- a/xdsl/utils/exceptions.py +++ b/xdsl/utils/exceptions.py @@ -2,10 +2,17 @@ This module contains all custom exceptions used by xDSL. """ +from __future__ import annotations +import sys +import typing from dataclasses import dataclass +from io import StringIO from typing import Any -from xdsl.ir import Attribute + +if typing.TYPE_CHECKING: + from parser import Span, BacktrackingHistory + from xdsl.ir import Attribute class DiagnosticException(Exception): @@ -28,3 +35,61 @@ class BuilderNotFoundException(Exception): def __str__(self) -> str: return f"No builder found for attribute {self.attribute} with " \ f"arguments {self.args}" + + +class ParseError(Exception): + span: 'Span' + msg: str + history: 'BacktrackingHistory' | None + + def __init__(self, + span: 'Span', + msg: str, + history: 'BacktrackingHistory' | None = None): + preamble = "" + if history: + preamble = history.error.args[0] + '\n' + if span is None: + raise ValueError("Span can't be None!") + super().__init__(preamble + span.print_with_context(msg)) + self.span = span + self.msg = msg + self.history = history + + def print_pretty(self, file=sys.stderr): + print(self.span.print_with_context(self.msg), file=file) + + def print_with_history(self, file=sys.stderr): + if self.history is not None: + for h in sorted(self.history.iterate(), key=lambda h: -h.pos): + h.print() + else: + self.print_pretty(file) + + def __repr__(self): + io = StringIO() + self.print_with_history(io) + return "{}:\n{}".format(self.__class__.__name__, io.getvalue()) + + +class MultipleSpansParseError(ParseError): + ref_text: str | None + refs: list[tuple['Span', str]] + + def __init__( + self, + span: 'Span', + msg: str, + ref_text: str, + refs: list[tuple['Span', str | None]], + history: 'BacktrackingHistory' | None = None, + ): + super(MultipleSpansParseError, self).__init__(span, msg, history) + self.refs = refs + self.ref_text = ref_text + + def print_pretty(self, file=sys.stderr): + super(MultipleSpansParseError, self).print_pretty(file) + print(self.ref_text or "With respect to:", file=file) + for span, msg in self.refs: + print(span.print_with_context(msg), file=file) diff --git a/xdsl/xdsl_opt_main.py b/xdsl/xdsl_opt_main.py index 7c1190b359..9ce2de6f6e 100644 --- a/xdsl/xdsl_opt_main.py +++ b/xdsl/xdsl_opt_main.py @@ -3,9 +3,10 @@ import os from io import IOBase, StringIO import coverage +from typing.io import IO from xdsl.ir import MLContext -from xdsl.parser import Parser +from xdsl.parser import XDSLParser, MLIRParser from xdsl.printer import Printer from xdsl.dialects.func import Func from xdsl.dialects.scf import Scf @@ -216,30 +217,13 @@ def register_all_frontends(self): Add other/additional frontends by overloading this function. """ - def parse_xdsl(f: IOBase): - input_str = f.read() - parser = Parser( - self.ctx, - input_str, - allow_unregistered_ops=self.args.allow_unregistered_ops) - module = parser.parse_op() - if not (isinstance(module, ModuleOp)): - raise Exception( - "Expected module or program as toplevel operation") - return module - - def parse_mlir(f: IOBase): - input_str = f.read() - parser = Parser( - self.ctx, - input_str, - source=Parser.Source.MLIR, - allow_unregistered_ops=self.args.allow_unregistered_ops) - module = parser.parse_op() - if not (isinstance(module, ModuleOp)): - raise Exception( - "Expected module or program as toplevel operation") - return module + def parse_xdsl(io: IOBase): + return XDSLParser(self.ctx, io.read(), self.get_input_name(), + self.args.allow_unregistered_ops).parse_module() + + def parse_mlir(io: IOBase): + return MLIRParser(self.ctx, io.read(), self.get_input_name(), + self.args.allow_unregistered_ops).parse_module() self.available_frontends['xdsl'] = parse_xdsl self.available_frontends['mlir'] = parse_mlir @@ -352,3 +336,6 @@ def print_to_output_stream(self, contents: str): else: output_stream = open(self.args.output_file, 'w') output_stream.write(contents) + + def get_input_name(self): + return self.args.input_file or 'stdin'