Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

misc: Add Dialect.split_name method #3003

Merged
merged 6 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion tests/test_dialect_utils.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we've added an error, it would be good to add a pytest that tracks the error message. The more I think of it, the more it seems that raising ValueError instead of asserting would be a little cleaner. Sorry about the churn. There are a bunch of examples in the tests of how to catch errors and test for contents.

Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
parse_dynamic_index_without_type,
print_dynamic_index_list,
)
from xdsl.ir import SSAValue
from xdsl.ir import Dialect, SSAValue
from xdsl.parser import Parser, UnresolvedOperand
from xdsl.printer import Printer
from xdsl.utils.test_value import TestSSAValue
Expand Down Expand Up @@ -152,3 +152,23 @@ def test_parse_dynamic_index_list_with_custom_delimiter():
assert values[0] is test_values[0]
assert values[1] is test_values[1]
assert tuple(indices) == (dynamic_index, 42, dynamic_index)


@pytest.mark.parametrize(
"name,expected_1,expected_2",
[
("dialect.op_name", "dialect", "op_name"),
("dialect.op.name", "dialect", "op.name"),
],
)
def test_split_name(name: str, expected_1: str, expected_2: str):
result_1, result_2 = Dialect.split_name(name)
assert result_1 == expected_1
assert result_2 == expected_2


def test_split_name_failure():
with pytest.raises(ValueError) as e:
Dialect.split_name("test")

assert e.value.args[0] == ("Invalid operation or attribute name test.")
24 changes: 12 additions & 12 deletions tests/test_parser_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@


@irdl_op_definition
class UnkownOp(IRDLOperation):
name = "unknown"
class UnknownOp(IRDLOperation):
name = "test.unknown"
ops: VarOperand = var_operand_def(AnyAttr())
res: VarOpResult = var_result_def(AnyAttr())


def check_error(prog: str, line: int, column: int, message: str):
ctx = MLContext()
ctx.load_op(UnkownOp)
ctx.load_op(UnknownOp)

parser = Parser(ctx, prog)
with pytest.raises(ParseError, match=message) as e:
Expand All @@ -35,11 +35,11 @@ def check_error(prog: str, line: int, column: int, message: str):
def test_parser_missing_equal():
"""Test a missing equal sign error."""
ctx = MLContext()
ctx.load_op(UnkownOp)
ctx.load_op(UnknownOp)

prog = """
"unknown"() ({
%0 "unknown"() : () -> !i32
"test.unknown"() ({
%0 "test.unknown"() : () -> !i32
}) : () -> ()
"""
check_error(prog, 3, 5, "Expected '=' after operation result list")
Expand All @@ -48,12 +48,12 @@ def test_parser_missing_equal():
def test_parser_redefined_value():
"""Test an SSA value redefinition error."""
ctx = MLContext()
ctx.load_op(UnkownOp)
ctx.load_op(UnknownOp)

prog = """
"unknown"() ({
%val = "unknown"() : () -> i32
%val = "unknown"() : () -> i32
"test.unknown"() ({
%val = "test.unknown"() : () -> i32
%val = "test.unknown"() : () -> i32
}) : () -> ()
"""
check_error(prog, 4, 2, "SSA value %val is already defined")
Expand All @@ -62,10 +62,10 @@ def test_parser_redefined_value():
def test_parser_missing_operation_name():
"""Test a missing operation name error."""
ctx = MLContext()
ctx.load_op(UnkownOp)
ctx.load_op(UnknownOp)

prog = """
"unknown"() ({
"test.unknown"() ({
%val =
}) : () -> ()
"""
Expand Down
6 changes: 3 additions & 3 deletions tests/test_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ def print_parameters(self, printer: Printer) -> None:

@irdl_op_definition
class AnyOp(IRDLOperation):
name = "any"
name = "test.any"


def test_custom_format_attr():
Expand All @@ -711,13 +711,13 @@ def test_custom_format_attr():
"""
prog = """\
"builtin.module"() ({
"any"() {"attr" = #test.custom<zero>} : () -> ()
"test.any"() {"attr" = #test.custom<zero>} : () -> ()
}) : () -> ()
"""

expected = """\
"builtin.module"() ({
"any"() {"attr" = #test.custom<zero>} : () -> ()
"test.any"() {"attr" = #test.custom<zero>} : () -> ()
}) : () -> ()"""

ctx = MLContext()
Expand Down
8 changes: 5 additions & 3 deletions xdsl/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from dataclasses import dataclass, field
from typing import TYPE_CHECKING

from xdsl.ir import Dialect

if TYPE_CHECKING:
from xdsl.ir import Attribute, Dialect, Operation
from xdsl.ir import Attribute, Operation


@dataclass
Expand Down Expand Up @@ -122,7 +124,7 @@ def get_optional_op(self, name: str) -> "type[Operation] | None":

# Otherwise, check if the operation dialect is registered.
if "." in name:
dialect_name, _ = name.split(".", 1)
dialect_name, _ = Dialect.split_name(name)
if (
dialect_name in self._registered_dialects
and dialect_name not in self._loaded_dialects
Expand Down Expand Up @@ -168,7 +170,7 @@ def get_optional_attr(
return self._loaded_attrs[name]

# Otherwise, check if the attribute dialect is registered.
dialect_name, _ = name.split(".", 1)
dialect_name, _ = Dialect.split_name(name)
if (
dialect_name in self._registered_dialects
and dialect_name not in self._loaded_dialects
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/riscv.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def assembly_instruction_name(self) -> str:
By default, the name of the instruction is the same as the name of the operation.
"""

return self.name.split(".", 1)[-1]
return Dialect.split_name(self.name)[1]

def assembly_line(self) -> str | None:
# default assembly code generator
Expand Down
3 changes: 2 additions & 1 deletion xdsl/dialects/x86/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from xdsl.dialects.func import FuncOp
from xdsl.ir import (
Attribute,
Dialect,
Operation,
SSAValue,
)
Expand Down Expand Up @@ -173,7 +174,7 @@ def assembly_instruction_name(self) -> str:
By default, the name of the instruction is the same as the name of the operation.
"""

return self.name.split(".", 1)[-1]
return Dialect.split_name(self.name)[1]

def assembly_line(self) -> str | None:
# default assembly code generator
Expand Down
8 changes: 6 additions & 2 deletions xdsl/interpreters/irdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def get_op(interpreter: Interpreter, name: str) -> type[Operation]:
"""
Get an operation type by name from the interpreter's state
"""
ops = IRDLFunctions.get_dialect(interpreter, name.split(".", 1)[0]).operations
ops = IRDLFunctions.get_dialect(
interpreter, Dialect.split_name(name)[0]
).operations
for op in ops:
if op.name == name:
return op
Expand All @@ -61,7 +63,9 @@ def get_attr(interpreter: Interpreter, name: str) -> type[ParametrizedAttribute]
"""
Get an attribute type by name from the interpreter's state
"""
attrs = IRDLFunctions.get_dialect(interpreter, name.split(".", 1)[0]).attributes
attrs = IRDLFunctions.get_dialect(
interpreter, Dialect.split_name(name)[0]
).attributes
for attr in attrs:
if attr.name == name:
if not issubclass(attr, ParametrizedAttribute):
Expand Down
11 changes: 10 additions & 1 deletion xdsl/ir/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ def attributes(self) -> Iterator[type[Attribute]]:
def name(self) -> str:
return self._name

@staticmethod
def split_name(name: str) -> tuple[str, str]:
try:
names = name.split(".", 1)
first, second = names
return (first, second)
except ValueError as e:
raise ValueError(f"Invalid operation or attribute name {name}.") from e


@dataclass(frozen=True)
class Use:
Expand Down Expand Up @@ -1064,7 +1073,7 @@ def emit_error(

@classmethod
def dialect_name(cls) -> str:
return cls.name.split(".")[0]
return Dialect.split_name(cls.name)[0]

def __eq__(self, other: object) -> bool:
return self is other
Expand Down
Loading