diff --git a/tests/test_dialect_utils.py b/tests/test_dialect_utils.py index 093f2c39ae..171d95373f 100644 --- a/tests/test_dialect_utils.py +++ b/tests/test_dialect_utils.py @@ -11,7 +11,7 @@ parse_dynamic_index_without_type, print_dynamic_index_list, ) -from xdsl.ir import SSAValue +from xdsl.ir import Dialect, SSAValue from xdsl.parser import Parser, UnresolvedOperand from xdsl.printer import Printer from xdsl.utils.test_value import TestSSAValue @@ -152,3 +152,23 @@ def test_parse_dynamic_index_list_with_custom_delimiter(): assert values[0] is test_values[0] assert values[1] is test_values[1] assert tuple(indices) == (dynamic_index, 42, dynamic_index) + + +@pytest.mark.parametrize( + "name,expected_1,expected_2", + [ + ("dialect.op_name", "dialect", "op_name"), + ("dialect.op.name", "dialect", "op.name"), + ], +) +def test_split_name(name: str, expected_1: str, expected_2: str): + result_1, result_2 = Dialect.split_name(name) + assert result_1 == expected_1 + assert result_2 == expected_2 + + +def test_split_name_failure(): + with pytest.raises(ValueError) as e: + Dialect.split_name("test") + + assert e.value.args[0] == ("Invalid operation or attribute name test.") diff --git a/tests/test_parser_error.py b/tests/test_parser_error.py index 6bd254c4ff..dd67b2c905 100644 --- a/tests/test_parser_error.py +++ b/tests/test_parser_error.py @@ -15,15 +15,15 @@ @irdl_op_definition -class UnkownOp(IRDLOperation): - name = "unknown" +class UnknownOp(IRDLOperation): + name = "test.unknown" ops: VarOperand = var_operand_def(AnyAttr()) res: VarOpResult = var_result_def(AnyAttr()) def check_error(prog: str, line: int, column: int, message: str): ctx = MLContext() - ctx.load_op(UnkownOp) + ctx.load_op(UnknownOp) parser = Parser(ctx, prog) with pytest.raises(ParseError, match=message) as e: @@ -35,11 +35,11 @@ def check_error(prog: str, line: int, column: int, message: str): def test_parser_missing_equal(): """Test a missing equal sign error.""" ctx = MLContext() - ctx.load_op(UnkownOp) + ctx.load_op(UnknownOp) prog = """ -"unknown"() ({ - %0 "unknown"() : () -> !i32 +"test.unknown"() ({ + %0 "test.unknown"() : () -> !i32 }) : () -> () """ check_error(prog, 3, 5, "Expected '=' after operation result list") @@ -48,12 +48,12 @@ def test_parser_missing_equal(): def test_parser_redefined_value(): """Test an SSA value redefinition error.""" ctx = MLContext() - ctx.load_op(UnkownOp) + ctx.load_op(UnknownOp) prog = """ -"unknown"() ({ - %val = "unknown"() : () -> i32 - %val = "unknown"() : () -> i32 +"test.unknown"() ({ + %val = "test.unknown"() : () -> i32 + %val = "test.unknown"() : () -> i32 }) : () -> () """ check_error(prog, 4, 2, "SSA value %val is already defined") @@ -62,10 +62,10 @@ def test_parser_redefined_value(): def test_parser_missing_operation_name(): """Test a missing operation name error.""" ctx = MLContext() - ctx.load_op(UnkownOp) + ctx.load_op(UnknownOp) prog = """ -"unknown"() ({ +"test.unknown"() ({ %val = }) : () -> () """ diff --git a/tests/test_printer.py b/tests/test_printer.py index 659e1943f0..6c3eb2eed7 100644 --- a/tests/test_printer.py +++ b/tests/test_printer.py @@ -702,7 +702,7 @@ def print_parameters(self, printer: Printer) -> None: @irdl_op_definition class AnyOp(IRDLOperation): - name = "any" + name = "test.any" def test_custom_format_attr(): @@ -711,13 +711,13 @@ def test_custom_format_attr(): """ prog = """\ "builtin.module"() ({ - "any"() {"attr" = #test.custom} : () -> () + "test.any"() {"attr" = #test.custom} : () -> () }) : () -> () """ expected = """\ "builtin.module"() ({ - "any"() {"attr" = #test.custom} : () -> () + "test.any"() {"attr" = #test.custom} : () -> () }) : () -> ()""" ctx = MLContext() diff --git a/xdsl/context.py b/xdsl/context.py index 7fb02f6255..decd1e1d68 100644 --- a/xdsl/context.py +++ b/xdsl/context.py @@ -2,8 +2,10 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING +from xdsl.ir import Dialect + if TYPE_CHECKING: - from xdsl.ir import Attribute, Dialect, Operation + from xdsl.ir import Attribute, Operation @dataclass @@ -122,7 +124,7 @@ def get_optional_op(self, name: str) -> "type[Operation] | None": # Otherwise, check if the operation dialect is registered. if "." in name: - dialect_name, _ = name.split(".", 1) + dialect_name, _ = Dialect.split_name(name) if ( dialect_name in self._registered_dialects and dialect_name not in self._loaded_dialects @@ -168,7 +170,7 @@ def get_optional_attr( return self._loaded_attrs[name] # Otherwise, check if the attribute dialect is registered. - dialect_name, _ = name.split(".", 1) + dialect_name, _ = Dialect.split_name(name) if ( dialect_name in self._registered_dialects and dialect_name not in self._loaded_dialects diff --git a/xdsl/dialects/riscv.py b/xdsl/dialects/riscv.py index 61400c0475..4254df9400 100644 --- a/xdsl/dialects/riscv.py +++ b/xdsl/dialects/riscv.py @@ -471,7 +471,7 @@ def assembly_instruction_name(self) -> str: By default, the name of the instruction is the same as the name of the operation. """ - return self.name.split(".", 1)[-1] + return Dialect.split_name(self.name)[1] def assembly_line(self) -> str | None: # default assembly code generator diff --git a/xdsl/dialects/x86/ops.py b/xdsl/dialects/x86/ops.py index 0ef74ed7db..f3329bf54b 100644 --- a/xdsl/dialects/x86/ops.py +++ b/xdsl/dialects/x86/ops.py @@ -18,6 +18,7 @@ from xdsl.dialects.func import FuncOp from xdsl.ir import ( Attribute, + Dialect, Operation, SSAValue, ) @@ -173,7 +174,7 @@ def assembly_instruction_name(self) -> str: By default, the name of the instruction is the same as the name of the operation. """ - return self.name.split(".", 1)[-1] + return Dialect.split_name(self.name)[1] def assembly_line(self) -> str | None: # default assembly code generator diff --git a/xdsl/interpreters/irdl.py b/xdsl/interpreters/irdl.py index 42bf220c08..76e7531161 100644 --- a/xdsl/interpreters/irdl.py +++ b/xdsl/interpreters/irdl.py @@ -50,7 +50,9 @@ def get_op(interpreter: Interpreter, name: str) -> type[Operation]: """ Get an operation type by name from the interpreter's state """ - ops = IRDLFunctions.get_dialect(interpreter, name.split(".", 1)[0]).operations + ops = IRDLFunctions.get_dialect( + interpreter, Dialect.split_name(name)[0] + ).operations for op in ops: if op.name == name: return op @@ -61,7 +63,9 @@ def get_attr(interpreter: Interpreter, name: str) -> type[ParametrizedAttribute] """ Get an attribute type by name from the interpreter's state """ - attrs = IRDLFunctions.get_dialect(interpreter, name.split(".", 1)[0]).attributes + attrs = IRDLFunctions.get_dialect( + interpreter, Dialect.split_name(name)[0] + ).attributes for attr in attrs: if attr.name == name: if not issubclass(attr, ParametrizedAttribute): diff --git a/xdsl/ir/core.py b/xdsl/ir/core.py index be9a37f0c5..e0607efe90 100644 --- a/xdsl/ir/core.py +++ b/xdsl/ir/core.py @@ -61,6 +61,15 @@ def attributes(self) -> Iterator[type[Attribute]]: def name(self) -> str: return self._name + @staticmethod + def split_name(name: str) -> tuple[str, str]: + try: + names = name.split(".", 1) + first, second = names + return (first, second) + except ValueError as e: + raise ValueError(f"Invalid operation or attribute name {name}.") from e + @dataclass(frozen=True) class Use: @@ -1064,7 +1073,7 @@ def emit_error( @classmethod def dialect_name(cls) -> str: - return cls.name.split(".")[0] + return Dialect.split_name(cls.name)[0] def __eq__(self, other: object) -> bool: return self is other