From 483de42c3854f9b21330fa0dfd45969e02e9db30 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Thu, 8 Dec 2022 17:27:28 +0000 Subject: [PATCH 01/65] [parser] Initial design concept for a new parser that supports BNF-style subroutines --- xdsl/parser_ng.py | 874 ++++++++++++++++++++++++++++++++++++++++++++++ xdsl/utils/bnf.py | 224 ++++++++++++ 2 files changed, 1098 insertions(+) create mode 100644 xdsl/parser_ng.py create mode 100644 xdsl/utils/bnf.py diff --git a/xdsl/parser_ng.py b/xdsl/parser_ng.py new file mode 100644 index 0000000000..9930f4bcf8 --- /dev/null +++ b/xdsl/parser_ng.py @@ -0,0 +1,874 @@ +from __future__ import annotations + +import contextlib +from dataclasses import dataclass, field +import re +import ast +from typing import Any, TypeVar, Iterable, Literal +from enum import Enum + + +from xdsl.ir import (SSAValue, Block, Callable, Attribute, Operation, Region, + BlockArgument, MLContext, ParametrizedAttribute) + +import xdsl.utils.bnf as BNF + +from xdsl.dialects.builtin import ( + AnyFloat, AnyTensorType, AnyUnrankedTensorType, AnyVectorType, + DenseIntOrFPElementsAttr, Float16Type, Float32Type, Float64Type, FloatAttr, + FunctionType, IndexType, IntegerType, OpaqueAttr, Signedness, StringAttr, + FlatSymbolRefAttr, IntegerAttr, ArrayAttr, TensorType, UnitAttr, + UnrankedTensorType, UnregisteredOp, VectorType) + +from xdsl.irdl import Data + + +@dataclass +class ParseError(Exception): + span: Span + msg: str + + +class BacktrackingAbort(Exception): + reason: str | None + + def __init__(self, reason: str | None = None): + super("This message should never escape the parser, it's intended to signal a failed parsing attempt\n" + "It should never be used outside of a tokenizer.backtracking() block!\n" + "The reason for this abort was {}".format('not specified' if reason is None else reason)) + self.reason = reason + + +@dataclass(frozen=True) +class Span: + """ + Parts of the input are always passed around as spans so we know where they originated. + """ + + start: int + """ + Start of tokens location in source file, global byte offset in file + """ + end: int + """ + End of tokens location in source file, global byte offset in file + """ + input: Input + """ + The input being operated on + """ + + def __len__(self): + return self.len + + @property + def len(self): + return self.end - self.start + + @property + def text(self): + return self.input.content[self.start:self.end] + + def print_with_context(self, msg: str | None = None): + info = self.input.get_lines_containing(self) + assert info is not None + lines, offset, line_no = info + remaining_len = self.len + print("In {}:{}".format(self.input.name, line_no)) + for line in lines: + print(line) + if offset > len(line): + offset -= len(line) + continue + if remaining_len <= 0: + continue + len_on_this_line = min(remaining_len, len(line) - offset) + remaining_len -= len_on_this_line + print((" " * offset) + ("^" * len_on_this_line)) + if msg is not None: + print("{}{}".format(" " * offset, msg)) + msg = None + offset = 0 + + def __repr__(self): + return "Span[{}:{}](text='{}', input={})".format(self.start, self.end, self.text, self.input) + + +@dataclass(frozen=True) +class StringLiteral(Span): + def __post_init__(self): + if len(self) < 2 or self.text[0] != '"' or self.text[-1] != '"': + raise ParseError(self, "Invalid string literal!") + + T_ = TypeVar('T_', Span, None) + + @classmethod + def from_span(cls, span: T_) -> T_: + if span is None: + return None + return cls(span.start, span.end, span.input) + + @property + def string_contents(self): + # TODO: is this a hack-job? + return ast.literal_eval(self.text) + + +@dataclass(frozen=True) +class Input: + """ + This is a very simple class that is used to keep track of the input. + """ + name: str + content: str = field(repr=False) + + @property + def len(self): + return len(self.content) + + def __len__(self): + return self.len + + def get_nth_line_bounds(self, n: int): + start = 0 + for i in range(n): + next_start = self.content.find('\n', start) + if next_start == -1: + return None + start = next_start + 1 + return start, self.content.find('\n', start) + + def get_lines_containing(self, span: Span) -> tuple[list[str], int, int] | None: + start = 0 + line_no = 0 + source = self.content + while True: + next_start = source.find('\n', start) + line_no += 1 + # handle eof + if next_start == -1: + return None + # as long as the next newline comes before the spans start we are good + if next_start < span.start: + start = next_start + 1 + continue + # if the whole span is on one line, we are good as well + if next_start >= span.end: + return [source[start:next_start]], start, line_no + while next_start < span.end: + next_start = source.find('\n', next_start + 1) + return source[start:next_start].split('\n'), start, line_no + + def at(self, i: int): + if i >= self.len: + raise EOFError() + return self.content[i] + + +save_t = tuple[int, tuple[str, ...], bool] +parsed_type_t = tuple[Span, tuple[Span]] + +@dataclass +class Tokenizer: + input: Input + + pos: int = field(init=False, default=0) + """ + The position in the input. Points to the first unconsumed character. + """ + + break_on: tuple[str, ...] = ( + '.', '%', ' ', '(', ')', '[', ']', '{', '}', '<', '>', ':', '=', '@', '?', '|', '->', '-', '//', '\n', '\t', '#' + ) + """ + characters the tokenizer should break on + """ + + ignore_whitespace: bool = True + + last_error: ParseError | None = field(init=False, default=None) + last_token: Span | None = field(init=False, default=None) + + def save(self) -> save_t: + """ + Create a checkpoint in the parsing process, useful for backtracking + """ + return self.pos, self.break_on, self.ignore_whitespace + + def resume_from(self, save: save_t): + """ + Resume from a previously saved position. + + Restores the state of the tokenizer to the exact previous position + """ + self.pos, self.break_on, self.ignore_whitespace = save + + @contextlib.contextmanager + def backtracking(self): + """ + Used to create backtracking parsers. You can wrap you parse code into + + with tokenizer.backtracking(): + # do some stuff + assert x == 'array' + + All exceptions triggered in the body will abort the parsing attempt, but not escape further. + + The tokenizer state will not change. + + When backtracking occurred, the backtracker will save the last exception in last_error + """ + save = self.save() + try: + self.last_error = None + yield + except Exception as ex: + if isinstance(ex, BacktrackingAbort): + self.last_error = ParseError( + self.next_token(peek=True), + 'Backtracking aborted: {}'.format(ex.reason or 'unknown reason') + ) + elif isinstance(ex, AssertionError): + reason = ['Generic assertion failure', *(reason for reason in ex.args if isinstance(reason, str))] + # we assume that assertions fail because of the last read-in token + self.last_error = ParseError(self.last_token, reason[-1]) + elif isinstance(ex, ParseError): + self.last_error = ex + print("Warning: ParseError in backtracking: {}".format(ex)) + else: + print("Warning: Unexpected error in backtracking: {}".format(ex)) + self.resume_from(save) + + def next_token(self, start: int | None = None, skip: int = 0, peek: bool = False, + include_comments: bool = False) -> Span: + """ + Best effort guess at what the next token could be + """ + i = self.next_pos(start) + while skip > 0: + # skip whitespace if able + i = self.next_pos(self._find_token_end(i)) + skip -= 1 + # advance to the next position + if not peek: + self.pos = self._find_token_end(i) + + span = self.span_of(i, self.pos) + if not include_comments and span.text == '//': + while self.input.at(i) != '\n': + i += 1 + return self.next_token(i, 0, peek, include_comments) + + # save last token + self.last_token = span + return span + + def next_token_of_pattern(self, pattern: re.Pattern, peek: bool = False) -> Span | None: + """ + Return a span that matched the pattern, or nothing. You can choose not to consume the span. + """ + start = self.next_pos() + match = pattern.match(self.input.content, start) + if match is None: + return None + if not peek: + self.pos = match.end() + # save last token + self.last_token = self.span_of(start, match.end()) + return self.last_token + + def jump_back_to(self, span: Span): + """ + This can be used to "rewind" the tokenizer back to the point right before you consumed the token. + + This leaves everything except the position untouched + """ + self.pos = span.start + + def consume_peeked(self, peeked_span: Span): + if peeked_span.start != self.next_pos(): + raise ParseError(peeked_span, "This is not the peeked span!") + self.pos = peeked_span.end + + def _find_token_end(self, start: int | None = None) -> int: + """ + Find the point (optionally starting from start) where the token ends + """ + i = self.next_pos() if start is None else start + # search for literal breaks + for part in self.break_on: + if self.input.content.startswith(part, i): + return i + len(part) + # otherwise return the start of the next break + return min(filter(lambda x: x >= 0, (self.input.content.find(part, i) for part in self.break_on))) + + def next_pos(self, i: int | None = None) -> int: + """ + Find the next starting position (optionally starting from i), considering ignore_whitespaces + """ + i = self.pos if i is None else i + # skip whitespaces + if self.ignore_whitespace: + while self.input.at(i).isspace(): + i += 1 + return i + + def is_eof(self): + try: + i = self.pos + while self.input.at(i).isspace(): + i += 1 + return False + except EOFError: + return True + + def span_of(self, start: int, end: int) -> Span: + return Span(start, end, self.input) + + def consume_opt_whitespace(self) -> Span: + start = self.pos + while self.input.at(self.pos).isspace(): + self.pos += 1 + return self.span_of(start, self.pos) + + @contextlib.contextmanager + def configured(self, break_on: tuple[str, ...] | None = None, ignore_whitespace: bool | None = None): + """ + This is a helper class to allow expressing a temporary change in config, allowing you to write: + + # parsing double-quoted string now + string_content = "" + with tokenizer.configured(break_on=('"', '\\'), ignore_whitespace=False): + # use tokenizer + + # now old config is restored automatically + + """ + save = self.save() + + if break_on is not None: + self.break_on = break_on + if ignore_whitespace is not None: + self.ignore_whitespace = ignore_whitespace + + try: + yield self + finally: + self.break_on = save[1] + self.ignore_whitespace = save[2] + + +class ParserCommons: + """ + Colelction of common things used in parsing MLIR/IRDL + + """ + integer_literal = re.compile(r'[+-]?([0-9]+|0x[0-9A-f]+)') + decimal_literal = re.compile(r'[+-]?([1-9][0-9]*)') + string_literal = re.compile(r'"([^\n\f\v\r"]|\\[nfvr"])+"') + float_literal = re.compile(r'[-+]?[0-9]+\.[0-9]*([eE][-+]?[0-9]+)?') + bare_id = re.compile(r'[A-z_][A-z0-9_$.]+') + value_id = re.compile(r'%[A-z_][A-z0-9_$.]+') + suffix_id = re.compile(r'([0-9]+|([A-z_$.-][0-9A-z_$.-]*))') + block_id = re.compile(r'\^([0-9]+|([A-z_$.-][0-9A-z_$.-]*))') + type_alias = re.compile(r'![A-z_][A-z0-9_$.]+') + attribute_alias = re.compile(r'#[A-z_][A-z0-9_$.]+') + builtin_type = re.compile('({})'.format( + '|'.join(( + r'[su]?i\d+', 'tensor', 'vector', + 'memref', 'complex', 'opaque', + 'tuple', 'index', + # TODO: add all the FloatNtype, Float8E4M3FNType, Float8E5M2Type, and BFloat16Type + )) + )) + double_colon = re.compile('::') + comma = re.compile(',') + + class BNF: + """ + Collection of BNF trees. + """ + generic_operation = BNF.Group([ + BNF.Nonterminal('string-literal', bind="name"), + BNF.Literal('('), + BNF.ListOf(BNF.Nonterminal('value-id'), bind='args'), + BNF.Literal(')'), + BNF.OptionalGroup([ + BNF.Literal('['), + BNF.ListOf(BNF.Nonterminal('block-id'), allow_empty=False, bind='blocks'), # TODD: allow for block args here?! (accordin to spec) + BNF.Literal(']') + ], bind='blocks_group'), + BNF.ListOf(BNF.Nonterminal('region'), bind='regions'), + BNF.Nonterminal('attr-dict', bind='attributes'), + BNF.Literal(':'), + BNF.Nonterminal('function-type', bind='type_signature') + ]) + + +class MlirParser: + """ + Basic recursive descent parser. + + methods marked try_... will attempt to parse, and return None if they failed. If they return None + they must make sure to restore all state. + + methods marked must_... will do greedy parsing, meaning they consume as much as they can. They will + also throw an error if the think they should still be parsing. e.g. when parsing a list of numbers + separated by '::', the following input will trigger an exception: + 1::2:: + Due to the '::' present after the last element. This is useful for parsing lists, as a trailing + separator is usually considered a syntax error there. + + You can turn a try_ into a must_ by using expect(try_parse_..., error_msg) + + You can turn a must_ into a try_ by wrapping it inside of a tokenizer.backtracking() + + must_ type parsers are preferred because they are explicit about their failure modes. + """ + + class Accent(Enum): + XDSL = 'xDSL' + MLIR = 'MLIR' + + accent: Accent + + ctx: MLContext + """xDSL context.""" + + _ssaValues: dict[str, SSAValue] = field(init=False, default_factory=dict) + _blocks: dict[str, Block] = field(init=False, default_factory=dict) + + T_ = TypeVar('T_') + """ + Type var used for handling function that return single or multiple Spans. Basically the output type + of all try_parse functions is T_ | None + """ + + def __int__(self, input: str, name: str, ctx: MLContext, accent: str | Accent = 'xDSL'): + self.tokenizer = Tokenizer(Input(input, name)) + self.ctx = ctx + if isinstance(accent, str): + accent = MlirParser.Accent[accent] + self.accent = accent + + def begin_parse(self): + pass + + def must_parse_block(self) -> Block | None: + next_id = self.expect(self.try_parse_block_id, 'Blocks must start with a block id!') + + assert next_id.text not in self._blocks + + block = Block() + self._blocks[next_id.text] = block + + if self.tokenizer.next_token(peek=True).text == '(': + for i, (name, type) in enumerate(self.must_parse_block_arg_list()): + arg = BlockArgument(type, block, i) + self._ssaValues[name.text] = arg + block.args.append(arg) + + while (next_op := self.try_parse_op()) is not None: + block.ops.append(next_op) + + return block + + def get_or_create_block_arg(self, name: Span, type: Attribute): + if name.text in self._ssaValues: + val = self._ssaValues.get(name.text) + assert val.typ == type + return val + self._ssaValues[name.text] = BlockArgument(type, ) + + def must_parse_block_arg_list(self) -> list[tuple[Span, Attribute]]: + self.assert_eq(self.tokenizer.next_token(), '(', 'Block arguments must start with `(`') + + args = self.must_parse_list_of(self.try_parse_value_id_and_type, "Expected ") + + self.assert_eq(self.tokenizer.next_token(), ')', 'Expected closing of block arguments!') + + return args + + def try_parse_single_reference(self) -> Span | None: + with self.tokenizer.backtracking(): + self.must_parse_characters('@', "references must start with `@`") + if (reference := self.try_parse_string_literal()) is not None: + return reference + if (reference := self.try_parse_suffix_id()) is not None: + return reference + raise BacktrackingAbort("References must conform to `@` (string-literal | suffix-id)") + + def must_parse_reference(self) -> list[Span]: + return self.must_parse_list_of( + self.try_parse_single_reference, + 'Expected reference here in the format of `@` (suffix-id | string-literal)', + ParserCommons.double_colon, + allow_empty=False + ) + + def must_parse_list_of(self, try_parse: Callable[[], T_ | None], error_msg: str, + separator_pattern: re.Pattern = ParserCommons.comma, allow_empty: bool = True) -> list[T_]: + items = list() + first_item = try_parse() + if first_item is None: + if allow_empty: + return items + self.raise_error(error_msg) + + items.append(first_item) + + while self.tokenizer.next_token_of_pattern(separator_pattern) is not None: + next_item = try_parse() + if next_item is None: + self.raise_error(error_msg) + items.append(next_item) + + return items + + def try_parse_integer_literal(self) -> Span | None: + return self.tokenizer.next_token_of_pattern(ParserCommons.integer_literal) + + def try_parse_decimal_literal(self) -> Span | None: + return self.tokenizer.next_token_of_pattern(ParserCommons.decimal_literal) + + def try_parse_string_literal(self) -> StringLiteral | None: + return StringLiteral.from_span(self.tokenizer.next_token_of_pattern(ParserCommons.string_literal)) + + def try_parse_float_literal(self) -> Span | None: + return self.tokenizer.next_token_of_pattern(ParserCommons.float_literal) + + def try_parse_bare_id(self) -> Span | None: + return self.tokenizer.next_token_of_pattern(ParserCommons.bare_id) + + def try_parse_value_id(self) -> Span | None: + return self.tokenizer.next_token_of_pattern(ParserCommons.value_id) + + def try_parse_suffix_id(self) -> Span | None: + return self.tokenizer.next_token_of_pattern(ParserCommons.suffix_id) + + def try_parse_block_id(self) -> Span | None: + return self.tokenizer.next_token_of_pattern(ParserCommons.block_id) + + def try_parse_value_id_and_type(self) -> tuple[Span, Attribute] | None: + with self.tokenizer.backtracking(): + value_id = self.try_parse_value_id() + + if value_id is None: + raise BacktrackingAbort("Expected value id here!") + + self.must_parse_characters(':', 'Expected expression (value-id `:` type)') + + type = self.try_parse_type() + + if type is None: + raise BacktrackingAbort("Expected type of value-id here!") + return value_id, type + + def try_parse_type(self) -> Attribute | None: + if (builtin_type := self.try_parse_builtin_type()) is not None: + return builtin_type + if (dialect_type := self.try_parse_dialect_type_or_attribute('type')) is not None: + return dialect_type + return None + + def try_parse_dialect_type_or_attribute(self, kind: Literal['type', 'attr']) -> Attribute | None: + with self.tokenizer.backtracking(): + if kind == 'type': + self.must_parse_characters('!', "Dialect types must start with a `!`") + else: + self.must_parse_characters('#', "Dialect attributes must start with a `#`") + + type_name = self.tokenizer.next_token_of_pattern(ParserCommons.bare_id) + + if type_name is None: + raise BacktrackingAbort("Expected a type name") + + type_def = self.ctx.get_attr(type_name.text) + + # pass the task of parsing parameters on to the attribute/type definition + param_list = type_def.parse_parameters(self) + return type_def(param_list) + + def try_parse_builtin_type(self) -> Attribute | None: + """ + parse a builtin-type like i32, index, vector etc. + """ + with self.tokenizer.backtracking(): + name = self.tokenizer.next_token_of_pattern(ParserCommons.builtin_type) + if name is None: + raise BacktrackingAbort("Expected builtin name!") + if name.text == 'index': + return IndexType.build() + if (re_match := re.match(r'^([su]?i(\d)+)$', name.text)) is not None: + signedness = { + 's': Signedness.SIGNED, + 'u': Signedness.UNSIGNED, + 'i': Signedness.SIGNLESS + } + return IntegerType.from_width(int(re_match.group(1)), signedness[name.text[0]]) + + return self.must_parse_builtin_parametrized_type(name) + + def must_parse_builtin_parametrized_type(self, name: Span) -> ParametrizedAttribute: + def unimplemented() -> ParametrizedAttribute: + raise ParseError(self.tokenizer.next_token(), "Type not supported yet!") + + builtin_parsers: dict[str, Callable[[], ParametrizedAttribute]] = { + 'vector': self.must_parse_vector_attrs, + 'memref': unimplemented, + 'tensor': self.must_parse_tensor_attrs, + 'complex': self.must_parse_complex_attrs, + 'opaque': unimplemented, + 'tuple': unimplemented, + } + if name.text not in builtin_parsers: + raise ParseError(name, "Unknown builtin {}".format(name.text)) + + self.assert_eq(self.tokenizer.next_token(), '<', 'Expected parameter list here!') + res = builtin_parsers[name.text]() + self.assert_eq(self.tokenizer.next_token(), '>', 'Expected end of parameter list here!') + return res + + def must_parse_complex_attrs(self): + type = self.try_parse_type() + self.raise_error("ComplexType is unimplemented!") + + def try_parse_numerical_dims(self, accept_closing_bracket: bool = False, lower_bound: int = 1) -> Iterable[int]: + while (shape_arg := self.try_parse_shape_element(lower_bound)) is not None: + yield shape_arg + # look out for the closing bracket for scalable vector dims + if accept_closing_bracket and self.tokenizer.next_token(peek=True).text == ']': + break + self.assert_eq(self.tokenizer.next_token(), 'x', 'Unexpected end of dimension parameters!') + + def must_parse_vector_attrs(self) -> AnyVectorType: + # also break on 'x' characters as they are separators in dimension parameters + with self.tokenizer.configured(break_on=self.tokenizer.break_on + ('x',)): + shape = list[int](self.try_parse_numerical_dims()) + scaling_shape: list[int] | None = None + + if self.tokenizer.next_token(peek=True).text == '[': + self.tokenizer.next_token() + # we now need to parse the scalable dimensions + scaling_shape = list(self.try_parse_numerical_dims()) + self.assert_eq(self.tokenizer.next_token(), ']', 'Expected end of scalable vector dimensions here!') + self.assert_eq(self.tokenizer.next_token(), 'x', 'Expected end of scalable vector dimensions here!') + + if scaling_shape is not None: + # TODO: handle scaling vectors! + print("Warning: scaling vectors not supported!") + pass + + type = self.try_parse_type() + if type is None: + self.raise_error("Expected a type at the end of the vector parameters!") + + return VectorType.from_type_and_list(type, shape) + + def must_parse_tensor_or_memref_dims(self) -> list[int] | None: + with self.tokenizer.configured(break_on=self.tokenizer.break_on + ('x',)): + if self.tokenizer.next_token(peek=True).text == '*': + # consume `*` + self.tokenizer.next_token() + # consume `x` + self.assert_eq(self.tokenizer.next_token(), 'x', 'Unranked tensors must follow format (`<*x` type `>`)') + else: + # parse rank: + return list(self.try_parse_numerical_dims(lower_bound=0)) + + def must_parse_tensor_attrs(self) -> AnyTensorType: + shape = self.must_parse_tensor_or_memref_dims() + type = self.try_parse_type() + + if type is None: + self.raise_error("Expected tensor type here!") + + if self.tokenizer.next_token(peek=True).text == ',': + # TODO: add tensor encoding! + raise self.raise_error("Parsing tensor encoding is not supported!") + + if shape is None and self.tokenizer.next_token(peek=True).text == ',': + raise self.raise_error("Unranked tensors don't have an encoding!") + + if shape is not None: + return TensorType.from_type_and_list(type, shape) + + return UnrankedTensorType.from_type(type) + + def try_parse_shape_element(self, lower_bound: int = 1) -> int | None: + """ + Parse a shape element, either a decimal integer immediate or a `?`, which evaluates to -1 + + immediate cannot be smaller than lower_bound (defaults to 1) (is 0 for tensors and memrefs) + """ + int_lit = self.try_parse_decimal_literal() + + if int_lit is not None: + value = int(int_lit.text) + if value < lower_bound: + # TODO: this is ugly, it's a raise inside a try_ type function, which should instead just give up + raise ParseError(int_lit, "Shape element literal cannot be negative or zero!") + return value + + next_token = self.tokenizer.next_token(peek=True) + + if next_token.text == '?': + self.tokenizer.consume_peeked(next_token) + return -1 + return None + + def must_parse_type_params(self) -> list[parsed_type_t]: + # consume opening bracket + assert self.tokenizer.next_token().text == '<', 'Type must be parameterized!' + + params = self.must_parse_list_of( + self.try_parse_type, + 'Expected a type here!' + ) + + assert self.tokenizer.next_token().text == '>', 'Expected end of type parameterization here!' + + return params + + def expect(self, try_parse: Callable[[], T_ | None], error_message: str) -> T_: + """ + Used to force completion of a try_parse function. Will throw a parse error if it can't + """ + res = try_parse() + if res is None: + self.raise_error(error_message) + return res + + def raise_error(self, msg: str, at_position: Span | None = None): + """ + Helper for raising exceptions, provides as much context as possible to them. + + This will, for example, include backtracking errors, if any occured previously + """ + if at_position is None: + at_position = self.tokenizer.next_token(peek=True) + + # include backtracking exception if available + if self.tokenizer.last_error: + raise ParseError(at_position, msg) from self.tokenizer.last_error + + raise ParseError(at_position, msg) + + def assert_eq(self, a: Span, b: str, msg: str): + if a.text == b: + return + raise AssertionError("Assertion failed ({} == {}): {}".format(a.text, b, msg), a) + + def assert_neq(self, a: Span, b: str, msg: str): + if a.text != b: + return + raise AssertionError("Assertion failed ({} != {}): {}".format(a.text, b, msg), a) + + def assert_in(self, a: Span, b: tuple[str], msg: str): + if a.text in b: + return + raise AssertionError("Assertion failed ({} in {}): {}".format(a.text, b, msg), a) + + def must_parse_characters(self, text: str, msg: str): + self.assert_eq(self.tokenizer.next_token(), text, msg) + + def try_parse_op_result_list(self) -> list[tuple[Span, Attribute] | Span] | None: + inner_parser = (dict(( + (MlirParser.Accent.MLIR, self.try_parse_value_id), + (MlirParser.Accent.XDSL, self.try_parse_value_id_and_type) + )))[self.accent] + + return self.must_parse_list_of(inner_parser, 'Expected op-result here!', allow_empty=False) + + def try_parse_op(self): + with self.tokenizer.backtracking(): + result_list = self.try_parse_op_result_list() + self.must_parse_characters('=', 'Operation definitions expect an `=` after op-result-list!') + name = self.try_parse_op_name() + + # handle custom-operation parsing + if not isinstance(name, StringLiteral): + op_type = self.ctx.get_op(name.text) + # TODO: how do we pass result types if we are in xDSL format? + op_type.parse() + + op_type = self.ctx.get_op(name.string_contents) + + + def try_parse_op_name(self) -> Span | None: + if (str_lit := self.try_parse_string_literal()) is not None: + return str_lit + return self.try_parse_bare_id() + + + + + +""" +digit ::= [0-9] +hex_digit ::= [0-9a-fA-F] +letter ::= [a-zA-Z] +id-punct ::= [$._-] + +integer-literal ::= decimal-literal | hexadecimal-literal +decimal-literal ::= digit+ +hexadecimal-literal ::= `0x` hex_digit+ +float-literal ::= [-+]?[0-9]+[.][0-9]*([eE][-+]?[0-9]+)? +string-literal ::= `"` [^"\n\f\v\r]* `"` TODO: define escaping rules + +bare-id ::= (letter|[_]) (letter|digit|[_$.])* +bare-id-list ::= bare-id (`,` bare-id)* +value-id ::= `%` suffix-id +alias-name :: = bare-id +suffix-id ::= (digit+ | ((letter|id-punct) (letter|id-punct|digit)*)) + + +symbol-ref-id ::= `@` (suffix-id | string-literal) (`::` symbol-ref-id)? +value-id-list ::= value-id (`,` value-id)* + +// Uses of value, e.g. in an operand list to an operation. +value-use ::= value-id +value-use-list ::= value-use (`,` value-use)* + +operation ::= op-result-list? (generic-operation | custom-operation) + trailing-location? +generic-operation ::= string-literal `(` value-use-list? `)` successor-list? + region-list? dictionary-attribute? `:` function-type +custom-operation ::= bare-id custom-operation-format +op-result-list ::= op-result (`,` op-result)* `=` +op-result ::= value-id (`:` integer-literal) +successor-list ::= `[` successor (`,` successor)* `]` +successor ::= caret-id (`:` block-arg-list)? +region-list ::= `(` region (`,` region)* `)` +dictionary-attribute ::= `{` (attribute-entry (`,` attribute-entry)*)? `}` +trailing-location ::= (`loc` `(` location `)`)? + +block ::= block-label operation+ +block-label ::= block-id block-arg-list? `:` +block-id ::= caret-id +caret-id ::= `^` suffix-id +value-id-and-type ::= value-id `:` type + +// Non-empty list of names and types. +value-id-and-type-list ::= value-id-and-type (`,` value-id-and-type)* + +block-arg-list ::= `(` value-id-and-type-list? `)` + +type ::= type-alias | dialect-type | builtin-type + +type-list-no-parens ::= type (`,` type)* +type-list-parens ::= `(` `)` + | `(` type-list-no-parens `)` + +// This is a common way to refer to a value with a specified type. +ssa-use-and-type ::= ssa-use `:` type +ssa-use ::= value-use + +// Non-empty list of names and types. +ssa-use-and-type-list ::= ssa-use-and-type (`,` ssa-use-and-type)* + +function-type ::= (type | type-list-parens) `->` (type | type-list-parens) + +type-alias-def ::= '!' alias-name '=' type +type-alias ::= '!' alias-name +""" diff --git a/xdsl/utils/bnf.py b/xdsl/utils/bnf.py new file mode 100644 index 0000000000..972f3b8499 --- /dev/null +++ b/xdsl/utils/bnf.py @@ -0,0 +1,224 @@ +from __future__ import annotations +import functools +import re +import typing +from dataclasses import dataclass, field +from abc import abstractmethod +from typing import Any + +if typing.TYPE_CHECKING: + from xdsl.parser_ng import MlirParser, ParseError + +T = typing.TypeVar('T') + + +@dataclass(frozen=True) +class BNFToken(typing.Generic[T]): + bind: str | None = field(kw_only=True, init=False) + + @abstractmethod + def must_parse(self, parser: MlirParser) -> T: + raise NotImplemented() + + def try_parse(self, parser: MlirParser) -> T | None: + with parser.tokenizer.backtracking(): + return self.must_parse(parser) + + def collect(self, value, collection: dict): + if self.bind is None: + return + collection[self.bind] = value + + +@dataclass(frozen=True) +class Literal(BNFToken): + """ + Match a fixed input string + """ + string: str + bind: str | None = field(kw_only=True, default=None) + + def must_parse(self, parser: MlirParser): + return parser.must_parse_characters(self.string, 'Expected `{}`'.format(self.string)) + + def __repr__(self): + return '`{}`'.format(self.string) + + +@dataclass(frozen=True) +class Regex(BNFToken): + pattern: re.Pattern + bind: str | None = field(kw_only=True, default=None) + + def try_parse(self, parser: MlirParser) -> T | None: + return parser.tokenizer.next_token_of_pattern(self.pattern) + + def must_parse(self, parser: MlirParser) -> T: + res = self.try_parse(parser) + if res is None: + parser.raise_error('Expected token of form {}!'.format(self)) + + def __repr__(self): + return 're`{}`'.format(self.pattern.pattern) + + +@dataclass(frozen=True) +class Nonterminal(BNFToken): + """ + This is used as an "escape hatch" to switch from BNF to the python parsing code. + + It will look for must_parse_, or try_parse_ in the parse object. This can + probably be improved, idk. + """ + + name: str + """ + The symbol name of the nonterminal, e.g. string-lieral, tensor-attrs, etc... + """ + bind: str | None = field(kw_only=True, default=None) + + def must_parse(self, parser: MlirParser): + if hasattr(parser, 'must_parse_{}'.format(self.name.replace('-', '_'))): + return getattr(parser, 'must_parse_{}'.format(self.name.replace('-', '_')))(), self.bind + elif hasattr(parser, 'try_parse_{}'.format(self.name.replace('-', '_'))): + return parser.expect( + getattr(parser, 'try_parse_{}'.format(self.name.replace('-', '_'))), + 'Expected to parse {} here!'.format(self.name) + ), self.bind + else: + NotImplemented("Parser cannot parse {}".format(self.name)) + + def try_parse(self, parser: MlirParser) -> T | None: + if hasattr(parser, 'try_parse_{}'.format(self.name.replace('-', '_'))): + return getattr(parser, 'try_parse_{}'.format(self.name.replace('-', '_')))() + return super().try_parse(parser) + + def __repr__(self): + return self.name + + +@dataclass(frozen=True) +class Group(BNFToken): + tokens: list[BNFToken] + bind: str | None = field(kw_only=True, default=None) + + def must_parse(self, parser: MlirParser) -> T: + return [ + token.must_parse(parser) for token in self.tokens + ] + + def __repr__(self): + return '( {} )'.format(' '.join(repr(t) for t in self.tokens)) + + def collect(self, value, collection: dict): + for child, value in zip(self.tokens, value): + child.collect(value, collection) + if self.bind is not None: + collection[self.bind] = value + + +@dataclass(frozen=True) +class OneOrMoreOf(BNFToken): + wraps: BNFToken[T] + bind: str | None = field(kw_only=True, default=None) + + def must_parse(self, parser: MlirParser) -> list[T]: + res = list() + while True: + val = self.wraps.try_parse(parser) + if val is None: + if len(res) == 0: + raise AssertionError("Expected at least one of {}".format(self.wraps)) + return res + res.append(val) + + def __repr__(self): + return '{}+'.format(self.wraps) + + def children(self) -> typing.Iterable[BNFToken]: + return self.wraps, + + def collect(self, value, collection: dict): + for val in value: + self.wraps.collect(val, collection) + if self.bind is not None: + collection[self.bind] = value + + +@dataclass(frozen=True) +class ZeroOrMoreOf(BNFToken): + wraps: BNFToken[T] + bind: str | None = field(kw_only=True, default=None) + + def must_parse(self, parser: MlirParser) -> list[T]: + res = list() + while True: + val = self.wraps.try_parse(parser) + if val is None: + return res + res.append(val) + + def __repr__(self): + return '{}*'.format(self.wraps) + + def children(self) -> typing.Iterable[BNFToken]: + return self.wraps, + + def collect(self, values, collection: dict): + for value in values: + self.wraps.collect(value, collection) + if self.bind is not None: + collection[self.bind] = values + + +@dataclass(frozen=True) +class ListOf(BNFToken): + element: BNFToken + separator: re.Pattern = re.compile(',') + + allow_empty: bool = True + bind: str | None = field(kw_only=True, default=None) + + def try_parse(self, parser: MlirParser) -> T | None: + return parser.must_parse_list_of( + self.element.try_parse, + 'Expected {}!'.format(self.element), + separator_pattern=self.separator, + allow_empty=self.allow_empty + ) + + def __repr__(self): + if self.allow_empty: + return '( {elm} ( re`{sep}` {elm} )* )?'.format(elm=self.element, sep=self.separator.pattern) + return '{elm} ( re`{sep}` {elm} )*'.format(elm=self.element, sep=self.separator.pattern) + + def collect(self, values, collection: dict): + for value in values: + self.element.collect(value, collection) + if self.bind is not None: + collection[self.bind] = values + + +@dataclass(frozen=True) +class Optional(BNFToken[T | None]): + wraps: BNFToken[T] + bind: str | None = field(kw_only=True, default=None) + + def must_parse(self, parser: MlirParser) -> T | None: + return self.wraps.try_parse(parser) + + def try_parse(self, parser: MlirParser) -> T | None: + return self.wraps.try_parse(parser) + + def __repr__(self): + return '{}?'.format(self.wraps) + + def collect(self, value, collection: dict): + if value is not None: + self.wraps.collect(value, collection) + if self.bind is not None: + collection[self.bind] = value + + +def OptionalGroup(tokens: list[BNFToken], bind: str | None = None) -> Optional: + return Optional(Group(tokens), bind=bind) From 8f39dbd8789d04a4c3610fcb9e48ca52a9ebbc29 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Fri, 9 Dec 2022 11:36:01 +0000 Subject: [PATCH 02/65] [parser] fixes for some correctnes issues with the parsing and error message code --- xdsl/parser_ng.py | 90 ++++++++++++++++++++++++++++++++--------------- xdsl/utils/bnf.py | 10 +++--- 2 files changed, 67 insertions(+), 33 deletions(-) diff --git a/xdsl/parser_ng.py b/xdsl/parser_ng.py index 9930f4bcf8..739c7000a8 100644 --- a/xdsl/parser_ng.py +++ b/xdsl/parser_ng.py @@ -4,10 +4,10 @@ from dataclasses import dataclass, field import re import ast +from io import StringIO from typing import Any, TypeVar, Iterable, Literal from enum import Enum - from xdsl.ir import (SSAValue, Block, Callable, Attribute, Operation, Region, BlockArgument, MLContext, ParametrizedAttribute) @@ -23,11 +23,20 @@ from xdsl.irdl import Data -@dataclass class ParseError(Exception): span: Span msg: str + def __init__(self, span: Span, msg: str): + super().__init__(span.print_with_context(msg)) + self.span = span + self.msg = msg + + def print_pretty(self): + print( + self.span.print_with_context(self.msg) + ) + class BacktrackingAbort(Exception): reason: str | None @@ -72,11 +81,14 @@ def text(self): def print_with_context(self, msg: str | None = None): info = self.input.get_lines_containing(self) assert info is not None - lines, offset, line_no = info + lines, offset_of_first_line, line_no = info + # offset relative to the first line: + offset = self.start - offset_of_first_line remaining_len = self.len - print("In {}:{}".format(self.input.name, line_no)) + capture = StringIO() + print("file: {}:{}".format(self.input.name, line_no), file=capture) for line in lines: - print(line) + print(line, file=capture) if offset > len(line): offset -= len(line) continue @@ -84,11 +96,12 @@ def print_with_context(self, msg: str | None = None): continue len_on_this_line = min(remaining_len, len(line) - offset) remaining_len -= len_on_this_line - print((" " * offset) + ("^" * len_on_this_line)) + print("{}{}".format(" " * offset, "^" * max(len_on_this_line, 1)), file=capture) if msg is not None: - print("{}{}".format(" " * offset, msg)) + print("{}{}".format(" " * offset, msg), file=capture) msg = None offset = 0 + return capture.getvalue() def __repr__(self): return "Span[{}:{}](text='{}', input={})".format(self.start, self.end, self.text, self.input) @@ -139,6 +152,7 @@ def get_nth_line_bounds(self, n: int): return start, self.content.find('\n', start) def get_lines_containing(self, span: Span) -> tuple[list[str], int, int] | None: + # A pointer to the start of the first line start = 0 line_no = 0 source = self.content @@ -168,6 +182,7 @@ def at(self, i: int): save_t = tuple[int, tuple[str, ...], bool] parsed_type_t = tuple[Span, tuple[Span]] + @dataclass class Tokenizer: input: Input @@ -234,9 +249,9 @@ def backtracking(self): self.last_error = ParseError(self.last_token, reason[-1]) elif isinstance(ex, ParseError): self.last_error = ex - print("Warning: ParseError in backtracking: {}".format(ex)) + print("Warning: ParseError in backtracking:\n{}".format(ex)) else: - print("Warning: Unexpected error in backtracking: {}".format(ex)) + print("Warning: Unexpected error in backtracking:\n{}".format(ex)) self.resume_from(save) def next_token(self, start: int | None = None, skip: int = 0, peek: bool = False, @@ -395,14 +410,24 @@ class BNF: BNF.Literal(')'), BNF.OptionalGroup([ BNF.Literal('['), - BNF.ListOf(BNF.Nonterminal('block-id'), allow_empty=False, bind='blocks'), # TODD: allow for block args here?! (accordin to spec) + BNF.ListOf(BNF.Nonterminal('block-id'), allow_empty=False, bind='blocks'), + # TODD: allow for block args here?! (accordin to spec) BNF.Literal(']') ], bind='blocks_group'), - BNF.ListOf(BNF.Nonterminal('region'), bind='regions'), + BNF.OptionalGroup([ + BNF.Literal('('), + BNF.ListOf(BNF.Nonterminal('region'), bind='regions', allow_empty=False), + BNF.Literal(')') + ], bind='region_group'), BNF.Nonterminal('attr-dict', bind='attributes'), BNF.Literal(':'), BNF.Nonterminal('function-type', bind='type_signature') ]) + region = BNF.Group([ + BNF.Literal('{'), + BNF.ListOf(BNF.Nonterminal('operation'), separator=re.compile('')), + BNF.Literal('}'), + ]) class MlirParser: @@ -444,7 +469,7 @@ class Accent(Enum): of all try_parse functions is T_ | None """ - def __int__(self, input: str, name: str, ctx: MLContext, accent: str | Accent = 'xDSL'): + def __init__(self, input: str, name: str, ctx: MLContext, accent: str | Accent = Accent.XDSL): self.tokenizer = Tokenizer(Input(input, name)) self.ctx = ctx if isinstance(accent, str): @@ -508,6 +533,20 @@ def must_parse_reference(self) -> list[Span]: def must_parse_list_of(self, try_parse: Callable[[], T_ | None], error_msg: str, separator_pattern: re.Pattern = ParserCommons.comma, allow_empty: bool = True) -> list[T_]: + """ + This is a greedy list-parser. It accepts input only in these cases: + + - If the separator isn't encountered, which signals the end of the list + - If an empty list is allowed, it accepts when the first try_parse fails + - If an empty separator is given, it instead sees a failed try_parse as the end of the list. + + This means, that the setup will not accept the input and instead raise an error: + try_parse = parse_integer_literal + separator = 'x' + input = 3x4x4xi32 + as it will read [3,4,4], then see another separator, and expects the next try_parse call to succeed + (which won't as i32 is not a valid integer literal) + """ items = list() first_item = try_parse() if first_item is None: @@ -520,6 +559,9 @@ def must_parse_list_of(self, try_parse: Callable[[], T_ | None], error_msg: str, while self.tokenizer.next_token_of_pattern(separator_pattern) is not None: next_item = try_parse() if next_item is None: + # if the separator is emtpy, we are good here + if separator_pattern.pattern == '': + return items self.raise_error(error_msg) items.append(next_item) @@ -754,20 +796,10 @@ def raise_error(self, msg: str, at_position: Span | None = None): raise ParseError(at_position, msg) - def assert_eq(self, a: Span, b: str, msg: str): - if a.text == b: - return - raise AssertionError("Assertion failed ({} == {}): {}".format(a.text, b, msg), a) - - def assert_neq(self, a: Span, b: str, msg: str): - if a.text != b: - return - raise AssertionError("Assertion failed ({} != {}): {}".format(a.text, b, msg), a) - - def assert_in(self, a: Span, b: tuple[str], msg: str): - if a.text in b: + def assert_eq(self, got: Span, want: str, msg: str): + if got.text == want: return - raise AssertionError("Assertion failed ({} in {}): {}".format(a.text, b, msg), a) + raise AssertionError("Assertion failed (assert `{}` == `{}`): {}".format(got.text, want, msg), got) def must_parse_characters(self, text: str, msg: str): self.assert_eq(self.tokenizer.next_token(), text, msg) @@ -794,6 +826,11 @@ def try_parse_op(self): op_type = self.ctx.get_op(name.string_contents) + def try_parse_region(self): + return ParserCommons.BNF.region.try_parse(self) + + def must_parse_region(self): + return ParserCommons.BNF.region.must_parse(self) def try_parse_op_name(self) -> Span | None: if (str_lit := self.try_parse_string_literal()) is not None: @@ -801,9 +838,6 @@ def try_parse_op_name(self) -> Span | None: return self.try_parse_bare_id() - - - """ digit ::= [0-9] hex_digit ::= [0-9a-fA-F] diff --git a/xdsl/utils/bnf.py b/xdsl/utils/bnf.py index 972f3b8499..7633d14e49 100644 --- a/xdsl/utils/bnf.py +++ b/xdsl/utils/bnf.py @@ -3,7 +3,7 @@ import re import typing from dataclasses import dataclass, field -from abc import abstractmethod +from abc import abstractmethod, ABC from typing import Any if typing.TYPE_CHECKING: @@ -13,7 +13,7 @@ @dataclass(frozen=True) -class BNFToken(typing.Generic[T]): +class BNFToken(typing.Generic[T], ABC): bind: str | None = field(kw_only=True, init=False) @abstractmethod @@ -86,7 +86,7 @@ def must_parse(self, parser: MlirParser): 'Expected to parse {} here!'.format(self.name) ), self.bind else: - NotImplemented("Parser cannot parse {}".format(self.name)) + raise NotImplementedError("Parser cannot parse {}".format(self.name)) def try_parse(self, parser: MlirParser) -> T | None: if hasattr(parser, 'try_parse_{}'.format(self.name.replace('-', '_'))): @@ -179,9 +179,9 @@ class ListOf(BNFToken): allow_empty: bool = True bind: str | None = field(kw_only=True, default=None) - def try_parse(self, parser: MlirParser) -> T | None: + def must_parse(self, parser: MlirParser) -> T | None: return parser.must_parse_list_of( - self.element.try_parse, + lambda : self.element.try_parse(parser), 'Expected {}!'.format(self.element), separator_pattern=self.separator, allow_empty=self.allow_empty From 86fb17cd9c132884c8ec32c034f200460b81baf1 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Fri, 9 Dec 2022 23:25:40 +0000 Subject: [PATCH 03/65] [parser] basic parsing of operation working --- xdsl/parser_ng.py | 257 +++++++++++++++++++++++++++++++++++++++------- xdsl/utils/bnf.py | 47 ++++----- 2 files changed, 240 insertions(+), 64 deletions(-) diff --git a/xdsl/parser_ng.py b/xdsl/parser_ng.py index 739c7000a8..aef70679a5 100644 --- a/xdsl/parser_ng.py +++ b/xdsl/parser_ng.py @@ -5,7 +5,7 @@ import re import ast from io import StringIO -from typing import Any, TypeVar, Iterable, Literal +from typing import Any, TypeVar, Iterable, Literal, Optional from enum import Enum from xdsl.ir import (SSAValue, Block, Callable, Attribute, Operation, Region, @@ -42,9 +42,10 @@ class BacktrackingAbort(Exception): reason: str | None def __init__(self, reason: str | None = None): - super("This message should never escape the parser, it's intended to signal a failed parsing attempt\n" - "It should never be used outside of a tokenizer.backtracking() block!\n" - "The reason for this abort was {}".format('not specified' if reason is None else reason)) + super().__init__("This message should never escape the parser, it's intended to signal a failed parsing " + "attempt\n " + "It should never be used outside of a tokenizer.backtracking() block!\n" + "The reason for this abort was {}".format('not specified' if reason is None else reason)) self.reason = reason @@ -78,21 +79,23 @@ def len(self): def text(self): return self.input.content[self.start:self.end] - def print_with_context(self, msg: str | None = None): + def print_with_context(self, msg: str | None = None) -> str: + """ + returns a string containing lines relevant to the span. The Span's contents + are highlighted by up-carets beneath them (`^`). The message msg is printed + along these. + """ info = self.input.get_lines_containing(self) assert info is not None lines, offset_of_first_line, line_no = info # offset relative to the first line: offset = self.start - offset_of_first_line - remaining_len = self.len + remaining_len = max(self.len, 1) capture = StringIO() - print("file: {}:{}".format(self.input.name, line_no), file=capture) + print("{}:{}:{}".format(self.input.name, line_no, offset, remaining_len), file=capture) for line in lines: print(line, file=capture) - if offset > len(line): - offset -= len(line) - continue - if remaining_len <= 0: + if remaining_len < 0: continue len_on_this_line = min(remaining_len, len(line) - offset) remaining_len -= len_on_this_line @@ -101,10 +104,12 @@ def print_with_context(self, msg: str | None = None): print("{}{}".format(" " * offset, msg), file=capture) msg = None offset = 0 + if msg is not None: + print(msg, file=capture) return capture.getvalue() def __repr__(self): - return "Span[{}:{}](text='{}', input={})".format(self.start, self.end, self.text, self.input) + return "Span[{}:{}](text='{}')".format(self.start, self.end, self.text) @dataclass(frozen=True) @@ -126,6 +131,9 @@ def string_contents(self): # TODO: is this a hack-job? return ast.literal_eval(self.text) + def __repr__(self): + return "StringLiteral[{}:{}](text='{}')".format(self.start, self.end, self.text) + @dataclass(frozen=True) class Input: @@ -193,7 +201,8 @@ class Tokenizer: """ break_on: tuple[str, ...] = ( - '.', '%', ' ', '(', ')', '[', ']', '{', '}', '<', '>', ':', '=', '@', '?', '|', '->', '-', '//', '\n', '\t', '#' + '.', '%', ' ', '(', ')', '[', ']', '{', '}', '<', '>', ':', '=', '@', '?', '|', '->', '-', '//', '\n', '\t', + '#', '"', "'" ) """ characters the tokenizer should break on @@ -219,7 +228,7 @@ def resume_from(self, save: save_t): self.pos, self.break_on, self.ignore_whitespace = save @contextlib.contextmanager - def backtracking(self): + def backtracking(self, region_name: str | None = None): """ Used to create backtracking parsers. You can wrap you parse code into @@ -238,7 +247,10 @@ def backtracking(self): self.last_error = None yield except Exception as ex: + if region_name is not None: + print("Backtracking in region {}".format(region_name)) if isinstance(ex, BacktrackingAbort): + print(ex.reason) self.last_error = ParseError( self.next_token(peek=True), 'Backtracking aborted: {}'.format(ex.reason or 'unknown reason') @@ -247,9 +259,11 @@ def backtracking(self): reason = ['Generic assertion failure', *(reason for reason in ex.args if isinstance(reason, str))] # we assume that assertions fail because of the last read-in token self.last_error = ParseError(self.last_token, reason[-1]) + print(self.last_error.msg) elif isinstance(ex, ParseError): self.last_error = ex - print("Warning: ParseError in backtracking:\n{}".format(ex)) + print("Warning: ParseError in backtracking: {}".format(ex.msg)) + ex.print_pretty() else: print("Warning: Unexpected error in backtracking:\n{}".format(ex)) self.resume_from(save) @@ -268,7 +282,7 @@ def next_token(self, start: int | None = None, skip: int = 0, peek: bool = False if not peek: self.pos = self._find_token_end(i) - span = self.span_of(i, self.pos) + span = self.span_of(i, self._find_token_end(i)) if not include_comments and span.text == '//': while self.input.at(i) != '\n': i += 1 @@ -383,11 +397,12 @@ class ParserCommons: string_literal = re.compile(r'"([^\n\f\v\r"]|\\[nfvr"])+"') float_literal = re.compile(r'[-+]?[0-9]+\.[0-9]*([eE][-+]?[0-9]+)?') bare_id = re.compile(r'[A-z_][A-z0-9_$.]+') - value_id = re.compile(r'%[A-z_][A-z0-9_$.]+') + value_id = re.compile(r'%([0-9]+|([A-z_$.-][0-9A-z_$.-]*))') suffix_id = re.compile(r'([0-9]+|([A-z_$.-][0-9A-z_$.-]*))') block_id = re.compile(r'\^([0-9]+|([A-z_$.-][0-9A-z_$.-]*))') type_alias = re.compile(r'![A-z_][A-z0-9_$.]+') attribute_alias = re.compile(r'#[A-z_][A-z0-9_$.]+') + boolean_literal = re.compile(r'(true|false)') builtin_type = re.compile('({})'.format( '|'.join(( r'[su]?i\d+', 'tensor', 'vector', @@ -403,7 +418,7 @@ class BNF: """ Collection of BNF trees. """ - generic_operation = BNF.Group([ + generic_operation_body = BNF.Group([ BNF.Nonterminal('string-literal', bind="name"), BNF.Literal('('), BNF.ListOf(BNF.Nonterminal('value-id'), bind='args'), @@ -413,12 +428,12 @@ class BNF: BNF.ListOf(BNF.Nonterminal('block-id'), allow_empty=False, bind='blocks'), # TODD: allow for block args here?! (accordin to spec) BNF.Literal(']') - ], bind='blocks_group'), + ]), BNF.OptionalGroup([ BNF.Literal('('), BNF.ListOf(BNF.Nonterminal('region'), bind='regions', allow_empty=False), BNF.Literal(')') - ], bind='region_group'), + ]), BNF.Nonterminal('attr-dict', bind='attributes'), BNF.Literal(':'), BNF.Nonterminal('function-type', bind='type_signature') @@ -428,6 +443,11 @@ class BNF: BNF.ListOf(BNF.Nonterminal('operation'), separator=re.compile('')), BNF.Literal('}'), ]) + attr_dict = BNF.Group([ + BNF.Literal('{'), + BNF.ListOf(BNF.Nonterminal('attribute-entry'), bind='attributes'), + BNF.Literal('}') + ]) class MlirParser: @@ -469,7 +489,7 @@ class Accent(Enum): of all try_parse functions is T_ | None """ - def __init__(self, input: str, name: str, ctx: MLContext, accent: str | Accent = Accent.XDSL): + def __init__(self, input: str, name: str, ctx: MLContext, accent: str | Accent = Accent.MLIR): self.tokenizer = Tokenizer(Input(input, name)) self.ctx = ctx if isinstance(accent, str): @@ -493,7 +513,7 @@ def must_parse_block(self) -> Block | None: self._ssaValues[name.text] = arg block.args.append(arg) - while (next_op := self.try_parse_op()) is not None: + while (next_op := self.try_parse_operation()) is not None: block.ops.append(next_op) return block @@ -591,6 +611,9 @@ def try_parse_suffix_id(self) -> Span | None: def try_parse_block_id(self) -> Span | None: return self.tokenizer.next_token_of_pattern(ParserCommons.block_id) + def try_parse_boolean_literal(self) -> Span | None: + return self.tokenizer.next_token_of_pattern(ParserCommons.boolean_literal) + def try_parse_value_id_and_type(self) -> tuple[Span, Attribute] | None: with self.tokenizer.backtracking(): value_id = self.try_parse_value_id() @@ -641,7 +664,7 @@ def try_parse_builtin_type(self) -> Attribute | None: raise BacktrackingAbort("Expected builtin name!") if name.text == 'index': return IndexType.build() - if (re_match := re.match(r'^([su]?i(\d)+)$', name.text)) is not None: + if (re_match := re.match(r'^[su]?i(\d+)$', name.text)) is not None: signedness = { 's': Signedness.SIGNED, 'u': Signedness.UNSIGNED, @@ -759,7 +782,7 @@ def try_parse_shape_element(self, lower_bound: int = 1) -> int | None: return -1 return None - def must_parse_type_params(self) -> list[parsed_type_t]: + def must_parse_type_params(self) -> list[Attribute]: # consume opening bracket assert self.tokenizer.next_token().text == '<', 'Type must be parameterized!' @@ -804,39 +827,195 @@ def assert_eq(self, got: Span, want: str, msg: str): def must_parse_characters(self, text: str, msg: str): self.assert_eq(self.tokenizer.next_token(), text, msg) - def try_parse_op_result_list(self) -> list[tuple[Span, Attribute] | Span] | None: + def must_parse_op_result_list(self) -> list[tuple[Span, Attribute] | Span] | None: inner_parser = (dict(( (MlirParser.Accent.MLIR, self.try_parse_value_id), (MlirParser.Accent.XDSL, self.try_parse_value_id_and_type) )))[self.accent] - return self.must_parse_list_of(inner_parser, 'Expected op-result here!', allow_empty=False) + return self.must_parse_list_of(self.try_parse_value_id, 'Expected op-result here!', allow_empty=False) - def try_parse_op(self): + def try_parse_operation(self) -> Operation | None: with self.tokenizer.backtracking(): - result_list = self.try_parse_op_result_list() + result_list = self.must_parse_op_result_list() self.must_parse_characters('=', 'Operation definitions expect an `=` after op-result-list!') - name = self.try_parse_op_name() - # handle custom-operation parsing - if not isinstance(name, StringLiteral): - op_type = self.ctx.get_op(name.text) - # TODO: how do we pass result types if we are in xDSL format? - op_type.parse() + generic_op = ParserCommons.BNF.generic_operation_body.try_parse(self) + if generic_op is None: + self.raise_error("custom operations not supported as of yet!") - op_type = self.ctx.get_op(name.string_contents) - - def try_parse_region(self): - return ParserCommons.BNF.region.try_parse(self) + values = ParserCommons.BNF.generic_operation_body.collect(generic_op, dict()) + print("parsed op {} = {}".format(result_list, values)) + return result_list, values def must_parse_region(self): - return ParserCommons.BNF.region.must_parse(self) + self.must_parse_characters('{', 'Regions begin with `{`') + ops = self.must_parse_list_of(self.try_parse_operation, 'Expected Operation', separator_pattern=re.compile("")) + self.must_parse_characters('}', 'Regions end with `}`') + return ops def try_parse_op_name(self) -> Span | None: if (str_lit := self.try_parse_string_literal()) is not None: return str_lit return self.try_parse_bare_id() + def must_parse_attribute_entry(self) -> tuple[Span, Attribute]: + """ + Parse entry in attribute dict. Of format: + + attrbiute_entry := (bare-id | string-literal) `=` attribute + attrbiute := dialect-attribute | builtin-attribute + """ + if (name := self.try_parse_bare_id()) is None: + name = self.try_parse_string_literal() + + if name is None: + self.raise_error('Expected bare-id or string-literal here as part of attribute entry!') + + self.must_parse_characters('=', 'Attribute entries must be of format name `=` attribute!') + + return name, self.must_parse_attribute() + + def must_parse_attribute(self) -> Attribute: + """ + Parse attribute (either builtin or dialect) + """ + # all dialect attrs must start with '#', so we check for that first (as it's easier) + if self.tokenizer.next_token(peek=True).text == '#': + value = self.try_parse_dialect_type_or_attribute('attr') + if value is None: + self.raise_error('`#` must be followed by a valid builtin attribute!') + return value + + builtin_val = self.try_parse_builtin_attr() + + if builtin_val is None: + self.raise_error("Unknown attribute!") + + return builtin_val + + def must_parse_attribute_type(self) -> Attribute: + self.must_parse_characters(':', 'Expected attribute type definition here ( `:` type )') + return self.expect(self.try_parse_type, 'Expected attribute type definition here ( `:` type )') + + def try_parse_builtin_attr(self) -> Attribute: + attrs = ( + self.try_parse_builtin_int_attr, + self.try_parse_builtin_float_attr, + self.try_parse_builtin_str_attr, + self.try_parse_builtin_arr_attr + ) + + for attr_parser in attrs: + if (val := attr_parser()) is not None: + print("got attr {}".format(val)) + return val + + def try_parse_builtin_int_attr(self) -> IntegerAttr | None: + bool = self.try_parse_builtin_boolean_attr() + if bool is not None: + return bool + + with self.tokenizer.backtracking(): + value = self.expect(self.try_parse_integer_literal, 'Integer attribute must start with an integer literal!') + if self.tokenizer.next_token(peek=True).text != ':': + print(self.tokenizer.next_token(peek=True)) + return IntegerAttr.from_index_int_value(int(value.text)) + type = self.must_parse_attribute_type() + return IntegerAttr.from_params(int(value.text), type) + + def try_parse_builtin_float_attr(self) -> IntegerAttr | None: + with self.tokenizer.backtracking(): + value = self.expect(self.try_parse_float_literal, 'Integer attribute must start with an integer literal!') + if self.tokenizer.next_token(peek=True).text != ':': + return FloatAttr.from_value(float(value.text)) + type = self.must_parse_attribute_type() + return IntegerAttr.from_params(float(value.text), type) + + def try_parse_builtin_boolean_attr(self) -> IntegerAttr | None: + span = self.try_parse_boolean_literal() + + if span is None: + return None + + int_val = ['true', 'false'].index(span.text) + return IntegerAttr.from_params(int_val, IntegerType.from_width(1)) + + def try_parse_builtin_str_attr(self): + if self.tokenizer.next_token(peek=True).text != '"': + return None + + with self.tokenizer.backtracking(): + literal = self.try_parse_string_literal() + if self.tokenizer.next_token(peek=True).text != ':': + return StringAttr.from_str(literal.string_contents) + self.raise_error("Typed string literals are not supported!") + + def try_parse_builtin_arr_attr(self) -> list[Attribute] | None: + if self.tokenizer.next_token(peek=True).text != '[': + return None + with self.tokenizer.backtracking(): + self.must_parse_characters('[', 'Array literals must start with `[`') + attrs = self.must_parse_list_of(self.try_parse_builtin_attr, 'Expected array entry!') + self.must_parse_characters(']', 'Array literals must be enclosed by square brackets!') + return ArrayAttr.from_list(attrs) + + def must_parse_attr_dict(self) -> list[tuple[Span, Attribute]]: + res = ParserCommons.BNF.attr_dict.try_parse(self) + if res is None: + return [] + return ParserCommons.BNF.attr_dict.collect(res, dict()).get('attributes', list()) + + def try_parse_attr_dict(self) -> list[tuple[Span, Attribute]] | None: + res = ParserCommons.BNF.attr_dict.try_parse(self) + if res is None: + return None + return ParserCommons.BNF.attr_dict.collect(res, dict()).get('attributes', list()) + + def must_parse_function_type(self) -> tuple[list[Attribute], list[Attribute]]: + """ + Parses function-type: + + viable function types are: + (i32) -> () + i32 -> () + () -> (i32, i32) + i32 -> i32 + + Uses type-or-type-list-parens + """ + args = self.must_parse_type_or_type_list_parens() + + self.must_parse_characters('->', 'Function type!') + + return args, self.must_parse_type_or_type_list_parens() + + def must_parse_type_or_type_list_parens(self) -> list[Attribute]: + """ + Parses type-or-type-list-parens, which is used in function-type. + + type-or-type-list-parens ::= type | type-list-parens + type-list-parens ::= `(` `)` | `(` type-list-no-parens `)` + type-list-no-parens ::= type (`,` type)* + """ + if self.tokenizer.next_token(peek=True).text == '(': + self.must_parse_characters('(', 'Function type!') + args: list[Attribute] = self.must_parse_list_of(self.try_parse_type, 'Expected type here!') + self.must_parse_characters(')', "End of function type args") + else: + args = [ + self.try_parse_type() + ] + if args[0] is None: + self.raise_error("Function type must either be single type or list of types in parenthesis!") + return args + + def try_parse_function_type(self) -> tuple[list[Attribute], list[Attribute]] | None: + if self.tokenizer.next_token(peek=True).text != '(': + return None + with self.tokenizer.backtracking('Function type'): + return self.must_parse_function_type() + """ digit ::= [0-9] diff --git a/xdsl/utils/bnf.py b/xdsl/utils/bnf.py index 7633d14e49..0467194f48 100644 --- a/xdsl/utils/bnf.py +++ b/xdsl/utils/bnf.py @@ -13,7 +13,7 @@ @dataclass(frozen=True) -class BNFToken(typing.Generic[T], ABC): +class BNFToken: bind: str | None = field(kw_only=True, init=False) @abstractmethod @@ -21,13 +21,14 @@ def must_parse(self, parser: MlirParser) -> T: raise NotImplemented() def try_parse(self, parser: MlirParser) -> T | None: - with parser.tokenizer.backtracking(): + with parser.tokenizer.backtracking(repr(self)): return self.must_parse(parser) - def collect(self, value, collection: dict): + def collect(self, value, collection: dict) -> dict: if self.bind is None: - return + return collection collection[self.bind] = value + return collection @dataclass(frozen=True) @@ -57,6 +58,7 @@ def must_parse(self, parser: MlirParser) -> T: res = self.try_parse(parser) if res is None: parser.raise_error('Expected token of form {}!'.format(self)) + return res def __repr__(self): return 're`{}`'.format(self.pattern.pattern) @@ -79,12 +81,12 @@ class Nonterminal(BNFToken): def must_parse(self, parser: MlirParser): if hasattr(parser, 'must_parse_{}'.format(self.name.replace('-', '_'))): - return getattr(parser, 'must_parse_{}'.format(self.name.replace('-', '_')))(), self.bind + return getattr(parser, 'must_parse_{}'.format(self.name.replace('-', '_')))() elif hasattr(parser, 'try_parse_{}'.format(self.name.replace('-', '_'))): return parser.expect( getattr(parser, 'try_parse_{}'.format(self.name.replace('-', '_'))), 'Expected to parse {} here!'.format(self.name) - ), self.bind + ) else: raise NotImplementedError("Parser cannot parse {}".format(self.name)) @@ -110,16 +112,15 @@ def must_parse(self, parser: MlirParser) -> T: def __repr__(self): return '( {} )'.format(' '.join(repr(t) for t in self.tokens)) - def collect(self, value, collection: dict): + def collect(self, value, collection: dict) -> dict: for child, value in zip(self.tokens, value): child.collect(value, collection) - if self.bind is not None: - collection[self.bind] = value + return super().collect(value, collection) @dataclass(frozen=True) class OneOrMoreOf(BNFToken): - wraps: BNFToken[T] + wraps: BNFToken bind: str | None = field(kw_only=True, default=None) def must_parse(self, parser: MlirParser) -> list[T]: @@ -138,16 +139,15 @@ def __repr__(self): def children(self) -> typing.Iterable[BNFToken]: return self.wraps, - def collect(self, value, collection: dict): + def collect(self, value, collection: dict) -> dict: for val in value: self.wraps.collect(val, collection) - if self.bind is not None: - collection[self.bind] = value + return super().collect(value, collection) @dataclass(frozen=True) class ZeroOrMoreOf(BNFToken): - wraps: BNFToken[T] + wraps: BNFToken bind: str | None = field(kw_only=True, default=None) def must_parse(self, parser: MlirParser) -> list[T]: @@ -164,11 +164,10 @@ def __repr__(self): def children(self) -> typing.Iterable[BNFToken]: return self.wraps, - def collect(self, values, collection: dict): + def collect(self, values, collection: dict) -> dict: for value in values: self.wraps.collect(value, collection) - if self.bind is not None: - collection[self.bind] = values + return super().collect(values, collection) @dataclass(frozen=True) @@ -192,16 +191,15 @@ def __repr__(self): return '( {elm} ( re`{sep}` {elm} )* )?'.format(elm=self.element, sep=self.separator.pattern) return '{elm} ( re`{sep}` {elm} )*'.format(elm=self.element, sep=self.separator.pattern) - def collect(self, values, collection: dict): + def collect(self, values, collection: dict) -> dict: for value in values: self.element.collect(value, collection) - if self.bind is not None: - collection[self.bind] = values + return super().collect(values, collection) @dataclass(frozen=True) -class Optional(BNFToken[T | None]): - wraps: BNFToken[T] +class Optional(BNFToken): + wraps: BNFToken bind: str | None = field(kw_only=True, default=None) def must_parse(self, parser: MlirParser) -> T | None: @@ -213,11 +211,10 @@ def try_parse(self, parser: MlirParser) -> T | None: def __repr__(self): return '{}?'.format(self.wraps) - def collect(self, value, collection: dict): + def collect(self, value, collection: dict) -> dict: if value is not None: self.wraps.collect(value, collection) - if self.bind is not None: - collection[self.bind] = value + return super().collect(value, collection) def OptionalGroup(tokens: list[BNFToken], bind: str | None = None) -> Optional: From 7de6326a698161b7d559fd5c4a10bae53ff26323 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Sat, 10 Dec 2022 00:34:34 +0000 Subject: [PATCH 04/65] [parser] cleanup of backtracking error reporting --- xdsl/parser_ng.py | 120 ++++++++++++++++++++++++++++++---------------- xdsl/utils/bnf.py | 16 +++++-- 2 files changed, 91 insertions(+), 45 deletions(-) diff --git a/xdsl/parser_ng.py b/xdsl/parser_ng.py index aef70679a5..7c991e429f 100644 --- a/xdsl/parser_ng.py +++ b/xdsl/parser_ng.py @@ -1,6 +1,7 @@ from __future__ import annotations import contextlib +import sys from dataclasses import dataclass, field import re import ast @@ -26,17 +27,32 @@ class ParseError(Exception): span: Span msg: str + history: BacktrackingHistory | None - def __init__(self, span: Span, msg: str): + def __init__(self, span: Span, msg: str, history: BacktrackingHistory | None = None): super().__init__(span.print_with_context(msg)) self.span = span self.msg = msg + self.history = history - def print_pretty(self): - print( - self.span.print_with_context(self.msg) - ) + def print_pretty(self, file=sys.stderr, print_history: bool = True): + if self.history and print_history: + self.history.print_unroll(file) + print(self.span.print_with_context(self.msg), file=file) + + +@dataclass +class BacktrackingHistory: + error: ParseError + parent: BacktrackingHistory | None + region_name: str | None + + def print_unroll(self, file=sys.stderr): + if self.parent: + self.parent.print_unroll(file) + print("Aborted parsing of {} because failure at:".format(self.region_name or ''), file=file) + self.error.print_pretty(file=file, print_history=False) class BacktrackingAbort(Exception): reason: str | None @@ -210,7 +226,8 @@ class Tokenizer: ignore_whitespace: bool = True - last_error: ParseError | None = field(init=False, default=None) + history: BacktrackingHistory | None = field(init=False, default=None) + last_token: Span | None = field(init=False, default=None) def save(self) -> save_t: @@ -244,28 +261,48 @@ def backtracking(self, region_name: str | None = None): """ save = self.save() try: - self.last_error = None yield + # clear error history when something doesn't fail + # this is because we are only interested in the last "cascade" of failures. + # if a backtracking() completes without failre, something has been parsed (we assume) + self.history = None except Exception as ex: - if region_name is not None: - print("Backtracking in region {}".format(region_name)) if isinstance(ex, BacktrackingAbort): - print(ex.reason) - self.last_error = ParseError( - self.next_token(peek=True), - 'Backtracking aborted: {}'.format(ex.reason or 'unknown reason') + self.history = BacktrackingHistory( + ParseError( + self.next_token(peek=True), + 'Backtracking aborted: {}'.format(ex.reason or 'unknown reason') + ), + self.history, + region_name ) elif isinstance(ex, AssertionError): reason = ['Generic assertion failure', *(reason for reason in ex.args if isinstance(reason, str))] # we assume that assertions fail because of the last read-in token - self.last_error = ParseError(self.last_token, reason[-1]) - print(self.last_error.msg) + self.history = BacktrackingHistory( + ParseError(self.last_token, reason[-1]), + self.history, + region_name + ) elif isinstance(ex, ParseError): - self.last_error = ex - print("Warning: ParseError in backtracking: {}".format(ex.msg)) - ex.print_pretty() + self.history = BacktrackingHistory( + ex, + self.history, + region_name + ) + elif isinstance(ex, EOFError): + self.history = BacktrackingHistory( + ParseError(self.last_token, "Encountered EOF"), + self.history, + region_name + ) else: - print("Warning: Unexpected error in backtracking:\n{}".format(ex)) + self.history = BacktrackingHistory( + ParseError(self.last_token, "Unexpected exception: {}".format(ex)), + self.history, + region_name + ) + print("Warning: Unexpected error in backtracking: {}".format(repr(ex))) self.resume_from(save) def next_token(self, start: int | None = None, skip: int = 0, peek: bool = False, @@ -428,26 +465,21 @@ class BNF: BNF.ListOf(BNF.Nonterminal('block-id'), allow_empty=False, bind='blocks'), # TODD: allow for block args here?! (accordin to spec) BNF.Literal(']') - ]), + ], debug_name="operations optional block id group"), BNF.OptionalGroup([ BNF.Literal('('), BNF.ListOf(BNF.Nonterminal('region'), bind='regions', allow_empty=False), BNF.Literal(')') - ]), - BNF.Nonterminal('attr-dict', bind='attributes'), + ], debug_name="operation regions"), + BNF.Nonterminal('attr-dict', bind='attributes', debug_name="attrbiute dictionary"), BNF.Literal(':'), BNF.Nonterminal('function-type', bind='type_signature') - ]) - region = BNF.Group([ - BNF.Literal('{'), - BNF.ListOf(BNF.Nonterminal('operation'), separator=re.compile('')), - BNF.Literal('}'), - ]) + ], debug_name="generic operation body") attr_dict = BNF.Group([ BNF.Literal('{'), - BNF.ListOf(BNF.Nonterminal('attribute-entry'), bind='attributes'), + BNF.ListOf(BNF.Nonterminal('attribute-entry', debug_name="attribute entry"), bind='attributes'), BNF.Literal('}') - ]) + ], debug_name="attrbute dictionary") class MlirParser: @@ -497,7 +529,13 @@ def __init__(self, input: str, name: str, ctx: MLContext, accent: str | Accent = self.accent = accent def begin_parse(self): - pass + ops = [] + while (op := self.try_parse_operation()) is not None: + ops.append(op) + if not self.tokenizer.is_eof(): + self.raise_error("Unfinished business!") + return ops + def must_parse_block(self) -> Block | None: next_id = self.expect(self.try_parse_block_id, 'Blocks must start with a block id!') @@ -813,11 +851,7 @@ def raise_error(self, msg: str, at_position: Span | None = None): if at_position is None: at_position = self.tokenizer.next_token(peek=True) - # include backtracking exception if available - if self.tokenizer.last_error: - raise ParseError(at_position, msg) from self.tokenizer.last_error - - raise ParseError(at_position, msg) + raise ParseError(at_position, msg, self.tokenizer.history) def assert_eq(self, got: Span, want: str, msg: str): if got.text == want: @@ -836,16 +870,19 @@ def must_parse_op_result_list(self) -> list[tuple[Span, Attribute] | Span] | Non return self.must_parse_list_of(self.try_parse_value_id, 'Expected op-result here!', allow_empty=False) def try_parse_operation(self) -> Operation | None: - with self.tokenizer.backtracking(): - result_list = self.must_parse_op_result_list() - self.must_parse_characters('=', 'Operation definitions expect an `=` after op-result-list!') + with self.tokenizer.backtracking("operation"): + if self.tokenizer.next_token(peek=True).text == '%': + result_list = self.must_parse_op_result_list() + self.must_parse_characters('=', 'Operation definitions expect an `=` after op-result-list!') + else: + result_list = [] generic_op = ParserCommons.BNF.generic_operation_body.try_parse(self) if generic_op is None: self.raise_error("custom operations not supported as of yet!") values = ParserCommons.BNF.generic_operation_body.collect(generic_op, dict()) - print("parsed op {} = {}".format(result_list, values)) + return result_list, values def must_parse_region(self): @@ -908,7 +945,6 @@ def try_parse_builtin_attr(self) -> Attribute: for attr_parser in attrs: if (val := attr_parser()) is not None: - print("got attr {}".format(val)) return val def try_parse_builtin_int_attr(self) -> IntegerAttr | None: @@ -916,7 +952,7 @@ def try_parse_builtin_int_attr(self) -> IntegerAttr | None: if bool is not None: return bool - with self.tokenizer.backtracking(): + with self.tokenizer.backtracking("built in int attribute"): value = self.expect(self.try_parse_integer_literal, 'Integer attribute must start with an integer literal!') if self.tokenizer.next_token(peek=True).text != ':': print(self.tokenizer.next_token(peek=True)) diff --git a/xdsl/utils/bnf.py b/xdsl/utils/bnf.py index 0467194f48..2ac7eed41f 100644 --- a/xdsl/utils/bnf.py +++ b/xdsl/utils/bnf.py @@ -15,13 +15,14 @@ @dataclass(frozen=True) class BNFToken: bind: str | None = field(kw_only=True, init=False) + debug_name: str | None = field(kw_only=True, init=False) @abstractmethod def must_parse(self, parser: MlirParser) -> T: raise NotImplemented() def try_parse(self, parser: MlirParser) -> T | None: - with parser.tokenizer.backtracking(repr(self)): + with parser.tokenizer.backtracking(self.debug_name or repr(self)): return self.must_parse(parser) def collect(self, value, collection: dict) -> dict: @@ -38,6 +39,7 @@ class Literal(BNFToken): """ string: str bind: str | None = field(kw_only=True, default=None) + debug_name: str | None = field(kw_only=True, default=None) def must_parse(self, parser: MlirParser): return parser.must_parse_characters(self.string, 'Expected `{}`'.format(self.string)) @@ -50,6 +52,7 @@ def __repr__(self): class Regex(BNFToken): pattern: re.Pattern bind: str | None = field(kw_only=True, default=None) + debug_name: str | None = field(kw_only=True, default=None) def try_parse(self, parser: MlirParser) -> T | None: return parser.tokenizer.next_token_of_pattern(self.pattern) @@ -79,6 +82,8 @@ class Nonterminal(BNFToken): """ bind: str | None = field(kw_only=True, default=None) + debug_name: str | None = field(kw_only=True, default=None) + def must_parse(self, parser: MlirParser): if hasattr(parser, 'must_parse_{}'.format(self.name.replace('-', '_'))): return getattr(parser, 'must_parse_{}'.format(self.name.replace('-', '_')))() @@ -103,6 +108,7 @@ def __repr__(self): class Group(BNFToken): tokens: list[BNFToken] bind: str | None = field(kw_only=True, default=None) + debug_name: str | None = field(kw_only=True, default=None) def must_parse(self, parser: MlirParser) -> T: return [ @@ -122,6 +128,7 @@ def collect(self, value, collection: dict) -> dict: class OneOrMoreOf(BNFToken): wraps: BNFToken bind: str | None = field(kw_only=True, default=None) + debug_name: str | None = field(kw_only=True, default=None) def must_parse(self, parser: MlirParser) -> list[T]: res = list() @@ -149,6 +156,7 @@ def collect(self, value, collection: dict) -> dict: class ZeroOrMoreOf(BNFToken): wraps: BNFToken bind: str | None = field(kw_only=True, default=None) + debug_name: str | None = field(kw_only=True, default=None) def must_parse(self, parser: MlirParser) -> list[T]: res = list() @@ -177,6 +185,7 @@ class ListOf(BNFToken): allow_empty: bool = True bind: str | None = field(kw_only=True, default=None) + debug_name: str | None = field(kw_only=True, default=None) def must_parse(self, parser: MlirParser) -> T | None: return parser.must_parse_list_of( @@ -201,6 +210,7 @@ def collect(self, values, collection: dict) -> dict: class Optional(BNFToken): wraps: BNFToken bind: str | None = field(kw_only=True, default=None) + debug_name: str | None = field(kw_only=True, default=None) def must_parse(self, parser: MlirParser) -> T | None: return self.wraps.try_parse(parser) @@ -217,5 +227,5 @@ def collect(self, value, collection: dict) -> dict: return super().collect(value, collection) -def OptionalGroup(tokens: list[BNFToken], bind: str | None = None) -> Optional: - return Optional(Group(tokens), bind=bind) +def OptionalGroup(tokens: list[BNFToken], bind: str | None = None, debug_name: str | None = None) -> Optional: + return Optional(Group(tokens), bind=bind, debug_name=debug_name) From 1f0cecb0b9288f77ff23810bfbb91f183cbfaec3 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Sat, 10 Dec 2022 09:25:34 +0000 Subject: [PATCH 05/65] [parser] improved region and attribute parsing, building The parse is now able to parse: - all supported float attributes - regions with multiple blocks - regions with named and unnamed blocks Furthermore, it now builds complete operations that can be printed by the printer! --- xdsl/parser_ng.py | 181 +++++++++++++++++++++++++++++++++------------- 1 file changed, 132 insertions(+), 49 deletions(-) diff --git a/xdsl/parser_ng.py b/xdsl/parser_ng.py index 7c991e429f..e4cf2474c5 100644 --- a/xdsl/parser_ng.py +++ b/xdsl/parser_ng.py @@ -2,6 +2,7 @@ import contextlib import sys +import traceback from dataclasses import dataclass, field import re import ast @@ -218,7 +219,7 @@ class Tokenizer: break_on: tuple[str, ...] = ( '.', '%', ' ', '(', ')', '[', ']', '{', '}', '<', '>', ':', '=', '@', '?', '|', '->', '-', '//', '\n', '\t', - '#', '"', "'" + '#', '"', "'", ',' ) """ characters the tokenizer should break on @@ -279,6 +280,11 @@ def backtracking(self, region_name: str | None = None): elif isinstance(ex, AssertionError): reason = ['Generic assertion failure', *(reason for reason in ex.args if isinstance(reason, str))] # we assume that assertions fail because of the last read-in token + if len(reason) == 1: + tb = StringIO() + traceback.print_exc(file=tb) + reason[0] += '\n' + tb.getvalue() + self.history = BacktrackingHistory( ParseError(self.last_token, reason[-1]), self.history, @@ -303,6 +309,7 @@ def backtracking(self, region_name: str | None = None): region_name ) print("Warning: Unexpected error in backtracking: {}".format(repr(ex))) + raise ex self.resume_from(save) def next_token(self, start: int | None = None, skip: int = 0, peek: bool = False, @@ -377,6 +384,10 @@ def next_pos(self, i: int | None = None) -> int: if self.ignore_whitespace: while self.input.at(i).isspace(): i += 1 + # skip comments as well + if self.input.content.startswith('//', i): + i = self.input.content.find('\n', i) + 1 + return self.next_pos(i) return i def is_eof(self): @@ -440,12 +451,14 @@ class ParserCommons: type_alias = re.compile(r'![A-z_][A-z0-9_$.]+') attribute_alias = re.compile(r'#[A-z_][A-z0-9_$.]+') boolean_literal = re.compile(r'(true|false)') - builtin_type = re.compile('({})'.format( - '|'.join(( - r'[su]?i\d+', 'tensor', 'vector', - 'memref', 'complex', 'opaque', - 'tuple', 'index', - # TODO: add all the FloatNtype, Float8E4M3FNType, Float8E5M2Type, and BFloat16Type + builtin_type = re.compile('(({}))'.format( + ')|('.join(( + r'[su]?i\d+', r'f\d+', + 'tensor', 'vector', + 'memref', 'complex', + 'opaque', 'tuple', + 'index', + # TODO: add all the Float8E4M3FNType, Float8E5M2Type, and BFloat16Type )) )) double_colon = re.compile('::') @@ -512,8 +525,8 @@ class Accent(Enum): ctx: MLContext """xDSL context.""" - _ssaValues: dict[str, SSAValue] = field(init=False, default_factory=dict) - _blocks: dict[str, Block] = field(init=False, default_factory=dict) + _ssaValues: dict[str, SSAValue] + _blocks: dict[str, Block] T_ = TypeVar('T_') """ @@ -527,6 +540,8 @@ def __init__(self, input: str, name: str, ctx: MLContext, accent: str | Accent = if isinstance(accent, str): accent = MlirParser.Accent[accent] self.accent = accent + self._ssaValues = dict() + self._blocks = dict() def begin_parse(self): ops = [] @@ -537,31 +552,37 @@ def begin_parse(self): return ops - def must_parse_block(self) -> Block | None: - next_id = self.expect(self.try_parse_block_id, 'Blocks must start with a block id!') - - assert next_id.text not in self._blocks + def must_parse_block(self) -> Block: + id, args = self.must_parse_optional_block_label() block = Block() - self._blocks[next_id.text] = block + if id is not None: + assert id.text not in self._blocks + self._blocks[id.text] = block - if self.tokenizer.next_token(peek=True).text == '(': - for i, (name, type) in enumerate(self.must_parse_block_arg_list()): - arg = BlockArgument(type, block, i) - self._ssaValues[name.text] = arg - block.args.append(arg) + for i, (name, type) in args: + arg = BlockArgument(type, block, i) + self._ssaValues[name.text] = arg + block.args.append(arg) while (next_op := self.try_parse_operation()) is not None: block.ops.append(next_op) return block - def get_or_create_block_arg(self, name: Span, type: Attribute): - if name.text in self._ssaValues: - val = self._ssaValues.get(name.text) - assert val.typ == type - return val - self._ssaValues[name.text] = BlockArgument(type, ) + def must_parse_optional_block_label(self): + next_id = self.try_parse_block_id() + arg_list = list() + + if next_id is not None: + assert next_id.text not in self._blocks, "Blocks cannot have the same ID!" + + if self.tokenizer.next_token(peek=True).text == '(': + arg_list = enumerate(self.must_parse_block_arg_list()) + + self.must_parse_characters(':', 'Block label must end in a `:`!') + + return next_id, arg_list def must_parse_block_arg_list(self) -> list[tuple[Span, Attribute]]: self.assert_eq(self.tokenizer.next_token(), '(', 'Block arguments must start with `(`') @@ -614,13 +635,13 @@ def must_parse_list_of(self, try_parse: Callable[[], T_ | None], error_msg: str, items.append(first_item) - while self.tokenizer.next_token_of_pattern(separator_pattern) is not None: + while (match := self.tokenizer.next_token_of_pattern(separator_pattern)) is not None: next_item = try_parse() if next_item is None: # if the separator is emtpy, we are good here if separator_pattern.pattern == '': return items - self.raise_error(error_msg) + self.raise_error(error_msg + ' because was able to match next separator {}'.format(match)) items.append(next_item) return items @@ -701,7 +722,7 @@ def try_parse_builtin_type(self) -> Attribute | None: if name is None: raise BacktrackingAbort("Expected builtin name!") if name.text == 'index': - return IndexType.build() + return IndexType() if (re_match := re.match(r'^[su]?i(\d+)$', name.text)) is not None: signedness = { 's': Signedness.SIGNED, @@ -710,6 +731,18 @@ def try_parse_builtin_type(self) -> Attribute | None: } return IntegerType.from_width(int(re_match.group(1)), signedness[name.text[0]]) + if (re_match := re.match(r'^f(\d+)$', name.text)) is not None: + width = int(re_match.group(1)) + type = { + 16: Float16Type, + 32: Float64Type, + 64: Float64Type + }.get(width, None) + if type is None: + self.raise_error("Unsupported floating point width: {}".format(width)) + return type() + + return self.must_parse_builtin_parametrized_type(name) def must_parse_builtin_parametrized_type(self, name: Span) -> ParametrizedAttribute: @@ -883,13 +916,49 @@ def try_parse_operation(self) -> Operation | None: values = ParserCommons.BNF.generic_operation_body.collect(generic_op, dict()) - return result_list, values + arg_types, ret_types = ([], []) + if 'type_signature' in values: + functype : FunctionType = values['type_signature'] + arg_types, ret_types = functype.inputs.data, functype.outputs.data - def must_parse_region(self): - self.must_parse_characters('{', 'Regions begin with `{`') - ops = self.must_parse_list_of(self.try_parse_operation, 'Expected Operation', separator_pattern=re.compile("")) - self.must_parse_characters('}', 'Regions end with `}`') - return ops + if len(ret_types) != len(result_list): + raise ParseError( + values['name'], + "Mismatch between type signature and result list for op!" + ) + + op_type = self.ctx.get_op(values['name'].string_contents) + return op_type.create( + [self._ssaValues[arg.text] for arg in values['args']], + ret_types, + values['attributes'], + [self._blocks[block_name.text] for block_name in values.get('blocks', [])], + values.get('regions', []) + ) + + def must_parse_region(self) -> Region: + oldSSAVals = self._ssaValues.copy() + oldBBNames = self._blocks.copy() + self._blocks = dict[str, Block]() + + region = Region() + + try: + self.must_parse_characters('{', 'Regions begin with `{`') + if self.tokenizer.next_token(peek=True).text != '}': + # parse first block + block = self.must_parse_block() + region.add_block(block) + + while self.tokenizer.next_token(peek=True).text == '^': + region.add_block(self.must_parse_block()) + + self.must_parse_characters('}', 'Reached end of region, expected `}`!') + + return region + finally: + self._ssaValues = oldSSAVals + self._blocks = oldBBNames def try_parse_op_name(self) -> Span | None: if (str_lit := self.try_parse_string_literal()) is not None: @@ -937,10 +1006,11 @@ def must_parse_attribute_type(self) -> Attribute: def try_parse_builtin_attr(self) -> Attribute: attrs = ( - self.try_parse_builtin_int_attr, self.try_parse_builtin_float_attr, + self.try_parse_builtin_int_attr, self.try_parse_builtin_str_attr, - self.try_parse_builtin_arr_attr + self.try_parse_builtin_arr_attr, + self.try_parse_function_type ) for attr_parser in attrs: @@ -960,13 +1030,14 @@ def try_parse_builtin_int_attr(self) -> IntegerAttr | None: type = self.must_parse_attribute_type() return IntegerAttr.from_params(int(value.text), type) - def try_parse_builtin_float_attr(self) -> IntegerAttr | None: + def try_parse_builtin_float_attr(self) -> FloatAttr | None: with self.tokenizer.backtracking(): - value = self.expect(self.try_parse_float_literal, 'Integer attribute must start with an integer literal!') + value = self.expect(self.try_parse_float_literal, 'Float attribute must start with a float literal!') if self.tokenizer.next_token(peek=True).text != ':': return FloatAttr.from_value(float(value.text)) + type = self.must_parse_attribute_type() - return IntegerAttr.from_params(float(value.text), type) + return FloatAttr.from_value(float(value.text), type) def try_parse_builtin_boolean_attr(self) -> IntegerAttr | None: span = self.try_parse_boolean_literal() @@ -996,17 +1067,25 @@ def try_parse_builtin_arr_attr(self) -> list[Attribute] | None: self.must_parse_characters(']', 'Array literals must be enclosed by square brackets!') return ArrayAttr.from_list(attrs) - def must_parse_attr_dict(self) -> list[tuple[Span, Attribute]]: + def must_parse_attr_dict(self) -> dict[str, Attribute]: res = ParserCommons.BNF.attr_dict.try_parse(self) if res is None: - return [] - return ParserCommons.BNF.attr_dict.collect(res, dict()).get('attributes', list()) + return dict() + return self.attr_dict_from_tuple_list(ParserCommons.BNF.attr_dict.collect(res, dict()).get('attributes', list())) + + def attr_dict_from_tuple_list(self, tuple_list: list[tuple[Span, Attribute]]): + return dict( + ( + (span.string_contents if isinstance(span, StringLiteral) else span.text), + attr + ) for span, attr in tuple_list + ) - def try_parse_attr_dict(self) -> list[tuple[Span, Attribute]] | None: + def try_parse_attr_dict(self) -> dict[str, Attribute] | None: res = ParserCommons.BNF.attr_dict.try_parse(self) if res is None: return None - return ParserCommons.BNF.attr_dict.collect(res, dict()).get('attributes', list()) + return self.attr_dict_from_tuple_list(ParserCommons.BNF.attr_dict.collect(res, dict()).get('attributes', list())) def must_parse_function_type(self) -> tuple[list[Attribute], list[Attribute]]: """ @@ -1014,17 +1093,21 @@ def must_parse_function_type(self) -> tuple[list[Attribute], list[Attribute]]: viable function types are: (i32) -> () - i32 -> () () -> (i32, i32) + (i32, i32) -> () + Non-viable types are: i32 -> i32 + i32 -> () - Uses type-or-type-list-parens + Uses type-or-type-list-parens internally """ - args = self.must_parse_type_or_type_list_parens() + self.must_parse_characters('(', 'First group of function args must start with a `(`') + args: list[Attribute] = self.must_parse_list_of(self.try_parse_type, 'Expected type here!') + self.must_parse_characters(')', "End of function type arguments") self.must_parse_characters('->', 'Function type!') - return args, self.must_parse_type_or_type_list_parens() + return FunctionType.from_lists(args, self.must_parse_type_or_type_list_parens()) def must_parse_type_or_type_list_parens(self) -> list[Attribute]: """ @@ -1046,7 +1129,7 @@ def must_parse_type_or_type_list_parens(self) -> list[Attribute]: self.raise_error("Function type must either be single type or list of types in parenthesis!") return args - def try_parse_function_type(self) -> tuple[list[Attribute], list[Attribute]] | None: + def try_parse_function_type(self) -> FunctionType | None: if self.tokenizer.next_token(peek=True).text != '(': return None with self.tokenizer.backtracking('Function type'): From 0560fde0bff987a7aeea41c88ac58fcede6a4e78 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Sat, 10 Dec 2022 09:44:52 +0000 Subject: [PATCH 06/65] [parser] fixes true/false switcheroo --- xdsl/parser_ng.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xdsl/parser_ng.py b/xdsl/parser_ng.py index e4cf2474c5..ceb1249d7d 100644 --- a/xdsl/parser_ng.py +++ b/xdsl/parser_ng.py @@ -1045,7 +1045,7 @@ def try_parse_builtin_boolean_attr(self) -> IntegerAttr | None: if span is None: return None - int_val = ['true', 'false'].index(span.text) + int_val = ['false', 'true'].index(span.text) return IntegerAttr.from_params(int_val, IntegerType.from_width(1)) def try_parse_builtin_str_attr(self): From 3bdf8aa5f180ad2d0f07d5795373bb0669861f7c Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Sat, 10 Dec 2022 09:50:02 +0000 Subject: [PATCH 07/65] [parser] minor fixes in integer and float type parsing --- xdsl/dialects/builtin.py | 2 +- xdsl/parser_ng.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index 7038d540fa..4f6cc7e306 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -202,7 +202,7 @@ def from_params( AnyIntegerAttr: TypeAlias = IntegerAttr[IntegerType | IndexType] - +DefaultIntegerAttrType = i64 @irdl_attr_definition class Float16Type(ParametrizedAttribute, MLIRType): diff --git a/xdsl/parser_ng.py b/xdsl/parser_ng.py index ceb1249d7d..b0e712706b 100644 --- a/xdsl/parser_ng.py +++ b/xdsl/parser_ng.py @@ -20,7 +20,7 @@ DenseIntOrFPElementsAttr, Float16Type, Float32Type, Float64Type, FloatAttr, FunctionType, IndexType, IntegerType, OpaqueAttr, Signedness, StringAttr, FlatSymbolRefAttr, IntegerAttr, ArrayAttr, TensorType, UnitAttr, - UnrankedTensorType, UnregisteredOp, VectorType) + UnrankedTensorType, UnregisteredOp, VectorType, DefaultIntegerAttrType) from xdsl.irdl import Data @@ -735,7 +735,7 @@ def try_parse_builtin_type(self) -> Attribute | None: width = int(re_match.group(1)) type = { 16: Float16Type, - 32: Float64Type, + 32: Float32Type, 64: Float64Type }.get(width, None) if type is None: @@ -1026,7 +1026,7 @@ def try_parse_builtin_int_attr(self) -> IntegerAttr | None: value = self.expect(self.try_parse_integer_literal, 'Integer attribute must start with an integer literal!') if self.tokenizer.next_token(peek=True).text != ':': print(self.tokenizer.next_token(peek=True)) - return IntegerAttr.from_index_int_value(int(value.text)) + return IntegerAttr.from_params(int(value.text), DefaultIntegerAttrType) type = self.must_parse_attribute_type() return IntegerAttr.from_params(int(value.text), type) @@ -1073,7 +1073,7 @@ def must_parse_attr_dict(self) -> dict[str, Attribute]: return dict() return self.attr_dict_from_tuple_list(ParserCommons.BNF.attr_dict.collect(res, dict()).get('attributes', list())) - def attr_dict_from_tuple_list(self, tuple_list: list[tuple[Span, Attribute]]): + def attr_dict_from_tuple_list(self, tuple_list: list[tuple[Span, Attribute]]) -> dict[str, Attribute]: return dict( ( (span.string_contents if isinstance(span, StringLiteral) else span.text), From b5123e89ea006aba80660969578010ec710da5b9 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Sat, 10 Dec 2022 11:01:51 +0000 Subject: [PATCH 08/65] [parser] yapf formatted to make CI happy --- xdsl/dialects/builtin.py | 1 + xdsl/parser_ng.py | 461 ++++++++++++++++++++++++--------------- xdsl/utils/bnf.py | 48 ++-- 3 files changed, 315 insertions(+), 195 deletions(-) diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index 4f6cc7e306..08cfef100f 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -204,6 +204,7 @@ def from_params( AnyIntegerAttr: TypeAlias = IntegerAttr[IntegerType | IndexType] DefaultIntegerAttrType = i64 + @irdl_attr_definition class Float16Type(ParametrizedAttribute, MLIRType): name = "f16" diff --git a/xdsl/parser_ng.py b/xdsl/parser_ng.py index b0e712706b..54f3863a3a 100644 --- a/xdsl/parser_ng.py +++ b/xdsl/parser_ng.py @@ -30,7 +30,10 @@ class ParseError(Exception): msg: str history: BacktrackingHistory | None - def __init__(self, span: Span, msg: str, history: BacktrackingHistory | None = None): + def __init__(self, + span: Span, + msg: str, + history: BacktrackingHistory | None = None): super().__init__(span.print_with_context(msg)) self.span = span self.msg = msg @@ -52,17 +55,22 @@ def print_unroll(self, file=sys.stderr): if self.parent: self.parent.print_unroll(file) - print("Aborted parsing of {} because failure at:".format(self.region_name or ''), file=file) + print("Aborted parsing of {} because failure at:".format( + self.region_name or ''), + file=file) self.error.print_pretty(file=file, print_history=False) + class BacktrackingAbort(Exception): reason: str | None def __init__(self, reason: str | None = None): - super().__init__("This message should never escape the parser, it's intended to signal a failed parsing " - "attempt\n " - "It should never be used outside of a tokenizer.backtracking() block!\n" - "The reason for this abort was {}".format('not specified' if reason is None else reason)) + super().__init__( + "This message should never escape the parser, it's intended to signal a failed parsing " + "attempt\n " + "It should never be used outside of a tokenizer.backtracking() block!\n" + "The reason for this abort was {}".format( + 'not specified' if reason is None else reason)) self.reason = reason @@ -109,14 +117,17 @@ def print_with_context(self, msg: str | None = None) -> str: offset = self.start - offset_of_first_line remaining_len = max(self.len, 1) capture = StringIO() - print("{}:{}:{}".format(self.input.name, line_no, offset, remaining_len), file=capture) + print("{}:{}:{}".format(self.input.name, line_no, offset, + remaining_len), + file=capture) for line in lines: print(line, file=capture) if remaining_len < 0: continue len_on_this_line = min(remaining_len, len(line) - offset) remaining_len -= len_on_this_line - print("{}{}".format(" " * offset, "^" * max(len_on_this_line, 1)), file=capture) + print("{}{}".format(" " * offset, "^" * max(len_on_this_line, 1)), + file=capture) if msg is not None: print("{}{}".format(" " * offset, msg), file=capture) msg = None @@ -131,6 +142,7 @@ def __repr__(self): @dataclass(frozen=True) class StringLiteral(Span): + def __post_init__(self): if len(self) < 2 or self.text[0] != '"' or self.text[-1] != '"': raise ParseError(self, "Invalid string literal!") @@ -149,7 +161,8 @@ def string_contents(self): return ast.literal_eval(self.text) def __repr__(self): - return "StringLiteral[{}:{}](text='{}')".format(self.start, self.end, self.text) + return "StringLiteral[{}:{}](text='{}')".format( + self.start, self.end, self.text) @dataclass(frozen=True) @@ -176,7 +189,8 @@ def get_nth_line_bounds(self, n: int): start = next_start + 1 return start, self.content.find('\n', start) - def get_lines_containing(self, span: Span) -> tuple[list[str], int, int] | None: + def get_lines_containing(self, + span: Span) -> tuple[list[str], int, int] | None: # A pointer to the start of the first line start = 0 line_no = 0 @@ -217,10 +231,9 @@ class Tokenizer: The position in the input. Points to the first unconsumed character. """ - break_on: tuple[str, ...] = ( - '.', '%', ' ', '(', ')', '[', ']', '{', '}', '<', '>', ':', '=', '@', '?', '|', '->', '-', '//', '\n', '\t', - '#', '"', "'", ',' - ) + break_on: tuple[str, ...] = ('.', '%', ' ', '(', ')', '[', ']', '{', '}', + '<', '>', ':', '=', '@', '?', '|', '->', '-', + '//', '\n', '\t', '#', '"', "'", ',') """ characters the tokenizer should break on """ @@ -272,13 +285,14 @@ def backtracking(self, region_name: str | None = None): self.history = BacktrackingHistory( ParseError( self.next_token(peek=True), - 'Backtracking aborted: {}'.format(ex.reason or 'unknown reason') - ), - self.history, - region_name - ) + 'Backtracking aborted: {}'.format( + ex.reason or 'unknown reason')), self.history, + region_name) elif isinstance(ex, AssertionError): - reason = ['Generic assertion failure', *(reason for reason in ex.args if isinstance(reason, str))] + reason = [ + 'Generic assertion failure', + *(reason for reason in ex.args if isinstance(reason, str)) + ] # we assume that assertions fail because of the last read-in token if len(reason) == 1: tb = StringIO() @@ -286,33 +300,29 @@ def backtracking(self, region_name: str | None = None): reason[0] += '\n' + tb.getvalue() self.history = BacktrackingHistory( - ParseError(self.last_token, reason[-1]), - self.history, - region_name - ) + ParseError(self.last_token, reason[-1]), self.history, + region_name) elif isinstance(ex, ParseError): - self.history = BacktrackingHistory( - ex, - self.history, - region_name - ) + self.history = BacktrackingHistory(ex, self.history, + region_name) elif isinstance(ex, EOFError): self.history = BacktrackingHistory( ParseError(self.last_token, "Encountered EOF"), - self.history, - region_name - ) + self.history, region_name) else: self.history = BacktrackingHistory( - ParseError(self.last_token, "Unexpected exception: {}".format(ex)), - self.history, - region_name - ) - print("Warning: Unexpected error in backtracking: {}".format(repr(ex))) + ParseError(self.last_token, + "Unexpected exception: {}".format(ex)), + self.history, region_name) + print("Warning: Unexpected error in backtracking: {}".format( + repr(ex))) raise ex self.resume_from(save) - def next_token(self, start: int | None = None, skip: int = 0, peek: bool = False, + def next_token(self, + start: int | None = None, + skip: int = 0, + peek: bool = False, include_comments: bool = False) -> Span: """ Best effort guess at what the next token could be @@ -336,7 +346,9 @@ def next_token(self, start: int | None = None, skip: int = 0, peek: bool = False self.last_token = span return span - def next_token_of_pattern(self, pattern: re.Pattern, peek: bool = False) -> Span | None: + def next_token_of_pattern(self, + pattern: re.Pattern, + peek: bool = False) -> Span | None: """ Return a span that matched the pattern, or nothing. You can choose not to consume the span. """ @@ -373,7 +385,9 @@ def _find_token_end(self, start: int | None = None) -> int: if self.input.content.startswith(part, i): return i + len(part) # otherwise return the start of the next break - return min(filter(lambda x: x >= 0, (self.input.content.find(part, i) for part in self.break_on))) + return min( + filter(lambda x: x >= 0, (self.input.content.find(part, i) + for part in self.break_on))) def next_pos(self, i: int | None = None) -> int: """ @@ -409,7 +423,9 @@ def consume_opt_whitespace(self) -> Span: return self.span_of(start, self.pos) @contextlib.contextmanager - def configured(self, break_on: tuple[str, ...] | None = None, ignore_whitespace: bool | None = None): + def configured(self, + break_on: tuple[str, ...] | None = None, + ignore_whitespace: bool | None = None): """ This is a helper class to allow expressing a temporary change in config, allowing you to write: @@ -451,16 +467,18 @@ class ParserCommons: type_alias = re.compile(r'![A-z_][A-z0-9_$.]+') attribute_alias = re.compile(r'#[A-z_][A-z0-9_$.]+') boolean_literal = re.compile(r'(true|false)') - builtin_type = re.compile('(({}))'.format( - ')|('.join(( - r'[su]?i\d+', r'f\d+', - 'tensor', 'vector', - 'memref', 'complex', - 'opaque', 'tuple', - 'index', - # TODO: add all the Float8E4M3FNType, Float8E5M2Type, and BFloat16Type - )) - )) + builtin_type = re.compile('(({}))'.format(')|('.join(( + r'[su]?i\d+', + r'f\d+', + 'tensor', + 'vector', + 'memref', + 'complex', + 'opaque', + 'tuple', + 'index', + # TODO: add all the Float8E4M3FNType, Float8E5M2Type, and BFloat16Type + )))) double_colon = re.compile('::') comma = re.compile(',') @@ -468,31 +486,45 @@ class BNF: """ Collection of BNF trees. """ - generic_operation_body = BNF.Group([ - BNF.Nonterminal('string-literal', bind="name"), - BNF.Literal('('), - BNF.ListOf(BNF.Nonterminal('value-id'), bind='args'), - BNF.Literal(')'), - BNF.OptionalGroup([ - BNF.Literal('['), - BNF.ListOf(BNF.Nonterminal('block-id'), allow_empty=False, bind='blocks'), - # TODD: allow for block args here?! (accordin to spec) - BNF.Literal(']') - ], debug_name="operations optional block id group"), - BNF.OptionalGroup([ + generic_operation_body = BNF.Group( + [ + BNF.Nonterminal('string-literal', bind="name"), BNF.Literal('('), - BNF.ListOf(BNF.Nonterminal('region'), bind='regions', allow_empty=False), - BNF.Literal(')') - ], debug_name="operation regions"), - BNF.Nonterminal('attr-dict', bind='attributes', debug_name="attrbiute dictionary"), - BNF.Literal(':'), - BNF.Nonterminal('function-type', bind='type_signature') - ], debug_name="generic operation body") + BNF.ListOf(BNF.Nonterminal('value-id'), bind='args'), + BNF.Literal(')'), + BNF.OptionalGroup( + [ + BNF.Literal('['), + BNF.ListOf(BNF.Nonterminal('block-id'), + allow_empty=False, + bind='blocks'), + # TODD: allow for block args here?! (accordin to spec) + BNF.Literal(']') + ], + debug_name="operations optional block id group"), + BNF.OptionalGroup([ + BNF.Literal('('), + BNF.ListOf(BNF.Nonterminal('region'), + bind='regions', + allow_empty=False), + BNF.Literal(')') + ], + debug_name="operation regions"), + BNF.Nonterminal('attr-dict', + bind='attributes', + debug_name="attrbiute dictionary"), + BNF.Literal(':'), + BNF.Nonterminal('function-type', bind='type_signature') + ], + debug_name="generic operation body") attr_dict = BNF.Group([ BNF.Literal('{'), - BNF.ListOf(BNF.Nonterminal('attribute-entry', debug_name="attribute entry"), bind='attributes'), + BNF.ListOf(BNF.Nonterminal('attribute-entry', + debug_name="attribute entry"), + bind='attributes'), BNF.Literal('}') - ], debug_name="attrbute dictionary") + ], + debug_name="attrbute dictionary") class MlirParser: @@ -534,7 +566,11 @@ class Accent(Enum): of all try_parse functions is T_ | None """ - def __init__(self, input: str, name: str, ctx: MLContext, accent: str | Accent = Accent.MLIR): + def __init__(self, + input: str, + name: str, + ctx: MLContext, + accent: str | Accent = Accent.MLIR): self.tokenizer = Tokenizer(Input(input, name)) self.ctx = ctx if isinstance(accent, str): @@ -551,7 +587,6 @@ def begin_parse(self): self.raise_error("Unfinished business!") return ops - def must_parse_block(self) -> Block: id, args = self.must_parse_optional_block_label() @@ -585,11 +620,14 @@ def must_parse_optional_block_label(self): return next_id, arg_list def must_parse_block_arg_list(self) -> list[tuple[Span, Attribute]]: - self.assert_eq(self.tokenizer.next_token(), '(', 'Block arguments must start with `(`') + self.assert_eq(self.tokenizer.next_token(), '(', + 'Block arguments must start with `(`') - args = self.must_parse_list_of(self.try_parse_value_id_and_type, "Expected ") + args = self.must_parse_list_of(self.try_parse_value_id_and_type, + "Expected ") - self.assert_eq(self.tokenizer.next_token(), ')', 'Expected closing of block arguments!') + self.assert_eq(self.tokenizer.next_token(), ')', + 'Expected closing of block arguments!') return args @@ -600,18 +638,21 @@ def try_parse_single_reference(self) -> Span | None: return reference if (reference := self.try_parse_suffix_id()) is not None: return reference - raise BacktrackingAbort("References must conform to `@` (string-literal | suffix-id)") + raise BacktrackingAbort( + "References must conform to `@` (string-literal | suffix-id)") def must_parse_reference(self) -> list[Span]: return self.must_parse_list_of( self.try_parse_single_reference, 'Expected reference here in the format of `@` (suffix-id | string-literal)', ParserCommons.double_colon, - allow_empty=False - ) + allow_empty=False) - def must_parse_list_of(self, try_parse: Callable[[], T_ | None], error_msg: str, - separator_pattern: re.Pattern = ParserCommons.comma, allow_empty: bool = True) -> list[T_]: + def must_parse_list_of(self, + try_parse: Callable[[], T_ | None], + error_msg: str, + separator_pattern: re.Pattern = ParserCommons.comma, + allow_empty: bool = True) -> list[T_]: """ This is a greedy list-parser. It accepts input only in these cases: @@ -635,28 +676,35 @@ def must_parse_list_of(self, try_parse: Callable[[], T_ | None], error_msg: str, items.append(first_item) - while (match := self.tokenizer.next_token_of_pattern(separator_pattern)) is not None: + while (match := self.tokenizer.next_token_of_pattern(separator_pattern) + ) is not None: next_item = try_parse() if next_item is None: # if the separator is emtpy, we are good here if separator_pattern.pattern == '': return items - self.raise_error(error_msg + ' because was able to match next separator {}'.format(match)) + self.raise_error(error_msg + + ' because was able to match next separator {}' + .format(match)) items.append(next_item) return items def try_parse_integer_literal(self) -> Span | None: - return self.tokenizer.next_token_of_pattern(ParserCommons.integer_literal) + return self.tokenizer.next_token_of_pattern( + ParserCommons.integer_literal) def try_parse_decimal_literal(self) -> Span | None: - return self.tokenizer.next_token_of_pattern(ParserCommons.decimal_literal) + return self.tokenizer.next_token_of_pattern( + ParserCommons.decimal_literal) def try_parse_string_literal(self) -> StringLiteral | None: - return StringLiteral.from_span(self.tokenizer.next_token_of_pattern(ParserCommons.string_literal)) + return StringLiteral.from_span( + self.tokenizer.next_token_of_pattern(ParserCommons.string_literal)) def try_parse_float_literal(self) -> Span | None: - return self.tokenizer.next_token_of_pattern(ParserCommons.float_literal) + return self.tokenizer.next_token_of_pattern( + ParserCommons.float_literal) def try_parse_bare_id(self) -> Span | None: return self.tokenizer.next_token_of_pattern(ParserCommons.bare_id) @@ -671,16 +719,18 @@ def try_parse_block_id(self) -> Span | None: return self.tokenizer.next_token_of_pattern(ParserCommons.block_id) def try_parse_boolean_literal(self) -> Span | None: - return self.tokenizer.next_token_of_pattern(ParserCommons.boolean_literal) + return self.tokenizer.next_token_of_pattern( + ParserCommons.boolean_literal) def try_parse_value_id_and_type(self) -> tuple[Span, Attribute] | None: with self.tokenizer.backtracking(): value_id = self.try_parse_value_id() if value_id is None: - raise BacktrackingAbort("Expected value id here!") + raise BacktrackingAbort("Invalid value-id format!") - self.must_parse_characters(':', 'Expected expression (value-id `:` type)') + self.must_parse_characters( + ':', 'Expected expression (value-id `:` type)') type = self.try_parse_type() @@ -691,18 +741,23 @@ def try_parse_value_id_and_type(self) -> tuple[Span, Attribute] | None: def try_parse_type(self) -> Attribute | None: if (builtin_type := self.try_parse_builtin_type()) is not None: return builtin_type - if (dialect_type := self.try_parse_dialect_type_or_attribute('type')) is not None: + if (dialect_type := + self.try_parse_dialect_type_or_attribute('type')) is not None: return dialect_type return None - def try_parse_dialect_type_or_attribute(self, kind: Literal['type', 'attr']) -> Attribute | None: + def try_parse_dialect_type_or_attribute( + self, kind: Literal['type', 'attr']) -> Attribute | None: with self.tokenizer.backtracking(): if kind == 'type': - self.must_parse_characters('!', "Dialect types must start with a `!`") + self.must_parse_characters( + '!', "Dialect types must start with a `!`") else: - self.must_parse_characters('#', "Dialect attributes must start with a `#`") + self.must_parse_characters( + '#', "Dialect attributes must start with a `#`") - type_name = self.tokenizer.next_token_of_pattern(ParserCommons.bare_id) + type_name = self.tokenizer.next_token_of_pattern( + ParserCommons.bare_id) if type_name is None: raise BacktrackingAbort("Expected a type name") @@ -718,7 +773,8 @@ def try_parse_builtin_type(self) -> Attribute | None: parse a builtin-type like i32, index, vector etc. """ with self.tokenizer.backtracking(): - name = self.tokenizer.next_token_of_pattern(ParserCommons.builtin_type) + name = self.tokenizer.next_token_of_pattern( + ParserCommons.builtin_type) if name is None: raise BacktrackingAbort("Expected builtin name!") if name.text == 'index': @@ -729,7 +785,8 @@ def try_parse_builtin_type(self) -> Attribute | None: 'u': Signedness.UNSIGNED, 'i': Signedness.SIGNLESS } - return IntegerType.from_width(int(re_match.group(1)), signedness[name.text[0]]) + return IntegerType.from_width(int(re_match.group(1)), + signedness[name.text[0]]) if (re_match := re.match(r'^f(\d+)$', name.text)) is not None: width = int(re_match.group(1)) @@ -739,15 +796,18 @@ def try_parse_builtin_type(self) -> Attribute | None: 64: Float64Type }.get(width, None) if type is None: - self.raise_error("Unsupported floating point width: {}".format(width)) + self.raise_error( + "Unsupported floating point width: {}".format(width)) return type() - return self.must_parse_builtin_parametrized_type(name) - def must_parse_builtin_parametrized_type(self, name: Span) -> ParametrizedAttribute: + def must_parse_builtin_parametrized_type( + self, name: Span) -> ParametrizedAttribute: + def unimplemented() -> ParametrizedAttribute: - raise ParseError(self.tokenizer.next_token(), "Type not supported yet!") + raise ParseError(self.tokenizer.next_token(), + "Type not supported yet!") builtin_parsers: dict[str, Callable[[], ParametrizedAttribute]] = { 'vector': self.must_parse_vector_attrs, @@ -760,26 +820,34 @@ def unimplemented() -> ParametrizedAttribute: if name.text not in builtin_parsers: raise ParseError(name, "Unknown builtin {}".format(name.text)) - self.assert_eq(self.tokenizer.next_token(), '<', 'Expected parameter list here!') + self.assert_eq(self.tokenizer.next_token(), '<', + 'Expected parameter list here!') res = builtin_parsers[name.text]() - self.assert_eq(self.tokenizer.next_token(), '>', 'Expected end of parameter list here!') + self.assert_eq(self.tokenizer.next_token(), '>', + 'Expected end of parameter list here!') return res def must_parse_complex_attrs(self): type = self.try_parse_type() self.raise_error("ComplexType is unimplemented!") - def try_parse_numerical_dims(self, accept_closing_bracket: bool = False, lower_bound: int = 1) -> Iterable[int]: - while (shape_arg := self.try_parse_shape_element(lower_bound)) is not None: + def try_parse_numerical_dims(self, + accept_closing_bracket: bool = False, + lower_bound: int = 1) -> Iterable[int]: + while (shape_arg := + self.try_parse_shape_element(lower_bound)) is not None: yield shape_arg # look out for the closing bracket for scalable vector dims - if accept_closing_bracket and self.tokenizer.next_token(peek=True).text == ']': + if accept_closing_bracket and self.tokenizer.next_token( + peek=True).text == ']': break - self.assert_eq(self.tokenizer.next_token(), 'x', 'Unexpected end of dimension parameters!') + self.assert_eq(self.tokenizer.next_token(), 'x', + 'Unexpected end of dimension parameters!') def must_parse_vector_attrs(self) -> AnyVectorType: # also break on 'x' characters as they are separators in dimension parameters - with self.tokenizer.configured(break_on=self.tokenizer.break_on + ('x',)): + with self.tokenizer.configured(break_on=self.tokenizer.break_on + + ('x', )): shape = list[int](self.try_parse_numerical_dims()) scaling_shape: list[int] | None = None @@ -787,8 +855,12 @@ def must_parse_vector_attrs(self) -> AnyVectorType: self.tokenizer.next_token() # we now need to parse the scalable dimensions scaling_shape = list(self.try_parse_numerical_dims()) - self.assert_eq(self.tokenizer.next_token(), ']', 'Expected end of scalable vector dimensions here!') - self.assert_eq(self.tokenizer.next_token(), 'x', 'Expected end of scalable vector dimensions here!') + self.assert_eq( + self.tokenizer.next_token(), ']', + 'Expected end of scalable vector dimensions here!') + self.assert_eq( + self.tokenizer.next_token(), 'x', + 'Expected end of scalable vector dimensions here!') if scaling_shape is not None: # TODO: handle scaling vectors! @@ -797,17 +869,21 @@ def must_parse_vector_attrs(self) -> AnyVectorType: type = self.try_parse_type() if type is None: - self.raise_error("Expected a type at the end of the vector parameters!") + self.raise_error( + "Expected a type at the end of the vector parameters!") return VectorType.from_type_and_list(type, shape) def must_parse_tensor_or_memref_dims(self) -> list[int] | None: - with self.tokenizer.configured(break_on=self.tokenizer.break_on + ('x',)): + with self.tokenizer.configured(break_on=self.tokenizer.break_on + + ('x', )): if self.tokenizer.next_token(peek=True).text == '*': # consume `*` self.tokenizer.next_token() # consume `x` - self.assert_eq(self.tokenizer.next_token(), 'x', 'Unranked tensors must follow format (`<*x` type `>`)') + self.assert_eq( + self.tokenizer.next_token(), 'x', + 'Unranked tensors must follow format (`<*x` type `>`)') else: # parse rank: return list(self.try_parse_numerical_dims(lower_bound=0)) @@ -843,7 +919,9 @@ def try_parse_shape_element(self, lower_bound: int = 1) -> int | None: value = int(int_lit.text) if value < lower_bound: # TODO: this is ugly, it's a raise inside a try_ type function, which should instead just give up - raise ParseError(int_lit, "Shape element literal cannot be negative or zero!") + raise ParseError( + int_lit, + "Shape element literal cannot be negative or zero!") return value next_token = self.tokenizer.next_token(peek=True) @@ -855,18 +933,19 @@ def try_parse_shape_element(self, lower_bound: int = 1) -> int | None: def must_parse_type_params(self) -> list[Attribute]: # consume opening bracket - assert self.tokenizer.next_token().text == '<', 'Type must be parameterized!' + assert self.tokenizer.next_token( + ).text == '<', 'Type must be parameterized!' - params = self.must_parse_list_of( - self.try_parse_type, - 'Expected a type here!' - ) + params = self.must_parse_list_of(self.try_parse_type, + 'Expected a type here!') - assert self.tokenizer.next_token().text == '>', 'Expected end of type parameterization here!' + assert self.tokenizer.next_token( + ).text == '>', 'Expected end of type parameterization here!' return params - def expect(self, try_parse: Callable[[], T_ | None], error_message: str) -> T_: + def expect(self, try_parse: Callable[[], T_ | None], + error_message: str) -> T_: """ Used to force completion of a try_parse function. Will throw a parse error if it can't """ @@ -889,52 +968,60 @@ def raise_error(self, msg: str, at_position: Span | None = None): def assert_eq(self, got: Span, want: str, msg: str): if got.text == want: return - raise AssertionError("Assertion failed (assert `{}` == `{}`): {}".format(got.text, want, msg), got) + raise AssertionError( + "Assertion failed (assert `{}` == `{}`): {}".format( + got.text, want, msg), got) def must_parse_characters(self, text: str, msg: str): self.assert_eq(self.tokenizer.next_token(), text, msg) - def must_parse_op_result_list(self) -> list[tuple[Span, Attribute] | Span] | None: - inner_parser = (dict(( - (MlirParser.Accent.MLIR, self.try_parse_value_id), - (MlirParser.Accent.XDSL, self.try_parse_value_id_and_type) - )))[self.accent] + def must_parse_op_result_list( + self) -> list[tuple[Span, Attribute] | Span] | None: + inner_parser = (dict( + ((MlirParser.Accent.MLIR, self.try_parse_value_id), + (MlirParser.Accent.XDSL, + self.try_parse_value_id_and_type))))[self.accent] - return self.must_parse_list_of(self.try_parse_value_id, 'Expected op-result here!', allow_empty=False) + return self.must_parse_list_of(self.try_parse_value_id, + 'Expected op-result here!', + allow_empty=False) def try_parse_operation(self) -> Operation | None: with self.tokenizer.backtracking("operation"): if self.tokenizer.next_token(peek=True).text == '%': result_list = self.must_parse_op_result_list() - self.must_parse_characters('=', 'Operation definitions expect an `=` after op-result-list!') + self.must_parse_characters( + '=', + 'Operation definitions expect an `=` after op-result-list!' + ) else: result_list = [] - generic_op = ParserCommons.BNF.generic_operation_body.try_parse(self) + generic_op = ParserCommons.BNF.generic_operation_body.try_parse( + self) if generic_op is None: self.raise_error("custom operations not supported as of yet!") - values = ParserCommons.BNF.generic_operation_body.collect(generic_op, dict()) + values = ParserCommons.BNF.generic_operation_body.collect( + generic_op, dict()) arg_types, ret_types = ([], []) if 'type_signature' in values: - functype : FunctionType = values['type_signature'] + functype: FunctionType = values['type_signature'] arg_types, ret_types = functype.inputs.data, functype.outputs.data if len(ret_types) != len(result_list): raise ParseError( values['name'], - "Mismatch between type signature and result list for op!" - ) + "Mismatch between type signature and result list for op!") op_type = self.ctx.get_op(values['name'].string_contents) return op_type.create( [self._ssaValues[arg.text] for arg in values['args']], - ret_types, - values['attributes'], - [self._blocks[block_name.text] for block_name in values.get('blocks', [])], - values.get('regions', []) - ) + ret_types, values['attributes'], [ + self._blocks[block_name.text] + for block_name in values.get('blocks', []) + ], values.get('regions', [])) def must_parse_region(self) -> Region: oldSSAVals = self._ssaValues.copy() @@ -953,7 +1040,8 @@ def must_parse_region(self) -> Region: while self.tokenizer.next_token(peek=True).text == '^': region.add_block(self.must_parse_block()) - self.must_parse_characters('}', 'Reached end of region, expected `}`!') + self.must_parse_characters('}', + 'Reached end of region, expected `}`!') return region finally: @@ -976,9 +1064,12 @@ def must_parse_attribute_entry(self) -> tuple[Span, Attribute]: name = self.try_parse_string_literal() if name is None: - self.raise_error('Expected bare-id or string-literal here as part of attribute entry!') + self.raise_error( + 'Expected bare-id or string-literal here as part of attribute entry!' + ) - self.must_parse_characters('=', 'Attribute entries must be of format name `=` attribute!') + self.must_parse_characters( + '=', 'Attribute entries must be of format name `=` attribute!') return name, self.must_parse_attribute() @@ -990,7 +1081,8 @@ def must_parse_attribute(self) -> Attribute: if self.tokenizer.next_token(peek=True).text == '#': value = self.try_parse_dialect_type_or_attribute('attr') if value is None: - self.raise_error('`#` must be followed by a valid builtin attribute!') + self.raise_error( + '`#` must be followed by a valid builtin attribute!') return value builtin_val = self.try_parse_builtin_attr() @@ -1001,17 +1093,17 @@ def must_parse_attribute(self) -> Attribute: return builtin_val def must_parse_attribute_type(self) -> Attribute: - self.must_parse_characters(':', 'Expected attribute type definition here ( `:` type )') - return self.expect(self.try_parse_type, 'Expected attribute type definition here ( `:` type )') + self.must_parse_characters( + ':', 'Expected attribute type definition here ( `:` type )') + return self.expect( + self.try_parse_type, + 'Expected attribute type definition here ( `:` type )') def try_parse_builtin_attr(self) -> Attribute: - attrs = ( - self.try_parse_builtin_float_attr, - self.try_parse_builtin_int_attr, - self.try_parse_builtin_str_attr, - self.try_parse_builtin_arr_attr, - self.try_parse_function_type - ) + attrs = (self.try_parse_builtin_float_attr, + self.try_parse_builtin_int_attr, + self.try_parse_builtin_str_attr, + self.try_parse_builtin_arr_attr, self.try_parse_function_type) for attr_parser in attrs: if (val := attr_parser()) is not None: @@ -1023,16 +1115,21 @@ def try_parse_builtin_int_attr(self) -> IntegerAttr | None: return bool with self.tokenizer.backtracking("built in int attribute"): - value = self.expect(self.try_parse_integer_literal, 'Integer attribute must start with an integer literal!') + value = self.expect( + self.try_parse_integer_literal, + 'Integer attribute must start with an integer literal!') if self.tokenizer.next_token(peek=True).text != ':': print(self.tokenizer.next_token(peek=True)) - return IntegerAttr.from_params(int(value.text), DefaultIntegerAttrType) + return IntegerAttr.from_params(int(value.text), + DefaultIntegerAttrType) type = self.must_parse_attribute_type() return IntegerAttr.from_params(int(value.text), type) def try_parse_builtin_float_attr(self) -> FloatAttr | None: with self.tokenizer.backtracking(): - value = self.expect(self.try_parse_float_literal, 'Float attribute must start with a float literal!') + value = self.expect( + self.try_parse_float_literal, + 'Float attribute must start with a float literal!') if self.tokenizer.next_token(peek=True).text != ':': return FloatAttr.from_value(float(value.text)) @@ -1062,32 +1159,40 @@ def try_parse_builtin_arr_attr(self) -> list[Attribute] | None: if self.tokenizer.next_token(peek=True).text != '[': return None with self.tokenizer.backtracking(): - self.must_parse_characters('[', 'Array literals must start with `[`') - attrs = self.must_parse_list_of(self.try_parse_builtin_attr, 'Expected array entry!') - self.must_parse_characters(']', 'Array literals must be enclosed by square brackets!') + self.must_parse_characters('[', + 'Array literals must start with `[`') + attrs = self.must_parse_list_of(self.try_parse_builtin_attr, + 'Expected array entry!') + self.must_parse_characters( + ']', 'Array literals must be enclosed by square brackets!') return ArrayAttr.from_list(attrs) def must_parse_attr_dict(self) -> dict[str, Attribute]: res = ParserCommons.BNF.attr_dict.try_parse(self) if res is None: return dict() - return self.attr_dict_from_tuple_list(ParserCommons.BNF.attr_dict.collect(res, dict()).get('attributes', list())) + return self.attr_dict_from_tuple_list( + ParserCommons.BNF.attr_dict.collect(res, dict()).get( + 'attributes', list())) - def attr_dict_from_tuple_list(self, tuple_list: list[tuple[Span, Attribute]]) -> dict[str, Attribute]: + def attr_dict_from_tuple_list( + self, tuple_list: list[tuple[Span, + Attribute]]) -> dict[str, Attribute]: return dict( - ( - (span.string_contents if isinstance(span, StringLiteral) else span.text), - attr - ) for span, attr in tuple_list - ) + ((span.string_contents if isinstance(span, StringLiteral + ) else span.text), attr) + for span, attr in tuple_list) def try_parse_attr_dict(self) -> dict[str, Attribute] | None: res = ParserCommons.BNF.attr_dict.try_parse(self) if res is None: return None - return self.attr_dict_from_tuple_list(ParserCommons.BNF.attr_dict.collect(res, dict()).get('attributes', list())) + return self.attr_dict_from_tuple_list( + ParserCommons.BNF.attr_dict.collect(res, dict()).get( + 'attributes', list())) - def must_parse_function_type(self) -> tuple[list[Attribute], list[Attribute]]: + def must_parse_function_type( + self) -> tuple[list[Attribute], list[Attribute]]: """ Parses function-type: @@ -1101,13 +1206,16 @@ def must_parse_function_type(self) -> tuple[list[Attribute], list[Attribute]]: Uses type-or-type-list-parens internally """ - self.must_parse_characters('(', 'First group of function args must start with a `(`') - args: list[Attribute] = self.must_parse_list_of(self.try_parse_type, 'Expected type here!') + self.must_parse_characters( + '(', 'First group of function args must start with a `(`') + args: list[Attribute] = self.must_parse_list_of( + self.try_parse_type, 'Expected type here!') self.must_parse_characters(')', "End of function type arguments") self.must_parse_characters('->', 'Function type!') - return FunctionType.from_lists(args, self.must_parse_type_or_type_list_parens()) + return FunctionType.from_lists( + args, self.must_parse_type_or_type_list_parens()) def must_parse_type_or_type_list_parens(self) -> list[Attribute]: """ @@ -1119,14 +1227,15 @@ def must_parse_type_or_type_list_parens(self) -> list[Attribute]: """ if self.tokenizer.next_token(peek=True).text == '(': self.must_parse_characters('(', 'Function type!') - args: list[Attribute] = self.must_parse_list_of(self.try_parse_type, 'Expected type here!') + args: list[Attribute] = self.must_parse_list_of( + self.try_parse_type, 'Expected type here!') self.must_parse_characters(')', "End of function type args") else: - args = [ - self.try_parse_type() - ] + args = [self.try_parse_type()] if args[0] is None: - self.raise_error("Function type must either be single type or list of types in parenthesis!") + self.raise_error( + "Function type must either be single type or list of types in parenthesis!" + ) return args def try_parse_function_type(self) -> FunctionType | None: diff --git a/xdsl/utils/bnf.py b/xdsl/utils/bnf.py index 2ac7eed41f..514d333920 100644 --- a/xdsl/utils/bnf.py +++ b/xdsl/utils/bnf.py @@ -42,7 +42,8 @@ class Literal(BNFToken): debug_name: str | None = field(kw_only=True, default=None) def must_parse(self, parser: MlirParser): - return parser.must_parse_characters(self.string, 'Expected `{}`'.format(self.string)) + return parser.must_parse_characters( + self.string, 'Expected `{}`'.format(self.string)) def __repr__(self): return '`{}`'.format(self.string) @@ -85,19 +86,26 @@ class Nonterminal(BNFToken): debug_name: str | None = field(kw_only=True, default=None) def must_parse(self, parser: MlirParser): - if hasattr(parser, 'must_parse_{}'.format(self.name.replace('-', '_'))): - return getattr(parser, 'must_parse_{}'.format(self.name.replace('-', '_')))() - elif hasattr(parser, 'try_parse_{}'.format(self.name.replace('-', '_'))): + if hasattr(parser, 'must_parse_{}'.format(self.name.replace('-', + '_'))): + return getattr(parser, + 'must_parse_{}'.format(self.name.replace('-', + '_')))() + elif hasattr(parser, 'try_parse_{}'.format(self.name.replace('-', + '_'))): return parser.expect( - getattr(parser, 'try_parse_{}'.format(self.name.replace('-', '_'))), - 'Expected to parse {} here!'.format(self.name) - ) + getattr(parser, + 'try_parse_{}'.format(self.name.replace('-', '_'))), + 'Expected to parse {} here!'.format(self.name)) else: - raise NotImplementedError("Parser cannot parse {}".format(self.name)) + raise NotImplementedError("Parser cannot parse {}".format( + self.name)) def try_parse(self, parser: MlirParser) -> T | None: if hasattr(parser, 'try_parse_{}'.format(self.name.replace('-', '_'))): - return getattr(parser, 'try_parse_{}'.format(self.name.replace('-', '_')))() + return getattr(parser, + 'try_parse_{}'.format(self.name.replace('-', + '_')))() return super().try_parse(parser) def __repr__(self): @@ -111,9 +119,7 @@ class Group(BNFToken): debug_name: str | None = field(kw_only=True, default=None) def must_parse(self, parser: MlirParser) -> T: - return [ - token.must_parse(parser) for token in self.tokens - ] + return [token.must_parse(parser) for token in self.tokens] def __repr__(self): return '( {} )'.format(' '.join(repr(t) for t in self.tokens)) @@ -136,7 +142,8 @@ def must_parse(self, parser: MlirParser) -> list[T]: val = self.wraps.try_parse(parser) if val is None: if len(res) == 0: - raise AssertionError("Expected at least one of {}".format(self.wraps)) + raise AssertionError("Expected at least one of {}".format( + self.wraps)) return res res.append(val) @@ -189,16 +196,17 @@ class ListOf(BNFToken): def must_parse(self, parser: MlirParser) -> T | None: return parser.must_parse_list_of( - lambda : self.element.try_parse(parser), + lambda: self.element.try_parse(parser), 'Expected {}!'.format(self.element), separator_pattern=self.separator, - allow_empty=self.allow_empty - ) + allow_empty=self.allow_empty) def __repr__(self): if self.allow_empty: - return '( {elm} ( re`{sep}` {elm} )* )?'.format(elm=self.element, sep=self.separator.pattern) - return '{elm} ( re`{sep}` {elm} )*'.format(elm=self.element, sep=self.separator.pattern) + return '( {elm} ( re`{sep}` {elm} )* )?'.format( + elm=self.element, sep=self.separator.pattern) + return '{elm} ( re`{sep}` {elm} )*'.format(elm=self.element, + sep=self.separator.pattern) def collect(self, values, collection: dict) -> dict: for value in values: @@ -227,5 +235,7 @@ def collect(self, value, collection: dict) -> dict: return super().collect(value, collection) -def OptionalGroup(tokens: list[BNFToken], bind: str | None = None, debug_name: str | None = None) -> Optional: +def OptionalGroup(tokens: list[BNFToken], + bind: str | None = None, + debug_name: str | None = None) -> Optional: return Optional(Group(tokens), bind=bind, debug_name=debug_name) From 57091bc8e01d75ee92cf3cbfe78107571646f2d8 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Mon, 12 Dec 2022 21:25:01 +0000 Subject: [PATCH 09/65] [parser] improve error logging and add docstrings to methods --- xdsl/parser_ng.py | 389 +++++++++++++++++++++++++++++----------------- xdsl/utils/bnf.py | 18 +-- 2 files changed, 252 insertions(+), 155 deletions(-) diff --git a/xdsl/parser_ng.py b/xdsl/parser_ng.py index 54f3863a3a..230aaaf8ac 100644 --- a/xdsl/parser_ng.py +++ b/xdsl/parser_ng.py @@ -10,6 +10,7 @@ from typing import Any, TypeVar, Iterable, Literal, Optional from enum import Enum +from .printer import Printer from xdsl.ir import (SSAValue, Block, Callable, Attribute, Operation, Region, BlockArgument, MLContext, ParametrizedAttribute) @@ -50,16 +51,24 @@ class BacktrackingHistory: error: ParseError parent: BacktrackingHistory | None region_name: str | None + pos: int def print_unroll(self, file=sys.stderr): if self.parent: self.parent.print_unroll(file) - print("Aborted parsing of {} because failure at:".format( - self.region_name or ''), + print("Parsing of {} failed:".format(self.region_name or ''), file=file) self.error.print_pretty(file=file, print_history=False) + def get_farthest_point(self) -> int: + """ + Find the farthest this history managed to parse + """ + if self.parent: + return max(self.pos, self.parent.get_farthest_point()) + return self.pos + class BacktrackingAbort(Exception): reason: str | None @@ -261,19 +270,21 @@ def resume_from(self, save: save_t): @contextlib.contextmanager def backtracking(self, region_name: str | None = None): """ - Used to create backtracking parsers. You can wrap you parse code into - - with tokenizer.backtracking(): - # do some stuff - assert x == 'array' + This context manager can be used to mark backtracking regions. - All exceptions triggered in the body will abort the parsing attempt, but not escape further. + When an error is thrown during backtracking, it is recorded and stored together + with some meta information in the history attribute. - The tokenizer state will not change. + The backtracker accepts the following exceptions: + - ParseError: signifies that the region could not be parsed because of (unexpected) syntax errors + - BacktrackingAbort: signifies that backtracking was aborted, not necessarily indicating a syntax error + - AssertionError: this error should probably be phased out in favour of the two above + - EOFError: signals that EOF was reached unexpectedly - When backtracking occurred, the backtracker will save the last exception in last_error + Any other error will be printed to stderr, but backtracking will continue as normal. """ save = self.save() + starting_position = self.pos try: yield # clear error history when something doesn't fail @@ -281,95 +292,125 @@ def backtracking(self, region_name: str | None = None): # if a backtracking() completes without failre, something has been parsed (we assume) self.history = None except Exception as ex: + how_far_we_got = self.pos + + # AssertionErrors act upon the consumed token, this means we only go to the start of the token if isinstance(ex, BacktrackingAbort): - self.history = BacktrackingHistory( - ParseError( - self.next_token(peek=True), - 'Backtracking aborted: {}'.format( - ex.reason or 'unknown reason')), self.history, - region_name) - elif isinstance(ex, AssertionError): - reason = [ - 'Generic assertion failure', - *(reason for reason in ex.args if isinstance(reason, str)) - ] - # we assume that assertions fail because of the last read-in token - if len(reason) == 1: - tb = StringIO() - traceback.print_exc(file=tb) - reason[0] += '\n' + tb.getvalue() - - self.history = BacktrackingHistory( - ParseError(self.last_token, reason[-1]), self.history, - region_name) - elif isinstance(ex, ParseError): - self.history = BacktrackingHistory(ex, self.history, - region_name) - elif isinstance(ex, EOFError): - self.history = BacktrackingHistory( - ParseError(self.last_token, "Encountered EOF"), - self.history, region_name) - else: - self.history = BacktrackingHistory( - ParseError(self.last_token, - "Unexpected exception: {}".format(ex)), - self.history, region_name) - print("Warning: Unexpected error in backtracking: {}".format( - repr(ex))) - raise ex + # TODO: skip space as well + how_far_we_got -= self.last_token.len + + # if we have no error history, start recording! + if not self.history: + self.history = self.history_entry_from_exception( + ex, region_name, how_far_we_got) + + # if we got further than on previous attempts + elif how_far_we_got > self.history.get_farthest_point(): + # throw away history + self.history = None + # generate new history entry, + self.history = self.history_entry_from_exception( + ex, region_name, how_far_we_got) + + # otherwise, add to exception, if we are in a named region + elif region_name is not None and how_far_we_got - starting_position > 0: + self.history = self.history_entry_from_exception( + ex, region_name, how_far_we_got) + self.resume_from(save) - def next_token(self, - start: int | None = None, - skip: int = 0, - peek: bool = False, - include_comments: bool = False) -> Span: + def history_entry_from_exception(self, ex: Exception, region: str, + pos: int) -> BacktrackingHistory: + """ + Given an exception generated inside a backtracking attempt, + generate a BacktrackingHistory object with the relevant information in it. + + If an unexpected exception type is encountered, print a traceback to stderr """ - Best effort guess at what the next token could be + if isinstance(ex, ParseError): + return BacktrackingHistory(ex, self.history, region, pos) + elif isinstance(ex, AssertionError): + reason = [ + 'Generic assertion failure', + *(reason for reason in ex.args if isinstance(reason, str)) + ] + # we assume that assertions fail because of the last read-in token + if len(reason) == 1: + tb = StringIO() + traceback.print_exc(file=tb) + reason[0] += '\n' + tb.getvalue() + + return BacktrackingHistory(ParseError(self.last_token, reason[-1]), + self.history, region, pos) + elif isinstance(ex, BacktrackingAbort): + return BacktrackingHistory( + ParseError( + self.next_token(peek=True), + 'Backtracking aborted: {}'.format(ex.reason + or 'unknown reason')), + self.history, region, pos) + elif isinstance(ex, EOFError): + return BacktrackingHistory( + ParseError(self.last_token, "Encountered EOF"), self.history, + region, pos) + + print("Warning: Unexpected error in backtracking:", file=sys.stderr) + traceback.print_exception(ex, file=sys.stderr) + + return BacktrackingHistory( + ParseError(self.last_token, "Unexpected exception: {}".format(ex)), + self.history, region, pos) + + def next_token(self, start: int | None = None, peek: bool = False) -> Span: + """ + Return a Span of the next token, according to the self.break_on rules. + + Can be modified using: + + - start: don't start at the current tokenizer position, instead start here (useful for skipping comments, etc) + - peek: don't advance the position, only "peek" at the input + + This will skip over line comments. Meaning it will skip the entire line if it encounters '//' """ i = self.next_pos(start) - while skip > 0: - # skip whitespace if able - i = self.next_pos(self._find_token_end(i)) - skip -= 1 - # advance to the next position + # construct the span: + span = Span(i, self._find_token_end(i), self.input) + # advance pointer if not peeking if not peek: - self.pos = self._find_token_end(i) - - span = self.span_of(i, self._find_token_end(i)) - if not include_comments and span.text == '//': - while self.input.at(i) != '\n': - i += 1 - return self.next_token(i, 0, peek, include_comments) + self.pos = span.end # save last token self.last_token = span return span def next_token_of_pattern(self, - pattern: re.Pattern, + pattern: re.Pattern | str, peek: bool = False) -> Span | None: """ Return a span that matched the pattern, or nothing. You can choose not to consume the span. """ start = self.next_pos() + + # handle search for string literal + if isinstance(pattern, str): + if self.starts_with(pattern): + if not peek: + self.pos = start + len(pattern) + return Span(start, start + len(pattern), self.input) + return None + + # handle regex logic match = pattern.match(self.input.content, start) if match is None: return None + if not peek: self.pos = match.end() + # save last token - self.last_token = self.span_of(start, match.end()) + self.last_token = Span(start, match.end(), self.input) return self.last_token - def jump_back_to(self, span: Span): - """ - This can be used to "rewind" the tokenizer back to the point right before you consumed the token. - - This leaves everything except the position untouched - """ - self.pos = span.start - def consume_peeked(self, peeked_span: Span): if peeked_span.start != self.next_pos(): raise ParseError(peeked_span, "This is not the peeked span!") @@ -392,6 +433,8 @@ def _find_token_end(self, start: int | None = None) -> int: def next_pos(self, i: int | None = None) -> int: """ Find the next starting position (optionally starting from i), considering ignore_whitespaces + + This will skip line comments! """ i = self.pos if i is None else i # skip whitespaces @@ -405,22 +448,19 @@ def next_pos(self, i: int | None = None) -> int: return i def is_eof(self): + """ + Check if the end of the input was reached. + """ try: - i = self.pos - while self.input.at(i).isspace(): - i += 1 - return False + self.next_pos() except EOFError: return True - def span_of(self, start: int, end: int) -> Span: - return Span(start, end, self.input) - def consume_opt_whitespace(self) -> Span: start = self.pos while self.input.at(self.pos).isspace(): self.pos += 1 - return self.span_of(start, self.pos) + return Span(start, self.pos, self.input) @contextlib.contextmanager def configured(self, @@ -450,6 +490,12 @@ def configured(self, self.break_on = save[1] self.ignore_whitespace = save[2] + def starts_with(self, text: str | re.Pattern): + start = self.next_pos() + if isinstance(text, re.Pattern): + return text.match(self.input.content, start) is None + return self.input.content.startswith(text, start) + class ParserCommons: """ @@ -506,6 +552,7 @@ class BNF: BNF.Literal('('), BNF.ListOf(BNF.Nonterminal('region'), bind='regions', + debug_name="regions", allow_empty=False), BNF.Literal(')') ], @@ -543,7 +590,7 @@ class MlirParser: You can turn a try_ into a must_ by using expect(try_parse_..., error_msg) - You can turn a must_ into a try_ by wrapping it inside of a tokenizer.backtracking() + You can turn a must_ into a try_ by wrapping it in tokenizer.backtracking() must_ type parsers are preferred because they are explicit about their failure modes. """ @@ -584,7 +631,7 @@ def begin_parse(self): while (op := self.try_parse_operation()) is not None: ops.append(op) if not self.tokenizer.is_eof(): - self.raise_error("Unfinished business!") + self.raise_error("Could not parse entire input!") return ops def must_parse_block(self) -> Block: @@ -605,40 +652,41 @@ def must_parse_block(self) -> Block: return block - def must_parse_optional_block_label(self): - next_id = self.try_parse_block_id() + def must_parse_optional_block_label( + self) -> tuple[Span, list[int, tuple[Span, Attribute]]]: + block_id = self.try_parse_block_id() arg_list = list() - if next_id is not None: - assert next_id.text not in self._blocks, "Blocks cannot have the same ID!" + if block_id is not None: + assert block_id.text not in self._blocks, "Blocks cannot have the same ID!" if self.tokenizer.next_token(peek=True).text == '(': - arg_list = enumerate(self.must_parse_block_arg_list()) + arg_list = list(enumerate(self.must_parse_block_arg_list())) self.must_parse_characters(':', 'Block label must end in a `:`!') - return next_id, arg_list + return block_id, arg_list def must_parse_block_arg_list(self) -> list[tuple[Span, Attribute]]: - self.assert_eq(self.tokenizer.next_token(), '(', - 'Block arguments must start with `(`') + self.must_parse_characters('(', 'Block arguments must start with `(`') args = self.must_parse_list_of(self.try_parse_value_id_and_type, "Expected ") - self.assert_eq(self.tokenizer.next_token(), ')', - 'Expected closing of block arguments!') + self.must_parse_characters(')', + 'Expected closing of block arguments!', + is_parse_error=True) return args def try_parse_single_reference(self) -> Span | None: - with self.tokenizer.backtracking(): + with self.tokenizer.backtracking('part of a reference'): self.must_parse_characters('@', "references must start with `@`") if (reference := self.try_parse_string_literal()) is not None: return reference if (reference := self.try_parse_suffix_id()) is not None: return reference - raise BacktrackingAbort( + self.raise_error( "References must conform to `@` (string-literal | suffix-id)") def must_parse_reference(self) -> list[Span]: @@ -685,7 +733,7 @@ def must_parse_list_of(self, return items self.raise_error(error_msg + ' because was able to match next separator {}' - .format(match)) + .format(match.text)) items.append(next_item) return items @@ -723,11 +771,11 @@ def try_parse_boolean_literal(self) -> Span | None: ParserCommons.boolean_literal) def try_parse_value_id_and_type(self) -> tuple[Span, Attribute] | None: - with self.tokenizer.backtracking(): + with self.tokenizer.backtracking("value id and type"): value_id = self.try_parse_value_id() if value_id is None: - raise BacktrackingAbort("Invalid value-id format!") + self.raise_error("Invalid value-id format!") self.must_parse_characters( ':', 'Expected expression (value-id `:` type)') @@ -735,7 +783,7 @@ def try_parse_value_id_and_type(self) -> tuple[Span, Attribute] | None: type = self.try_parse_type() if type is None: - raise BacktrackingAbort("Expected type of value-id here!") + self.raise_error("Expected type of value-id here!") return value_id, type def try_parse_type(self) -> Attribute | None: @@ -748,7 +796,7 @@ def try_parse_type(self) -> Attribute | None: def try_parse_dialect_type_or_attribute( self, kind: Literal['type', 'attr']) -> Attribute | None: - with self.tokenizer.backtracking(): + with self.tokenizer.backtracking("dialect " + kind): if kind == 'type': self.must_parse_characters( '!', "Dialect types must start with a `!`") @@ -760,7 +808,7 @@ def try_parse_dialect_type_or_attribute( ParserCommons.bare_id) if type_name is None: - raise BacktrackingAbort("Expected a type name") + self.raise_error("Expected a type name") type_def = self.ctx.get_attr(type_name.text) @@ -772,7 +820,7 @@ def try_parse_builtin_type(self) -> Attribute | None: """ parse a builtin-type like i32, index, vector etc. """ - with self.tokenizer.backtracking(): + with self.tokenizer.backtracking("builtin type"): name = self.tokenizer.next_token_of_pattern( ParserCommons.builtin_type) if name is None: @@ -820,11 +868,11 @@ def unimplemented() -> ParametrizedAttribute: if name.text not in builtin_parsers: raise ParseError(name, "Unknown builtin {}".format(name.text)) - self.assert_eq(self.tokenizer.next_token(), '<', - 'Expected parameter list here!') + self.must_parse_characters('<', 'Expected parameter list here!') res = builtin_parsers[name.text]() - self.assert_eq(self.tokenizer.next_token(), '>', - 'Expected end of parameter list here!') + self.must_parse_characters('>', + 'Expected end of parameter list here!', + is_parse_error=True) return res def must_parse_complex_attrs(self): @@ -841,8 +889,10 @@ def try_parse_numerical_dims(self, if accept_closing_bracket and self.tokenizer.next_token( peek=True).text == ']': break - self.assert_eq(self.tokenizer.next_token(), 'x', - 'Unexpected end of dimension parameters!') + self.must_parse_characters( + 'x', + 'Unexpected end of dimension parameters!', + is_parse_error=True) def must_parse_vector_attrs(self) -> AnyVectorType: # also break on 'x' characters as they are separators in dimension parameters @@ -851,16 +901,17 @@ def must_parse_vector_attrs(self) -> AnyVectorType: shape = list[int](self.try_parse_numerical_dims()) scaling_shape: list[int] | None = None - if self.tokenizer.next_token(peek=True).text == '[': - self.tokenizer.next_token() + if self.tokenizer.next_token_of_pattern('[') is not None: # we now need to parse the scalable dimensions scaling_shape = list(self.try_parse_numerical_dims()) - self.assert_eq( - self.tokenizer.next_token(), ']', - 'Expected end of scalable vector dimensions here!') - self.assert_eq( - self.tokenizer.next_token(), 'x', - 'Expected end of scalable vector dimensions here!') + self.must_parse_characters( + ']', + 'Expected end of scalable vector dimensions here!', + is_parse_error=True) + self.must_parse_characters( + 'x', + 'Expected end of scalable vector dimensions here!', + is_parse_error=True) if scaling_shape is not None: # TODO: handle scaling vectors! @@ -877,13 +928,13 @@ def must_parse_vector_attrs(self) -> AnyVectorType: def must_parse_tensor_or_memref_dims(self) -> list[int] | None: with self.tokenizer.configured(break_on=self.tokenizer.break_on + ('x', )): - if self.tokenizer.next_token(peek=True).text == '*': - # consume `*` - self.tokenizer.next_token() + # check for unranked-ness + if self.tokenizer.next_token_of_pattern('*') is not None: # consume `x` - self.assert_eq( - self.tokenizer.next_token(), 'x', - 'Unranked tensors must follow format (`<*x` type `>`)') + self.must_parse_characters( + 'x', + 'Unranked tensors must follow format (`<*x` type `>`)', + is_parse_error=True) else: # parse rank: return list(self.try_parse_numerical_dims(lower_bound=0)) @@ -965,15 +1016,15 @@ def raise_error(self, msg: str, at_position: Span | None = None): raise ParseError(at_position, msg, self.tokenizer.history) - def assert_eq(self, got: Span, want: str, msg: str): - if got.text == want: - return - raise AssertionError( - "Assertion failed (assert `{}` == `{}`): {}".format( - got.text, want, msg), got) - - def must_parse_characters(self, text: str, msg: str): - self.assert_eq(self.tokenizer.next_token(), text, msg) + def must_parse_characters(self, + text: str, + msg: str, + is_parse_error: bool = False) -> Span: + if (match := self.tokenizer.next_token_of_pattern(text)) is None: + if is_parse_error: + self.raise_error(msg) + raise AssertionError("Unexpected input: {}".format(msg)) + return match def must_parse_op_result_list( self) -> list[tuple[Span, Attribute] | Span] | None: @@ -1088,18 +1139,26 @@ def must_parse_attribute(self) -> Attribute: builtin_val = self.try_parse_builtin_attr() if builtin_val is None: - self.raise_error("Unknown attribute!") + self.raise_error( + "Unknown attribute (neither builtin nor dialect could be parsed)!" + ) return builtin_val def must_parse_attribute_type(self) -> Attribute: + """ + Parses `:` type and returns the type + """ self.must_parse_characters( - ':', 'Expected attribute type definition here ( `:` type )') + ':', 'Expected attribute type definition here ( `:` type )') return self.expect( self.try_parse_type, - 'Expected attribute type definition here ( `:` type )') + 'Expected attribute type definition here ( `:` type )') def try_parse_builtin_attr(self) -> Attribute: + """ + Tries to parse a bultin attribute, e.g. a string literal, int, array, etc.. + """ attrs = (self.try_parse_builtin_float_attr, self.try_parse_builtin_int_attr, self.try_parse_builtin_str_attr, @@ -1126,10 +1185,11 @@ def try_parse_builtin_int_attr(self) -> IntegerAttr | None: return IntegerAttr.from_params(int(value.text), type) def try_parse_builtin_float_attr(self) -> FloatAttr | None: - with self.tokenizer.backtracking(): + with self.tokenizer.backtracking("float literal"): value = self.expect( self.try_parse_float_literal, 'Float attribute must start with a float literal!') + # if we don't see a ':' indicating a type signature if self.tokenizer.next_token(peek=True).text != ':': return FloatAttr.from_value(float(value.text)) @@ -1149,7 +1209,7 @@ def try_parse_builtin_str_attr(self): if self.tokenizer.next_token(peek=True).text != '"': return None - with self.tokenizer.backtracking(): + with self.tokenizer.backtracking("string literal"): literal = self.try_parse_string_literal() if self.tokenizer.next_token(peek=True).text != ':': return StringAttr.from_str(literal.string_contents) @@ -1158,7 +1218,7 @@ def try_parse_builtin_str_attr(self): def try_parse_builtin_arr_attr(self) -> list[Attribute] | None: if self.tokenizer.next_token(peek=True).text != '[': return None - with self.tokenizer.backtracking(): + with self.tokenizer.backtracking("array literal"): self.must_parse_characters('[', 'Array literals must start with `[`') attrs = self.must_parse_list_of(self.try_parse_builtin_attr, @@ -1191,8 +1251,7 @@ def try_parse_attr_dict(self) -> dict[str, Attribute] | None: ParserCommons.BNF.attr_dict.collect(res, dict()).get( 'attributes', list())) - def must_parse_function_type( - self) -> tuple[list[Attribute], list[Attribute]]: + def must_parse_function_type(self) -> FunctionType: """ Parses function-type: @@ -1210,9 +1269,13 @@ def must_parse_function_type( '(', 'First group of function args must start with a `(`') args: list[Attribute] = self.must_parse_list_of( self.try_parse_type, 'Expected type here!') - self.must_parse_characters(')', "End of function type arguments") + self.must_parse_characters(')', + "Malformed function type!", + is_parse_error=True) - self.must_parse_characters('->', 'Function type!') + self.must_parse_characters('->', + 'Malformed function type!', + is_parse_error=True) return FunctionType.from_lists( args, self.must_parse_type_or_type_list_parens()) @@ -1225,11 +1288,12 @@ def must_parse_type_or_type_list_parens(self) -> list[Attribute]: type-list-parens ::= `(` `)` | `(` type-list-no-parens `)` type-list-no-parens ::= type (`,` type)* """ - if self.tokenizer.next_token(peek=True).text == '(': - self.must_parse_characters('(', 'Function type!') + if self.tokenizer.next_token_of_pattern('(') is not None: args: list[Attribute] = self.must_parse_list_of( self.try_parse_type, 'Expected type here!') - self.must_parse_characters(')', "End of function type args") + self.must_parse_characters(')', + "Unclosed function type argument list!", + is_parse_error=True) else: args = [self.try_parse_type()] if args[0] is None: @@ -1241,7 +1305,7 @@ def must_parse_type_or_type_list_parens(self) -> list[Attribute]: def try_parse_function_type(self) -> FunctionType | None: if self.tokenizer.next_token(peek=True).text != '(': return None - with self.tokenizer.backtracking('Function type'): + with self.tokenizer.backtracking('function type'): return self.must_parse_function_type() @@ -1313,3 +1377,38 @@ def try_parse_function_type(self) -> FunctionType | None: type-alias-def ::= '!' alias-name '=' type type-alias ::= '!' alias-name """ + +if __name__ == '__main__': + infile = sys.argv[-1] + from xdsl.dialects.affine import Affine + from xdsl.dialects.arith import Arith + from xdsl.dialects.builtin import Builtin + from xdsl.dialects.cf import Cf + from xdsl.dialects.cmath import CMath + from xdsl.dialects.func import Func + from xdsl.dialects.irdl import IRDL + from xdsl.dialects.llvm import LLVM + from xdsl.dialects.memref import MemRef + from xdsl.dialects.scf import Scf + import os + + ctx = MLContext() + ctx.register_dialect(Builtin) + ctx.register_dialect(Func) + ctx.register_dialect(Arith) + ctx.register_dialect(MemRef) + ctx.register_dialect(Affine) + ctx.register_dialect(Scf) + ctx.register_dialect(Cf) + ctx.register_dialect(CMath) + ctx.register_dialect(IRDL) + ctx.register_dialect(LLVM) + + p = MlirParser(infile, open(infile, 'r').read(), ctx) + + printer = Printer() + try: + for op in p.begin_parse(): + printer.print_op(op) + except ParseError as pe: + pe.print_pretty() diff --git a/xdsl/utils/bnf.py b/xdsl/utils/bnf.py index 514d333920..c72e6aaf5a 100644 --- a/xdsl/utils/bnf.py +++ b/xdsl/utils/bnf.py @@ -22,7 +22,7 @@ def must_parse(self, parser: MlirParser) -> T: raise NotImplemented() def try_parse(self, parser: MlirParser) -> T | None: - with parser.tokenizer.backtracking(self.debug_name or repr(self)): + with parser.tokenizer.backtracking(self.debug_name): return self.must_parse(parser) def collect(self, value, collection: dict) -> dict: @@ -85,17 +85,15 @@ class Nonterminal(BNFToken): debug_name: str | None = field(kw_only=True, default=None) + def parser_func_name(self, prefix: str): + return prefix + self.name.replace('-', '_') + def must_parse(self, parser: MlirParser): - if hasattr(parser, 'must_parse_{}'.format(self.name.replace('-', - '_'))): - return getattr(parser, - 'must_parse_{}'.format(self.name.replace('-', - '_')))() - elif hasattr(parser, 'try_parse_{}'.format(self.name.replace('-', - '_'))): + if hasattr(parser, self.parser_func_name('must_parse_')): + return getattr(parser, self.parser_func_name('must_parse_'))() + elif hasattr(parser, self.parser_func_name('try_parse_')): return parser.expect( - getattr(parser, - 'try_parse_{}'.format(self.name.replace('-', '_'))), + getattr(parser, self.parser_func_name('try_parse_')), 'Expected to parse {} here!'.format(self.name)) else: raise NotImplementedError("Parser cannot parse {}".format( From d416eff6f9264ea541c852374d23bf59f983f5e1 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Fri, 6 Jan 2023 11:40:31 +0100 Subject: [PATCH 10/65] [parser] Add basics for parsing xDSL --- xdsl/parser_ng.py | 297 +++++++++++++++++++++++++++++++++++----------- xdsl/utils/bnf.py | 6 +- 2 files changed, 227 insertions(+), 76 deletions(-) diff --git a/xdsl/parser_ng.py b/xdsl/parser_ng.py index 230aaaf8ac..b24cea0643 100644 --- a/xdsl/parser_ng.py +++ b/xdsl/parser_ng.py @@ -35,16 +35,23 @@ def __init__(self, span: Span, msg: str, history: BacktrackingHistory | None = None): + if span is None: + print(span, msg) + print("Huston, we have a problem") + super().__init__("None None None FUCK None NONE") + return super().__init__(span.print_with_context(msg)) self.span = span self.msg = msg self.history = history - def print_pretty(self, file=sys.stderr, print_history: bool = True): - if self.history and print_history: - self.history.print_unroll(file) + def print_pretty(self, file=sys.stderr): print(self.span.print_with_context(self.msg), file=file) + def print_with_history(self): + if self.history is not None: + self.history.print_unroll() + @dataclass class BacktrackingHistory: @@ -59,7 +66,7 @@ def print_unroll(self, file=sys.stderr): print("Parsing of {} failed:".format(self.region_name or ''), file=file) - self.error.print_pretty(file=file, print_history=False) + self.error.print_pretty(file=file) def get_farthest_point(self) -> int: """ @@ -86,7 +93,7 @@ def __init__(self, reason: str | None = None): @dataclass(frozen=True) class Span: """ - Parts of the input are always passed around as spans so we know where they originated. + Parts of the input are always passed around as spans, so we know where they originated. """ start: int @@ -242,7 +249,7 @@ class Tokenizer: break_on: tuple[str, ...] = ('.', '%', ' ', '(', ')', '[', ']', '{', '}', '<', '>', ':', '=', '@', '?', '|', '->', '-', - '//', '\n', '\t', '#', '"', "'", ',') + '//', '\n', '\t', '#', '"', "'", ',', '!') """ characters the tokenizer should break on """ @@ -290,7 +297,8 @@ def backtracking(self, region_name: str | None = None): # clear error history when something doesn't fail # this is because we are only interested in the last "cascade" of failures. # if a backtracking() completes without failre, something has been parsed (we assume) - self.history = None + if self.pos > starting_position and self.history is not None: + self.history = None except Exception as ex: how_far_we_got = self.pos @@ -340,25 +348,25 @@ def history_entry_from_exception(self, ex: Exception, region: str, traceback.print_exc(file=tb) reason[0] += '\n' + tb.getvalue() - return BacktrackingHistory(ParseError(self.last_token, reason[-1]), + return BacktrackingHistory(ParseError(self.last_token, reason[-1], self.history), self.history, region, pos) elif isinstance(ex, BacktrackingAbort): return BacktrackingHistory( ParseError( self.next_token(peek=True), 'Backtracking aborted: {}'.format(ex.reason - or 'unknown reason')), + or 'unknown reason'), self.history), self.history, region, pos) elif isinstance(ex, EOFError): return BacktrackingHistory( - ParseError(self.last_token, "Encountered EOF"), self.history, + ParseError(self.last_token, "Encountered EOF", self.history), self.history, region, pos) print("Warning: Unexpected error in backtracking:", file=sys.stderr) traceback.print_exception(ex, file=sys.stderr) return BacktrackingHistory( - ParseError(self.last_token, "Unexpected exception: {}".format(ex)), + ParseError(self.last_token, "Unexpected exception: {}".format(ex), self.history), self.history, region, pos) def next_token(self, start: int | None = None, peek: bool = False) -> Span: @@ -490,7 +498,7 @@ def configured(self, self.break_on = save[1] self.ignore_whitespace = save[2] - def starts_with(self, text: str | re.Pattern): + def starts_with(self, text: str | re.Pattern) -> bool: start = self.next_pos() if isinstance(text, re.Pattern): return text.match(self.input.content, start) is None @@ -502,16 +510,16 @@ class ParserCommons: Colelction of common things used in parsing MLIR/IRDL """ - integer_literal = re.compile(r'[+-]?([0-9]+|0x[0-9A-f]+)') + integer_literal = re.compile(r'[+-]?([0-9]+|0x[0-9A-Fa-f]+)') decimal_literal = re.compile(r'[+-]?([1-9][0-9]*)') string_literal = re.compile(r'"([^\n\f\v\r"]|\\[nfvr"])+"') float_literal = re.compile(r'[-+]?[0-9]+\.[0-9]*([eE][-+]?[0-9]+)?') - bare_id = re.compile(r'[A-z_][A-z0-9_$.]+') - value_id = re.compile(r'%([0-9]+|([A-z_$.-][0-9A-z_$.-]*))') - suffix_id = re.compile(r'([0-9]+|([A-z_$.-][0-9A-z_$.-]*))') - block_id = re.compile(r'\^([0-9]+|([A-z_$.-][0-9A-z_$.-]*))') - type_alias = re.compile(r'![A-z_][A-z0-9_$.]+') - attribute_alias = re.compile(r'#[A-z_][A-z0-9_$.]+') + bare_id = re.compile(r'[A-Za-z_][\w$.]+') + value_id = re.compile(r'%([0-9]+|([A-Za-z_$.-][\w$.-]*))') + suffix_id = re.compile(r'([0-9]+|([A-Za-z_$.-][\w$.-]*))') + block_id = re.compile(r'\^([0-9]+|([A-Za-z_$.-][\w$.-]*))') + type_alias = re.compile(r'![A-Za-z_][\w$.]+') + attribute_alias = re.compile(r'#[A-Za-z_][\w$.]+') boolean_literal = re.compile(r'(true|false)') builtin_type = re.compile('(({}))'.format(')|('.join(( r'[su]?i\d+', @@ -525,6 +533,18 @@ class ParserCommons: 'index', # TODO: add all the Float8E4M3FNType, Float8E5M2Type, and BFloat16Type )))) + builtin_type_xdsl = re.compile('!(({}))'.format(')|('.join(( + r'[su]?i\d+', + r'f\d+', + 'tensor', + 'vector', + 'memref', + 'complex', + 'opaque', + 'tuple', + 'index', + # TODO: add all the Float8E4M3FNType, Float8E5M2Type, and BFloat16Type + )))) double_colon = re.compile('::') comma = re.compile(',') @@ -544,7 +564,7 @@ class BNF: BNF.ListOf(BNF.Nonterminal('block-id'), allow_empty=False, bind='blocks'), - # TODD: allow for block args here?! (accordin to spec) + # TODD: allow for block args here?! (according to spec) BNF.Literal(']') ], debug_name="operations optional block id group"), @@ -557,21 +577,30 @@ class BNF: BNF.Literal(')') ], debug_name="operation regions"), - BNF.Nonterminal('attr-dict', + BNF.Nonterminal('optional-attr-dict', bind='attributes', debug_name="attrbiute dictionary"), BNF.Literal(':'), BNF.Nonterminal('function-type', bind='type_signature') ], debug_name="generic operation body") - attr_dict = BNF.Group([ + attr_dict_mlir = BNF.Group([ BNF.Literal('{'), BNF.ListOf(BNF.Nonterminal('attribute-entry', debug_name="attribute entry"), bind='attributes'), BNF.Literal('}') ], - debug_name="attrbute dictionary") + debug_name="attrbute dictionary") + + attr_dict_xdsl = BNF.Group([ + BNF.Literal('['), + BNF.ListOf(BNF.Nonterminal('attribute-entry', + debug_name="attribute entry"), + bind='attributes'), + BNF.Literal(']') + ], + debug_name="attrbute dictionary") class MlirParser: @@ -635,14 +664,14 @@ def begin_parse(self): return ops def must_parse_block(self) -> Block: - id, args = self.must_parse_optional_block_label() + block_id, args = self.must_parse_optional_block_label() block = Block() - if id is not None: - assert id.text not in self._blocks - self._blocks[id.text] = block + if block_id is not None: + assert block_id.text not in self._blocks + self._blocks[block_id.text] = block - for i, (name, type) in args: + for i, (name, type) in enumerate(args): arg = BlockArgument(type, block, i) self._ssaValues[name.text] = arg block.args.append(arg) @@ -653,7 +682,7 @@ def must_parse_block(self) -> Block: return block def must_parse_optional_block_label( - self) -> tuple[Span, list[int, tuple[Span, Attribute]]]: + self) -> tuple[Span | None, list[tuple[Span, Attribute]]]: block_id = self.try_parse_block_id() arg_list = list() @@ -661,7 +690,7 @@ def must_parse_optional_block_label( assert block_id.text not in self._blocks, "Blocks cannot have the same ID!" if self.tokenizer.next_token(peek=True).text == '(': - arg_list = list(enumerate(self.must_parse_block_arg_list())) + arg_list = self.must_parse_block_arg_list() self.must_parse_characters(':', 'Block label must end in a `:`!') @@ -671,7 +700,7 @@ def must_parse_block_arg_list(self) -> list[tuple[Span, Attribute]]: self.must_parse_characters('(', 'Block arguments must start with `(`') args = self.must_parse_list_of(self.try_parse_value_id_and_type, - "Expected ") + "Expected value-id and type here!") self.must_parse_characters(')', 'Expected closing of block arguments!', @@ -795,22 +824,35 @@ def try_parse_type(self) -> Attribute | None: return None def try_parse_dialect_type_or_attribute( - self, kind: Literal['type', 'attr']) -> Attribute | None: + self, kind: Literal['type', 'attr', + 'type or attr']) -> Attribute | None: with self.tokenizer.backtracking("dialect " + kind): - if kind == 'type': - self.must_parse_characters( - '!', "Dialect types must start with a `!`") - else: - self.must_parse_characters( - '#', "Dialect attributes must start with a `#`") + # check to see if we get the expected qualifiers (! for type, # for attr) + detected_kind = None + + if (match := self.tokenizer.next_token_of_pattern( + re.compile('[!#]'))) is not None: + detected_kind = {'!': 'type', '#': 'attr'}[match.text] + + if detected_kind is None or detected_kind not in kind: + self.raise_error( + "Dialect {} must start with {}!".format( + kind, { + 'type': '`!`', + 'attr': '`#`', + 'type or attr': 'either `!` or `#`' + }[kind]), match) type_name = self.tokenizer.next_token_of_pattern( ParserCommons.bare_id) if type_name is None: - self.raise_error("Expected a type name") + self.raise_error( + "Expected the name of the {} name".format(detected_kind)) - type_def = self.ctx.get_attr(type_name.text) + type_def = self.ctx.get_optional_attr(type_name.text) + if type_def is None: + self.raise_error("'{}' is not a know attribute!".format(type_name.text), type_name) # pass the task of parsing parameters on to the attribute/type definition param_list = type_def.parse_parameters(self) @@ -821,10 +863,17 @@ def try_parse_builtin_type(self) -> Attribute | None: parse a builtin-type like i32, index, vector etc. """ with self.tokenizer.backtracking("builtin type"): - name = self.tokenizer.next_token_of_pattern( - ParserCommons.builtin_type) + pattern = ParserCommons.builtin_type + if self.accent == MlirParser.Accent.XDSL: + pattern = ParserCommons.builtin_type_xdsl + name = self.tokenizer.next_token_of_pattern(pattern) if name is None: raise BacktrackingAbort("Expected builtin name!") + # if we are parsing xDSL, we have to skip the leading '!' + if self.accent == MlirParser.Accent.XDSL: + name = Span(start=name.start + 1, + end=name.end, + input=name.input) if name.text == 'index': return IndexType() if (re_match := re.match(r'^[su]?i(\d+)$', name.text)) is not None: @@ -854,8 +903,8 @@ def must_parse_builtin_parametrized_type( self, name: Span) -> ParametrizedAttribute: def unimplemented() -> ParametrizedAttribute: - raise ParseError(self.tokenizer.next_token(), - "Type not supported yet!") + raise ParseError(name, + "Builtin {} not supported yet!".format(name.text)) builtin_parsers: dict[str, Callable[[], ParametrizedAttribute]] = { 'vector': self.must_parse_vector_attrs, @@ -865,11 +914,10 @@ def unimplemented() -> ParametrizedAttribute: 'opaque': unimplemented, 'tuple': unimplemented, } - if name.text not in builtin_parsers: - raise ParseError(name, "Unknown builtin {}".format(name.text)) self.must_parse_characters('<', 'Expected parameter list here!') - res = builtin_parsers[name.text]() + # get the parser for the type, falling back to the unimplemented warning + res = builtin_parsers.get(name.text, unimplemented)() self.must_parse_characters('>', 'Expected end of parameter list here!', is_parse_error=True) @@ -1033,7 +1081,7 @@ def must_parse_op_result_list( (MlirParser.Accent.XDSL, self.try_parse_value_id_and_type))))[self.accent] - return self.must_parse_list_of(self.try_parse_value_id, + return self.must_parse_list_of(inner_parser, 'Expected op-result here!', allow_empty=False) @@ -1048,7 +1096,17 @@ def try_parse_operation(self) -> Operation | None: else: result_list = [] - generic_op = ParserCommons.BNF.generic_operation_body.try_parse( + if not self.tokenizer.starts_with('"'): + # parse custom op: + custom_op_name = self.try_parse_bare_id() + if custom_op_name is None: + self.raise_error( + "Expected an operation name here, either a bare-id, or a string literal!" + ) + op_type = self.ctx.get_op(custom_op_name.text) + return op_type.parse([type for _, type in result_list], self) + + generic_op = ParserCommons.BNF.generic_operation_body.must_parse( self) if generic_op is None: self.raise_error("custom operations not supported as of yet!") @@ -1068,11 +1126,14 @@ def try_parse_operation(self) -> Operation | None: op_type = self.ctx.get_op(values['name'].string_contents) return op_type.create( - [self._ssaValues[arg.text] for arg in values['args']], - ret_types, values['attributes'], [ + operands=[self._ssaValues[arg.text] for arg in values['args']], + result_types=ret_types, + attributes=values['attributes'], + successors=[ self._blocks[block_name.text] for block_name in values.get('blocks', []) - ], values.get('regions', [])) + ], + regions=values.get('regions', [])) def must_parse_region(self) -> Region: oldSSAVals = self._ssaValues.copy() @@ -1129,12 +1190,17 @@ def must_parse_attribute(self) -> Attribute: Parse attribute (either builtin or dialect) """ # all dialect attrs must start with '#', so we check for that first (as it's easier) - if self.tokenizer.next_token(peek=True).text == '#': - value = self.try_parse_dialect_type_or_attribute('attr') - if value is None: + if self.tokenizer.next_token(peek=True).text in '#!': + # in MLIR, # and ! are prefixes for dialext types/attrs + value = self.try_parse_dialect_type_or_attribute('type or attr') + # if we are in MLIR, and we get nothing, that's an error! + if not self.is_xdsl() and value is None: # if is_mlir and value is none, raise error self.raise_error( - '`#` must be followed by a valid builtin attribute!') - return value + '`#` or `!` must be followed by a valid dialect attribute or type!' + ) + # otherwise, if we have a value, return it. + if value is not None: + return value builtin_val = self.try_parse_builtin_attr() @@ -1158,7 +1224,15 @@ def must_parse_attribute_type(self) -> Attribute: def try_parse_builtin_attr(self) -> Attribute: """ Tries to parse a bultin attribute, e.g. a string literal, int, array, etc.. + + If the mode is xDSL, it also allows parsing of builtin types """ + # in xdsl, two things are different here: + # 1. types are considered valid attributes + # 2. all types, builtins included, are prefixed with ! + if self.is_xdsl() and self.tokenizer.starts_with('!'): + return self.try_parse_builtin_type() + attrs = (self.try_parse_builtin_float_attr, self.try_parse_builtin_int_attr, self.try_parse_builtin_str_attr, @@ -1221,20 +1295,29 @@ def try_parse_builtin_arr_attr(self) -> list[Attribute] | None: with self.tokenizer.backtracking("array literal"): self.must_parse_characters('[', 'Array literals must start with `[`') - attrs = self.must_parse_list_of(self.try_parse_builtin_attr, + attrs = self.must_parse_list_of(self.must_parse_attribute, 'Expected array entry!') self.must_parse_characters( ']', 'Array literals must be enclosed by square brackets!') return ArrayAttr.from_list(attrs) - def must_parse_attr_dict(self) -> dict[str, Attribute]: - res = ParserCommons.BNF.attr_dict.try_parse(self) - if res is None: + def must_parse_optional_attr_dict(self) -> dict[str, Attribute]: + tree, prefix = { + MlirParser.Accent.MLIR: (ParserCommons.BNF.attr_dict_mlir, '{'), + MlirParser.Accent.XDSL: (ParserCommons.BNF.attr_dict_xdsl, '[') + }[self.accent] + + if self.tokenizer.next_token_of_pattern(prefix, peek=True) is None: return dict() + + res = tree.must_parse(self) + return self.attr_dict_from_tuple_list( - ParserCommons.BNF.attr_dict.collect(res, dict()).get( + ParserCommons.BNF.attr_dict_mlir.collect(res, dict()).get( 'attributes', list())) + + def attr_dict_from_tuple_list( self, tuple_list: list[tuple[Span, Attribute]]) -> dict[str, Attribute]: @@ -1243,14 +1326,6 @@ def attr_dict_from_tuple_list( ) else span.text), attr) for span, attr in tuple_list) - def try_parse_attr_dict(self) -> dict[str, Attribute] | None: - res = ParserCommons.BNF.attr_dict.try_parse(self) - if res is None: - return None - return self.attr_dict_from_tuple_list( - ParserCommons.BNF.attr_dict.collect(res, dict()).get( - 'attributes', list())) - def must_parse_function_type(self) -> FunctionType: """ Parses function-type: @@ -1259,6 +1334,7 @@ def must_parse_function_type(self) -> FunctionType: (i32) -> () () -> (i32, i32) (i32, i32) -> () + () -> i32 Non-viable types are: i32 -> i32 i32 -> () @@ -1308,6 +1384,78 @@ def try_parse_function_type(self) -> FunctionType | None: with self.tokenizer.backtracking('function type'): return self.must_parse_function_type() + def must_parse_region_list(self) -> list[Region]: + """ + Parses a sequence of regions for as long as there is a `{` in the input. + """ + regions = [] + while self.tokenizer.next_token(peek=True).text == '{': + regions.append(self.must_parse_region()) + return regions + + def is_xdsl(self) -> bool: + return self.accent == MlirParser.Accent.XDSL + + # HERE STARTS A SOMEWHAT CURSED COMPATIBILITY LAYER: + # since we don't want to rewrite all dialects currently, the new emulator needs to expose the same + # interface to the dialect definitions. Here we implement that interface. + + _OperationType = TypeVar('_OperationType', bound=Operation) + + def parse_op_with_default_format( + self, + op_type: type[_OperationType], + result_types: list[Attribute], + skip_white_space: bool = True) -> _OperationType: + """ + Compatibility wrapper so the new parser can be passed instead of the old one. Parses everything after the + operation name. + + This implicitly assumes XDSL format, and will fail on MLIR style operations + """ + # TODO: remove this function and restructure custom op / irdl parsing + assert self.accent == MlirParser.Accent.XDSL, "Function parse_op_with_default_format requires xDSL format" + + args = self.must_parse_block_arg_list() + successors: list[Span] = [] + if self.tokenizer.next_token_of_pattern('(') is not None: + successors = self.must_parse_list_of(self.try_parse_block_id, + 'Malformed block-id!') + self.must_parse_characters( + ')', + 'Expected either a block id or the end of the successor list here' + ) + + attributes = self.must_parse_optional_attr_dict() + + regions = self.must_parse_region_list() + + return op_type.create( + operands=[self._ssaValues[span.text] for span, _ in args], + result_types=result_types, + attributes=attributes, + successors=[self._blocks[span.text] for span in successors], + regions=regions) + + def parse_paramattr_parameters( + self, + expect_brackets: bool = False, + skip_white_space: bool = True) -> list[Attribute]: + if self.tokenizer.next_token_of_pattern( + '<') is None and expect_brackets: + self.raise_error("Expected start attribute parameters here (`<`)!") + + res = self.must_parse_list_of(self.must_parse_attribute, + 'Expected another attribute here!') + + if self.tokenizer.next_token_of_pattern( + '>') is None and expect_brackets: + self.raise_error( + "Malformed parameter list, expected either another parameter or `>`!" + ) + + return res + """ digit ::= [0-9] @@ -1404,11 +1552,16 @@ def try_parse_function_type(self) -> FunctionType | None: ctx.register_dialect(IRDL) ctx.register_dialect(LLVM) - p = MlirParser(infile, open(infile, 'r').read(), ctx) + dialect = {'xdsl': MlirParser.Accent.XDSL, 'mlir': MlirParser.Accent.MLIR} + + p = MlirParser(infile, + open(infile, 'r').read(), + ctx, + accent=dialect.get(infile.split('.')[-1])) printer = Printer() try: for op in p.begin_parse(): printer.print_op(op) except ParseError as pe: - pe.print_pretty() + pe.print_with_history() diff --git a/xdsl/utils/bnf.py b/xdsl/utils/bnf.py index c72e6aaf5a..2614d0b522 100644 --- a/xdsl/utils/bnf.py +++ b/xdsl/utils/bnf.py @@ -100,10 +100,8 @@ def must_parse(self, parser: MlirParser): self.name)) def try_parse(self, parser: MlirParser) -> T | None: - if hasattr(parser, 'try_parse_{}'.format(self.name.replace('-', '_'))): - return getattr(parser, - 'try_parse_{}'.format(self.name.replace('-', - '_')))() + if hasattr(parser, self.parser_func_name('try_parse_')): + return getattr(parser,self.parser_func_name('try_parse_'))() return super().try_parse(parser) def __repr__(self): From 48b92a1f0e5503f676c141f253952a12f51273bd Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Mon, 9 Jan 2023 11:41:43 +0100 Subject: [PATCH 11/65] [parser] separate MLIR and xDSL parsing into two separate classes --- xdsl/parser_ng.py | 580 +++++++++++++++++++++++++++++----------------- 1 file changed, 373 insertions(+), 207 deletions(-) diff --git a/xdsl/parser_ng.py b/xdsl/parser_ng.py index b24cea0643..bba1d02116 100644 --- a/xdsl/parser_ng.py +++ b/xdsl/parser_ng.py @@ -3,6 +3,7 @@ import contextlib import sys import traceback +from abc import ABC, abstractmethod from dataclasses import dataclass, field import re import ast @@ -35,11 +36,6 @@ def __init__(self, span: Span, msg: str, history: BacktrackingHistory | None = None): - if span is None: - print(span, msg) - print("Huston, we have a problem") - super().__init__("None None None FUCK None NONE") - return super().__init__(span.print_with_context(msg)) self.span = span self.msg = msg @@ -127,7 +123,8 @@ def print_with_context(self, msg: str | None = None) -> str: along these. """ info = self.input.get_lines_containing(self) - assert info is not None + if info is None: + return "Unknown location of span {}. Error: ".format(self, msg) lines, offset_of_first_line, line_no = info # offset relative to the first line: offset = self.start - offset_of_first_line @@ -215,7 +212,7 @@ def get_lines_containing(self, next_start = source.find('\n', start) line_no += 1 # handle eof - if next_start == -1: + if next_start == -1 : return None # as long as the next newline comes before the spans start we are good if next_start < span.start: @@ -576,7 +573,7 @@ class BNF: allow_empty=False), BNF.Literal(')') ], - debug_name="operation regions"), + debug_name="operation regions"), BNF.Nonterminal('optional-attr-dict', bind='attributes', debug_name="attrbiute dictionary"), @@ -591,7 +588,7 @@ class BNF: bind='attributes'), BNF.Literal('}') ], - debug_name="attrbute dictionary") + debug_name="attrbute dictionary") attr_dict_xdsl = BNF.Group([ BNF.Literal('['), @@ -600,10 +597,10 @@ class BNF: bind='attributes'), BNF.Literal(']') ], - debug_name="attrbute dictionary") + debug_name="attrbute dictionary") -class MlirParser: +class BaseParser(ABC): """ Basic recursive descent parser. @@ -624,17 +621,11 @@ class MlirParser: must_ type parsers are preferred because they are explicit about their failure modes. """ - class Accent(Enum): - XDSL = 'xDSL' - MLIR = 'MLIR' - - accent: Accent - ctx: MLContext """xDSL context.""" - _ssaValues: dict[str, SSAValue] - _blocks: dict[str, Block] + ssaValues: dict[str, SSAValue] + blocks: dict[str, Block] T_ = TypeVar('T_') """ @@ -645,15 +636,11 @@ class Accent(Enum): def __init__(self, input: str, name: str, - ctx: MLContext, - accent: str | Accent = Accent.MLIR): + ctx: MLContext, ): self.tokenizer = Tokenizer(Input(input, name)) self.ctx = ctx - if isinstance(accent, str): - accent = MlirParser.Accent[accent] - self.accent = accent - self._ssaValues = dict() - self._blocks = dict() + self.ssaValues = dict() + self.blocks = dict() def begin_parse(self): ops = [] @@ -668,12 +655,12 @@ def must_parse_block(self) -> Block: block = Block() if block_id is not None: - assert block_id.text not in self._blocks - self._blocks[block_id.text] = block + assert block_id.text not in self.blocks + self.blocks[block_id.text] = block for i, (name, type) in enumerate(args): arg = BlockArgument(type, block, i) - self._ssaValues[name.text] = arg + self.ssaValues[name.text] = arg block.args.append(arg) while (next_op := self.try_parse_operation()) is not None: @@ -687,7 +674,7 @@ def must_parse_optional_block_label( arg_list = list() if block_id is not None: - assert block_id.text not in self._blocks, "Blocks cannot have the same ID!" + assert block_id.text not in self.blocks, "Blocks cannot have the same ID!" if self.tokenizer.next_token(peek=True).text == '(': arg_list = self.must_parse_block_arg_list() @@ -754,7 +741,7 @@ def must_parse_list_of(self, items.append(first_item) while (match := self.tokenizer.next_token_of_pattern(separator_pattern) - ) is not None: + ) is not None: next_item = try_parse() if next_item is None: # if the separator is emtpy, we are good here @@ -819,85 +806,68 @@ def try_parse_type(self) -> Attribute | None: if (builtin_type := self.try_parse_builtin_type()) is not None: return builtin_type if (dialect_type := - self.try_parse_dialect_type_or_attribute('type')) is not None: + self.try_parse_dialect_type()) is not None: return dialect_type return None - def try_parse_dialect_type_or_attribute( - self, kind: Literal['type', 'attr', - 'type or attr']) -> Attribute | None: - with self.tokenizer.backtracking("dialect " + kind): - # check to see if we get the expected qualifiers (! for type, # for attr) - detected_kind = None + def try_parse_dialect_type_or_attribute(self) -> Attribute | None: + """ + Parse a type or an attribute. + """ + kind = self.tokenizer.next_token_of_pattern(re.compile('[!#]'), peek=True) - if (match := self.tokenizer.next_token_of_pattern( - re.compile('[!#]'))) is not None: - detected_kind = {'!': 'type', '#': 'attr'}[match.text] + if kind is None: + return None - if detected_kind is None or detected_kind not in kind: - self.raise_error( - "Dialect {} must start with {}!".format( - kind, { - 'type': '`!`', - 'attr': '`#`', - 'type or attr': 'either `!` or `#`' - }[kind]), match) + with self.tokenizer.backtracking("dialect attribute or type"): + self.tokenizer.consume_peeked(kind) + if kind.text == '!': + return self.must_parse_dialect_type_or_attribute_inner('type') + else: + return self.must_parse_dialect_type_or_attribute_inner('attribute') - type_name = self.tokenizer.next_token_of_pattern( - ParserCommons.bare_id) + def try_parse_dialect_type(self): + """ + Parse a dialect type (something prefixed by `!`, defined by a dialect) + """ + if self.tokenizer.next_token_of_pattern('!', peek=True) is None: + return None + with self.tokenizer.backtracking("dialect type"): + self.tokenizer.next_token_of_pattern('!') + return self.must_parse_dialect_type_or_attribute_inner('type') - if type_name is None: - self.raise_error( - "Expected the name of the {} name".format(detected_kind)) + def try_parse_dialect_attr(self): + """ + Parse a dialect attribute (something prefixed by `#`, defined by a dialect) + """ + if self.tokenizer.next_token_of_pattern('#', peek=True) is None: + return None + with self.tokenizer.backtracking("dialect attribute"): + self.tokenizer.next_token_of_pattern('#') + return self.must_parse_dialect_type_or_attribute_inner('attribute') + + def must_parse_dialect_type_or_attribute_inner(self, kind: str): + type_name = self.tokenizer.next_token_of_pattern( + ParserCommons.bare_id) - type_def = self.ctx.get_optional_attr(type_name.text) - if type_def is None: - self.raise_error("'{}' is not a know attribute!".format(type_name.text), type_name) + if type_name is None: + self.raise_error( + "Expected dialect {} name here!".format(kind)) - # pass the task of parsing parameters on to the attribute/type definition - param_list = type_def.parse_parameters(self) - return type_def(param_list) + type_def = self.ctx.get_optional_attr(type_name.text) + if type_def is None: + self.raise_error("'{}' is not a know attribute!".format(type_name.text), type_name) + # pass the task of parsing parameters on to the attribute/type definition + param_list = type_def.parse_parameters(self) + return type_def(param_list) + + @abstractmethod def try_parse_builtin_type(self) -> Attribute | None: """ parse a builtin-type like i32, index, vector etc. """ - with self.tokenizer.backtracking("builtin type"): - pattern = ParserCommons.builtin_type - if self.accent == MlirParser.Accent.XDSL: - pattern = ParserCommons.builtin_type_xdsl - name = self.tokenizer.next_token_of_pattern(pattern) - if name is None: - raise BacktrackingAbort("Expected builtin name!") - # if we are parsing xDSL, we have to skip the leading '!' - if self.accent == MlirParser.Accent.XDSL: - name = Span(start=name.start + 1, - end=name.end, - input=name.input) - if name.text == 'index': - return IndexType() - if (re_match := re.match(r'^[su]?i(\d+)$', name.text)) is not None: - signedness = { - 's': Signedness.SIGNED, - 'u': Signedness.UNSIGNED, - 'i': Signedness.SIGNLESS - } - return IntegerType.from_width(int(re_match.group(1)), - signedness[name.text[0]]) - - if (re_match := re.match(r'^f(\d+)$', name.text)) is not None: - width = int(re_match.group(1)) - type = { - 16: Float16Type, - 32: Float32Type, - 64: Float64Type - }.get(width, None) - if type is None: - self.raise_error( - "Unsupported floating point width: {}".format(width)) - return type() - - return self.must_parse_builtin_parametrized_type(name) + raise NotImplemented("Subclasses must implement this method!") def must_parse_builtin_parametrized_type( self, name: Span) -> ParametrizedAttribute: @@ -924,14 +894,13 @@ def unimplemented() -> ParametrizedAttribute: return res def must_parse_complex_attrs(self): - type = self.try_parse_type() self.raise_error("ComplexType is unimplemented!") def try_parse_numerical_dims(self, accept_closing_bracket: bool = False, lower_bound: int = 1) -> Iterable[int]: while (shape_arg := - self.try_parse_shape_element(lower_bound)) is not None: + self.try_parse_shape_element(lower_bound)) is not None: yield shape_arg # look out for the closing bracket for scalable vector dims if accept_closing_bracket and self.tokenizer.next_token( @@ -945,7 +914,7 @@ def try_parse_numerical_dims(self, def must_parse_vector_attrs(self) -> AnyVectorType: # also break on 'x' characters as they are separators in dimension parameters with self.tokenizer.configured(break_on=self.tokenizer.break_on + - ('x', )): + ('x',)): shape = list[int](self.try_parse_numerical_dims()) scaling_shape: list[int] | None = None @@ -975,7 +944,7 @@ def must_parse_vector_attrs(self) -> AnyVectorType: def must_parse_tensor_or_memref_dims(self) -> list[int] | None: with self.tokenizer.configured(break_on=self.tokenizer.break_on + - ('x', )): + ('x',)): # check for unranked-ness if self.tokenizer.next_token_of_pattern('*') is not None: # consume `x` @@ -1074,71 +1043,67 @@ def must_parse_characters(self, raise AssertionError("Unexpected input: {}".format(msg)) return match + @abstractmethod def must_parse_op_result_list( - self) -> list[tuple[Span, Attribute] | Span] | None: - inner_parser = (dict( - ((MlirParser.Accent.MLIR, self.try_parse_value_id), - (MlirParser.Accent.XDSL, - self.try_parse_value_id_and_type))))[self.accent] - - return self.must_parse_list_of(inner_parser, - 'Expected op-result here!', - allow_empty=False) + self) -> tuple[list[Span], list[Attribute] | None]: + raise NotImplemented() def try_parse_operation(self) -> Operation | None: with self.tokenizer.backtracking("operation"): - if self.tokenizer.next_token(peek=True).text == '%': - result_list = self.must_parse_op_result_list() + + result_list, ret_types = self.must_parse_op_result_list() + if len(result_list) > 0: self.must_parse_characters( '=', 'Operation definitions expect an `=` after op-result-list!' ) - else: - result_list = [] - if not self.tokenizer.starts_with('"'): - # parse custom op: - custom_op_name = self.try_parse_bare_id() - if custom_op_name is None: + # check for custom op format + op_name = self.try_parse_bare_id() + if op_name is not None: + op_type = self.ctx.get_op(op_name.text) + op = op_type.parse(ret_types, self) + else: + # check for basic op format + op_name = self.try_parse_string_literal() + if op_name is None: self.raise_error( "Expected an operation name here, either a bare-id, or a string literal!" ) - op_type = self.ctx.get_op(custom_op_name.text) - return op_type.parse([type for _, type in result_list], self) - generic_op = ParserCommons.BNF.generic_operation_body.must_parse( - self) - if generic_op is None: - self.raise_error("custom operations not supported as of yet!") + args, successors, attrs, regions, func_type = self.must_parse_operation_details() - values = ParserCommons.BNF.generic_operation_body.collect( - generic_op, dict()) + if ret_types is None: + assert func_type is not None + ret_types = func_type.outputs.data - arg_types, ret_types = ([], []) - if 'type_signature' in values: - functype: FunctionType = values['type_signature'] - arg_types, ret_types = functype.inputs.data, functype.outputs.data + op_type = self.ctx.get_op(op_name.string_contents) - if len(ret_types) != len(result_list): - raise ParseError( - values['name'], - "Mismatch between type signature and result list for op!") - - op_type = self.ctx.get_op(values['name'].string_contents) - return op_type.create( - operands=[self._ssaValues[arg.text] for arg in values['args']], - result_types=ret_types, - attributes=values['attributes'], - successors=[ - self._blocks[block_name.text] - for block_name in values.get('blocks', []) - ], - regions=values.get('regions', [])) + op = op_type.create( + operands=[self.ssaValues[span.text] for span in args], + result_types=ret_types, + attributes=attrs, + successors=[ + self.blocks[block_name.text] + for block_name in successors + ], + regions=regions) + + # Register the result SSA value names in the parser + for (idx, res) in enumerate(result_list): + ssa_val_name = res.text + if ssa_val_name in self.ssaValues: + self.raise_error(f"SSA value {ssa_val_name} is already defined", res) + self.ssaValues[ssa_val_name] = op.results[idx] + # TODO: check name? + self.ssaValues[ssa_val_name].name = ssa_val_name + + return op def must_parse_region(self) -> Region: - oldSSAVals = self._ssaValues.copy() - oldBBNames = self._blocks.copy() - self._blocks = dict[str, Block]() + oldSSAVals = self.ssaValues.copy() + oldBBNames = self.blocks.copy() + self.blocks = dict[str, Block]() region = Region() @@ -1157,8 +1122,8 @@ def must_parse_region(self) -> Region: return region finally: - self._ssaValues = oldSSAVals - self._blocks = oldBBNames + self.ssaValues = oldSSAVals + self.blocks = oldBBNames def try_parse_op_name(self) -> Span | None: if (str_lit := self.try_parse_string_literal()) is not None: @@ -1185,31 +1150,12 @@ def must_parse_attribute_entry(self) -> tuple[Span, Attribute]: return name, self.must_parse_attribute() + @abstractmethod def must_parse_attribute(self) -> Attribute: """ Parse attribute (either builtin or dialect) """ - # all dialect attrs must start with '#', so we check for that first (as it's easier) - if self.tokenizer.next_token(peek=True).text in '#!': - # in MLIR, # and ! are prefixes for dialext types/attrs - value = self.try_parse_dialect_type_or_attribute('type or attr') - # if we are in MLIR, and we get nothing, that's an error! - if not self.is_xdsl() and value is None: # if is_mlir and value is none, raise error - self.raise_error( - '`#` or `!` must be followed by a valid dialect attribute or type!' - ) - # otherwise, if we have a value, return it. - if value is not None: - return value - - builtin_val = self.try_parse_builtin_attr() - - if builtin_val is None: - self.raise_error( - "Unknown attribute (neither builtin nor dialect could be parsed)!" - ) - - return builtin_val + raise NotImplemented() def must_parse_attribute_type(self) -> Attribute: """ @@ -1224,15 +1170,8 @@ def must_parse_attribute_type(self) -> Attribute: def try_parse_builtin_attr(self) -> Attribute: """ Tries to parse a bultin attribute, e.g. a string literal, int, array, etc.. - - If the mode is xDSL, it also allows parsing of builtin types """ - # in xdsl, two things are different here: - # 1. types are considered valid attributes - # 2. all types, builtins included, are prefixed with ! - if self.is_xdsl() and self.tokenizer.starts_with('!'): - return self.try_parse_builtin_type() - + # order here is important! attrs = (self.try_parse_builtin_float_attr, self.try_parse_builtin_int_attr, self.try_parse_builtin_str_attr, @@ -1301,22 +1240,9 @@ def try_parse_builtin_arr_attr(self) -> list[Attribute] | None: ']', 'Array literals must be enclosed by square brackets!') return ArrayAttr.from_list(attrs) + @abstractmethod def must_parse_optional_attr_dict(self) -> dict[str, Attribute]: - tree, prefix = { - MlirParser.Accent.MLIR: (ParserCommons.BNF.attr_dict_mlir, '{'), - MlirParser.Accent.XDSL: (ParserCommons.BNF.attr_dict_xdsl, '[') - }[self.accent] - - if self.tokenizer.next_token_of_pattern(prefix, peek=True) is None: - return dict() - - res = tree.must_parse(self) - - return self.attr_dict_from_tuple_list( - ParserCommons.BNF.attr_dict_mlir.collect(res, dict()).get( - 'attributes', list())) - - + raise NotImplementedError() def attr_dict_from_tuple_list( self, tuple_list: list[tuple[Span, @@ -1393,9 +1319,6 @@ def must_parse_region_list(self) -> list[Region]: regions.append(self.must_parse_region()) return regions - def is_xdsl(self) -> bool: - return self.accent == MlirParser.Accent.XDSL - # HERE STARTS A SOMEWHAT CURSED COMPATIBILITY LAYER: # since we don't want to rewrite all dialects currently, the new emulator needs to expose the same # interface to the dialect definitions. Here we implement that interface. @@ -1414,9 +1337,8 @@ def parse_op_with_default_format( This implicitly assumes XDSL format, and will fail on MLIR style operations """ # TODO: remove this function and restructure custom op / irdl parsing - assert self.accent == MlirParser.Accent.XDSL, "Function parse_op_with_default_format requires xDSL format" - args = self.must_parse_block_arg_list() + args = self.must_parse_op_args_list() successors: list[Span] = [] if self.tokenizer.next_token_of_pattern('(') is not None: successors = self.must_parse_list_of(self.try_parse_block_id, @@ -1430,11 +1352,18 @@ def parse_op_with_default_format( regions = self.must_parse_region_list() + for x in args: + if x.text not in self.ssaValues: + self.raise_error( + "Unknown SSAValue name, known SSA Values are: {}".format(", ".join(self.ssaValues.keys())), + x + ) + return op_type.create( - operands=[self._ssaValues[span.text] for span, _ in args], + operands=[self.ssaValues[span.text] for span in args], result_types=result_types, attributes=attributes, - successors=[self._blocks[span.text] for span in successors], + successors=[self.blocks[span.text] for span in successors], regions=regions) def parse_paramattr_parameters( @@ -1456,6 +1385,244 @@ def parse_paramattr_parameters( return res + # COMMON xDSL/MLIR code: + def must_parse_builtin_type_with_name(self, name: Span): + if name.text == 'index': + return IndexType() + if (re_match := re.match(r'^[su]?i(\d+)$', name.text)) is not None: + signedness = { + 's': Signedness.SIGNED, + 'u': Signedness.UNSIGNED, + 'i': Signedness.SIGNLESS + } + return IntegerType.from_width(int(re_match.group(1)), + signedness[name.text[0]]) + + if (re_match := re.match(r'^f(\d+)$', name.text)) is not None: + width = int(re_match.group(1)) + type = { + 16: Float16Type, + 32: Float32Type, + 64: Float64Type + }.get(width, None) + if type is None: + self.raise_error( + "Unsupported floating point width: {}".format(width)) + return type() + + return self.must_parse_builtin_parametrized_type(name) + + @abstractmethod + def must_parse_operation_details(self) -> tuple[ + list[Span], list[Span], dict[str, Attribute], list[Region], FunctionType | None]: + """ + Must return a tuple consisting of: + - a list of arguments to the operation + - a list of successor names + - the attributes attached to the OP + - the regions of the op + - An optional function type. If not supplied, must_parse_op_result_list must return a second value + containing the types of the returned SSAValues + + Your implementation should make use of the following functions: + - must_parse_op_args_list + - must_parse_optional_attr_dict + - must_parse_ + """ + raise NotImplementedError() + + + def must_parse_op_args_list(self) -> list[Span]: + self.must_parse_characters('(', 'Operation args list must be enclosed by brackets!') + args = self.must_parse_list_of(self.try_parse_value_id_and_type, 'Expected another bare-id here') + self.must_parse_characters(')', 'Operation args list must be closed by a closing bracket') + # TODO: check if type is correct here! + return [name for name, _ in args] + + @abstractmethod + def must_parse_optional_successor_list(self) -> list[Span]: + pass + +class MLIRParser(BaseParser): + + def try_parse_builtin_type(self) -> Attribute | None: + """ + parse a builtin-type like i32, index, vector etc. + """ + with self.tokenizer.backtracking("builtin type"): + name = self.tokenizer.next_token_of_pattern(ParserCommons.builtin_type) + if name is None: + raise BacktrackingAbort("Expected builtin name!") + + return self.must_parse_builtin_type_with_name(name) + + def must_parse_attribute(self) -> Attribute: + """ + Parse attribute (either builtin or dialect) + """ + # all dialect attrs must start with '#', so we check for that first (as it's easier) + if self.tokenizer.next_token(peek=True).text == '#': + value = self.try_parse_dialect_attr() + + # no value => error + if value is None: + self.raise_error( + '`#` must be followed by a valid dialect attribute or type!' + ) + + return value + + # if it isn't a dialect attr, parse builtin + builtin_val = self.try_parse_builtin_attr() + + if builtin_val is None: + self.raise_error( + "Unknown attribute (neither builtin nor dialect could be parsed)!" + ) + + return builtin_val + + def must_parse_op_result_list( + self) -> tuple[list[Span], list[Attribute] | None]: + return self.must_parse_list_of(self.try_parse_value_id, + 'Expected op-result here!', + allow_empty=True), None + + def must_parse_optional_attr_dict(self) -> dict[str, Attribute]: + if self.tokenizer.next_token_of_pattern('{', peek=True) is None: + return dict() + + res = ParserCommons.BNF.attr_dict_mlir.must_parse(self) + + return self.attr_dict_from_tuple_list( + ParserCommons.BNF.attr_dict_mlir.collect(res, dict()).get( + 'attributes', list())) + + def must_parse_operation_details(self) -> tuple[ + list[Span], list[Span], dict[str, Attribute], list[Region], FunctionType | None]: + + args = self.must_parse_op_args_list() + succ = self.must_parse_optional_successor_list() + attrs = self.must_parse_optional_attr_dict() + + regions = [] + if self.tokenizer.starts_with('('): + self.must_parse_characters('(', 'Expected brackets enclosing regions!') + regions = self.must_parse_region_list() + self.must_parse_characters(')', 'Expected brackets enclosing regions!') + self.must_parse_characters(':', 'MLIR Operation defintions must end in a function type signature!') + + func_type = self.must_parse_function_type() + + return args, succ, attrs, regions, func_type + + def must_parse_optional_successor_list(self) -> list[Span]: + if not self.tokenizer.starts_with('['): + return [] + self.must_parse_characters('[', 'Successor list is enclosed in square brackets') + successors = self.must_parse_list_of(self.try_parse_block_id, 'Expected a block-id', allow_empty=False) + self.must_parse_characters(']', 'Successor list is enclosed in square brackets') + return successors + + +class XDSLParser(BaseParser): + + def try_parse_builtin_type(self) -> Attribute | None: + """ + parse a builtin-type like i32, index, vector etc. + """ + with self.tokenizer.backtracking("builtin type"): + name = self.tokenizer.next_token_of_pattern(ParserCommons.builtin_type_xdsl) + if name is None: + raise BacktrackingAbort("Expected builtin name!") + # xdsl builtin types have a '!' prefix, we strip that out here + name = Span(start=name.start + 1, + end=name.end, + input=name.input) + + return self.must_parse_builtin_type_with_name(name) + + def must_parse_attribute(self) -> Attribute: + """ + Parse attribute (either builtin or dialect) + + xDSL allows types in places of attributes! That's why we parse types here as well + """ + value = self.try_parse_builtin_attr() + + # xDSL: Allow both # and ! prefixes, as we allow both types and attrs + if value is None and self.tokenizer.next_token(peek=True).text in '#!': + # in MLIR # and ! are prefixes for dialect attrs/types, but in xDSL ! is also used for builtin types + value = self.try_parse_dialect_type_or_attribute() + + if value is None: + self.raise_error( + "Unknown attribute (neither builtin nor dialect could be parsed)!" + ) + + return value + + def must_parse_op_result_list( + self) -> tuple[list[Span], list[Attribute] | None]: + results = self.must_parse_list_of(self.try_parse_value_id_and_type, + 'Expected (value-id `:` type) here!', + allow_empty=True) + # TODO: this is hideous, make it cleaner + # zip(*results) works, but is barely readable :/ + return [name for name, _ in results], [type for _, type in results] + + def try_parse_builtin_attr(self) -> Attribute: + """ + Tries to parse a bultin attribute, e.g. a string literal, int, array, etc.. + + If the mode is xDSL, it also allows parsing of builtin types + """ + # in xdsl, two things are different here: + # 1. types are considered valid attributes + # 2. all types, builtins included, are prefixed with ! + if self.tokenizer.starts_with('!'): + return self.try_parse_builtin_type() + + return super().try_parse_builtin_attr() + + def must_parse_optional_attr_dict(self) -> dict[str, Attribute]: + if self.tokenizer.next_token_of_pattern('[', peek=True) is None: + return dict() + + res = ParserCommons.BNF.attr_dict_xdsl.must_parse(self) + + return self.attr_dict_from_tuple_list( + ParserCommons.BNF.attr_dict_mlir.collect(res, dict()).get( + 'attributes', list())) + + def must_parse_operation_details(self) -> tuple[ + list[Span], list[Span], dict[str, Attribute], list[Region], FunctionType | None]: + """ + Must return a tuple consisting of: + - a list of arguments to the operation + - a list of successor names + - the attributes attached to the OP + - the regions of the op + - An optional function type. If not supplied, must_parse_op_result_list must return a second value + containing the types of the returned SSAValues + + """ + args = self.must_parse_op_args_list() + succ = self.must_parse_optional_successor_list() + attrs = self.must_parse_optional_attr_dict() + regions = self.must_parse_region_list() + + return args, succ, attrs, regions, None + + def must_parse_optional_successor_list(self) -> list[Span]: + if not self.tokenizer.starts_with('['): + return [] + self.must_parse_characters('[', 'Successor list is enclosed in square brackets') + successors = self.must_parse_list_of(self.try_parse_block_id, 'Expected a block-id', allow_empty=False) + self.must_parse_characters(']', 'Successor list is enclosed in square brackets') + return successors + + """ digit ::= [0-9] @@ -1522,8 +1689,6 @@ def parse_paramattr_parameters( function-type ::= (type | type-list-parens) `->` (type | type-list-parens) -type-alias-def ::= '!' alias-name '=' type -type-alias ::= '!' alias-name """ if __name__ == '__main__': @@ -1552,12 +1717,13 @@ def parse_paramattr_parameters( ctx.register_dialect(IRDL) ctx.register_dialect(LLVM) - dialect = {'xdsl': MlirParser.Accent.XDSL, 'mlir': MlirParser.Accent.MLIR} + parses_by_file_name = {'xdsl': XDSLParser, 'mlir': MLIRParser} + + parser = parses_by_file_name[infile.split('.')[-1]] - p = MlirParser(infile, - open(infile, 'r').read(), - ctx, - accent=dialect.get(infile.split('.')[-1])) + p = parser(infile, + open(infile, 'r').read(), + ctx) printer = Printer() try: From 2940b982552dc23b151fa6f488ffd63d89d62102 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Mon, 9 Jan 2023 12:28:02 +0100 Subject: [PATCH 12/65] [parser] fixed bug in EOL handling for span context printing + order in MLIR op parser --- xdsl/parser_ng.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/xdsl/parser_ng.py b/xdsl/parser_ng.py index bba1d02116..7e0bf342ed 100644 --- a/xdsl/parser_ng.py +++ b/xdsl/parser_ng.py @@ -206,15 +206,17 @@ def get_lines_containing(self, span: Span) -> tuple[list[str], int, int] | None: # A pointer to the start of the first line start = 0 - line_no = 0 + line_no = -1 source = self.content while True: next_start = source.find('\n', start) line_no += 1 # handle eof - if next_start == -1 : - return None - # as long as the next newline comes before the spans start we are good + if next_start == -1: + if span.start > len(source): + return None + return source[start:], start, line_no + # as long as the next newline comes before the spans start we can continue if next_start < span.start: start = next_start + 1 continue @@ -1503,15 +1505,16 @@ def must_parse_operation_details(self) -> tuple[ args = self.must_parse_op_args_list() succ = self.must_parse_optional_successor_list() - attrs = self.must_parse_optional_attr_dict() regions = [] if self.tokenizer.starts_with('('): self.must_parse_characters('(', 'Expected brackets enclosing regions!') regions = self.must_parse_region_list() self.must_parse_characters(')', 'Expected brackets enclosing regions!') - self.must_parse_characters(':', 'MLIR Operation defintions must end in a function type signature!') + attrs = self.must_parse_optional_attr_dict() + + self.must_parse_characters(':', 'MLIR Operation defintions must end in a function type signature!') func_type = self.must_parse_function_type() return args, succ, attrs, regions, func_type From 33d59fa1a13526436b012cb4e9006f45d5bcc1ae Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Tue, 10 Jan 2023 14:48:38 +0000 Subject: [PATCH 13/65] [parser] now able to parse some MLIR and xDSL code successfully --- xdsl/parser_ng.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xdsl/parser_ng.py b/xdsl/parser_ng.py index 7e0bf342ed..3654dc71bb 100644 --- a/xdsl/parser_ng.py +++ b/xdsl/parser_ng.py @@ -1092,13 +1092,13 @@ def try_parse_operation(self) -> Operation | None: regions=regions) # Register the result SSA value names in the parser - for (idx, res) in enumerate(result_list): + for idx, res in enumerate(result_list): ssa_val_name = res.text if ssa_val_name in self.ssaValues: self.raise_error(f"SSA value {ssa_val_name} is already defined", res) self.ssaValues[ssa_val_name] = op.results[idx] # TODO: check name? - self.ssaValues[ssa_val_name].name = ssa_val_name + self.ssaValues[ssa_val_name].name = ssa_val_name.lstrip('%') return op @@ -1317,7 +1317,7 @@ def must_parse_region_list(self) -> list[Region]: Parses a sequence of regions for as long as there is a `{` in the input. """ regions = [] - while self.tokenizer.next_token(peek=True).text == '{': + while not self.tokenizer.is_eof() and self.tokenizer.next_token(peek=True).text == '{': regions.append(self.must_parse_region()) return regions From 881f8d0b6025ffcac10de44b44d7d1f652a57802 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Tue, 10 Jan 2023 16:18:50 +0000 Subject: [PATCH 14/65] [parser] clean up lots of unused or unnecessary parts --- xdsl/parser_ng.py | 70 ++++++++--------------------------------------- 1 file changed, 12 insertions(+), 58 deletions(-) diff --git a/xdsl/parser_ng.py b/xdsl/parser_ng.py index 3654dc71bb..df017cba74 100644 --- a/xdsl/parser_ng.py +++ b/xdsl/parser_ng.py @@ -160,10 +160,8 @@ def __post_init__(self): if len(self) < 2 or self.text[0] != '"' or self.text[-1] != '"': raise ParseError(self, "Invalid string literal!") - T_ = TypeVar('T_', Span, None) - @classmethod - def from_span(cls, span: T_) -> T_: + def from_span(cls, span: Span | None) -> StringLiteral | None: if span is None: return None return cls(span.start, span.end, span.input) @@ -215,7 +213,7 @@ def get_lines_containing(self, if next_start == -1: if span.start > len(source): return None - return source[start:], start, line_no + return [source[start:]], start, line_no # as long as the next newline comes before the spans start we can continue if next_start < span.start: start = next_start + 1 @@ -233,7 +231,7 @@ def at(self, i: int): return self.content[i] -save_t = tuple[int, tuple[str, ...], bool] +save_t = tuple[int, tuple[str, ...]] parsed_type_t = tuple[Span, tuple[Span]] @@ -253,8 +251,6 @@ class Tokenizer: characters the tokenizer should break on """ - ignore_whitespace: bool = True - history: BacktrackingHistory | None = field(init=False, default=None) last_token: Span | None = field(init=False, default=None) @@ -263,7 +259,7 @@ def save(self) -> save_t: """ Create a checkpoint in the parsing process, useful for backtracking """ - return self.pos, self.break_on, self.ignore_whitespace + return self.pos, self.break_on def resume_from(self, save: save_t): """ @@ -271,7 +267,7 @@ def resume_from(self, save: save_t): Restores the state of the tokenizer to the exact previous position """ - self.pos, self.break_on, self.ignore_whitespace = save + self.pos, self.break_on = save @contextlib.contextmanager def backtracking(self, region_name: str | None = None): @@ -439,19 +435,20 @@ def _find_token_end(self, start: int | None = None) -> int: def next_pos(self, i: int | None = None) -> int: """ - Find the next starting position (optionally starting from i), considering ignore_whitespaces + Find the next starting position (optionally starting from i) This will skip line comments! """ i = self.pos if i is None else i # skip whitespaces - if self.ignore_whitespace: - while self.input.at(i).isspace(): - i += 1 + while self.input.at(i).isspace(): + i += 1 + # skip comments as well if self.input.content.startswith('//', i): i = self.input.content.find('\n', i) + 1 return self.next_pos(i) + return i def is_eof(self): @@ -463,22 +460,14 @@ def is_eof(self): except EOFError: return True - def consume_opt_whitespace(self) -> Span: - start = self.pos - while self.input.at(self.pos).isspace(): - self.pos += 1 - return Span(start, self.pos, self.input) - @contextlib.contextmanager - def configured(self, - break_on: tuple[str, ...] | None = None, - ignore_whitespace: bool | None = None): + def configured(self, break_on: tuple[str, ...]): """ This is a helper class to allow expressing a temporary change in config, allowing you to write: # parsing double-quoted string now string_content = "" - with tokenizer.configured(break_on=('"', '\\'), ignore_whitespace=False): + with tokenizer.configured(break_on=('"', '\\'),): # use tokenizer # now old config is restored automatically @@ -488,14 +477,11 @@ def configured(self, if break_on is not None: self.break_on = break_on - if ignore_whitespace is not None: - self.ignore_whitespace = ignore_whitespace try: yield self finally: self.break_on = save[1] - self.ignore_whitespace = save[2] def starts_with(self, text: str | re.Pattern) -> bool: start = self.next_pos() @@ -551,38 +537,6 @@ class BNF: """ Collection of BNF trees. """ - generic_operation_body = BNF.Group( - [ - BNF.Nonterminal('string-literal', bind="name"), - BNF.Literal('('), - BNF.ListOf(BNF.Nonterminal('value-id'), bind='args'), - BNF.Literal(')'), - BNF.OptionalGroup( - [ - BNF.Literal('['), - BNF.ListOf(BNF.Nonterminal('block-id'), - allow_empty=False, - bind='blocks'), - # TODD: allow for block args here?! (according to spec) - BNF.Literal(']') - ], - debug_name="operations optional block id group"), - BNF.OptionalGroup([ - BNF.Literal('('), - BNF.ListOf(BNF.Nonterminal('region'), - bind='regions', - debug_name="regions", - allow_empty=False), - BNF.Literal(')') - ], - debug_name="operation regions"), - BNF.Nonterminal('optional-attr-dict', - bind='attributes', - debug_name="attrbiute dictionary"), - BNF.Literal(':'), - BNF.Nonterminal('function-type', bind='type_signature') - ], - debug_name="generic operation body") attr_dict_mlir = BNF.Group([ BNF.Literal('{'), BNF.ListOf(BNF.Nonterminal('attribute-entry', From 280d7ee86af802dec26e48b09492c9456eb40d28 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Tue, 10 Jan 2023 16:18:50 +0000 Subject: [PATCH 15/65] [parser] clean up lots of unused or unnecessary parts --- xdsl/parser_ng.py | 68 ----------------------------------------------- 1 file changed, 68 deletions(-) diff --git a/xdsl/parser_ng.py b/xdsl/parser_ng.py index df017cba74..336b0fca77 100644 --- a/xdsl/parser_ng.py +++ b/xdsl/parser_ng.py @@ -1580,74 +1580,6 @@ def must_parse_optional_successor_list(self) -> list[Span]: return successors - -""" -digit ::= [0-9] -hex_digit ::= [0-9a-fA-F] -letter ::= [a-zA-Z] -id-punct ::= [$._-] - -integer-literal ::= decimal-literal | hexadecimal-literal -decimal-literal ::= digit+ -hexadecimal-literal ::= `0x` hex_digit+ -float-literal ::= [-+]?[0-9]+[.][0-9]*([eE][-+]?[0-9]+)? -string-literal ::= `"` [^"\n\f\v\r]* `"` TODO: define escaping rules - -bare-id ::= (letter|[_]) (letter|digit|[_$.])* -bare-id-list ::= bare-id (`,` bare-id)* -value-id ::= `%` suffix-id -alias-name :: = bare-id -suffix-id ::= (digit+ | ((letter|id-punct) (letter|id-punct|digit)*)) - - -symbol-ref-id ::= `@` (suffix-id | string-literal) (`::` symbol-ref-id)? -value-id-list ::= value-id (`,` value-id)* - -// Uses of value, e.g. in an operand list to an operation. -value-use ::= value-id -value-use-list ::= value-use (`,` value-use)* - -operation ::= op-result-list? (generic-operation | custom-operation) - trailing-location? -generic-operation ::= string-literal `(` value-use-list? `)` successor-list? - region-list? dictionary-attribute? `:` function-type -custom-operation ::= bare-id custom-operation-format -op-result-list ::= op-result (`,` op-result)* `=` -op-result ::= value-id (`:` integer-literal) -successor-list ::= `[` successor (`,` successor)* `]` -successor ::= caret-id (`:` block-arg-list)? -region-list ::= `(` region (`,` region)* `)` -dictionary-attribute ::= `{` (attribute-entry (`,` attribute-entry)*)? `}` -trailing-location ::= (`loc` `(` location `)`)? - -block ::= block-label operation+ -block-label ::= block-id block-arg-list? `:` -block-id ::= caret-id -caret-id ::= `^` suffix-id -value-id-and-type ::= value-id `:` type - -// Non-empty list of names and types. -value-id-and-type-list ::= value-id-and-type (`,` value-id-and-type)* - -block-arg-list ::= `(` value-id-and-type-list? `)` - -type ::= type-alias | dialect-type | builtin-type - -type-list-no-parens ::= type (`,` type)* -type-list-parens ::= `(` `)` - | `(` type-list-no-parens `)` - -// This is a common way to refer to a value with a specified type. -ssa-use-and-type ::= ssa-use `:` type -ssa-use ::= value-use - -// Non-empty list of names and types. -ssa-use-and-type-list ::= ssa-use-and-type (`,` ssa-use-and-type)* - -function-type ::= (type | type-list-parens) `->` (type | type-list-parens) - -""" - if __name__ == '__main__': infile = sys.argv[-1] from xdsl.dialects.affine import Affine From 30730392818e62252ee8713990134050cc428511 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Tue, 10 Jan 2023 17:21:16 +0000 Subject: [PATCH 16/65] parser: remove BNF stuff for now --- xdsl/parser_ng.py | 70 ++++++-------- xdsl/utils/bnf.py | 237 ---------------------------------------------- 2 files changed, 27 insertions(+), 280 deletions(-) delete mode 100644 xdsl/utils/bnf.py diff --git a/xdsl/parser_ng.py b/xdsl/parser_ng.py index 336b0fca77..257a5acfcb 100644 --- a/xdsl/parser_ng.py +++ b/xdsl/parser_ng.py @@ -15,8 +15,6 @@ from xdsl.ir import (SSAValue, Block, Callable, Attribute, Operation, Region, BlockArgument, MLContext, ParametrizedAttribute) -import xdsl.utils.bnf as BNF - from xdsl.dialects.builtin import ( AnyFloat, AnyTensorType, AnyUnrankedTensorType, AnyVectorType, DenseIntOrFPElementsAttr, Float16Type, Float32Type, Float64Type, FloatAttr, @@ -533,28 +531,6 @@ class ParserCommons: double_colon = re.compile('::') comma = re.compile(',') - class BNF: - """ - Collection of BNF trees. - """ - attr_dict_mlir = BNF.Group([ - BNF.Literal('{'), - BNF.ListOf(BNF.Nonterminal('attribute-entry', - debug_name="attribute entry"), - bind='attributes'), - BNF.Literal('}') - ], - debug_name="attrbute dictionary") - - attr_dict_xdsl = BNF.Group([ - BNF.Literal('['), - BNF.ListOf(BNF.Nonterminal('attribute-entry', - debug_name="attribute entry"), - bind='attributes'), - BNF.Literal(']') - ], - debug_name="attrbute dictionary") - class BaseParser(ABC): """ @@ -1203,10 +1179,17 @@ def must_parse_optional_attr_dict(self) -> dict[str, Attribute]: def attr_dict_from_tuple_list( self, tuple_list: list[tuple[Span, Attribute]]) -> dict[str, Attribute]: - return dict( - ((span.string_contents if isinstance(span, StringLiteral - ) else span.text), attr) - for span, attr in tuple_list) + """ + Convert a list of tuples (Span, Attribute) to a dictionary. + + This function converts the span to a string, trimming quotes from string literals + """ + def span_to_str(span: Span) -> str: + if isinstance(span, StringLiteral): + return span.string_contents + return span.text + + return dict((span_to_str(span), attr) for span, attr in tuple_list) def must_parse_function_type(self) -> FunctionType: """ @@ -1228,11 +1211,11 @@ def must_parse_function_type(self) -> FunctionType: args: list[Attribute] = self.must_parse_list_of( self.try_parse_type, 'Expected type here!') self.must_parse_characters(')', - "Malformed function type!", + "Malformed function type, expected closing brackets of argument types!", is_parse_error=True) self.must_parse_characters('->', - 'Malformed function type!', + 'Malformed function type, expected `->`!', is_parse_error=True) return FunctionType.from_lists( @@ -1395,9 +1378,6 @@ def must_parse_op_args_list(self) -> list[Span]: # TODO: check if type is correct here! return [name for name, _ in args] - @abstractmethod - def must_parse_optional_successor_list(self) -> list[Span]: - pass class MLIRParser(BaseParser): @@ -1445,14 +1425,16 @@ def must_parse_op_result_list( allow_empty=True), None def must_parse_optional_attr_dict(self) -> dict[str, Attribute]: - if self.tokenizer.next_token_of_pattern('{', peek=True) is None: + if not self.tokenizer.starts_with('{'): return dict() - res = ParserCommons.BNF.attr_dict_mlir.must_parse(self) + self.must_parse_characters('{', 'MLIR Attribute dictionary must be enclosed in curly brackets') + + attrs = self.must_parse_list_of(self.must_parse_attribute_entry, "Expected attribute entry") - return self.attr_dict_from_tuple_list( - ParserCommons.BNF.attr_dict_mlir.collect(res, dict()).get( - 'attributes', list())) + self.must_parse_characters('}', 'MLIR Attribute dictionary must be enclosed in curly brackets') + + return self.attr_dict_from_tuple_list(attrs) def must_parse_operation_details(self) -> tuple[ list[Span], list[Span], dict[str, Attribute], list[Region], FunctionType | None]: @@ -1543,14 +1525,16 @@ def try_parse_builtin_attr(self) -> Attribute: return super().try_parse_builtin_attr() def must_parse_optional_attr_dict(self) -> dict[str, Attribute]: - if self.tokenizer.next_token_of_pattern('[', peek=True) is None: + if not self.tokenizer.starts_with('['): return dict() - res = ParserCommons.BNF.attr_dict_xdsl.must_parse(self) + self.must_parse_characters('[', 'xDSL Attribute dictionary must be enclosed in curly brackets') + + attrs = self.must_parse_list_of(self.must_parse_attribute_entry, "Expected attribute entry") + + self.must_parse_characters(']', 'xDSL Attribute dictionary must be enclosed in curly brackets') - return self.attr_dict_from_tuple_list( - ParserCommons.BNF.attr_dict_mlir.collect(res, dict()).get( - 'attributes', list())) + return self.attr_dict_from_tuple_list(attrs) def must_parse_operation_details(self) -> tuple[ list[Span], list[Span], dict[str, Attribute], list[Region], FunctionType | None]: diff --git a/xdsl/utils/bnf.py b/xdsl/utils/bnf.py deleted file mode 100644 index 2614d0b522..0000000000 --- a/xdsl/utils/bnf.py +++ /dev/null @@ -1,237 +0,0 @@ -from __future__ import annotations -import functools -import re -import typing -from dataclasses import dataclass, field -from abc import abstractmethod, ABC -from typing import Any - -if typing.TYPE_CHECKING: - from xdsl.parser_ng import MlirParser, ParseError - -T = typing.TypeVar('T') - - -@dataclass(frozen=True) -class BNFToken: - bind: str | None = field(kw_only=True, init=False) - debug_name: str | None = field(kw_only=True, init=False) - - @abstractmethod - def must_parse(self, parser: MlirParser) -> T: - raise NotImplemented() - - def try_parse(self, parser: MlirParser) -> T | None: - with parser.tokenizer.backtracking(self.debug_name): - return self.must_parse(parser) - - def collect(self, value, collection: dict) -> dict: - if self.bind is None: - return collection - collection[self.bind] = value - return collection - - -@dataclass(frozen=True) -class Literal(BNFToken): - """ - Match a fixed input string - """ - string: str - bind: str | None = field(kw_only=True, default=None) - debug_name: str | None = field(kw_only=True, default=None) - - def must_parse(self, parser: MlirParser): - return parser.must_parse_characters( - self.string, 'Expected `{}`'.format(self.string)) - - def __repr__(self): - return '`{}`'.format(self.string) - - -@dataclass(frozen=True) -class Regex(BNFToken): - pattern: re.Pattern - bind: str | None = field(kw_only=True, default=None) - debug_name: str | None = field(kw_only=True, default=None) - - def try_parse(self, parser: MlirParser) -> T | None: - return parser.tokenizer.next_token_of_pattern(self.pattern) - - def must_parse(self, parser: MlirParser) -> T: - res = self.try_parse(parser) - if res is None: - parser.raise_error('Expected token of form {}!'.format(self)) - return res - - def __repr__(self): - return 're`{}`'.format(self.pattern.pattern) - - -@dataclass(frozen=True) -class Nonterminal(BNFToken): - """ - This is used as an "escape hatch" to switch from BNF to the python parsing code. - - It will look for must_parse_, or try_parse_ in the parse object. This can - probably be improved, idk. - """ - - name: str - """ - The symbol name of the nonterminal, e.g. string-lieral, tensor-attrs, etc... - """ - bind: str | None = field(kw_only=True, default=None) - - debug_name: str | None = field(kw_only=True, default=None) - - def parser_func_name(self, prefix: str): - return prefix + self.name.replace('-', '_') - - def must_parse(self, parser: MlirParser): - if hasattr(parser, self.parser_func_name('must_parse_')): - return getattr(parser, self.parser_func_name('must_parse_'))() - elif hasattr(parser, self.parser_func_name('try_parse_')): - return parser.expect( - getattr(parser, self.parser_func_name('try_parse_')), - 'Expected to parse {} here!'.format(self.name)) - else: - raise NotImplementedError("Parser cannot parse {}".format( - self.name)) - - def try_parse(self, parser: MlirParser) -> T | None: - if hasattr(parser, self.parser_func_name('try_parse_')): - return getattr(parser,self.parser_func_name('try_parse_'))() - return super().try_parse(parser) - - def __repr__(self): - return self.name - - -@dataclass(frozen=True) -class Group(BNFToken): - tokens: list[BNFToken] - bind: str | None = field(kw_only=True, default=None) - debug_name: str | None = field(kw_only=True, default=None) - - def must_parse(self, parser: MlirParser) -> T: - return [token.must_parse(parser) for token in self.tokens] - - def __repr__(self): - return '( {} )'.format(' '.join(repr(t) for t in self.tokens)) - - def collect(self, value, collection: dict) -> dict: - for child, value in zip(self.tokens, value): - child.collect(value, collection) - return super().collect(value, collection) - - -@dataclass(frozen=True) -class OneOrMoreOf(BNFToken): - wraps: BNFToken - bind: str | None = field(kw_only=True, default=None) - debug_name: str | None = field(kw_only=True, default=None) - - def must_parse(self, parser: MlirParser) -> list[T]: - res = list() - while True: - val = self.wraps.try_parse(parser) - if val is None: - if len(res) == 0: - raise AssertionError("Expected at least one of {}".format( - self.wraps)) - return res - res.append(val) - - def __repr__(self): - return '{}+'.format(self.wraps) - - def children(self) -> typing.Iterable[BNFToken]: - return self.wraps, - - def collect(self, value, collection: dict) -> dict: - for val in value: - self.wraps.collect(val, collection) - return super().collect(value, collection) - - -@dataclass(frozen=True) -class ZeroOrMoreOf(BNFToken): - wraps: BNFToken - bind: str | None = field(kw_only=True, default=None) - debug_name: str | None = field(kw_only=True, default=None) - - def must_parse(self, parser: MlirParser) -> list[T]: - res = list() - while True: - val = self.wraps.try_parse(parser) - if val is None: - return res - res.append(val) - - def __repr__(self): - return '{}*'.format(self.wraps) - - def children(self) -> typing.Iterable[BNFToken]: - return self.wraps, - - def collect(self, values, collection: dict) -> dict: - for value in values: - self.wraps.collect(value, collection) - return super().collect(values, collection) - - -@dataclass(frozen=True) -class ListOf(BNFToken): - element: BNFToken - separator: re.Pattern = re.compile(',') - - allow_empty: bool = True - bind: str | None = field(kw_only=True, default=None) - debug_name: str | None = field(kw_only=True, default=None) - - def must_parse(self, parser: MlirParser) -> T | None: - return parser.must_parse_list_of( - lambda: self.element.try_parse(parser), - 'Expected {}!'.format(self.element), - separator_pattern=self.separator, - allow_empty=self.allow_empty) - - def __repr__(self): - if self.allow_empty: - return '( {elm} ( re`{sep}` {elm} )* )?'.format( - elm=self.element, sep=self.separator.pattern) - return '{elm} ( re`{sep}` {elm} )*'.format(elm=self.element, - sep=self.separator.pattern) - - def collect(self, values, collection: dict) -> dict: - for value in values: - self.element.collect(value, collection) - return super().collect(values, collection) - - -@dataclass(frozen=True) -class Optional(BNFToken): - wraps: BNFToken - bind: str | None = field(kw_only=True, default=None) - debug_name: str | None = field(kw_only=True, default=None) - - def must_parse(self, parser: MlirParser) -> T | None: - return self.wraps.try_parse(parser) - - def try_parse(self, parser: MlirParser) -> T | None: - return self.wraps.try_parse(parser) - - def __repr__(self): - return '{}?'.format(self.wraps) - - def collect(self, value, collection: dict) -> dict: - if value is not None: - self.wraps.collect(value, collection) - return super().collect(value, collection) - - -def OptionalGroup(tokens: list[BNFToken], - bind: str | None = None, - debug_name: str | None = None) -> Optional: - return Optional(Group(tokens), bind=bind, debug_name=debug_name) From 4464c323df452dfd672b711d4365ac23fa040278 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Tue, 10 Jan 2023 17:53:21 +0000 Subject: [PATCH 17/65] parser: unify code style and apply yapf --- xdsl/parser_ng.py | 224 +++++++++++++++++++++++++++------------------- 1 file changed, 130 insertions(+), 94 deletions(-) diff --git a/xdsl/parser_ng.py b/xdsl/parser_ng.py index 257a5acfcb..518cb089cf 100644 --- a/xdsl/parser_ng.py +++ b/xdsl/parser_ng.py @@ -230,7 +230,6 @@ def at(self, i: int): save_t = tuple[int, tuple[str, ...]] -parsed_type_t = tuple[Span, tuple[Span]] @dataclass @@ -249,9 +248,11 @@ class Tokenizer: characters the tokenizer should break on """ - history: BacktrackingHistory | None = field(init=False, default=None) + history: BacktrackingHistory | None = field(init=False, + default=None, + repr=False) - last_token: Span | None = field(init=False, default=None) + last_token: Span | None = field(init=False, default=None, repr=False) def save(self) -> save_t: """ @@ -341,26 +342,27 @@ def history_entry_from_exception(self, ex: Exception, region: str, traceback.print_exc(file=tb) reason[0] += '\n' + tb.getvalue() - return BacktrackingHistory(ParseError(self.last_token, reason[-1], self.history), - self.history, region, pos) + return BacktrackingHistory( + ParseError(self.last_token, reason[-1], self.history), + self.history, region, pos) elif isinstance(ex, BacktrackingAbort): return BacktrackingHistory( ParseError( self.next_token(peek=True), 'Backtracking aborted: {}'.format(ex.reason - or 'unknown reason'), self.history), - self.history, region, pos) + or 'unknown reason'), + self.history), self.history, region, pos) elif isinstance(ex, EOFError): return BacktrackingHistory( - ParseError(self.last_token, "Encountered EOF", self.history), self.history, - region, pos) + ParseError(self.last_token, "Encountered EOF", self.history), + self.history, region, pos) print("Warning: Unexpected error in backtracking:", file=sys.stderr) traceback.print_exception(ex, file=sys.stderr) return BacktrackingHistory( - ParseError(self.last_token, "Unexpected exception: {}".format(ex), self.history), - self.history, region, pos) + ParseError(self.last_token, "Unexpected exception: {}".format(ex), + self.history), self.history, region, pos) def next_token(self, start: int | None = None, peek: bool = False) -> Span: """ @@ -565,10 +567,12 @@ class BaseParser(ABC): of all try_parse functions is T_ | None """ - def __init__(self, - input: str, - name: str, - ctx: MLContext, ): + def __init__( + self, + input: str, + name: str, + ctx: MLContext, + ): self.tokenizer = Tokenizer(Input(input, name)) self.ctx = ctx self.ssaValues = dict() @@ -608,7 +612,7 @@ def must_parse_optional_block_label( if block_id is not None: assert block_id.text not in self.blocks, "Blocks cannot have the same ID!" - if self.tokenizer.next_token(peek=True).text == '(': + if self.tokenizer.starts_with('('): arg_list = self.must_parse_block_arg_list() self.must_parse_characters(':', 'Block label must end in a `:`!') @@ -673,7 +677,7 @@ def must_parse_list_of(self, items.append(first_item) while (match := self.tokenizer.next_token_of_pattern(separator_pattern) - ) is not None: + ) is not None: next_item = try_parse() if next_item is None: # if the separator is emtpy, we are good here @@ -737,8 +741,7 @@ def try_parse_value_id_and_type(self) -> tuple[Span, Attribute] | None: def try_parse_type(self) -> Attribute | None: if (builtin_type := self.try_parse_builtin_type()) is not None: return builtin_type - if (dialect_type := - self.try_parse_dialect_type()) is not None: + if (dialect_type := self.try_parse_dialect_type()) is not None: return dialect_type return None @@ -746,7 +749,8 @@ def try_parse_dialect_type_or_attribute(self) -> Attribute | None: """ Parse a type or an attribute. """ - kind = self.tokenizer.next_token_of_pattern(re.compile('[!#]'), peek=True) + kind = self.tokenizer.next_token_of_pattern(re.compile('[!#]'), + peek=True) if kind is None: return None @@ -756,39 +760,42 @@ def try_parse_dialect_type_or_attribute(self) -> Attribute | None: if kind.text == '!': return self.must_parse_dialect_type_or_attribute_inner('type') else: - return self.must_parse_dialect_type_or_attribute_inner('attribute') + return self.must_parse_dialect_type_or_attribute_inner( + 'attribute') def try_parse_dialect_type(self): """ Parse a dialect type (something prefixed by `!`, defined by a dialect) """ - if self.tokenizer.next_token_of_pattern('!', peek=True) is None: + if not self.tokenizer.starts_with('!'): return None with self.tokenizer.backtracking("dialect type"): - self.tokenizer.next_token_of_pattern('!') + self.must_parse_characters('!', + "Dialect type must start with a `!`") return self.must_parse_dialect_type_or_attribute_inner('type') def try_parse_dialect_attr(self): """ Parse a dialect attribute (something prefixed by `#`, defined by a dialect) """ - if self.tokenizer.next_token_of_pattern('#', peek=True) is None: + if not self.tokenizer.starts_with('#'): return None with self.tokenizer.backtracking("dialect attribute"): - self.tokenizer.next_token_of_pattern('#') + self.must_parse_characters( + '#', "Dialect attribute must start with a `#`") return self.must_parse_dialect_type_or_attribute_inner('attribute') def must_parse_dialect_type_or_attribute_inner(self, kind: str): - type_name = self.tokenizer.next_token_of_pattern( - ParserCommons.bare_id) + type_name = self.tokenizer.next_token_of_pattern(ParserCommons.bare_id) if type_name is None: - self.raise_error( - "Expected dialect {} name here!".format(kind)) + self.raise_error("Expected dialect {} name here!".format(kind)) type_def = self.ctx.get_optional_attr(type_name.text) if type_def is None: - self.raise_error("'{}' is not a know attribute!".format(type_name.text), type_name) + self.raise_error( + "'{}' is not a know attribute!".format(type_name.text), + type_name) # pass the task of parsing parameters on to the attribute/type definition param_list = type_def.parse_parameters(self) @@ -832,11 +839,10 @@ def try_parse_numerical_dims(self, accept_closing_bracket: bool = False, lower_bound: int = 1) -> Iterable[int]: while (shape_arg := - self.try_parse_shape_element(lower_bound)) is not None: + self.try_parse_shape_element(lower_bound)) is not None: yield shape_arg # look out for the closing bracket for scalable vector dims - if accept_closing_bracket and self.tokenizer.next_token( - peek=True).text == ']': + if accept_closing_bracket and self.tokenizer.starts_with(']'): break self.must_parse_characters( 'x', @@ -846,7 +852,7 @@ def try_parse_numerical_dims(self, def must_parse_vector_attrs(self) -> AnyVectorType: # also break on 'x' characters as they are separators in dimension parameters with self.tokenizer.configured(break_on=self.tokenizer.break_on + - ('x',)): + ('x', )): shape = list[int](self.try_parse_numerical_dims()) scaling_shape: list[int] | None = None @@ -864,7 +870,7 @@ def must_parse_vector_attrs(self) -> AnyVectorType: if scaling_shape is not None: # TODO: handle scaling vectors! - print("Warning: scaling vectors not supported!") + self.raise_error("Warning: scaling vectors not supported!") pass type = self.try_parse_type() @@ -876,7 +882,7 @@ def must_parse_vector_attrs(self) -> AnyVectorType: def must_parse_tensor_or_memref_dims(self) -> list[int] | None: with self.tokenizer.configured(break_on=self.tokenizer.break_on + - ('x',)): + ('x', )): # check for unranked-ness if self.tokenizer.next_token_of_pattern('*') is not None: # consume `x` @@ -895,11 +901,11 @@ def must_parse_tensor_attrs(self) -> AnyTensorType: if type is None: self.raise_error("Expected tensor type here!") - if self.tokenizer.next_token(peek=True).text == ',': + if self.tokenizer.starts_with(','): # TODO: add tensor encoding! raise self.raise_error("Parsing tensor encoding is not supported!") - if shape is None and self.tokenizer.next_token(peek=True).text == ',': + if shape is None and self.tokenizer.starts_with(','): raise self.raise_error("Unranked tensors don't have an encoding!") if shape is not None: @@ -924,10 +930,7 @@ def try_parse_shape_element(self, lower_bound: int = 1) -> int | None: "Shape element literal cannot be negative or zero!") return value - next_token = self.tokenizer.next_token(peek=True) - - if next_token.text == '?': - self.tokenizer.consume_peeked(next_token) + if self.tokenizer.next_token_of_pattern('?') is not None: return -1 return None @@ -1003,7 +1006,8 @@ def try_parse_operation(self) -> Operation | None: "Expected an operation name here, either a bare-id, or a string literal!" ) - args, successors, attrs, regions, func_type = self.must_parse_operation_details() + args, successors, attrs, regions, func_type = self.must_parse_operation_details( + ) if ret_types is None: assert func_type is not None @@ -1025,7 +1029,8 @@ def try_parse_operation(self) -> Operation | None: for idx, res in enumerate(result_list): ssa_val_name = res.text if ssa_val_name in self.ssaValues: - self.raise_error(f"SSA value {ssa_val_name} is already defined", res) + self.raise_error( + f"SSA value {ssa_val_name} is already defined", res) self.ssaValues[ssa_val_name] = op.results[idx] # TODO: check name? self.ssaValues[ssa_val_name].name = ssa_val_name.lstrip('%') @@ -1041,12 +1046,12 @@ def must_parse_region(self) -> Region: try: self.must_parse_characters('{', 'Regions begin with `{`') - if self.tokenizer.next_token(peek=True).text != '}': + if not self.tokenizer.starts_with('}'): # parse first block block = self.must_parse_block() region.add_block(block) - while self.tokenizer.next_token(peek=True).text == '^': + while self.tokenizer.starts_with('^'): region.add_block(self.must_parse_block()) self.must_parse_characters('}', @@ -1135,7 +1140,7 @@ def try_parse_builtin_float_attr(self) -> FloatAttr | None: self.try_parse_float_literal, 'Float attribute must start with a float literal!') # if we don't see a ':' indicating a type signature - if self.tokenizer.next_token(peek=True).text != ':': + if not self.tokenizer.starts_with(':'): return FloatAttr.from_value(float(value.text)) type = self.must_parse_attribute_type() @@ -1151,17 +1156,17 @@ def try_parse_builtin_boolean_attr(self) -> IntegerAttr | None: return IntegerAttr.from_params(int_val, IntegerType.from_width(1)) def try_parse_builtin_str_attr(self): - if self.tokenizer.next_token(peek=True).text != '"': + if not self.tokenizer.starts_with('"'): return None with self.tokenizer.backtracking("string literal"): literal = self.try_parse_string_literal() - if self.tokenizer.next_token(peek=True).text != ':': - return StringAttr.from_str(literal.string_contents) - self.raise_error("Typed string literals are not supported!") + if literal is None: + self.raise_error('Invalid string literal') + return StringAttr.from_str(literal.string_contents) def try_parse_builtin_arr_attr(self) -> list[Attribute] | None: - if self.tokenizer.next_token(peek=True).text != '[': + if not self.tokenizer.starts_with('['): return None with self.tokenizer.backtracking("array literal"): self.must_parse_characters('[', @@ -1184,6 +1189,7 @@ def attr_dict_from_tuple_list( This function converts the span to a string, trimming quotes from string literals """ + def span_to_str(span: Span) -> str: if isinstance(span, StringLiteral): return span.string_contents @@ -1208,11 +1214,14 @@ def must_parse_function_type(self) -> FunctionType: """ self.must_parse_characters( '(', 'First group of function args must start with a `(`') + args: list[Attribute] = self.must_parse_list_of( self.try_parse_type, 'Expected type here!') - self.must_parse_characters(')', - "Malformed function type, expected closing brackets of argument types!", - is_parse_error=True) + + self.must_parse_characters( + ')', + "Malformed function type, expected closing brackets of argument types!", + is_parse_error=True) self.must_parse_characters('->', 'Malformed function type, expected `->`!', @@ -1244,7 +1253,7 @@ def must_parse_type_or_type_list_parens(self) -> list[Attribute]: return args def try_parse_function_type(self) -> FunctionType | None: - if self.tokenizer.next_token(peek=True).text != '(': + if not self.tokenizer.starts_with('('): return None with self.tokenizer.backtracking('function type'): return self.must_parse_function_type() @@ -1254,7 +1263,7 @@ def must_parse_region_list(self) -> list[Region]: Parses a sequence of regions for as long as there is a `{` in the input. """ regions = [] - while not self.tokenizer.is_eof() and self.tokenizer.next_token(peek=True).text == '{': + while not self.tokenizer.is_eof() and self.tokenizer.starts_with('{'): regions.append(self.must_parse_region()) return regions @@ -1294,9 +1303,8 @@ def parse_op_with_default_format( for x in args: if x.text not in self.ssaValues: self.raise_error( - "Unknown SSAValue name, known SSA Values are: {}".format(", ".join(self.ssaValues.keys())), - x - ) + "Unknown SSAValue name, known SSA Values are: {}".format( + ", ".join(self.ssaValues.keys())), x) return op_type.create( operands=[self.ssaValues[span.text] for span in args], @@ -1352,8 +1360,10 @@ def must_parse_builtin_type_with_name(self, name: Span): return self.must_parse_builtin_parametrized_type(name) @abstractmethod - def must_parse_operation_details(self) -> tuple[ - list[Span], list[Span], dict[str, Attribute], list[Region], FunctionType | None]: + def must_parse_operation_details( + self + ) -> tuple[list[Span], list[Span], dict[str, Attribute], list[Region], + FunctionType | None]: """ Must return a tuple consisting of: - a list of arguments to the operation @@ -1370,11 +1380,13 @@ def must_parse_operation_details(self) -> tuple[ """ raise NotImplementedError() - def must_parse_op_args_list(self) -> list[Span]: - self.must_parse_characters('(', 'Operation args list must be enclosed by brackets!') - args = self.must_parse_list_of(self.try_parse_value_id_and_type, 'Expected another bare-id here') - self.must_parse_characters(')', 'Operation args list must be closed by a closing bracket') + self.must_parse_characters( + '(', 'Operation args list must be enclosed by brackets!') + args = self.must_parse_list_of(self.try_parse_value_id_and_type, + 'Expected another bare-id here') + self.must_parse_characters( + ')', 'Operation args list must be closed by a closing bracket') # TODO: check if type is correct here! return [name for name, _ in args] @@ -1386,7 +1398,8 @@ def try_parse_builtin_type(self) -> Attribute | None: parse a builtin-type like i32, index, vector etc. """ with self.tokenizer.backtracking("builtin type"): - name = self.tokenizer.next_token_of_pattern(ParserCommons.builtin_type) + name = self.tokenizer.next_token_of_pattern( + ParserCommons.builtin_type) if name is None: raise BacktrackingAbort("Expected builtin name!") @@ -1397,7 +1410,7 @@ def must_parse_attribute(self) -> Attribute: Parse attribute (either builtin or dialect) """ # all dialect attrs must start with '#', so we check for that first (as it's easier) - if self.tokenizer.next_token(peek=True).text == '#': + if self.tokenizer.starts_with('#'): value = self.try_parse_dialect_attr() # no value => error @@ -1428,29 +1441,40 @@ def must_parse_optional_attr_dict(self) -> dict[str, Attribute]: if not self.tokenizer.starts_with('{'): return dict() - self.must_parse_characters('{', 'MLIR Attribute dictionary must be enclosed in curly brackets') + self.must_parse_characters( + '{', + 'MLIR Attribute dictionary must be enclosed in curly brackets') - attrs = self.must_parse_list_of(self.must_parse_attribute_entry, "Expected attribute entry") + attrs = self.must_parse_list_of(self.must_parse_attribute_entry, + "Expected attribute entry") - self.must_parse_characters('}', 'MLIR Attribute dictionary must be enclosed in curly brackets') + self.must_parse_characters( + '}', + 'MLIR Attribute dictionary must be enclosed in curly brackets') return self.attr_dict_from_tuple_list(attrs) - def must_parse_operation_details(self) -> tuple[ - list[Span], list[Span], dict[str, Attribute], list[Region], FunctionType | None]: + def must_parse_operation_details( + self + ) -> tuple[list[Span], list[Span], dict[str, Attribute], list[Region], + FunctionType | None]: args = self.must_parse_op_args_list() succ = self.must_parse_optional_successor_list() regions = [] if self.tokenizer.starts_with('('): - self.must_parse_characters('(', 'Expected brackets enclosing regions!') + self.must_parse_characters('(', + 'Expected brackets enclosing regions!') regions = self.must_parse_region_list() - self.must_parse_characters(')', 'Expected brackets enclosing regions!') + self.must_parse_characters(')', + 'Expected brackets enclosing regions!') attrs = self.must_parse_optional_attr_dict() - self.must_parse_characters(':', 'MLIR Operation defintions must end in a function type signature!') + self.must_parse_characters( + ':', + 'MLIR Operation defintions must end in a function type signature!') func_type = self.must_parse_function_type() return args, succ, attrs, regions, func_type @@ -1458,9 +1482,13 @@ def must_parse_operation_details(self) -> tuple[ def must_parse_optional_successor_list(self) -> list[Span]: if not self.tokenizer.starts_with('['): return [] - self.must_parse_characters('[', 'Successor list is enclosed in square brackets') - successors = self.must_parse_list_of(self.try_parse_block_id, 'Expected a block-id', allow_empty=False) - self.must_parse_characters(']', 'Successor list is enclosed in square brackets') + self.must_parse_characters( + '[', 'Successor list is enclosed in square brackets') + successors = self.must_parse_list_of(self.try_parse_block_id, + 'Expected a block-id', + allow_empty=False) + self.must_parse_characters( + ']', 'Successor list is enclosed in square brackets') return successors @@ -1471,13 +1499,12 @@ def try_parse_builtin_type(self) -> Attribute | None: parse a builtin-type like i32, index, vector etc. """ with self.tokenizer.backtracking("builtin type"): - name = self.tokenizer.next_token_of_pattern(ParserCommons.builtin_type_xdsl) + name = self.tokenizer.next_token_of_pattern( + ParserCommons.builtin_type_xdsl) if name is None: raise BacktrackingAbort("Expected builtin name!") # xdsl builtin types have a '!' prefix, we strip that out here - name = Span(start=name.start + 1, - end=name.end, - input=name.input) + name = Span(start=name.start + 1, end=name.end, input=name.input) return self.must_parse_builtin_type_with_name(name) @@ -1528,16 +1555,23 @@ def must_parse_optional_attr_dict(self) -> dict[str, Attribute]: if not self.tokenizer.starts_with('['): return dict() - self.must_parse_characters('[', 'xDSL Attribute dictionary must be enclosed in curly brackets') + self.must_parse_characters( + '[', + 'xDSL Attribute dictionary must be enclosed in curly brackets') - attrs = self.must_parse_list_of(self.must_parse_attribute_entry, "Expected attribute entry") + attrs = self.must_parse_list_of(self.must_parse_attribute_entry, + "Expected attribute entry") - self.must_parse_characters(']', 'xDSL Attribute dictionary must be enclosed in curly brackets') + self.must_parse_characters( + ']', + 'xDSL Attribute dictionary must be enclosed in curly brackets') return self.attr_dict_from_tuple_list(attrs) - def must_parse_operation_details(self) -> tuple[ - list[Span], list[Span], dict[str, Attribute], list[Region], FunctionType | None]: + def must_parse_operation_details( + self + ) -> tuple[list[Span], list[Span], dict[str, Attribute], list[Region], + FunctionType | None]: """ Must return a tuple consisting of: - a list of arguments to the operation @@ -1558,9 +1592,13 @@ def must_parse_operation_details(self) -> tuple[ def must_parse_optional_successor_list(self) -> list[Span]: if not self.tokenizer.starts_with('['): return [] - self.must_parse_characters('[', 'Successor list is enclosed in square brackets') - successors = self.must_parse_list_of(self.try_parse_block_id, 'Expected a block-id', allow_empty=False) - self.must_parse_characters(']', 'Successor list is enclosed in square brackets') + self.must_parse_characters( + '[', 'Successor list is enclosed in square brackets') + successors = self.must_parse_list_of(self.try_parse_block_id, + 'Expected a block-id', + allow_empty=False) + self.must_parse_characters( + ']', 'Successor list is enclosed in square brackets') return successors @@ -1594,9 +1632,7 @@ def must_parse_optional_successor_list(self) -> list[Span]: parser = parses_by_file_name[infile.split('.')[-1]] - p = parser(infile, - open(infile, 'r').read(), - ctx) + p = parser(infile, open(infile, 'r').read(), ctx) printer = Printer() try: From cfd5c62a597a544c59cb94b5d905868484d45616 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Wed, 11 Jan 2023 14:34:05 +0000 Subject: [PATCH 18/65] parser: a lot more cleanup in various places and correctnes fixes --- xdsl/parser_ng.py | 104 ++++++++++++++++++++-------------------------- 1 file changed, 46 insertions(+), 58 deletions(-) diff --git a/xdsl/parser_ng.py b/xdsl/parser_ng.py index 518cb089cf..e49d4f044c 100644 --- a/xdsl/parser_ng.py +++ b/xdsl/parser_ng.py @@ -1,28 +1,23 @@ from __future__ import annotations +import ast import contextlib +import re import sys import traceback from abc import ABC, abstractmethod from dataclasses import dataclass, field -import re -import ast from io import StringIO -from typing import Any, TypeVar, Iterable, Literal, Optional -from enum import Enum +from typing import TypeVar, Iterable -from .printer import Printer +from xdsl.dialects.builtin import ( + AnyTensorType, AnyVectorType, + Float16Type, Float32Type, Float64Type, FloatAttr, + FunctionType, IndexType, IntegerType, Signedness, StringAttr, + IntegerAttr, ArrayAttr, TensorType, UnrankedTensorType, VectorType, DefaultIntegerAttrType, FlatSymbolRefAttr) from xdsl.ir import (SSAValue, Block, Callable, Attribute, Operation, Region, BlockArgument, MLContext, ParametrizedAttribute) - -from xdsl.dialects.builtin import ( - AnyFloat, AnyTensorType, AnyUnrankedTensorType, AnyVectorType, - DenseIntOrFPElementsAttr, Float16Type, Float32Type, Float64Type, FloatAttr, - FunctionType, IndexType, IntegerType, OpaqueAttr, Signedness, StringAttr, - FlatSymbolRefAttr, IntegerAttr, ArrayAttr, TensorType, UnitAttr, - UnrankedTensorType, UnregisteredOp, VectorType, DefaultIntegerAttrType) - -from xdsl.irdl import Data +from .printer import Printer class ParseError(Exception): @@ -148,10 +143,10 @@ def print_with_context(self, msg: str | None = None) -> str: return capture.getvalue() def __repr__(self): - return "Span[{}:{}](text='{}')".format(self.start, self.end, self.text) + return "{}[{}:{}](text='{}')".format(self.__class__.__name__, self.start, self.end, self.text) -@dataclass(frozen=True) +@dataclass(frozen=True, repr=False) class StringLiteral(Span): def __post_init__(self): @@ -169,10 +164,6 @@ def string_contents(self): # TODO: is this a hack-job? return ast.literal_eval(self.text) - def __repr__(self): - return "StringLiteral[{}:{}](text='{}')".format( - self.start, self.end, self.text) - @dataclass(frozen=True) class Input: @@ -506,19 +497,8 @@ class ParserCommons: type_alias = re.compile(r'![A-Za-z_][\w$.]+') attribute_alias = re.compile(r'#[A-Za-z_][\w$.]+') boolean_literal = re.compile(r'(true|false)') - builtin_type = re.compile('(({}))'.format(')|('.join(( - r'[su]?i\d+', - r'f\d+', - 'tensor', - 'vector', - 'memref', - 'complex', - 'opaque', - 'tuple', - 'index', - # TODO: add all the Float8E4M3FNType, Float8E5M2Type, and BFloat16Type - )))) - builtin_type_xdsl = re.compile('!(({}))'.format(')|('.join(( + # a list of + _builtin_type_names = ( r'[su]?i\d+', r'f\d+', 'tensor', @@ -529,7 +509,9 @@ class ParserCommons: 'tuple', 'index', # TODO: add all the Float8E4M3FNType, Float8E5M2Type, and BFloat16Type - )))) + ) + builtin_type = re.compile('(({}))'.format(')|('.join(_builtin_type_names))) + builtin_type_xdsl = re.compile('!(({}))'.format(')|('.join(_builtin_type_names))) double_colon = re.compile('::') comma = re.compile(',') @@ -610,7 +592,7 @@ def must_parse_optional_block_label( arg_list = list() if block_id is not None: - assert block_id.text not in self.blocks, "Blocks cannot have the same ID!" + assert block_id.text not in self.blocks, "two blocks cannot have the same ID!" if self.tokenizer.starts_with('('): arg_list = self.must_parse_block_arg_list() @@ -1091,9 +1073,15 @@ def must_parse_attribute_entry(self) -> tuple[Span, Attribute]: def must_parse_attribute(self) -> Attribute: """ Parse attribute (either builtin or dialect) + + This is different in xDSL and MLIR, so the actuall implementation is provided by the subclass """ raise NotImplemented() + def try_parse_attribute(self) -> Attribute | None: + with self.tokenizer.backtracking('attribute'): + return self.must_parse_attribute() + def must_parse_attribute_type(self) -> Attribute: """ Parses `:` type and returns the type @@ -1112,12 +1100,25 @@ def try_parse_builtin_attr(self) -> Attribute: attrs = (self.try_parse_builtin_float_attr, self.try_parse_builtin_int_attr, self.try_parse_builtin_str_attr, - self.try_parse_builtin_arr_attr, self.try_parse_function_type) + self.try_parse_builtin_arr_attr, + self.try_parse_function_type, + self.try_parse_ref_attr) for attr_parser in attrs: if (val := attr_parser()) is not None: return val + def try_parse_ref_attr(self) -> FlatSymbolRefAttr | None: + if not self.tokenizer.starts_with('@'): + return None + + ref = self.must_parse_reference() + + if len(ref) > 1: + self.raise_error("Nested refs are not supported yet!", ref[1]) + + return FlatSymbolRefAttr.from_str(ref[0].text[1:]) + def try_parse_builtin_int_attr(self) -> IntegerAttr | None: bool = self.try_parse_builtin_boolean_attr() if bool is not None: @@ -1171,7 +1172,7 @@ def try_parse_builtin_arr_attr(self) -> list[Attribute] | None: with self.tokenizer.backtracking("array literal"): self.must_parse_characters('[', 'Array literals must start with `[`') - attrs = self.must_parse_list_of(self.must_parse_attribute, + attrs = self.must_parse_list_of(self.try_parse_attribute, 'Expected array entry!') self.must_parse_characters( ']', 'Array literals must be enclosed by square brackets!') @@ -1285,20 +1286,8 @@ def parse_op_with_default_format( This implicitly assumes XDSL format, and will fail on MLIR style operations """ # TODO: remove this function and restructure custom op / irdl parsing - - args = self.must_parse_op_args_list() - successors: list[Span] = [] - if self.tokenizer.next_token_of_pattern('(') is not None: - successors = self.must_parse_list_of(self.try_parse_block_id, - 'Malformed block-id!') - self.must_parse_characters( - ')', - 'Expected either a block id or the end of the successor list here' - ) - - attributes = self.must_parse_optional_attr_dict() - - regions = self.must_parse_region_list() + assert isinstance(self, XDSLParser) + args, successors, attributes, regions, _ = self.must_parse_operation_details() for x in args: if x.text not in self.ssaValues: @@ -1321,7 +1310,7 @@ def parse_paramattr_parameters( '<') is None and expect_brackets: self.raise_error("Expected start attribute parameters here (`<`)!") - res = self.must_parse_list_of(self.must_parse_attribute, + res = self.must_parse_list_of(self.try_parse_attribute, 'Expected another attribute here!') if self.tokenizer.next_token_of_pattern( @@ -1557,14 +1546,14 @@ def must_parse_optional_attr_dict(self) -> dict[str, Attribute]: self.must_parse_characters( '[', - 'xDSL Attribute dictionary must be enclosed in curly brackets') + 'xDSL Attribute dictionary must be enclosed in square brackets') attrs = self.must_parse_list_of(self.must_parse_attribute_entry, "Expected attribute entry") self.must_parse_characters( ']', - 'xDSL Attribute dictionary must be enclosed in curly brackets') + 'xDSL Attribute dictionary must be enclosed in square brackets') return self.attr_dict_from_tuple_list(attrs) @@ -1590,15 +1579,15 @@ def must_parse_operation_details( return args, succ, attrs, regions, None def must_parse_optional_successor_list(self) -> list[Span]: - if not self.tokenizer.starts_with('['): + if not self.tokenizer.starts_with('('): return [] self.must_parse_characters( - '[', 'Successor list is enclosed in square brackets') + '(', 'Successor list is enclosed in round brackets') successors = self.must_parse_list_of(self.try_parse_block_id, 'Expected a block-id', allow_empty=False) self.must_parse_characters( - ']', 'Successor list is enclosed in square brackets') + ')', 'Successor list is enclosed in round brackets') return successors @@ -1614,7 +1603,6 @@ def must_parse_optional_successor_list(self) -> list[Span]: from xdsl.dialects.llvm import LLVM from xdsl.dialects.memref import MemRef from xdsl.dialects.scf import Scf - import os ctx = MLContext() ctx.register_dialect(Builtin) From 7f410098ef705ce5a60aed7b7874886f3367093a Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Wed, 11 Jan 2023 16:24:33 +0000 Subject: [PATCH 19/65] parser: add proper handling of forward references to blocks --- xdsl/ir.py | 2 ++ xdsl/parser_ng.py | 75 +++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 68 insertions(+), 9 deletions(-) diff --git a/xdsl/ir.py b/xdsl/ir.py index e6672ca46c..d5440e30c1 100644 --- a/xdsl/ir.py +++ b/xdsl/ir.py @@ -623,6 +623,8 @@ def irdl_definition(cls) -> OpDef: class Block(IRNode): """A sequence of operations""" + delcared_at: 'Span' | None = None + _args: FrozenList[BlockArgument] = field(default_factory=FrozenList, init=False) """The basic block arguments.""" diff --git a/xdsl/parser_ng.py b/xdsl/parser_ng.py index e49d4f044c..fa7008c26c 100644 --- a/xdsl/parser_ng.py +++ b/xdsl/parser_ng.py @@ -2,10 +2,13 @@ import ast import contextlib +import functools +import itertools import re import sys import traceback from abc import ABC, abstractmethod +from collections import defaultdict from dataclasses import dataclass, field from io import StringIO from typing import TypeVar, Iterable @@ -42,6 +45,21 @@ def print_with_history(self): self.history.print_unroll() +class MultipleSpansParseError(ParseError): + ref_text: str | None + refs: list[tuple[Span, str]] + def __init__(self, span: Span, msg: str, ref_text: str, refs: list[tuple[Span, str | None]], history: BacktrackingHistory | None = None): + super(MultipleSpansParseError, self).__init__(span, msg, history) + self.refs = refs + self.ref_text = ref_text + + def print_pretty(self, file=sys.stderr): + super(MultipleSpansParseError, self).print_pretty(file) + print(self.ref_text or "With respect to:", file=file) + for span, msg in self.refs: + print(span.print_with_context(msg), file=file) + + @dataclass class BacktrackingHistory: error: ParseError @@ -542,6 +560,10 @@ class BaseParser(ABC): ssaValues: dict[str, SSAValue] blocks: dict[str, Block] + forward_block_references: dict[str, list[Span]] + """ + Blocks we encountered references to before the definition (must be empty after parsing of region completes) + """ T_ = TypeVar('T_') """ @@ -559,6 +581,7 @@ def __init__( self.ctx = ctx self.ssaValues = dict() self.blocks = dict() + self.forward_block_references = set() def begin_parse(self): ops = [] @@ -568,12 +591,36 @@ def begin_parse(self): self.raise_error("Could not parse entire input!") return ops + def get_block_from_name(self, block_name: Span): + """ + This function takes a span containing a block id (like `^42`) and returns a block. + + If the block defintion was not seen yet, we create a forward declaration. + """ + name = block_name.text + if name not in self.blocks: + self.forward_block_references[name].append(block_name) + self.blocks[name] = Block() + return self.blocks[name] + def must_parse_block(self) -> Block: block_id, args = self.must_parse_optional_block_label() - block = Block() - if block_id is not None: - assert block_id.text not in self.blocks + if block_id is None: + block = Block(self.tokenizer.last_token) + elif self.forward_block_references.pop(block_id.text, None) is not None: + block = self.blocks[block_id.text] + block.delcared_at = block_id + else: + if block_id.text in self.blocks: + raise MultipleSpansParseError( + block_id, + "Re-declaration of block {}".format(block_id.text), + 'Originally declared here:', + [(self.blocks[block_id.text].delcared_at, None)], + self.tokenizer.history + ) + block = Block(block_id) self.blocks[block_id.text] = block for i, (name, type) in enumerate(args): @@ -592,8 +639,6 @@ def must_parse_optional_block_label( arg_list = list() if block_id is not None: - assert block_id.text not in self.blocks, "two blocks cannot have the same ID!" - if self.tokenizer.starts_with('('): arg_list = self.must_parse_block_arg_list() @@ -1021,8 +1066,10 @@ def try_parse_operation(self) -> Operation | None: def must_parse_region(self) -> Region: oldSSAVals = self.ssaValues.copy() - oldBBNames = self.blocks.copy() - self.blocks = dict[str, Block]() + oldBBNames = self.blocks + oldForwardRefs = self.forward_block_references + self.blocks = dict() + self.forward_block_references = defaultdict(list) region = Region() @@ -1036,13 +1083,23 @@ def must_parse_region(self) -> Region: while self.tokenizer.starts_with('^'): region.add_block(self.must_parse_block()) - self.must_parse_characters('}', + end = self.must_parse_characters('}', 'Reached end of region, expected `}`!') + if len(self.forward_block_references) > 0: + raise MultipleSpansParseError( + end, + "Region ends with missing block declarations for block(s) {}!".format(', '.join(self.forward_block_references.keys())), + 'The following block references are dangling:', + [(span, "Reference to block \"{}\" without implementation!".format(span.text)) for span in itertools.chain(*self.forward_block_references.values())], + self.tokenizer.history + ) + return region finally: self.ssaValues = oldSSAVals self.blocks = oldBBNames + self.forward_block_references = oldForwardRefs def try_parse_op_name(self) -> Span | None: if (str_lit := self.try_parse_string_literal()) is not None: @@ -1299,7 +1356,7 @@ def parse_op_with_default_format( operands=[self.ssaValues[span.text] for span in args], result_types=result_types, attributes=attributes, - successors=[self.blocks[span.text] for span in successors], + successors=[self.get_block_from_name(span) for span in successors], regions=regions) def parse_paramattr_parameters( From 4404ffc27553db690f2905ae3c238d9e68bc3382 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Wed, 11 Jan 2023 16:31:26 +0000 Subject: [PATCH 20/65] parser: improved compatibility layer to support parsing of cmath.xdsl --- xdsl/parser_ng.py | 116 +++++++++++++++++++++++++--------------------- 1 file changed, 62 insertions(+), 54 deletions(-) diff --git a/xdsl/parser_ng.py b/xdsl/parser_ng.py index fa7008c26c..a1f21ec922 100644 --- a/xdsl/parser_ng.py +++ b/xdsl/parser_ng.py @@ -581,7 +581,7 @@ def __init__( self.ctx = ctx self.ssaValues = dict() self.blocks = dict() - self.forward_block_references = set() + self.forward_block_references = dict() def begin_parse(self): ops = [] @@ -1325,59 +1325,6 @@ def must_parse_region_list(self) -> list[Region]: regions.append(self.must_parse_region()) return regions - # HERE STARTS A SOMEWHAT CURSED COMPATIBILITY LAYER: - # since we don't want to rewrite all dialects currently, the new emulator needs to expose the same - # interface to the dialect definitions. Here we implement that interface. - - _OperationType = TypeVar('_OperationType', bound=Operation) - - def parse_op_with_default_format( - self, - op_type: type[_OperationType], - result_types: list[Attribute], - skip_white_space: bool = True) -> _OperationType: - """ - Compatibility wrapper so the new parser can be passed instead of the old one. Parses everything after the - operation name. - - This implicitly assumes XDSL format, and will fail on MLIR style operations - """ - # TODO: remove this function and restructure custom op / irdl parsing - assert isinstance(self, XDSLParser) - args, successors, attributes, regions, _ = self.must_parse_operation_details() - - for x in args: - if x.text not in self.ssaValues: - self.raise_error( - "Unknown SSAValue name, known SSA Values are: {}".format( - ", ".join(self.ssaValues.keys())), x) - - return op_type.create( - operands=[self.ssaValues[span.text] for span in args], - result_types=result_types, - attributes=attributes, - successors=[self.get_block_from_name(span) for span in successors], - regions=regions) - - def parse_paramattr_parameters( - self, - expect_brackets: bool = False, - skip_white_space: bool = True) -> list[Attribute]: - if self.tokenizer.next_token_of_pattern( - '<') is None and expect_brackets: - self.raise_error("Expected start attribute parameters here (`<`)!") - - res = self.must_parse_list_of(self.try_parse_attribute, - 'Expected another attribute here!') - - if self.tokenizer.next_token_of_pattern( - '>') is None and expect_brackets: - self.raise_error( - "Malformed parameter list, expected either another parameter or `>`!" - ) - - return res - # COMMON xDSL/MLIR code: def must_parse_builtin_type_with_name(self, name: Span): if name.text == 'index': @@ -1436,6 +1383,67 @@ def must_parse_op_args_list(self) -> list[Span]: # TODO: check if type is correct here! return [name for name, _ in args] + # HERE STARTS A SOMEWHAT CURSED COMPATIBILITY LAYER: + # since we don't want to rewrite all dialects currently, the new emulator needs to expose the same + # interface to the dialect definitions. Here we implement that interface. + + _OperationType = TypeVar('_OperationType', bound=Operation) + + def parse_op_with_default_format( + self, + op_type: type[_OperationType], + result_types: list[Attribute], + skip_white_space: bool = True) -> _OperationType: + """ + Compatibility wrapper so the new parser can be passed instead of the old one. Parses everything after the + operation name. + + This implicitly assumes XDSL format, and will fail on MLIR style operations + """ + # TODO: remove this function and restructure custom op / irdl parsing + assert isinstance(self, XDSLParser) + args, successors, attributes, regions, _ = self.must_parse_operation_details() + + for x in args: + if x.text not in self.ssaValues: + self.raise_error( + "Unknown SSAValue name, known SSA Values are: {}".format( + ", ".join(self.ssaValues.keys())), x) + + return op_type.create( + operands=[self.ssaValues[span.text] for span in args], + result_types=result_types, + attributes=attributes, + successors=[self.get_block_from_name(span) for span in successors], + regions=regions) + + def parse_paramattr_parameters( + self, + expect_brackets: bool = False, + skip_white_space: bool = True) -> list[Attribute]: + if self.tokenizer.next_token_of_pattern( + '<') is None and expect_brackets: + self.raise_error("Expected start attribute parameters here (`<`)!") + + res = self.must_parse_list_of(self.try_parse_attribute, + 'Expected another attribute here!') + + if self.tokenizer.next_token_of_pattern( + '>') is None and expect_brackets: + self.raise_error( + "Malformed parameter list, expected either another parameter or `>`!" + ) + + return res + + def parse_char(self, text: str) -> Span: + self.must_parse_characters(text, "Expected {} here!".format(text)) + + def parse_str_literal(self) -> str: + return self.expect(self.try_parse_string_literal, 'Malformed string literal!').string_contents + + def parse_attribute(self) -> Attribute: + return self.must_parse_attribute() class MLIRParser(BaseParser): From 3a06393df3bd18264d3fab08e99b2d6ee0ec68fe Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Thu, 12 Jan 2023 10:09:55 +0000 Subject: [PATCH 21/65] parser: fixed a bug with expect_brackets in parse_paramattr_parameters --- xdsl/parser_ng.py | 66 ++++++++++++++++++++++++++++------------------- 1 file changed, 39 insertions(+), 27 deletions(-) diff --git a/xdsl/parser_ng.py b/xdsl/parser_ng.py index a1f21ec922..2c28783813 100644 --- a/xdsl/parser_ng.py +++ b/xdsl/parser_ng.py @@ -14,10 +14,10 @@ from typing import TypeVar, Iterable from xdsl.dialects.builtin import ( - AnyTensorType, AnyVectorType, - Float16Type, Float32Type, Float64Type, FloatAttr, - FunctionType, IndexType, IntegerType, Signedness, StringAttr, - IntegerAttr, ArrayAttr, TensorType, UnrankedTensorType, VectorType, DefaultIntegerAttrType, FlatSymbolRefAttr) + AnyTensorType, AnyVectorType, Float16Type, Float32Type, Float64Type, + FloatAttr, FunctionType, IndexType, IntegerType, Signedness, StringAttr, + IntegerAttr, ArrayAttr, TensorType, UnrankedTensorType, VectorType, + DefaultIntegerAttrType, FlatSymbolRefAttr) from xdsl.ir import (SSAValue, Block, Callable, Attribute, Operation, Region, BlockArgument, MLContext, ParametrizedAttribute) from .printer import Printer @@ -48,7 +48,13 @@ def print_with_history(self): class MultipleSpansParseError(ParseError): ref_text: str | None refs: list[tuple[Span, str]] - def __init__(self, span: Span, msg: str, ref_text: str, refs: list[tuple[Span, str | None]], history: BacktrackingHistory | None = None): + + def __init__(self, + span: Span, + msg: str, + ref_text: str, + refs: list[tuple[Span, str | None]], + history: BacktrackingHistory | None = None): super(MultipleSpansParseError, self).__init__(span, msg, history) self.refs = refs self.ref_text = ref_text @@ -161,7 +167,8 @@ def print_with_context(self, msg: str | None = None) -> str: return capture.getvalue() def __repr__(self): - return "{}[{}:{}](text='{}')".format(self.__class__.__name__, self.start, self.end, self.text) + return "{}[{}:{}](text='{}')".format(self.__class__.__name__, + self.start, self.end, self.text) @dataclass(frozen=True, repr=False) @@ -529,7 +536,8 @@ class ParserCommons: # TODO: add all the Float8E4M3FNType, Float8E5M2Type, and BFloat16Type ) builtin_type = re.compile('(({}))'.format(')|('.join(_builtin_type_names))) - builtin_type_xdsl = re.compile('!(({}))'.format(')|('.join(_builtin_type_names))) + builtin_type_xdsl = re.compile('!(({}))'.format( + ')|('.join(_builtin_type_names))) double_colon = re.compile('::') comma = re.compile(',') @@ -608,7 +616,8 @@ def must_parse_block(self) -> Block: if block_id is None: block = Block(self.tokenizer.last_token) - elif self.forward_block_references.pop(block_id.text, None) is not None: + elif self.forward_block_references.pop(block_id.text, + None) is not None: block = self.blocks[block_id.text] block.delcared_at = block_id else: @@ -618,8 +627,7 @@ def must_parse_block(self) -> Block: "Re-declaration of block {}".format(block_id.text), 'Originally declared here:', [(self.blocks[block_id.text].delcared_at, None)], - self.tokenizer.history - ) + self.tokenizer.history) block = Block(block_id) self.blocks[block_id.text] = block @@ -1083,17 +1091,19 @@ def must_parse_region(self) -> Region: while self.tokenizer.starts_with('^'): region.add_block(self.must_parse_block()) - end = self.must_parse_characters('}', - 'Reached end of region, expected `}`!') + end = self.must_parse_characters( + '}', 'Reached end of region, expected `}`!') if len(self.forward_block_references) > 0: raise MultipleSpansParseError( end, - "Region ends with missing block declarations for block(s) {}!".format(', '.join(self.forward_block_references.keys())), + "Region ends with missing block declarations for block(s) {}!" + .format(', '.join(self.forward_block_references.keys())), 'The following block references are dangling:', - [(span, "Reference to block \"{}\" without implementation!".format(span.text)) for span in itertools.chain(*self.forward_block_references.values())], - self.tokenizer.history - ) + [(span, "Reference to block \"{}\" without implementation!" + .format(span.text)) for span in itertools.chain( + *self.forward_block_references.values())], + self.tokenizer.history) return region finally: @@ -1157,8 +1167,7 @@ def try_parse_builtin_attr(self) -> Attribute: attrs = (self.try_parse_builtin_float_attr, self.try_parse_builtin_int_attr, self.try_parse_builtin_str_attr, - self.try_parse_builtin_arr_attr, - self.try_parse_function_type, + self.try_parse_builtin_arr_attr, self.try_parse_function_type, self.try_parse_ref_attr) for attr_parser in attrs: @@ -1232,7 +1241,7 @@ def try_parse_builtin_arr_attr(self) -> list[Attribute] | None: attrs = self.must_parse_list_of(self.try_parse_attribute, 'Expected array entry!') self.must_parse_characters( - ']', 'Array literals must be enclosed by square brackets!') + ']', 'Malformed array contents (expected end of array here!') return ArrayAttr.from_list(attrs) @abstractmethod @@ -1402,7 +1411,8 @@ def parse_op_with_default_format( """ # TODO: remove this function and restructure custom op / irdl parsing assert isinstance(self, XDSLParser) - args, successors, attributes, regions, _ = self.must_parse_operation_details() + args, successors, attributes, regions, _ = self.must_parse_operation_details( + ) for x in args: if x.text not in self.ssaValues: @@ -1421,30 +1431,32 @@ def parse_paramattr_parameters( self, expect_brackets: bool = False, skip_white_space: bool = True) -> list[Attribute]: - if self.tokenizer.next_token_of_pattern( - '<') is None and expect_brackets: + opening_brackets = self.tokenizer.next_token_of_pattern('<') + if expect_brackets and opening_brackets is None: self.raise_error("Expected start attribute parameters here (`<`)!") res = self.must_parse_list_of(self.try_parse_attribute, 'Expected another attribute here!') - if self.tokenizer.next_token_of_pattern( - '>') is None and expect_brackets: + if opening_brackets is not None and self.tokenizer.next_token_of_pattern( + '>') is None: self.raise_error( "Malformed parameter list, expected either another parameter or `>`!" ) return res - def parse_char(self, text: str) -> Span: - self.must_parse_characters(text, "Expected {} here!".format(text)) + def parse_char(self, text: str): + self.must_parse_characters(text, "Expected '{}' here!".format(text)) def parse_str_literal(self) -> str: - return self.expect(self.try_parse_string_literal, 'Malformed string literal!').string_contents + return self.expect(self.try_parse_string_literal, + 'Malformed string literal!').string_contents def parse_attribute(self) -> Attribute: return self.must_parse_attribute() + class MLIRParser(BaseParser): def try_parse_builtin_type(self) -> Attribute | None: From 7b94084dd5fd8aed2f10ade8788b709d5cbc7e6a Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Thu, 12 Jan 2023 18:00:58 +0000 Subject: [PATCH 22/65] xdsl: disallow invalid ssa value name hints in the class --- tests/test_ssa_value.py | 32 ++++++++++++++++++++++++++++++++ xdsl/ir.py | 12 +++++++++++- 2 files changed, 43 insertions(+), 1 deletion(-) create mode 100644 tests/test_ssa_value.py diff --git a/tests/test_ssa_value.py b/tests/test_ssa_value.py new file mode 100644 index 0000000000..2847a49233 --- /dev/null +++ b/tests/test_ssa_value.py @@ -0,0 +1,32 @@ +from io import StringIO +from typing import Callable + +import pytest + +from xdsl.dialects.arith import Arith, Constant, Addi +from xdsl.dialects.builtin import ModuleOp, Builtin, i32 +from xdsl.dialects.scf import Scf, Yield +from xdsl.dialects.func import Func +from xdsl.ir import MLContext, Block, SSAValue, OpResult, BlockArgument +from xdsl.parser import Parser +from xdsl.printer import Printer +from xdsl.rewriter import Rewriter + + +@pytest.mark.parametrize("name,result", [ + ('a', 'a'), + ('test', 'test'), + ('test1', None), + ('1', None), +]) +def test_ssa_value_name_hints(name, result): + """ + The rewriter assumes, that ssa value name hints (their .name field) does not end in a numeric value. If it does, + it will generate broken rewrites that potentially assign twice to an SSA value. + + Therefore, the SSAValue class prevents the setting of names ending in a number. + """ + val = BlockArgument(i32, Block(), 0) + + val.name = name + assert val.name == result diff --git a/xdsl/ir.py b/xdsl/ir.py index d5440e30c1..07370288e1 100644 --- a/xdsl/ir.py +++ b/xdsl/ir.py @@ -118,7 +118,17 @@ class SSAValue(ABC): uses: set[Use] = field(init=False, default_factory=set, repr=False) """All uses of the value.""" - name: str | None = field(init=False, default=None) + _name: str | None = field(init=False, default=None) + + @property + def name(self) -> str | None: + return self._name + + @name.setter + def name(self, name: str): + if name[-1].isnumeric(): + return + self._name = name @staticmethod def get(arg: SSAValue | Operation) -> SSAValue: From ed3a997d529e666bf4561c03eeb262e4dbdc51e2 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Fri, 13 Jan 2023 11:43:20 +0000 Subject: [PATCH 23/65] parser: fixing tests to work with the new parser --- tests/test_mlir_printer.py | 18 +- tests/test_parser.py | 4 +- tests/test_parser_error.py | 21 +- tests/test_pattern_rewriter.py | 2 + tests/test_printer.py | 35 +- xdsl/dialects/builtin.py | 4 +- xdsl/parser.py | 2792 ++++++++++++++++++-------------- xdsl/parser_ng.py | 1707 ------------------- 8 files changed, 1649 insertions(+), 2934 deletions(-) delete mode 100644 xdsl/parser_ng.py diff --git a/tests/test_mlir_printer.py b/tests/test_mlir_printer.py index 38aed4e260..20f86e0f25 100644 --- a/tests/test_mlir_printer.py +++ b/tests/test_mlir_printer.py @@ -1,14 +1,11 @@ +import re from io import StringIO from typing import Annotated -import re -from xdsl.dialects.builtin import Builtin -from xdsl.dialects.memref import MemRef -from xdsl.dialects.func import Func from xdsl.ir import Attribute, Data, MLContext, MLIRType, Operation, ParametrizedAttribute -from xdsl.irdl import (AnyAttr, ParameterDef, RegionDef, VarOpResult, - VarOperand, irdl_attr_definition, irdl_op_definition) -from xdsl.parser import Parser +from xdsl.irdl import (AnyAttr, ParameterDef, RegionDef, irdl_attr_definition, irdl_op_definition, VarOperand, + VarOpResult) +from xdsl.parser import Parser, ParseError from xdsl.printer import Printer @@ -93,7 +90,12 @@ def print_as_mlir_and_compare(test_prog: str, expected: str): ctx.register_attr(ParamAttrWithCustomFormat) parser = Parser(ctx, test_prog) - module = parser.parse_op() + try: + module = parser.parse_op() + except ParseError as err: + io = StringIO() + err.print_with_history(file=io) + raise ParseError(err.span, io.getvalue(), None) res = StringIO() printer = Printer(target=Printer.Target.MLIR, stream=res) diff --git a/tests/test_parser.py b/tests/test_parser.py index 8e3c90707d..3322fac40e 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -11,8 +11,8 @@ def test_int_list_parser(input: str, expected: list[int]): ctx = MLContext() parser = Parser(ctx, input) - int_list = parser.parse_list(parser.parse_int_literal) - assert int_list == expected + int_list = parser.must_parse_list_of(parser.try_parse_integer_literal, '') + assert [int(span.text) for span in int_list] == expected @pytest.mark.parametrize("input,expected", [('{"A"=0, "B"=1, "C"=2}', { diff --git a/tests/test_parser_error.py b/tests/test_parser_error.py index 46dab27b1e..4f853f86a8 100644 --- a/tests/test_parser_error.py +++ b/tests/test_parser_error.py @@ -1,8 +1,8 @@ from __future__ import annotations from typing import Annotated -from xdsl.ir import MLContext -from xdsl.irdl import AnyAttr, VarOpResult, VarOperand, irdl_op_definition, Operation +from xdsl.ir import MLContext, OpResult, SSAValue +from xdsl.irdl import AnyAttr, VarOperandDef, VarResultDef, irdl_op_definition, Operation from xdsl.parser import Parser, ParserError from pytest import raises @@ -19,13 +19,12 @@ def check_error(prog: str, line: int, column: int, message: str): ctx.register_op(UnkownOp) parser = Parser(ctx, prog) - with raises(ParserError) as e: - parser.parse_op() + with raises(ParseError) as e: + parser.must_parse_operation() - assert e.value.pos - assert e.value.pos.line is line - assert e.value.pos.column is column - assert e.value.message == message + assert e.value.span + assert e.value.span.get_line_col() == (line, column) + assert any(message in ex.error.msg for ex in e.value.history.iterate()) def test_parser_missing_equal(): @@ -39,7 +38,7 @@ def test_parser_missing_equal(): %0 : !i32 unknown() } """ - check_error(prog, 3, 13, "'=' expected, got 'u'") + check_error(prog, 3, 13, "Operation definitions expect an `=` after op-result-list!") def test_parser_redefined_value(): @@ -54,7 +53,7 @@ def test_parser_redefined_value(): %val : !i32 = unknown() } """ - check_error(prog, 4, 3, "SSA value val is already defined") + check_error(prog, 4, 2, "SSA value %val is already defined") def test_parser_missing_operation_name(): @@ -68,7 +67,7 @@ def test_parser_missing_operation_name(): %val : !i32 = } """ - check_error(prog, 4, 1, "operation name expected") + check_error(prog, 3, 13, "Expected an operation name here") def test_parser_missing_attribute(): diff --git a/tests/test_pattern_rewriter.py b/tests/test_pattern_rewriter.py index 510e0929e9..d34ba36726 100644 --- a/tests/test_pattern_rewriter.py +++ b/tests/test_pattern_rewriter.py @@ -22,6 +22,8 @@ def rewrite_and_compare(prog: str, expected_prog: str, parser = Parser(ctx, prog) module = parser.parse_op() + assert isinstance(module, ModuleOp) + walker.rewrite_module(module) file = StringIO("") printer = Printer(stream=file) diff --git a/tests/test_printer.py b/tests/test_printer.py index 93f2ec7e5a..487f7d27d6 100644 --- a/tests/test_printer.py +++ b/tests/test_printer.py @@ -1,17 +1,17 @@ from __future__ import annotations +import re from io import StringIO from typing import List, Annotated -from xdsl.dialects.func import Func, FuncOp -from xdsl.dialects.builtin import Builtin, IntAttr, ModuleOp, IntegerType, UnitAttr from xdsl.dialects.arith import Arith, Addi, Constant - -from xdsl.ir import Attribute, MLContext, OpResult, ParametrizedAttribute +from xdsl.dialects.builtin import Builtin, IntAttr, ModuleOp, IntegerType, UnitAttr +from xdsl.dialects.func import Func +from xdsl.ir import Attribute, MLContext, OpResult, ParametrizedAttribute, SSAValue from xdsl.irdl import (ParameterDef, irdl_attr_definition, irdl_op_definition, Operation, Operand, OptAttributeDef) +from xdsl.parser import Parser, BaseParser, Span from xdsl.printer import Printer -from xdsl.parser import Parser from xdsl.utils.diagnostic import Diagnostic @@ -382,12 +382,17 @@ class PlusCustomFormatOp(Operation): @classmethod def parse(cls, result_types: List[Attribute], - parser: Parser) -> PlusCustomFormatOp: - lhs = parser.parse_ssa_value() - parser.skip_white_space() + parser: BaseParser) -> PlusCustomFormatOp: + def get_ssa_val(name: Span) -> SSAValue: + if name.text not in parser.ssaValues: + parser.raise_error('Unknown SSA Value name', name) + return parser.ssaValues[name.text] + + lhs = parser.expect(parser.try_parse_value_id, 'Expected SSA Value name here!') parser.parse_char("+") - rhs = parser.parse_ssa_value() - return PlusCustomFormatOp.create(operands=[lhs, rhs], + rhs = parser.expect(parser.try_parse_value_id, 'Expected SSA Value name here!') + + return PlusCustomFormatOp.create(operands=[get_ssa_val(name) for name in (lhs, rhs)], result_types=result_types) def print(self, printer: Printer): @@ -494,13 +499,13 @@ class CustomFormatAttr(ParametrizedAttribute): attr: ParameterDef[IntAttr] @staticmethod - def parse_parameters(parser: Parser) -> list[Attribute]: + def parse_parameters(parser: BaseParser) -> list[Attribute]: parser.parse_char("<") - value = parser.parse_alpha_num(skip_white_space=False) - if value == "zero": + value = parser.tokenizer.next_token_of_pattern(re.compile('(zero|one)')) + if value and value.text == "zero": parser.parse_char(">") return [IntAttr.from_int(0)] - if value == "one": + if value and value.text == "one": parser.parse_char(">") return [IntAttr.from_int(1)] assert False @@ -550,7 +555,7 @@ def test_parse_generic_format_attr(): """ prog = \ """builtin.module() { - any() ["attr" = !"custom">] + any() ["attr" = #"custom"<#int<0>>] }""" expected = \ diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index 08cfef100f..a35520c783 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -16,7 +16,7 @@ from xdsl.utils.exceptions import VerifyException if TYPE_CHECKING: - from xdsl.parser import Parser, ParserError + from xdsl.parser import Parser, ParseError from xdsl.printer import Printer @@ -117,7 +117,7 @@ def parse_parameter(parser: Parser) -> Signedness: return Signedness.SIGNED elif parser.parse_optional_string("unsigned") is not None: return Signedness.UNSIGNED - raise ParserError(parser.get_pos(), "Expected signedness") + raise ParseError(parser.get_pos(), "Expected signedness") @staticmethod def print_parameter(data: Signedness, printer: Printer) -> None: diff --git a/xdsl/parser.py b/xdsl/parser.py index ad01f0ebb4..fc8300f0fb 100644 --- a/xdsl/parser.py +++ b/xdsl/parser.py @@ -1,1336 +1,1750 @@ from __future__ import annotations +import ast +import contextlib +import functools +import itertools +import re +import sys +import traceback +from abc import ABC, abstractmethod +from collections import defaultdict from dataclasses import dataclass, field from enum import Enum -from typing import Any, TypeVar +from io import StringIO +from typing import TypeVar, Iterable -from xdsl.dialects.memref import MemRefType, UnrankedMemrefType +from xdsl.dialects.builtin import ( + AnyTensorType, AnyVectorType, Float16Type, Float32Type, Float64Type, + FloatAttr, FunctionType, IndexType, IntegerType, Signedness, StringAttr, + IntegerAttr, ArrayAttr, TensorType, UnrankedTensorType, VectorType, + DefaultIntegerAttrType, FlatSymbolRefAttr) from xdsl.ir import (SSAValue, Block, Callable, Attribute, Operation, Region, - BlockArgument, MLContext, ParametrizedAttribute) + BlockArgument, MLContext, ParametrizedAttribute, Data) +from .printer import Printer + + +class ParseError(Exception): + span: Span + msg: str + history: BacktrackingHistory | None + + def __init__(self, + span: Span, + msg: str, + history: BacktrackingHistory | None = None): + preamble = "" + if history: + io = StringIO() + history.print_unroll(io) + preamble = io.getvalue() + '\n' + + super().__init__(preamble + span.print_with_context(msg)) + self.span = span + self.msg = msg + self.history = history + + def print_pretty(self, file=sys.stderr): + print(self.span.print_with_context(self.msg), file=file) + + def print_with_history(self, file=sys.stderr): + if self.history is not None: + for h in sorted(self.history.iterate(), key=lambda h: -h.pos): + h.print() + else: + self.print_pretty(file) + + def __repr__(self): + io = StringIO() + self.print_with_history(io) + return "{}:\n{}".format( + self.__class__.__name__, + io.getvalue() + ) + + +class MultipleSpansParseError(ParseError): + ref_text: str | None + refs: list[tuple[Span, str]] + + def __init__(self, + span: Span, + msg: str, + ref_text: str, + refs: list[tuple[Span, str | None]], + history: BacktrackingHistory | None = None): + super(MultipleSpansParseError, self).__init__(span, msg, history) + self.refs = refs + self.ref_text = ref_text + + def print_pretty(self, file=sys.stderr): + super(MultipleSpansParseError, self).print_pretty(file) + print(self.ref_text or "With respect to:", file=file) + for span, msg in self.refs: + print(span.print_with_context(msg), file=file) -from xdsl.dialects.builtin import ( - AnyFloat, AnyTensorType, AnyUnrankedTensorType, AnyVectorType, - DenseIntOrFPElementsAttr, Float16Type, Float32Type, Float64Type, FloatAttr, - FunctionType, IndexType, IntegerType, OpaqueAttr, Signedness, StringAttr, - FlatSymbolRefAttr, IntegerAttr, ArrayAttr, TensorType, UnitAttr, - UnrankedTensorType, UnregisteredOp, VectorType, DictionaryAttr) -from xdsl.irdl import Data -indentNumSpaces = 2 +@dataclass +class BacktrackingHistory: + error: ParseError + parent: BacktrackingHistory | None + region_name: str | None + pos: int + + def print_unroll(self, file=sys.stderr): + if self.parent: + if self.parent.get_farthest_point() > self.pos: + self.parent.print_unroll(file) + self.print(file) + else: + self.print(file) + self.parent.print_unroll(file) + def print(self, file=sys.stderr): + print("Parsing of {} failed:".format(self.region_name or ''), + file=file) + self.error.print_pretty(file=file) -@dataclass(frozen=True) -class Position: - """A position in a file""" + @functools.cache + def get_farthest_point(self) -> int: + """ + Find the farthest this history managed to parse + """ + if self.parent: + return max(self.pos, self.parent.get_farthest_point()) + return self.pos + + def iterate(self) -> Iterable[BacktrackingHistory]: + yield self + if self.parent: + yield from self.parent.iterate() + + def __hash__(self): + return id(self) + + +class BacktrackingAbort(Exception): + reason: str | None - file: str + def __init__(self, reason: str | None = None): + super().__init__( + "This message should never escape the parser, it's intended to signal a failed parsing " + "attempt\n " + "It should never be used outside of a tokenizer.backtracking() block!\n" + "The reason for this abort was {}".format( + 'not specified' if reason is None else reason)) + self.reason = reason + + +@dataclass(frozen=True) +class Span: """ - A handle to the file contents. The position is relative to this file. + Parts of the input are always passed around as spans, so we know where they originated. """ - idx: int = field(default=0) + start: int + """ + Start of tokens location in source file, global byte offset in file + """ + end: int + """ + End of tokens location in source file, global byte offset in file """ - The character index in the entire file. - A line break is consider to be a character here. + input: Input """ + The input being operated on + """ + + def __len__(self): + return self.len - line: int = field(default=1) - """The line index.""" + @property + def len(self): + return self.end - self.start - column: int = field(default=1) - """The character index in the current line.""" + @property + def text(self): + return self.input.content[self.start:self.end] - def __str__(self): - return f"{self.line}:{self.column}" + def get_line_col(self) -> tuple[int, int]: + info = self.input.get_lines_containing(self) + if info is None: + return -1, -1 + lines, offset_of_first_line, line_no = info + return line_no, self.start - offset_of_first_line - def next_char_pos(self, n: int = 1) -> Position | None: - """Return the position of the next character in the string.""" - if self.idx >= len(self.file) - n: + def print_with_context(self, msg: str | None = None) -> str: + """ + returns a string containing lines relevant to the span. The Span's contents + are highlighted by up-carets beneath them (`^`). The message msg is printed + along these. + """ + info = self.input.get_lines_containing(self) + if info is None: + return "Unknown location of span {}. Error: ".format(self, msg) + lines, offset_of_first_line, line_no = info + # offset relative to the first line: + offset = self.start - offset_of_first_line + remaining_len = max(self.len, 1) + capture = StringIO() + print("{}:{}:{}".format(self.input.name, line_no, offset), file=capture) + for line in lines: + print(line, file=capture) + if remaining_len < 0: + continue + len_on_this_line = min(remaining_len, len(line) - offset) + remaining_len -= len_on_this_line + print("{}{}".format(" " * offset, "^" * max(len_on_this_line, 1)), + file=capture) + if msg is not None: + print("{}{}".format(" " * offset, msg), file=capture) + msg = None + offset = 0 + if msg is not None: + print(msg, file=capture) + return capture.getvalue() + + def __repr__(self): + return "{}[{}:{}](text='{}')".format(self.__class__.__name__, + self.start, self.end, self.text) + + +@dataclass(frozen=True, repr=False) +class StringLiteral(Span): + + def __post_init__(self): + if len(self) < 2 or self.text[0] != '"' or self.text[-1] != '"': + raise ParseError(self, "Invalid string literal!") + + @classmethod + def from_span(cls, span: Span | None) -> StringLiteral | None: + if span is None: return None - new_idx = self.idx - new_line = self.line - new_column = self.column - while n > 0: - if self.file[new_idx] == '\n': - new_line += 1 - new_column = 1 - else: - new_column += 1 - new_idx += 1 - n -= 1 - assert new_idx < len(self.file) - return Position(self.file, new_idx, new_line, new_column) - - def get_char(self) -> str: - """Return the character at the current position.""" - assert self.idx < len(self.file) - return self.file[self.idx] - - def get_current_line(self) -> str: - """Return the current line.""" - assert self.idx < len(self.file) - start_idx = self.idx - self.column + 1 - end_idx = self.idx - while self.file[end_idx] != '\n': - end_idx += 1 - return self.file[start_idx:end_idx] + return cls(span.start, span.end, span.input) + @property + def string_contents(self): + # TODO: is this a hack-job? + return ast.literal_eval(self.text) -@dataclass -class ParserError(Exception): - """An error triggered during parsing.""" - pos: Position | None - message: str +@dataclass(frozen=True) +class Input: + """ + This is a very simple class that is used to keep track of the input. + """ + content: str = field(repr=False) + name: str + + @property + def len(self): + return len(self.content) + + def __len__(self): + return self.len + + def get_nth_line_bounds(self, n: int): + start = 0 + for i in range(n): + next_start = self.content.find('\n', start) + if next_start == -1: + return None + start = next_start + 1 + return start, self.content.find('\n', start) + + def get_lines_containing(self, + span: Span) -> tuple[list[str], int, int] | None: + # A pointer to the start of the first line + start = 0 + line_no = 0 + source = self.content + while True: + next_start = source.find('\n', start) + line_no += 1 + # handle eof + if next_start == -1: + if span.start > len(source): + return None + return [source[start:]], start, line_no + # as long as the next newline comes before the spans start we can continue + if next_start < span.start: + start = next_start + 1 + continue + # if the whole span is on one line, we are good as well + if next_start >= span.end: + return [source[start:next_start]], start, line_no + while next_start < span.end: + next_start = source.find('\n', next_start + 1) + return source[start:next_start].split('\n'), start, line_no - def __str__(self): - if self.pos is None: - return f"Parsing error at end of file :{self.message}\n" - message = f"Parsing error at {self.pos}:\n" - message += self.pos.get_current_line() + '\n' - message += " " * (self.pos.column - 1) + "^\n" - message += self.message + '\n' - return message + def at(self, i: int): + if i >= self.len: + raise EOFError() + return self.content[i] + + +save_t = tuple[int, tuple[str, ...]] @dataclass -class Parser: +class Tokenizer: + input: Input - class Source(Enum): - XDSL = 1 - MLIR = 2 + pos: int = field(init=False, default=0) + """ + The position in the input. Points to the first unconsumed character. + """ - ctx: MLContext - """xDSL context.""" + break_on: tuple[str, ...] = ('.', '%', ' ', '(', ')', '[', ']', '{', '}', + '<', '>', ':', '=', '@', '?', '|', '->', '-', + '//', '\n', '\t', '#', '"', "'", ',', '!') + """ + characters the tokenizer should break on + """ - str: str - """The current file/input to parse.""" + history: BacktrackingHistory | None = field(init=False, + default=None, + repr=False) - source: Source = field(default=Source.XDSL, kw_only=True) - """The source language to parse.""" + last_token: Span | None = field(init=False, default=None, repr=False) - allow_unregistered_ops: bool = field(default=False, kw_only=True) - """Allow the parsing of unregistered ops.""" + def save(self) -> save_t: + """ + Create a checkpoint in the parsing process, useful for backtracking + """ + return self.pos, self.break_on - _pos: Position | None = field(init=False) - """Position in the file. None represent the end of the file.""" + def resume_from(self, save: save_t): + """ + Resume from a previously saved position. - _ssaValues: dict[str, SSAValue] = field(init=False, default_factory=dict) - """Associate SSA values with their names.""" + Restores the state of the tokenizer to the exact previous position + """ + self.pos, self.break_on = save - _blocks: dict[str, Block] = field(init=False, default_factory=dict) - """Associate blocks with their names.""" + @contextlib.contextmanager + def backtracking(self, region_name: str | None = None): + """ + This context manager can be used to mark backtracking regions. - def __post_init__(self): - if len(self.str) == 0: - self._pos = None - else: - self._pos = Position(self.str) - - def get_pos(self) -> Position | None: - """Return the current position.""" - return self._pos - - def get_char(self, - n: int = 1, - skip_white_space: bool = True) -> str | None: - """Get the next n characters (including the current one)""" - assert n >= 0 - if skip_white_space: - self.skip_white_space() - if self._pos is None: + When an error is thrown during backtracking, it is recorded and stored together + with some meta information in the history attribute. + + The backtracker accepts the following exceptions: + - ParseError: signifies that the region could not be parsed because of (unexpected) syntax errors + - BacktrackingAbort: signifies that backtracking was aborted, not necessarily indicating a syntax error + - AssertionError: this error should probably be phased out in favour of the two above + - EOFError: signals that EOF was reached unexpectedly + + Any other error will be printed to stderr, but backtracking will continue as normal. + """ + save = self.save() + starting_position = self.pos + try: + yield + # clear error history when something doesn't fail + # this is because we are only interested in the last "cascade" of failures. + # if a backtracking() completes without failre, something has been parsed (we assume) + if self.pos > starting_position and self.history is not None: + self.history = None + except Exception as ex: + how_far_we_got = self.pos + + # AssertionErrors act upon the consumed token, this means we only go to the start of the token + if isinstance(ex, BacktrackingAbort): + # TODO: skip space as well + how_far_we_got -= self.last_token.len + + # if we have no error history, start recording! + if not self.history: + self.history = self.history_entry_from_exception( + ex, region_name, how_far_we_got) + + # if we got further than on previous attempts + elif how_far_we_got > self.history.get_farthest_point(): + # throw away history + self.history = None + # generate new history entry, + self.history = self.history_entry_from_exception( + ex, region_name, how_far_we_got) + + # otherwise, add to exception, if we are in a named region + elif region_name is not None and how_far_we_got - starting_position > 0: + self.history = self.history_entry_from_exception( + ex, region_name, how_far_we_got) + + self.resume_from(save) + + def history_entry_from_exception(self, ex: Exception, region: str, + pos: int) -> BacktrackingHistory: + """ + Given an exception generated inside a backtracking attempt, + generate a BacktrackingHistory object with the relevant information in it. + + If an unexpected exception type is encountered, print a traceback to stderr + """ + if isinstance(ex, ParseError): + return BacktrackingHistory(ex, self.history, region, pos) + elif isinstance(ex, AssertionError): + reason = [ + 'Generic assertion failure', + *(reason for reason in ex.args if isinstance(reason, str)) + ] + # we assume that assertions fail because of the last read-in token + if len(reason) == 1: + tb = StringIO() + traceback.print_exc(file=tb) + reason[0] += '\n' + tb.getvalue() + + return BacktrackingHistory( + ParseError(self.last_token, reason[-1], self.history), + self.history, region, pos) + elif isinstance(ex, BacktrackingAbort): + return BacktrackingHistory( + ParseError( + self.next_token(peek=True), + 'Backtracking aborted: {}'.format(ex.reason + or 'unknown reason'), + self.history), self.history, region, pos) + elif isinstance(ex, EOFError): + return BacktrackingHistory( + ParseError(self.last_token, "Encountered EOF", self.history), + self.history, region, pos) + + print("Warning: Unexpected error in backtracking:", file=sys.stderr) + traceback.print_exception(ex, file=sys.stderr) + + return BacktrackingHistory( + ParseError(self.last_token, "Unexpected exception: {}".format(ex), + self.history), self.history, region, pos) + + def next_token(self, start: int | None = None, peek: bool = False) -> Span: + """ + Return a Span of the next token, according to the self.break_on rules. + + Can be modified using: + + - start: don't start at the current tokenizer position, instead start here (useful for skipping comments, etc) + - peek: don't advance the position, only "peek" at the input + + This will skip over line comments. Meaning it will skip the entire line if it encounters '//' + """ + i = self.next_pos(start) + # construct the span: + span = Span(i, self._find_token_end(i), self.input) + # advance pointer if not peeking + if not peek: + self.pos = span.end + + # save last token + self.last_token = span + return span + + def next_token_of_pattern(self, + pattern: re.Pattern | str, + peek: bool = False) -> Span | None: + """ + Return a span that matched the pattern, or nothing. You can choose not to consume the span. + """ + try: + start = self.next_pos() + except EOFError: return None - if self._pos.idx + n > len(self.str): + + # handle search for string literal + if isinstance(pattern, str): + if self.starts_with(pattern): + if not peek: + self.pos = start + len(pattern) + return Span(start, start + len(pattern), self.input) + return None + + # handle regex logic + match = pattern.match(self.input.content, start) + if match is None: return None - return self.str[self._pos.idx:self._pos.idx + n] - _T = TypeVar("_T") + if not peek: + self.pos = match.end() + + # save last token + self.last_token = Span(start, match.end(), self.input) + return self.last_token + + def consume_peeked(self, peeked_span: Span): + if peeked_span.start != self.next_pos(): + raise ParseError(peeked_span, "This is not the peeked span!") + self.pos = peeked_span.end + + def _find_token_end(self, start: int | None = None) -> int: + """ + Find the point (optionally starting from start) where the token ends + """ + i = self.next_pos() if start is None else start + # search for literal breaks + for part in self.break_on: + if self.input.content.startswith(part, i): + return i + len(part) + # otherwise return the start of the next break + return min( + filter(lambda x: x >= 0, (self.input.content.find(part, i) + for part in self.break_on))) + + def next_pos(self, i: int | None = None) -> int: + """ + Find the next starting position (optionally starting from i) + + This will skip line comments! + """ + i = self.pos if i is None else i + # skip whitespaces + while self.input.at(i).isspace(): + i += 1 + + # skip comments as well + if self.input.content.startswith('//', i): + i = self.input.content.find('\n', i) + 1 + return self.next_pos(i) + + return i - def try_parse(self, - parse_fn: Callable[[], _T | None], - skip_white_space: bool = True) -> _T | None: + def is_eof(self): """ - Wrap a parsing function. If the parsing fails, then return without - any change to the current position. + Check if the end of the input was reached. """ - if skip_white_space: - self.skip_white_space() - start_pos = self._pos try: - return parse_fn() - except ParserError: - pass - self._pos = start_pos - return None + self.next_pos() + except EOFError: + return True - def skip_white_space(self) -> None: - while pos := self._pos: - char = pos.get_char() - if char.isspace(): - self._pos = pos.next_char_pos() - elif self.get_char(2, skip_white_space=False) == "//": - self.parse_while(lambda x: x != '\n', False) - else: - return - - def parse_while(self, - cond: Callable[[str], bool], - skip_white_space: bool = True) -> str: - if skip_white_space: - self.skip_white_space() - start_pos = self._pos - if start_pos is None: - return "" - while self._pos: - char = self._pos.get_char() - if not cond(char): - return self.str[start_pos.idx:self._pos.idx] - self._pos = self._pos.next_char_pos() - return self.str[start_pos.idx:] - - # TODO why two different functions, no nums in ident? - def parse_optional_ident(self, - skip_white_space: bool = True) -> str | None: - res = self.parse_while(lambda x: x.isalpha() or x == "_" or x == ".", - skip_white_space=skip_white_space) - if len(res) == 0: - return None - return res + @contextlib.contextmanager + def configured(self, break_on: tuple[str, ...]): + """ + This is a helper class to allow expressing a temporary change in config, allowing you to write: - def parse_ident(self, skip_white_space: bool = True) -> str: - res = self.parse_optional_ident(skip_white_space=skip_white_space) - if res is None: - raise ParserError(self._pos, "ident expected") - return res + # parsing double-quoted string now + string_content = "" + with tokenizer.configured(break_on=('"', '\\'),): + # use tokenizer - def parse_optional_alpha_num(self, - skip_white_space: bool = True) -> str | None: - res = self.parse_while(lambda x: x.isalnum() or x == "_" or x == ".", - skip_white_space=skip_white_space) - if len(res) == 0: - return None - return res + # now old config is restored automatically - def parse_alpha_num(self, skip_white_space: bool = True) -> str: - res = self.parse_optional_alpha_num(skip_white_space=skip_white_space) - if res is None: - raise ParserError(self._pos, "alphanum expected") - return res + """ + save = self.save() - def parse_optional_str_literal(self, - skip_white_space: bool = True - ) -> str | None: - parsed = self.parse_optional_char('"', - skip_white_space=skip_white_space) - if parsed is None: - return None - start_pos = self._pos - if start_pos is None: - raise ParserError(None, "Unexpected end of file") - while self._pos: - pos = self._pos - char = pos.get_char() - if char == '\\': - if next_pos := pos.next_char_pos(): - escaped = next_pos.get_char() - if escaped in ['\\', 'n', 't', 'r', '"']: - self._pos = next_pos.next_char_pos() - continue - else: - raise ParserError( - next_pos, - f"Unrecognized escaped character: \\{escaped}") - else: - raise ParserError(None, "Unexpected end of file") - elif char == '"': - break - self._pos = pos.next_char_pos() - if self._pos is None: - res = self.str[start_pos.idx:] + if break_on is not None: + self.break_on = break_on + + try: + yield self + finally: + self.break_on = save[1] + + def starts_with(self, text: str | re.Pattern) -> bool: + try: + start = self.next_pos() + if isinstance(text, re.Pattern): + return text.match(self.input.content, start) is None + return self.input.content.startswith(text, start) + except EOFError: + return False + + +class ParserCommons: + """ + Colelction of common things used in parsing MLIR/IRDL + + """ + integer_literal = re.compile(r'[+-]?([0-9]+|0x[0-9A-Fa-f]+)') + decimal_literal = re.compile(r'[+-]?([1-9][0-9]*)') + string_literal = re.compile(r'"([^\n\f\v\r"]|\\[nfvr"])+"') + float_literal = re.compile(r'[-+]?[0-9]+\.[0-9]*([eE][-+]?[0-9]+)?') + bare_id = re.compile(r'[A-Za-z_][\w$.]+') + value_id = re.compile(r'%([0-9]+|([A-Za-z_$.-][\w$.-]*))') + suffix_id = re.compile(r'([0-9]+|([A-Za-z_$.-][\w$.-]*))') + block_id = re.compile(r'\^([0-9]+|([A-Za-z_$.-][\w$.-]*))') + type_alias = re.compile(r'![A-Za-z_][\w$.]+') + attribute_alias = re.compile(r'#[A-Za-z_][\w$.]+') + boolean_literal = re.compile(r'(true|false)') + # a list of + _builtin_type_names = ( + r'[su]?i\d+', + r'f\d+', + 'tensor', + 'vector', + 'memref', + 'complex', + 'opaque', + 'tuple', + 'index', + # TODO: add all the Float8E4M3FNType, Float8E5M2Type, and BFloat16Type + ) + builtin_type = re.compile('(({}))'.format(')|('.join(_builtin_type_names))) + builtin_type_xdsl = re.compile('!(({}))'.format( + ')|('.join(_builtin_type_names))) + double_colon = re.compile('::') + comma = re.compile(',') + + +class BaseParser(ABC): + """ + Basic recursive descent parser. + + methods marked try_... will attempt to parse, and return None if they failed. If they return None + they must make sure to restore all state. + + methods marked must_... will do greedy parsing, meaning they consume as much as they can. They will + also throw an error if the think they should still be parsing. e.g. when parsing a list of numbers + separated by '::', the following input will trigger an exception: + 1::2:: + Due to the '::' present after the last element. This is useful for parsing lists, as a trailing + separator is usually considered a syntax error there. + + You can turn a try_ into a must_ by using expect(try_parse_..., error_msg) + + You can turn a must_ into a try_ by wrapping it in tokenizer.backtracking() + + must_ type parsers are preferred because they are explicit about their failure modes. + """ + + ctx: MLContext + """xDSL context.""" + + ssaValues: dict[str, SSAValue] + blocks: dict[str, Block] + forward_block_references: dict[str, list[Span]] + """ + Blocks we encountered references to before the definition (must be empty after parsing of region completes) + """ + + T_ = TypeVar('T_') + """ + Type var used for handling function that return single or multiple Spans. Basically the output type + of all try_parse functions is T_ | None + """ + + def __init__( + self, + ctx: MLContext, + input: str, + name: str, + ): + self.tokenizer = Tokenizer(Input(input, name)) + self.ctx = ctx + self.ssaValues = dict() + self.blocks = dict() + self.forward_block_references = dict() + + def begin_parse(self): + ops = [] + while (op := self.try_parse_operation()) is not None: + ops.append(op) + if not self.tokenizer.is_eof(): + self.raise_error("Could not parse entire input!") + return ops + + def get_block_from_name(self, block_name: Span): + """ + This function takes a span containing a block id (like `^42`) and returns a block. + + If the block defintion was not seen yet, we create a forward declaration. + """ + name = block_name.text + if name not in self.blocks: + self.forward_block_references[name].append(block_name) + self.blocks[name] = Block() + return self.blocks[name] + + def must_parse_block(self) -> Block: + block_id, args = self.must_parse_optional_block_label() + + if block_id is None: + block = Block(self.tokenizer.last_token) + elif self.forward_block_references.pop(block_id.text, + None) is not None: + block = self.blocks[block_id.text] + block.delcared_at = block_id else: - res = self.str[start_pos.idx:self._pos.idx] - self.parse_char('"') - return res + if block_id.text in self.blocks: + raise MultipleSpansParseError( + block_id, + "Re-declaration of block {}".format(block_id.text), + 'Originally declared here:', + [(self.blocks[block_id.text].delcared_at, None)], + self.tokenizer.history) + block = Block(block_id) + self.blocks[block_id.text] = block + + for i, (name, type) in enumerate(args): + arg = BlockArgument(type, block, i) + self.ssaValues[name.text] = arg + block.args.append(arg) + + while (next_op := self.try_parse_operation()) is not None: + block.add_op(next_op) - def parse_str_literal(self, skip_white_space: bool = True) -> str: - res = self.parse_optional_str_literal( - skip_white_space=skip_white_space) - if res is None: - raise ParserError(self._pos, "string literal expected") - return res + return block - def parse_optional_int_literal(self, - skip_white_space: bool = True - ) -> int | None: - is_negative = self.parse_optional_char( - "-", skip_white_space=skip_white_space) - res = self.parse_while(lambda char: char.isnumeric(), - skip_white_space=False) - if len(res) == 0: - if is_negative: - raise ParserError(self._pos, "int literal expected") - return None - return int(res) if is_negative is None else -int(res) + def must_parse_optional_block_label( + self) -> tuple[Span | None, list[tuple[Span, Attribute]]]: + block_id = self.try_parse_block_id() + arg_list = list() - def parse_int_literal(self, skip_white_space: bool = True) -> int: - res = self.parse_optional_int_literal( - skip_white_space=skip_white_space) - if res is None: - raise ParserError(self._pos, "int literal expected") - return res + if block_id is not None: + if self.tokenizer.starts_with('('): + arg_list = self.must_parse_block_arg_list() - def parse_optional_float_literal(self, - skip_white_space: bool = True - ) -> float | None: - return self.try_parse(self.parse_float_literal, - skip_white_space=skip_white_space) - - def parse_float_literal(self, skip_white_space: bool = True) -> float: - # Parse the optional sign - value = "" - if self.parse_optional_char("+", skip_white_space=skip_white_space): - value += "+" - elif self.parse_optional_char("-", skip_white_space=False): - value += "-" - - # Parse the significant digits - digits = self.parse_while(lambda x: x.isdigit(), - skip_white_space=False) - if digits == "": - raise ParserError(self._pos, "float literal expected") - value += digits - - # Check that we are parsing a float, and not an integer - is_float = False - - # Parse the optional decimal point - if self.parse_optional_char(".", skip_white_space=False): - # Parse the fractional digits - value += "." - value += self.parse_while(lambda x: x.isdigit(), - skip_white_space=False) - is_float = True - - # Parse the optional exponent - if self.parse_optional_char( - "e", skip_white_space=False) or self.parse_optional_char( - "E", skip_white_space=False): - value += "e" - # Parse the optional exponent sign - if self.parse_optional_char("+", skip_white_space=False): - value += "+" - elif self.parse_optional_char("-", skip_white_space=False): - value += "-" - # Parse the exponent digits - value += self.parse_while(lambda x: x.isdigit(), - skip_white_space=False) - is_float = True - - if not is_float: - raise ParserError( - self._pos, - "float literal expected, but got an integer literal instead") - - return float(value) - - def peek_char(self, - char: str, - skip_white_space: bool = True) -> bool | None: - if skip_white_space: - self.skip_white_space() - if self.get_char() == char: - return True - return None + self.must_parse_characters(':', 'Block label must end in a `:`!') - def parse_optional_char(self, - char: str, - skip_white_space: bool = True) -> bool | None: - assert len(char) == 1 - if skip_white_space: - self.skip_white_space() - if self._pos is None: - return None - if self._pos.get_char() == char: - self._pos = self._pos.next_char_pos() - return True - return None + return block_id, arg_list - def parse_char(self, char: str, skip_white_space: bool = True) -> bool: - assert (len(char) == 1) - current_char = self.get_char() - res = self.parse_optional_char(char, skip_white_space=skip_white_space) - if res is None: - raise ParserError(self._pos, - f"'{char}' expected, got '{current_char}'") - return True - - def parse_string(self, - contents: str, - skip_white_space: bool = True) -> bool: - if skip_white_space: - self.skip_white_space() - chars = self.get_char(len(contents)) - if chars == contents: - assert self._pos - self._pos = self._pos.next_char_pos(len(contents)) - return True - raise ParserError(self._pos, f"'{contents}' expected") - - def parse_optional_string(self, - contents: str, - skip_white_space: bool = True) -> bool | None: - if skip_white_space: - self.skip_white_space() - chars = self.get_char(len(contents)) - if chars == contents: - assert self._pos is not None - self._pos = self._pos.next_char_pos(len(contents)) - return True + def must_parse_block_arg_list(self) -> list[tuple[Span, Attribute]]: + self.must_parse_characters('(', 'Block arguments must start with `(`') + + args = self.must_parse_list_of(self.try_parse_value_id_and_type, + "Expected value-id and type here!") + + self.must_parse_characters(')', + 'Expected closing of block arguments!') + + return args + + def try_parse_single_reference(self) -> Span | None: + with self.tokenizer.backtracking('part of a reference'): + self.must_parse_characters('@', "references must start with `@`") + if (reference := self.try_parse_string_literal()) is not None: + return reference + if (reference := self.try_parse_suffix_id()) is not None: + return reference + self.raise_error( + "References must conform to `@` (string-literal | suffix-id)") + + def must_parse_reference(self) -> list[Span]: + return self.must_parse_list_of( + self.try_parse_single_reference, + 'Expected reference here in the format of `@` (suffix-id | string-literal)', + ParserCommons.double_colon, + allow_empty=False) + + def must_parse_list_of(self, + try_parse: Callable[[], T_ | None], + error_msg: str, + separator_pattern: re.Pattern = ParserCommons.comma, + allow_empty: bool = True) -> list[T_]: + """ + This is a greedy list-parser. It accepts input only in these cases: + + - If the separator isn't encountered, which signals the end of the list + - If an empty list is allowed, it accepts when the first try_parse fails + - If an empty separator is given, it instead sees a failed try_parse as the end of the list. + + This means, that the setup will not accept the input and instead raise an error: + try_parse = parse_integer_literal + separator = 'x' + input = 3x4x4xi32 + as it will read [3,4,4], then see another separator, and expects the next try_parse call to succeed + (which won't as i32 is not a valid integer literal) + """ + items = list() + first_item = try_parse() + if first_item is None: + if allow_empty: + return items + self.raise_error(error_msg) + + items.append(first_item) + + while (match := self.tokenizer.next_token_of_pattern(separator_pattern) + ) is not None: + next_item = try_parse() + if next_item is None: + # if the separator is emtpy, we are good here + if separator_pattern.pattern == '': + return items + self.raise_error(error_msg + + ' because was able to match next separator {}' + .format(match.text)) + items.append(next_item) + + return items + + def try_parse_integer_literal(self) -> Span | None: + return self.tokenizer.next_token_of_pattern( + ParserCommons.integer_literal) + + def try_parse_decimal_literal(self) -> Span | None: + return self.tokenizer.next_token_of_pattern( + ParserCommons.decimal_literal) + + def try_parse_string_literal(self) -> StringLiteral | None: + return StringLiteral.from_span( + self.tokenizer.next_token_of_pattern(ParserCommons.string_literal)) + + def try_parse_float_literal(self) -> Span | None: + return self.tokenizer.next_token_of_pattern( + ParserCommons.float_literal) + + def try_parse_bare_id(self) -> Span | None: + return self.tokenizer.next_token_of_pattern(ParserCommons.bare_id) + + def try_parse_value_id(self) -> Span | None: + return self.tokenizer.next_token_of_pattern(ParserCommons.value_id) + + def try_parse_suffix_id(self) -> Span | None: + return self.tokenizer.next_token_of_pattern(ParserCommons.suffix_id) + + def try_parse_block_id(self) -> Span | None: + return self.tokenizer.next_token_of_pattern(ParserCommons.block_id) + + def try_parse_boolean_literal(self) -> Span | None: + return self.tokenizer.next_token_of_pattern( + ParserCommons.boolean_literal) + + def try_parse_value_id_and_type(self) -> tuple[Span, Attribute] | None: + with self.tokenizer.backtracking("value id and type"): + value_id = self.try_parse_value_id() + + if value_id is None: + self.raise_error("Invalid value-id format!") + + self.must_parse_characters( + ':', 'Expected expression (value-id `:` type)') + + type = self.try_parse_type() + + if type is None: + self.raise_error("Expected type of value-id here!") + return value_id, type + + def try_parse_type(self) -> Attribute | None: + if (builtin_type := self.try_parse_builtin_type()) is not None: + return builtin_type + if (dialect_type := self.try_parse_dialect_type()) is not None: + return dialect_type return None - T = TypeVar('T') + def try_parse_dialect_type_or_attribute(self) -> Attribute | None: + """ + Parse a type or an attribute. + """ + kind = self.tokenizer.next_token_of_pattern(re.compile('[!#]'), + peek=True) - def parse_optional_nested_list( - self, - parse_optional_one: Callable[[], T | None], - delimiter: str = ",", - brackets: str = "[]", - skip_white_space: bool = True) -> list[T] | None: - ''' - Parse and flatten a list of lists. The result is a list of elements, no matter the - rank of the input. - Delimiter must be length one, for example ",". - Brackets must be length two, for example "[]". - ''' - - assert len(delimiter) == 1 - assert len(brackets) == 2 - - open_bracket, close_bracket = brackets - if not self.parse_optional_char(open_bracket, - skip_white_space=skip_white_space): - # This is not a list that opens with the opening bracket + if kind is None: return None - indices = [0] + with self.tokenizer.backtracking("dialect attribute or type"): + self.tokenizer.consume_peeked(kind) + if kind.text == '!': + return self.must_parse_dialect_type_or_attribute_inner('type') + else: + return self.must_parse_dialect_type_or_attribute_inner( + 'attribute') - res = list[Any]() # Pyright does not let us use `T` here + def try_parse_dialect_type(self): + """ + Parse a dialect type (something prefixed by `!`, defined by a dialect) + """ + if not self.tokenizer.starts_with('!'): + return None + with self.tokenizer.backtracking("dialect type"): + self.must_parse_characters('!', + "Dialect type must start with a `!`") + return self.must_parse_dialect_type_or_attribute_inner('type') - while len(indices) > 0: - if self.parse_optional_char(close_bracket, - skip_white_space=skip_white_space): - # This is the end of a list - indices.pop() - if len(indices) > 0: - indices[-1] += 1 - continue + def try_parse_dialect_attr(self): + """ + Parse a dialect attribute (something prefixed by `#`, defined by a dialect) + """ + if not self.tokenizer.starts_with('#'): + return None + with self.tokenizer.backtracking("dialect attribute"): + self.must_parse_characters( + '#', "Dialect attribute must start with a `#`") + return self.must_parse_dialect_type_or_attribute_inner('attribute') + + def must_parse_dialect_type_or_attribute_inner(self, kind: str): + type_name = self.tokenizer.next_token_of_pattern(ParserCommons.bare_id) + + if type_name is None: + self.raise_error("Expected dialect {} name here!".format(kind)) + + type_def = self.ctx.get_optional_attr(type_name.text) + if type_def is None: + self.raise_error( + "'{}' is not a know attribute!".format(type_name.text), + type_name) + + # pass the task of parsing parameters on to the attribute/type definition + if issubclass(type_def, ParametrizedAttribute): + param_list = type_def.parse_parameters(self) + elif issubclass(type_def, Data): + self.must_parse_characters('<', 'This attribute must be parametrized!') + param_list = type_def.parse_parameter(self) + self.must_parse_characters('>', 'Invalid attribute parametrization, expected `>`!') + else: + assert False, "Mathieu said this cannot be." + return type_def(param_list) - if indices[-1]: - # If we're not at the end of the list, then it's a delimiter followed by - # the next eleement, which might be a nested list. - self.parse_char(delimiter, skip_white_space=skip_white_space) + @abstractmethod + def try_parse_builtin_type(self) -> Attribute | None: + """ + parse a builtin-type like i32, index, vector etc. + """ + raise NotImplemented("Subclasses must implement this method!") + + def must_parse_builtin_parametrized_type( + self, name: Span) -> ParametrizedAttribute: + + def unimplemented() -> ParametrizedAttribute: + raise ParseError(name, + "Builtin {} not supported yet!".format(name.text)) + + builtin_parsers: dict[str, Callable[[], ParametrizedAttribute]] = { + 'vector': self.must_parse_vector_attrs, + 'memref': unimplemented, + 'tensor': self.must_parse_tensor_attrs, + 'complex': self.must_parse_complex_attrs, + 'opaque': unimplemented, + 'tuple': unimplemented, + } + + self.must_parse_characters('<', 'Expected parameter list here!') + # get the parser for the type, falling back to the unimplemented warning + res = builtin_parsers.get(name.text, unimplemented)() + self.must_parse_characters('>', + 'Expected end of parameter list here!') + return res - if self.parse_optional_char(open_bracket, - skip_white_space=skip_white_space): - # A new nested list, reset the index - indices.append(0) + def must_parse_complex_attrs(self): + self.raise_error("ComplexType is unimplemented!") + + def try_parse_numerical_dims(self, + accept_closing_bracket: bool = False, + lower_bound: int = 1) -> Iterable[int]: + while (shape_arg := + self.try_parse_shape_element(lower_bound)) is not None: + yield shape_arg + # look out for the closing bracket for scalable vector dims + if accept_closing_bracket and self.tokenizer.starts_with(']'): + break + self.must_parse_characters( + 'x', + 'Unexpected end of dimension parameters!') + + def must_parse_vector_attrs(self) -> AnyVectorType: + # also break on 'x' characters as they are separators in dimension parameters + with self.tokenizer.configured(break_on=self.tokenizer.break_on + + ('x',)): + shape = list[int](self.try_parse_numerical_dims()) + scaling_shape: list[int] | None = None + + if self.tokenizer.next_token_of_pattern('[') is not None: + # we now need to parse the scalable dimensions + scaling_shape = list(self.try_parse_numerical_dims()) + self.must_parse_characters( + ']', + 'Expected end of scalable vector dimensions here!') + self.must_parse_characters( + 'x', + 'Expected end of scalable vector dimensions here!') + + if scaling_shape is not None: + # TODO: handle scaling vectors! + self.raise_error("Warning: scaling vectors not supported!") + pass + + type = self.try_parse_type() + if type is None: + self.raise_error( + "Expected a type at the end of the vector parameters!") + + return VectorType.from_type_and_list(type, shape) + + def must_parse_tensor_or_memref_dims(self) -> list[int] | None: + with self.tokenizer.configured(break_on=self.tokenizer.break_on + + ('x',)): + # check for unranked-ness + if self.tokenizer.next_token_of_pattern('*') is not None: + # consume `x` + self.must_parse_characters( + 'x', + 'Unranked tensors must follow format (`<*x` type `>`)') else: - # This must be a list element - one = parse_optional_one() - if one is None: - raise ParserError(self._pos, 'Expected list element') - res.append(one) - indices[-1] += 1 + # parse rank: + return list(self.try_parse_numerical_dims(lower_bound=0)) - return res + def must_parse_tensor_attrs(self) -> AnyTensorType: + shape = self.must_parse_tensor_or_memref_dims() + type = self.try_parse_type() - def parse_list(self, - parse_optional_one: Callable[[], T | None], - delimiter: str = ",", - skip_white_space: bool = True) -> list[T]: - if skip_white_space: - self.skip_white_space() - assert (len(delimiter) <= 1) - res = list[Any]() # Pyright do not let us use `T` here - one = parse_optional_one() - if one is not None: - res.append(one) - while self.parse_optional_char(delimiter) if len( - delimiter) == 1 else True: - one = parse_optional_one() - if one is None: - return res - res.append(one) - return res + if type is None: + self.raise_error("Expected tensor type here!") - K = TypeVar('K') - V = TypeVar('V') - - def parse_dictionary(self, - parse_key: Callable[[], K], - parse_value: Callable[[], V], - delimiter: str = ",", - skip_white_space: bool = True) -> dict[K, V]: - if skip_white_space: - self.skip_white_space() - assert (len(delimiter) <= 1) - if len(delimiter): - parse_delimiter = lambda: self.parse_char(delimiter) - else: - parse_delimiter = lambda: True + if self.tokenizer.starts_with(','): + # TODO: add tensor encoding! + raise self.raise_error("Parsing tensor encoding is not supported!") + + if shape is None and self.tokenizer.starts_with(','): + raise self.raise_error("Unranked tensors don't have an encoding!") + + if shape is not None: + return TensorType.from_type_and_list(type, shape) + + return UnrankedTensorType.from_type(type) + + def try_parse_shape_element(self, lower_bound: int = 1) -> int | None: + """ + Parse a shape element, either a decimal integer immediate or a `?`, which evaluates to -1 + + immediate cannot be smaller than lower_bound (defaults to 1) (is 0 for tensors and memrefs) + """ + int_lit = self.try_parse_decimal_literal() + + if int_lit is not None: + value = int(int_lit.text) + if value < lower_bound: + # TODO: this is ugly, it's a raise inside a try_ type function, which should instead just give up + raise ParseError( + int_lit, + "Shape element literal cannot be negative or zero!") + return value + + if self.tokenizer.next_token_of_pattern('?') is not None: + return -1 + return None - self.parse_char("{") - if self.peek_char("}"): - return {} + def must_parse_type_params(self) -> list[Attribute]: + # consume opening bracket + self.must_parse_characters('<', 'Type must be parameterized!') - key, value = self.parse_dict_entry(parse_key, parse_value) - res = {key: value} - while not self.peek_char("}"): - parse_delimiter() - key, value = self.parse_dict_entry(parse_key, parse_value) - res[key] = value + params = self.must_parse_list_of(self.try_parse_type, + 'Expected a type here!') - self.parse_char("}") + self.must_parse_characters('>', 'Expected end of type parameterization here!') + return params + + def expect(self, try_parse: Callable[[], T_ | None], + error_message: str) -> T_: + """ + Used to force completion of a try_parse function. Will throw a parse error if it can't + """ + res = try_parse() + if res is None: + self.raise_error(error_message) return res - def parse_dict_entry( - self, - parse_key: Callable[[], K], - parse_value: Callable[[], V], - ) -> tuple[K, V]: - key = parse_key() - self.parse_char("=") - value = parse_value() - return key, value - - def parse_optional_block_argument( - self, - skip_white_space: bool = True) -> tuple[str, Attribute] | None: - name = self.parse_optional_ssa_name(skip_white_space=skip_white_space) - if name is None: - return None - self.parse_char(":") - typ = self.parse_attribute() - # TODO how to get the id? - return name, typ - - def parse_optional_named_block(self, - skip_white_space: bool = True - ) -> Block | None: - if self.parse_optional_char("^", - skip_white_space=skip_white_space) is None: - return None - block_name = self.parse_alpha_num(skip_white_space=False) - if block_name in self._blocks: - block = self._blocks[block_name] + def raise_error(self, msg: str, at_position: Span | None = None): + """ + Helper for raising exceptions, provides as much context as possible to them. + + This will, for example, include backtracking errors, if any occured previously + """ + if at_position is None: + at_position = self.tokenizer.next_token(peek=True) + + raise ParseError(at_position, msg, self.tokenizer.history) + + def must_parse_characters(self, + text: str, + msg: str, ) -> Span: + if (match := self.tokenizer.next_token_of_pattern(text)) is None: + self.raise_error(msg) + return match + + @abstractmethod + def must_parse_op_result_list( + self) -> tuple[list[Span], list[Attribute] | None]: + raise NotImplemented() + + def try_parse_operation(self) -> Operation | None: + with self.tokenizer.backtracking("operation"): + return self.must_parse_operation() + + def must_parse_operation(self) -> Operation: + result_list, ret_types = self.must_parse_op_result_list() + if len(result_list) > 0: + self.must_parse_characters( + '=', + 'Operation definitions expect an `=` after op-result-list!' + ) + + # check for custom op format + op_name = self.try_parse_bare_id() + if op_name is not None: + op_type = self.ctx.get_op(op_name.text) + op = op_type.parse(ret_types, self) else: - block = Block() - self._blocks[block_name] = block - - if self.parse_optional_char("("): - tuple_list = self.parse_list(self.parse_optional_block_argument) - # Register the BlockArguments as ssa values and add them to - # the block - for (idx, (arg_name, arg_type)) in enumerate(tuple_list): - if arg_name in self._ssaValues: - raise ParserError( - self._pos, f"SSA value {arg_name} is already defined") - arg = BlockArgument(arg_type, block, idx) - self._ssaValues[arg_name] = arg - block.args.append(arg) - - self.parse_char(")") - self.parse_char(":") - for op in self.parse_list(self.parse_optional_op, delimiter=""): - block.add_op(op) - return block + # check for basic op format + op_name = self.try_parse_string_literal() + if op_name is None: + self.raise_error( + "Expected an operation name here, either a bare-id, or a string literal!" + ) + + args, successors, attrs, regions, func_type = self.must_parse_operation_details( + ) + + if ret_types is None: + assert func_type is not None + ret_types = func_type.outputs.data + + op_type = self.ctx.get_op(op_name.string_contents) + + op = op_type.create( + operands=[self.ssaValues[span.text] for span in args], + result_types=ret_types, + attributes=attrs, + successors=[ + self.blocks[block_name.text] + for block_name in successors + ], + regions=regions) + + # Register the result SSA value names in the parser + for idx, res in enumerate(result_list): + ssa_val_name = res.text + if ssa_val_name in self.ssaValues: + self.raise_error( + f"SSA value {ssa_val_name} is already defined", res) + self.ssaValues[ssa_val_name] = op.results[idx] + self.ssaValues[ssa_val_name].name = ssa_val_name.lstrip('%') + + return op + + def must_parse_region(self) -> Region: + oldSSAVals = self.ssaValues.copy() + oldBBNames = self.blocks + oldForwardRefs = self.forward_block_references + self.blocks = dict() + self.forward_block_references = defaultdict(list) - def parse_optional_region(self, - skip_white_space: bool = True) -> Region | None: - if not self.parse_optional_char("{", - skip_white_space=skip_white_space): - return None region = Region() - oldSSAVals = self._ssaValues.copy() - oldBBNames = self._blocks.copy() - self._blocks = dict[str, Block]() - if self.peek_char('^'): - for block in self.parse_list(self.parse_optional_named_block, - delimiter=""): + try: + self.must_parse_characters('{', 'Regions begin with `{`') + if not self.tokenizer.starts_with('}'): + # parse first block + block = self.must_parse_block() region.add_block(block) - else: - region.add_block(Block()) - for op in self.parse_list(self.parse_optional_op, delimiter=""): - region.blocks[0].add_op(op) - self.parse_char("}") - - self._ssaValues = oldSSAVals - self._blocks = oldBBNames - return region - - def parse_optional_ssa_name(self, - skip_white_space: bool = True) -> str | None: - if self.parse_optional_char("%", - skip_white_space=skip_white_space) is None: - return None - name = self.parse_alpha_num() - return name - - def parse_optional_ssa_value(self, - skip_white_space: bool = True - ) -> SSAValue | None: - if skip_white_space: - self.skip_white_space() - start_pos = self._pos - name = self.parse_optional_ssa_name() + + while self.tokenizer.starts_with('^'): + region.add_block(self.must_parse_block()) + + end = self.must_parse_characters( + '}', 'Reached end of region, expected `}`!') + + if len(self.forward_block_references) > 0: + raise MultipleSpansParseError( + end, + "Region ends with missing block declarations for block(s) {}!" + .format(', '.join(self.forward_block_references.keys())), + 'The following block references are dangling:', + [(span, "Reference to block \"{}\" without implementation!" + .format(span.text)) for span in itertools.chain( + *self.forward_block_references.values())], + self.tokenizer.history) + + return region + finally: + self.ssaValues = oldSSAVals + self.blocks = oldBBNames + self.forward_block_references = oldForwardRefs + + def try_parse_op_name(self) -> Span | None: + if (str_lit := self.try_parse_string_literal()) is not None: + return str_lit + return self.try_parse_bare_id() + + def must_parse_attribute_entry(self) -> tuple[Span, Attribute]: + """ + Parse entry in attribute dict. Of format: + + attrbiute_entry := (bare-id | string-literal) `=` attribute + attrbiute := dialect-attribute | builtin-attribute + """ + if (name := self.try_parse_bare_id()) is None: + name = self.try_parse_string_literal() + if name is None: - return None - if name not in self._ssaValues: - raise ParserError(start_pos, - f"name {name} does not refer to a SSA value") - return self._ssaValues[name] + self.raise_error( + 'Expected bare-id or string-literal here as part of attribute entry!' + ) - def parse_ssa_value(self, skip_white_space: bool = True) -> SSAValue: - res = self.parse_optional_ssa_value(skip_white_space=skip_white_space) - if res is None: - raise ParserError(self._pos, "SSA value expected") - return res + self.must_parse_characters( + '=', 'Attribute entries must be of format name `=` attribute!') + + return name, self.must_parse_attribute() + + @abstractmethod + def must_parse_attribute(self) -> Attribute: + """ + Parse attribute (either builtin or dialect) + + This is different in xDSL and MLIR, so the actuall implementation is provided by the subclass + """ + raise NotImplemented() - def parse_optional_results(self, - skip_white_space: bool = True - ) -> list[str] | None: - res = self.parse_list(self.parse_optional_ssa_name, - skip_white_space=skip_white_space) - if len(res) == 0: + def try_parse_attribute(self) -> Attribute | None: + with self.tokenizer.backtracking('attribute'): + return self.must_parse_attribute() + + def must_parse_attribute_type(self) -> Attribute: + """ + Parses `:` type and returns the type + """ + self.must_parse_characters( + ':', 'Expected attribute type definition here ( `:` type )') + return self.expect( + self.try_parse_type, + 'Expected attribute type definition here ( `:` type )') + + def try_parse_builtin_attr(self) -> Attribute: + """ + Tries to parse a bultin attribute, e.g. a string literal, int, array, etc.. + """ + # order here is important! + attrs = (self.try_parse_builtin_float_attr, + self.try_parse_builtin_int_attr, + self.try_parse_builtin_str_attr, + self.try_parse_builtin_arr_attr, self.try_parse_function_type, + self.try_parse_ref_attr) + + for attr_parser in attrs: + if (val := attr_parser()) is not None: + return val + + def try_parse_ref_attr(self) -> FlatSymbolRefAttr | None: + if not self.tokenizer.starts_with('@'): return None - self.parse_char("=") - return res - def parse_optional_typed_result( - self, - skip_white_space: bool = True) -> tuple[str, Attribute] | None: - name = self.parse_optional_ssa_name(skip_white_space=skip_white_space) - if name is None: + ref = self.must_parse_reference() + + if len(ref) > 1: + self.raise_error("Nested refs are not supported yet!", ref[1]) + + return FlatSymbolRefAttr.from_str(ref[0].text[1:]) + + def try_parse_builtin_int_attr(self) -> IntegerAttr | None: + bool = self.try_parse_builtin_boolean_attr() + if bool is not None: + return bool + + with self.tokenizer.backtracking("built in int attribute"): + value = self.expect( + self.try_parse_integer_literal, + 'Integer attribute must start with an integer literal!') + if self.tokenizer.next_token(peek=True).text != ':': + print(self.tokenizer.next_token(peek=True)) + return IntegerAttr.from_params(int(value.text), + DefaultIntegerAttrType) + type = self.must_parse_attribute_type() + return IntegerAttr.from_params(int(value.text), type) + + def try_parse_builtin_float_attr(self) -> FloatAttr | None: + with self.tokenizer.backtracking("float literal"): + value = self.expect( + self.try_parse_float_literal, + 'Float attribute must start with a float literal!') + # if we don't see a ':' indicating a type signature + if not self.tokenizer.starts_with(':'): + return FloatAttr.from_value(float(value.text)) + + type = self.must_parse_attribute_type() + return FloatAttr.from_value(float(value.text), type) + + def try_parse_builtin_boolean_attr(self) -> IntegerAttr | None: + span = self.try_parse_boolean_literal() + + if span is None: return None - self.parse_char(":") - typ = self.parse_attribute() - return name, typ - def parse_optional_typed_results( - self, - skip_white_space: bool = True - ) -> list[tuple[str, Attribute]] | None: - res = self.parse_list(lambda: self.parse_optional_typed_result( - skip_white_space=skip_white_space)) - if len(res) == 0: + int_val = ['false', 'true'].index(span.text) + return IntegerAttr.from_params(int_val, IntegerType.from_width(1)) + + def try_parse_builtin_str_attr(self): + if not self.tokenizer.starts_with('"'): return None - elif len(res) == 1 and res[0] is None: + + with self.tokenizer.backtracking("string literal"): + literal = self.try_parse_string_literal() + if literal is None: + self.raise_error('Invalid string literal') + return StringAttr.from_str(literal.string_contents) + + def try_parse_builtin_arr_attr(self) -> list[Attribute] | None: + if not self.tokenizer.starts_with('['): return None + with self.tokenizer.backtracking("array literal"): + self.must_parse_characters('[', + 'Array literals must start with `[`') + attrs = self.must_parse_list_of(self.try_parse_attribute, + 'Expected array entry!') + self.must_parse_characters( + ']', 'Malformed array contents (expected end of array here!') + return ArrayAttr.from_list(attrs) + + @abstractmethod + def must_parse_optional_attr_dict(self) -> dict[str, Attribute]: + raise NotImplementedError() + + def attr_dict_from_tuple_list( + self, tuple_list: list[tuple[Span, + Attribute]]) -> dict[str, Attribute]: + """ + Convert a list of tuples (Span, Attribute) to a dictionary. + + This function converts the span to a string, trimming quotes from string literals + """ + + def span_to_str(span: Span) -> str: + if isinstance(span, StringLiteral): + return span.string_contents + return span.text + + return dict((span_to_str(span), attr) for span, attr in tuple_list) + + def must_parse_function_type(self) -> FunctionType: + """ + Parses function-type: + + viable function types are: + (i32) -> () + () -> (i32, i32) + (i32, i32) -> () + () -> i32 + Non-viable types are: + i32 -> i32 + i32 -> () + + Uses type-or-type-list-parens internally + """ + self.must_parse_characters( + '(', 'First group of function args must start with a `(`') + + args: list[Attribute] = self.must_parse_list_of( + self.try_parse_type, 'Expected type here!') + + self.must_parse_characters( + ')', + "Malformed function type, expected closing brackets of argument types!") + + self.must_parse_characters('->', + 'Malformed function type, expected `->`!') + + return FunctionType.from_lists( + args, self.must_parse_type_or_type_list_parens()) + + def must_parse_type_or_type_list_parens(self) -> list[Attribute]: + """ + Parses type-or-type-list-parens, which is used in function-type. + + type-or-type-list-parens ::= type | type-list-parens + type-list-parens ::= `(` `)` | `(` type-list-no-parens `)` + type-list-no-parens ::= type (`,` type)* + """ + if self.tokenizer.next_token_of_pattern('(') is not None: + args: list[Attribute] = self.must_parse_list_of( + self.try_parse_type, 'Expected type here!') + self.must_parse_characters(')', + "Unclosed function type argument list!") else: - self.parse_char("=") - return res - - def parse_optional_operand(self, - skip_white_space: bool = True - ) -> SSAValue | None: - value = self.parse_optional_ssa_value( - skip_white_space=skip_white_space) - if value is None: + args = [self.try_parse_type()] + if args[0] is None: + self.raise_error( + "Function type must either be single type or list of types in parenthesis!" + ) + return args + + def try_parse_function_type(self) -> FunctionType | None: + if not self.tokenizer.starts_with('('): return None - if self.source == self.Source.XDSL: - self.parse_char(":") - typ = self.parse_attribute() - if value.typ != typ: - raise ParserError( - self._pos, f"type mismatch between {typ} and {value.typ}") - return value + with self.tokenizer.backtracking('function type'): + return self.must_parse_function_type() - def parse_operands(self, skip_white_space: bool = True) -> list[SSAValue]: - self.parse_char("(", skip_white_space=skip_white_space) - res = self.parse_list(lambda: self.parse_optional_operand()) - self.parse_char(")") - return res + def must_parse_region_list(self) -> list[Region]: + """ + Parses a sequence of regions for as long as there is a `{` in the input. + """ + regions = [] + while not self.tokenizer.is_eof() and self.tokenizer.starts_with('{'): + regions.append(self.must_parse_region()) + return regions + + # COMMON xDSL/MLIR code: + def must_parse_builtin_type_with_name(self, name: Span): + if name.text == 'index': + return IndexType() + if (re_match := re.match(r'^[su]?i(\d+)$', name.text)) is not None: + signedness = { + 's': Signedness.SIGNED, + 'u': Signedness.UNSIGNED, + 'i': Signedness.SIGNLESS + } + return IntegerType.from_width(int(re_match.group(1)), + signedness[name.text[0]]) + + if (re_match := re.match(r'^f(\d+)$', name.text)) is not None: + width = int(re_match.group(1)) + type = { + 16: Float16Type, + 32: Float32Type, + 64: Float64Type + }.get(width, None) + if type is None: + self.raise_error( + "Unsupported floating point width: {}".format(width)) + return type() + + return self.must_parse_builtin_parametrized_type(name) + + @abstractmethod + def must_parse_operation_details( + self + ) -> tuple[list[Span], list[Span], dict[str, Attribute], list[Region], + FunctionType | None]: + """ + Must return a tuple consisting of: + - a list of arguments to the operation + - a list of successor names + - the attributes attached to the OP + - the regions of the op + - An optional function type. If not supplied, must_parse_op_result_list must return a second value + containing the types of the returned SSAValues + + Your implementation should make use of the following functions: + - must_parse_op_args_list + - must_parse_optional_attr_dict + - must_parse_ + """ + raise NotImplementedError() + + def must_parse_op_args_list(self) -> list[Span]: + self.must_parse_characters( + '(', 'Operation args list must be enclosed by brackets!') + args = self.must_parse_list_of(self.try_parse_value_id_and_type, + 'Expected another bare-id here') + self.must_parse_characters( + ')', 'Operation args list must be closed by a closing bracket') + # TODO: check if type is correct here! + return [name for name, _ in args] + + # HERE STARTS A SOMEWHAT CURSED COMPATIBILITY LAYER: + # since we don't want to rewrite all dialects currently, the new emulator needs to expose the same + # interface to the dialect definitions. Here we implement that interface. + + _OperationType = TypeVar('_OperationType', bound=Operation) + + def parse_op_with_default_format( + self, + op_type: type[_OperationType], + result_types: list[Attribute], + skip_white_space: bool = True) -> _OperationType: + """ + Compatibility wrapper so the new parser can be passed instead of the old one. Parses everything after the + operation name. + + This implicitly assumes XDSL format, and will fail on MLIR style operations + """ + # TODO: remove this function and restructure custom op / irdl parsing + assert isinstance(self, XDSLParser) + args, successors, attributes, regions, _ = self.must_parse_operation_details( + ) + + for x in args: + if x.text not in self.ssaValues: + self.raise_error( + "Unknown SSAValue name, known SSA Values are: {}".format( + ", ".join(self.ssaValues.keys())), x) + + return op_type.create( + operands=[self.ssaValues[span.text] for span in args], + result_types=result_types, + attributes=attributes, + successors=[self.get_block_from_name(span) for span in successors], + regions=regions) def parse_paramattr_parameters( self, expect_brackets: bool = False, skip_white_space: bool = True) -> list[Attribute]: - if expect_brackets: - self.parse_char("<", skip_white_space=skip_white_space) - elif self.parse_optional_char( - "<", skip_white_space=skip_white_space) is None: - return [] + opening_brackets = self.tokenizer.next_token_of_pattern('<') + if expect_brackets and opening_brackets is None: + self.raise_error("Expected start attribute parameters here (`<`)!") - res = self.parse_list(self.parse_optional_attribute) - self.parse_char(">") - return res + res = self.must_parse_list_of(self.try_parse_attribute, + 'Expected another attribute here!') - def parse_optional_boolean_attribute( - self, - skip_white_space: bool = True) -> IntegerAttr[IntegerType] | None: - if self.parse_optional_string( - "true", skip_white_space=skip_white_space) is not None: - return IntegerAttr.from_int_and_width(1, 1) - if self.parse_optional_string( - "false", skip_white_space=skip_white_space) is not None: - return IntegerAttr.from_int_and_width(0, 1) - - def parse_optional_xdsl_builtin_attribute(self, - skip_white_space: bool = True - ) -> Attribute | None: - # Shorthand for StringAttr - string_lit = self.parse_optional_str_literal( - skip_white_space=skip_white_space) - if string_lit is not None: - return StringAttr.from_str(string_lit) - - # Shorthand for FloatAttr - float_lit = self.parse_optional_float_literal() - if float_lit is not None: - if self.parse_optional_char(":"): - typ = self.parse_attribute() - else: - typ = Float32Type() - return FloatAttr.from_value(float_lit, typ) - - # Shorthand for boolean literals (IntegerAttr of width 1) - if (bool_attr := self.parse_optional_boolean_attribute( - skip_white_space=skip_white_space)): - return bool_attr - - # Shorthand for IntegerAttr - integer_lit = self.parse_optional_int_literal() - if integer_lit is not None: - if self.parse_optional_char(":"): - typ = self.parse_attribute() - else: - typ = IntegerType.from_width(64) - return IntegerAttr.from_params(integer_lit, typ) - - # Shorthand for ArrayAttr - parse_bracket = self.parse_optional_char("[") - if parse_bracket: - array = self.parse_list(self.parse_optional_attribute) - self.parse_char("]") - return ArrayAttr.from_list(array) - - # Shorthand for DictionaryAttr - if self.peek_char("{"): - dictionary = self.parse_dictionary(self.parse_str_literal, - self.parse_attribute) - return DictionaryAttr.from_dict(dictionary) - - # Shorthand for FlatSymbolRefAttr - parse_at = self.parse_optional_char("@") - if parse_at: - symbol_name = self.parse_alpha_num(skip_white_space=False) - return FlatSymbolRefAttr.from_str(symbol_name) - - def parse_integer_type(): - self.parse_char("!", skip_white_space=skip_white_space) - return self.parse_mlir_integer_type( - skip_white_space=skip_white_space) - - if int_type := self.try_parse(parse_integer_type): - return int_type + if opening_brackets is not None and self.tokenizer.next_token_of_pattern( + '>') is None: + self.raise_error( + "Malformed parameter list, expected either another parameter or `>`!" + ) - return None + return res - def parse_optional_attribute(self, - skip_white_space: bool = True - ) -> Attribute | None: - # If we are parsing an MLIR file, we first try to parse builtin - # attributes, which have a different format. - if self.source == self.Source.MLIR: - if attr := self.parse_optional_mlir_attribute( - skip_white_space=skip_white_space): - return attr - - # If we are parsing an xDSL file, we first try to parse builtin - # attributes, which have a different format. - if self.source == self.Source.XDSL: - if attr := self.parse_optional_xdsl_builtin_attribute( - skip_white_space=skip_white_space): - return attr - - # Then, we parse attributes/types with the generic format. - - if self.parse_optional_char("!") is None: - if self.source == self.Source.MLIR: - if self.parse_optional_char("#") is None: - return None - else: - return None + def parse_char(self, text: str): + self.must_parse_characters(text, "Expected '{}' here!".format(text)) - parse_with_default_format = False - # Attribute with default format - if self.parse_optional_char('"'): - attr_def_name = self.parse_alpha_num(skip_white_space=False) - self.parse_char('"') - parse_with_default_format = True - else: - attr_def_name = self.parse_alpha_num(skip_white_space=True) + def parse_str_literal(self) -> str: + return self.expect(self.try_parse_string_literal, + 'Malformed string literal!').string_contents - if (self.source == self.Source.MLIR) and parse_with_default_format: - raise ParserError(self._pos, "cannot parse generic MLIR attribute") + def parse_attribute(self) -> Attribute: + return self.must_parse_attribute() - attr_def = self.ctx.get_attr(attr_def_name) + def parse_op(self) -> Operation: + return self.must_parse_operation() - # Attribute with default format - if parse_with_default_format: - if not issubclass(attr_def, ParametrizedAttribute): - raise ParserError( - self._pos, - f"{attr_def_name} is not a parameterized attribute, and " - "thus cannot be parsed with a generic format.") - params = self.parse_paramattr_parameters() - return attr_def(params) # type: ignore + def parse_int_literal(self) -> int: + return int(self.expect(self.try_parse_integer_literal, 'Expected integer literal here').text) - if issubclass(attr_def, Data): - self.parse_char("<") - attr: Any = attr_def.parse_parameter(self) - self.parse_char(">") - return attr_def(attr) # type: ignore - assert issubclass(attr_def, ParametrizedAttribute) - param_list = attr_def.parse_parameters(self) - return attr_def(param_list) # type: ignore +class MLIRParser(BaseParser): - def parse_optional_dim(self, skip_white_space: bool = True) -> int | None: + def try_parse_builtin_type(self) -> Attribute | None: """ - Parse an optional dimension. - The dimension is either a non-negative integer, or -1 for dynamic dimensions. + parse a builtin-type like i32, index, vector etc. """ - if self.parse_optional_char("?", skip_white_space=skip_white_space): - return -1 - if (dim := self.parse_optional_int_literal()) is not None: - return dim - return None + with self.tokenizer.backtracking("builtin type"): + name = self.tokenizer.next_token_of_pattern( + ParserCommons.builtin_type) + if name is None: + raise BacktrackingAbort("Expected builtin name!") - def parse_dim(self, skip_white_space: bool = True) -> int: - """ - Parse a dimension. - The dimension is either a non-negative integer, - or -1 for dynamic dimensions, represented by `?`. - """ - dim = self.parse_optional_dim(skip_white_space=skip_white_space) - if dim is not None: - return dim - raise ParserError(self._pos, "dimension expected") + return self.must_parse_builtin_type_with_name(name) - def parse_optional_shape( - self, - skip_white_space: bool = True - ) -> tuple[list[int], Attribute] | None: + def must_parse_attribute(self) -> Attribute: """ - Parse a shape, with the format `dim0 x dim1 x ... x dimN x type`. + Parse attribute (either builtin or dialect) """ - dims = list[int]() + # all dialect attrs must start with '#', so we check for that first (as it's easier) + if self.tokenizer.starts_with('#'): + value = self.try_parse_dialect_attr() - if skip_white_space: - self.skip_white_space() + # no value => error + if value is None: + self.raise_error( + '`#` must be followed by a valid dialect attribute or type!' + ) - def parse_optional_dim_and_x(): - if (dim := self.parse_optional_dim( - skip_white_space=False)) is not None: - self.parse_char("x", skip_white_space=False) - return dim - return None + return value - dims = self.parse_list(parse_optional_dim_and_x, delimiter="") - typ = self.parse_attribute() + # if it isn't a dialect attr, parse builtin + builtin_val = self.try_parse_builtin_attr() - return dims, typ + if builtin_val is None: + self.raise_error( + "Unknown attribute (neither builtin nor dialect could be parsed)!" + ) - def parse_shape( - self, - skip_white_space: bool = True) -> tuple[list[int], Attribute]: + return builtin_val + + def must_parse_op_result_list( + self) -> tuple[list[Span], list[Attribute] | None]: + return self.must_parse_list_of(self.try_parse_value_id, + 'Expected op-result here!', + allow_empty=True), None + + def must_parse_optional_attr_dict(self) -> dict[str, Attribute]: + if not self.tokenizer.starts_with('{'): + return dict() + + self.must_parse_characters( + '{', + 'MLIR Attribute dictionary must be enclosed in curly brackets') + + attrs = self.must_parse_list_of(self.must_parse_attribute_entry, + "Expected attribute entry") + + self.must_parse_characters( + '}', + 'MLIR Attribute dictionary must be enclosed in curly brackets') + + return self.attr_dict_from_tuple_list(attrs) + + def must_parse_operation_details( + self + ) -> tuple[list[Span], list[Span], dict[str, Attribute], list[Region], + FunctionType | None]: + + args = self.must_parse_op_args_list() + succ = self.must_parse_optional_successor_list() + + regions = [] + if self.tokenizer.starts_with('('): + self.must_parse_characters('(', + 'Expected brackets enclosing regions!') + regions = self.must_parse_region_list() + self.must_parse_characters(')', + 'Expected brackets enclosing regions!') + + attrs = self.must_parse_optional_attr_dict() + + self.must_parse_characters( + ':', + 'MLIR Operation defintions must end in a function type signature!') + func_type = self.must_parse_function_type() + + return args, succ, attrs, regions, func_type + + def must_parse_optional_successor_list(self) -> list[Span]: + if not self.tokenizer.starts_with('['): + return [] + self.must_parse_characters( + '[', 'Successor list is enclosed in square brackets') + successors = self.must_parse_list_of(self.try_parse_block_id, + 'Expected a block-id', + allow_empty=False) + self.must_parse_characters( + ']', 'Successor list is enclosed in square brackets') + return successors + + +class XDSLParser(BaseParser): + + def try_parse_builtin_type(self) -> Attribute | None: """ - Parse a shape, with the format `dim0 x dim1 x ... x dimN x type`. + parse a builtin-type like i32, index, vector etc. """ - shape = self.parse_optional_shape(skip_white_space=skip_white_space) - if shape is not None: - return shape - raise ParserError(self._pos, "shape expected") - - def parse_optional_mlir_tensor( - self, - skip_white_space: bool = True - ) -> AnyTensorType | AnyUnrankedTensorType | None: - if self.parse_optional_string("tensor", - skip_white_space=skip_white_space): - self.parse_char("<") - # Unranked tensor case - if self.parse_optional_char("*"): - self.parse_char("x") - typ = self.parse_attribute() - self.parse_char(">") - return UnrankedTensorType.from_type(typ) - dims, typ = self.parse_shape() - self.parse_char(">") - return TensorType.from_type_and_list(typ, dims) - return None + with self.tokenizer.backtracking("builtin type"): + name = self.tokenizer.next_token_of_pattern( + ParserCommons.builtin_type_xdsl) + if name is None: + raise BacktrackingAbort("Expected builtin name!") + # xdsl builtin types have a '!' prefix, we strip that out here + name = Span(start=name.start + 1, end=name.end, input=name.input) - def parse_optional_mlir_vector(self, - skip_white_space: bool = True - ) -> AnyVectorType | None: - if self.parse_optional_string("vector", - skip_white_space=skip_white_space): - self.parse_optional_char("<") - dims, typ = self.parse_shape() - self.parse_char(">") - return VectorType.from_element_type_and_shape(typ, dims) - return None + return self.must_parse_builtin_type_with_name(name) - def parse_optional_mlir_memref( - self, - skip_white_space: bool = True - ) -> MemRefType[Any] | UnrankedMemrefType[Any] | None: - if self.parse_optional_string("memref", - skip_white_space=skip_white_space): - self.parse_char("<") - # Unranked memref case - if self.parse_optional_char("*"): - self.parse_char("x") - typ = self.parse_attribute() - self.parse_char(">") - return UnrankedMemrefType.from_type(typ) - dims, typ = self.parse_shape() - self.parse_char(">") - return MemRefType.from_element_type_and_shape(typ, dims) - return None + def must_parse_attribute(self) -> Attribute: + """ + Parse attribute (either builtin or dialect) - def parse_optional_mlir_index_type(self, - skip_white_space: bool = True - ) -> IndexType | None: - if self.parse_optional_string("index", - skip_white_space=skip_white_space): - return IndexType() - return None + xDSL allows types in places of attributes! That's why we parse types here as well + """ + value = self.try_parse_builtin_attr() - def parse_mlir_index_type(self, - skip_white_space: bool = True) -> IndexType: - typ = self.parse_optional_mlir_index_type( - skip_white_space=skip_white_space) - if typ is not None: - return typ - raise ParserError(self._pos, "index type expected") - - def parse_mlir_integer_type(self, - skip_white_space: bool = True) -> IntegerType: - # Parse the optional signedness semantics - if self.parse_optional_string("si", skip_white_space=skip_white_space): - signedness = Signedness.SIGNED - elif self.parse_optional_string("ui", - skip_white_space=skip_white_space): - signedness = Signedness.UNSIGNED - elif self.parse_optional_string("i", - skip_white_space=skip_white_space): - signedness = Signedness.SIGNLESS - else: - raise ParserError(self._pos, "integer type expected") - - val = self.parse_int_literal(skip_white_space=False) - return IntegerType.from_width(val, signedness) - - def parse_optional_mlir_integer_type(self, - skip_white_space: bool = True - ) -> IntegerType | None: - return self.try_parse(self.parse_mlir_integer_type, - skip_white_space=skip_white_space) - - def parse_optional_mlir_float_type(self, - skip_white_space: bool = True - ) -> AnyFloat | None: - if self.parse_optional_string("f16") is not None: - return Float16Type() - if self.parse_optional_string("f32") is not None: - return Float32Type() - if self.parse_optional_string("f64") is not None: - return Float64Type() - return None + # xDSL: Allow both # and ! prefixes, as we allow both types and attrs + # TODO: phase out use of next_token(peek=True) in favour of starts_with + if value is None and self.tokenizer.next_token(peek=True).text in '#!': + # in MLIR # and ! are prefixes for dialect attrs/types, but in xDSL ! is also used for builtin types + value = self.try_parse_dialect_type_or_attribute() - def parse_mlir_float_type(self, skip_white_space: bool = True) -> AnyFloat: - typ = self.parse_optional_mlir_float_type( - skip_white_space=skip_white_space) - if typ is not None: - return typ - raise ParserError(self._pos, "float type expected") - - def parse_optional_mlir_attribute(self, - skip_white_space: bool = True - ) -> Attribute | None: - if skip_white_space: - self.skip_white_space() - - # index type - if (index_type := self.parse_optional_mlir_index_type()) is not None: - return index_type - - # integer type - if (int_type := self.parse_optional_mlir_integer_type()) is not None: - return int_type - - # float type - if (float_type := self.parse_optional_mlir_float_type()) is not None: - return float_type - - # float attribute - if (lit := self.parse_optional_float_literal()) is not None: - if self.parse_optional_char(":"): - if (typ := self.parse_optional_mlir_float_type()) is not None: - return FloatAttr.from_value(lit, typ) - raise ParserError(self._pos, "float type expected") - return FloatAttr.from_value(lit, Float64Type()) - - # Shorthand for boolean attributes (integer attributes of width 1) - if (bool_attr := self.parse_optional_boolean_attribute()) is not None: - return bool_attr - - # integer attribute - if (lit := self.parse_optional_int_literal()) is not None: - if self.parse_optional_char(":"): - if (typ := - self.parse_optional_mlir_integer_type()) is not None: - return IntegerAttr.from_params(lit, typ) - if (typ := self.parse_optional_mlir_index_type()) is not None: - return IntegerAttr.from_params(lit, typ) - raise ParserError(self._pos, "integer or index type expected") - return IntegerAttr.from_params(lit, IntegerType.from_width(64)) - - # string literal - str_literal = self.parse_optional_str_literal() - if str_literal is not None: - return StringAttr.from_str(str_literal) - - # Array attribute - if self.parse_optional_char("["): - contents = self.parse_list(self.parse_optional_attribute) - self.parse_char("]") - return ArrayAttr.from_list(contents) - - # Shorthand for DictionaryAttr - if self.peek_char("{"): - contents = self.parse_dictionary(self.parse_str_literal, - self.parse_attribute) - return DictionaryAttr.from_dict(contents) - - # FlatSymbolRefAttr - if self.parse_optional_char("@"): - symbol_name = self.parse_alpha_num(skip_white_space=False) - return FlatSymbolRefAttr.from_str(symbol_name) - - # tensor type - if (tensor := self.parse_optional_mlir_tensor()) is not None: - return tensor - - # vector type - if (vector := self.parse_optional_mlir_vector()) is not None: - return vector - - # dense attribute - if self.parse_optional_string("dense"): - self.parse_char("<") - - def parse_num() -> int | float | None: - if (f := self.parse_optional_float_literal()) is not None: - return f - if (i := self.parse_optional_int_literal()) is not None: - return i - return None + if value is None: + self.raise_error( + "Unknown attribute (neither builtin nor dialect could be parsed)!" + ) - value = self.parse_optional_nested_list(parse_num) - self.parse_char(">") - self.parse_char(":") - - # Parse the dense attribute type. It is either a tensor or a vector. - loc = self._pos - type_attr: AnyVectorType | AnyTensorType - if (vec := self.parse_optional_mlir_vector()) is not None: - type_attr = vec - elif (tensor := self.parse_optional_mlir_tensor()) is not None: - type_attr = tensor - else: - raise ParserError(loc, "expected a tensor or a vector type") - - return DenseIntOrFPElementsAttr.from_list(type_attr, value) - - # opaque attribute - if self.parse_optional_string("opaque") is not None: - self.parse_char("<") - name = self.parse_str_literal() - self.parse_char(",") - val: str = self.parse_str_literal() - self.parse_char(">") - if self.parse_optional_char(":") is not None: - typ = self.parse_attribute() - return OpaqueAttr.from_strings(name, val, typ) - return OpaqueAttr.from_strings(name, val) - - # function attribute - if self.parse_optional_char("(") is not None: - inputs = self.parse_list(self.parse_optional_attribute) - self.parse_char(")") - self.parse_string("->") - if self.parse_optional_char("("): - outputs = self.parse_list(self.parse_optional_attribute) - self.parse_char(")") - return FunctionType.from_lists(inputs, outputs) - output = self.parse_attribute() - return FunctionType.from_lists(inputs, [output]) - - # memref type - if (memref := self.parse_optional_mlir_memref()) is not None: - return memref + return value - return None + def must_parse_op_result_list( + self) -> tuple[list[Span], list[Attribute] | None]: + if not self.tokenizer.starts_with('%'): + return list(), list() + results = self.must_parse_list_of(self.try_parse_value_id_and_type, + 'Expected (value-id `:` type) here!', + allow_empty=False) + # TODO: this is hideous, make it cleaner + # zip(*results) works, but is barely readable :/ + return [name for name, _ in results], [type for _, type in results] + + def try_parse_builtin_attr(self) -> Attribute: + """ + Tries to parse a bultin attribute, e.g. a string literal, int, array, etc.. - def parse_attribute(self, skip_white_space: bool = True) -> Attribute: - res = self.parse_optional_attribute(skip_white_space=skip_white_space) - if res is None: - raise ParserError(self._pos, "attribute expected") - return res + If the mode is xDSL, it also allows parsing of builtin types + """ + # in xdsl, two things are different here: + # 1. types are considered valid attributes + # 2. all types, builtins included, are prefixed with ! + if self.tokenizer.starts_with('!'): + return self.try_parse_builtin_type() - def parse_optional_named_attribute( - self, - skip_white_space: bool = True) -> tuple[str, Attribute] | None: - # The attribute name is either a string literal, or an identifier. - attr_name = self.parse_optional_str_literal( - skip_white_space=skip_white_space) - if attr_name is None: - attr_name = self.parse_optional_alpha_num( - skip_white_space=skip_white_space) - - if attr_name is None: - return None - if not self.peek_char("="): - return attr_name, UnitAttr([]) - self.parse_char("=") - attr = self.parse_attribute() - return attr_name, attr - - def parse_op_attributes(self, - skip_white_space: bool = True - ) -> dict[str, Attribute]: - if not self.parse_optional_char( - "[" if self.source == self.Source.XDSL else "{", - skip_white_space=skip_white_space): + return super().try_parse_builtin_attr() + + def must_parse_optional_attr_dict(self) -> dict[str, Attribute]: + if not self.tokenizer.starts_with('['): return dict() - attrs_with_names = self.parse_list(self.parse_optional_named_attribute) - self.parse_char("]" if self.source == self.Source.XDSL else "}") - return {name: attr for (name, attr) in attrs_with_names} - - def parse_optional_successor(self, - skip_white_space: bool = True - ) -> Block | None: - parsed = self.parse_optional_char("^", - skip_white_space=skip_white_space) - if parsed is None: - return None - bb_name = self.parse_alpha_num(skip_white_space=False) - if bb_name in self._blocks: - block = self._blocks[bb_name] - pass - else: - block = Block() - self._blocks[bb_name] = block - return block - def parse_successors(self, skip_white_space: bool = True) -> list[Block]: - parsed = self.parse_optional_char( - "(" if self.source == self.Source.XDSL else "[", - skip_white_space=skip_white_space) - if parsed is None: - return [] - res = self.parse_list(self.parse_optional_successor, delimiter=',') - self.parse_char(")" if self.source == self.Source.XDSL else "]") - return res + self.must_parse_characters( + '[', + 'xDSL Attribute dictionary must be enclosed in square brackets') - def is_valid_name(self, name: str) -> bool: - return not name[-1].isnumeric() + attrs = self.must_parse_list_of(self.must_parse_attribute_entry, + "Expected attribute entry") - _OperationType = TypeVar('_OperationType', bound='Operation') + self.must_parse_characters( + ']', + 'xDSL Attribute dictionary must be enclosed in square brackets') - def parse_op_with_default_format( - self, - op_type: type[_OperationType], - result_types: list[Attribute], - skip_white_space: bool = True) -> _OperationType: - operands = self.parse_operands(skip_white_space=skip_white_space) - successors = self.parse_successors() - attributes = self.parse_op_attributes() - regions = self.parse_list(self.parse_optional_region, delimiter="") - - return op_type.create(operands=operands, - result_types=result_types, - attributes=attributes, - successors=successors, - regions=regions) - - def _parse_optional_op_name(self, - skip_white_space: bool = True - ) -> tuple[str, bool] | None: - op_name = self.parse_optional_alpha_num( - skip_white_space=skip_white_space) - if op_name: - return op_name, False - op_name = self.parse_optional_str_literal() - if op_name: - return op_name, True - return None + return self.attr_dict_from_tuple_list(attrs) - def _parse_op_name(self, - skip_white_space: bool = True) -> tuple[str, bool]: - op_name = self._parse_optional_op_name( - skip_white_space=skip_white_space) - if op_name is None: - raise ParserError(self._pos, "operation name expected") - return op_name - - def parse_optional_op(self, - skip_white_space: bool = True) -> Operation | None: - if self.source == self.Source.MLIR: - return self.parse_optional_mlir_op( - skip_white_space=skip_white_space) - - start_pos = self._pos - results = self.parse_optional_typed_results( - skip_white_space=skip_white_space) - if results is None: - op_name_and_generic = self._parse_optional_op_name() - if op_name_and_generic is None: - return None - op_name, is_generic_format = op_name_and_generic - results = [] - else: - op_name, is_generic_format = self._parse_op_name() - - result_types = [typ for (_, typ) in results] - op_type = self.ctx.get_optional_op(op_name) - - # If the operation is not registered, we create an UnregisteredOp instead, - # or fail. - if op_type is None: - if not self.allow_unregistered_ops: - raise ParserError(start_pos, f"unknown operation '{op_name}'") - if not is_generic_format: - raise ParserError( - start_pos, f"unknown operation '{op_name}' can " - "only be parsed using the generic format") - - op = self.parse_op_with_default_format(UnregisteredOp, - result_types) - op.attributes["op_name__"] = StringAttr.from_str(op_name) - else: - if not is_generic_format: - op = op_type.parse(result_types, self) - else: - op = self.parse_op_with_default_format(op_type, result_types) + def must_parse_operation_details( + self + ) -> tuple[list[Span], list[Span], dict[str, Attribute], list[Region], + FunctionType | None]: + """ + Must return a tuple consisting of: + - a list of arguments to the operation + - a list of successor names + - the attributes attached to the OP + - the regions of the op + - An optional function type. If not supplied, must_parse_op_result_list must return a second value + containing the types of the returned SSAValues - # Register the SSA value names in the parser - for (idx, res) in enumerate(results): - if res[0] in self._ssaValues: - raise ParserError(start_pos, - f"SSA value {res[0]} is already defined") - self._ssaValues[res[0]] = op.results[idx] - if self.is_valid_name(res[0]): - self._ssaValues[res[0]].name = res[0] + """ + args = self.must_parse_op_args_list() + succ = self.must_parse_optional_successor_list() + attrs = self.must_parse_optional_attr_dict() + regions = self.must_parse_region_list() - return op + return args, succ, attrs, regions, None - def parse_op_type( - self, - skip_white_space: bool = True - ) -> tuple[list[Attribute], list[Attribute]]: - self.parse_char("(", skip_white_space=skip_white_space) - inputs = self.parse_list(self.parse_optional_attribute) - self.parse_char(")") - self.parse_string("->") - - # No or multiple result types - if self.parse_optional_char("("): - outputs = self.parse_list(self.parse_optional_attribute) - self.parse_char(")") - else: - outputs = [self.parse_attribute()] + def must_parse_optional_successor_list(self) -> list[Span]: + if not self.tokenizer.starts_with('('): + return [] + self.must_parse_characters( + '(', 'Successor list is enclosed in round brackets') + successors = self.must_parse_list_of(self.try_parse_block_id, + 'Expected a block-id', + allow_empty=False) + self.must_parse_characters( + ')', 'Successor list is enclosed in round brackets') + return successors - return inputs, outputs + def must_parse_dialect_type_or_attribute_inner(self, kind: str): + if self.tokenizer.starts_with('"'): + name = self.try_parse_string_literal() + if name is None: + self.raise_error("Expected string literal for an attribute in generic format here!") + return self.must_parse_generic_attribute_args(name) + return super().must_parse_dialect_type_or_attribute_inner(kind) - def parse_mlir_op_with_default_format( - self, - op_type: type[_OperationType], - num_results: int, - skip_white_space: bool = True) -> _OperationType: - operands = self.parse_operands(skip_white_space=skip_white_space) + def must_parse_generic_attribute_args(self, name: StringLiteral): + attr = self.ctx.get_optional_attr(name.string_contents) + if attr is None: + self.raise_error("Unknown attribute name!", name) + if not issubclass(attr, ParametrizedAttribute): + self.raise_error("Expected ParametrizedAttribute name here!", name) + self.must_parse_characters('<', 'Expected generic attribute arguments here!') + args = self.must_parse_list_of(self.try_parse_attribute, 'Unexpected end of attribute list!') + self.must_parse_characters('>', 'Malformed attribute arguments, reached end of args list!') + return attr(args) - regions = [] - if self.parse_optional_char("(") is not None: - regions = self.parse_list(self.parse_optional_region) - self.parse_char(")") - - attributes = self.parse_op_attributes() - - self.parse_char(":") - operand_types, result_types = self.parse_op_type() - - if len(operand_types) != len(operands): - raise Exception( - "Operand types are not matching the number of operands.") - if len(result_types) != num_results: - raise Exception( - "Result types are not matching the number of results.") - for operand, operand_type in zip(operands, operand_types): - if operand.typ != operand_type: - raise Exception("Operation operand types are not matching " - "the types of its operands. Got operand with " - f"type {operand.typ}, but operation expect " - f"operand to be of type {operand_type}") - - return op_type.create(operands=operands, - result_types=result_types, - attributes=attributes, - regions=regions) - - def parse_optional_mlir_op(self, - skip_white_space: bool = True - ) -> Operation | None: - start_pos = self._pos - results = self.parse_optional_results( - skip_white_space=skip_white_space) - if results is None: - results = [] - op_name = self.parse_optional_str_literal() - if op_name is None: - return None - else: - op_name = self.parse_str_literal() - op_type = self.ctx.get_optional_op(op_name) - if op_type is None: - if not self.allow_unregistered_ops: - raise ParserError(start_pos, f"unknown operation '{op_name}'") +# COMPAT layer so parser_ng is a drop-in replacement for parser: - op_type = UnregisteredOp - op = self.parse_mlir_op_with_default_format(op_type, len(results)) - op.attributes["op_name__"] = StringAttr.from_str(op_name) - else: - op = self.parse_mlir_op_with_default_format(op_type, len(results)) - # Register the SSA value names in the parser - for (idx, res) in enumerate(results): - if res in self._ssaValues: - raise ParserError(start_pos, - f"SSA value {res} is already defined") - self._ssaValues[res] = op.results[idx] - if self.is_valid_name(res): - self._ssaValues[res].name = res +class Source(Enum): + XDSL = 1 + MLIR = 2 - return op - def parse_op(self, skip_white_space: bool = True) -> Operation: - res = self.parse_optional_op(skip_white_space=skip_white_space) - if res is None: - raise ParserError(self._pos, "operation expected") - return res +def Parser(ctx: MLContext, prog: str, source: Source = Source.XDSL, filename: str = '') -> BaseParser: + selected_parser = {Source.XDSL: XDSLParser, Source.MLIR: MLIRParser}[source] + return selected_parser(ctx, prog, filename) + + +setattr(Parser, 'Source', Source) diff --git a/xdsl/parser_ng.py b/xdsl/parser_ng.py deleted file mode 100644 index 2c28783813..0000000000 --- a/xdsl/parser_ng.py +++ /dev/null @@ -1,1707 +0,0 @@ -from __future__ import annotations - -import ast -import contextlib -import functools -import itertools -import re -import sys -import traceback -from abc import ABC, abstractmethod -from collections import defaultdict -from dataclasses import dataclass, field -from io import StringIO -from typing import TypeVar, Iterable - -from xdsl.dialects.builtin import ( - AnyTensorType, AnyVectorType, Float16Type, Float32Type, Float64Type, - FloatAttr, FunctionType, IndexType, IntegerType, Signedness, StringAttr, - IntegerAttr, ArrayAttr, TensorType, UnrankedTensorType, VectorType, - DefaultIntegerAttrType, FlatSymbolRefAttr) -from xdsl.ir import (SSAValue, Block, Callable, Attribute, Operation, Region, - BlockArgument, MLContext, ParametrizedAttribute) -from .printer import Printer - - -class ParseError(Exception): - span: Span - msg: str - history: BacktrackingHistory | None - - def __init__(self, - span: Span, - msg: str, - history: BacktrackingHistory | None = None): - super().__init__(span.print_with_context(msg)) - self.span = span - self.msg = msg - self.history = history - - def print_pretty(self, file=sys.stderr): - print(self.span.print_with_context(self.msg), file=file) - - def print_with_history(self): - if self.history is not None: - self.history.print_unroll() - - -class MultipleSpansParseError(ParseError): - ref_text: str | None - refs: list[tuple[Span, str]] - - def __init__(self, - span: Span, - msg: str, - ref_text: str, - refs: list[tuple[Span, str | None]], - history: BacktrackingHistory | None = None): - super(MultipleSpansParseError, self).__init__(span, msg, history) - self.refs = refs - self.ref_text = ref_text - - def print_pretty(self, file=sys.stderr): - super(MultipleSpansParseError, self).print_pretty(file) - print(self.ref_text or "With respect to:", file=file) - for span, msg in self.refs: - print(span.print_with_context(msg), file=file) - - -@dataclass -class BacktrackingHistory: - error: ParseError - parent: BacktrackingHistory | None - region_name: str | None - pos: int - - def print_unroll(self, file=sys.stderr): - if self.parent: - self.parent.print_unroll(file) - - print("Parsing of {} failed:".format(self.region_name or ''), - file=file) - self.error.print_pretty(file=file) - - def get_farthest_point(self) -> int: - """ - Find the farthest this history managed to parse - """ - if self.parent: - return max(self.pos, self.parent.get_farthest_point()) - return self.pos - - -class BacktrackingAbort(Exception): - reason: str | None - - def __init__(self, reason: str | None = None): - super().__init__( - "This message should never escape the parser, it's intended to signal a failed parsing " - "attempt\n " - "It should never be used outside of a tokenizer.backtracking() block!\n" - "The reason for this abort was {}".format( - 'not specified' if reason is None else reason)) - self.reason = reason - - -@dataclass(frozen=True) -class Span: - """ - Parts of the input are always passed around as spans, so we know where they originated. - """ - - start: int - """ - Start of tokens location in source file, global byte offset in file - """ - end: int - """ - End of tokens location in source file, global byte offset in file - """ - input: Input - """ - The input being operated on - """ - - def __len__(self): - return self.len - - @property - def len(self): - return self.end - self.start - - @property - def text(self): - return self.input.content[self.start:self.end] - - def print_with_context(self, msg: str | None = None) -> str: - """ - returns a string containing lines relevant to the span. The Span's contents - are highlighted by up-carets beneath them (`^`). The message msg is printed - along these. - """ - info = self.input.get_lines_containing(self) - if info is None: - return "Unknown location of span {}. Error: ".format(self, msg) - lines, offset_of_first_line, line_no = info - # offset relative to the first line: - offset = self.start - offset_of_first_line - remaining_len = max(self.len, 1) - capture = StringIO() - print("{}:{}:{}".format(self.input.name, line_no, offset, - remaining_len), - file=capture) - for line in lines: - print(line, file=capture) - if remaining_len < 0: - continue - len_on_this_line = min(remaining_len, len(line) - offset) - remaining_len -= len_on_this_line - print("{}{}".format(" " * offset, "^" * max(len_on_this_line, 1)), - file=capture) - if msg is not None: - print("{}{}".format(" " * offset, msg), file=capture) - msg = None - offset = 0 - if msg is not None: - print(msg, file=capture) - return capture.getvalue() - - def __repr__(self): - return "{}[{}:{}](text='{}')".format(self.__class__.__name__, - self.start, self.end, self.text) - - -@dataclass(frozen=True, repr=False) -class StringLiteral(Span): - - def __post_init__(self): - if len(self) < 2 or self.text[0] != '"' or self.text[-1] != '"': - raise ParseError(self, "Invalid string literal!") - - @classmethod - def from_span(cls, span: Span | None) -> StringLiteral | None: - if span is None: - return None - return cls(span.start, span.end, span.input) - - @property - def string_contents(self): - # TODO: is this a hack-job? - return ast.literal_eval(self.text) - - -@dataclass(frozen=True) -class Input: - """ - This is a very simple class that is used to keep track of the input. - """ - name: str - content: str = field(repr=False) - - @property - def len(self): - return len(self.content) - - def __len__(self): - return self.len - - def get_nth_line_bounds(self, n: int): - start = 0 - for i in range(n): - next_start = self.content.find('\n', start) - if next_start == -1: - return None - start = next_start + 1 - return start, self.content.find('\n', start) - - def get_lines_containing(self, - span: Span) -> tuple[list[str], int, int] | None: - # A pointer to the start of the first line - start = 0 - line_no = -1 - source = self.content - while True: - next_start = source.find('\n', start) - line_no += 1 - # handle eof - if next_start == -1: - if span.start > len(source): - return None - return [source[start:]], start, line_no - # as long as the next newline comes before the spans start we can continue - if next_start < span.start: - start = next_start + 1 - continue - # if the whole span is on one line, we are good as well - if next_start >= span.end: - return [source[start:next_start]], start, line_no - while next_start < span.end: - next_start = source.find('\n', next_start + 1) - return source[start:next_start].split('\n'), start, line_no - - def at(self, i: int): - if i >= self.len: - raise EOFError() - return self.content[i] - - -save_t = tuple[int, tuple[str, ...]] - - -@dataclass -class Tokenizer: - input: Input - - pos: int = field(init=False, default=0) - """ - The position in the input. Points to the first unconsumed character. - """ - - break_on: tuple[str, ...] = ('.', '%', ' ', '(', ')', '[', ']', '{', '}', - '<', '>', ':', '=', '@', '?', '|', '->', '-', - '//', '\n', '\t', '#', '"', "'", ',', '!') - """ - characters the tokenizer should break on - """ - - history: BacktrackingHistory | None = field(init=False, - default=None, - repr=False) - - last_token: Span | None = field(init=False, default=None, repr=False) - - def save(self) -> save_t: - """ - Create a checkpoint in the parsing process, useful for backtracking - """ - return self.pos, self.break_on - - def resume_from(self, save: save_t): - """ - Resume from a previously saved position. - - Restores the state of the tokenizer to the exact previous position - """ - self.pos, self.break_on = save - - @contextlib.contextmanager - def backtracking(self, region_name: str | None = None): - """ - This context manager can be used to mark backtracking regions. - - When an error is thrown during backtracking, it is recorded and stored together - with some meta information in the history attribute. - - The backtracker accepts the following exceptions: - - ParseError: signifies that the region could not be parsed because of (unexpected) syntax errors - - BacktrackingAbort: signifies that backtracking was aborted, not necessarily indicating a syntax error - - AssertionError: this error should probably be phased out in favour of the two above - - EOFError: signals that EOF was reached unexpectedly - - Any other error will be printed to stderr, but backtracking will continue as normal. - """ - save = self.save() - starting_position = self.pos - try: - yield - # clear error history when something doesn't fail - # this is because we are only interested in the last "cascade" of failures. - # if a backtracking() completes without failre, something has been parsed (we assume) - if self.pos > starting_position and self.history is not None: - self.history = None - except Exception as ex: - how_far_we_got = self.pos - - # AssertionErrors act upon the consumed token, this means we only go to the start of the token - if isinstance(ex, BacktrackingAbort): - # TODO: skip space as well - how_far_we_got -= self.last_token.len - - # if we have no error history, start recording! - if not self.history: - self.history = self.history_entry_from_exception( - ex, region_name, how_far_we_got) - - # if we got further than on previous attempts - elif how_far_we_got > self.history.get_farthest_point(): - # throw away history - self.history = None - # generate new history entry, - self.history = self.history_entry_from_exception( - ex, region_name, how_far_we_got) - - # otherwise, add to exception, if we are in a named region - elif region_name is not None and how_far_we_got - starting_position > 0: - self.history = self.history_entry_from_exception( - ex, region_name, how_far_we_got) - - self.resume_from(save) - - def history_entry_from_exception(self, ex: Exception, region: str, - pos: int) -> BacktrackingHistory: - """ - Given an exception generated inside a backtracking attempt, - generate a BacktrackingHistory object with the relevant information in it. - - If an unexpected exception type is encountered, print a traceback to stderr - """ - if isinstance(ex, ParseError): - return BacktrackingHistory(ex, self.history, region, pos) - elif isinstance(ex, AssertionError): - reason = [ - 'Generic assertion failure', - *(reason for reason in ex.args if isinstance(reason, str)) - ] - # we assume that assertions fail because of the last read-in token - if len(reason) == 1: - tb = StringIO() - traceback.print_exc(file=tb) - reason[0] += '\n' + tb.getvalue() - - return BacktrackingHistory( - ParseError(self.last_token, reason[-1], self.history), - self.history, region, pos) - elif isinstance(ex, BacktrackingAbort): - return BacktrackingHistory( - ParseError( - self.next_token(peek=True), - 'Backtracking aborted: {}'.format(ex.reason - or 'unknown reason'), - self.history), self.history, region, pos) - elif isinstance(ex, EOFError): - return BacktrackingHistory( - ParseError(self.last_token, "Encountered EOF", self.history), - self.history, region, pos) - - print("Warning: Unexpected error in backtracking:", file=sys.stderr) - traceback.print_exception(ex, file=sys.stderr) - - return BacktrackingHistory( - ParseError(self.last_token, "Unexpected exception: {}".format(ex), - self.history), self.history, region, pos) - - def next_token(self, start: int | None = None, peek: bool = False) -> Span: - """ - Return a Span of the next token, according to the self.break_on rules. - - Can be modified using: - - - start: don't start at the current tokenizer position, instead start here (useful for skipping comments, etc) - - peek: don't advance the position, only "peek" at the input - - This will skip over line comments. Meaning it will skip the entire line if it encounters '//' - """ - i = self.next_pos(start) - # construct the span: - span = Span(i, self._find_token_end(i), self.input) - # advance pointer if not peeking - if not peek: - self.pos = span.end - - # save last token - self.last_token = span - return span - - def next_token_of_pattern(self, - pattern: re.Pattern | str, - peek: bool = False) -> Span | None: - """ - Return a span that matched the pattern, or nothing. You can choose not to consume the span. - """ - start = self.next_pos() - - # handle search for string literal - if isinstance(pattern, str): - if self.starts_with(pattern): - if not peek: - self.pos = start + len(pattern) - return Span(start, start + len(pattern), self.input) - return None - - # handle regex logic - match = pattern.match(self.input.content, start) - if match is None: - return None - - if not peek: - self.pos = match.end() - - # save last token - self.last_token = Span(start, match.end(), self.input) - return self.last_token - - def consume_peeked(self, peeked_span: Span): - if peeked_span.start != self.next_pos(): - raise ParseError(peeked_span, "This is not the peeked span!") - self.pos = peeked_span.end - - def _find_token_end(self, start: int | None = None) -> int: - """ - Find the point (optionally starting from start) where the token ends - """ - i = self.next_pos() if start is None else start - # search for literal breaks - for part in self.break_on: - if self.input.content.startswith(part, i): - return i + len(part) - # otherwise return the start of the next break - return min( - filter(lambda x: x >= 0, (self.input.content.find(part, i) - for part in self.break_on))) - - def next_pos(self, i: int | None = None) -> int: - """ - Find the next starting position (optionally starting from i) - - This will skip line comments! - """ - i = self.pos if i is None else i - # skip whitespaces - while self.input.at(i).isspace(): - i += 1 - - # skip comments as well - if self.input.content.startswith('//', i): - i = self.input.content.find('\n', i) + 1 - return self.next_pos(i) - - return i - - def is_eof(self): - """ - Check if the end of the input was reached. - """ - try: - self.next_pos() - except EOFError: - return True - - @contextlib.contextmanager - def configured(self, break_on: tuple[str, ...]): - """ - This is a helper class to allow expressing a temporary change in config, allowing you to write: - - # parsing double-quoted string now - string_content = "" - with tokenizer.configured(break_on=('"', '\\'),): - # use tokenizer - - # now old config is restored automatically - - """ - save = self.save() - - if break_on is not None: - self.break_on = break_on - - try: - yield self - finally: - self.break_on = save[1] - - def starts_with(self, text: str | re.Pattern) -> bool: - start = self.next_pos() - if isinstance(text, re.Pattern): - return text.match(self.input.content, start) is None - return self.input.content.startswith(text, start) - - -class ParserCommons: - """ - Colelction of common things used in parsing MLIR/IRDL - - """ - integer_literal = re.compile(r'[+-]?([0-9]+|0x[0-9A-Fa-f]+)') - decimal_literal = re.compile(r'[+-]?([1-9][0-9]*)') - string_literal = re.compile(r'"([^\n\f\v\r"]|\\[nfvr"])+"') - float_literal = re.compile(r'[-+]?[0-9]+\.[0-9]*([eE][-+]?[0-9]+)?') - bare_id = re.compile(r'[A-Za-z_][\w$.]+') - value_id = re.compile(r'%([0-9]+|([A-Za-z_$.-][\w$.-]*))') - suffix_id = re.compile(r'([0-9]+|([A-Za-z_$.-][\w$.-]*))') - block_id = re.compile(r'\^([0-9]+|([A-Za-z_$.-][\w$.-]*))') - type_alias = re.compile(r'![A-Za-z_][\w$.]+') - attribute_alias = re.compile(r'#[A-Za-z_][\w$.]+') - boolean_literal = re.compile(r'(true|false)') - # a list of - _builtin_type_names = ( - r'[su]?i\d+', - r'f\d+', - 'tensor', - 'vector', - 'memref', - 'complex', - 'opaque', - 'tuple', - 'index', - # TODO: add all the Float8E4M3FNType, Float8E5M2Type, and BFloat16Type - ) - builtin_type = re.compile('(({}))'.format(')|('.join(_builtin_type_names))) - builtin_type_xdsl = re.compile('!(({}))'.format( - ')|('.join(_builtin_type_names))) - double_colon = re.compile('::') - comma = re.compile(',') - - -class BaseParser(ABC): - """ - Basic recursive descent parser. - - methods marked try_... will attempt to parse, and return None if they failed. If they return None - they must make sure to restore all state. - - methods marked must_... will do greedy parsing, meaning they consume as much as they can. They will - also throw an error if the think they should still be parsing. e.g. when parsing a list of numbers - separated by '::', the following input will trigger an exception: - 1::2:: - Due to the '::' present after the last element. This is useful for parsing lists, as a trailing - separator is usually considered a syntax error there. - - You can turn a try_ into a must_ by using expect(try_parse_..., error_msg) - - You can turn a must_ into a try_ by wrapping it in tokenizer.backtracking() - - must_ type parsers are preferred because they are explicit about their failure modes. - """ - - ctx: MLContext - """xDSL context.""" - - ssaValues: dict[str, SSAValue] - blocks: dict[str, Block] - forward_block_references: dict[str, list[Span]] - """ - Blocks we encountered references to before the definition (must be empty after parsing of region completes) - """ - - T_ = TypeVar('T_') - """ - Type var used for handling function that return single or multiple Spans. Basically the output type - of all try_parse functions is T_ | None - """ - - def __init__( - self, - input: str, - name: str, - ctx: MLContext, - ): - self.tokenizer = Tokenizer(Input(input, name)) - self.ctx = ctx - self.ssaValues = dict() - self.blocks = dict() - self.forward_block_references = dict() - - def begin_parse(self): - ops = [] - while (op := self.try_parse_operation()) is not None: - ops.append(op) - if not self.tokenizer.is_eof(): - self.raise_error("Could not parse entire input!") - return ops - - def get_block_from_name(self, block_name: Span): - """ - This function takes a span containing a block id (like `^42`) and returns a block. - - If the block defintion was not seen yet, we create a forward declaration. - """ - name = block_name.text - if name not in self.blocks: - self.forward_block_references[name].append(block_name) - self.blocks[name] = Block() - return self.blocks[name] - - def must_parse_block(self) -> Block: - block_id, args = self.must_parse_optional_block_label() - - if block_id is None: - block = Block(self.tokenizer.last_token) - elif self.forward_block_references.pop(block_id.text, - None) is not None: - block = self.blocks[block_id.text] - block.delcared_at = block_id - else: - if block_id.text in self.blocks: - raise MultipleSpansParseError( - block_id, - "Re-declaration of block {}".format(block_id.text), - 'Originally declared here:', - [(self.blocks[block_id.text].delcared_at, None)], - self.tokenizer.history) - block = Block(block_id) - self.blocks[block_id.text] = block - - for i, (name, type) in enumerate(args): - arg = BlockArgument(type, block, i) - self.ssaValues[name.text] = arg - block.args.append(arg) - - while (next_op := self.try_parse_operation()) is not None: - block.ops.append(next_op) - - return block - - def must_parse_optional_block_label( - self) -> tuple[Span | None, list[tuple[Span, Attribute]]]: - block_id = self.try_parse_block_id() - arg_list = list() - - if block_id is not None: - if self.tokenizer.starts_with('('): - arg_list = self.must_parse_block_arg_list() - - self.must_parse_characters(':', 'Block label must end in a `:`!') - - return block_id, arg_list - - def must_parse_block_arg_list(self) -> list[tuple[Span, Attribute]]: - self.must_parse_characters('(', 'Block arguments must start with `(`') - - args = self.must_parse_list_of(self.try_parse_value_id_and_type, - "Expected value-id and type here!") - - self.must_parse_characters(')', - 'Expected closing of block arguments!', - is_parse_error=True) - - return args - - def try_parse_single_reference(self) -> Span | None: - with self.tokenizer.backtracking('part of a reference'): - self.must_parse_characters('@', "references must start with `@`") - if (reference := self.try_parse_string_literal()) is not None: - return reference - if (reference := self.try_parse_suffix_id()) is not None: - return reference - self.raise_error( - "References must conform to `@` (string-literal | suffix-id)") - - def must_parse_reference(self) -> list[Span]: - return self.must_parse_list_of( - self.try_parse_single_reference, - 'Expected reference here in the format of `@` (suffix-id | string-literal)', - ParserCommons.double_colon, - allow_empty=False) - - def must_parse_list_of(self, - try_parse: Callable[[], T_ | None], - error_msg: str, - separator_pattern: re.Pattern = ParserCommons.comma, - allow_empty: bool = True) -> list[T_]: - """ - This is a greedy list-parser. It accepts input only in these cases: - - - If the separator isn't encountered, which signals the end of the list - - If an empty list is allowed, it accepts when the first try_parse fails - - If an empty separator is given, it instead sees a failed try_parse as the end of the list. - - This means, that the setup will not accept the input and instead raise an error: - try_parse = parse_integer_literal - separator = 'x' - input = 3x4x4xi32 - as it will read [3,4,4], then see another separator, and expects the next try_parse call to succeed - (which won't as i32 is not a valid integer literal) - """ - items = list() - first_item = try_parse() - if first_item is None: - if allow_empty: - return items - self.raise_error(error_msg) - - items.append(first_item) - - while (match := self.tokenizer.next_token_of_pattern(separator_pattern) - ) is not None: - next_item = try_parse() - if next_item is None: - # if the separator is emtpy, we are good here - if separator_pattern.pattern == '': - return items - self.raise_error(error_msg + - ' because was able to match next separator {}' - .format(match.text)) - items.append(next_item) - - return items - - def try_parse_integer_literal(self) -> Span | None: - return self.tokenizer.next_token_of_pattern( - ParserCommons.integer_literal) - - def try_parse_decimal_literal(self) -> Span | None: - return self.tokenizer.next_token_of_pattern( - ParserCommons.decimal_literal) - - def try_parse_string_literal(self) -> StringLiteral | None: - return StringLiteral.from_span( - self.tokenizer.next_token_of_pattern(ParserCommons.string_literal)) - - def try_parse_float_literal(self) -> Span | None: - return self.tokenizer.next_token_of_pattern( - ParserCommons.float_literal) - - def try_parse_bare_id(self) -> Span | None: - return self.tokenizer.next_token_of_pattern(ParserCommons.bare_id) - - def try_parse_value_id(self) -> Span | None: - return self.tokenizer.next_token_of_pattern(ParserCommons.value_id) - - def try_parse_suffix_id(self) -> Span | None: - return self.tokenizer.next_token_of_pattern(ParserCommons.suffix_id) - - def try_parse_block_id(self) -> Span | None: - return self.tokenizer.next_token_of_pattern(ParserCommons.block_id) - - def try_parse_boolean_literal(self) -> Span | None: - return self.tokenizer.next_token_of_pattern( - ParserCommons.boolean_literal) - - def try_parse_value_id_and_type(self) -> tuple[Span, Attribute] | None: - with self.tokenizer.backtracking("value id and type"): - value_id = self.try_parse_value_id() - - if value_id is None: - self.raise_error("Invalid value-id format!") - - self.must_parse_characters( - ':', 'Expected expression (value-id `:` type)') - - type = self.try_parse_type() - - if type is None: - self.raise_error("Expected type of value-id here!") - return value_id, type - - def try_parse_type(self) -> Attribute | None: - if (builtin_type := self.try_parse_builtin_type()) is not None: - return builtin_type - if (dialect_type := self.try_parse_dialect_type()) is not None: - return dialect_type - return None - - def try_parse_dialect_type_or_attribute(self) -> Attribute | None: - """ - Parse a type or an attribute. - """ - kind = self.tokenizer.next_token_of_pattern(re.compile('[!#]'), - peek=True) - - if kind is None: - return None - - with self.tokenizer.backtracking("dialect attribute or type"): - self.tokenizer.consume_peeked(kind) - if kind.text == '!': - return self.must_parse_dialect_type_or_attribute_inner('type') - else: - return self.must_parse_dialect_type_or_attribute_inner( - 'attribute') - - def try_parse_dialect_type(self): - """ - Parse a dialect type (something prefixed by `!`, defined by a dialect) - """ - if not self.tokenizer.starts_with('!'): - return None - with self.tokenizer.backtracking("dialect type"): - self.must_parse_characters('!', - "Dialect type must start with a `!`") - return self.must_parse_dialect_type_or_attribute_inner('type') - - def try_parse_dialect_attr(self): - """ - Parse a dialect attribute (something prefixed by `#`, defined by a dialect) - """ - if not self.tokenizer.starts_with('#'): - return None - with self.tokenizer.backtracking("dialect attribute"): - self.must_parse_characters( - '#', "Dialect attribute must start with a `#`") - return self.must_parse_dialect_type_or_attribute_inner('attribute') - - def must_parse_dialect_type_or_attribute_inner(self, kind: str): - type_name = self.tokenizer.next_token_of_pattern(ParserCommons.bare_id) - - if type_name is None: - self.raise_error("Expected dialect {} name here!".format(kind)) - - type_def = self.ctx.get_optional_attr(type_name.text) - if type_def is None: - self.raise_error( - "'{}' is not a know attribute!".format(type_name.text), - type_name) - - # pass the task of parsing parameters on to the attribute/type definition - param_list = type_def.parse_parameters(self) - return type_def(param_list) - - @abstractmethod - def try_parse_builtin_type(self) -> Attribute | None: - """ - parse a builtin-type like i32, index, vector etc. - """ - raise NotImplemented("Subclasses must implement this method!") - - def must_parse_builtin_parametrized_type( - self, name: Span) -> ParametrizedAttribute: - - def unimplemented() -> ParametrizedAttribute: - raise ParseError(name, - "Builtin {} not supported yet!".format(name.text)) - - builtin_parsers: dict[str, Callable[[], ParametrizedAttribute]] = { - 'vector': self.must_parse_vector_attrs, - 'memref': unimplemented, - 'tensor': self.must_parse_tensor_attrs, - 'complex': self.must_parse_complex_attrs, - 'opaque': unimplemented, - 'tuple': unimplemented, - } - - self.must_parse_characters('<', 'Expected parameter list here!') - # get the parser for the type, falling back to the unimplemented warning - res = builtin_parsers.get(name.text, unimplemented)() - self.must_parse_characters('>', - 'Expected end of parameter list here!', - is_parse_error=True) - return res - - def must_parse_complex_attrs(self): - self.raise_error("ComplexType is unimplemented!") - - def try_parse_numerical_dims(self, - accept_closing_bracket: bool = False, - lower_bound: int = 1) -> Iterable[int]: - while (shape_arg := - self.try_parse_shape_element(lower_bound)) is not None: - yield shape_arg - # look out for the closing bracket for scalable vector dims - if accept_closing_bracket and self.tokenizer.starts_with(']'): - break - self.must_parse_characters( - 'x', - 'Unexpected end of dimension parameters!', - is_parse_error=True) - - def must_parse_vector_attrs(self) -> AnyVectorType: - # also break on 'x' characters as they are separators in dimension parameters - with self.tokenizer.configured(break_on=self.tokenizer.break_on + - ('x', )): - shape = list[int](self.try_parse_numerical_dims()) - scaling_shape: list[int] | None = None - - if self.tokenizer.next_token_of_pattern('[') is not None: - # we now need to parse the scalable dimensions - scaling_shape = list(self.try_parse_numerical_dims()) - self.must_parse_characters( - ']', - 'Expected end of scalable vector dimensions here!', - is_parse_error=True) - self.must_parse_characters( - 'x', - 'Expected end of scalable vector dimensions here!', - is_parse_error=True) - - if scaling_shape is not None: - # TODO: handle scaling vectors! - self.raise_error("Warning: scaling vectors not supported!") - pass - - type = self.try_parse_type() - if type is None: - self.raise_error( - "Expected a type at the end of the vector parameters!") - - return VectorType.from_type_and_list(type, shape) - - def must_parse_tensor_or_memref_dims(self) -> list[int] | None: - with self.tokenizer.configured(break_on=self.tokenizer.break_on + - ('x', )): - # check for unranked-ness - if self.tokenizer.next_token_of_pattern('*') is not None: - # consume `x` - self.must_parse_characters( - 'x', - 'Unranked tensors must follow format (`<*x` type `>`)', - is_parse_error=True) - else: - # parse rank: - return list(self.try_parse_numerical_dims(lower_bound=0)) - - def must_parse_tensor_attrs(self) -> AnyTensorType: - shape = self.must_parse_tensor_or_memref_dims() - type = self.try_parse_type() - - if type is None: - self.raise_error("Expected tensor type here!") - - if self.tokenizer.starts_with(','): - # TODO: add tensor encoding! - raise self.raise_error("Parsing tensor encoding is not supported!") - - if shape is None and self.tokenizer.starts_with(','): - raise self.raise_error("Unranked tensors don't have an encoding!") - - if shape is not None: - return TensorType.from_type_and_list(type, shape) - - return UnrankedTensorType.from_type(type) - - def try_parse_shape_element(self, lower_bound: int = 1) -> int | None: - """ - Parse a shape element, either a decimal integer immediate or a `?`, which evaluates to -1 - - immediate cannot be smaller than lower_bound (defaults to 1) (is 0 for tensors and memrefs) - """ - int_lit = self.try_parse_decimal_literal() - - if int_lit is not None: - value = int(int_lit.text) - if value < lower_bound: - # TODO: this is ugly, it's a raise inside a try_ type function, which should instead just give up - raise ParseError( - int_lit, - "Shape element literal cannot be negative or zero!") - return value - - if self.tokenizer.next_token_of_pattern('?') is not None: - return -1 - return None - - def must_parse_type_params(self) -> list[Attribute]: - # consume opening bracket - assert self.tokenizer.next_token( - ).text == '<', 'Type must be parameterized!' - - params = self.must_parse_list_of(self.try_parse_type, - 'Expected a type here!') - - assert self.tokenizer.next_token( - ).text == '>', 'Expected end of type parameterization here!' - - return params - - def expect(self, try_parse: Callable[[], T_ | None], - error_message: str) -> T_: - """ - Used to force completion of a try_parse function. Will throw a parse error if it can't - """ - res = try_parse() - if res is None: - self.raise_error(error_message) - return res - - def raise_error(self, msg: str, at_position: Span | None = None): - """ - Helper for raising exceptions, provides as much context as possible to them. - - This will, for example, include backtracking errors, if any occured previously - """ - if at_position is None: - at_position = self.tokenizer.next_token(peek=True) - - raise ParseError(at_position, msg, self.tokenizer.history) - - def must_parse_characters(self, - text: str, - msg: str, - is_parse_error: bool = False) -> Span: - if (match := self.tokenizer.next_token_of_pattern(text)) is None: - if is_parse_error: - self.raise_error(msg) - raise AssertionError("Unexpected input: {}".format(msg)) - return match - - @abstractmethod - def must_parse_op_result_list( - self) -> tuple[list[Span], list[Attribute] | None]: - raise NotImplemented() - - def try_parse_operation(self) -> Operation | None: - with self.tokenizer.backtracking("operation"): - - result_list, ret_types = self.must_parse_op_result_list() - if len(result_list) > 0: - self.must_parse_characters( - '=', - 'Operation definitions expect an `=` after op-result-list!' - ) - - # check for custom op format - op_name = self.try_parse_bare_id() - if op_name is not None: - op_type = self.ctx.get_op(op_name.text) - op = op_type.parse(ret_types, self) - else: - # check for basic op format - op_name = self.try_parse_string_literal() - if op_name is None: - self.raise_error( - "Expected an operation name here, either a bare-id, or a string literal!" - ) - - args, successors, attrs, regions, func_type = self.must_parse_operation_details( - ) - - if ret_types is None: - assert func_type is not None - ret_types = func_type.outputs.data - - op_type = self.ctx.get_op(op_name.string_contents) - - op = op_type.create( - operands=[self.ssaValues[span.text] for span in args], - result_types=ret_types, - attributes=attrs, - successors=[ - self.blocks[block_name.text] - for block_name in successors - ], - regions=regions) - - # Register the result SSA value names in the parser - for idx, res in enumerate(result_list): - ssa_val_name = res.text - if ssa_val_name in self.ssaValues: - self.raise_error( - f"SSA value {ssa_val_name} is already defined", res) - self.ssaValues[ssa_val_name] = op.results[idx] - # TODO: check name? - self.ssaValues[ssa_val_name].name = ssa_val_name.lstrip('%') - - return op - - def must_parse_region(self) -> Region: - oldSSAVals = self.ssaValues.copy() - oldBBNames = self.blocks - oldForwardRefs = self.forward_block_references - self.blocks = dict() - self.forward_block_references = defaultdict(list) - - region = Region() - - try: - self.must_parse_characters('{', 'Regions begin with `{`') - if not self.tokenizer.starts_with('}'): - # parse first block - block = self.must_parse_block() - region.add_block(block) - - while self.tokenizer.starts_with('^'): - region.add_block(self.must_parse_block()) - - end = self.must_parse_characters( - '}', 'Reached end of region, expected `}`!') - - if len(self.forward_block_references) > 0: - raise MultipleSpansParseError( - end, - "Region ends with missing block declarations for block(s) {}!" - .format(', '.join(self.forward_block_references.keys())), - 'The following block references are dangling:', - [(span, "Reference to block \"{}\" without implementation!" - .format(span.text)) for span in itertools.chain( - *self.forward_block_references.values())], - self.tokenizer.history) - - return region - finally: - self.ssaValues = oldSSAVals - self.blocks = oldBBNames - self.forward_block_references = oldForwardRefs - - def try_parse_op_name(self) -> Span | None: - if (str_lit := self.try_parse_string_literal()) is not None: - return str_lit - return self.try_parse_bare_id() - - def must_parse_attribute_entry(self) -> tuple[Span, Attribute]: - """ - Parse entry in attribute dict. Of format: - - attrbiute_entry := (bare-id | string-literal) `=` attribute - attrbiute := dialect-attribute | builtin-attribute - """ - if (name := self.try_parse_bare_id()) is None: - name = self.try_parse_string_literal() - - if name is None: - self.raise_error( - 'Expected bare-id or string-literal here as part of attribute entry!' - ) - - self.must_parse_characters( - '=', 'Attribute entries must be of format name `=` attribute!') - - return name, self.must_parse_attribute() - - @abstractmethod - def must_parse_attribute(self) -> Attribute: - """ - Parse attribute (either builtin or dialect) - - This is different in xDSL and MLIR, so the actuall implementation is provided by the subclass - """ - raise NotImplemented() - - def try_parse_attribute(self) -> Attribute | None: - with self.tokenizer.backtracking('attribute'): - return self.must_parse_attribute() - - def must_parse_attribute_type(self) -> Attribute: - """ - Parses `:` type and returns the type - """ - self.must_parse_characters( - ':', 'Expected attribute type definition here ( `:` type )') - return self.expect( - self.try_parse_type, - 'Expected attribute type definition here ( `:` type )') - - def try_parse_builtin_attr(self) -> Attribute: - """ - Tries to parse a bultin attribute, e.g. a string literal, int, array, etc.. - """ - # order here is important! - attrs = (self.try_parse_builtin_float_attr, - self.try_parse_builtin_int_attr, - self.try_parse_builtin_str_attr, - self.try_parse_builtin_arr_attr, self.try_parse_function_type, - self.try_parse_ref_attr) - - for attr_parser in attrs: - if (val := attr_parser()) is not None: - return val - - def try_parse_ref_attr(self) -> FlatSymbolRefAttr | None: - if not self.tokenizer.starts_with('@'): - return None - - ref = self.must_parse_reference() - - if len(ref) > 1: - self.raise_error("Nested refs are not supported yet!", ref[1]) - - return FlatSymbolRefAttr.from_str(ref[0].text[1:]) - - def try_parse_builtin_int_attr(self) -> IntegerAttr | None: - bool = self.try_parse_builtin_boolean_attr() - if bool is not None: - return bool - - with self.tokenizer.backtracking("built in int attribute"): - value = self.expect( - self.try_parse_integer_literal, - 'Integer attribute must start with an integer literal!') - if self.tokenizer.next_token(peek=True).text != ':': - print(self.tokenizer.next_token(peek=True)) - return IntegerAttr.from_params(int(value.text), - DefaultIntegerAttrType) - type = self.must_parse_attribute_type() - return IntegerAttr.from_params(int(value.text), type) - - def try_parse_builtin_float_attr(self) -> FloatAttr | None: - with self.tokenizer.backtracking("float literal"): - value = self.expect( - self.try_parse_float_literal, - 'Float attribute must start with a float literal!') - # if we don't see a ':' indicating a type signature - if not self.tokenizer.starts_with(':'): - return FloatAttr.from_value(float(value.text)) - - type = self.must_parse_attribute_type() - return FloatAttr.from_value(float(value.text), type) - - def try_parse_builtin_boolean_attr(self) -> IntegerAttr | None: - span = self.try_parse_boolean_literal() - - if span is None: - return None - - int_val = ['false', 'true'].index(span.text) - return IntegerAttr.from_params(int_val, IntegerType.from_width(1)) - - def try_parse_builtin_str_attr(self): - if not self.tokenizer.starts_with('"'): - return None - - with self.tokenizer.backtracking("string literal"): - literal = self.try_parse_string_literal() - if literal is None: - self.raise_error('Invalid string literal') - return StringAttr.from_str(literal.string_contents) - - def try_parse_builtin_arr_attr(self) -> list[Attribute] | None: - if not self.tokenizer.starts_with('['): - return None - with self.tokenizer.backtracking("array literal"): - self.must_parse_characters('[', - 'Array literals must start with `[`') - attrs = self.must_parse_list_of(self.try_parse_attribute, - 'Expected array entry!') - self.must_parse_characters( - ']', 'Malformed array contents (expected end of array here!') - return ArrayAttr.from_list(attrs) - - @abstractmethod - def must_parse_optional_attr_dict(self) -> dict[str, Attribute]: - raise NotImplementedError() - - def attr_dict_from_tuple_list( - self, tuple_list: list[tuple[Span, - Attribute]]) -> dict[str, Attribute]: - """ - Convert a list of tuples (Span, Attribute) to a dictionary. - - This function converts the span to a string, trimming quotes from string literals - """ - - def span_to_str(span: Span) -> str: - if isinstance(span, StringLiteral): - return span.string_contents - return span.text - - return dict((span_to_str(span), attr) for span, attr in tuple_list) - - def must_parse_function_type(self) -> FunctionType: - """ - Parses function-type: - - viable function types are: - (i32) -> () - () -> (i32, i32) - (i32, i32) -> () - () -> i32 - Non-viable types are: - i32 -> i32 - i32 -> () - - Uses type-or-type-list-parens internally - """ - self.must_parse_characters( - '(', 'First group of function args must start with a `(`') - - args: list[Attribute] = self.must_parse_list_of( - self.try_parse_type, 'Expected type here!') - - self.must_parse_characters( - ')', - "Malformed function type, expected closing brackets of argument types!", - is_parse_error=True) - - self.must_parse_characters('->', - 'Malformed function type, expected `->`!', - is_parse_error=True) - - return FunctionType.from_lists( - args, self.must_parse_type_or_type_list_parens()) - - def must_parse_type_or_type_list_parens(self) -> list[Attribute]: - """ - Parses type-or-type-list-parens, which is used in function-type. - - type-or-type-list-parens ::= type | type-list-parens - type-list-parens ::= `(` `)` | `(` type-list-no-parens `)` - type-list-no-parens ::= type (`,` type)* - """ - if self.tokenizer.next_token_of_pattern('(') is not None: - args: list[Attribute] = self.must_parse_list_of( - self.try_parse_type, 'Expected type here!') - self.must_parse_characters(')', - "Unclosed function type argument list!", - is_parse_error=True) - else: - args = [self.try_parse_type()] - if args[0] is None: - self.raise_error( - "Function type must either be single type or list of types in parenthesis!" - ) - return args - - def try_parse_function_type(self) -> FunctionType | None: - if not self.tokenizer.starts_with('('): - return None - with self.tokenizer.backtracking('function type'): - return self.must_parse_function_type() - - def must_parse_region_list(self) -> list[Region]: - """ - Parses a sequence of regions for as long as there is a `{` in the input. - """ - regions = [] - while not self.tokenizer.is_eof() and self.tokenizer.starts_with('{'): - regions.append(self.must_parse_region()) - return regions - - # COMMON xDSL/MLIR code: - def must_parse_builtin_type_with_name(self, name: Span): - if name.text == 'index': - return IndexType() - if (re_match := re.match(r'^[su]?i(\d+)$', name.text)) is not None: - signedness = { - 's': Signedness.SIGNED, - 'u': Signedness.UNSIGNED, - 'i': Signedness.SIGNLESS - } - return IntegerType.from_width(int(re_match.group(1)), - signedness[name.text[0]]) - - if (re_match := re.match(r'^f(\d+)$', name.text)) is not None: - width = int(re_match.group(1)) - type = { - 16: Float16Type, - 32: Float32Type, - 64: Float64Type - }.get(width, None) - if type is None: - self.raise_error( - "Unsupported floating point width: {}".format(width)) - return type() - - return self.must_parse_builtin_parametrized_type(name) - - @abstractmethod - def must_parse_operation_details( - self - ) -> tuple[list[Span], list[Span], dict[str, Attribute], list[Region], - FunctionType | None]: - """ - Must return a tuple consisting of: - - a list of arguments to the operation - - a list of successor names - - the attributes attached to the OP - - the regions of the op - - An optional function type. If not supplied, must_parse_op_result_list must return a second value - containing the types of the returned SSAValues - - Your implementation should make use of the following functions: - - must_parse_op_args_list - - must_parse_optional_attr_dict - - must_parse_ - """ - raise NotImplementedError() - - def must_parse_op_args_list(self) -> list[Span]: - self.must_parse_characters( - '(', 'Operation args list must be enclosed by brackets!') - args = self.must_parse_list_of(self.try_parse_value_id_and_type, - 'Expected another bare-id here') - self.must_parse_characters( - ')', 'Operation args list must be closed by a closing bracket') - # TODO: check if type is correct here! - return [name for name, _ in args] - - # HERE STARTS A SOMEWHAT CURSED COMPATIBILITY LAYER: - # since we don't want to rewrite all dialects currently, the new emulator needs to expose the same - # interface to the dialect definitions. Here we implement that interface. - - _OperationType = TypeVar('_OperationType', bound=Operation) - - def parse_op_with_default_format( - self, - op_type: type[_OperationType], - result_types: list[Attribute], - skip_white_space: bool = True) -> _OperationType: - """ - Compatibility wrapper so the new parser can be passed instead of the old one. Parses everything after the - operation name. - - This implicitly assumes XDSL format, and will fail on MLIR style operations - """ - # TODO: remove this function and restructure custom op / irdl parsing - assert isinstance(self, XDSLParser) - args, successors, attributes, regions, _ = self.must_parse_operation_details( - ) - - for x in args: - if x.text not in self.ssaValues: - self.raise_error( - "Unknown SSAValue name, known SSA Values are: {}".format( - ", ".join(self.ssaValues.keys())), x) - - return op_type.create( - operands=[self.ssaValues[span.text] for span in args], - result_types=result_types, - attributes=attributes, - successors=[self.get_block_from_name(span) for span in successors], - regions=regions) - - def parse_paramattr_parameters( - self, - expect_brackets: bool = False, - skip_white_space: bool = True) -> list[Attribute]: - opening_brackets = self.tokenizer.next_token_of_pattern('<') - if expect_brackets and opening_brackets is None: - self.raise_error("Expected start attribute parameters here (`<`)!") - - res = self.must_parse_list_of(self.try_parse_attribute, - 'Expected another attribute here!') - - if opening_brackets is not None and self.tokenizer.next_token_of_pattern( - '>') is None: - self.raise_error( - "Malformed parameter list, expected either another parameter or `>`!" - ) - - return res - - def parse_char(self, text: str): - self.must_parse_characters(text, "Expected '{}' here!".format(text)) - - def parse_str_literal(self) -> str: - return self.expect(self.try_parse_string_literal, - 'Malformed string literal!').string_contents - - def parse_attribute(self) -> Attribute: - return self.must_parse_attribute() - - -class MLIRParser(BaseParser): - - def try_parse_builtin_type(self) -> Attribute | None: - """ - parse a builtin-type like i32, index, vector etc. - """ - with self.tokenizer.backtracking("builtin type"): - name = self.tokenizer.next_token_of_pattern( - ParserCommons.builtin_type) - if name is None: - raise BacktrackingAbort("Expected builtin name!") - - return self.must_parse_builtin_type_with_name(name) - - def must_parse_attribute(self) -> Attribute: - """ - Parse attribute (either builtin or dialect) - """ - # all dialect attrs must start with '#', so we check for that first (as it's easier) - if self.tokenizer.starts_with('#'): - value = self.try_parse_dialect_attr() - - # no value => error - if value is None: - self.raise_error( - '`#` must be followed by a valid dialect attribute or type!' - ) - - return value - - # if it isn't a dialect attr, parse builtin - builtin_val = self.try_parse_builtin_attr() - - if builtin_val is None: - self.raise_error( - "Unknown attribute (neither builtin nor dialect could be parsed)!" - ) - - return builtin_val - - def must_parse_op_result_list( - self) -> tuple[list[Span], list[Attribute] | None]: - return self.must_parse_list_of(self.try_parse_value_id, - 'Expected op-result here!', - allow_empty=True), None - - def must_parse_optional_attr_dict(self) -> dict[str, Attribute]: - if not self.tokenizer.starts_with('{'): - return dict() - - self.must_parse_characters( - '{', - 'MLIR Attribute dictionary must be enclosed in curly brackets') - - attrs = self.must_parse_list_of(self.must_parse_attribute_entry, - "Expected attribute entry") - - self.must_parse_characters( - '}', - 'MLIR Attribute dictionary must be enclosed in curly brackets') - - return self.attr_dict_from_tuple_list(attrs) - - def must_parse_operation_details( - self - ) -> tuple[list[Span], list[Span], dict[str, Attribute], list[Region], - FunctionType | None]: - - args = self.must_parse_op_args_list() - succ = self.must_parse_optional_successor_list() - - regions = [] - if self.tokenizer.starts_with('('): - self.must_parse_characters('(', - 'Expected brackets enclosing regions!') - regions = self.must_parse_region_list() - self.must_parse_characters(')', - 'Expected brackets enclosing regions!') - - attrs = self.must_parse_optional_attr_dict() - - self.must_parse_characters( - ':', - 'MLIR Operation defintions must end in a function type signature!') - func_type = self.must_parse_function_type() - - return args, succ, attrs, regions, func_type - - def must_parse_optional_successor_list(self) -> list[Span]: - if not self.tokenizer.starts_with('['): - return [] - self.must_parse_characters( - '[', 'Successor list is enclosed in square brackets') - successors = self.must_parse_list_of(self.try_parse_block_id, - 'Expected a block-id', - allow_empty=False) - self.must_parse_characters( - ']', 'Successor list is enclosed in square brackets') - return successors - - -class XDSLParser(BaseParser): - - def try_parse_builtin_type(self) -> Attribute | None: - """ - parse a builtin-type like i32, index, vector etc. - """ - with self.tokenizer.backtracking("builtin type"): - name = self.tokenizer.next_token_of_pattern( - ParserCommons.builtin_type_xdsl) - if name is None: - raise BacktrackingAbort("Expected builtin name!") - # xdsl builtin types have a '!' prefix, we strip that out here - name = Span(start=name.start + 1, end=name.end, input=name.input) - - return self.must_parse_builtin_type_with_name(name) - - def must_parse_attribute(self) -> Attribute: - """ - Parse attribute (either builtin or dialect) - - xDSL allows types in places of attributes! That's why we parse types here as well - """ - value = self.try_parse_builtin_attr() - - # xDSL: Allow both # and ! prefixes, as we allow both types and attrs - if value is None and self.tokenizer.next_token(peek=True).text in '#!': - # in MLIR # and ! are prefixes for dialect attrs/types, but in xDSL ! is also used for builtin types - value = self.try_parse_dialect_type_or_attribute() - - if value is None: - self.raise_error( - "Unknown attribute (neither builtin nor dialect could be parsed)!" - ) - - return value - - def must_parse_op_result_list( - self) -> tuple[list[Span], list[Attribute] | None]: - results = self.must_parse_list_of(self.try_parse_value_id_and_type, - 'Expected (value-id `:` type) here!', - allow_empty=True) - # TODO: this is hideous, make it cleaner - # zip(*results) works, but is barely readable :/ - return [name for name, _ in results], [type for _, type in results] - - def try_parse_builtin_attr(self) -> Attribute: - """ - Tries to parse a bultin attribute, e.g. a string literal, int, array, etc.. - - If the mode is xDSL, it also allows parsing of builtin types - """ - # in xdsl, two things are different here: - # 1. types are considered valid attributes - # 2. all types, builtins included, are prefixed with ! - if self.tokenizer.starts_with('!'): - return self.try_parse_builtin_type() - - return super().try_parse_builtin_attr() - - def must_parse_optional_attr_dict(self) -> dict[str, Attribute]: - if not self.tokenizer.starts_with('['): - return dict() - - self.must_parse_characters( - '[', - 'xDSL Attribute dictionary must be enclosed in square brackets') - - attrs = self.must_parse_list_of(self.must_parse_attribute_entry, - "Expected attribute entry") - - self.must_parse_characters( - ']', - 'xDSL Attribute dictionary must be enclosed in square brackets') - - return self.attr_dict_from_tuple_list(attrs) - - def must_parse_operation_details( - self - ) -> tuple[list[Span], list[Span], dict[str, Attribute], list[Region], - FunctionType | None]: - """ - Must return a tuple consisting of: - - a list of arguments to the operation - - a list of successor names - - the attributes attached to the OP - - the regions of the op - - An optional function type. If not supplied, must_parse_op_result_list must return a second value - containing the types of the returned SSAValues - - """ - args = self.must_parse_op_args_list() - succ = self.must_parse_optional_successor_list() - attrs = self.must_parse_optional_attr_dict() - regions = self.must_parse_region_list() - - return args, succ, attrs, regions, None - - def must_parse_optional_successor_list(self) -> list[Span]: - if not self.tokenizer.starts_with('('): - return [] - self.must_parse_characters( - '(', 'Successor list is enclosed in round brackets') - successors = self.must_parse_list_of(self.try_parse_block_id, - 'Expected a block-id', - allow_empty=False) - self.must_parse_characters( - ')', 'Successor list is enclosed in round brackets') - return successors - - -if __name__ == '__main__': - infile = sys.argv[-1] - from xdsl.dialects.affine import Affine - from xdsl.dialects.arith import Arith - from xdsl.dialects.builtin import Builtin - from xdsl.dialects.cf import Cf - from xdsl.dialects.cmath import CMath - from xdsl.dialects.func import Func - from xdsl.dialects.irdl import IRDL - from xdsl.dialects.llvm import LLVM - from xdsl.dialects.memref import MemRef - from xdsl.dialects.scf import Scf - - ctx = MLContext() - ctx.register_dialect(Builtin) - ctx.register_dialect(Func) - ctx.register_dialect(Arith) - ctx.register_dialect(MemRef) - ctx.register_dialect(Affine) - ctx.register_dialect(Scf) - ctx.register_dialect(Cf) - ctx.register_dialect(CMath) - ctx.register_dialect(IRDL) - ctx.register_dialect(LLVM) - - parses_by_file_name = {'xdsl': XDSLParser, 'mlir': MLIRParser} - - parser = parses_by_file_name[infile.split('.')[-1]] - - p = parser(infile, open(infile, 'r').read(), ctx) - - printer = Printer() - try: - for op in p.begin_parse(): - printer.print_op(op) - except ParseError as pe: - pe.print_with_history() From 34157f57f6c0a4154a5ddddfc1586b0fbf5516fe Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Fri, 13 Jan 2023 15:40:47 +0000 Subject: [PATCH 24/65] parser: fixed all remaining tests --- tests/test_parser_error.py | 18 +- tests/test_printer.py | 4 +- xdsl/parser.py | 428 ++++++++++++++++++++++--------------- 3 files changed, 266 insertions(+), 184 deletions(-) diff --git a/tests/test_parser_error.py b/tests/test_parser_error.py index 4f853f86a8..c0abffbea6 100644 --- a/tests/test_parser_error.py +++ b/tests/test_parser_error.py @@ -23,8 +23,14 @@ def check_error(prog: str, line: int, column: int, message: str): parser.must_parse_operation() assert e.value.span - assert e.value.span.get_line_col() == (line, column) - assert any(message in ex.error.msg for ex in e.value.history.iterate()) + msgs = [err.error.msg for err in e.value.history.iterate()] + + for err in e.value.history.iterate(): + if message in err.error.msg: + assert err.error.span.get_line_col() == (line, column) + break + else: + assert False, "'{}' not found in an error message {}!".format(message, e.value.args) def test_parser_missing_equal(): @@ -38,7 +44,7 @@ def test_parser_missing_equal(): %0 : !i32 unknown() } """ - check_error(prog, 3, 13, "Operation definitions expect an `=` after op-result-list!") + check_error(prog, 3, 12, "Operation definitions expect an `=` after op-result-list!") def test_parser_redefined_value(): @@ -67,10 +73,10 @@ def test_parser_missing_operation_name(): %val : !i32 = } """ - check_error(prog, 3, 13, "Expected an operation name here") + check_error(prog, 4, 0, "Expected an operation name here") -def test_parser_missing_attribute(): +def test_parser_malformed_type(): """Test a missing attribute error.""" ctx = MLContext() ctx.register_op(UnkownOp) @@ -81,4 +87,4 @@ def test_parser_missing_attribute(): %val : i32 = unknown() } """ - check_error(prog, 3, 10, "attribute expected") + check_error(prog, 3, 9, "Expected type of value-id here!") diff --git a/tests/test_printer.py b/tests/test_printer.py index 487f7d27d6..edf3f8333d 100644 --- a/tests/test_printer.py +++ b/tests/test_printer.py @@ -385,11 +385,11 @@ def parse(cls, result_types: List[Attribute], parser: BaseParser) -> PlusCustomFormatOp: def get_ssa_val(name: Span) -> SSAValue: if name.text not in parser.ssaValues: - parser.raise_error('Unknown SSA Value name', name) + parser.raise_error('SSA Value used before assignment', name) return parser.ssaValues[name.text] lhs = parser.expect(parser.try_parse_value_id, 'Expected SSA Value name here!') - parser.parse_char("+") + parser.must_parse_characters("+", "Malformed operation format, expected `+`!") rhs = parser.expect(parser.try_parse_value_id, 'Expected SSA Value name here!') return PlusCustomFormatOp.create(operands=[get_ssa_val(name) for name in (lhs, rhs)], diff --git a/xdsl/parser.py b/xdsl/parser.py index fc8300f0fb..5a98f58034 100644 --- a/xdsl/parser.py +++ b/xdsl/parser.py @@ -18,10 +18,9 @@ AnyTensorType, AnyVectorType, Float16Type, Float32Type, Float64Type, FloatAttr, FunctionType, IndexType, IntegerType, Signedness, StringAttr, IntegerAttr, ArrayAttr, TensorType, UnrankedTensorType, VectorType, - DefaultIntegerAttrType, FlatSymbolRefAttr) + DefaultIntegerAttrType, FlatSymbolRefAttr, DenseIntOrFPElementsAttr) from xdsl.ir import (SSAValue, Block, Callable, Attribute, Operation, Region, BlockArgument, MLContext, ParametrizedAttribute, Data) -from .printer import Printer class ParseError(Exception): @@ -57,22 +56,21 @@ def print_with_history(self, file=sys.stderr): def __repr__(self): io = StringIO() self.print_with_history(io) - return "{}:\n{}".format( - self.__class__.__name__, - io.getvalue() - ) + return "{}:\n{}".format(self.__class__.__name__, io.getvalue()) class MultipleSpansParseError(ParseError): ref_text: str | None refs: list[tuple[Span, str]] - def __init__(self, - span: Span, - msg: str, - ref_text: str, - refs: list[tuple[Span, str | None]], - history: BacktrackingHistory | None = None): + def __init__( + self, + span: Span, + msg: str, + ref_text: str, + refs: list[tuple[Span, str | None]], + history: BacktrackingHistory | None = None, + ): super(MultipleSpansParseError, self).__init__(span, msg, history) self.refs = refs self.ref_text = ref_text @@ -101,7 +99,7 @@ def print_unroll(self, file=sys.stderr): self.parent.print_unroll(file) def print(self, file=sys.stderr): - print("Parsing of {} failed:".format(self.region_name or ''), + print("Parsing of {} failed:".format(self.region_name or ""), file=file) self.error.print_pretty(file=file) @@ -387,36 +385,51 @@ def history_entry_from_exception(self, ex: Exception, region: str, return BacktrackingHistory(ex, self.history, region, pos) elif isinstance(ex, AssertionError): reason = [ - 'Generic assertion failure', - *(reason for reason in ex.args if isinstance(reason, str)) + "Generic assertion failure", + *(reason for reason in ex.args if isinstance(reason, str)), ] # we assume that assertions fail because of the last read-in token if len(reason) == 1: tb = StringIO() traceback.print_exc(file=tb) - reason[0] += '\n' + tb.getvalue() + reason[0] += "\n" + tb.getvalue() return BacktrackingHistory( ParseError(self.last_token, reason[-1], self.history), - self.history, region, pos) + self.history, + region, + pos, + ) elif isinstance(ex, BacktrackingAbort): return BacktrackingHistory( ParseError( self.next_token(peek=True), - 'Backtracking aborted: {}'.format(ex.reason - or 'unknown reason'), - self.history), self.history, region, pos) + "Backtracking aborted: {}".format(ex.reason + or "unknown reason"), + self.history, + ), + self.history, + region, + pos, + ) elif isinstance(ex, EOFError): return BacktrackingHistory( ParseError(self.last_token, "Encountered EOF", self.history), - self.history, region, pos) + self.history, + region, + pos, + ) print("Warning: Unexpected error in backtracking:", file=sys.stderr) traceback.print_exception(ex, file=sys.stderr) return BacktrackingHistory( ParseError(self.last_token, "Unexpected exception: {}".format(ex), - self.history), self.history, region, pos) + self.history), + self.history, + region, + pos, + ) def next_token(self, start: int | None = None, peek: bool = False) -> Span: """ @@ -487,8 +500,10 @@ def _find_token_end(self, start: int | None = None) -> int: return i + len(part) # otherwise return the start of the next break return min( - filter(lambda x: x >= 0, (self.input.content.find(part, i) - for part in self.break_on))) + filter( + lambda x: x >= 0, + (self.input.content.find(part, i) for part in self.break_on), + )) def next_pos(self, i: int | None = None) -> int: """ @@ -502,8 +517,8 @@ def next_pos(self, i: int | None = None) -> int: i += 1 # skip comments as well - if self.input.content.startswith('//', i): - i = self.input.content.find('\n', i) + 1 + if self.input.content.startswith("//", i): + i = self.input.content.find("\n", i) + 1 return self.next_pos(i) return i @@ -555,35 +570,29 @@ class ParserCommons: Colelction of common things used in parsing MLIR/IRDL """ - integer_literal = re.compile(r'[+-]?([0-9]+|0x[0-9A-Fa-f]+)') - decimal_literal = re.compile(r'[+-]?([1-9][0-9]*)') + + integer_literal = re.compile(r"[+-]?([0-9]+|0x[0-9A-Fa-f]+)") + decimal_literal = re.compile(r"[+-]?([1-9][0-9]*)") string_literal = re.compile(r'"([^\n\f\v\r"]|\\[nfvr"])+"') - float_literal = re.compile(r'[-+]?[0-9]+\.[0-9]*([eE][-+]?[0-9]+)?') - bare_id = re.compile(r'[A-Za-z_][\w$.]+') - value_id = re.compile(r'%([0-9]+|([A-Za-z_$.-][\w$.-]*))') - suffix_id = re.compile(r'([0-9]+|([A-Za-z_$.-][\w$.-]*))') - block_id = re.compile(r'\^([0-9]+|([A-Za-z_$.-][\w$.-]*))') - type_alias = re.compile(r'![A-Za-z_][\w$.]+') - attribute_alias = re.compile(r'#[A-Za-z_][\w$.]+') - boolean_literal = re.compile(r'(true|false)') + float_literal = re.compile(r"[-+]?[0-9]+\.[0-9]*([eE][-+]?[0-9]+)?") + bare_id = re.compile(r"[A-Za-z_][\w$.]+") + value_id = re.compile(r"%([0-9]+|([A-Za-z_$.-][\w$.-]*))") + suffix_id = re.compile(r"([0-9]+|([A-Za-z_$.-][\w$.-]*))") + block_id = re.compile(r"\^([0-9]+|([A-Za-z_$.-][\w$.-]*))") + type_alias = re.compile(r"![A-Za-z_][\w$.]+") + attribute_alias = re.compile(r"#[A-Za-z_][\w$.]+") + boolean_literal = re.compile(r"(true|false)") # a list of _builtin_type_names = ( - r'[su]?i\d+', - r'f\d+', - 'tensor', - 'vector', - 'memref', - 'complex', - 'opaque', - 'tuple', - 'index', + r"[su]?i\d+", r"f\d+", "tensor", "vector", "memref", "complex", + "opaque", "tuple", "index", "dense" # TODO: add all the Float8E4M3FNType, Float8E5M2Type, and BFloat16Type ) - builtin_type = re.compile('(({}))'.format(')|('.join(_builtin_type_names))) - builtin_type_xdsl = re.compile('!(({}))'.format( - ')|('.join(_builtin_type_names))) - double_colon = re.compile('::') - comma = re.compile(',') + builtin_type = re.compile("(({}))".format(")|(".join(_builtin_type_names))) + builtin_type_xdsl = re.compile("!(({}))".format( + ")|(".join(_builtin_type_names))) + double_colon = re.compile("::") + comma = re.compile(",") class BaseParser(ABC): @@ -617,17 +626,17 @@ class BaseParser(ABC): Blocks we encountered references to before the definition (must be empty after parsing of region completes) """ - T_ = TypeVar('T_') + T_ = TypeVar("T_") """ Type var used for handling function that return single or multiple Spans. Basically the output type of all try_parse functions is T_ | None """ def __init__( - self, - ctx: MLContext, - input: str, - name: str, + self, + ctx: MLContext, + input: str, + name: str, ): self.tokenizer = Tokenizer(Input(input, name)) self.ctx = ctx @@ -669,9 +678,10 @@ def must_parse_block(self) -> Block: raise MultipleSpansParseError( block_id, "Re-declaration of block {}".format(block_id.text), - 'Originally declared here:', + "Originally declared here:", [(self.blocks[block_id.text].delcared_at, None)], - self.tokenizer.history) + self.tokenizer.history, + ) block = Block(block_id) self.blocks[block_id.text] = block @@ -879,9 +889,11 @@ def must_parse_dialect_type_or_attribute_inner(self, kind: str): if issubclass(type_def, ParametrizedAttribute): param_list = type_def.parse_parameters(self) elif issubclass(type_def, Data): - self.must_parse_characters('<', 'This attribute must be parametrized!') + self.must_parse_characters("<", + "This attribute must be parametrized!") param_list = type_def.parse_parameter(self) - self.must_parse_characters('>', 'Invalid attribute parametrization, expected `>`!') + self.must_parse_characters( + ">", "Invalid attribute parametrization, expected `>`!") else: assert False, "Mathieu said this cannot be." return type_def(param_list) @@ -901,21 +913,37 @@ def unimplemented() -> ParametrizedAttribute: "Builtin {} not supported yet!".format(name.text)) builtin_parsers: dict[str, Callable[[], ParametrizedAttribute]] = { - 'vector': self.must_parse_vector_attrs, - 'memref': unimplemented, - 'tensor': self.must_parse_tensor_attrs, - 'complex': self.must_parse_complex_attrs, - 'opaque': unimplemented, - 'tuple': unimplemented, + "vector": self.must_parse_vector_attrs, + "memref": unimplemented, + "tensor": self.must_parse_tensor_attrs, + "complex": self.must_parse_complex_attrs, + "opaque": unimplemented, + "tuple": unimplemented, } - self.must_parse_characters('<', 'Expected parameter list here!') + self.must_parse_characters("<", "Expected parameter list here!") # get the parser for the type, falling back to the unimplemented warning res = builtin_parsers.get(name.text, unimplemented)() - self.must_parse_characters('>', - 'Expected end of parameter list here!') + self.must_parse_characters(">", "Expected end of parameter list here!") + + if name in ("dense", ): + self.must_parse_characters( + ":", + "Attribute {} must be followed by (`:` type)!".format(name)) + type = self.expect( + self.try_parse_type(), + "Attribute {} must be followed by (`:` type)!".format(name), + ) + return res + def must_parse_dense_type_attrs(self): + arr = self.expect( + self.try_parse_builtin_arr_attr(), + "dense attribute must be parametrized by Array", + ) + DenseIntOrFPElementsAttr.from_list(arr) + def must_parse_complex_attrs(self): self.raise_error("ComplexType is unimplemented!") @@ -923,31 +951,28 @@ def try_parse_numerical_dims(self, accept_closing_bracket: bool = False, lower_bound: int = 1) -> Iterable[int]: while (shape_arg := - self.try_parse_shape_element(lower_bound)) is not None: + self.try_parse_shape_element(lower_bound)) is not None: yield shape_arg # look out for the closing bracket for scalable vector dims - if accept_closing_bracket and self.tokenizer.starts_with(']'): + if accept_closing_bracket and self.tokenizer.starts_with("]"): break self.must_parse_characters( - 'x', - 'Unexpected end of dimension parameters!') + "x", "Unexpected end of dimension parameters!") def must_parse_vector_attrs(self) -> AnyVectorType: # also break on 'x' characters as they are separators in dimension parameters with self.tokenizer.configured(break_on=self.tokenizer.break_on + - ('x',)): + ("x", )): shape = list[int](self.try_parse_numerical_dims()) scaling_shape: list[int] | None = None - if self.tokenizer.next_token_of_pattern('[') is not None: + if self.tokenizer.next_token_of_pattern("[") is not None: # we now need to parse the scalable dimensions scaling_shape = list(self.try_parse_numerical_dims()) self.must_parse_characters( - ']', - 'Expected end of scalable vector dimensions here!') + "]", "Expected end of scalable vector dimensions here!") self.must_parse_characters( - 'x', - 'Expected end of scalable vector dimensions here!') + "x", "Expected end of scalable vector dimensions here!") if scaling_shape is not None: # TODO: handle scaling vectors! @@ -1046,9 +1071,7 @@ def raise_error(self, msg: str, at_position: Span | None = None): raise ParseError(at_position, msg, self.tokenizer.history) - def must_parse_characters(self, - text: str, - msg: str, ) -> Span: + def must_parse_characters(self, text: str, msg: str) -> Span: if (match := self.tokenizer.next_token_of_pattern(text)) is None: self.raise_error(msg) return match @@ -1123,17 +1146,19 @@ def must_parse_region(self) -> Region: region = Region() try: - self.must_parse_characters('{', 'Regions begin with `{`') - if not self.tokenizer.starts_with('}'): + self.must_parse_characters("{", "Regions begin with `{`") + if self.tokenizer.starts_with("}"): + region.add_block(Block()) + else: # parse first block block = self.must_parse_block() region.add_block(block) - while self.tokenizer.starts_with('^'): + while self.tokenizer.starts_with("^"): region.add_block(self.must_parse_block()) end = self.must_parse_characters( - '}', 'Reached end of region, expected `}`!') + "}", "Reached end of region, expected `}`!") if len(self.forward_block_references) > 0: raise MultipleSpansParseError( @@ -1169,11 +1194,11 @@ def must_parse_attribute_entry(self) -> tuple[Span, Attribute]: if name is None: self.raise_error( - 'Expected bare-id or string-literal here as part of attribute entry!' + "Expected bare-id or string-literal here as part of attribute entry!" ) self.must_parse_characters( - '=', 'Attribute entries must be of format name `=` attribute!') + "=", "Attribute entries must be of format name `=` attribute!") return name, self.must_parse_attribute() @@ -1187,7 +1212,7 @@ def must_parse_attribute(self) -> Attribute: raise NotImplemented() def try_parse_attribute(self) -> Attribute | None: - with self.tokenizer.backtracking('attribute'): + with self.tokenizer.backtracking("attribute"): return self.must_parse_attribute() def must_parse_attribute_type(self) -> Attribute: @@ -1195,28 +1220,71 @@ def must_parse_attribute_type(self) -> Attribute: Parses `:` type and returns the type """ self.must_parse_characters( - ':', 'Expected attribute type definition here ( `:` type )') + ":", "Expected attribute type definition here ( `:` type )") return self.expect( self.try_parse_type, - 'Expected attribute type definition here ( `:` type )') + "Expected attribute type definition here ( `:` type )") - def try_parse_builtin_attr(self) -> Attribute: + def try_parse_builtin_attr(self) -> Attribute | None: """ Tries to parse a bultin attribute, e.g. a string literal, int, array, etc.. """ + next_token = self.tokenizer.next_token(peek=True) + if next_token.text == '"': + return self.try_parse_builtin_str_attr() + elif next_token.text == "[": + return self.try_parse_builtin_arr_attr() + elif next_token.text == "@": + return self.try_parse_ref_attr() + elif next_token.text == "dense": + return self.try_parse_builtin_dense_attr() + # order here is important! attrs = (self.try_parse_builtin_float_attr, - self.try_parse_builtin_int_attr, - self.try_parse_builtin_str_attr, - self.try_parse_builtin_arr_attr, self.try_parse_function_type, - self.try_parse_ref_attr) + self.try_parse_builtin_int_attr) for attr_parser in attrs: if (val := attr_parser()) is not None: return val + def try_parse_builtin_dense_attr(self) -> Attribute | None: + with self.tokenizer.backtracking("dense attribute"): + self.must_parse_characters("dense", "builtin dense attribute must start with `dense`") + err_msg = "Malformed dense attribute, format must be (`dense<` array-attr `>:` type)" + self.must_parse_characters("<", err_msg) + info = list(self.must_parse_builtin_dense_attr_args()) + self.must_parse_characters(">", err_msg) + self.must_parse_characters(":", err_msg) + type = self.expect(self.try_parse_type, "Dense attribute must be typed!") + return DenseIntOrFPElementsAttr.from_list(type, info) + + def must_parse_builtin_dense_attr_args(self) -> Iterable[int | float]: + """ + dense attribute params must be: + + dense-attr-params := float-literal | int-literal | list-of-dense-attrs-params + list-of-dense-attrs-params := `[` dense-attr-params (`,` dense-attr-params)* `]` + """ + def try_parse_int_or_float(): + if (literal := self.try_parse_float_literal()) is not None: + return float(literal.text) + if (literal := self.try_parse_integer_literal()) is not None: + return int(literal.text) + self.raise_error('Expected int or float literal here!') + if not self.tokenizer.starts_with('['): + yield try_parse_int_or_float() + return + + self.must_parse_characters('[', '') + while not self.tokenizer.starts_with(']'): + yield from self.must_parse_builtin_dense_attr_args() + if self.tokenizer.next_token_of_pattern(',') is None: + break + self.must_parse_characters(']', '') + + def try_parse_ref_attr(self) -> FlatSymbolRefAttr | None: - if not self.tokenizer.starts_with('@'): + if not self.tokenizer.starts_with("@"): return None ref = self.must_parse_reference() @@ -1246,9 +1314,10 @@ def try_parse_builtin_float_attr(self) -> FloatAttr | None: with self.tokenizer.backtracking("float literal"): value = self.expect( self.try_parse_float_literal, - 'Float attribute must start with a float literal!') + "Float attribute must start with a float literal!", + ) # if we don't see a ':' indicating a type signature - if not self.tokenizer.starts_with(':'): + if not self.tokenizer.starts_with(":"): return FloatAttr.from_value(float(value.text)) type = self.must_parse_attribute_type() @@ -1260,7 +1329,7 @@ def try_parse_builtin_boolean_attr(self) -> IntegerAttr | None: if span is None: return None - int_val = ['false', 'true'].index(span.text) + int_val = ["false", "true"].index(span.text) return IntegerAttr.from_params(int_val, IntegerType.from_width(1)) def try_parse_builtin_str_attr(self): @@ -1270,19 +1339,19 @@ def try_parse_builtin_str_attr(self): with self.tokenizer.backtracking("string literal"): literal = self.try_parse_string_literal() if literal is None: - self.raise_error('Invalid string literal') + self.raise_error("Invalid string literal") return StringAttr.from_str(literal.string_contents) - def try_parse_builtin_arr_attr(self) -> list[Attribute] | None: - if not self.tokenizer.starts_with('['): + def try_parse_builtin_arr_attr(self) -> ArrayAttr | None: + if not self.tokenizer.starts_with("["): return None with self.tokenizer.backtracking("array literal"): - self.must_parse_characters('[', - 'Array literals must start with `[`') + self.must_parse_characters("[", + "Array literals must start with `[`") attrs = self.must_parse_list_of(self.try_parse_attribute, - 'Expected array entry!') + "Expected array entry!") self.must_parse_characters( - ']', 'Malformed array contents (expected end of array here!') + "]", "Malformed array contents (expected end of array here!") return ArrayAttr.from_list(attrs) @abstractmethod @@ -1321,17 +1390,18 @@ def must_parse_function_type(self) -> FunctionType: Uses type-or-type-list-parens internally """ self.must_parse_characters( - '(', 'First group of function args must start with a `(`') + "(", "First group of function args must start with a `(`") args: list[Attribute] = self.must_parse_list_of( - self.try_parse_type, 'Expected type here!') + self.try_parse_type, "Expected type here!") self.must_parse_characters( - ')', - "Malformed function type, expected closing brackets of argument types!") + ")", + "Malformed function type, expected closing brackets of argument types!" + ) - self.must_parse_characters('->', - 'Malformed function type, expected `->`!') + self.must_parse_characters("->", + "Malformed function type, expected `->`!") return FunctionType.from_lists( args, self.must_parse_type_or_type_list_parens()) @@ -1344,23 +1414,23 @@ def must_parse_type_or_type_list_parens(self) -> list[Attribute]: type-list-parens ::= `(` `)` | `(` type-list-no-parens `)` type-list-no-parens ::= type (`,` type)* """ - if self.tokenizer.next_token_of_pattern('(') is not None: + if self.tokenizer.next_token_of_pattern("(") is not None: args: list[Attribute] = self.must_parse_list_of( - self.try_parse_type, 'Expected type here!') - self.must_parse_characters(')', - "Unclosed function type argument list!") + self.try_parse_type, "Expected type here!") + self.must_parse_characters( + ")", "Unclosed function type argument list!") else: args = [self.try_parse_type()] if args[0] is None: self.raise_error( - "Function type must either be single type or list of types in parenthesis!" - ) + "Function type must either be single type or list of types in" + " parenthesis!") return args def try_parse_function_type(self) -> FunctionType | None: - if not self.tokenizer.starts_with('('): + if not self.tokenizer.starts_with("("): return None - with self.tokenizer.backtracking('function type'): + with self.tokenizer.backtracking("function type"): return self.must_parse_function_type() def must_parse_region_list(self) -> list[Region]: @@ -1368,24 +1438,24 @@ def must_parse_region_list(self) -> list[Region]: Parses a sequence of regions for as long as there is a `{` in the input. """ regions = [] - while not self.tokenizer.is_eof() and self.tokenizer.starts_with('{'): + while not self.tokenizer.is_eof() and self.tokenizer.starts_with("{"): regions.append(self.must_parse_region()) return regions # COMMON xDSL/MLIR code: def must_parse_builtin_type_with_name(self, name: Span): - if name.text == 'index': + if name.text == "index": return IndexType() - if (re_match := re.match(r'^[su]?i(\d+)$', name.text)) is not None: + if (re_match := re.match(r"^[su]?i(\d+)$", name.text)) is not None: signedness = { - 's': Signedness.SIGNED, - 'u': Signedness.UNSIGNED, - 'i': Signedness.SIGNLESS + "s": Signedness.SIGNED, + "u": Signedness.UNSIGNED, + "i": Signedness.SIGNLESS, } return IntegerType.from_width(int(re_match.group(1)), signedness[name.text[0]]) - if (re_match := re.match(r'^f(\d+)$', name.text)) is not None: + if (re_match := re.match(r"^f(\d+)$", name.text)) is not None: width = int(re_match.group(1)) type = { 16: Float16Type, @@ -1401,7 +1471,7 @@ def must_parse_builtin_type_with_name(self, name: Span): @abstractmethod def must_parse_operation_details( - self + self, ) -> tuple[list[Span], list[Span], dict[str, Attribute], list[Region], FunctionType | None]: """ @@ -1422,11 +1492,11 @@ def must_parse_operation_details( def must_parse_op_args_list(self) -> list[Span]: self.must_parse_characters( - '(', 'Operation args list must be enclosed by brackets!') + "(", "Operation args list must be enclosed by brackets!") args = self.must_parse_list_of(self.try_parse_value_id_and_type, - 'Expected another bare-id here') + "Expected another bare-id here") self.must_parse_characters( - ')', 'Operation args list must be closed by a closing bracket') + ")", "Operation args list must be closed by a closing bracket") # TODO: check if type is correct here! return [name for name, _ in args] @@ -1434,13 +1504,13 @@ def must_parse_op_args_list(self) -> list[Span]: # since we don't want to rewrite all dialects currently, the new emulator needs to expose the same # interface to the dialect definitions. Here we implement that interface. - _OperationType = TypeVar('_OperationType', bound=Operation) + _OperationType = TypeVar("_OperationType", bound=Operation) def parse_op_with_default_format( - self, - op_type: type[_OperationType], - result_types: list[Attribute], - skip_white_space: bool = True) -> _OperationType: + self, + op_type: type[_OperationType], + result_types: list[Attribute], + ) -> _OperationType: """ Compatibility wrapper so the new parser can be passed instead of the old one. Parses everything after the operation name. @@ -1520,13 +1590,13 @@ def must_parse_attribute(self) -> Attribute: Parse attribute (either builtin or dialect) """ # all dialect attrs must start with '#', so we check for that first (as it's easier) - if self.tokenizer.starts_with('#'): + if self.tokenizer.starts_with("#"): value = self.try_parse_dialect_attr() # no value => error if value is None: self.raise_error( - '`#` must be followed by a valid dialect attribute or type!' + "`#` must be followed by a valid dialect attribute or type!" ) return value @@ -1543,62 +1613,64 @@ def must_parse_attribute(self) -> Attribute: def must_parse_op_result_list( self) -> tuple[list[Span], list[Attribute] | None]: - return self.must_parse_list_of(self.try_parse_value_id, - 'Expected op-result here!', - allow_empty=True), None + return ( + self.must_parse_list_of(self.try_parse_value_id, + "Expected op-result here!", + allow_empty=True), + None, + ) def must_parse_optional_attr_dict(self) -> dict[str, Attribute]: - if not self.tokenizer.starts_with('{'): + if not self.tokenizer.starts_with("{"): return dict() self.must_parse_characters( - '{', - 'MLIR Attribute dictionary must be enclosed in curly brackets') + "{", + "MLIR Attribute dictionary must be enclosed in curly brackets") attrs = self.must_parse_list_of(self.must_parse_attribute_entry, "Expected attribute entry") self.must_parse_characters( - '}', - 'MLIR Attribute dictionary must be enclosed in curly brackets') + "}", + "MLIR Attribute dictionary must be enclosed in curly brackets") return self.attr_dict_from_tuple_list(attrs) def must_parse_operation_details( - self + self, ) -> tuple[list[Span], list[Span], dict[str, Attribute], list[Region], FunctionType | None]: - args = self.must_parse_op_args_list() succ = self.must_parse_optional_successor_list() regions = [] - if self.tokenizer.starts_with('('): - self.must_parse_characters('(', - 'Expected brackets enclosing regions!') + if self.tokenizer.starts_with("("): + self.must_parse_characters("(", + "Expected brackets enclosing regions!") regions = self.must_parse_region_list() - self.must_parse_characters(')', - 'Expected brackets enclosing regions!') + self.must_parse_characters(")", + "Expected brackets enclosing regions!") attrs = self.must_parse_optional_attr_dict() self.must_parse_characters( - ':', - 'MLIR Operation defintions must end in a function type signature!') + ":", + "MLIR Operation defintions must end in a function type signature!") func_type = self.must_parse_function_type() return args, succ, attrs, regions, func_type def must_parse_optional_successor_list(self) -> list[Span]: - if not self.tokenizer.starts_with('['): + if not self.tokenizer.starts_with("["): return [] self.must_parse_characters( - '[', 'Successor list is enclosed in square brackets') + "[", "Successor list is enclosed in square brackets") successors = self.must_parse_list_of(self.try_parse_block_id, - 'Expected a block-id', + "Expected a block-id", allow_empty=False) self.must_parse_characters( - ']', 'Successor list is enclosed in square brackets') + "]", "Successor list is enclosed in square brackets") return successors @@ -1628,7 +1700,7 @@ def must_parse_attribute(self) -> Attribute: # xDSL: Allow both # and ! prefixes, as we allow both types and attrs # TODO: phase out use of next_token(peek=True) in favour of starts_with - if value is None and self.tokenizer.next_token(peek=True).text in '#!': + if value is None and self.tokenizer.next_token(peek=True).text in "#!": # in MLIR # and ! are prefixes for dialect attrs/types, but in xDSL ! is also used for builtin types value = self.try_parse_dialect_type_or_attribute() @@ -1641,11 +1713,13 @@ def must_parse_attribute(self) -> Attribute: def must_parse_op_result_list( self) -> tuple[list[Span], list[Attribute] | None]: - if not self.tokenizer.starts_with('%'): + if not self.tokenizer.starts_with("%"): return list(), list() - results = self.must_parse_list_of(self.try_parse_value_id_and_type, - 'Expected (value-id `:` type) here!', - allow_empty=False) + results = self.must_parse_list_of( + self.try_parse_value_id_and_type, + "Expected (value-id `:` type) here!", + allow_empty=False, + ) # TODO: this is hideous, make it cleaner # zip(*results) works, but is barely readable :/ return [name for name, _ in results], [type for _, type in results] @@ -1659,30 +1733,30 @@ def try_parse_builtin_attr(self) -> Attribute: # in xdsl, two things are different here: # 1. types are considered valid attributes # 2. all types, builtins included, are prefixed with ! - if self.tokenizer.starts_with('!'): + if self.tokenizer.starts_with("!"): return self.try_parse_builtin_type() return super().try_parse_builtin_attr() def must_parse_optional_attr_dict(self) -> dict[str, Attribute]: - if not self.tokenizer.starts_with('['): + if not self.tokenizer.starts_with("["): return dict() self.must_parse_characters( - '[', - 'xDSL Attribute dictionary must be enclosed in square brackets') + "[", + "xDSL Attribute dictionary must be enclosed in square brackets") attrs = self.must_parse_list_of(self.must_parse_attribute_entry, "Expected attribute entry") self.must_parse_characters( - ']', - 'xDSL Attribute dictionary must be enclosed in square brackets') + "]", + "xDSL Attribute dictionary must be enclosed in square brackets") return self.attr_dict_from_tuple_list(attrs) def must_parse_operation_details( - self + self, ) -> tuple[list[Span], list[Span], dict[str, Attribute], list[Region], FunctionType | None]: """ @@ -1703,22 +1777,24 @@ def must_parse_operation_details( return args, succ, attrs, regions, None def must_parse_optional_successor_list(self) -> list[Span]: - if not self.tokenizer.starts_with('('): + if not self.tokenizer.starts_with("("): return [] self.must_parse_characters( - '(', 'Successor list is enclosed in round brackets') + "(", "Successor list is enclosed in round brackets") successors = self.must_parse_list_of(self.try_parse_block_id, - 'Expected a block-id', + "Expected a block-id", allow_empty=False) self.must_parse_characters( - ')', 'Successor list is enclosed in round brackets') + ")", "Successor list is enclosed in round brackets") return successors def must_parse_dialect_type_or_attribute_inner(self, kind: str): if self.tokenizer.starts_with('"'): name = self.try_parse_string_literal() if name is None: - self.raise_error("Expected string literal for an attribute in generic format here!") + self.raise_error( + "Expected string literal for an attribute in generic format here!" + ) return self.must_parse_generic_attribute_args(name) return super().must_parse_dialect_type_or_attribute_inner(kind) From 649e43a5c6b0646006411f6d616ffe6855ae6eaf Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Fri, 13 Jan 2023 15:42:06 +0000 Subject: [PATCH 25/65] parser: yapf formatting run --- xdsl/parser.py | 53 +++++++++++++++++++++++++++++++------------------- 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/xdsl/parser.py b/xdsl/parser.py index 5a98f58034..a2ebfb2fc9 100644 --- a/xdsl/parser.py +++ b/xdsl/parser.py @@ -185,7 +185,8 @@ def print_with_context(self, msg: str | None = None) -> str: offset = self.start - offset_of_first_line remaining_len = max(self.len, 1) capture = StringIO() - print("{}:{}:{}".format(self.input.name, line_no, offset), file=capture) + print("{}:{}:{}".format(self.input.name, line_no, offset), + file=capture) for line in lines: print(line, file=capture) if remaining_len < 0: @@ -714,8 +715,7 @@ def must_parse_block_arg_list(self) -> list[tuple[Span, Attribute]]: args = self.must_parse_list_of(self.try_parse_value_id_and_type, "Expected value-id and type here!") - self.must_parse_characters(')', - 'Expected closing of block arguments!') + self.must_parse_characters(')', 'Expected closing of block arguments!') return args @@ -765,7 +765,7 @@ def must_parse_list_of(self, items.append(first_item) while (match := self.tokenizer.next_token_of_pattern(separator_pattern) - ) is not None: + ) is not None: next_item = try_parse() if next_item is None: # if the separator is emtpy, we are good here @@ -988,7 +988,7 @@ def must_parse_vector_attrs(self) -> AnyVectorType: def must_parse_tensor_or_memref_dims(self) -> list[int] | None: with self.tokenizer.configured(break_on=self.tokenizer.break_on + - ('x',)): + ('x', )): # check for unranked-ness if self.tokenizer.next_token_of_pattern('*') is not None: # consume `x` @@ -1046,7 +1046,8 @@ def must_parse_type_params(self) -> list[Attribute]: params = self.must_parse_list_of(self.try_parse_type, 'Expected a type here!') - self.must_parse_characters('>', 'Expected end of type parameterization here!') + self.must_parse_characters( + '>', 'Expected end of type parameterization here!') return params @@ -1090,8 +1091,7 @@ def must_parse_operation(self) -> Operation: if len(result_list) > 0: self.must_parse_characters( '=', - 'Operation definitions expect an `=` after op-result-list!' - ) + 'Operation definitions expect an `=` after op-result-list!') # check for custom op format op_name = self.try_parse_bare_id() @@ -1120,8 +1120,7 @@ def must_parse_operation(self) -> Operation: result_types=ret_types, attributes=attrs, successors=[ - self.blocks[block_name.text] - for block_name in successors + self.blocks[block_name.text] for block_name in successors ], regions=regions) @@ -1168,7 +1167,7 @@ def must_parse_region(self) -> Region: 'The following block references are dangling:', [(span, "Reference to block \"{}\" without implementation!" .format(span.text)) for span in itertools.chain( - *self.forward_block_references.values())], + *self.forward_block_references.values())], self.tokenizer.history) return region @@ -1249,13 +1248,15 @@ def try_parse_builtin_attr(self) -> Attribute | None: def try_parse_builtin_dense_attr(self) -> Attribute | None: with self.tokenizer.backtracking("dense attribute"): - self.must_parse_characters("dense", "builtin dense attribute must start with `dense`") + self.must_parse_characters( + "dense", "builtin dense attribute must start with `dense`") err_msg = "Malformed dense attribute, format must be (`dense<` array-attr `>:` type)" self.must_parse_characters("<", err_msg) info = list(self.must_parse_builtin_dense_attr_args()) self.must_parse_characters(">", err_msg) self.must_parse_characters(":", err_msg) - type = self.expect(self.try_parse_type, "Dense attribute must be typed!") + type = self.expect(self.try_parse_type, + "Dense attribute must be typed!") return DenseIntOrFPElementsAttr.from_list(type, info) def must_parse_builtin_dense_attr_args(self) -> Iterable[int | float]: @@ -1265,12 +1266,14 @@ def must_parse_builtin_dense_attr_args(self) -> Iterable[int | float]: dense-attr-params := float-literal | int-literal | list-of-dense-attrs-params list-of-dense-attrs-params := `[` dense-attr-params (`,` dense-attr-params)* `]` """ + def try_parse_int_or_float(): if (literal := self.try_parse_float_literal()) is not None: return float(literal.text) if (literal := self.try_parse_integer_literal()) is not None: return int(literal.text) self.raise_error('Expected int or float literal here!') + if not self.tokenizer.starts_with('['): yield try_parse_int_or_float() return @@ -1282,7 +1285,6 @@ def try_parse_int_or_float(): break self.must_parse_characters(']', '') - def try_parse_ref_attr(self) -> FlatSymbolRefAttr | None: if not self.tokenizer.starts_with("@"): return None @@ -1568,7 +1570,9 @@ def parse_op(self) -> Operation: return self.must_parse_operation() def parse_int_literal(self) -> int: - return int(self.expect(self.try_parse_integer_literal, 'Expected integer literal here').text) + return int( + self.expect(self.try_parse_integer_literal, + 'Expected integer literal here').text) class MLIRParser(BaseParser): @@ -1804,9 +1808,12 @@ def must_parse_generic_attribute_args(self, name: StringLiteral): self.raise_error("Unknown attribute name!", name) if not issubclass(attr, ParametrizedAttribute): self.raise_error("Expected ParametrizedAttribute name here!", name) - self.must_parse_characters('<', 'Expected generic attribute arguments here!') - args = self.must_parse_list_of(self.try_parse_attribute, 'Unexpected end of attribute list!') - self.must_parse_characters('>', 'Malformed attribute arguments, reached end of args list!') + self.must_parse_characters( + '<', 'Expected generic attribute arguments here!') + args = self.must_parse_list_of(self.try_parse_attribute, + 'Unexpected end of attribute list!') + self.must_parse_characters( + '>', 'Malformed attribute arguments, reached end of args list!') return attr(args) @@ -1818,8 +1825,14 @@ class Source(Enum): MLIR = 2 -def Parser(ctx: MLContext, prog: str, source: Source = Source.XDSL, filename: str = '') -> BaseParser: - selected_parser = {Source.XDSL: XDSLParser, Source.MLIR: MLIRParser}[source] +def Parser(ctx: MLContext, + prog: str, + source: Source = Source.XDSL, + filename: str = '') -> BaseParser: + selected_parser = { + Source.XDSL: XDSLParser, + Source.MLIR: MLIRParser + }[source] return selected_parser(ctx, prog, filename) From a4bb0498e23a05a1bee7a6424e9f7d6d7e78b73a Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Fri, 13 Jan 2023 16:13:49 +0000 Subject: [PATCH 26/65] parser: fixed xdsl_opt tests --- tests/test_parser_error.py | 8 +++++--- xdsl/parser.py | 3 ++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/test_parser_error.py b/tests/test_parser_error.py index c0abffbea6..0f7bb6bc47 100644 --- a/tests/test_parser_error.py +++ b/tests/test_parser_error.py @@ -1,11 +1,13 @@ from __future__ import annotations + from typing import Annotated -from xdsl.ir import MLContext, OpResult, SSAValue -from xdsl.irdl import AnyAttr, VarOperandDef, VarResultDef, irdl_op_definition, Operation -from xdsl.parser import Parser, ParserError from pytest import raises +from xdsl.ir import MLContext +from xdsl.irdl import AnyAttr, irdl_op_definition, Operation, VarOperand, VarOpResult +from xdsl.parser import Parser, ParseError + @irdl_op_definition class UnkownOp(Operation): diff --git a/xdsl/parser.py b/xdsl/parser.py index a2ebfb2fc9..30b76541dc 100644 --- a/xdsl/parser.py +++ b/xdsl/parser.py @@ -1828,7 +1828,8 @@ class Source(Enum): def Parser(ctx: MLContext, prog: str, source: Source = Source.XDSL, - filename: str = '') -> BaseParser: + filename: str = '', + allow_unregistered_ops = False) -> BaseParser: selected_parser = { Source.XDSL: XDSLParser, Source.MLIR: MLIRParser From af8cebfca23a4a2a63ab402a189d813186356912 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Tue, 17 Jan 2023 13:56:25 +0000 Subject: [PATCH 27/65] parser: fix all tests failing after rebase --- tests/test_attribute_builder.py | 12 ++++---- tests/test_attribute_definition.py | 27 +++++++++--------- tests/test_ir.py | 10 +++---- tests/test_irdl.py | 6 ++-- tests/test_mlir_converter.py | 4 +-- tests/test_mlir_printer.py | 10 +++---- tests/test_parser.py | 46 ++++++++++++++++-------------- tests/test_parser_error.py | 4 +-- tests/test_printer.py | 34 +++++++++++----------- xdsl/dialects/builtin.py | 44 ++++++++++++++-------------- xdsl/dialects/irdl.py | 4 +-- xdsl/dialects/llvm.py | 10 +++---- xdsl/ir.py | 8 +++--- xdsl/parser.py | 17 +++++++++-- xdsl/xdsl_opt_main.py | 9 +++--- 15 files changed, 131 insertions(+), 114 deletions(-) diff --git a/tests/test_attribute_builder.py b/tests/test_attribute_builder.py index f89794fffe..cda2add81b 100644 --- a/tests/test_attribute_builder.py +++ b/tests/test_attribute_builder.py @@ -4,7 +4,7 @@ from xdsl.ir import ParametrizedAttribute, Data from xdsl.irdl import irdl_attr_definition, builder -from xdsl.parser import Parser +from xdsl.parser import BaseParser from xdsl.printer import Printer from xdsl.utils.exceptions import BuilderNotFoundException @@ -34,7 +34,7 @@ def from_int(data: int) -> OneBuilderAttr: return OneBuilderAttr(str(data)) @staticmethod - def parse_parameter(parser: Parser) -> str: + def parse_parameter(parser: BaseParser) -> str: raise NotImplementedError() @staticmethod @@ -72,7 +72,7 @@ def from_int(data1: int, data2: str) -> OneBuilderAttr: return OneBuilderAttr(str(data1) + data2) @staticmethod - def parse_parameter(parser: Parser) -> str: + def parse_parameter(parser: BaseParser) -> str: raise NotImplementedError() @staticmethod @@ -105,7 +105,7 @@ def from_str(s: str) -> TwoBuildersAttr: return TwoBuildersAttr(s) @staticmethod - def parse_parameter(parser: Parser) -> str: + def parse_parameter(parser: BaseParser) -> str: raise NotImplementedError() @staticmethod @@ -145,7 +145,7 @@ def from_int(data1: int, return BuilderDefaultArgAttr(f"{data1}, {data2}, {data3}") @staticmethod - def parse_parameter(parser: Parser) -> str: + def parse_parameter(parser: BaseParser) -> str: raise NotImplementedError() @staticmethod @@ -188,7 +188,7 @@ def from_int(data: str | int) -> BuilderUnionArgAttr: return BuilderUnionArgAttr(str(data)) @staticmethod - def parse_parameter(parser: Parser) -> str: + def parse_parameter(parser: BaseParser) -> str: raise NotImplementedError() @staticmethod diff --git a/tests/test_attribute_definition.py b/tests/test_attribute_definition.py index cb78352cb9..f7436c75ed 100644 --- a/tests/test_attribute_definition.py +++ b/tests/test_attribute_definition.py @@ -13,7 +13,7 @@ from xdsl.irdl import (AttrConstraint, GenericData, ParameterDef, irdl_attr_definition, builder, irdl_to_attr_constraint, AnyAttr, BaseAttr, ParamAttrDef) -from xdsl.parser import Parser +from xdsl.parser import BaseParser from xdsl.printer import Printer from xdsl.utils.exceptions import VerifyException @@ -31,14 +31,13 @@ class BoolData(Data[bool]): name = "bool" @staticmethod - def parse_parameter(parser: Parser) -> bool: - val = parser.parse_optional_ident() - if val == "True": + def parse_parameter(parser: BaseParser) -> bool: + val = parser.tokenizer.next_token_of_pattern('(True|False)') + if val is None or val.text not in ('True', 'False'): + parser.raise_error("Expected True or False literal") + if val.text == "True": return True - elif val == "False": - return False - else: - raise Exception("Wrong argument passed to BoolAttr.") + return False @staticmethod def print_parameter(data: bool, printer: Printer): @@ -51,7 +50,7 @@ class IntData(Data[int]): name = "int" @staticmethod - def parse_parameter(parser: Parser) -> int: + def parse_parameter(parser: BaseParser) -> int: return parser.parse_int_literal() @staticmethod @@ -65,7 +64,7 @@ class StringData(Data[str]): name = "str" @staticmethod - def parse_parameter(parser: Parser) -> str: + def parse_parameter(parser: BaseParser) -> str: return parser.parse_str_literal() @staticmethod @@ -102,7 +101,7 @@ class IntListMissingVerifierData(Data[list[int]]): name = "missing_verifier_data" @staticmethod - def parse_parameter(parser: Parser) -> list[int]: + def parse_parameter(parser: BaseParser) -> list[int]: raise NotImplementedError() @staticmethod @@ -134,7 +133,7 @@ class IntListData(Data[list[int]]): name = "int_list" @staticmethod - def parse_parameter(parser: Parser) -> list[int]: + def parse_parameter(parser: BaseParser) -> list[int]: raise NotImplementedError() @staticmethod @@ -431,7 +430,7 @@ class MissingGenericDataData(Data[_MissingGenericDataData]): name = "missing_genericdata" @staticmethod - def parse_parameter(parser: Parser) -> _MissingGenericDataData: + def parse_parameter(parser: BaseParser) -> _MissingGenericDataData: raise NotImplementedError() @staticmethod @@ -484,7 +483,7 @@ class ListData(GenericData[list[A]]): name = "list" @staticmethod - def parse_parameter(parser: Parser) -> list[A]: + def parse_parameter(parser: BaseParser) -> list[A]: raise NotImplementedError() @staticmethod diff --git a/tests/test_ir.py b/tests/test_ir.py index 1b690e4ecf..d33cea4837 100644 --- a/tests/test_ir.py +++ b/tests/test_ir.py @@ -4,7 +4,7 @@ from xdsl.dialects.arith import Addi, Subi, Constant from xdsl.dialects.builtin import i32, IntegerAttr, ModuleOp from xdsl.dialects.scf import If -from xdsl.parser import Parser +from xdsl.parser import XDSLParser from xdsl.dialects.builtin import Builtin from xdsl.dialects.func import Func from xdsl.dialects.arith import Arith @@ -203,10 +203,10 @@ def test_is_structurally_equivalent(args: list[str], expected_result: bool): ctx.register_dialect(Arith) ctx.register_dialect(Cf) - parser = Parser(ctx, args[0]) + parser = XDSLParser(ctx, args[0]) lhs: Operation = parser.parse_op() - parser = Parser(ctx, args[1]) + parser = XDSLParser(ctx, args[1]) rhs: Operation = parser.parse_op() assert lhs.is_structurally_equivalent(rhs) == expected_result @@ -231,8 +231,8 @@ def test_is_structurally_equivalent_incompatible_ir_nodes(): ctx.register_dialect(Arith) ctx.register_dialect(Cf) - parser = Parser(ctx, program_func) - program: ModuleOp = parser.parse_op() + parser = XDSLParser(ctx, program_func) + program: ModuleOp = parser.must_parse_operation() assert program.is_structurally_equivalent(program.regions[0]) == False assert program.is_structurally_equivalent( diff --git a/tests/test_irdl.py b/tests/test_irdl.py index c28c4b2b1a..9bd1248a2f 100644 --- a/tests/test_irdl.py +++ b/tests/test_irdl.py @@ -5,7 +5,7 @@ from xdsl.ir import Attribute, Data, ParametrizedAttribute from xdsl.irdl import AllOf, AnyAttr, AnyOf, AttrConstraint, BaseAttr, EqAttrConstraint, ParamAttrConstraint, ParameterDef, irdl_attr_definition -from xdsl.parser import Parser +from xdsl.parser import BaseParser from xdsl.printer import Printer from xdsl.utils.exceptions import VerifyException @@ -16,7 +16,7 @@ class BoolData(Data[bool]): name = "bool" @staticmethod - def parse_parameter(parser: Parser) -> bool: + def parse_parameter(parser: BaseParser) -> bool: raise NotImplementedError() @staticmethod @@ -30,7 +30,7 @@ class IntData(Data[int]): name = "int" @staticmethod - def parse_parameter(parser: Parser) -> int: + def parse_parameter(parser: BaseParser) -> int: return parser.parse_int_literal() @staticmethod diff --git a/tests/test_mlir_converter.py b/tests/test_mlir_converter.py index c68f609c59..465dfb51f6 100644 --- a/tests/test_mlir_converter.py +++ b/tests/test_mlir_converter.py @@ -9,7 +9,7 @@ from xdsl.dialects.affine import Affine from xdsl.dialects.arith import Arith -from xdsl.parser import Parser +from xdsl.parser import XDSLParser from xdsl.ir import MLContext from xdsl.dialects.builtin import Builtin @@ -23,7 +23,7 @@ def convert_and_verify(test_prog: str): ctx.register_dialect(Scf) ctx.register_dialect(MemRef) - parser = Parser(ctx, test_prog) + parser = XDSLParser(ctx, test_prog) module = parser.parse_op() module.verify() diff --git a/tests/test_mlir_printer.py b/tests/test_mlir_printer.py index 20f86e0f25..d23d3ea54b 100644 --- a/tests/test_mlir_printer.py +++ b/tests/test_mlir_printer.py @@ -5,7 +5,7 @@ from xdsl.ir import Attribute, Data, MLContext, MLIRType, Operation, ParametrizedAttribute from xdsl.irdl import (AnyAttr, ParameterDef, RegionDef, irdl_attr_definition, irdl_op_definition, VarOperand, VarOpResult) -from xdsl.parser import Parser, ParseError +from xdsl.parser import ParseError, BaseParser, XDSLParser from xdsl.printer import Printer @@ -30,7 +30,7 @@ class DataAttr(Data[int]): name = "data_attr" @staticmethod - def parse_parameter(parser: Parser) -> int: + def parse_parameter(parser: BaseParser) -> int: return parser.parse_int_literal() @staticmethod @@ -44,7 +44,7 @@ class DataType(Data[int], MLIRType): name = "data_type" @staticmethod - def parse_parameter(parser: Parser) -> int: + def parse_parameter(parser: BaseParser) -> int: return parser.parse_int_literal() @staticmethod @@ -89,9 +89,9 @@ def print_as_mlir_and_compare(test_prog: str, expected: str): ctx.register_attr(ParamAttrWithParam) ctx.register_attr(ParamAttrWithCustomFormat) - parser = Parser(ctx, test_prog) + parser = XDSLParser(ctx, test_prog) try: - module = parser.parse_op() + module = parser.must_parse_operation() except ParseError as err: io = StringIO() err.print_with_history(file=io) diff --git a/tests/test_parser.py b/tests/test_parser.py index 3322fac40e..7be6a22c76 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -1,7 +1,11 @@ +from io import StringIO + import pytest -from xdsl.ir import MLContext -from xdsl.parser import Parser +from printer import Printer +from xdsl.ir import MLContext, Attribute +from xdsl.parser import XDSLParser +from xdsl.dialects.builtin import IntAttr, DictionaryAttr, StringAttr, FloatAttr, ArrayAttr, Builtin @pytest.mark.parametrize("input,expected", [("0, 1, 1", [0, 1, 1]), @@ -9,29 +13,29 @@ ("1, 1, 0", [1, 1, 0])]) def test_int_list_parser(input: str, expected: list[int]): ctx = MLContext() - parser = Parser(ctx, input) + parser = XDSLParser(ctx, input, '') int_list = parser.must_parse_list_of(parser.try_parse_integer_literal, '') assert [int(span.text) for span in int_list] == expected -@pytest.mark.parametrize("input,expected", [('{"A"=0, "B"=1, "C"=2}', { - "A": 0, - "B": 1, - "C": 2 -}), ('{"MA"=10, "BR"=7, "Z"=3}', { - "MA": 10, - "BR": 7, - "Z": 3 -}), ('{"Q"=77, "VV"=12, "AA"=-8}', { - "Q": 77, - "VV": 12, - "AA": -8 -})]) -def test_int_dictionary_parser(input: str, expected: dict[str, int]): +@pytest.mark.parametrize('data', [ + dict(a=IntAttr.from_int(1), b=IntAttr.from_int(2), c=IntAttr.from_int(3)), + dict(a=StringAttr.from_str('hello'), b=IntAttr.from_int(2), c=ArrayAttr.from_list([IntAttr.from_int(2), StringAttr.from_str('world')])), + dict(), +]) +def test_dictionary_attr(data: dict[str, Attribute]): + attr = DictionaryAttr.from_dict(data) + + with StringIO() as io: + Printer(io).print(attr) + text = io.getvalue() + ctx = MLContext() - parser = Parser(ctx, input) + ctx.register_dialect(Builtin) + + attr = XDSLParser(ctx, text).must_parse_attribute() + + assert attr.data == data + - int_dict = parser.parse_dictionary(parser.parse_str_literal, - parser.parse_int_literal) - assert int_dict == expected diff --git a/tests/test_parser_error.py b/tests/test_parser_error.py index 0f7bb6bc47..0e3718608c 100644 --- a/tests/test_parser_error.py +++ b/tests/test_parser_error.py @@ -6,7 +6,7 @@ from xdsl.ir import MLContext from xdsl.irdl import AnyAttr, irdl_op_definition, Operation, VarOperand, VarOpResult -from xdsl.parser import Parser, ParseError +from xdsl.parser import Parser, ParseError, XDSLParser @irdl_op_definition @@ -20,7 +20,7 @@ def check_error(prog: str, line: int, column: int, message: str): ctx = MLContext() ctx.register_op(UnkownOp) - parser = Parser(ctx, prog) + parser = XDSLParser(ctx, prog) with raises(ParseError) as e: parser.must_parse_operation() diff --git a/tests/test_printer.py b/tests/test_printer.py index edf3f8333d..ce370a7dfc 100644 --- a/tests/test_printer.py +++ b/tests/test_printer.py @@ -10,7 +10,7 @@ from xdsl.ir import Attribute, MLContext, OpResult, ParametrizedAttribute, SSAValue from xdsl.irdl import (ParameterDef, irdl_attr_definition, irdl_op_definition, Operation, Operand, OptAttributeDef) -from xdsl.parser import Parser, BaseParser, Span +from xdsl.parser import Parser, BaseParser, Span, XDSLParser from xdsl.printer import Printer from xdsl.utils.diagnostic import Diagnostic @@ -149,7 +149,7 @@ def test_op_message(): ctx.register_dialect(Arith) ctx.register_dialect(Builtin) - parser = Parser(ctx, prog) + parser = XDSLParser(ctx, prog, '') module = parser.parse_op() file = StringIO("") @@ -184,7 +184,7 @@ def test_two_different_op_messages(): ctx.register_dialect(Arith) ctx.register_dialect(Builtin) - parser = Parser(ctx, prog) + parser = XDSLParser(ctx, prog, '') module = parser.parse_op() file = StringIO("") @@ -220,7 +220,7 @@ def test_two_same_op_messages(): ctx.register_dialect(Arith) ctx.register_dialect(Builtin) - parser = Parser(ctx, prog) + parser = XDSLParser(ctx, prog, '') module = parser.parse_op() file = StringIO("") @@ -254,7 +254,7 @@ def test_op_message_with_region(): ctx.register_dialect(Arith) ctx.register_dialect(Builtin) - parser = Parser(ctx, prog) + parser = XDSLParser(ctx, prog, '') module = parser.parse_op() file = StringIO("") @@ -290,7 +290,7 @@ def test_op_message_with_region_and_overflow(): ctx.register_dialect(Arith) ctx.register_dialect(Builtin) - parser = Parser(ctx, prog) + parser = XDSLParser(ctx, prog, '') module = parser.parse_op() file = StringIO("") @@ -316,7 +316,7 @@ def test_diagnostic(): ctx.register_dialect(Arith) ctx.register_dialect(Builtin) - parser = Parser(ctx, prog) + parser = XDSLParser(ctx, prog, '') module = parser.parse_op() diag = Diagnostic() @@ -356,7 +356,7 @@ def test_print_custom_name(): ctx.register_dialect(Arith) ctx.register_dialect(Builtin) - parser = Parser(ctx, prog) + parser = XDSLParser(ctx, prog, '') module = parser.parse_op() file = StringIO("") @@ -421,7 +421,7 @@ def test_generic_format(): ctx.register_dialect(Builtin) ctx.register_op(PlusCustomFormatOp) - parser = Parser(ctx, prog) + parser = XDSLParser(ctx, prog, '') module = parser.parse_op() file = StringIO("") @@ -452,7 +452,7 @@ def test_custom_format(): ctx.register_dialect(Builtin) ctx.register_op(PlusCustomFormatOp) - parser = Parser(ctx, prog) + parser = XDSLParser(ctx, prog, '') module = parser.parse_op() file = StringIO("") @@ -483,7 +483,7 @@ def test_custom_format_II(): ctx.register_dialect(Builtin) ctx.register_op(PlusCustomFormatOp) - parser = Parser(ctx, prog) + parser = XDSLParser(ctx, prog, '') module = parser.parse_op() file = StringIO("") @@ -540,7 +540,7 @@ def test_custom_format_attr(): ctx.register_op(AnyOp) ctx.register_attr(CustomFormatAttr) - parser = Parser(ctx, prog) + parser = XDSLParser(ctx, prog, '') module = parser.parse_op() file = StringIO("") @@ -569,7 +569,7 @@ def test_parse_generic_format_attr(): ctx.register_op(AnyOp) ctx.register_attr(CustomFormatAttr) - parser = Parser(ctx, prog) + parser = XDSLParser(ctx, prog, '') module = parser.parse_op() file = StringIO("") @@ -598,7 +598,7 @@ def test_parse_generic_format_attr_II(): ctx.register_op(AnyOp) ctx.register_attr(CustomFormatAttr) - parser = Parser(ctx, prog) + parser = XDSLParser(ctx, prog, '') module = parser.parse_op() file = StringIO("") @@ -653,7 +653,7 @@ def test_parse_dense_xdsl(): ctx.register_dialect(Builtin) ctx.register_dialect(Arith) - parser = Parser(ctx, prog) + parser = XDSLParser(ctx, prog, '') module = parser.parse_op() file = StringIO("") @@ -701,7 +701,7 @@ def test_foo_string(): ctx.register_op(AnyOp) ctx.register_attr(CustomFormatAttr) - parser = Parser(ctx, prog) + parser = XDSLParser(ctx, prog, '') try: parser.parse_op() assert False @@ -720,7 +720,7 @@ def test_dictionary_attr(): ctx.register_dialect(Builtin) ctx.register_dialect(Func) - parser = Parser(ctx, prog) + parser = XDSLParser(ctx, prog, '') parsed = parser.parse_op() file = StringIO("") diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index a35520c783..ffc72a4dc2 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -16,7 +16,7 @@ from xdsl.utils.exceptions import VerifyException if TYPE_CHECKING: - from xdsl.parser import Parser, ParseError + from xdsl.parser import BaseParser, ParseError from xdsl.printer import Printer @@ -25,7 +25,7 @@ class StringAttr(Data[str]): name = "string" @staticmethod - def parse_parameter(parser: Parser) -> str: + def parse_parameter(parser: BaseParser) -> str: data = parser.parse_str_literal() return data @@ -81,7 +81,7 @@ class IntAttr(Data[int]): name = "int" @staticmethod - def parse_parameter(parser: Parser) -> int: + def parse_parameter(parser: BaseParser) -> int: data = parser.parse_int_literal() return data @@ -110,7 +110,7 @@ class SignednessAttr(Data[Signedness]): name = "signedness" @staticmethod - def parse_parameter(parser: Parser) -> Signedness: + def parse_parameter(parser: BaseParser) -> Signedness: if parser.parse_optional_string("signless") is not None: return Signedness.SIGNLESS elif parser.parse_optional_string("signed") is not None: @@ -227,7 +227,7 @@ class FloatData(Data[float]): name = "float_data" @staticmethod - def parse_parameter(parser: Parser) -> float: + def parse_parameter(parser: BaseParser) -> float: return parser.parse_float_literal() @staticmethod @@ -300,7 +300,7 @@ class ArrayAttr(GenericData[List[_ArrayAttrT]]): name = "array" @staticmethod - def parse_parameter(parser: Parser) -> List[_ArrayAttrT]: + def parse_parameter(parser: BaseParser) -> List[_ArrayAttrT]: parser.parse_char("[") data = parser.parse_list(parser.parse_optional_attribute) parser.parse_char("]") @@ -345,17 +345,17 @@ def from_list(data: List[_ArrayAttrT]) -> ArrayAttr[_ArrayAttrT]: @irdl_attr_definition -class DictionaryAttr(GenericData[dict[str, Attribute]]): +class DictionaryAttr(GenericData[dict[StringAttr, Attribute]]): name = "dictionary" @staticmethod - def parse_parameter(parser: Parser) -> dict[str, Attribute]: - data = parser.parse_dictionary(parser.parse_str_literal, - parser.parse_attribute) - return data + def parse_parameter(parser: BaseParser) -> dict[str, Attribute]: + # force MLIR style parsing of attribute + from xdsl.parser import MLIRParser + return MLIRParser.must_parse_optional_attr_dict(parser) @staticmethod - def print_parameter(data: dict[str, Attribute], printer: Printer) -> None: + def print_parameter(data: dict[StringAttr, Attribute], printer: Printer) -> None: printer.print_string("{") printer.print_dictionary(data, printer.print_string_literal, printer.print_attribute) @@ -384,18 +384,18 @@ def verify(self) -> None: @staticmethod @builder def from_dict(data: dict[str | StringAttr, Attribute]) -> DictionaryAttr: - to_add_data: dict[str, Attribute] = {} + to_add_data = {} for k, v in data.items(): + # try to coerce keys into StringAttr + if isinstance(k, StringAttr): + k = k.data + # if coercion fails, raise KeyError! if not isinstance(k, str): - if isinstance(k, StringAttr): - to_add_data[k.data] = v - else: - raise TypeError( - f"Attribute DictionaryAttr expects keys to" - f" be of type StringAttr or str, but {type(k)} provided" - ) - else: - to_add_data[k] = v + raise TypeError( + f"Attribute DictionaryAttr expects keys to" + f" be of type str or str, but {type(k)} provided" + ) + to_add_data[k] = v return DictionaryAttr(to_add_data) diff --git a/xdsl/dialects/irdl.py b/xdsl/dialects/irdl.py index d37462a98a..672da4b145 100644 --- a/xdsl/dialects/irdl.py +++ b/xdsl/dialects/irdl.py @@ -6,7 +6,7 @@ from xdsl.irdl import (ParameterDef, AnyAttr, AttributeDef, SingleBlockRegionDef, irdl_op_definition, irdl_attr_definition) -from xdsl.parser import Parser +from xdsl.parser import BaseParser from xdsl.printer import Printer @@ -61,7 +61,7 @@ class NamedTypeConstraintAttr(ParametrizedAttribute): params_constraints: ParameterDef[Attribute] @staticmethod - def parse_parameters(parser: Parser) -> list[Attribute]: + def parse_parameters(parser: BaseParser) -> list[Attribute]: parser.parse_char("<") type_name = parser.parse_str_literal() parser.parse_char(":") diff --git a/xdsl/dialects/llvm.py b/xdsl/dialects/llvm.py index cd154ffcc4..b4e203d8e6 100644 --- a/xdsl/dialects/llvm.py +++ b/xdsl/dialects/llvm.py @@ -9,7 +9,7 @@ from xdsl.dialects.builtin import StringAttr, ArrayOfConstraint, ArrayAttr if TYPE_CHECKING: - from xdsl.parser import Parser + from xdsl.parser import BaseParser from xdsl.printer import Printer @@ -38,10 +38,10 @@ def print_parameters(self, printer: Printer) -> None: printer.print(")>") @staticmethod - def parse_parameters(parser: Parser) -> list[Attribute]: - parser.parse_string("<(") - params = parser.parse_list(parser.parse_optional_attribute) - parser.parse_string(")>") + def parse_parameters(parser: BaseParser) -> list[Attribute]: + parser.must_parse_characters("<(", "LLVM Struct must start with `<(`") + params = parser.must_parse_list_of(parser.try_parse_attribute, "Malformed LLVM struct, expected attribute definition here!") + parser.must_parse_characters(")>", "Unexpected input, expected end of LLVM struct!") return [StringAttr.from_str(""), ArrayAttr.from_list(params)] diff --git a/xdsl/ir.py b/xdsl/ir.py index 07370288e1..9b2eb7a6b6 100644 --- a/xdsl/ir.py +++ b/xdsl/ir.py @@ -10,7 +10,7 @@ # Used for cyclic dependencies in type hints if TYPE_CHECKING: - from xdsl.parser import Parser + from xdsl.parser import Parser, BaseParser from xdsl.printer import Printer from xdsl.irdl import OpDef, ParamAttrDef @@ -294,7 +294,7 @@ class Data(Generic[DataElement], Attribute, ABC): @staticmethod @abstractmethod - def parse_parameter(parser: Parser) -> DataElement: + def parse_parameter(parser: BaseParser) -> DataElement: """Parse the attribute parameter.""" @staticmethod @@ -309,7 +309,7 @@ class ParametrizedAttribute(Attribute): parameters: list[Attribute] = field(default_factory=list) @staticmethod - def parse_parameters(parser: Parser) -> list[Attribute]: + def parse_parameters(parser: BaseParser) -> list[Attribute]: """Parse the attribute parameters.""" return parser.parse_paramattr_parameters() @@ -507,7 +507,7 @@ def verify_(self) -> None: @classmethod def parse(cls: type[_OperationType], result_types: list[Attribute], - parser: Parser) -> _OperationType: + parser: BaseParser) -> _OperationType: return parser.parse_op_with_default_format(cls, result_types) def print(self, printer: Printer): diff --git a/xdsl/parser.py b/xdsl/parser.py index 30b76541dc..1b95d70e5a 100644 --- a/xdsl/parser.py +++ b/xdsl/parser.py @@ -637,7 +637,8 @@ def __init__( self, ctx: MLContext, input: str, - name: str, + name: str = '', + allow_unregistered_ops = False ): self.tokenizer = Tokenizer(Input(input, name)) self.ctx = ctx @@ -1237,6 +1238,8 @@ def try_parse_builtin_attr(self) -> Attribute | None: return self.try_parse_ref_attr() elif next_token.text == "dense": return self.try_parse_builtin_dense_attr() + elif next_token.text == '{': + return self.try_parse_builtin_dict_attr() # order here is important! attrs = (self.try_parse_builtin_float_attr, @@ -1574,6 +1577,13 @@ def parse_int_literal(self) -> int: self.expect(self.try_parse_integer_literal, 'Expected integer literal here').text) + def try_parse_builtin_dict_attr(self): + attr_def = self.ctx.get_optional_attr('dictionary') + if attr_def is None: + self.raise_error("An attribute named `dictionary` must be available in the context in order to parse dictionary attributes! Please make sure the builtin dialect is available, or provide your own replacement!") + param = attr_def.parse_parameter(self) + return attr_def(param) + class MLIRParser(BaseParser): @@ -1624,6 +1634,7 @@ def must_parse_op_result_list( None, ) + def must_parse_optional_attr_dict(self) -> dict[str, Attribute]: if not self.tokenizer.starts_with("{"): return dict() @@ -1632,7 +1643,9 @@ def must_parse_optional_attr_dict(self) -> dict[str, Attribute]: "{", "MLIR Attribute dictionary must be enclosed in curly brackets") - attrs = self.must_parse_list_of(self.must_parse_attribute_entry, + attrs = [] + if not self.tokenizer.starts_with('}'): + attrs = self.must_parse_list_of(self.must_parse_attribute_entry, "Expected attribute entry") self.must_parse_characters( diff --git a/xdsl/xdsl_opt_main.py b/xdsl/xdsl_opt_main.py index 7c1190b359..4781e3ad4d 100644 --- a/xdsl/xdsl_opt_main.py +++ b/xdsl/xdsl_opt_main.py @@ -5,7 +5,7 @@ import coverage from xdsl.ir import MLContext -from xdsl.parser import Parser +from xdsl.parser import Parser, XDSLParser, MLIRParser from xdsl.printer import Printer from xdsl.dialects.func import Func from xdsl.dialects.scf import Scf @@ -218,9 +218,10 @@ def register_all_frontends(self): def parse_xdsl(f: IOBase): input_str = f.read() - parser = Parser( + parser = XDSLParser( self.ctx, input_str, + self.args.input_file or '', allow_unregistered_ops=self.args.allow_unregistered_ops) module = parser.parse_op() if not (isinstance(module, ModuleOp)): @@ -230,10 +231,10 @@ def parse_xdsl(f: IOBase): def parse_mlir(f: IOBase): input_str = f.read() - parser = Parser( + parser = MLIRParser( self.ctx, input_str, - source=Parser.Source.MLIR, + self.args.input_file or '', allow_unregistered_ops=self.args.allow_unregistered_ops) module = parser.parse_op() if not (isinstance(module, ModuleOp)): From 64ac94e3f1fa540b1307f0ce3ea5beaed4fa6dbf Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Tue, 17 Jan 2023 13:58:59 +0000 Subject: [PATCH 28/65] parser: run yapf one last time before PR is done --- tests/test_mlir_printer.py | 4 ++-- tests/test_parser.py | 12 +++++++----- tests/test_parser_error.py | 6 ++++-- tests/test_printer.py | 18 ++++++++++++------ xdsl/dialects/builtin.py | 6 +++--- xdsl/dialects/llvm.py | 7 +++++-- xdsl/parser.py | 21 ++++++++++----------- 7 files changed, 43 insertions(+), 31 deletions(-) diff --git a/tests/test_mlir_printer.py b/tests/test_mlir_printer.py index d23d3ea54b..4e046f0892 100644 --- a/tests/test_mlir_printer.py +++ b/tests/test_mlir_printer.py @@ -3,8 +3,8 @@ from typing import Annotated from xdsl.ir import Attribute, Data, MLContext, MLIRType, Operation, ParametrizedAttribute -from xdsl.irdl import (AnyAttr, ParameterDef, RegionDef, irdl_attr_definition, irdl_op_definition, VarOperand, - VarOpResult) +from xdsl.irdl import (AnyAttr, ParameterDef, RegionDef, irdl_attr_definition, + irdl_op_definition, VarOperand, VarOpResult) from xdsl.parser import ParseError, BaseParser, XDSLParser from xdsl.printer import Printer diff --git a/tests/test_parser.py b/tests/test_parser.py index 7be6a22c76..b49d5ff175 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -2,10 +2,10 @@ import pytest -from printer import Printer +from xdsl.printer import Printer from xdsl.ir import MLContext, Attribute from xdsl.parser import XDSLParser -from xdsl.dialects.builtin import IntAttr, DictionaryAttr, StringAttr, FloatAttr, ArrayAttr, Builtin +from xdsl.dialects.builtin import IntAttr, DictionaryAttr, StringAttr, ArrayAttr, Builtin @pytest.mark.parametrize("input,expected", [("0, 1, 1", [0, 1, 1]), @@ -21,7 +21,11 @@ def test_int_list_parser(input: str, expected: list[int]): @pytest.mark.parametrize('data', [ dict(a=IntAttr.from_int(1), b=IntAttr.from_int(2), c=IntAttr.from_int(3)), - dict(a=StringAttr.from_str('hello'), b=IntAttr.from_int(2), c=ArrayAttr.from_list([IntAttr.from_int(2), StringAttr.from_str('world')])), + dict(a=StringAttr.from_str('hello'), + b=IntAttr.from_int(2), + c=ArrayAttr.from_list( + [IntAttr.from_int(2), + StringAttr.from_str('world')])), dict(), ]) def test_dictionary_attr(data: dict[str, Attribute]): @@ -37,5 +41,3 @@ def test_dictionary_attr(data: dict[str, Attribute]): attr = XDSLParser(ctx, text).must_parse_attribute() assert attr.data == data - - diff --git a/tests/test_parser_error.py b/tests/test_parser_error.py index 0e3718608c..98fdbdcc3e 100644 --- a/tests/test_parser_error.py +++ b/tests/test_parser_error.py @@ -32,7 +32,8 @@ def check_error(prog: str, line: int, column: int, message: str): assert err.error.span.get_line_col() == (line, column) break else: - assert False, "'{}' not found in an error message {}!".format(message, e.value.args) + assert False, "'{}' not found in an error message {}!".format( + message, e.value.args) def test_parser_missing_equal(): @@ -46,7 +47,8 @@ def test_parser_missing_equal(): %0 : !i32 unknown() } """ - check_error(prog, 3, 12, "Operation definitions expect an `=` after op-result-list!") + check_error(prog, 3, 12, + "Operation definitions expect an `=` after op-result-list!") def test_parser_redefined_value(): diff --git a/tests/test_printer.py b/tests/test_printer.py index ce370a7dfc..aa9bfd20bc 100644 --- a/tests/test_printer.py +++ b/tests/test_printer.py @@ -383,17 +383,22 @@ class PlusCustomFormatOp(Operation): @classmethod def parse(cls, result_types: List[Attribute], parser: BaseParser) -> PlusCustomFormatOp: + def get_ssa_val(name: Span) -> SSAValue: if name.text not in parser.ssaValues: parser.raise_error('SSA Value used before assignment', name) return parser.ssaValues[name.text] - lhs = parser.expect(parser.try_parse_value_id, 'Expected SSA Value name here!') - parser.must_parse_characters("+", "Malformed operation format, expected `+`!") - rhs = parser.expect(parser.try_parse_value_id, 'Expected SSA Value name here!') + lhs = parser.expect(parser.try_parse_value_id, + 'Expected SSA Value name here!') + parser.must_parse_characters( + "+", "Malformed operation format, expected `+`!") + rhs = parser.expect(parser.try_parse_value_id, + 'Expected SSA Value name here!') - return PlusCustomFormatOp.create(operands=[get_ssa_val(name) for name in (lhs, rhs)], - result_types=result_types) + return PlusCustomFormatOp.create( + operands=[get_ssa_val(name) for name in (lhs, rhs)], + result_types=result_types) def print(self, printer: Printer): printer.print(" ", self.lhs, " + ", self.rhs) @@ -501,7 +506,8 @@ class CustomFormatAttr(ParametrizedAttribute): @staticmethod def parse_parameters(parser: BaseParser) -> list[Attribute]: parser.parse_char("<") - value = parser.tokenizer.next_token_of_pattern(re.compile('(zero|one)')) + value = parser.tokenizer.next_token_of_pattern( + re.compile('(zero|one)')) if value and value.text == "zero": parser.parse_char(">") return [IntAttr.from_int(0)] diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index ffc72a4dc2..ba8b9a6bdb 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -355,7 +355,8 @@ def parse_parameter(parser: BaseParser) -> dict[str, Attribute]: return MLIRParser.must_parse_optional_attr_dict(parser) @staticmethod - def print_parameter(data: dict[StringAttr, Attribute], printer: Printer) -> None: + def print_parameter(data: dict[StringAttr, Attribute], + printer: Printer) -> None: printer.print_string("{") printer.print_dictionary(data, printer.print_string_literal, printer.print_attribute) @@ -393,8 +394,7 @@ def from_dict(data: dict[str | StringAttr, Attribute]) -> DictionaryAttr: if not isinstance(k, str): raise TypeError( f"Attribute DictionaryAttr expects keys to" - f" be of type str or str, but {type(k)} provided" - ) + f" be of type str or str, but {type(k)} provided") to_add_data[k] = v return DictionaryAttr(to_add_data) diff --git a/xdsl/dialects/llvm.py b/xdsl/dialects/llvm.py index b4e203d8e6..ec1fcad79d 100644 --- a/xdsl/dialects/llvm.py +++ b/xdsl/dialects/llvm.py @@ -40,8 +40,11 @@ def print_parameters(self, printer: Printer) -> None: @staticmethod def parse_parameters(parser: BaseParser) -> list[Attribute]: parser.must_parse_characters("<(", "LLVM Struct must start with `<(`") - params = parser.must_parse_list_of(parser.try_parse_attribute, "Malformed LLVM struct, expected attribute definition here!") - parser.must_parse_characters(")>", "Unexpected input, expected end of LLVM struct!") + params = parser.must_parse_list_of( + parser.try_parse_attribute, + "Malformed LLVM struct, expected attribute definition here!") + parser.must_parse_characters( + ")>", "Unexpected input, expected end of LLVM struct!") return [StringAttr.from_str(""), ArrayAttr.from_list(params)] diff --git a/xdsl/parser.py b/xdsl/parser.py index 1b95d70e5a..42b0e5a1f8 100644 --- a/xdsl/parser.py +++ b/xdsl/parser.py @@ -633,13 +633,11 @@ class BaseParser(ABC): of all try_parse functions is T_ | None """ - def __init__( - self, - ctx: MLContext, - input: str, - name: str = '', - allow_unregistered_ops = False - ): + def __init__(self, + ctx: MLContext, + input: str, + name: str = '', + allow_unregistered_ops=False): self.tokenizer = Tokenizer(Input(input, name)) self.ctx = ctx self.ssaValues = dict() @@ -1580,7 +1578,9 @@ def parse_int_literal(self) -> int: def try_parse_builtin_dict_attr(self): attr_def = self.ctx.get_optional_attr('dictionary') if attr_def is None: - self.raise_error("An attribute named `dictionary` must be available in the context in order to parse dictionary attributes! Please make sure the builtin dialect is available, or provide your own replacement!") + self.raise_error( + "An attribute named `dictionary` must be available in the context in order to parse dictionary attributes! Please make sure the builtin dialect is available, or provide your own replacement!" + ) param = attr_def.parse_parameter(self) return attr_def(param) @@ -1634,7 +1634,6 @@ def must_parse_op_result_list( None, ) - def must_parse_optional_attr_dict(self) -> dict[str, Attribute]: if not self.tokenizer.starts_with("{"): return dict() @@ -1646,7 +1645,7 @@ def must_parse_optional_attr_dict(self) -> dict[str, Attribute]: attrs = [] if not self.tokenizer.starts_with('}'): attrs = self.must_parse_list_of(self.must_parse_attribute_entry, - "Expected attribute entry") + "Expected attribute entry") self.must_parse_characters( "}", @@ -1842,7 +1841,7 @@ def Parser(ctx: MLContext, prog: str, source: Source = Source.XDSL, filename: str = '', - allow_unregistered_ops = False) -> BaseParser: + allow_unregistered_ops=False) -> BaseParser: selected_parser = { Source.XDSL: XDSLParser, Source.MLIR: MLIRParser From c51e7a9dac62ef04f2ef1d84993c4bd5a0d965c4 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Tue, 17 Jan 2023 21:59:52 +0000 Subject: [PATCH 29/65] tests: fix broken float parsing filecheck --- tests/filecheck/parser-printer/float_parsing.xdsl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/filecheck/parser-printer/float_parsing.xdsl b/tests/filecheck/parser-printer/float_parsing.xdsl index 1810fffb41..cf08385de1 100644 --- a/tests/filecheck/parser-printer/float_parsing.xdsl +++ b/tests/filecheck/parser-printer/float_parsing.xdsl @@ -8,16 +8,16 @@ builtin.module() { %1 : !f32 = arith.constant() ["value" = -42.0 : !f32] // CHECK-NEXT: %{{.*}} : !f32 = arith.constant() ["value" = -42.0 : !f32] - %2 : !f32 = arith.constant() ["value" = 34e0 : !f32] + %2 : !f32 = arith.constant() ["value" = 34.e0 : !f32] // CHECK-NEXT: %{{.*}} : !f32 = arith.constant() ["value" = 34.0 : !f32] - %3 : !f32 = arith.constant() ["value" = 34e-23 : !f32] + %3 : !f32 = arith.constant() ["value" = 34.e-23 : !f32] // CHECK-NEXT: %{{.*}} : !f32 = arith.constant() ["value" = 3.4e-22 : !f32] - %4 : !f32 = arith.constant() ["value" = 34e12 : !f32] + %4 : !f32 = arith.constant() ["value" = 34.e12 : !f32] // CHECK-NEXT: %{{.*}} : !f32 = arith.constant() ["value" = 34000000000000.0 : !f32] - %5 : !f32 = arith.constant() ["value" = -34e-12 : !f32] + %5 : !f32 = arith.constant() ["value" = -34.e-12 : !f32] // CHECK-NEXT: %{{.*}} : !f32 = arith.constant() ["value" = -3.4e-11 : !f32] func.return() From 7782654ceac34e981e85555d1d62c9954a149727 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Tue, 17 Jan 2023 22:02:36 +0000 Subject: [PATCH 30/65] parser: fix xdsl-opt command to properly use new parser class --- xdsl/parser.py | 2 ++ xdsl/tools/xdsl-opt | 21 +-------------------- xdsl/xdsl_opt_main.py | 40 ++++++++-------------------------------- 3 files changed, 11 insertions(+), 52 deletions(-) diff --git a/xdsl/parser.py b/xdsl/parser.py index 42b0e5a1f8..6a4909e365 100644 --- a/xdsl/parser.py +++ b/xdsl/parser.py @@ -530,6 +530,7 @@ def is_eof(self): """ try: self.next_pos() + return False except EOFError: return True @@ -643,6 +644,7 @@ def __init__(self, self.ssaValues = dict() self.blocks = dict() self.forward_block_references = dict() + self.allow_unregistered_ops = allow_unregistered_ops def begin_parse(self): ops = [] diff --git a/xdsl/tools/xdsl-opt b/xdsl/tools/xdsl-opt index 4923ccfe33..38206cfb69 100755 --- a/xdsl/tools/xdsl-opt +++ b/xdsl/tools/xdsl-opt @@ -1,25 +1,6 @@ #!/usr/bin/env python3 -import argparse from xdsl.xdsl_opt_main import xDSLOptMain - -class OptMain(xDSLOptMain): - - def register_all_dialects(self): - super().register_all_dialects() - - def register_all_passes(self): - super().register_all_passes() - - def register_all_arguments(self, arg_parser: argparse.ArgumentParser): - super().register_all_arguments(arg_parser) - - -def __main__(): - xdsl_main = OptMain() - xdsl_main.run() - - if __name__ == "__main__": - __main__() + xDSLOptMain().run() diff --git a/xdsl/xdsl_opt_main.py b/xdsl/xdsl_opt_main.py index 4781e3ad4d..dace34b7d6 100644 --- a/xdsl/xdsl_opt_main.py +++ b/xdsl/xdsl_opt_main.py @@ -5,7 +5,7 @@ import coverage from xdsl.ir import MLContext -from xdsl.parser import Parser, XDSLParser, MLIRParser +from xdsl.parser import Parser, XDSLParser, MLIRParser, BaseParser from xdsl.printer import Printer from xdsl.dialects.func import Func from xdsl.dialects.scf import Scf @@ -33,7 +33,7 @@ class xDSLOptMain: attributes. """ - available_frontends: Dict[str, Callable[[IOBase], ModuleOp]] + available_frontends: Dict[str, type[BaseParser]] """ A mapping from file extension to a frontend that can handle this file type. @@ -215,35 +215,8 @@ def register_all_frontends(self): Add other/additional frontends by overloading this function. """ - - def parse_xdsl(f: IOBase): - input_str = f.read() - parser = XDSLParser( - self.ctx, - input_str, - self.args.input_file or '', - allow_unregistered_ops=self.args.allow_unregistered_ops) - module = parser.parse_op() - if not (isinstance(module, ModuleOp)): - raise Exception( - "Expected module or program as toplevel operation") - return module - - def parse_mlir(f: IOBase): - input_str = f.read() - parser = MLIRParser( - self.ctx, - input_str, - self.args.input_file or '', - allow_unregistered_ops=self.args.allow_unregistered_ops) - module = parser.parse_op() - if not (isinstance(module, ModuleOp)): - raise Exception( - "Expected module or program as toplevel operation") - return module - - self.available_frontends['xdsl'] = parse_xdsl - self.available_frontends['mlir'] = parse_mlir + self.available_frontends['xdsl'] = XDSLParser + self.available_frontends['mlir'] = MLIRParser def register_all_passes(self): """ @@ -319,7 +292,10 @@ def parse_input(self) -> ModuleOp: if file_extension not in self.available_frontends: raise Exception(f"Unrecognized file extension '{file_extension}'") - return self.available_frontends[file_extension](f) + parser = self.available_frontends[file_extension]( + self.ctx, f.read(), self.args.input_file or 'stdin', + self.args.allow_unregistered_ops) + return parser.begin_parse() def apply_passes(self, prog: ModuleOp): """Apply passes in order.""" From d77fa27efc3bc7e68a6062912c111042cb16b1ae Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Tue, 17 Jan 2023 22:02:58 +0000 Subject: [PATCH 31/65] parser: fix UnregisteredOp interaction with new parser --- xdsl/dialects/builtin.py | 33 ++++++++++--------- xdsl/ir.py | 2 ++ xdsl/parser.py | 70 +++++++++++++++++++++++++++++----------- 3 files changed, 72 insertions(+), 33 deletions(-) diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index ba8b9a6bdb..c241be7770 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -6,7 +6,7 @@ TYPE_CHECKING, Any, TypeVar) from xdsl.ir import (Data, MLIRType, ParametrizedAttribute, Operation, - SSAValue, Region, Attribute, Dialect) + SSAValue, Region, Attribute, Dialect, MLContext) from xdsl.irdl import (AttributeDef, VarOpResult, VarOperand, VarRegionDef, irdl_attr_definition, attr_constr_coercion, irdl_data_definition, irdl_to_attr_constraint, @@ -692,24 +692,27 @@ class UnregisteredOp(Operation): res: Annotated[VarOpResult, AnyAttr()] regs = VarRegionDef() + __registered_unregistered_ops: dict[str, type['UnregisteredOp']] = dict() + @property def op_name(self) -> StringAttr: return self.op_name__ # type: ignore - @staticmethod - def from_name(name: str | StringAttr, - args: list[SSAValue | Operation] = [], - res: list[Attribute] = [], - regs: list[Region] = [], - attrs: dict[str, Attribute] = {}) -> UnregisteredOp: - if "op_name__" in attrs: - raise Exception( - "Cannot create an unregistered op with an __op_name attribute") - attrs["op_name__"] = StringAttr.build(name) - return UnregisteredOp.build(operands=args, - result_types=res, - regions=regs, - attributes=attrs) + @classmethod + def with_name(cls, name: str, ctx: MLContext) -> type[UnregisteredOp]: + if name in ctx.registered_unregistered_ops: + return ctx.registered_unregistered_ops[name] # type: ignore + + class UnregisteredOpWithName(UnregisteredOp): + + @classmethod + def create(cls, **kwargs): + op = super().create(**kwargs) + op.attributes['op_name__'] = StringAttr.build(name) + return op + + ctx.registered_unregistered_ops[name] = UnregisteredOpWithName + return UnregisteredOpWithName @irdl_op_definition diff --git a/xdsl/ir.py b/xdsl/ir.py index 9b2eb7a6b6..fdb5cf8cf1 100644 --- a/xdsl/ir.py +++ b/xdsl/ir.py @@ -49,6 +49,8 @@ class MLContext: """Contains structures for operations/attributes registration.""" _registeredOps: dict[str, type[Operation]] = field(default_factory=dict) _registeredAttrs: dict[str, type[Attribute]] = field(default_factory=dict) + registered_unregistered_ops: dict[str, type[Operation]] = field( + default_factory=dict) def register_dialect(self, dialect: Dialect): """Register a dialect. Operation and Attribute names should be unique""" diff --git a/xdsl/parser.py b/xdsl/parser.py index 6a4909e365..82ae7435c5 100644 --- a/xdsl/parser.py +++ b/xdsl/parser.py @@ -18,7 +18,8 @@ AnyTensorType, AnyVectorType, Float16Type, Float32Type, Float64Type, FloatAttr, FunctionType, IndexType, IntegerType, Signedness, StringAttr, IntegerAttr, ArrayAttr, TensorType, UnrankedTensorType, VectorType, - DefaultIntegerAttrType, FlatSymbolRefAttr, DenseIntOrFPElementsAttr) + DefaultIntegerAttrType, FlatSymbolRefAttr, DenseIntOrFPElementsAttr, + UnregisteredOp) from xdsl.ir import (SSAValue, Block, Callable, Attribute, Operation, Region, BlockArgument, MLContext, ParametrizedAttribute, Data) @@ -37,7 +38,8 @@ def __init__(self, io = StringIO() history.print_unroll(io) preamble = io.getvalue() + '\n' - + if span is None: + raise ValueError("Span can't be None!") super().__init__(preamble + span.print_with_context(msg)) self.span = span self.msg = msg @@ -307,6 +309,9 @@ class Tokenizer: last_token: Span | None = field(init=False, default=None, repr=False) + def __post_init__(self): + self.last_token = self.next_token(peek=True) + def save(self) -> save_t: """ Create a checkpoint in the parsing process, useful for backtracking @@ -634,6 +639,8 @@ class BaseParser(ABC): of all try_parse functions is T_ | None """ + allow_unregistered_ops: bool + def __init__(self, ctx: MLContext, input: str, @@ -647,12 +654,10 @@ def __init__(self, self.allow_unregistered_ops = allow_unregistered_ops def begin_parse(self): - ops = [] - while (op := self.try_parse_operation()) is not None: - ops.append(op) - if not self.tokenizer.is_eof(): + op = self.try_parse_operation() + if not op: self.raise_error("Could not parse entire input!") - return ops + return op def get_block_from_name(self, block_name: Span): """ @@ -1097,7 +1102,7 @@ def must_parse_operation(self) -> Operation: # check for custom op format op_name = self.try_parse_bare_id() if op_name is not None: - op_type = self.ctx.get_op(op_name.text) + op_type = self._get_op_by_name(op_name) op = op_type.parse(ret_types, self) else: # check for basic op format @@ -1114,7 +1119,7 @@ def must_parse_operation(self) -> Operation: assert func_type is not None ret_types = func_type.outputs.data - op_type = self.ctx.get_op(op_name.string_contents) + op_type = self._get_op_by_name(op_name) op = op_type.create( operands=[self.ssaValues[span.text] for span in args], @@ -1136,6 +1141,22 @@ def must_parse_operation(self) -> Operation: return op + def _get_op_by_name(self, span: Span) -> type[Operation]: + if isinstance(span, StringLiteral): + op_name = span.string_contents + else: + op_name = span.text + + op_type = self.ctx.get_optional_op(op_name) + + if op_type is not None: + return op_type + + if self.allow_unregistered_ops: + return UnregisteredOp.with_name(op_name, self.ctx) + + self.raise_error(f'Unknown operation {op_name}!', span) + def must_parse_region(self) -> Region: oldSSAVals = self.ssaValues.copy() oldBBNames = self.blocks @@ -1309,7 +1330,6 @@ def try_parse_builtin_int_attr(self) -> IntegerAttr | None: self.try_parse_integer_literal, 'Integer attribute must start with an integer literal!') if self.tokenizer.next_token(peek=True).text != ':': - print(self.tokenizer.next_token(peek=True)) return IntegerAttr.from_params(int(value.text), DefaultIntegerAttrType) type = self.must_parse_attribute_type() @@ -1495,15 +1515,9 @@ def must_parse_operation_details( """ raise NotImplementedError() + @abstractmethod def must_parse_op_args_list(self) -> list[Span]: - self.must_parse_characters( - "(", "Operation args list must be enclosed by brackets!") - args = self.must_parse_list_of(self.try_parse_value_id_and_type, - "Expected another bare-id here") - self.must_parse_characters( - ")", "Operation args list must be closed by a closing bracket") - # TODO: check if type is correct here! - return [name for name, _ in args] + raise NotImplementedError() # HERE STARTS A SOMEWHAT CURSED COMPATIBILITY LAYER: # since we don't want to rewrite all dialects currently, the new emulator needs to expose the same @@ -1691,6 +1705,16 @@ def must_parse_optional_successor_list(self) -> list[Span]: "]", "Successor list is enclosed in square brackets") return successors + def must_parse_op_args_list(self) -> list[Span]: + self.must_parse_characters( + "(", "Operation args list must be enclosed by brackets!") + args = self.must_parse_list_of(self.try_parse_value_id, + "Expected another bare-id here") + self.must_parse_characters( + ")", "Operation args list must be closed by a closing bracket") + # TODO: check if type is correct here! + return args + class XDSLParser(BaseParser): @@ -1830,6 +1854,16 @@ def must_parse_generic_attribute_args(self, name: StringLiteral): '>', 'Malformed attribute arguments, reached end of args list!') return attr(args) + def must_parse_op_args_list(self) -> list[Span]: + self.must_parse_characters( + "(", "Operation args list must be enclosed by brackets!") + args = self.must_parse_list_of(self.try_parse_value_id_and_type, + "Expected another bare-id here") + self.must_parse_characters( + ")", "Operation args list must be closed by a closing bracket") + # TODO: check if type is correct here! + return [name for name, _ in args] + # COMPAT layer so parser_ng is a drop-in replacement for parser: From b415f6ac490a0da943d47008694c226ae5c4b82d Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Tue, 17 Jan 2023 22:13:10 +0000 Subject: [PATCH 32/65] parser: add function_type as possible value for parse_attribute --- xdsl/parser.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xdsl/parser.py b/xdsl/parser.py index 82ae7435c5..63583cf5df 100644 --- a/xdsl/parser.py +++ b/xdsl/parser.py @@ -580,7 +580,7 @@ class ParserCommons: integer_literal = re.compile(r"[+-]?([0-9]+|0x[0-9A-Fa-f]+)") decimal_literal = re.compile(r"[+-]?([1-9][0-9]*)") - string_literal = re.compile(r'"([^\n\f\v\r"]|\\[nfvr"])+"') + string_literal = re.compile(r'"(\\[nfvr"\\]|[^\n\f\v\r"\\])*"') float_literal = re.compile(r"[-+]?[0-9]+\.[0-9]*([eE][-+]?[0-9]+)?") bare_id = re.compile(r"[A-Za-z_][\w$.]+") value_id = re.compile(r"%([0-9]+|([A-Za-z_$.-][\w$.-]*))") @@ -1261,6 +1261,8 @@ def try_parse_builtin_attr(self) -> Attribute | None: return self.try_parse_builtin_dense_attr() elif next_token.text == '{': return self.try_parse_builtin_dict_attr() + elif next_token.text == '(': + return self.try_parse_function_type() # order here is important! attrs = (self.try_parse_builtin_float_attr, From ba392077cca0d0241cba6257320efb0063e72e3a Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Tue, 17 Jan 2023 22:15:47 +0000 Subject: [PATCH 33/65] parser: fix parsing of symbol reference We used to cut off the first letter of a symbol reference, we stopped doing that now. --- xdsl/parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xdsl/parser.py b/xdsl/parser.py index 63583cf5df..552d59f61d 100644 --- a/xdsl/parser.py +++ b/xdsl/parser.py @@ -1320,7 +1320,7 @@ def try_parse_ref_attr(self) -> FlatSymbolRefAttr | None: if len(ref) > 1: self.raise_error("Nested refs are not supported yet!", ref[1]) - return FlatSymbolRefAttr.from_str(ref[0].text[1:]) + return FlatSymbolRefAttr.from_str(ref[0].text) def try_parse_builtin_int_attr(self) -> IntegerAttr | None: bool = self.try_parse_builtin_boolean_attr() From 1e9ccc817cf4cadb598e7fbd641a9f53d6f9640e Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Tue, 17 Jan 2023 22:36:57 +0000 Subject: [PATCH 34/65] printer: fix printing of strings containing special characters --- xdsl/printer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xdsl/printer.py b/xdsl/printer.py index 85e2ed7c67..9b275d8935 100644 --- a/xdsl/printer.py +++ b/xdsl/printer.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json from dataclasses import dataclass, field from enum import Enum from frozenlist import FrozenList @@ -296,7 +297,7 @@ def print_paramattr_parameters( self.print(">") def print_string_literal(self, string: str): - self.print(f'"{string}"') + self.print(json.dumps(string)) def print_attribute(self, attribute: Attribute) -> None: if isinstance(attribute, UnitAttr): From e21d7a04bb012f86613b4f536843d08fb5b35fe8 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Tue, 17 Jan 2023 22:51:15 +0000 Subject: [PATCH 35/65] parser: fix some more failing tests --- xdsl/dialects/llvm.py | 2 +- xdsl/parser.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/xdsl/dialects/llvm.py b/xdsl/dialects/llvm.py index ec1fcad79d..76c0bf4372 100644 --- a/xdsl/dialects/llvm.py +++ b/xdsl/dialects/llvm.py @@ -41,7 +41,7 @@ def print_parameters(self, printer: Printer) -> None: def parse_parameters(parser: BaseParser) -> list[Attribute]: parser.must_parse_characters("<(", "LLVM Struct must start with `<(`") params = parser.must_parse_list_of( - parser.try_parse_attribute, + parser.try_parse_type, "Malformed LLVM struct, expected attribute definition here!") parser.must_parse_characters( ")>", "Unexpected input, expected end of LLVM struct!") diff --git a/xdsl/parser.py b/xdsl/parser.py index 552d59f61d..ef7628f02f 100644 --- a/xdsl/parser.py +++ b/xdsl/parser.py @@ -580,7 +580,7 @@ class ParserCommons: integer_literal = re.compile(r"[+-]?([0-9]+|0x[0-9A-Fa-f]+)") decimal_literal = re.compile(r"[+-]?([1-9][0-9]*)") - string_literal = re.compile(r'"(\\[nfvr"\\]|[^\n\f\v\r"\\])*"') + string_literal = re.compile(r'"(\\[nfvtr"\\]|[^\n\f\v\r"\\])*"') float_literal = re.compile(r"[-+]?[0-9]+\.[0-9]*([eE][-+]?[0-9]+)?") bare_id = re.compile(r"[A-Za-z_][\w$.]+") value_id = re.compile(r"%([0-9]+|([A-Za-z_$.-][\w$.-]*))") @@ -990,7 +990,7 @@ def must_parse_vector_attrs(self) -> AnyVectorType: self.raise_error( "Expected a type at the end of the vector parameters!") - return VectorType.from_type_and_list(type, shape) + return VectorType.from_element_type_and_shape(type, shape) def must_parse_tensor_or_memref_dims(self) -> list[int] | None: with self.tokenizer.configured(break_on=self.tokenizer.break_on + @@ -1867,6 +1867,9 @@ def must_parse_op_args_list(self) -> list[Span]: return [name for name, _ in args] + def try_parse_type(self) -> Attribute | None: + return self.try_parse_attribute() + # COMPAT layer so parser_ng is a drop-in replacement for parser: From 9365029483b3012350fe776515909d47c0880e86 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Wed, 18 Jan 2023 17:15:03 +0000 Subject: [PATCH 36/65] parser: add opaque attribute and fix lots of small issues --- xdsl/parser.py | 83 ++++++++++++++++++++++++++++++++------------------ 1 file changed, 53 insertions(+), 30 deletions(-) diff --git a/xdsl/parser.py b/xdsl/parser.py index ef7628f02f..1ffadc536a 100644 --- a/xdsl/parser.py +++ b/xdsl/parser.py @@ -19,7 +19,7 @@ FloatAttr, FunctionType, IndexType, IntegerType, Signedness, StringAttr, IntegerAttr, ArrayAttr, TensorType, UnrankedTensorType, VectorType, DefaultIntegerAttrType, FlatSymbolRefAttr, DenseIntOrFPElementsAttr, - UnregisteredOp) + UnregisteredOp, OpaqueAttr, NoneAttr) from xdsl.ir import (SSAValue, Block, Callable, Attribute, Operation, Region, BlockArgument, MLContext, ParametrizedAttribute, Data) @@ -35,9 +35,7 @@ def __init__(self, history: BacktrackingHistory | None = None): preamble = "" if history: - io = StringIO() - history.print_unroll(io) - preamble = io.getvalue() + '\n' + preamble = history.error.args[0] + '\n' if span is None: raise ValueError("Span can't be None!") super().__init__(preamble + span.print_with_context(msg)) @@ -595,6 +593,8 @@ class ParserCommons: "opaque", "tuple", "index", "dense" # TODO: add all the Float8E4M3FNType, Float8E5M2Type, and BFloat16Type ) + builtin_attr_names = ('dense', 'opaque', 'affine_map', 'array', + 'dense_resource', 'sparse') builtin_type = re.compile("(({}))".format(")|(".join(_builtin_type_names))) builtin_type_xdsl = re.compile("!(({}))".format( ")|(".join(_builtin_type_names))) @@ -923,7 +923,6 @@ def unimplemented() -> ParametrizedAttribute: "memref": unimplemented, "tensor": self.must_parse_tensor_attrs, "complex": self.must_parse_complex_attrs, - "opaque": unimplemented, "tuple": unimplemented, } @@ -932,15 +931,6 @@ def unimplemented() -> ParametrizedAttribute: res = builtin_parsers.get(name.text, unimplemented)() self.must_parse_characters(">", "Expected end of parameter list here!") - if name in ("dense", ): - self.must_parse_characters( - ":", - "Attribute {} must be followed by (`:` type)!".format(name)) - type = self.expect( - self.try_parse_type(), - "Attribute {} must be followed by (`:` type)!".format(name), - ) - return res def must_parse_dense_type_attrs(self): @@ -1257,33 +1247,66 @@ def try_parse_builtin_attr(self) -> Attribute | None: return self.try_parse_builtin_arr_attr() elif next_token.text == "@": return self.try_parse_ref_attr() - elif next_token.text == "dense": - return self.try_parse_builtin_dense_attr() elif next_token.text == '{': return self.try_parse_builtin_dict_attr() elif next_token.text == '(': return self.try_parse_function_type() - + elif next_token.text in ParserCommons.builtin_attr_names: + return self.try_parse_builtin_named_attr() # order here is important! attrs = (self.try_parse_builtin_float_attr, - self.try_parse_builtin_int_attr) + self.try_parse_builtin_int_attr, self.try_parse_builtin_type) for attr_parser in attrs: if (val := attr_parser()) is not None: return val - def try_parse_builtin_dense_attr(self) -> Attribute | None: - with self.tokenizer.backtracking("dense attribute"): - self.must_parse_characters( - "dense", "builtin dense attribute must start with `dense`") - err_msg = "Malformed dense attribute, format must be (`dense<` array-attr `>:` type)" - self.must_parse_characters("<", err_msg) - info = list(self.must_parse_builtin_dense_attr_args()) - self.must_parse_characters(">", err_msg) - self.must_parse_characters(":", err_msg) + def try_parse_builtin_named_attr(self) -> Attribute | None: + name = self.tokenizer.next_token(peek=True) + with self.tokenizer.backtracking("Builtin attribute {}".format( + name.text)): + self.tokenizer.consume_peeked(name) + parsers = { + 'dense': self.must_parse_builtin_dense_attr, + 'opaque': self.must_parse_builtin_opaque_attr, + } + + def not_implemented(): + raise NotImplementedError() + + return parsers.get(name.text, not_implemented)() + + def must_parse_builtin_dense_attr(self) -> Attribute | None: + err_msg = "Malformed dense attribute, format must be (`dense<` array-attr `>:` type)" + self.must_parse_characters("<", err_msg) + info = list(self.must_parse_builtin_dense_attr_args()) + self.must_parse_characters(">", err_msg) + self.must_parse_characters(":", err_msg) + type = self.expect(self.try_parse_type, + "Dense attribute must be typed!") + return DenseIntOrFPElementsAttr.from_list(type, info) + + def must_parse_builtin_opaque_attr(self): + self.must_parse_characters("<", + "Opaque attribute must be parametrized") + str_lit_list = self.must_parse_list_of(self.try_parse_string_literal, + 'Expected opaque attr here!') + + if len(str_lit_list) != 2: + self.raise_error('Opaque expects 2 string literal parameters!') + + self.must_parse_characters( + ">", "Unexpected parameters for opaque attr, expected `>`!") + + type = NoneAttr() + if self.tokenizer.starts_with(':'): + self.must_parse_characters(":", "opaque attribute must be typed!") type = self.expect(self.try_parse_type, - "Dense attribute must be typed!") - return DenseIntOrFPElementsAttr.from_list(type, info) + "opaque attribute must be typed!") + + return OpaqueAttr.from_strings(*(span.string_contents + for span in str_lit_list), + type=type) def must_parse_builtin_dense_attr_args(self) -> Iterable[int | float]: """ @@ -1866,10 +1889,10 @@ def must_parse_op_args_list(self) -> list[Span]: # TODO: check if type is correct here! return [name for name, _ in args] - def try_parse_type(self) -> Attribute | None: return self.try_parse_attribute() + # COMPAT layer so parser_ng is a drop-in replacement for parser: From 684c809c241975aea01c31f6d5db2a58bdfa8cae Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Wed, 18 Jan 2023 17:25:38 +0000 Subject: [PATCH 37/65] parser+tests: fix module-op at top level error --- tests/xdsl_opt/test_xdsl_opt.py | 6 +++--- xdsl/parser.py | 5 ++++- xdsl/xdsl_opt_main.py | 2 +- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/xdsl_opt/test_xdsl_opt.py b/tests/xdsl_opt/test_xdsl_opt.py index 55dcf6896b..2a2de11ace 100644 --- a/tests/xdsl_opt/test_xdsl_opt.py +++ b/tests/xdsl_opt/test_xdsl_opt.py @@ -29,9 +29,9 @@ def test_empty_program(): @pytest.mark.parametrize("args, expected_error", [(['tests/xdsl_opt/not_module.xdsl'], - "Expected module or program as toplevel operation"), + "Expected ModuleOp at top level!"), (['tests/xdsl_opt/not_module.mlir'], - "Expected module or program as toplevel operation"), + "Expected ModuleOp at top level!"), (['tests/xdsl_opt/empty_program.wrong' ], "Unrecognized file extension 'wrong'")]) def test_error_on_run(args, expected_error): @@ -40,7 +40,7 @@ def test_error_on_run(args, expected_error): with pytest.raises(Exception) as e: opt.run() - assert e.value.args[0] == expected_error + assert expected_error in e.value.args[0] @pytest.mark.parametrize( diff --git a/xdsl/parser.py b/xdsl/parser.py index 1ffadc536a..24fed1ce9e 100644 --- a/xdsl/parser.py +++ b/xdsl/parser.py @@ -19,7 +19,7 @@ FloatAttr, FunctionType, IndexType, IntegerType, Signedness, StringAttr, IntegerAttr, ArrayAttr, TensorType, UnrankedTensorType, VectorType, DefaultIntegerAttrType, FlatSymbolRefAttr, DenseIntOrFPElementsAttr, - UnregisteredOp, OpaqueAttr, NoneAttr) + UnregisteredOp, OpaqueAttr, NoneAttr, ModuleOp) from xdsl.ir import (SSAValue, Block, Callable, Attribute, Operation, Region, BlockArgument, MLContext, ParametrizedAttribute, Data) @@ -655,6 +655,9 @@ def __init__(self, def begin_parse(self): op = self.try_parse_operation() + if not isinstance(op, ModuleOp): + self.tokenizer.pos = 0 + self.raise_error("Expected ModuleOp at top level!", self.tokenizer.next_token()) if not op: self.raise_error("Could not parse entire input!") return op diff --git a/xdsl/xdsl_opt_main.py b/xdsl/xdsl_opt_main.py index dace34b7d6..fccf99517b 100644 --- a/xdsl/xdsl_opt_main.py +++ b/xdsl/xdsl_opt_main.py @@ -299,7 +299,7 @@ def parse_input(self) -> ModuleOp: def apply_passes(self, prog: ModuleOp): """Apply passes in order.""" - assert isinstance(prog, ModuleOp) + assert isinstance(prog, ModuleOp), "Expected top-level module!" if not self.args.disable_verify: prog.verify() for pass_name, p in self.pipeline: From c49fc2905ea7b22419f64eae9d5f597b5b390d8f Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Wed, 18 Jan 2023 17:43:48 +0000 Subject: [PATCH 38/65] parser: added memref support --- tests/xdsl_opt/test_xdsl_opt.py | 13 ++++++------- xdsl/parser.py | 20 +++++++++++--------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/tests/xdsl_opt/test_xdsl_opt.py b/tests/xdsl_opt/test_xdsl_opt.py index 2a2de11ace..4f71ccc983 100644 --- a/tests/xdsl_opt/test_xdsl_opt.py +++ b/tests/xdsl_opt/test_xdsl_opt.py @@ -27,13 +27,12 @@ def test_empty_program(): assert f.getvalue().strip() == expected.strip() -@pytest.mark.parametrize("args, expected_error", - [(['tests/xdsl_opt/not_module.xdsl'], - "Expected ModuleOp at top level!"), - (['tests/xdsl_opt/not_module.mlir'], - "Expected ModuleOp at top level!"), - (['tests/xdsl_opt/empty_program.wrong' - ], "Unrecognized file extension 'wrong'")]) +@pytest.mark.parametrize( + "args, expected_error", + [(['tests/xdsl_opt/not_module.xdsl'], "Expected ModuleOp at top level!"), + (['tests/xdsl_opt/not_module.mlir'], "Expected ModuleOp at top level!"), + (['tests/xdsl_opt/empty_program.wrong' + ], "Unrecognized file extension 'wrong'")]) def test_error_on_run(args, expected_error): opt = xDSLOptMain(args=args) diff --git a/xdsl/parser.py b/xdsl/parser.py index 24fed1ce9e..e09d2fa976 100644 --- a/xdsl/parser.py +++ b/xdsl/parser.py @@ -14,6 +14,7 @@ from io import StringIO from typing import TypeVar, Iterable +from xdsl.dialects.memref import MemRefType, UnrankedMemrefType from xdsl.dialects.builtin import ( AnyTensorType, AnyVectorType, Float16Type, Float32Type, Float64Type, FloatAttr, FunctionType, IndexType, IntegerType, Signedness, StringAttr, @@ -657,7 +658,8 @@ def begin_parse(self): op = self.try_parse_operation() if not isinstance(op, ModuleOp): self.tokenizer.pos = 0 - self.raise_error("Expected ModuleOp at top level!", self.tokenizer.next_token()) + self.raise_error("Expected ModuleOp at top level!", + self.tokenizer.next_token()) if not op: self.raise_error("Could not parse entire input!") return op @@ -923,7 +925,7 @@ def unimplemented() -> ParametrizedAttribute: builtin_parsers: dict[str, Callable[[], ParametrizedAttribute]] = { "vector": self.must_parse_vector_attrs, - "memref": unimplemented, + "memref": self.must_parse_memref_attrs, "tensor": self.must_parse_tensor_attrs, "complex": self.must_parse_complex_attrs, "tuple": unimplemented, @@ -936,16 +938,16 @@ def unimplemented() -> ParametrizedAttribute: return res - def must_parse_dense_type_attrs(self): - arr = self.expect( - self.try_parse_builtin_arr_attr(), - "dense attribute must be parametrized by Array", - ) - DenseIntOrFPElementsAttr.from_list(arr) - def must_parse_complex_attrs(self): self.raise_error("ComplexType is unimplemented!") + def must_parse_memref_attrs(self) -> MemRefType | UnrankedMemrefType: + dims = self.must_parse_tensor_or_memref_dims() + type = self.try_parse_type() + if dims is None: + return UnrankedMemrefType.from_type(type) + return MemRefType.from_element_type_and_shape(type, dims) + def try_parse_numerical_dims(self, accept_closing_bracket: bool = False, lower_bound: int = 1) -> Iterable[int]: From 0420b5621beec7b1996fb096f07a9d07b5ea7a18 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Wed, 18 Jan 2023 18:02:24 +0000 Subject: [PATCH 39/65] parser: change type annotations to make python3.11 happy --- xdsl/dialects/builtin.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index c241be7770..4cbe7d11f8 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -692,14 +692,12 @@ class UnregisteredOp(Operation): res: Annotated[VarOpResult, AnyAttr()] regs = VarRegionDef() - __registered_unregistered_ops: dict[str, type['UnregisteredOp']] = dict() - @property def op_name(self) -> StringAttr: return self.op_name__ # type: ignore @classmethod - def with_name(cls, name: str, ctx: MLContext) -> type[UnregisteredOp]: + def with_name(cls, name: str, ctx: MLContext) -> type[Operation]: if name in ctx.registered_unregistered_ops: return ctx.registered_unregistered_ops[name] # type: ignore From b58143a6dd06d9f7c208f0d2302ec1b962665bbe Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Wed, 18 Jan 2023 18:12:18 +0000 Subject: [PATCH 40/65] parser: fixed special attribute-entry parsing for UnitAttr --- xdsl/parser.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/xdsl/parser.py b/xdsl/parser.py index e09d2fa976..80f5106c8d 100644 --- a/xdsl/parser.py +++ b/xdsl/parser.py @@ -20,7 +20,7 @@ FloatAttr, FunctionType, IndexType, IntegerType, Signedness, StringAttr, IntegerAttr, ArrayAttr, TensorType, UnrankedTensorType, VectorType, DefaultIntegerAttrType, FlatSymbolRefAttr, DenseIntOrFPElementsAttr, - UnregisteredOp, OpaqueAttr, NoneAttr, ModuleOp) + UnregisteredOp, OpaqueAttr, NoneAttr, ModuleOp, UnitAttr) from xdsl.ir import (SSAValue, Block, Callable, Attribute, Operation, Region, BlockArgument, MLContext, ParametrizedAttribute, Data) @@ -1213,6 +1213,9 @@ def must_parse_attribute_entry(self) -> tuple[Span, Attribute]: "Expected bare-id or string-literal here as part of attribute entry!" ) + if not self.tokenizer.starts_with('='): + return name, UnitAttr() + self.must_parse_characters( "=", "Attribute entries must be of format name `=` attribute!") From 58339bc285a8e952417accf414eb06c942d0b2d5 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Fri, 20 Jan 2023 11:01:54 +0000 Subject: [PATCH 41/65] xdsl: fix how the parser is used in tests --- tests/test_mlir_printer.py | 7 +------ tests/test_printer.py | 2 +- xdsl/parser.py | 13 ++++++++----- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/tests/test_mlir_printer.py b/tests/test_mlir_printer.py index 4e046f0892..0903835e89 100644 --- a/tests/test_mlir_printer.py +++ b/tests/test_mlir_printer.py @@ -90,12 +90,7 @@ def print_as_mlir_and_compare(test_prog: str, expected: str): ctx.register_attr(ParamAttrWithCustomFormat) parser = XDSLParser(ctx, test_prog) - try: - module = parser.must_parse_operation() - except ParseError as err: - io = StringIO() - err.print_with_history(file=io) - raise ParseError(err.span, io.getvalue(), None) + module = parser.begin_parse() res = StringIO() printer = Printer(target=Printer.Target.MLIR, stream=res) diff --git a/tests/test_printer.py b/tests/test_printer.py index aa9bfd20bc..06c9406f01 100644 --- a/tests/test_printer.py +++ b/tests/test_printer.py @@ -185,7 +185,7 @@ def test_two_different_op_messages(): ctx.register_dialect(Builtin) parser = XDSLParser(ctx, prog, '') - module = parser.parse_op() + module = parser.begin_parse() file = StringIO("") diagnostic = Diagnostic() diff --git a/xdsl/parser.py b/xdsl/parser.py index 80f5106c8d..d9a89f1e89 100644 --- a/xdsl/parser.py +++ b/xdsl/parser.py @@ -654,15 +654,18 @@ def __init__(self, self.forward_block_references = dict() self.allow_unregistered_ops = allow_unregistered_ops - def begin_parse(self): + def begin_parse(self) -> ModuleOp: op = self.try_parse_operation() - if not isinstance(op, ModuleOp): + + if op is None: + self.raise_error("Could not parse entire input!") + + if isinstance(op, ModuleOp): + return op + else: self.tokenizer.pos = 0 self.raise_error("Expected ModuleOp at top level!", self.tokenizer.next_token()) - if not op: - self.raise_error("Could not parse entire input!") - return op def get_block_from_name(self, block_name: Span): """ From d0b5a5b30a27ed67e0780f9c678d003096748dd1 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Fri, 20 Jan 2023 11:04:28 +0000 Subject: [PATCH 42/65] xdsl: fix typo in Block.delcared_at --- tests/test_mlir_printer.py | 2 +- xdsl/ir.py | 2 +- xdsl/parser.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_mlir_printer.py b/tests/test_mlir_printer.py index 0903835e89..2a83d58816 100644 --- a/tests/test_mlir_printer.py +++ b/tests/test_mlir_printer.py @@ -5,7 +5,7 @@ from xdsl.ir import Attribute, Data, MLContext, MLIRType, Operation, ParametrizedAttribute from xdsl.irdl import (AnyAttr, ParameterDef, RegionDef, irdl_attr_definition, irdl_op_definition, VarOperand, VarOpResult) -from xdsl.parser import ParseError, BaseParser, XDSLParser +from xdsl.parser import BaseParser, XDSLParser from xdsl.printer import Printer diff --git a/xdsl/ir.py b/xdsl/ir.py index fdb5cf8cf1..16024666c2 100644 --- a/xdsl/ir.py +++ b/xdsl/ir.py @@ -635,7 +635,7 @@ def irdl_definition(cls) -> OpDef: class Block(IRNode): """A sequence of operations""" - delcared_at: 'Span' | None = None + declared_at: 'Span' | None = None _args: FrozenList[BlockArgument] = field(default_factory=FrozenList, init=False) diff --git a/xdsl/parser.py b/xdsl/parser.py index d9a89f1e89..b0b1c9009a 100644 --- a/xdsl/parser.py +++ b/xdsl/parser.py @@ -687,14 +687,14 @@ def must_parse_block(self) -> Block: elif self.forward_block_references.pop(block_id.text, None) is not None: block = self.blocks[block_id.text] - block.delcared_at = block_id + block.declared_at = block_id else: if block_id.text in self.blocks: raise MultipleSpansParseError( block_id, "Re-declaration of block {}".format(block_id.text), "Originally declared here:", - [(self.blocks[block_id.text].delcared_at, None)], + [(self.blocks[block_id.text].declared_at, None)], self.tokenizer.history, ) block = Block(block_id) From 41049fe394d3a822ae36cff2a8233dde7cbf1075 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Fri, 20 Jan 2023 15:04:30 +0000 Subject: [PATCH 43/65] xdsl: fixed errorneous type hints on DictionaryAttr --- xdsl/dialects/builtin.py | 11 +++++------ xdsl/ir.py | 11 +++++++---- xdsl/parser.py | 4 ++++ 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index 4cbe7d11f8..1f9cb3dd5f 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -345,7 +345,7 @@ def from_list(data: List[_ArrayAttrT]) -> ArrayAttr[_ArrayAttrT]: @irdl_attr_definition -class DictionaryAttr(GenericData[dict[StringAttr, Attribute]]): +class DictionaryAttr(GenericData[dict[str, Attribute]]): name = "dictionary" @staticmethod @@ -355,8 +355,7 @@ def parse_parameter(parser: BaseParser) -> dict[str, Attribute]: return MLIRParser.must_parse_optional_attr_dict(parser) @staticmethod - def print_parameter(data: dict[StringAttr, Attribute], - printer: Printer) -> None: + def print_parameter(data: dict[str, Attribute], printer: Printer) -> None: printer.print_string("{") printer.print_dictionary(data, printer.print_string_literal, printer.print_attribute) @@ -385,7 +384,7 @@ def verify(self) -> None: @staticmethod @builder def from_dict(data: dict[str | StringAttr, Attribute]) -> DictionaryAttr: - to_add_data = {} + to_add_data: dict[str, Attribute] = {} for k, v in data.items(): # try to coerce keys into StringAttr if isinstance(k, StringAttr): @@ -393,8 +392,8 @@ def from_dict(data: dict[str | StringAttr, Attribute]) -> DictionaryAttr: # if coercion fails, raise KeyError! if not isinstance(k, str): raise TypeError( - f"Attribute DictionaryAttr expects keys to" - f" be of type str or str, but {type(k)} provided") + f"DictionaryAttr.from_dict expects keys to" + f" be of type str or StringAttr, but {type(k)} provided") to_add_data[k] = v return DictionaryAttr(to_add_data) diff --git a/xdsl/ir.py b/xdsl/ir.py index 16024666c2..6fcfc8d3c8 100644 --- a/xdsl/ir.py +++ b/xdsl/ir.py @@ -1,11 +1,12 @@ from __future__ import annotations +import re from abc import ABC, abstractmethod from dataclasses import dataclass, field from frozenlist import FrozenList from io import StringIO from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, Protocol, - Sequence, TypeVar, cast, Iterator, Union) + Sequence, TypeVar, cast, Iterator, Union, ClassVar) import sys # Used for cyclic dependencies in type hints @@ -122,15 +123,17 @@ class SSAValue(ABC): _name: str | None = field(init=False, default=None) + _name_regex: ClassVar[re.Pattern] = re.compile( + r'[A-Za-z0-9._$-]*[A-Za-z._$-]') + @property def name(self) -> str | None: return self._name @name.setter def name(self, name: str): - if name[-1].isnumeric(): - return - self._name = name + if self._name_regex.fullmatch(name): + self._name = name @staticmethod def get(arg: SSAValue | Operation) -> SSAValue: diff --git a/xdsl/parser.py b/xdsl/parser.py index b0b1c9009a..3352ed9fa7 100644 --- a/xdsl/parser.py +++ b/xdsl/parser.py @@ -584,6 +584,10 @@ class ParserCommons: bare_id = re.compile(r"[A-Za-z_][\w$.]+") value_id = re.compile(r"%([0-9]+|([A-Za-z_$.-][\w$.-]*))") suffix_id = re.compile(r"([0-9]+|([A-Za-z_$.-][\w$.-]*))") + """ + suffix-id ::= (digit+ | ((letter|id-punct) (letter|id-punct|digit)*)) + id-punct ::= [$._-] + """ block_id = re.compile(r"\^([0-9]+|([A-Za-z_$.-][\w$.-]*))") type_alias = re.compile(r"![A-Za-z_][\w$.]+") attribute_alias = re.compile(r"#[A-Za-z_][\w$.]+") From 7fda7e07bedb8e484ae418232a565e7580221a8a Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Fri, 20 Jan 2023 15:22:39 +0000 Subject: [PATCH 44/65] tests: fix tests that don't wrap their input in a builtin.module --- tests/test_mlir_printer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_mlir_printer.py b/tests/test_mlir_printer.py index 2a83d58816..e9e60ca454 100644 --- a/tests/test_mlir_printer.py +++ b/tests/test_mlir_printer.py @@ -90,7 +90,7 @@ def print_as_mlir_and_compare(test_prog: str, expected: str): ctx.register_attr(ParamAttrWithCustomFormat) parser = XDSLParser(ctx, test_prog) - module = parser.begin_parse() + module = parser.must_parse_operation() res = StringIO() printer = Printer(target=Printer.Target.MLIR, stream=res) From 70ecada28efce1bfb00d6bae5811f57c98c53cfd Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Fri, 20 Jan 2023 15:26:21 +0000 Subject: [PATCH 45/65] parser: fix a typo Co-authored-by: Fehr Mathieu --- xdsl/parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xdsl/parser.py b/xdsl/parser.py index 3352ed9fa7..15d39cb273 100644 --- a/xdsl/parser.py +++ b/xdsl/parser.py @@ -347,7 +347,7 @@ def backtracking(self, region_name: str | None = None): yield # clear error history when something doesn't fail # this is because we are only interested in the last "cascade" of failures. - # if a backtracking() completes without failre, something has been parsed (we assume) + # if a backtracking() completes without failure, something has been parsed (we assume) if self.pos > starting_position and self.history is not None: self.history = None except Exception as ex: From 8e5ce881e551493db2294a6b88b7599d785e7eaa Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Fri, 20 Jan 2023 15:52:40 +0000 Subject: [PATCH 46/65] parser: add docstring to tokenizer --- xdsl/parser.py | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/xdsl/parser.py b/xdsl/parser.py index 15d39cb273..7c5583a9fb 100644 --- a/xdsl/parser.py +++ b/xdsl/parser.py @@ -288,6 +288,32 @@ def at(self, i: int): @dataclass class Tokenizer: + """ + This class is used to tokenize an Input. + + It provides an interface for backtracking, so you can use: + + with tokenizer.backtracking(): + # try stuff + raise BacktrackingAbort(...) + + and not worry about manually resetting the input position. Backtracking will also + record errors that happen during backtracking to provide a richer error reporting + experience. + + It also provides the following methods to inspect the input: + + - next_token(peek) is used to get the next token + (which just breaks the input as per the rules defined in break_on) + peek=True doesn't advance the position in the file. + - next_token_of_pattern(pattern, peek) can be used to get a next token if it + conforms to a specific pattern. If a literal string is given, it'll check + if the next characters match. If a regex is given, it will check + the regex. + - starts_with(pattern) checks if the input starts with a literal string or + regex pattern + """ + input: Input pos: int = field(init=False, default=0) @@ -436,18 +462,16 @@ def history_entry_from_exception(self, ex: Exception, region: str, pos, ) - def next_token(self, start: int | None = None, peek: bool = False) -> Span: + def next_token(self, peek: bool = False) -> Span: """ Return a Span of the next token, according to the self.break_on rules. Can be modified using: - - - start: don't start at the current tokenizer position, instead start here (useful for skipping comments, etc) - peek: don't advance the position, only "peek" at the input This will skip over line comments. Meaning it will skip the entire line if it encounters '//' """ - i = self.next_pos(start) + i = self.next_pos() # construct the span: span = Span(i, self._find_token_end(i), self.input) # advance pointer if not peeking From 696861905b2fe1e45b5d7f8a15282812dd8a95f2 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Fri, 20 Jan 2023 17:22:01 +0000 Subject: [PATCH 47/65] parser: remove BacktrackingAbort --- xdsl/parser.py | 37 +++---------------------------------- 1 file changed, 3 insertions(+), 34 deletions(-) diff --git a/xdsl/parser.py b/xdsl/parser.py index 7c5583a9fb..eb48d8fe17 100644 --- a/xdsl/parser.py +++ b/xdsl/parser.py @@ -122,19 +122,6 @@ def __hash__(self): return id(self) -class BacktrackingAbort(Exception): - reason: str | None - - def __init__(self, reason: str | None = None): - super().__init__( - "This message should never escape the parser, it's intended to signal a failed parsing " - "attempt\n " - "It should never be used outside of a tokenizer.backtracking() block!\n" - "The reason for this abort was {}".format( - 'not specified' if reason is None else reason)) - self.reason = reason - - @dataclass(frozen=True) class Span: """ @@ -295,7 +282,7 @@ class Tokenizer: with tokenizer.backtracking(): # try stuff - raise BacktrackingAbort(...) + raise ParseError(...) and not worry about manually resetting the input position. Backtracking will also record errors that happen during backtracking to provide a richer error reporting @@ -361,7 +348,6 @@ def backtracking(self, region_name: str | None = None): The backtracker accepts the following exceptions: - ParseError: signifies that the region could not be parsed because of (unexpected) syntax errors - - BacktrackingAbort: signifies that backtracking was aborted, not necessarily indicating a syntax error - AssertionError: this error should probably be phased out in favour of the two above - EOFError: signals that EOF was reached unexpectedly @@ -379,11 +365,6 @@ def backtracking(self, region_name: str | None = None): except Exception as ex: how_far_we_got = self.pos - # AssertionErrors act upon the consumed token, this means we only go to the start of the token - if isinstance(ex, BacktrackingAbort): - # TODO: skip space as well - how_far_we_got -= self.last_token.len - # if we have no error history, start recording! if not self.history: self.history = self.history_entry_from_exception( @@ -431,18 +412,6 @@ def history_entry_from_exception(self, ex: Exception, region: str, region, pos, ) - elif isinstance(ex, BacktrackingAbort): - return BacktrackingHistory( - ParseError( - self.next_token(peek=True), - "Backtracking aborted: {}".format(ex.reason - or "unknown reason"), - self.history, - ), - self.history, - region, - pos, - ) elif isinstance(ex, EOFError): return BacktrackingHistory( ParseError(self.last_token, "Encountered EOF", self.history), @@ -1675,7 +1644,7 @@ def try_parse_builtin_type(self) -> Attribute | None: name = self.tokenizer.next_token_of_pattern( ParserCommons.builtin_type) if name is None: - raise BacktrackingAbort("Expected builtin name!") + raise self.raise_error("Expected builtin name!") return self.must_parse_builtin_type_with_name(name) @@ -1790,7 +1759,7 @@ def try_parse_builtin_type(self) -> Attribute | None: name = self.tokenizer.next_token_of_pattern( ParserCommons.builtin_type_xdsl) if name is None: - raise BacktrackingAbort("Expected builtin name!") + self.raise_error("Expected builtin name!") # xdsl builtin types have a '!' prefix, we strip that out here name = Span(start=name.start + 1, end=name.end, input=name.input) From ebbc7bca330ff301681b306f21d3d2b6f3077f17 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Fri, 20 Jan 2023 17:32:27 +0000 Subject: [PATCH 48/65] parser: add docstring for BacktrackingHistory --- xdsl/parser.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/xdsl/parser.py b/xdsl/parser.py index eb48d8fe17..6025095cc2 100644 --- a/xdsl/parser.py +++ b/xdsl/parser.py @@ -85,6 +85,24 @@ def print_pretty(self, file=sys.stderr): @dataclass class BacktrackingHistory: + """ + This class holds on to past errors encountered during parsing. + + Given the following error message: + :2:12 + %0 : !invalid = arith.constant() ["value" = 1 : !i32] + ^^^^^^^ + 'invalid' is not a known attribute + + :2:7 + %0 : !invalid = arith.constant() ["value" = 1 : !i32] + ^ + Expected type of value-id here! + + The BacktrackingHistory will contain the outermost error (expected type of value-id here) + It's parent will be the next error message (not a known attribute). + Some errors happen in named regions (e.g. "parsing of operation") + """ error: ParseError parent: BacktrackingHistory | None region_name: str | None From 56490fd7f3c8bdcfe4754f9a4aecb181e58b1985 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Fri, 20 Jan 2023 17:35:49 +0000 Subject: [PATCH 49/65] parser: renamed begin_parse to parse_module --- tests/test_printer.py | 2 +- xdsl/parser.py | 2 +- xdsl/xdsl_opt_main.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_printer.py b/tests/test_printer.py index 06c9406f01..6e4b1ef5b9 100644 --- a/tests/test_printer.py +++ b/tests/test_printer.py @@ -185,7 +185,7 @@ def test_two_different_op_messages(): ctx.register_dialect(Builtin) parser = XDSLParser(ctx, prog, '') - module = parser.begin_parse() + module = parser.parse_module() file = StringIO("") diagnostic = Diagnostic() diff --git a/xdsl/parser.py b/xdsl/parser.py index 6025095cc2..b35b918853 100644 --- a/xdsl/parser.py +++ b/xdsl/parser.py @@ -669,7 +669,7 @@ def __init__(self, self.forward_block_references = dict() self.allow_unregistered_ops = allow_unregistered_ops - def begin_parse(self) -> ModuleOp: + def parse_module(self) -> ModuleOp: op = self.try_parse_operation() if op is None: diff --git a/xdsl/xdsl_opt_main.py b/xdsl/xdsl_opt_main.py index fccf99517b..2fac8e3559 100644 --- a/xdsl/xdsl_opt_main.py +++ b/xdsl/xdsl_opt_main.py @@ -295,7 +295,7 @@ def parse_input(self) -> ModuleOp: parser = self.available_frontends[file_extension]( self.ctx, f.read(), self.args.input_file or 'stdin', self.args.allow_unregistered_ops) - return parser.begin_parse() + return parser.parse_module() def apply_passes(self, prog: ModuleOp): """Apply passes in order.""" From 932e40309da729bb7274d0833915e5d801d430df Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Fri, 20 Jan 2023 17:36:02 +0000 Subject: [PATCH 50/65] parser: added return type to get_block_from_name --- xdsl/parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xdsl/parser.py b/xdsl/parser.py index b35b918853..8fa4d0fe32 100644 --- a/xdsl/parser.py +++ b/xdsl/parser.py @@ -682,7 +682,7 @@ def parse_module(self) -> ModuleOp: self.raise_error("Expected ModuleOp at top level!", self.tokenizer.next_token()) - def get_block_from_name(self, block_name: Span): + def get_block_from_name(self, block_name: Span) -> Block: """ This function takes a span containing a block id (like `^42`) and returns a block. From 4e985cd6dc68c4b9e41373cc4ad1733d86b744f0 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Mon, 23 Jan 2023 12:45:40 +0000 Subject: [PATCH 51/65] parser: fix typos and alignement issues --- xdsl/parser.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/xdsl/parser.py b/xdsl/parser.py index 8fa4d0fe32..9419cc2056 100644 --- a/xdsl/parser.py +++ b/xdsl/parser.py @@ -1518,7 +1518,6 @@ def must_parse_region_list(self) -> list[Region]: regions.append(self.must_parse_region()) return regions - # COMMON xDSL/MLIR code: def must_parse_builtin_type_with_name(self, name: Span): if name.text == "index": return IndexType() @@ -1571,8 +1570,8 @@ def must_parse_op_args_list(self) -> list[Span]: raise NotImplementedError() # HERE STARTS A SOMEWHAT CURSED COMPATIBILITY LAYER: - # since we don't want to rewrite all dialects currently, the new emulator needs to expose the same - # interface to the dialect definitions. Here we implement that interface. + # since we don't want to rewrite all dialects currently, the new parser needs to expose the same + # interface to the dialect definitions (to some extent). Here we implement that interface. _OperationType = TypeVar("_OperationType", bound=Operation) @@ -1646,8 +1645,10 @@ def try_parse_builtin_dict_attr(self): attr_def = self.ctx.get_optional_attr('dictionary') if attr_def is None: self.raise_error( - "An attribute named `dictionary` must be available in the context in order to parse dictionary attributes! Please make sure the builtin dialect is available, or provide your own replacement!" - ) + "An attribute named `dictionary` must be available in the " + "context in order to parse dictionary attributes! Please make " + "sure the builtin dialect is available, or provide your own " + "replacement!") param = attr_def.parse_parameter(self) return attr_def(param) From ea9a75bce3ec643da591966601a78f17029cb522 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Mon, 23 Jan 2023 12:52:13 +0000 Subject: [PATCH 52/65] parser: removed get_nth_line_bounds - unused function --- xdsl/parser.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/xdsl/parser.py b/xdsl/parser.py index 9419cc2056..1b11049509 100644 --- a/xdsl/parser.py +++ b/xdsl/parser.py @@ -248,15 +248,6 @@ def len(self): def __len__(self): return self.len - def get_nth_line_bounds(self, n: int): - start = 0 - for i in range(n): - next_start = self.content.find('\n', start) - if next_start == -1: - return None - start = next_start + 1 - return start, self.content.find('\n', start) - def get_lines_containing(self, span: Span) -> tuple[list[str], int, int] | None: # A pointer to the start of the first line From db7c4cb7794f25785cc3b02b1cd64ea338969c33 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Mon, 23 Jan 2023 13:08:21 +0000 Subject: [PATCH 53/65] parser: fix minor nitpicks in tests --- tests/test_printer.py | 31 +++++++++++++------------------ xdsl/parser.py | 5 +++++ 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/tests/test_printer.py b/tests/test_printer.py index 6e4b1ef5b9..0b8735f71a 100644 --- a/tests/test_printer.py +++ b/tests/test_printer.py @@ -184,7 +184,7 @@ def test_two_different_op_messages(): ctx.register_dialect(Arith) ctx.register_dialect(Builtin) - parser = XDSLParser(ctx, prog, '') + parser = XDSLParser(ctx, prog) module = parser.parse_module() file = StringIO("") @@ -220,7 +220,7 @@ def test_two_same_op_messages(): ctx.register_dialect(Arith) ctx.register_dialect(Builtin) - parser = XDSLParser(ctx, prog, '') + parser = XDSLParser(ctx, prog) module = parser.parse_op() file = StringIO("") @@ -254,7 +254,7 @@ def test_op_message_with_region(): ctx.register_dialect(Arith) ctx.register_dialect(Builtin) - parser = XDSLParser(ctx, prog, '') + parser = XDSLParser(ctx, prog) module = parser.parse_op() file = StringIO("") @@ -290,7 +290,7 @@ def test_op_message_with_region_and_overflow(): ctx.register_dialect(Arith) ctx.register_dialect(Builtin) - parser = XDSLParser(ctx, prog, '') + parser = XDSLParser(ctx, prog) module = parser.parse_op() file = StringIO("") @@ -316,7 +316,7 @@ def test_diagnostic(): ctx.register_dialect(Arith) ctx.register_dialect(Builtin) - parser = XDSLParser(ctx, prog, '') + parser = XDSLParser(ctx, prog) module = parser.parse_op() diag = Diagnostic() @@ -356,7 +356,7 @@ def test_print_custom_name(): ctx.register_dialect(Arith) ctx.register_dialect(Builtin) - parser = XDSLParser(ctx, prog, '') + parser = XDSLParser(ctx, prog) module = parser.parse_op() file = StringIO("") @@ -384,11 +384,6 @@ class PlusCustomFormatOp(Operation): def parse(cls, result_types: List[Attribute], parser: BaseParser) -> PlusCustomFormatOp: - def get_ssa_val(name: Span) -> SSAValue: - if name.text not in parser.ssaValues: - parser.raise_error('SSA Value used before assignment', name) - return parser.ssaValues[name.text] - lhs = parser.expect(parser.try_parse_value_id, 'Expected SSA Value name here!') parser.must_parse_characters( @@ -397,7 +392,7 @@ def get_ssa_val(name: Span) -> SSAValue: 'Expected SSA Value name here!') return PlusCustomFormatOp.create( - operands=[get_ssa_val(name) for name in (lhs, rhs)], + operands=[parser.get_ssa_val(lhs), parser.get_ssa_val(rhs)], result_types=result_types) def print(self, printer: Printer): @@ -426,7 +421,7 @@ def test_generic_format(): ctx.register_dialect(Builtin) ctx.register_op(PlusCustomFormatOp) - parser = XDSLParser(ctx, prog, '') + parser = XDSLParser(ctx, prog) module = parser.parse_op() file = StringIO("") @@ -457,7 +452,7 @@ def test_custom_format(): ctx.register_dialect(Builtin) ctx.register_op(PlusCustomFormatOp) - parser = XDSLParser(ctx, prog, '') + parser = XDSLParser(ctx, prog) module = parser.parse_op() file = StringIO("") @@ -488,7 +483,7 @@ def test_custom_format_II(): ctx.register_dialect(Builtin) ctx.register_op(PlusCustomFormatOp) - parser = XDSLParser(ctx, prog, '') + parser = XDSLParser(ctx, prog) module = parser.parse_op() file = StringIO("") @@ -546,7 +541,7 @@ def test_custom_format_attr(): ctx.register_op(AnyOp) ctx.register_attr(CustomFormatAttr) - parser = XDSLParser(ctx, prog, '') + parser = XDSLParser(ctx, prog) module = parser.parse_op() file = StringIO("") @@ -575,7 +570,7 @@ def test_parse_generic_format_attr(): ctx.register_op(AnyOp) ctx.register_attr(CustomFormatAttr) - parser = XDSLParser(ctx, prog, '') + parser = XDSLParser(ctx, prog) module = parser.parse_op() file = StringIO("") @@ -604,7 +599,7 @@ def test_parse_generic_format_attr_II(): ctx.register_op(AnyOp) ctx.register_attr(CustomFormatAttr) - parser = XDSLParser(ctx, prog, '') + parser = XDSLParser(ctx, prog) module = parser.parse_op() file = StringIO("") diff --git a/xdsl/parser.py b/xdsl/parser.py index 1b11049509..7f200d7bc1 100644 --- a/xdsl/parser.py +++ b/xdsl/parser.py @@ -673,6 +673,11 @@ def parse_module(self) -> ModuleOp: self.raise_error("Expected ModuleOp at top level!", self.tokenizer.next_token()) + def get_ssa_val(self, name: Span) -> SSAValue: + if name.text not in self.ssaValues: + self.raise_error('SSA Value used before assignment', name) + return self.ssaValues[name.text] + def get_block_from_name(self, block_name: Span) -> Block: """ This function takes a span containing a block id (like `^42`) and returns a block. From 4e27ede088c756747438f81aa1184f7044ac4dc6 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Mon, 23 Jan 2023 13:38:45 +0000 Subject: [PATCH 54/65] xdsl: revert back to a callable interface for xdsl-opt frontends --- tests/test_printer.py | 3 ++- xdsl/xdsl_opt_main.py | 29 +++++++++++++++++++++-------- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/tests/test_printer.py b/tests/test_printer.py index 0b8735f71a..911e62f5c7 100644 --- a/tests/test_printer.py +++ b/tests/test_printer.py @@ -392,7 +392,8 @@ def parse(cls, result_types: List[Attribute], 'Expected SSA Value name here!') return PlusCustomFormatOp.create( - operands=[parser.get_ssa_val(lhs), parser.get_ssa_val(rhs)], + operands=[parser.get_ssa_val(lhs), + parser.get_ssa_val(rhs)], result_types=result_types) def print(self, printer: Printer): diff --git a/xdsl/xdsl_opt_main.py b/xdsl/xdsl_opt_main.py index 2fac8e3559..25ca9c8421 100644 --- a/xdsl/xdsl_opt_main.py +++ b/xdsl/xdsl_opt_main.py @@ -3,9 +3,10 @@ import os from io import IOBase, StringIO import coverage +from typing.io import IO from xdsl.ir import MLContext -from xdsl.parser import Parser, XDSLParser, MLIRParser, BaseParser +from xdsl.parser import XDSLParser, MLIRParser from xdsl.printer import Printer from xdsl.dialects.func import Func from xdsl.dialects.scf import Scf @@ -33,7 +34,7 @@ class xDSLOptMain: attributes. """ - available_frontends: Dict[str, type[BaseParser]] + available_frontends: Dict[str, Callable[[IOBase], ModuleOp]] """ A mapping from file extension to a frontend that can handle this file type. @@ -215,8 +216,20 @@ def register_all_frontends(self): Add other/additional frontends by overloading this function. """ - self.available_frontends['xdsl'] = XDSLParser - self.available_frontends['mlir'] = MLIRParser + def parse_xdsl(io: IOBase): + return XDSLParser( + self.ctx, io.read(), self.get_input_name(), + self.args.allow_unregistered_ops + ).parse_module() + + def parse_mlir(io: IOBase): + return MLIRParser( + self.ctx, io.read(), self.get_input_name(), + self.args.allow_unregistered_ops + ).parse_module() + + self.available_frontends['xdsl'] = parse_xdsl + self.available_frontends['mlir'] = parse_mlir def register_all_passes(self): """ @@ -292,10 +305,7 @@ def parse_input(self) -> ModuleOp: if file_extension not in self.available_frontends: raise Exception(f"Unrecognized file extension '{file_extension}'") - parser = self.available_frontends[file_extension]( - self.ctx, f.read(), self.args.input_file or 'stdin', - self.args.allow_unregistered_ops) - return parser.parse_module() + return self.available_frontends[file_extension](f) def apply_passes(self, prog: ModuleOp): """Apply passes in order.""" @@ -329,3 +339,6 @@ def print_to_output_stream(self, contents: str): else: output_stream = open(self.args.output_file, 'w') output_stream.write(contents) + + def get_input_name(self): + return self.args.input_file or 'stdin' From 7a40bf251b3d09408e59ff4a3f8fa7e624dedcad Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Mon, 23 Jan 2023 13:40:51 +0000 Subject: [PATCH 55/65] tests: removed a bunch of unneeded arguments for the parser --- tests/test_parser.py | 2 +- tests/test_printer.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_parser.py b/tests/test_parser.py index b49d5ff175..abad6663c2 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -13,7 +13,7 @@ ("1, 1, 0", [1, 1, 0])]) def test_int_list_parser(input: str, expected: list[int]): ctx = MLContext() - parser = XDSLParser(ctx, input, '') + parser = XDSLParser(ctx, input) int_list = parser.must_parse_list_of(parser.try_parse_integer_literal, '') assert [int(span.text) for span in int_list] == expected diff --git a/tests/test_printer.py b/tests/test_printer.py index 911e62f5c7..b051a73210 100644 --- a/tests/test_printer.py +++ b/tests/test_printer.py @@ -149,7 +149,7 @@ def test_op_message(): ctx.register_dialect(Arith) ctx.register_dialect(Builtin) - parser = XDSLParser(ctx, prog, '') + parser = XDSLParser(ctx, prog) module = parser.parse_op() file = StringIO("") @@ -655,7 +655,7 @@ def test_parse_dense_xdsl(): ctx.register_dialect(Builtin) ctx.register_dialect(Arith) - parser = XDSLParser(ctx, prog, '') + parser = XDSLParser(ctx, prog) module = parser.parse_op() file = StringIO("") @@ -703,7 +703,7 @@ def test_foo_string(): ctx.register_op(AnyOp) ctx.register_attr(CustomFormatAttr) - parser = XDSLParser(ctx, prog, '') + parser = XDSLParser(ctx, prog) try: parser.parse_op() assert False @@ -722,7 +722,7 @@ def test_dictionary_attr(): ctx.register_dialect(Builtin) ctx.register_dialect(Func) - parser = XDSLParser(ctx, prog, '') + parser = XDSLParser(ctx, prog) parsed = parser.parse_op() file = StringIO("") From 46fee686dda09a09082cfe2732f3d3926c811a60 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Mon, 23 Jan 2023 14:04:56 +0000 Subject: [PATCH 56/65] parser: move stuff around, fix formatting --- tests/test_parser_error.py | 3 +- xdsl/dialects/builtin.py | 4 +-- xdsl/parser.py | 71 +++++--------------------------------- xdsl/utils/exceptions.py | 67 ++++++++++++++++++++++++++++++++++- xdsl/xdsl_opt_main.py | 13 +++---- 5 files changed, 84 insertions(+), 74 deletions(-) diff --git a/tests/test_parser_error.py b/tests/test_parser_error.py index 98fdbdcc3e..71938f8319 100644 --- a/tests/test_parser_error.py +++ b/tests/test_parser_error.py @@ -6,7 +6,8 @@ from xdsl.ir import MLContext from xdsl.irdl import AnyAttr, irdl_op_definition, Operation, VarOperand, VarOpResult -from xdsl.parser import Parser, ParseError, XDSLParser +from xdsl.parser import XDSLParser +from xdsl.utils.exceptions import ParseError @irdl_op_definition diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index 1f9cb3dd5f..0b4589540c 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -16,7 +16,8 @@ from xdsl.utils.exceptions import VerifyException if TYPE_CHECKING: - from xdsl.parser import BaseParser, ParseError + from xdsl.parser import BaseParser + from utils.exceptions import ParseError from xdsl.printer import Printer @@ -202,7 +203,6 @@ def from_params( AnyIntegerAttr: TypeAlias = IntegerAttr[IntegerType | IndexType] -DefaultIntegerAttrType = i64 @irdl_attr_definition diff --git a/xdsl/parser.py b/xdsl/parser.py index 7f200d7bc1..008b0fd7af 100644 --- a/xdsl/parser.py +++ b/xdsl/parser.py @@ -14,75 +14,18 @@ from io import StringIO from typing import TypeVar, Iterable +from xdsl.utils.exceptions import ParseError, MultipleSpansParseError from xdsl.dialects.memref import MemRefType, UnrankedMemrefType from xdsl.dialects.builtin import ( AnyTensorType, AnyVectorType, Float16Type, Float32Type, Float64Type, FloatAttr, FunctionType, IndexType, IntegerType, Signedness, StringAttr, IntegerAttr, ArrayAttr, TensorType, UnrankedTensorType, VectorType, - DefaultIntegerAttrType, FlatSymbolRefAttr, DenseIntOrFPElementsAttr, - UnregisteredOp, OpaqueAttr, NoneAttr, ModuleOp, UnitAttr) + FlatSymbolRefAttr, DenseIntOrFPElementsAttr, UnregisteredOp, OpaqueAttr, + NoneAttr, ModuleOp, UnitAttr, i64) from xdsl.ir import (SSAValue, Block, Callable, Attribute, Operation, Region, BlockArgument, MLContext, ParametrizedAttribute, Data) -class ParseError(Exception): - span: Span - msg: str - history: BacktrackingHistory | None - - def __init__(self, - span: Span, - msg: str, - history: BacktrackingHistory | None = None): - preamble = "" - if history: - preamble = history.error.args[0] + '\n' - if span is None: - raise ValueError("Span can't be None!") - super().__init__(preamble + span.print_with_context(msg)) - self.span = span - self.msg = msg - self.history = history - - def print_pretty(self, file=sys.stderr): - print(self.span.print_with_context(self.msg), file=file) - - def print_with_history(self, file=sys.stderr): - if self.history is not None: - for h in sorted(self.history.iterate(), key=lambda h: -h.pos): - h.print() - else: - self.print_pretty(file) - - def __repr__(self): - io = StringIO() - self.print_with_history(io) - return "{}:\n{}".format(self.__class__.__name__, io.getvalue()) - - -class MultipleSpansParseError(ParseError): - ref_text: str | None - refs: list[tuple[Span, str]] - - def __init__( - self, - span: Span, - msg: str, - ref_text: str, - refs: list[tuple[Span, str | None]], - history: BacktrackingHistory | None = None, - ): - super(MultipleSpansParseError, self).__init__(span, msg, history) - self.refs = refs - self.ref_text = ref_text - - def print_pretty(self, file=sys.stderr): - super(MultipleSpansParseError, self).print_pretty(file) - print(self.ref_text or "With respect to:", file=file) - for span, msg in self.refs: - print(span.print_with_context(msg), file=file) - - @dataclass class BacktrackingHistory: """ @@ -223,6 +166,11 @@ def __post_init__(self): @classmethod def from_span(cls, span: Span | None) -> StringLiteral | None: + """ + Convert a normal span into a StringLiteral, to facilitate parsing. + + If argument is None, returns None. + """ if span is None: return None return cls(span.start, span.end, span.input) @@ -1377,8 +1325,7 @@ def try_parse_builtin_int_attr(self) -> IntegerAttr | None: self.try_parse_integer_literal, 'Integer attribute must start with an integer literal!') if self.tokenizer.next_token(peek=True).text != ':': - return IntegerAttr.from_params(int(value.text), - DefaultIntegerAttrType) + return IntegerAttr.from_params(int(value.text), i64) type = self.must_parse_attribute_type() return IntegerAttr.from_params(int(value.text), type) diff --git a/xdsl/utils/exceptions.py b/xdsl/utils/exceptions.py index c167fd472d..6bbb53358a 100644 --- a/xdsl/utils/exceptions.py +++ b/xdsl/utils/exceptions.py @@ -2,10 +2,17 @@ This module contains all custom exceptions used by xDSL. """ +from __future__ import annotations +import sys +import typing from dataclasses import dataclass +from io import StringIO from typing import Any -from xdsl.ir import Attribute + +if typing.TYPE_CHECKING: + from parser import Span, BacktrackingHistory + from xdsl.ir import Attribute class DiagnosticException(Exception): @@ -28,3 +35,61 @@ class BuilderNotFoundException(Exception): def __str__(self) -> str: return f"No builder found for attribute {self.attribute} with " \ f"arguments {self.args}" + + +class ParseError(Exception): + span: 'Span' + msg: str + history: 'BacktrackingHistory' | None + + def __init__(self, + span: 'Span', + msg: str, + history: 'BacktrackingHistory' | None = None): + preamble = "" + if history: + preamble = history.error.args[0] + '\n' + if span is None: + raise ValueError("Span can't be None!") + super().__init__(preamble + span.print_with_context(msg)) + self.span = span + self.msg = msg + self.history = history + + def print_pretty(self, file=sys.stderr): + print(self.span.print_with_context(self.msg), file=file) + + def print_with_history(self, file=sys.stderr): + if self.history is not None: + for h in sorted(self.history.iterate(), key=lambda h: -h.pos): + h.print() + else: + self.print_pretty(file) + + def __repr__(self): + io = StringIO() + self.print_with_history(io) + return "{}:\n{}".format(self.__class__.__name__, io.getvalue()) + + +class MultipleSpansParseError(ParseError): + ref_text: str | None + refs: list[tuple['Span', str]] + + def __init__( + self, + span: 'Span', + msg: str, + ref_text: str, + refs: list[tuple['Span', str | None]], + history: 'BacktrackingHistory' | None = None, + ): + super(MultipleSpansParseError, self).__init__(span, msg, history) + self.refs = refs + self.ref_text = ref_text + + def print_pretty(self, file=sys.stderr): + super(MultipleSpansParseError, self).print_pretty(file) + print(self.ref_text or "With respect to:", file=file) + for span, msg in self.refs: + print(span.print_with_context(msg), file=file) diff --git a/xdsl/xdsl_opt_main.py b/xdsl/xdsl_opt_main.py index 25ca9c8421..c9c0e45fc4 100644 --- a/xdsl/xdsl_opt_main.py +++ b/xdsl/xdsl_opt_main.py @@ -216,17 +216,14 @@ def register_all_frontends(self): Add other/additional frontends by overloading this function. """ + def parse_xdsl(io: IOBase): - return XDSLParser( - self.ctx, io.read(), self.get_input_name(), - self.args.allow_unregistered_ops - ).parse_module() + return XDSLParser(self.ctx, io.read(), self.get_input_name(), + self.args.allow_unregistered_ops).parse_module() def parse_mlir(io: IOBase): - return MLIRParser( - self.ctx, io.read(), self.get_input_name(), - self.args.allow_unregistered_ops - ).parse_module() + return MLIRParser(self.ctx, io.read(), self.get_input_name(), + self.args.allow_unregistered_ops).parse_module() self.available_frontends['xdsl'] = parse_xdsl self.available_frontends['mlir'] = parse_mlir From 2f09883c86491f3495a0f452edc7fc0ca40b4d55 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Mon, 23 Jan 2023 14:13:59 +0000 Subject: [PATCH 57/65] parser: uppercase comments --- xdsl/parser.py | 102 ++++++++++++++++++++++++------------------------- 1 file changed, 51 insertions(+), 51 deletions(-) diff --git a/xdsl/parser.py b/xdsl/parser.py index 008b0fd7af..2976dd634e 100644 --- a/xdsl/parser.py +++ b/xdsl/parser.py @@ -130,7 +130,7 @@ def print_with_context(self, msg: str | None = None) -> str: if info is None: return "Unknown location of span {}. Error: ".format(self, msg) lines, offset_of_first_line, line_no = info - # offset relative to the first line: + # Offset relative to the first line: offset = self.start - offset_of_first_line remaining_len = max(self.len, 1) capture = StringIO() @@ -205,16 +205,16 @@ def get_lines_containing(self, while True: next_start = source.find('\n', start) line_no += 1 - # handle eof + # Handle eof if next_start == -1: if span.start > len(source): return None return [source[start:]], start, line_no - # as long as the next newline comes before the spans start we can continue + # As long as the next newline comes before the spans start we can continue if next_start < span.start: start = next_start + 1 continue - # if the whole span is on one line, we are good as well + # If the whole span is on one line, we are good as well if next_start >= span.end: return [source[start:next_start]], start, line_no while next_start < span.end: @@ -238,7 +238,7 @@ class Tokenizer: It provides an interface for backtracking, so you can use: with tokenizer.backtracking(): - # try stuff + # Try stuff raise ParseError(...) and not worry about manually resetting the input position. Backtracking will also @@ -314,28 +314,28 @@ def backtracking(self, region_name: str | None = None): starting_position = self.pos try: yield - # clear error history when something doesn't fail - # this is because we are only interested in the last "cascade" of failures. - # if a backtracking() completes without failure, something has been parsed (we assume) + # Clear error history when something doesn't fail + # Lhis is because we are only interested in the last "cascade" of failures. + # If a backtracking() completes without failure, something has been parsed (we assume) if self.pos > starting_position and self.history is not None: self.history = None except Exception as ex: how_far_we_got = self.pos - # if we have no error history, start recording! + # If we have no error history, start recording! if not self.history: self.history = self.history_entry_from_exception( ex, region_name, how_far_we_got) - # if we got further than on previous attempts + # If we got further than on previous attempts elif how_far_we_got > self.history.get_farthest_point(): - # throw away history + # Throw away history self.history = None - # generate new history entry, + # Generate new history entry, self.history = self.history_entry_from_exception( ex, region_name, how_far_we_got) - # otherwise, add to exception, if we are in a named region + # Otherwise, add to exception, if we are in a named region elif region_name is not None and how_far_we_got - starting_position > 0: self.history = self.history_entry_from_exception( ex, region_name, how_far_we_got) @@ -357,7 +357,7 @@ def history_entry_from_exception(self, ex: Exception, region: str, "Generic assertion failure", *(reason for reason in ex.args if isinstance(reason, str)), ] - # we assume that assertions fail because of the last read-in token + # We assume that assertions fail because of the last read-in token if len(reason) == 1: tb = StringIO() traceback.print_exc(file=tb) @@ -398,13 +398,13 @@ def next_token(self, peek: bool = False) -> Span: This will skip over line comments. Meaning it will skip the entire line if it encounters '//' """ i = self.next_pos() - # construct the span: + # Construct the span: span = Span(i, self._find_token_end(i), self.input) - # advance pointer if not peeking + # Advance pointer if not peeking if not peek: self.pos = span.end - # save last token + # Save last token self.last_token = span return span @@ -419,7 +419,7 @@ def next_token_of_pattern(self, except EOFError: return None - # handle search for string literal + # Handle search for string literal if isinstance(pattern, str): if self.starts_with(pattern): if not peek: @@ -427,7 +427,7 @@ def next_token_of_pattern(self, return Span(start, start + len(pattern), self.input) return None - # handle regex logic + # Handle regex logic match = pattern.match(self.input.content, start) if match is None: return None @@ -435,7 +435,7 @@ def next_token_of_pattern(self, if not peek: self.pos = match.end() - # save last token + # Save last token self.last_token = Span(start, match.end(), self.input) return self.last_token @@ -449,11 +449,11 @@ def _find_token_end(self, start: int | None = None) -> int: Find the point (optionally starting from start) where the token ends """ i = self.next_pos() if start is None else start - # search for literal breaks + # Search for literal breaks for part in self.break_on: if self.input.content.startswith(part, i): return i + len(part) - # otherwise return the start of the next break + # Otherwise return the start of the next break return min( filter( lambda x: x >= 0, @@ -467,11 +467,11 @@ def next_pos(self, i: int | None = None) -> int: This will skip line comments! """ i = self.pos if i is None else i - # skip whitespaces + # Skip whitespaces while self.input.at(i).isspace(): i += 1 - # skip comments as well + # Skip comments as well if self.input.content.startswith("//", i): i = self.input.content.find("\n", i) + 1 return self.next_pos(i) @@ -493,12 +493,12 @@ def configured(self, break_on: tuple[str, ...]): """ This is a helper class to allow expressing a temporary change in config, allowing you to write: - # parsing double-quoted string now + # Parsing double-quoted string now string_content = "" with tokenizer.configured(break_on=('"', '\\'),): - # use tokenizer + # Use tokenizer - # now old config is restored automatically + # Now old config is restored automatically """ save = self.save() @@ -542,7 +542,7 @@ class ParserCommons: type_alias = re.compile(r"![A-Za-z_][\w$.]+") attribute_alias = re.compile(r"#[A-Za-z_][\w$.]+") boolean_literal = re.compile(r"(true|false)") - # a list of + # A list of names that are builtin types _builtin_type_names = ( r"[su]?i\d+", r"f\d+", "tensor", "vector", "memref", "complex", "opaque", "tuple", "index", "dense" @@ -741,7 +741,7 @@ def must_parse_list_of(self, ) is not None: next_item = try_parse() if next_item is None: - # if the separator is emtpy, we are good here + # If the separator is emtpy, we are good here if separator_pattern.pattern == '': return items self.raise_error(error_msg + @@ -858,7 +858,7 @@ def must_parse_dialect_type_or_attribute_inner(self, kind: str): "'{}' is not a know attribute!".format(type_name.text), type_name) - # pass the task of parsing parameters on to the attribute/type definition + # Pass the task of parsing parameters on to the attribute/type definition if issubclass(type_def, ParametrizedAttribute): param_list = type_def.parse_parameters(self) elif issubclass(type_def, Data): @@ -894,7 +894,7 @@ def unimplemented() -> ParametrizedAttribute: } self.must_parse_characters("<", "Expected parameter list here!") - # get the parser for the type, falling back to the unimplemented warning + # Get the parser for the type, falling back to the unimplemented warning res = builtin_parsers.get(name.text, unimplemented)() self.must_parse_characters(">", "Expected end of parameter list here!") @@ -916,21 +916,21 @@ def try_parse_numerical_dims(self, while (shape_arg := self.try_parse_shape_element(lower_bound)) is not None: yield shape_arg - # look out for the closing bracket for scalable vector dims + # Look out for the closing bracket for scalable vector dims if accept_closing_bracket and self.tokenizer.starts_with("]"): break self.must_parse_characters( "x", "Unexpected end of dimension parameters!") def must_parse_vector_attrs(self) -> AnyVectorType: - # also break on 'x' characters as they are separators in dimension parameters + # Also break on 'x' characters as they are separators in dimension parameters with self.tokenizer.configured(break_on=self.tokenizer.break_on + ("x", )): shape = list[int](self.try_parse_numerical_dims()) scaling_shape: list[int] | None = None if self.tokenizer.next_token_of_pattern("[") is not None: - # we now need to parse the scalable dimensions + # We now need to parse the scalable dimensions scaling_shape = list(self.try_parse_numerical_dims()) self.must_parse_characters( "]", "Expected end of scalable vector dimensions here!") @@ -952,14 +952,14 @@ def must_parse_vector_attrs(self) -> AnyVectorType: def must_parse_tensor_or_memref_dims(self) -> list[int] | None: with self.tokenizer.configured(break_on=self.tokenizer.break_on + ('x', )): - # check for unranked-ness + # Check for unranked-ness if self.tokenizer.next_token_of_pattern('*') is not None: - # consume `x` + # Consume `x` self.must_parse_characters( 'x', 'Unranked tensors must follow format (`<*x` type `>`)') else: - # parse rank: + # Parse rank: return list(self.try_parse_numerical_dims(lower_bound=0)) def must_parse_tensor_attrs(self) -> AnyTensorType: @@ -1003,7 +1003,7 @@ def try_parse_shape_element(self, lower_bound: int = 1) -> int | None: return None def must_parse_type_params(self) -> list[Attribute]: - # consume opening bracket + # Consume opening bracket self.must_parse_characters('<', 'Type must be parameterized!') params = self.must_parse_list_of(self.try_parse_type, @@ -1056,13 +1056,13 @@ def must_parse_operation(self) -> Operation: '=', 'Operation definitions expect an `=` after op-result-list!') - # check for custom op format + # Check for custom op format op_name = self.try_parse_bare_id() if op_name is not None: op_type = self._get_op_by_name(op_name) op = op_type.parse(ret_types, self) else: - # check for basic op format + # Check for basic op format op_name = self.try_parse_string_literal() if op_name is None: self.raise_error( @@ -1128,7 +1128,7 @@ def must_parse_region(self) -> Region: if self.tokenizer.starts_with("}"): region.add_block(Block()) else: - # parse first block + # Parse first block block = self.must_parse_block() region.add_block(block) @@ -1223,7 +1223,7 @@ def try_parse_builtin_attr(self) -> Attribute | None: return self.try_parse_function_type() elif next_token.text in ParserCommons.builtin_attr_names: return self.try_parse_builtin_named_attr() - # order here is important! + # Order here is important! attrs = (self.try_parse_builtin_float_attr, self.try_parse_builtin_int_attr, self.try_parse_builtin_type) @@ -1335,7 +1335,7 @@ def try_parse_builtin_float_attr(self) -> FloatAttr | None: self.try_parse_float_literal, "Float attribute must start with a float literal!", ) - # if we don't see a ':' indicating a type signature + # If we don't see a ':' indicating a type signature if not self.tokenizer.starts_with(":"): return FloatAttr.from_value(float(value.text)) @@ -1513,8 +1513,8 @@ def must_parse_op_args_list(self) -> list[Span]: raise NotImplementedError() # HERE STARTS A SOMEWHAT CURSED COMPATIBILITY LAYER: - # since we don't want to rewrite all dialects currently, the new parser needs to expose the same - # interface to the dialect definitions (to some extent). Here we implement that interface. + # Since we don't want to rewrite all dialects currently, the new parser needs to expose the same + # Interface to the dialect definitions (to some extent). Here we implement that interface. _OperationType = TypeVar("_OperationType", bound=Operation) @@ -1614,11 +1614,11 @@ def must_parse_attribute(self) -> Attribute: """ Parse attribute (either builtin or dialect) """ - # all dialect attrs must start with '#', so we check for that first (as it's easier) + # All dialect attrs must start with '#', so we check for that first (as it's easier) if self.tokenizer.starts_with("#"): value = self.try_parse_dialect_attr() - # no value => error + # No value => error if value is None: self.raise_error( "`#` must be followed by a valid dialect attribute or type!" @@ -1626,7 +1626,7 @@ def must_parse_attribute(self) -> Attribute: return value - # if it isn't a dialect attr, parse builtin + # If it isn't a dialect attr, parse builtin builtin_val = self.try_parse_builtin_attr() if builtin_val is None: @@ -1722,7 +1722,7 @@ def try_parse_builtin_type(self) -> Attribute | None: ParserCommons.builtin_type_xdsl) if name is None: self.raise_error("Expected builtin name!") - # xdsl builtin types have a '!' prefix, we strip that out here + # xDSL builtin types have a '!' prefix, we strip that out here name = Span(start=name.start + 1, end=name.end, input=name.input) return self.must_parse_builtin_type_with_name(name) @@ -1738,7 +1738,7 @@ def must_parse_attribute(self) -> Attribute: # xDSL: Allow both # and ! prefixes, as we allow both types and attrs # TODO: phase out use of next_token(peek=True) in favour of starts_with if value is None and self.tokenizer.next_token(peek=True).text in "#!": - # in MLIR # and ! are prefixes for dialect attrs/types, but in xDSL ! is also used for builtin types + # In MLIR # and ! are prefixes for dialect attrs/types, but in xDSL ! is also used for builtin types value = self.try_parse_dialect_type_or_attribute() if value is None: @@ -1767,7 +1767,7 @@ def try_parse_builtin_attr(self) -> Attribute: If the mode is xDSL, it also allows parsing of builtin types """ - # in xdsl, two things are different here: + # In xdsl, two things are different here: # 1. types are considered valid attributes # 2. all types, builtins included, are prefixed with ! if self.tokenizer.starts_with("!"): From 91a8d12197aaeb3df520f1125dfdc80c29726203 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Mon, 23 Jan 2023 14:42:15 +0000 Subject: [PATCH 58/65] parser: make a bunch of methods private --- xdsl/parser.py | 148 ++++++++++++++++++++++++++----------------------- 1 file changed, 80 insertions(+), 68 deletions(-) diff --git a/xdsl/parser.py b/xdsl/parser.py index 2976dd634e..b86ffc4ff2 100644 --- a/xdsl/parser.py +++ b/xdsl/parser.py @@ -324,7 +324,7 @@ def backtracking(self, region_name: str | None = None): # If we have no error history, start recording! if not self.history: - self.history = self.history_entry_from_exception( + self.history = self._history_entry_from_exception( ex, region_name, how_far_we_got) # If we got further than on previous attempts @@ -332,18 +332,18 @@ def backtracking(self, region_name: str | None = None): # Throw away history self.history = None # Generate new history entry, - self.history = self.history_entry_from_exception( + self.history = self._history_entry_from_exception( ex, region_name, how_far_we_got) # Otherwise, add to exception, if we are in a named region elif region_name is not None and how_far_we_got - starting_position > 0: - self.history = self.history_entry_from_exception( + self.history = self._history_entry_from_exception( ex, region_name, how_far_we_got) self.resume_from(save) - def history_entry_from_exception(self, ex: Exception, region: str, - pos: int) -> BacktrackingHistory: + def _history_entry_from_exception(self, ex: Exception, region: str, + pos: int) -> BacktrackingHistory: """ Given an exception generated inside a backtracking attempt, generate a BacktrackingHistory object with the relevant information in it. @@ -626,7 +626,7 @@ def get_ssa_val(self, name: Span) -> SSAValue: self.raise_error('SSA Value used before assignment', name) return self.ssaValues[name.text] - def get_block_from_name(self, block_name: Span) -> Block: + def _get_block_from_name(self, block_name: Span) -> Block: """ This function takes a span containing a block id (like `^42`) and returns a block. @@ -639,7 +639,7 @@ def get_block_from_name(self, block_name: Span) -> Block: return self.blocks[name] def must_parse_block(self) -> Block: - block_id, args = self.must_parse_optional_block_label() + block_id, args = self._must_parse_optional_block_label() if block_id is None: block = Block(self.tokenizer.last_token) @@ -669,20 +669,23 @@ def must_parse_block(self) -> Block: return block - def must_parse_optional_block_label( + def _must_parse_optional_block_label( self) -> tuple[Span | None, list[tuple[Span, Attribute]]]: + """ + A block label consists of block-id ( `(` block-arg `,` ... `)` )? + """ block_id = self.try_parse_block_id() arg_list = list() if block_id is not None: if self.tokenizer.starts_with('('): - arg_list = self.must_parse_block_arg_list() + arg_list = self._must_parse_block_arg_list() self.must_parse_characters(':', 'Block label must end in a `:`!') return block_id, arg_list - def must_parse_block_arg_list(self) -> list[tuple[Span, Attribute]]: + def _must_parse_block_arg_list(self) -> list[tuple[Span, Attribute]]: self.must_parse_characters('(', 'Block arguments must start with `(`') args = self.must_parse_list_of(self.try_parse_value_id_and_type, @@ -819,9 +822,9 @@ def try_parse_dialect_type_or_attribute(self) -> Attribute | None: with self.tokenizer.backtracking("dialect attribute or type"): self.tokenizer.consume_peeked(kind) if kind.text == '!': - return self.must_parse_dialect_type_or_attribute_inner('type') + return self._must_parse_dialect_type_or_attribute_inner('type') else: - return self.must_parse_dialect_type_or_attribute_inner( + return self._must_parse_dialect_type_or_attribute_inner( 'attribute') def try_parse_dialect_type(self): @@ -833,7 +836,7 @@ def try_parse_dialect_type(self): with self.tokenizer.backtracking("dialect type"): self.must_parse_characters('!', "Dialect type must start with a `!`") - return self.must_parse_dialect_type_or_attribute_inner('type') + return self._must_parse_dialect_type_or_attribute_inner('type') def try_parse_dialect_attr(self): """ @@ -844,9 +847,10 @@ def try_parse_dialect_attr(self): with self.tokenizer.backtracking("dialect attribute"): self.must_parse_characters( '#', "Dialect attribute must start with a `#`") - return self.must_parse_dialect_type_or_attribute_inner('attribute') + return self._must_parse_dialect_type_or_attribute_inner( + 'attribute') - def must_parse_dialect_type_or_attribute_inner(self, kind: str): + def _must_parse_dialect_type_or_attribute_inner(self, kind: str): type_name = self.tokenizer.next_token_of_pattern(ParserCommons.bare_id) if type_name is None: @@ -878,8 +882,11 @@ def try_parse_builtin_type(self) -> Attribute | None: """ raise NotImplemented("Subclasses must implement this method!") - def must_parse_builtin_parametrized_type( + def _must_parse_builtin_parametrized_type( self, name: Span) -> ParametrizedAttribute: + """ + This function is called after we parse the name of a paremetrized type such as vector. + """ def unimplemented() -> ParametrizedAttribute: raise ParseError(name, @@ -904,7 +911,7 @@ def must_parse_complex_attrs(self): self.raise_error("ComplexType is unimplemented!") def must_parse_memref_attrs(self) -> MemRefType | UnrankedMemrefType: - dims = self.must_parse_tensor_or_memref_dims() + dims = self._must_parse_tensor_or_memref_dims() type = self.try_parse_type() if dims is None: return UnrankedMemrefType.from_type(type) @@ -914,7 +921,7 @@ def try_parse_numerical_dims(self, accept_closing_bracket: bool = False, lower_bound: int = 1) -> Iterable[int]: while (shape_arg := - self.try_parse_shape_element(lower_bound)) is not None: + self._try_parse_shape_element(lower_bound)) is not None: yield shape_arg # Look out for the closing bracket for scalable vector dims if accept_closing_bracket and self.tokenizer.starts_with("]"): @@ -949,7 +956,7 @@ def must_parse_vector_attrs(self) -> AnyVectorType: return VectorType.from_element_type_and_shape(type, shape) - def must_parse_tensor_or_memref_dims(self) -> list[int] | None: + def _must_parse_tensor_or_memref_dims(self) -> list[int] | None: with self.tokenizer.configured(break_on=self.tokenizer.break_on + ('x', )): # Check for unranked-ness @@ -963,7 +970,7 @@ def must_parse_tensor_or_memref_dims(self) -> list[int] | None: return list(self.try_parse_numerical_dims(lower_bound=0)) def must_parse_tensor_attrs(self) -> AnyTensorType: - shape = self.must_parse_tensor_or_memref_dims() + shape = self._must_parse_tensor_or_memref_dims() type = self.try_parse_type() if type is None: @@ -981,7 +988,7 @@ def must_parse_tensor_attrs(self) -> AnyTensorType: return UnrankedTensorType.from_type(type) - def try_parse_shape_element(self, lower_bound: int = 1) -> int | None: + def _try_parse_shape_element(self, lower_bound: int = 1) -> int | None: """ Parse a shape element, either a decimal integer immediate or a `?`, which evaluates to -1 @@ -1002,7 +1009,7 @@ def try_parse_shape_element(self, lower_bound: int = 1) -> int | None: return -1 return None - def must_parse_type_params(self) -> list[Attribute]: + def _must_parse_type_params(self) -> list[Attribute]: # Consume opening bracket self.must_parse_characters('<', 'Type must be parameterized!') @@ -1041,7 +1048,7 @@ def must_parse_characters(self, text: str, msg: str) -> Span: return match @abstractmethod - def must_parse_op_result_list( + def _must_parse_op_result_list( self) -> tuple[list[Span], list[Attribute] | None]: raise NotImplemented() @@ -1050,7 +1057,7 @@ def try_parse_operation(self) -> Operation | None: return self.must_parse_operation() def must_parse_operation(self) -> Operation: - result_list, ret_types = self.must_parse_op_result_list() + result_list, ret_types = self._must_parse_op_result_list() if len(result_list) > 0: self.must_parse_characters( '=', @@ -1069,7 +1076,7 @@ def must_parse_operation(self) -> Operation: "Expected an operation name here, either a bare-id, or a string literal!" ) - args, successors, attrs, regions, func_type = self.must_parse_operation_details( + args, successors, attrs, regions, func_type = self._must_parse_operation_details( ) if ret_types is None: @@ -1155,12 +1162,12 @@ def must_parse_region(self) -> Region: self.blocks = oldBBNames self.forward_block_references = oldForwardRefs - def try_parse_op_name(self) -> Span | None: + def _try_parse_op_name(self) -> Span | None: if (str_lit := self.try_parse_string_literal()) is not None: return str_lit return self.try_parse_bare_id() - def must_parse_attribute_entry(self) -> tuple[Span, Attribute]: + def _must_parse_attribute_entry(self) -> tuple[Span, Attribute]: """ Parse entry in attribute dict. Of format: @@ -1196,7 +1203,7 @@ def try_parse_attribute(self) -> Attribute | None: with self.tokenizer.backtracking("attribute"): return self.must_parse_attribute() - def must_parse_attribute_type(self) -> Attribute: + def _must_parse_attribute_type(self) -> Attribute: """ Parses `:` type and returns the type """ @@ -1237,8 +1244,8 @@ def try_parse_builtin_named_attr(self) -> Attribute | None: name.text)): self.tokenizer.consume_peeked(name) parsers = { - 'dense': self.must_parse_builtin_dense_attr, - 'opaque': self.must_parse_builtin_opaque_attr, + 'dense': self._must_parse_builtin_dense_attr, + 'opaque': self._must_parse_builtin_opaque_attr, } def not_implemented(): @@ -1246,17 +1253,17 @@ def not_implemented(): return parsers.get(name.text, not_implemented)() - def must_parse_builtin_dense_attr(self) -> Attribute | None: + def _must_parse_builtin_dense_attr(self) -> Attribute | None: err_msg = "Malformed dense attribute, format must be (`dense<` array-attr `>:` type)" self.must_parse_characters("<", err_msg) - info = list(self.must_parse_builtin_dense_attr_args()) + info = list(self._must_parse_builtin_dense_attr_args()) self.must_parse_characters(">", err_msg) self.must_parse_characters(":", err_msg) type = self.expect(self.try_parse_type, "Dense attribute must be typed!") return DenseIntOrFPElementsAttr.from_list(type, info) - def must_parse_builtin_opaque_attr(self): + def _must_parse_builtin_opaque_attr(self): self.must_parse_characters("<", "Opaque attribute must be parametrized") str_lit_list = self.must_parse_list_of(self.try_parse_string_literal, @@ -1278,7 +1285,7 @@ def must_parse_builtin_opaque_attr(self): for span in str_lit_list), type=type) - def must_parse_builtin_dense_attr_args(self) -> Iterable[int | float]: + def _must_parse_builtin_dense_attr_args(self) -> Iterable[int | float]: """ dense attribute params must be: @@ -1299,7 +1306,7 @@ def try_parse_int_or_float(): self.must_parse_characters('[', '') while not self.tokenizer.starts_with(']'): - yield from self.must_parse_builtin_dense_attr_args() + yield from self._must_parse_builtin_dense_attr_args() if self.tokenizer.next_token_of_pattern(',') is None: break self.must_parse_characters(']', '') @@ -1326,7 +1333,7 @@ def try_parse_builtin_int_attr(self) -> IntegerAttr | None: 'Integer attribute must start with an integer literal!') if self.tokenizer.next_token(peek=True).text != ':': return IntegerAttr.from_params(int(value.text), i64) - type = self.must_parse_attribute_type() + type = self._must_parse_attribute_type() return IntegerAttr.from_params(int(value.text), type) def try_parse_builtin_float_attr(self) -> FloatAttr | None: @@ -1339,7 +1346,7 @@ def try_parse_builtin_float_attr(self) -> FloatAttr | None: if not self.tokenizer.starts_with(":"): return FloatAttr.from_value(float(value.text)) - type = self.must_parse_attribute_type() + type = self._must_parse_attribute_type() return FloatAttr.from_value(float(value.text), type) def try_parse_builtin_boolean_attr(self) -> IntegerAttr | None: @@ -1377,7 +1384,7 @@ def try_parse_builtin_arr_attr(self) -> ArrayAttr | None: def must_parse_optional_attr_dict(self) -> dict[str, Attribute]: raise NotImplementedError() - def attr_dict_from_tuple_list( + def _attr_dict_from_tuple_list( self, tuple_list: list[tuple[Span, Attribute]]) -> dict[str, Attribute]: """ @@ -1423,9 +1430,9 @@ def must_parse_function_type(self) -> FunctionType: "Malformed function type, expected `->`!") return FunctionType.from_lists( - args, self.must_parse_type_or_type_list_parens()) + args, self._must_parse_type_or_type_list_parens()) - def must_parse_type_or_type_list_parens(self) -> list[Attribute]: + def _must_parse_type_or_type_list_parens(self) -> list[Attribute]: """ Parses type-or-type-list-parens, which is used in function-type. @@ -1461,7 +1468,10 @@ def must_parse_region_list(self) -> list[Region]: regions.append(self.must_parse_region()) return regions - def must_parse_builtin_type_with_name(self, name: Span): + def _must_parse_builtin_type_with_name(self, name: Span): + """ + Parses one of the builtin types like i42, vector, etc... + """ if name.text == "index": return IndexType() if (re_match := re.match(r"^[su]?i(\d+)$", name.text)) is not None: @@ -1485,10 +1495,10 @@ def must_parse_builtin_type_with_name(self, name: Span): "Unsupported floating point width: {}".format(width)) return type() - return self.must_parse_builtin_parametrized_type(name) + return self._must_parse_builtin_parametrized_type(name) @abstractmethod - def must_parse_operation_details( + def _must_parse_operation_details( self, ) -> tuple[list[Span], list[Span], dict[str, Attribute], list[Region], FunctionType | None]: @@ -1509,7 +1519,7 @@ def must_parse_operation_details( raise NotImplementedError() @abstractmethod - def must_parse_op_args_list(self) -> list[Span]: + def _must_parse_op_args_list(self) -> list[Span]: raise NotImplementedError() # HERE STARTS A SOMEWHAT CURSED COMPATIBILITY LAYER: @@ -1531,7 +1541,7 @@ def parse_op_with_default_format( """ # TODO: remove this function and restructure custom op / irdl parsing assert isinstance(self, XDSLParser) - args, successors, attributes, regions, _ = self.must_parse_operation_details( + args, successors, attributes, regions, _ = self._must_parse_operation_details( ) for x in args: @@ -1544,7 +1554,9 @@ def parse_op_with_default_format( operands=[self.ssaValues[span.text] for span in args], result_types=result_types, attributes=attributes, - successors=[self.get_block_from_name(span) for span in successors], + successors=[ + self._get_block_from_name(span) for span in successors + ], regions=regions) def parse_paramattr_parameters( @@ -1608,7 +1620,7 @@ def try_parse_builtin_type(self) -> Attribute | None: if name is None: raise self.raise_error("Expected builtin name!") - return self.must_parse_builtin_type_with_name(name) + return self._must_parse_builtin_type_with_name(name) def must_parse_attribute(self) -> Attribute: """ @@ -1636,7 +1648,7 @@ def must_parse_attribute(self) -> Attribute: return builtin_val - def must_parse_op_result_list( + def _must_parse_op_result_list( self) -> tuple[list[Span], list[Attribute] | None]: return ( self.must_parse_list_of(self.try_parse_value_id, @@ -1655,21 +1667,21 @@ def must_parse_optional_attr_dict(self) -> dict[str, Attribute]: attrs = [] if not self.tokenizer.starts_with('}'): - attrs = self.must_parse_list_of(self.must_parse_attribute_entry, + attrs = self.must_parse_list_of(self._must_parse_attribute_entry, "Expected attribute entry") self.must_parse_characters( "}", "MLIR Attribute dictionary must be enclosed in curly brackets") - return self.attr_dict_from_tuple_list(attrs) + return self._attr_dict_from_tuple_list(attrs) - def must_parse_operation_details( + def _must_parse_operation_details( self, ) -> tuple[list[Span], list[Span], dict[str, Attribute], list[Region], FunctionType | None]: - args = self.must_parse_op_args_list() - succ = self.must_parse_optional_successor_list() + args = self._must_parse_op_args_list() + succ = self._must_parse_optional_successor_list() regions = [] if self.tokenizer.starts_with("("): @@ -1688,7 +1700,7 @@ def must_parse_operation_details( return args, succ, attrs, regions, func_type - def must_parse_optional_successor_list(self) -> list[Span]: + def _must_parse_optional_successor_list(self) -> list[Span]: if not self.tokenizer.starts_with("["): return [] self.must_parse_characters( @@ -1700,7 +1712,7 @@ def must_parse_optional_successor_list(self) -> list[Span]: "]", "Successor list is enclosed in square brackets") return successors - def must_parse_op_args_list(self) -> list[Span]: + def _must_parse_op_args_list(self) -> list[Span]: self.must_parse_characters( "(", "Operation args list must be enclosed by brackets!") args = self.must_parse_list_of(self.try_parse_value_id, @@ -1725,7 +1737,7 @@ def try_parse_builtin_type(self) -> Attribute | None: # xDSL builtin types have a '!' prefix, we strip that out here name = Span(start=name.start + 1, end=name.end, input=name.input) - return self.must_parse_builtin_type_with_name(name) + return self._must_parse_builtin_type_with_name(name) def must_parse_attribute(self) -> Attribute: """ @@ -1748,7 +1760,7 @@ def must_parse_attribute(self) -> Attribute: return value - def must_parse_op_result_list( + def _must_parse_op_result_list( self) -> tuple[list[Span], list[Attribute] | None]: if not self.tokenizer.starts_with("%"): return list(), list() @@ -1783,16 +1795,16 @@ def must_parse_optional_attr_dict(self) -> dict[str, Attribute]: "[", "xDSL Attribute dictionary must be enclosed in square brackets") - attrs = self.must_parse_list_of(self.must_parse_attribute_entry, + attrs = self.must_parse_list_of(self._must_parse_attribute_entry, "Expected attribute entry") self.must_parse_characters( "]", "xDSL Attribute dictionary must be enclosed in square brackets") - return self.attr_dict_from_tuple_list(attrs) + return self._attr_dict_from_tuple_list(attrs) - def must_parse_operation_details( + def _must_parse_operation_details( self, ) -> tuple[list[Span], list[Span], dict[str, Attribute], list[Region], FunctionType | None]: @@ -1806,14 +1818,14 @@ def must_parse_operation_details( containing the types of the returned SSAValues """ - args = self.must_parse_op_args_list() - succ = self.must_parse_optional_successor_list() + args = self._must_parse_op_args_list() + succ = self._must_parse_optional_successor_list() attrs = self.must_parse_optional_attr_dict() regions = self.must_parse_region_list() return args, succ, attrs, regions, None - def must_parse_optional_successor_list(self) -> list[Span]: + def _must_parse_optional_successor_list(self) -> list[Span]: if not self.tokenizer.starts_with("("): return [] self.must_parse_characters( @@ -1825,17 +1837,17 @@ def must_parse_optional_successor_list(self) -> list[Span]: ")", "Successor list is enclosed in round brackets") return successors - def must_parse_dialect_type_or_attribute_inner(self, kind: str): + def _must_parse_dialect_type_or_attribute_inner(self, kind: str): if self.tokenizer.starts_with('"'): name = self.try_parse_string_literal() if name is None: self.raise_error( "Expected string literal for an attribute in generic format here!" ) - return self.must_parse_generic_attribute_args(name) - return super().must_parse_dialect_type_or_attribute_inner(kind) + return self._must_parse_generic_attribute_args(name) + return super()._must_parse_dialect_type_or_attribute_inner(kind) - def must_parse_generic_attribute_args(self, name: StringLiteral): + def _must_parse_generic_attribute_args(self, name: StringLiteral): attr = self.ctx.get_optional_attr(name.string_contents) if attr is None: self.raise_error("Unknown attribute name!", name) @@ -1849,7 +1861,7 @@ def must_parse_generic_attribute_args(self, name: StringLiteral): '>', 'Malformed attribute arguments, reached end of args list!') return attr(args) - def must_parse_op_args_list(self) -> list[Span]: + def _must_parse_op_args_list(self) -> list[Span]: self.must_parse_characters( "(", "Operation args list must be enclosed by brackets!") args = self.must_parse_list_of(self.try_parse_value_id_and_type, From 85810fc1424fb3774bb8a2c7bd0545840757e08d Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Mon, 23 Jan 2023 14:46:26 +0000 Subject: [PATCH 59/65] tests: clean up imports and unused vars in test_parse_error --- tests/test_parser_error.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/test_parser_error.py b/tests/test_parser_error.py index 71938f8319..4d2599f8e0 100644 --- a/tests/test_parser_error.py +++ b/tests/test_parser_error.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from typing import Annotated from pytest import raises @@ -26,7 +24,6 @@ def check_error(prog: str, line: int, column: int, message: str): parser.must_parse_operation() assert e.value.span - msgs = [err.error.msg for err in e.value.history.iterate()] for err in e.value.history.iterate(): if message in err.error.msg: From 54a7788abf56ef2af5e1405288fbe8be83e4478b Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Mon, 23 Jan 2023 14:56:40 +0000 Subject: [PATCH 60/65] tools: removed assertion text - off topic --- xdsl/xdsl_opt_main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xdsl/xdsl_opt_main.py b/xdsl/xdsl_opt_main.py index c9c0e45fc4..9ce2de6f6e 100644 --- a/xdsl/xdsl_opt_main.py +++ b/xdsl/xdsl_opt_main.py @@ -306,7 +306,7 @@ def parse_input(self) -> ModuleOp: def apply_passes(self, prog: ModuleOp): """Apply passes in order.""" - assert isinstance(prog, ModuleOp), "Expected top-level module!" + assert isinstance(prog, ModuleOp) if not self.args.disable_verify: prog.verify() for pass_name, p in self.pipeline: From 9db3804be5f89834a8bc9431b9be0aba41ba5734 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Mon, 23 Jan 2023 15:08:13 +0000 Subject: [PATCH 61/65] tests: fix docstring --- tests/test_parser_error.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_parser_error.py b/tests/test_parser_error.py index 4d2599f8e0..829240e756 100644 --- a/tests/test_parser_error.py +++ b/tests/test_parser_error.py @@ -79,7 +79,7 @@ def test_parser_missing_operation_name(): def test_parser_malformed_type(): - """Test a missing attribute error.""" + """Test a missing type error.""" ctx = MLContext() ctx.register_op(UnkownOp) From 868e509da01d9fecc92f23011b35d771c5615559 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Mon, 23 Jan 2023 16:36:09 +0000 Subject: [PATCH 62/65] xdsl: removed must_ prefix from parser methods --- tests/test_ir.py | 2 +- tests/test_mlir_printer.py | 2 +- tests/test_parser.py | 4 +- tests/test_parser_error.py | 2 +- tests/test_printer.py | 2 +- xdsl/dialects/builtin.py | 2 +- xdsl/dialects/llvm.py | 6 +- xdsl/parser.py | 336 ++++++++++++++++++------------------- 8 files changed, 174 insertions(+), 182 deletions(-) diff --git a/tests/test_ir.py b/tests/test_ir.py index d33cea4837..47fbe88eaa 100644 --- a/tests/test_ir.py +++ b/tests/test_ir.py @@ -232,7 +232,7 @@ def test_is_structurally_equivalent_incompatible_ir_nodes(): ctx.register_dialect(Cf) parser = XDSLParser(ctx, program_func) - program: ModuleOp = parser.must_parse_operation() + program: ModuleOp = parser.parse_operation() assert program.is_structurally_equivalent(program.regions[0]) == False assert program.is_structurally_equivalent( diff --git a/tests/test_mlir_printer.py b/tests/test_mlir_printer.py index e9e60ca454..0fd69d769a 100644 --- a/tests/test_mlir_printer.py +++ b/tests/test_mlir_printer.py @@ -90,7 +90,7 @@ def print_as_mlir_and_compare(test_prog: str, expected: str): ctx.register_attr(ParamAttrWithCustomFormat) parser = XDSLParser(ctx, test_prog) - module = parser.must_parse_operation() + module = parser.parse_operation() res = StringIO() printer = Printer(target=Printer.Target.MLIR, stream=res) diff --git a/tests/test_parser.py b/tests/test_parser.py index abad6663c2..7d3d186c5c 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -15,7 +15,7 @@ def test_int_list_parser(input: str, expected: list[int]): ctx = MLContext() parser = XDSLParser(ctx, input) - int_list = parser.must_parse_list_of(parser.try_parse_integer_literal, '') + int_list = parser.parse_list_of(parser.try_parse_integer_literal, '') assert [int(span.text) for span in int_list] == expected @@ -38,6 +38,6 @@ def test_dictionary_attr(data: dict[str, Attribute]): ctx = MLContext() ctx.register_dialect(Builtin) - attr = XDSLParser(ctx, text).must_parse_attribute() + attr = XDSLParser(ctx, text).parse_attribute() assert attr.data == data diff --git a/tests/test_parser_error.py b/tests/test_parser_error.py index 829240e756..de67c51b8f 100644 --- a/tests/test_parser_error.py +++ b/tests/test_parser_error.py @@ -21,7 +21,7 @@ def check_error(prog: str, line: int, column: int, message: str): parser = XDSLParser(ctx, prog) with raises(ParseError) as e: - parser.must_parse_operation() + parser.parse_operation() assert e.value.span diff --git a/tests/test_printer.py b/tests/test_printer.py index b051a73210..ea0cfbe0d1 100644 --- a/tests/test_printer.py +++ b/tests/test_printer.py @@ -386,7 +386,7 @@ def parse(cls, result_types: List[Attribute], lhs = parser.expect(parser.try_parse_value_id, 'Expected SSA Value name here!') - parser.must_parse_characters( + parser.parse_characters( "+", "Malformed operation format, expected `+`!") rhs = parser.expect(parser.try_parse_value_id, 'Expected SSA Value name here!') diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index 0b4589540c..a48dd903f0 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -352,7 +352,7 @@ class DictionaryAttr(GenericData[dict[str, Attribute]]): def parse_parameter(parser: BaseParser) -> dict[str, Attribute]: # force MLIR style parsing of attribute from xdsl.parser import MLIRParser - return MLIRParser.must_parse_optional_attr_dict(parser) + return MLIRParser.parse_optional_attr_dict(parser) @staticmethod def print_parameter(data: dict[str, Attribute], printer: Printer) -> None: diff --git a/xdsl/dialects/llvm.py b/xdsl/dialects/llvm.py index 76c0bf4372..749787ca2c 100644 --- a/xdsl/dialects/llvm.py +++ b/xdsl/dialects/llvm.py @@ -39,11 +39,11 @@ def print_parameters(self, printer: Printer) -> None: @staticmethod def parse_parameters(parser: BaseParser) -> list[Attribute]: - parser.must_parse_characters("<(", "LLVM Struct must start with `<(`") - params = parser.must_parse_list_of( + parser.parse_characters("<(", "LLVM Struct must start with `<(`") + params = parser.parse_list_of( parser.try_parse_type, "Malformed LLVM struct, expected attribute definition here!") - parser.must_parse_characters( + parser.parse_characters( ")>", "Unexpected input, expected end of LLVM struct!") return [StringAttr.from_str(""), ArrayAttr.from_list(params)] diff --git a/xdsl/parser.py b/xdsl/parser.py index b86ffc4ff2..b30b33dda9 100644 --- a/xdsl/parser.py +++ b/xdsl/parser.py @@ -564,17 +564,13 @@ class BaseParser(ABC): methods marked try_... will attempt to parse, and return None if they failed. If they return None they must make sure to restore all state. - methods marked must_... will do greedy parsing, meaning they consume as much as they can. They will + methods marked parse_... will do "greedy" parsing, meaning they consume as much as they can. They will also throw an error if the think they should still be parsing. e.g. when parsing a list of numbers separated by '::', the following input will trigger an exception: 1::2:: Due to the '::' present after the last element. This is useful for parsing lists, as a trailing separator is usually considered a syntax error there. - You can turn a try_ into a must_ by using expect(try_parse_..., error_msg) - - You can turn a must_ into a try_ by wrapping it in tokenizer.backtracking() - must_ type parsers are preferred because they are explicit about their failure modes. """ @@ -638,8 +634,8 @@ def _get_block_from_name(self, block_name: Span) -> Block: self.blocks[name] = Block() return self.blocks[name] - def must_parse_block(self) -> Block: - block_id, args = self._must_parse_optional_block_label() + def parse_block(self) -> Block: + block_id, args = self._parse_optional_block_label() if block_id is None: block = Block(self.tokenizer.last_token) @@ -669,7 +665,7 @@ def must_parse_block(self) -> Block: return block - def _must_parse_optional_block_label( + def _parse_optional_block_label( self) -> tuple[Span | None, list[tuple[Span, Attribute]]]: """ A block label consists of block-id ( `(` block-arg `,` ... `)` )? @@ -679,25 +675,25 @@ def _must_parse_optional_block_label( if block_id is not None: if self.tokenizer.starts_with('('): - arg_list = self._must_parse_block_arg_list() + arg_list = self._parse_block_arg_list() - self.must_parse_characters(':', 'Block label must end in a `:`!') + self.parse_characters(':', 'Block label must end in a `:`!') return block_id, arg_list - def _must_parse_block_arg_list(self) -> list[tuple[Span, Attribute]]: - self.must_parse_characters('(', 'Block arguments must start with `(`') + def _parse_block_arg_list(self) -> list[tuple[Span, Attribute]]: + self.parse_characters('(', 'Block arguments must start with `(`') - args = self.must_parse_list_of(self.try_parse_value_id_and_type, + args = self.parse_list_of(self.try_parse_value_id_and_type, "Expected value-id and type here!") - self.must_parse_characters(')', 'Expected closing of block arguments!') + self.parse_characters(')', 'Expected closing of block arguments!') return args def try_parse_single_reference(self) -> Span | None: with self.tokenizer.backtracking('part of a reference'): - self.must_parse_characters('@', "references must start with `@`") + self.parse_characters('@', "references must start with `@`") if (reference := self.try_parse_string_literal()) is not None: return reference if (reference := self.try_parse_suffix_id()) is not None: @@ -705,14 +701,14 @@ def try_parse_single_reference(self) -> Span | None: self.raise_error( "References must conform to `@` (string-literal | suffix-id)") - def must_parse_reference(self) -> list[Span]: - return self.must_parse_list_of( + def parse_reference(self) -> list[Span]: + return self.parse_list_of( self.try_parse_single_reference, 'Expected reference here in the format of `@` (suffix-id | string-literal)', ParserCommons.double_colon, allow_empty=False) - def must_parse_list_of(self, + def parse_list_of(self, try_parse: Callable[[], T_ | None], error_msg: str, separator_pattern: re.Pattern = ParserCommons.comma, @@ -793,7 +789,7 @@ def try_parse_value_id_and_type(self) -> tuple[Span, Attribute] | None: if value_id is None: self.raise_error("Invalid value-id format!") - self.must_parse_characters( + self.parse_characters( ':', 'Expected expression (value-id `:` type)') type = self.try_parse_type() @@ -822,9 +818,9 @@ def try_parse_dialect_type_or_attribute(self) -> Attribute | None: with self.tokenizer.backtracking("dialect attribute or type"): self.tokenizer.consume_peeked(kind) if kind.text == '!': - return self._must_parse_dialect_type_or_attribute_inner('type') + return self._parse_dialect_type_or_attribute_inner('type') else: - return self._must_parse_dialect_type_or_attribute_inner( + return self._parse_dialect_type_or_attribute_inner( 'attribute') def try_parse_dialect_type(self): @@ -834,9 +830,9 @@ def try_parse_dialect_type(self): if not self.tokenizer.starts_with('!'): return None with self.tokenizer.backtracking("dialect type"): - self.must_parse_characters('!', + self.parse_characters('!', "Dialect type must start with a `!`") - return self._must_parse_dialect_type_or_attribute_inner('type') + return self._parse_dialect_type_or_attribute_inner('type') def try_parse_dialect_attr(self): """ @@ -845,12 +841,12 @@ def try_parse_dialect_attr(self): if not self.tokenizer.starts_with('#'): return None with self.tokenizer.backtracking("dialect attribute"): - self.must_parse_characters( + self.parse_characters( '#', "Dialect attribute must start with a `#`") - return self._must_parse_dialect_type_or_attribute_inner( + return self._parse_dialect_type_or_attribute_inner( 'attribute') - def _must_parse_dialect_type_or_attribute_inner(self, kind: str): + def _parse_dialect_type_or_attribute_inner(self, kind: str): type_name = self.tokenizer.next_token_of_pattern(ParserCommons.bare_id) if type_name is None: @@ -866,10 +862,10 @@ def _must_parse_dialect_type_or_attribute_inner(self, kind: str): if issubclass(type_def, ParametrizedAttribute): param_list = type_def.parse_parameters(self) elif issubclass(type_def, Data): - self.must_parse_characters("<", + self.parse_characters("<", "This attribute must be parametrized!") param_list = type_def.parse_parameter(self) - self.must_parse_characters( + self.parse_characters( ">", "Invalid attribute parametrization, expected `>`!") else: assert False, "Mathieu said this cannot be." @@ -882,7 +878,7 @@ def try_parse_builtin_type(self) -> Attribute | None: """ raise NotImplemented("Subclasses must implement this method!") - def _must_parse_builtin_parametrized_type( + def _parse_builtin_parametrized_type( self, name: Span) -> ParametrizedAttribute: """ This function is called after we parse the name of a paremetrized type such as vector. @@ -893,25 +889,25 @@ def unimplemented() -> ParametrizedAttribute: "Builtin {} not supported yet!".format(name.text)) builtin_parsers: dict[str, Callable[[], ParametrizedAttribute]] = { - "vector": self.must_parse_vector_attrs, - "memref": self.must_parse_memref_attrs, - "tensor": self.must_parse_tensor_attrs, - "complex": self.must_parse_complex_attrs, + "vector": self.parse_vector_attrs, + "memref": self.parse_memref_attrs, + "tensor": self.parse_tensor_attrs, + "complex": self.parse_complex_attrs, "tuple": unimplemented, } - self.must_parse_characters("<", "Expected parameter list here!") + self.parse_characters("<", "Expected parameter list here!") # Get the parser for the type, falling back to the unimplemented warning res = builtin_parsers.get(name.text, unimplemented)() - self.must_parse_characters(">", "Expected end of parameter list here!") + self.parse_characters(">", "Expected end of parameter list here!") return res - def must_parse_complex_attrs(self): + def parse_complex_attrs(self): self.raise_error("ComplexType is unimplemented!") - def must_parse_memref_attrs(self) -> MemRefType | UnrankedMemrefType: - dims = self._must_parse_tensor_or_memref_dims() + def parse_memref_attrs(self) -> MemRefType | UnrankedMemrefType: + dims = self._parse_tensor_or_memref_dims() type = self.try_parse_type() if dims is None: return UnrankedMemrefType.from_type(type) @@ -926,10 +922,10 @@ def try_parse_numerical_dims(self, # Look out for the closing bracket for scalable vector dims if accept_closing_bracket and self.tokenizer.starts_with("]"): break - self.must_parse_characters( + self.parse_characters( "x", "Unexpected end of dimension parameters!") - def must_parse_vector_attrs(self) -> AnyVectorType: + def parse_vector_attrs(self) -> AnyVectorType: # Also break on 'x' characters as they are separators in dimension parameters with self.tokenizer.configured(break_on=self.tokenizer.break_on + ("x", )): @@ -939,9 +935,9 @@ def must_parse_vector_attrs(self) -> AnyVectorType: if self.tokenizer.next_token_of_pattern("[") is not None: # We now need to parse the scalable dimensions scaling_shape = list(self.try_parse_numerical_dims()) - self.must_parse_characters( + self.parse_characters( "]", "Expected end of scalable vector dimensions here!") - self.must_parse_characters( + self.parse_characters( "x", "Expected end of scalable vector dimensions here!") if scaling_shape is not None: @@ -956,21 +952,21 @@ def must_parse_vector_attrs(self) -> AnyVectorType: return VectorType.from_element_type_and_shape(type, shape) - def _must_parse_tensor_or_memref_dims(self) -> list[int] | None: + def _parse_tensor_or_memref_dims(self) -> list[int] | None: with self.tokenizer.configured(break_on=self.tokenizer.break_on + ('x', )): # Check for unranked-ness if self.tokenizer.next_token_of_pattern('*') is not None: # Consume `x` - self.must_parse_characters( + self.parse_characters( 'x', 'Unranked tensors must follow format (`<*x` type `>`)') else: # Parse rank: return list(self.try_parse_numerical_dims(lower_bound=0)) - def must_parse_tensor_attrs(self) -> AnyTensorType: - shape = self._must_parse_tensor_or_memref_dims() + def parse_tensor_attrs(self) -> AnyTensorType: + shape = self._parse_tensor_or_memref_dims() type = self.try_parse_type() if type is None: @@ -1009,14 +1005,14 @@ def _try_parse_shape_element(self, lower_bound: int = 1) -> int | None: return -1 return None - def _must_parse_type_params(self) -> list[Attribute]: + def _parse_type_params(self) -> list[Attribute]: # Consume opening bracket - self.must_parse_characters('<', 'Type must be parameterized!') + self.parse_characters('<', 'Type must be parameterized!') - params = self.must_parse_list_of(self.try_parse_type, + params = self.parse_list_of(self.try_parse_type, 'Expected a type here!') - self.must_parse_characters( + self.parse_characters( '>', 'Expected end of type parameterization here!') return params @@ -1042,24 +1038,24 @@ def raise_error(self, msg: str, at_position: Span | None = None): raise ParseError(at_position, msg, self.tokenizer.history) - def must_parse_characters(self, text: str, msg: str) -> Span: + def parse_characters(self, text: str, msg: str) -> Span: if (match := self.tokenizer.next_token_of_pattern(text)) is None: self.raise_error(msg) return match @abstractmethod - def _must_parse_op_result_list( + def _parse_op_result_list( self) -> tuple[list[Span], list[Attribute] | None]: raise NotImplemented() def try_parse_operation(self) -> Operation | None: with self.tokenizer.backtracking("operation"): - return self.must_parse_operation() + return self.parse_operation() - def must_parse_operation(self) -> Operation: - result_list, ret_types = self._must_parse_op_result_list() + def parse_operation(self) -> Operation: + result_list, ret_types = self._parse_op_result_list() if len(result_list) > 0: - self.must_parse_characters( + self.parse_characters( '=', 'Operation definitions expect an `=` after op-result-list!') @@ -1076,7 +1072,7 @@ def must_parse_operation(self) -> Operation: "Expected an operation name here, either a bare-id, or a string literal!" ) - args, successors, attrs, regions, func_type = self._must_parse_operation_details( + args, successors, attrs, regions, func_type = self._parse_operation_details( ) if ret_types is None: @@ -1121,7 +1117,7 @@ def _get_op_by_name(self, span: Span) -> type[Operation]: self.raise_error(f'Unknown operation {op_name}!', span) - def must_parse_region(self) -> Region: + def parse_region(self) -> Region: oldSSAVals = self.ssaValues.copy() oldBBNames = self.blocks oldForwardRefs = self.forward_block_references @@ -1131,18 +1127,18 @@ def must_parse_region(self) -> Region: region = Region() try: - self.must_parse_characters("{", "Regions begin with `{`") + self.parse_characters("{", "Regions begin with `{`") if self.tokenizer.starts_with("}"): region.add_block(Block()) else: # Parse first block - block = self.must_parse_block() + block = self.parse_block() region.add_block(block) while self.tokenizer.starts_with("^"): - region.add_block(self.must_parse_block()) + region.add_block(self.parse_block()) - end = self.must_parse_characters( + end = self.parse_characters( "}", "Reached end of region, expected `}`!") if len(self.forward_block_references) > 0: @@ -1167,7 +1163,7 @@ def _try_parse_op_name(self) -> Span | None: return str_lit return self.try_parse_bare_id() - def _must_parse_attribute_entry(self) -> tuple[Span, Attribute]: + def _parse_attribute_entry(self) -> tuple[Span, Attribute]: """ Parse entry in attribute dict. Of format: @@ -1185,13 +1181,13 @@ def _must_parse_attribute_entry(self) -> tuple[Span, Attribute]: if not self.tokenizer.starts_with('='): return name, UnitAttr() - self.must_parse_characters( + self.parse_characters( "=", "Attribute entries must be of format name `=` attribute!") - return name, self.must_parse_attribute() + return name, self.parse_attribute() @abstractmethod - def must_parse_attribute(self) -> Attribute: + def parse_attribute(self) -> Attribute: """ Parse attribute (either builtin or dialect) @@ -1201,13 +1197,13 @@ def must_parse_attribute(self) -> Attribute: def try_parse_attribute(self) -> Attribute | None: with self.tokenizer.backtracking("attribute"): - return self.must_parse_attribute() + return self.parse_attribute() - def _must_parse_attribute_type(self) -> Attribute: + def _parse_attribute_type(self) -> Attribute: """ Parses `:` type and returns the type """ - self.must_parse_characters( + self.parse_characters( ":", "Expected attribute type definition here ( `:` type )") return self.expect( self.try_parse_type, @@ -1244,8 +1240,8 @@ def try_parse_builtin_named_attr(self) -> Attribute | None: name.text)): self.tokenizer.consume_peeked(name) parsers = { - 'dense': self._must_parse_builtin_dense_attr, - 'opaque': self._must_parse_builtin_opaque_attr, + 'dense': self._parse_builtin_dense_attr, + 'opaque': self._parse_builtin_opaque_attr, } def not_implemented(): @@ -1253,31 +1249,31 @@ def not_implemented(): return parsers.get(name.text, not_implemented)() - def _must_parse_builtin_dense_attr(self) -> Attribute | None: + def _parse_builtin_dense_attr(self) -> Attribute | None: err_msg = "Malformed dense attribute, format must be (`dense<` array-attr `>:` type)" - self.must_parse_characters("<", err_msg) - info = list(self._must_parse_builtin_dense_attr_args()) - self.must_parse_characters(">", err_msg) - self.must_parse_characters(":", err_msg) + self.parse_characters("<", err_msg) + info = list(self._parse_builtin_dense_attr_args()) + self.parse_characters(">", err_msg) + self.parse_characters(":", err_msg) type = self.expect(self.try_parse_type, "Dense attribute must be typed!") return DenseIntOrFPElementsAttr.from_list(type, info) - def _must_parse_builtin_opaque_attr(self): - self.must_parse_characters("<", + def _parse_builtin_opaque_attr(self): + self.parse_characters("<", "Opaque attribute must be parametrized") - str_lit_list = self.must_parse_list_of(self.try_parse_string_literal, + str_lit_list = self.parse_list_of(self.try_parse_string_literal, 'Expected opaque attr here!') if len(str_lit_list) != 2: self.raise_error('Opaque expects 2 string literal parameters!') - self.must_parse_characters( + self.parse_characters( ">", "Unexpected parameters for opaque attr, expected `>`!") type = NoneAttr() if self.tokenizer.starts_with(':'): - self.must_parse_characters(":", "opaque attribute must be typed!") + self.parse_characters(":", "opaque attribute must be typed!") type = self.expect(self.try_parse_type, "opaque attribute must be typed!") @@ -1285,7 +1281,7 @@ def _must_parse_builtin_opaque_attr(self): for span in str_lit_list), type=type) - def _must_parse_builtin_dense_attr_args(self) -> Iterable[int | float]: + def _parse_builtin_dense_attr_args(self) -> Iterable[int | float]: """ dense attribute params must be: @@ -1304,18 +1300,18 @@ def try_parse_int_or_float(): yield try_parse_int_or_float() return - self.must_parse_characters('[', '') + self.parse_characters('[', '') while not self.tokenizer.starts_with(']'): - yield from self._must_parse_builtin_dense_attr_args() + yield from self._parse_builtin_dense_attr_args() if self.tokenizer.next_token_of_pattern(',') is None: break - self.must_parse_characters(']', '') + self.parse_characters(']', '') def try_parse_ref_attr(self) -> FlatSymbolRefAttr | None: if not self.tokenizer.starts_with("@"): return None - ref = self.must_parse_reference() + ref = self.parse_reference() if len(ref) > 1: self.raise_error("Nested refs are not supported yet!", ref[1]) @@ -1333,7 +1329,7 @@ def try_parse_builtin_int_attr(self) -> IntegerAttr | None: 'Integer attribute must start with an integer literal!') if self.tokenizer.next_token(peek=True).text != ':': return IntegerAttr.from_params(int(value.text), i64) - type = self._must_parse_attribute_type() + type = self._parse_attribute_type() return IntegerAttr.from_params(int(value.text), type) def try_parse_builtin_float_attr(self) -> FloatAttr | None: @@ -1346,7 +1342,7 @@ def try_parse_builtin_float_attr(self) -> FloatAttr | None: if not self.tokenizer.starts_with(":"): return FloatAttr.from_value(float(value.text)) - type = self._must_parse_attribute_type() + type = self._parse_attribute_type() return FloatAttr.from_value(float(value.text), type) def try_parse_builtin_boolean_attr(self) -> IntegerAttr | None: @@ -1372,16 +1368,16 @@ def try_parse_builtin_arr_attr(self) -> ArrayAttr | None: if not self.tokenizer.starts_with("["): return None with self.tokenizer.backtracking("array literal"): - self.must_parse_characters("[", + self.parse_characters("[", "Array literals must start with `[`") - attrs = self.must_parse_list_of(self.try_parse_attribute, + attrs = self.parse_list_of(self.try_parse_attribute, "Expected array entry!") - self.must_parse_characters( + self.parse_characters( "]", "Malformed array contents (expected end of array here!") return ArrayAttr.from_list(attrs) @abstractmethod - def must_parse_optional_attr_dict(self) -> dict[str, Attribute]: + def parse_optional_attr_dict(self) -> dict[str, Attribute]: raise NotImplementedError() def _attr_dict_from_tuple_list( @@ -1400,7 +1396,7 @@ def span_to_str(span: Span) -> str: return dict((span_to_str(span), attr) for span, attr in tuple_list) - def must_parse_function_type(self) -> FunctionType: + def parse_function_type(self) -> FunctionType: """ Parses function-type: @@ -1415,24 +1411,24 @@ def must_parse_function_type(self) -> FunctionType: Uses type-or-type-list-parens internally """ - self.must_parse_characters( + self.parse_characters( "(", "First group of function args must start with a `(`") - args: list[Attribute] = self.must_parse_list_of( + args: list[Attribute] = self.parse_list_of( self.try_parse_type, "Expected type here!") - self.must_parse_characters( + self.parse_characters( ")", "Malformed function type, expected closing brackets of argument types!" ) - self.must_parse_characters("->", + self.parse_characters("->", "Malformed function type, expected `->`!") return FunctionType.from_lists( - args, self._must_parse_type_or_type_list_parens()) + args, self._parse_type_or_type_list_parens()) - def _must_parse_type_or_type_list_parens(self) -> list[Attribute]: + def _parse_type_or_type_list_parens(self) -> list[Attribute]: """ Parses type-or-type-list-parens, which is used in function-type. @@ -1441,9 +1437,9 @@ def _must_parse_type_or_type_list_parens(self) -> list[Attribute]: type-list-no-parens ::= type (`,` type)* """ if self.tokenizer.next_token_of_pattern("(") is not None: - args: list[Attribute] = self.must_parse_list_of( + args: list[Attribute] = self.parse_list_of( self.try_parse_type, "Expected type here!") - self.must_parse_characters( + self.parse_characters( ")", "Unclosed function type argument list!") else: args = [self.try_parse_type()] @@ -1457,18 +1453,18 @@ def try_parse_function_type(self) -> FunctionType | None: if not self.tokenizer.starts_with("("): return None with self.tokenizer.backtracking("function type"): - return self.must_parse_function_type() + return self.parse_function_type() - def must_parse_region_list(self) -> list[Region]: + def parse_region_list(self) -> list[Region]: """ Parses a sequence of regions for as long as there is a `{` in the input. """ regions = [] while not self.tokenizer.is_eof() and self.tokenizer.starts_with("{"): - regions.append(self.must_parse_region()) + regions.append(self.parse_region()) return regions - def _must_parse_builtin_type_with_name(self, name: Span): + def _parse_builtin_type_with_name(self, name: Span): """ Parses one of the builtin types like i42, vector, etc... """ @@ -1495,10 +1491,10 @@ def _must_parse_builtin_type_with_name(self, name: Span): "Unsupported floating point width: {}".format(width)) return type() - return self._must_parse_builtin_parametrized_type(name) + return self._parse_builtin_parametrized_type(name) @abstractmethod - def _must_parse_operation_details( + def _parse_operation_details( self, ) -> tuple[list[Span], list[Span], dict[str, Attribute], list[Region], FunctionType | None]: @@ -1508,18 +1504,14 @@ def _must_parse_operation_details( - a list of successor names - the attributes attached to the OP - the regions of the op - - An optional function type. If not supplied, must_parse_op_result_list must return a second value + - An optional function type. If not supplied, parse_op_result_list must return a second value containing the types of the returned SSAValues - Your implementation should make use of the following functions: - - must_parse_op_args_list - - must_parse_optional_attr_dict - - must_parse_ """ raise NotImplementedError() @abstractmethod - def _must_parse_op_args_list(self) -> list[Span]: + def _parse_op_args_list(self) -> list[Span]: raise NotImplementedError() # HERE STARTS A SOMEWHAT CURSED COMPATIBILITY LAYER: @@ -1541,7 +1533,7 @@ def parse_op_with_default_format( """ # TODO: remove this function and restructure custom op / irdl parsing assert isinstance(self, XDSLParser) - args, successors, attributes, regions, _ = self._must_parse_operation_details( + args, successors, attributes, regions, _ = self._parse_operation_details( ) for x in args: @@ -1567,7 +1559,7 @@ def parse_paramattr_parameters( if expect_brackets and opening_brackets is None: self.raise_error("Expected start attribute parameters here (`<`)!") - res = self.must_parse_list_of(self.try_parse_attribute, + res = self.parse_list_of(self.try_parse_attribute, 'Expected another attribute here!') if opening_brackets is not None and self.tokenizer.next_token_of_pattern( @@ -1579,17 +1571,17 @@ def parse_paramattr_parameters( return res def parse_char(self, text: str): - self.must_parse_characters(text, "Expected '{}' here!".format(text)) + self.parse_characters(text, "Expected '{}' here!".format(text)) def parse_str_literal(self) -> str: return self.expect(self.try_parse_string_literal, 'Malformed string literal!').string_contents def parse_attribute(self) -> Attribute: - return self.must_parse_attribute() + return self.parse_attribute() def parse_op(self) -> Operation: - return self.must_parse_operation() + return self.parse_operation() def parse_int_literal(self) -> int: return int( @@ -1620,9 +1612,9 @@ def try_parse_builtin_type(self) -> Attribute | None: if name is None: raise self.raise_error("Expected builtin name!") - return self._must_parse_builtin_type_with_name(name) + return self._parse_builtin_type_with_name(name) - def must_parse_attribute(self) -> Attribute: + def parse_attribute(self) -> Attribute: """ Parse attribute (either builtin or dialect) """ @@ -1648,76 +1640,76 @@ def must_parse_attribute(self) -> Attribute: return builtin_val - def _must_parse_op_result_list( + def _parse_op_result_list( self) -> tuple[list[Span], list[Attribute] | None]: return ( - self.must_parse_list_of(self.try_parse_value_id, + self.parse_list_of(self.try_parse_value_id, "Expected op-result here!", allow_empty=True), None, ) - def must_parse_optional_attr_dict(self) -> dict[str, Attribute]: + def parse_optional_attr_dict(self) -> dict[str, Attribute]: if not self.tokenizer.starts_with("{"): return dict() - self.must_parse_characters( + self.parse_characters( "{", "MLIR Attribute dictionary must be enclosed in curly brackets") attrs = [] if not self.tokenizer.starts_with('}'): - attrs = self.must_parse_list_of(self._must_parse_attribute_entry, + attrs = self.parse_list_of(self._parse_attribute_entry, "Expected attribute entry") - self.must_parse_characters( + self.parse_characters( "}", "MLIR Attribute dictionary must be enclosed in curly brackets") return self._attr_dict_from_tuple_list(attrs) - def _must_parse_operation_details( + def _parse_operation_details( self, ) -> tuple[list[Span], list[Span], dict[str, Attribute], list[Region], FunctionType | None]: - args = self._must_parse_op_args_list() - succ = self._must_parse_optional_successor_list() + args = self._parse_op_args_list() + succ = self._parse_optional_successor_list() regions = [] if self.tokenizer.starts_with("("): - self.must_parse_characters("(", + self.parse_characters("(", "Expected brackets enclosing regions!") - regions = self.must_parse_region_list() - self.must_parse_characters(")", + regions = self.parse_region_list() + self.parse_characters(")", "Expected brackets enclosing regions!") - attrs = self.must_parse_optional_attr_dict() + attrs = self.parse_optional_attr_dict() - self.must_parse_characters( + self.parse_characters( ":", "MLIR Operation defintions must end in a function type signature!") - func_type = self.must_parse_function_type() + func_type = self.parse_function_type() return args, succ, attrs, regions, func_type - def _must_parse_optional_successor_list(self) -> list[Span]: + def _parse_optional_successor_list(self) -> list[Span]: if not self.tokenizer.starts_with("["): return [] - self.must_parse_characters( + self.parse_characters( "[", "Successor list is enclosed in square brackets") - successors = self.must_parse_list_of(self.try_parse_block_id, + successors = self.parse_list_of(self.try_parse_block_id, "Expected a block-id", allow_empty=False) - self.must_parse_characters( + self.parse_characters( "]", "Successor list is enclosed in square brackets") return successors - def _must_parse_op_args_list(self) -> list[Span]: - self.must_parse_characters( + def _parse_op_args_list(self) -> list[Span]: + self.parse_characters( "(", "Operation args list must be enclosed by brackets!") - args = self.must_parse_list_of(self.try_parse_value_id, + args = self.parse_list_of(self.try_parse_value_id, "Expected another bare-id here") - self.must_parse_characters( + self.parse_characters( ")", "Operation args list must be closed by a closing bracket") # TODO: check if type is correct here! return args @@ -1737,9 +1729,9 @@ def try_parse_builtin_type(self) -> Attribute | None: # xDSL builtin types have a '!' prefix, we strip that out here name = Span(start=name.start + 1, end=name.end, input=name.input) - return self._must_parse_builtin_type_with_name(name) + return self._parse_builtin_type_with_name(name) - def must_parse_attribute(self) -> Attribute: + def parse_attribute(self) -> Attribute: """ Parse attribute (either builtin or dialect) @@ -1760,11 +1752,11 @@ def must_parse_attribute(self) -> Attribute: return value - def _must_parse_op_result_list( + def _parse_op_result_list( self) -> tuple[list[Span], list[Attribute] | None]: if not self.tokenizer.starts_with("%"): return list(), list() - results = self.must_parse_list_of( + results = self.parse_list_of( self.try_parse_value_id_and_type, "Expected (value-id `:` type) here!", allow_empty=False, @@ -1787,24 +1779,24 @@ def try_parse_builtin_attr(self) -> Attribute: return super().try_parse_builtin_attr() - def must_parse_optional_attr_dict(self) -> dict[str, Attribute]: + def parse_optional_attr_dict(self) -> dict[str, Attribute]: if not self.tokenizer.starts_with("["): return dict() - self.must_parse_characters( + self.parse_characters( "[", "xDSL Attribute dictionary must be enclosed in square brackets") - attrs = self.must_parse_list_of(self._must_parse_attribute_entry, + attrs = self.parse_list_of(self._parse_attribute_entry, "Expected attribute entry") - self.must_parse_characters( + self.parse_characters( "]", "xDSL Attribute dictionary must be enclosed in square brackets") return self._attr_dict_from_tuple_list(attrs) - def _must_parse_operation_details( + def _parse_operation_details( self, ) -> tuple[list[Span], list[Span], dict[str, Attribute], list[Region], FunctionType | None]: @@ -1814,59 +1806,59 @@ def _must_parse_operation_details( - a list of successor names - the attributes attached to the OP - the regions of the op - - An optional function type. If not supplied, must_parse_op_result_list must return a second value + - An optional function type. If not supplied, parse_op_result_list must return a second value containing the types of the returned SSAValues """ - args = self._must_parse_op_args_list() - succ = self._must_parse_optional_successor_list() - attrs = self.must_parse_optional_attr_dict() - regions = self.must_parse_region_list() + args = self._parse_op_args_list() + succ = self._parse_optional_successor_list() + attrs = self.parse_optional_attr_dict() + regions = self.parse_region_list() return args, succ, attrs, regions, None - def _must_parse_optional_successor_list(self) -> list[Span]: + def _parse_optional_successor_list(self) -> list[Span]: if not self.tokenizer.starts_with("("): return [] - self.must_parse_characters( + self.parse_characters( "(", "Successor list is enclosed in round brackets") - successors = self.must_parse_list_of(self.try_parse_block_id, + successors = self.parse_list_of(self.try_parse_block_id, "Expected a block-id", allow_empty=False) - self.must_parse_characters( + self.parse_characters( ")", "Successor list is enclosed in round brackets") return successors - def _must_parse_dialect_type_or_attribute_inner(self, kind: str): + def _parse_dialect_type_or_attribute_inner(self, kind: str): if self.tokenizer.starts_with('"'): name = self.try_parse_string_literal() if name is None: self.raise_error( "Expected string literal for an attribute in generic format here!" ) - return self._must_parse_generic_attribute_args(name) - return super()._must_parse_dialect_type_or_attribute_inner(kind) + return self._parse_generic_attribute_args(name) + return super()._parse_dialect_type_or_attribute_inner(kind) - def _must_parse_generic_attribute_args(self, name: StringLiteral): + def _parse_generic_attribute_args(self, name: StringLiteral): attr = self.ctx.get_optional_attr(name.string_contents) if attr is None: self.raise_error("Unknown attribute name!", name) if not issubclass(attr, ParametrizedAttribute): self.raise_error("Expected ParametrizedAttribute name here!", name) - self.must_parse_characters( + self.parse_characters( '<', 'Expected generic attribute arguments here!') - args = self.must_parse_list_of(self.try_parse_attribute, + args = self.parse_list_of(self.try_parse_attribute, 'Unexpected end of attribute list!') - self.must_parse_characters( + self.parse_characters( '>', 'Malformed attribute arguments, reached end of args list!') return attr(args) - def _must_parse_op_args_list(self) -> list[Span]: - self.must_parse_characters( + def _parse_op_args_list(self) -> list[Span]: + self.parse_characters( "(", "Operation args list must be enclosed by brackets!") - args = self.must_parse_list_of(self.try_parse_value_id_and_type, + args = self.parse_list_of(self.try_parse_value_id_and_type, "Expected another bare-id here") - self.must_parse_characters( + self.parse_characters( ")", "Operation args list must be closed by a closing bracket") # TODO: check if type is correct here! return [name for name, _ in args] From de5fa1256f237a1f3e3706d16fcefbd4c92c341f Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Mon, 23 Jan 2023 16:39:18 +0000 Subject: [PATCH 63/65] formatting: run yapf --- tests/test_printer.py | 4 +- xdsl/parser.py | 122 +++++++++++++++++++----------------------- 2 files changed, 58 insertions(+), 68 deletions(-) diff --git a/tests/test_printer.py b/tests/test_printer.py index ea0cfbe0d1..3e2b9b5d6a 100644 --- a/tests/test_printer.py +++ b/tests/test_printer.py @@ -386,8 +386,8 @@ def parse(cls, result_types: List[Attribute], lhs = parser.expect(parser.try_parse_value_id, 'Expected SSA Value name here!') - parser.parse_characters( - "+", "Malformed operation format, expected `+`!") + parser.parse_characters("+", + "Malformed operation format, expected `+`!") rhs = parser.expect(parser.try_parse_value_id, 'Expected SSA Value name here!') diff --git a/xdsl/parser.py b/xdsl/parser.py index b30b33dda9..53dc790b41 100644 --- a/xdsl/parser.py +++ b/xdsl/parser.py @@ -685,7 +685,7 @@ def _parse_block_arg_list(self) -> list[tuple[Span, Attribute]]: self.parse_characters('(', 'Block arguments must start with `(`') args = self.parse_list_of(self.try_parse_value_id_and_type, - "Expected value-id and type here!") + "Expected value-id and type here!") self.parse_characters(')', 'Expected closing of block arguments!') @@ -709,10 +709,10 @@ def parse_reference(self) -> list[Span]: allow_empty=False) def parse_list_of(self, - try_parse: Callable[[], T_ | None], - error_msg: str, - separator_pattern: re.Pattern = ParserCommons.comma, - allow_empty: bool = True) -> list[T_]: + try_parse: Callable[[], T_ | None], + error_msg: str, + separator_pattern: re.Pattern = ParserCommons.comma, + allow_empty: bool = True) -> list[T_]: """ This is a greedy list-parser. It accepts input only in these cases: @@ -789,8 +789,8 @@ def try_parse_value_id_and_type(self) -> tuple[Span, Attribute] | None: if value_id is None: self.raise_error("Invalid value-id format!") - self.parse_characters( - ':', 'Expected expression (value-id `:` type)') + self.parse_characters(':', + 'Expected expression (value-id `:` type)') type = self.try_parse_type() @@ -820,8 +820,7 @@ def try_parse_dialect_type_or_attribute(self) -> Attribute | None: if kind.text == '!': return self._parse_dialect_type_or_attribute_inner('type') else: - return self._parse_dialect_type_or_attribute_inner( - 'attribute') + return self._parse_dialect_type_or_attribute_inner('attribute') def try_parse_dialect_type(self): """ @@ -830,8 +829,7 @@ def try_parse_dialect_type(self): if not self.tokenizer.starts_with('!'): return None with self.tokenizer.backtracking("dialect type"): - self.parse_characters('!', - "Dialect type must start with a `!`") + self.parse_characters('!', "Dialect type must start with a `!`") return self._parse_dialect_type_or_attribute_inner('type') def try_parse_dialect_attr(self): @@ -841,10 +839,9 @@ def try_parse_dialect_attr(self): if not self.tokenizer.starts_with('#'): return None with self.tokenizer.backtracking("dialect attribute"): - self.parse_characters( - '#', "Dialect attribute must start with a `#`") - return self._parse_dialect_type_or_attribute_inner( - 'attribute') + self.parse_characters('#', + "Dialect attribute must start with a `#`") + return self._parse_dialect_type_or_attribute_inner('attribute') def _parse_dialect_type_or_attribute_inner(self, kind: str): type_name = self.tokenizer.next_token_of_pattern(ParserCommons.bare_id) @@ -862,8 +859,7 @@ def _parse_dialect_type_or_attribute_inner(self, kind: str): if issubclass(type_def, ParametrizedAttribute): param_list = type_def.parse_parameters(self) elif issubclass(type_def, Data): - self.parse_characters("<", - "This attribute must be parametrized!") + self.parse_characters("<", "This attribute must be parametrized!") param_list = type_def.parse_parameter(self) self.parse_characters( ">", "Invalid attribute parametrization, expected `>`!") @@ -878,8 +874,8 @@ def try_parse_builtin_type(self) -> Attribute | None: """ raise NotImplemented("Subclasses must implement this method!") - def _parse_builtin_parametrized_type( - self, name: Span) -> ParametrizedAttribute: + def _parse_builtin_parametrized_type(self, + name: Span) -> ParametrizedAttribute: """ This function is called after we parse the name of a paremetrized type such as vector. """ @@ -922,8 +918,8 @@ def try_parse_numerical_dims(self, # Look out for the closing bracket for scalable vector dims if accept_closing_bracket and self.tokenizer.starts_with("]"): break - self.parse_characters( - "x", "Unexpected end of dimension parameters!") + self.parse_characters("x", + "Unexpected end of dimension parameters!") def parse_vector_attrs(self) -> AnyVectorType: # Also break on 'x' characters as they are separators in dimension parameters @@ -1010,10 +1006,10 @@ def _parse_type_params(self) -> list[Attribute]: self.parse_characters('<', 'Type must be parameterized!') params = self.parse_list_of(self.try_parse_type, - 'Expected a type here!') + 'Expected a type here!') - self.parse_characters( - '>', 'Expected end of type parameterization here!') + self.parse_characters('>', + 'Expected end of type parameterization here!') return params @@ -1260,10 +1256,9 @@ def _parse_builtin_dense_attr(self) -> Attribute | None: return DenseIntOrFPElementsAttr.from_list(type, info) def _parse_builtin_opaque_attr(self): - self.parse_characters("<", - "Opaque attribute must be parametrized") + self.parse_characters("<", "Opaque attribute must be parametrized") str_lit_list = self.parse_list_of(self.try_parse_string_literal, - 'Expected opaque attr here!') + 'Expected opaque attr here!') if len(str_lit_list) != 2: self.raise_error('Opaque expects 2 string literal parameters!') @@ -1368,10 +1363,9 @@ def try_parse_builtin_arr_attr(self) -> ArrayAttr | None: if not self.tokenizer.starts_with("["): return None with self.tokenizer.backtracking("array literal"): - self.parse_characters("[", - "Array literals must start with `[`") + self.parse_characters("[", "Array literals must start with `[`") attrs = self.parse_list_of(self.try_parse_attribute, - "Expected array entry!") + "Expected array entry!") self.parse_characters( "]", "Malformed array contents (expected end of array here!") return ArrayAttr.from_list(attrs) @@ -1414,19 +1408,18 @@ def parse_function_type(self) -> FunctionType: self.parse_characters( "(", "First group of function args must start with a `(`") - args: list[Attribute] = self.parse_list_of( - self.try_parse_type, "Expected type here!") + args: list[Attribute] = self.parse_list_of(self.try_parse_type, + "Expected type here!") self.parse_characters( ")", "Malformed function type, expected closing brackets of argument types!" ) - self.parse_characters("->", - "Malformed function type, expected `->`!") + self.parse_characters("->", "Malformed function type, expected `->`!") - return FunctionType.from_lists( - args, self._parse_type_or_type_list_parens()) + return FunctionType.from_lists(args, + self._parse_type_or_type_list_parens()) def _parse_type_or_type_list_parens(self) -> list[Attribute]: """ @@ -1437,10 +1430,9 @@ def _parse_type_or_type_list_parens(self) -> list[Attribute]: type-list-no-parens ::= type (`,` type)* """ if self.tokenizer.next_token_of_pattern("(") is not None: - args: list[Attribute] = self.parse_list_of( - self.try_parse_type, "Expected type here!") - self.parse_characters( - ")", "Unclosed function type argument list!") + args: list[Attribute] = self.parse_list_of(self.try_parse_type, + "Expected type here!") + self.parse_characters(")", "Unclosed function type argument list!") else: args = [self.try_parse_type()] if args[0] is None: @@ -1560,7 +1552,7 @@ def parse_paramattr_parameters( self.raise_error("Expected start attribute parameters here (`<`)!") res = self.parse_list_of(self.try_parse_attribute, - 'Expected another attribute here!') + 'Expected another attribute here!') if opening_brackets is not None and self.tokenizer.next_token_of_pattern( '>') is None: @@ -1644,8 +1636,8 @@ def _parse_op_result_list( self) -> tuple[list[Span], list[Attribute] | None]: return ( self.parse_list_of(self.try_parse_value_id, - "Expected op-result here!", - allow_empty=True), + "Expected op-result here!", + allow_empty=True), None, ) @@ -1660,7 +1652,7 @@ def parse_optional_attr_dict(self) -> dict[str, Attribute]: attrs = [] if not self.tokenizer.starts_with('}'): attrs = self.parse_list_of(self._parse_attribute_entry, - "Expected attribute entry") + "Expected attribute entry") self.parse_characters( "}", @@ -1677,11 +1669,9 @@ def _parse_operation_details( regions = [] if self.tokenizer.starts_with("("): - self.parse_characters("(", - "Expected brackets enclosing regions!") + self.parse_characters("(", "Expected brackets enclosing regions!") regions = self.parse_region_list() - self.parse_characters(")", - "Expected brackets enclosing regions!") + self.parse_characters(")", "Expected brackets enclosing regions!") attrs = self.parse_optional_attr_dict() @@ -1695,20 +1685,20 @@ def _parse_operation_details( def _parse_optional_successor_list(self) -> list[Span]: if not self.tokenizer.starts_with("["): return [] - self.parse_characters( - "[", "Successor list is enclosed in square brackets") + self.parse_characters("[", + "Successor list is enclosed in square brackets") successors = self.parse_list_of(self.try_parse_block_id, - "Expected a block-id", - allow_empty=False) - self.parse_characters( - "]", "Successor list is enclosed in square brackets") + "Expected a block-id", + allow_empty=False) + self.parse_characters("]", + "Successor list is enclosed in square brackets") return successors def _parse_op_args_list(self) -> list[Span]: self.parse_characters( "(", "Operation args list must be enclosed by brackets!") args = self.parse_list_of(self.try_parse_value_id, - "Expected another bare-id here") + "Expected another bare-id here") self.parse_characters( ")", "Operation args list must be closed by a closing bracket") # TODO: check if type is correct here! @@ -1788,7 +1778,7 @@ def parse_optional_attr_dict(self) -> dict[str, Attribute]: "xDSL Attribute dictionary must be enclosed in square brackets") attrs = self.parse_list_of(self._parse_attribute_entry, - "Expected attribute entry") + "Expected attribute entry") self.parse_characters( "]", @@ -1820,13 +1810,13 @@ def _parse_operation_details( def _parse_optional_successor_list(self) -> list[Span]: if not self.tokenizer.starts_with("("): return [] - self.parse_characters( - "(", "Successor list is enclosed in round brackets") + self.parse_characters("(", + "Successor list is enclosed in round brackets") successors = self.parse_list_of(self.try_parse_block_id, - "Expected a block-id", - allow_empty=False) - self.parse_characters( - ")", "Successor list is enclosed in round brackets") + "Expected a block-id", + allow_empty=False) + self.parse_characters(")", + "Successor list is enclosed in round brackets") return successors def _parse_dialect_type_or_attribute_inner(self, kind: str): @@ -1845,10 +1835,10 @@ def _parse_generic_attribute_args(self, name: StringLiteral): self.raise_error("Unknown attribute name!", name) if not issubclass(attr, ParametrizedAttribute): self.raise_error("Expected ParametrizedAttribute name here!", name) - self.parse_characters( - '<', 'Expected generic attribute arguments here!') + self.parse_characters('<', + 'Expected generic attribute arguments here!') args = self.parse_list_of(self.try_parse_attribute, - 'Unexpected end of attribute list!') + 'Unexpected end of attribute list!') self.parse_characters( '>', 'Malformed attribute arguments, reached end of args list!') return attr(args) @@ -1857,7 +1847,7 @@ def _parse_op_args_list(self) -> list[Span]: self.parse_characters( "(", "Operation args list must be enclosed by brackets!") args = self.parse_list_of(self.try_parse_value_id_and_type, - "Expected another bare-id here") + "Expected another bare-id here") self.parse_characters( ")", "Operation args list must be closed by a closing bracket") # TODO: check if type is correct here! From cc6761581f12114aa82e62ada9d08bce044992fa Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Mon, 23 Jan 2023 16:43:55 +0000 Subject: [PATCH 64/65] formatting: fixed typos and other minor issues --- xdsl/parser.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/xdsl/parser.py b/xdsl/parser.py index 53dc790b41..24df8c7532 100644 --- a/xdsl/parser.py +++ b/xdsl/parser.py @@ -315,7 +315,7 @@ def backtracking(self, region_name: str | None = None): try: yield # Clear error history when something doesn't fail - # Lhis is because we are only interested in the last "cascade" of failures. + # This is because we are only interested in the last "cascade" of failures. # If a backtracking() completes without failure, something has been parsed (we assume) if self.pos > starting_position and self.history is not None: self.history = None @@ -523,7 +523,7 @@ def starts_with(self, text: str | re.Pattern) -> bool: class ParserCommons: """ - Colelction of common things used in parsing MLIR/IRDL + Collection of common things used in parsing MLIR/IRDL """ @@ -626,7 +626,7 @@ def _get_block_from_name(self, block_name: Span) -> Block: """ This function takes a span containing a block id (like `^42`) and returns a block. - If the block defintion was not seen yet, we create a forward declaration. + If the block definition was not seen yet, we create a forward declaration. """ name = block_name.text if name not in self.blocks: @@ -872,12 +872,12 @@ def try_parse_builtin_type(self) -> Attribute | None: """ parse a builtin-type like i32, index, vector etc. """ - raise NotImplemented("Subclasses must implement this method!") + raise NotImplementedError("Subclasses must implement this method!") def _parse_builtin_parametrized_type(self, name: Span) -> ParametrizedAttribute: """ - This function is called after we parse the name of a paremetrized type such as vector. + This function is called after we parse the name of a parameterized type such as vector. """ def unimplemented() -> ParametrizedAttribute: @@ -1027,7 +1027,7 @@ def raise_error(self, msg: str, at_position: Span | None = None): """ Helper for raising exceptions, provides as much context as possible to them. - This will, for example, include backtracking errors, if any occured previously + This will, for example, include backtracking errors, if any occurred previously """ if at_position is None: at_position = self.tokenizer.next_token(peek=True) @@ -1042,7 +1042,7 @@ def parse_characters(self, text: str, msg: str) -> Span: @abstractmethod def _parse_op_result_list( self) -> tuple[list[Span], list[Attribute] | None]: - raise NotImplemented() + raise NotImplementedError() def try_parse_operation(self) -> Operation | None: with self.tokenizer.backtracking("operation"): @@ -1163,8 +1163,8 @@ def _parse_attribute_entry(self) -> tuple[Span, Attribute]: """ Parse entry in attribute dict. Of format: - attrbiute_entry := (bare-id | string-literal) `=` attribute - attrbiute := dialect-attribute | builtin-attribute + attribute_entry := (bare-id | string-literal) `=` attribute + attribute := dialect-attribute | builtin-attribute """ if (name := self.try_parse_bare_id()) is None: name = self.try_parse_string_literal() @@ -1189,7 +1189,7 @@ def parse_attribute(self) -> Attribute: This is different in xDSL and MLIR, so the actuall implementation is provided by the subclass """ - raise NotImplemented() + raise NotImplementedError() def try_parse_attribute(self) -> Attribute | None: with self.tokenizer.backtracking("attribute"): @@ -1207,7 +1207,7 @@ def _parse_attribute_type(self) -> Attribute: def try_parse_builtin_attr(self) -> Attribute | None: """ - Tries to parse a bultin attribute, e.g. a string literal, int, array, etc.. + Tries to parse a builtin attribute, e.g. a string literal, int, array, etc.. """ next_token = self.tokenizer.next_token(peek=True) if next_token.text == '"': @@ -1570,7 +1570,7 @@ def parse_str_literal(self) -> str: 'Malformed string literal!').string_contents def parse_attribute(self) -> Attribute: - return self.parse_attribute() + raise NotImplementedError() def parse_op(self) -> Operation: return self.parse_operation() @@ -1677,7 +1677,7 @@ def _parse_operation_details( self.parse_characters( ":", - "MLIR Operation defintions must end in a function type signature!") + "MLIR Operation definitions must end in a function type signature!") func_type = self.parse_function_type() return args, succ, attrs, regions, func_type @@ -1757,7 +1757,7 @@ def _parse_op_result_list( def try_parse_builtin_attr(self) -> Attribute: """ - Tries to parse a bultin attribute, e.g. a string literal, int, array, etc.. + Tries to parse a builtin attribute, e.g. a string literal, int, array, etc.. If the mode is xDSL, it also allows parsing of builtin types """ From 3c4349a92aa53c9a4ad9c08f816c62c600f53872 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Mon, 23 Jan 2023 16:45:49 +0000 Subject: [PATCH 65/65] formatting: yapf run --- xdsl/parser.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xdsl/parser.py b/xdsl/parser.py index 24df8c7532..85e32fc41b 100644 --- a/xdsl/parser.py +++ b/xdsl/parser.py @@ -1677,7 +1677,8 @@ def _parse_operation_details( self.parse_characters( ":", - "MLIR Operation definitions must end in a function type signature!") + "MLIR Operation definitions must end in a function type signature!" + ) func_type = self.parse_function_type() return args, succ, attrs, regions, func_type