Skip to content

Commit

Permalink
xDSL: Rewrite Parser (#262)
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonLydike authored Jan 23, 2023
1 parent b6678ee commit b09e94e
Show file tree
Hide file tree
Showing 22 changed files with 2,061 additions and 1,412 deletions.
8 changes: 4 additions & 4 deletions tests/filecheck/parser-printer/float_parsing.xdsl
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@ builtin.module() {
%1 : !f32 = arith.constant() ["value" = -42.0 : !f32]
// CHECK-NEXT: %{{.*}} : !f32 = arith.constant() ["value" = -42.0 : !f32]

%2 : !f32 = arith.constant() ["value" = 34e0 : !f32]
%2 : !f32 = arith.constant() ["value" = 34.e0 : !f32]
// CHECK-NEXT: %{{.*}} : !f32 = arith.constant() ["value" = 34.0 : !f32]

%3 : !f32 = arith.constant() ["value" = 34e-23 : !f32]
%3 : !f32 = arith.constant() ["value" = 34.e-23 : !f32]
// CHECK-NEXT: %{{.*}} : !f32 = arith.constant() ["value" = 3.4e-22 : !f32]

%4 : !f32 = arith.constant() ["value" = 34e12 : !f32]
%4 : !f32 = arith.constant() ["value" = 34.e12 : !f32]
// CHECK-NEXT: %{{.*}} : !f32 = arith.constant() ["value" = 34000000000000.0 : !f32]

%5 : !f32 = arith.constant() ["value" = -34e-12 : !f32]
%5 : !f32 = arith.constant() ["value" = -34.e-12 : !f32]
// CHECK-NEXT: %{{.*}} : !f32 = arith.constant() ["value" = -3.4e-11 : !f32]

func.return()
Expand Down
12 changes: 6 additions & 6 deletions tests/test_attribute_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from xdsl.ir import ParametrizedAttribute, Data
from xdsl.irdl import irdl_attr_definition, builder
from xdsl.parser import Parser
from xdsl.parser import BaseParser
from xdsl.printer import Printer
from xdsl.utils.exceptions import BuilderNotFoundException

Expand Down Expand Up @@ -34,7 +34,7 @@ def from_int(data: int) -> OneBuilderAttr:
return OneBuilderAttr(str(data))

@staticmethod
def parse_parameter(parser: Parser) -> str:
def parse_parameter(parser: BaseParser) -> str:
raise NotImplementedError()

@staticmethod
Expand Down Expand Up @@ -72,7 +72,7 @@ def from_int(data1: int, data2: str) -> OneBuilderAttr:
return OneBuilderAttr(str(data1) + data2)

@staticmethod
def parse_parameter(parser: Parser) -> str:
def parse_parameter(parser: BaseParser) -> str:
raise NotImplementedError()

@staticmethod
Expand Down Expand Up @@ -105,7 +105,7 @@ def from_str(s: str) -> TwoBuildersAttr:
return TwoBuildersAttr(s)

@staticmethod
def parse_parameter(parser: Parser) -> str:
def parse_parameter(parser: BaseParser) -> str:
raise NotImplementedError()

@staticmethod
Expand Down Expand Up @@ -145,7 +145,7 @@ def from_int(data1: int,
return BuilderDefaultArgAttr(f"{data1}, {data2}, {data3}")

@staticmethod
def parse_parameter(parser: Parser) -> str:
def parse_parameter(parser: BaseParser) -> str:
raise NotImplementedError()

@staticmethod
Expand Down Expand Up @@ -188,7 +188,7 @@ def from_int(data: str | int) -> BuilderUnionArgAttr:
return BuilderUnionArgAttr(str(data))

@staticmethod
def parse_parameter(parser: Parser) -> str:
def parse_parameter(parser: BaseParser) -> str:
raise NotImplementedError()

@staticmethod
Expand Down
27 changes: 13 additions & 14 deletions tests/test_attribute_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from xdsl.irdl import (AttrConstraint, GenericData, ParameterDef,
irdl_attr_definition, builder, irdl_to_attr_constraint,
AnyAttr, BaseAttr, ParamAttrDef)
from xdsl.parser import Parser
from xdsl.parser import BaseParser
from xdsl.printer import Printer
from xdsl.utils.exceptions import VerifyException

Expand All @@ -31,14 +31,13 @@ class BoolData(Data[bool]):
name = "bool"

@staticmethod
def parse_parameter(parser: Parser) -> bool:
val = parser.parse_optional_ident()
if val == "True":
def parse_parameter(parser: BaseParser) -> bool:
val = parser.tokenizer.next_token_of_pattern('(True|False)')
if val is None or val.text not in ('True', 'False'):
parser.raise_error("Expected True or False literal")
if val.text == "True":
return True
elif val == "False":
return False
else:
raise Exception("Wrong argument passed to BoolAttr.")
return False

@staticmethod
def print_parameter(data: bool, printer: Printer):
Expand All @@ -51,7 +50,7 @@ class IntData(Data[int]):
name = "int"

@staticmethod
def parse_parameter(parser: Parser) -> int:
def parse_parameter(parser: BaseParser) -> int:
return parser.parse_int_literal()

@staticmethod
Expand All @@ -65,7 +64,7 @@ class StringData(Data[str]):
name = "str"

@staticmethod
def parse_parameter(parser: Parser) -> str:
def parse_parameter(parser: BaseParser) -> str:
return parser.parse_str_literal()

@staticmethod
Expand Down Expand Up @@ -102,7 +101,7 @@ class IntListMissingVerifierData(Data[list[int]]):
name = "missing_verifier_data"

@staticmethod
def parse_parameter(parser: Parser) -> list[int]:
def parse_parameter(parser: BaseParser) -> list[int]:
raise NotImplementedError()

@staticmethod
Expand Down Expand Up @@ -134,7 +133,7 @@ class IntListData(Data[list[int]]):
name = "int_list"

@staticmethod
def parse_parameter(parser: Parser) -> list[int]:
def parse_parameter(parser: BaseParser) -> list[int]:
raise NotImplementedError()

@staticmethod
Expand Down Expand Up @@ -431,7 +430,7 @@ class MissingGenericDataData(Data[_MissingGenericDataData]):
name = "missing_genericdata"

@staticmethod
def parse_parameter(parser: Parser) -> _MissingGenericDataData:
def parse_parameter(parser: BaseParser) -> _MissingGenericDataData:
raise NotImplementedError()

@staticmethod
Expand Down Expand Up @@ -484,7 +483,7 @@ class ListData(GenericData[list[A]]):
name = "list"

@staticmethod
def parse_parameter(parser: Parser) -> list[A]:
def parse_parameter(parser: BaseParser) -> list[A]:
raise NotImplementedError()

@staticmethod
Expand Down
10 changes: 5 additions & 5 deletions tests/test_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from xdsl.dialects.arith import Addi, Subi, Constant
from xdsl.dialects.builtin import i32, IntegerAttr, ModuleOp
from xdsl.dialects.scf import If
from xdsl.parser import Parser
from xdsl.parser import XDSLParser
from xdsl.dialects.builtin import Builtin
from xdsl.dialects.func import Func
from xdsl.dialects.arith import Arith
Expand Down Expand Up @@ -203,10 +203,10 @@ def test_is_structurally_equivalent(args: list[str], expected_result: bool):
ctx.register_dialect(Arith)
ctx.register_dialect(Cf)

parser = Parser(ctx, args[0])
parser = XDSLParser(ctx, args[0])
lhs: Operation = parser.parse_op()

parser = Parser(ctx, args[1])
parser = XDSLParser(ctx, args[1])
rhs: Operation = parser.parse_op()

assert lhs.is_structurally_equivalent(rhs) == expected_result
Expand All @@ -231,8 +231,8 @@ def test_is_structurally_equivalent_incompatible_ir_nodes():
ctx.register_dialect(Arith)
ctx.register_dialect(Cf)

parser = Parser(ctx, program_func)
program: ModuleOp = parser.parse_op()
parser = XDSLParser(ctx, program_func)
program: ModuleOp = parser.parse_operation()

assert program.is_structurally_equivalent(program.regions[0]) == False
assert program.is_structurally_equivalent(
Expand Down
6 changes: 3 additions & 3 deletions tests/test_irdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from xdsl.ir import Attribute, Data, ParametrizedAttribute
from xdsl.irdl import AllOf, AnyAttr, AnyOf, AttrConstraint, BaseAttr, EqAttrConstraint, ParamAttrConstraint, ParameterDef, irdl_attr_definition
from xdsl.parser import Parser
from xdsl.parser import BaseParser
from xdsl.printer import Printer
from xdsl.utils.exceptions import VerifyException

Expand All @@ -16,7 +16,7 @@ class BoolData(Data[bool]):
name = "bool"

@staticmethod
def parse_parameter(parser: Parser) -> bool:
def parse_parameter(parser: BaseParser) -> bool:
raise NotImplementedError()

@staticmethod
Expand All @@ -30,7 +30,7 @@ class IntData(Data[int]):
name = "int"

@staticmethod
def parse_parameter(parser: Parser) -> int:
def parse_parameter(parser: BaseParser) -> int:
return parser.parse_int_literal()

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions tests/test_mlir_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from xdsl.dialects.affine import Affine
from xdsl.dialects.arith import Arith

from xdsl.parser import Parser
from xdsl.parser import XDSLParser
from xdsl.ir import MLContext
from xdsl.dialects.builtin import Builtin

Expand All @@ -23,7 +23,7 @@ def convert_and_verify(test_prog: str):
ctx.register_dialect(Scf)
ctx.register_dialect(MemRef)

parser = Parser(ctx, test_prog)
parser = XDSLParser(ctx, test_prog)
module = parser.parse_op()
module.verify()

Expand Down
19 changes: 8 additions & 11 deletions tests/test_mlir_printer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import re
from io import StringIO
from typing import Annotated
import re

from xdsl.dialects.builtin import Builtin
from xdsl.dialects.memref import MemRef
from xdsl.dialects.func import Func
from xdsl.ir import Attribute, Data, MLContext, MLIRType, Operation, ParametrizedAttribute
from xdsl.irdl import (AnyAttr, ParameterDef, RegionDef, VarOpResult,
VarOperand, irdl_attr_definition, irdl_op_definition)
from xdsl.parser import Parser
from xdsl.irdl import (AnyAttr, ParameterDef, RegionDef, irdl_attr_definition,
irdl_op_definition, VarOperand, VarOpResult)
from xdsl.parser import BaseParser, XDSLParser
from xdsl.printer import Printer


Expand All @@ -33,7 +30,7 @@ class DataAttr(Data[int]):
name = "data_attr"

@staticmethod
def parse_parameter(parser: Parser) -> int:
def parse_parameter(parser: BaseParser) -> int:
return parser.parse_int_literal()

@staticmethod
Expand All @@ -47,7 +44,7 @@ class DataType(Data[int], MLIRType):
name = "data_type"

@staticmethod
def parse_parameter(parser: Parser) -> int:
def parse_parameter(parser: BaseParser) -> int:
return parser.parse_int_literal()

@staticmethod
Expand Down Expand Up @@ -92,8 +89,8 @@ def print_as_mlir_and_compare(test_prog: str, expected: str):
ctx.register_attr(ParamAttrWithParam)
ctx.register_attr(ParamAttrWithCustomFormat)

parser = Parser(ctx, test_prog)
module = parser.parse_op()
parser = XDSLParser(ctx, test_prog)
module = parser.parse_operation()

res = StringIO()
printer = Printer(target=Printer.Target.MLIR, stream=res)
Expand Down
58 changes: 32 additions & 26 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,43 @@
from io import StringIO

import pytest

from xdsl.ir import MLContext
from xdsl.parser import Parser
from xdsl.printer import Printer
from xdsl.ir import MLContext, Attribute
from xdsl.parser import XDSLParser
from xdsl.dialects.builtin import IntAttr, DictionaryAttr, StringAttr, ArrayAttr, Builtin


@pytest.mark.parametrize("input,expected", [("0, 1, 1", [0, 1, 1]),
("1, 0, 1", [1, 0, 1]),
("1, 1, 0", [1, 1, 0])])
def test_int_list_parser(input: str, expected: list[int]):
ctx = MLContext()
parser = Parser(ctx, input)

int_list = parser.parse_list(parser.parse_int_literal)
assert int_list == expected


@pytest.mark.parametrize("input,expected", [('{"A"=0, "B"=1, "C"=2}', {
"A": 0,
"B": 1,
"C": 2
}), ('{"MA"=10, "BR"=7, "Z"=3}', {
"MA": 10,
"BR": 7,
"Z": 3
}), ('{"Q"=77, "VV"=12, "AA"=-8}', {
"Q": 77,
"VV": 12,
"AA": -8
})])
def test_int_dictionary_parser(input: str, expected: dict[str, int]):
parser = XDSLParser(ctx, input)

int_list = parser.parse_list_of(parser.try_parse_integer_literal, '')
assert [int(span.text) for span in int_list] == expected


@pytest.mark.parametrize('data', [
dict(a=IntAttr.from_int(1), b=IntAttr.from_int(2), c=IntAttr.from_int(3)),
dict(a=StringAttr.from_str('hello'),
b=IntAttr.from_int(2),
c=ArrayAttr.from_list(
[IntAttr.from_int(2),
StringAttr.from_str('world')])),
dict(),
])
def test_dictionary_attr(data: dict[str, Attribute]):
attr = DictionaryAttr.from_dict(data)

with StringIO() as io:
Printer(io).print(attr)
text = io.getvalue()

ctx = MLContext()
parser = Parser(ctx, input)
ctx.register_dialect(Builtin)

attr = XDSLParser(ctx, text).parse_attribute()

int_dict = parser.parse_dictionary(parser.parse_str_literal,
parser.parse_int_literal)
assert int_dict == expected
assert attr.data == data
Loading

0 comments on commit b09e94e

Please sign in to comment.