diff --git a/tests/dialects/test_hw.py b/tests/dialects/test_hw.py index b29b9ba90e..403a20e05d 100644 --- a/tests/dialects/test_hw.py +++ b/tests/dialects/test_hw.py @@ -9,10 +9,10 @@ from xdsl.dialects.builtin import StringAttr from xdsl.dialects.hw import ( InnerRefAttr, - InnerRefNamespace, - InnerRefUserOpInterface, - InnerSymbolTable, + InnerRefNamespaceTrait, + InnerRefUserOpInterfaceTrait, InnerSymbolTableCollection, + InnerSymbolTableTrait, InnerSymTarget, ) from xdsl.dialects.test import TestOp @@ -58,7 +58,7 @@ class ModuleOp(IRDLOperation): name = "module" region = region_def() sym_name = attr_def(StringAttr) - traits = frozenset({InnerSymbolTable(), SymbolOpInterface()}) + traits = frozenset({InnerSymbolTableTrait(), SymbolOpInterface()}) @irdl_op_definition @@ -73,7 +73,11 @@ class CircuitOp(IRDLOperation): region: Region | None = opt_region_def() sym_name = attr_def(StringAttr) traits = frozenset( - {InnerRefNamespace(), SymbolTable(), SingleBlockImplicitTerminator(OutputOp)} + { + InnerRefNamespaceTrait(), + SymbolTable(), + SingleBlockImplicitTerminator(OutputOp), + } ) def __post_init__(self): @@ -85,12 +89,12 @@ def __post_init__(self): class WireOp(IRDLOperation): name = "wire" sym_name = attr_def(StringAttr) - traits = frozenset({InnerRefUserOpInterface()}) + traits = frozenset({InnerRefUserOpInterfaceTrait()}) def test_inner_symbol_table_interface(): """ - Test operations that conform to InnerSymbolTable + Test operations that conform to InnerSymbolTableTrait """ mod = ModuleOp( attributes={"sym_name": StringAttr("symbol_name")}, regions=[[OutputOp()]] @@ -104,7 +108,7 @@ def test_inner_symbol_table_interface(): ) with pytest.raises( VerifyException, - match="Operation module with trait InnerSymbolTable must have a parent with trait SymbolOpInterface", + match="Operation module with trait InnerSymbolTableTrait must have a parent with trait SymbolOpInterface", ): mod_no_parent.verify() @@ -114,12 +118,12 @@ def test_inner_symbol_table_interface(): no_trait_circ = TestOp(regions=[[mod_no_trait_circ, OutputOp()]]) with pytest.raises( VerifyException, - match="Operation module with trait InnerSymbolTable must have a parent with trait InnerRefNamespace", + match="Operation module with trait InnerSymbolTableTrait must have a parent with trait InnerRefNamespaceTrait", ): mod_no_trait_circ.verify() with pytest.raises( VerifyException, - match="Operation module with trait InnerSymbolTable must have a parent with trait InnerRefNamespace", + match="Operation module with trait InnerSymbolTableTrait must have a parent with trait InnerRefNamespaceTrait", ): no_trait_circ.verify() @@ -128,7 +132,7 @@ class MissingTraitModuleOp(IRDLOperation): name = "module" region = region_def() sym_name = attr_def(StringAttr) - traits = frozenset({InnerSymbolTable()}) + traits = frozenset({InnerSymbolTableTrait()}) mod_missing_trait = MissingTraitModuleOp( attributes={"sym_name": StringAttr("symbol_name")}, regions=[[OutputOp()]] @@ -147,7 +151,7 @@ class MissingTraitModuleOp(IRDLOperation): class MissingAttrModuleOp(IRDLOperation): name = "module" region = region_def() - traits = frozenset({InnerSymbolTable(), SymbolOpInterface()}) + traits = frozenset({InnerSymbolTableTrait(), SymbolOpInterface()}) mod_missing_trait_parent = ModuleOp(regions=[[OutputOp()]]) MissingAttrModuleOp(regions=[[mod_missing_trait_parent, OutputOp()]]) @@ -160,7 +164,7 @@ class MissingAttrModuleOp(IRDLOperation): def test_inner_ref_namespace_interface(): """ - Test operations that conform to InnerRefNamespace + Test operations that conform to InnerRefNamespaceTrait """ @irdl_op_definition @@ -169,7 +173,7 @@ class MissingTraitCircuitOp(IRDLOperation): region: Region | None = opt_region_def() sym_name = attr_def(StringAttr) traits = frozenset( - {InnerRefNamespace(), SingleBlockImplicitTerminator(OutputOp)} + {InnerRefNamespaceTrait(), SingleBlockImplicitTerminator(OutputOp)} ) wire0 = WireOp(attributes={"sym_name": StringAttr("wire0")}) @@ -197,9 +201,11 @@ class MissingTraitCircuitOp(IRDLOperation): attributes={"sym_name": StringAttr("circuit")}, regions=[[mod1, mod2]] ) - # InnerRefUserOpInterface.verify_inner_refs() does not do anything, so just mock + # InnerRefUserOpInterfaceTrait.verify_inner_refs() does not do anything, so just mock # to check it is called - with patch.object(InnerRefUserOpInterface, "verify_inner_refs") as inner_ref_verif: + with patch.object( + InnerRefUserOpInterfaceTrait, "verify_inner_refs" + ) as inner_ref_verif: circuit.verify() inner_ref_verif.assert_any_call(wire1, ANY) @@ -226,7 +232,7 @@ def test_inner_symbol_table_collection(): inner_sym_tables = InnerSymbolTableCollection(op=circuit) with pytest.raises( - VerifyException, match="Operation wire should have InnerSymbolTable trait" + VerifyException, match="Operation wire should have InnerSymbolTableTrait trait" ): inner_sym_tables.get_inner_symbol_table(wire1) @@ -234,14 +240,14 @@ def test_inner_symbol_table_collection(): sym_table2 = inner_sym_tables.get_inner_symbol_table(mod2) assert ( sym_table1 is not sym_table2 - ), "Different InnerSymbolTable objects must return different instances of inner symbol tables" + ), "Different InnerSymbolTableTrait objects must return different instances of inner symbol tables" unpopulated_inner_sym_tables = InnerSymbolTableCollection() sym_table3 = unpopulated_inner_sym_tables.get_inner_symbol_table(mod1) sym_table4 = unpopulated_inner_sym_tables.get_inner_symbol_table(mod2) assert ( sym_table3 is not sym_table4 - ), "InnerSymbolTable still behave as expected when created on the fly" + ), "InnerSymbolTableTrait still behave as expected when created on the fly" def test_inner_ref_attr(): diff --git a/xdsl/dialects/hw.py b/xdsl/dialects/hw.py index 8cb2ae76f5..e749ff6180 100644 --- a/xdsl/dialects/hw.py +++ b/xdsl/dialects/hw.py @@ -120,20 +120,8 @@ def print_parameters(self, printer: Printer) -> None: @dataclass(frozen=True) -class InnerSymbolTable(OpTrait): - """A trait for inner symbol table functionality on an operation. - Merges the upstream table of inner symbols and their resolutions and the op trait. - """ - - symbol_table: dict[StringAttr, InnerSymTarget] = field( - default_factory=dict, compare=False - ) - op: InitVar[Operation | None] = None - - def __post_init__(self, op: Operation | None = None) -> None: - if op is None: - return - # Here will populate self.symbol_table +class InnerSymbolTableTrait(OpTrait): + """A trait for inner symbol table functionality on an operation.""" def verify(self, op: Operation): # Insist that ops with InnerSymbolTable's provide a Symbol, this is @@ -143,29 +131,43 @@ def verify(self, op: Operation): f"Operation {op.name} must have trait {trait.__name__}" ) - # InnerSymbolTable's must be directly nested within an InnerRefNamespace, + # InnerSymbolTable's must be directly nested within an InnerRefNamespaceTrait, # however don’t test InnerRefNamespace’s symbol lookups parent = op.parent_op() if ( parent is None - or len(parent.get_traits_of_type(trait := InnerRefNamespace)) != 1 + or len(parent.get_traits_of_type(trait := InnerRefNamespaceTrait)) != 1 ): raise VerifyException( f"Operation {op.name} with trait {type(self).__name__} must have a parent with trait {trait.__name__}" ) +@dataclass +class InnerSymbolTable: + """A class for lookups in inner symbol tables. Called InnerSymbolTable in upstream (name clash with trait).""" + + op: InitVar[Operation | None] = None + symbol_table: dict[StringAttr, InnerSymTarget] = field(default_factory=dict) + + def __post_init__(self, op: Operation | None = None) -> None: + pass + # Here will populate self.symbol_table + + @dataclass class InnerSymbolTableCollection: - """This class represents an InnerSymbolTable collection.""" + """This class represents a collection of InnerSymbolTable.""" - symbol_tables: dict[Operation, InnerSymbolTable] = field(default_factory=dict) + symbol_tables: dict[Operation, InnerSymbolTable] = field( + default_factory=dict, init=False + ) op: InitVar[Operation | None] = None def __post_init__(self, op: Operation | None = None) -> None: if op is None: return - if not op.has_trait(trait := InnerRefNamespace): + if not op.has_trait(trait := InnerRefNamespaceTrait): raise VerifyException( f"Operation {op.name} should have {trait.__name__} trait" ) @@ -173,7 +175,7 @@ def __post_init__(self, op: Operation | None = None) -> None: def get_inner_symbol_table(self, op: Operation) -> InnerSymbolTable: """Returns the InnerSymolTable trait, ensuring `op` is in the collection""" - if not op.has_trait(trait := InnerSymbolTable): + if not op.has_trait(trait := InnerSymbolTableTrait): raise VerifyException( f"Operation {op.name} should have {trait.__name__} trait" ) @@ -185,7 +187,7 @@ def populate_and_verify_tables(self, inner_ref_ns_op: Operation): """Populate tables for all InnerSymbolTable operations in the given InnerRefNamespace operation, verifying each.""" # Gather top-level operations that have the InnerSymbolTable trait. inner_sym_table_ops = ( - op for op in inner_ref_ns_op.walk() if op.has_trait(InnerSymbolTable) + op for op in inner_ref_ns_op.walk() if op.has_trait(InnerSymbolTableTrait) ) # Construct the tables @@ -197,7 +199,7 @@ def populate_and_verify_tables(self, inner_ref_ns_op: Operation): self.symbol_tables[op] = InnerSymbolTable(op) -class InnerRefUserOpInterface(OpTrait): +class InnerRefUserOpInterfaceTrait(OpTrait): """This interface describes an operation that may use a `InnerRef`. This interface allows for users of inner symbols to hook into verification and other inner symbol related utilities that are either costly or otherwise @@ -209,21 +211,8 @@ def verify_inner_refs(self, op: Operation, namespace: "InnerRefNamespace"): @dataclass(frozen=True) -class InnerRefNamespace(OpTrait): - """Defines a new scope for InnerRef’s. Operations with this trait myst be a SymbolTable. - Combines InnerSymbolTableCollection with a SymbolTable for resolution of InnerRefAttrs, used during verification. - - Inner symbols are more costly than normal symbols, with tricker verification. For this reason, - verification is driven as a trait verifier on InnerRefNamespace which constructs and verifies InnerSymbolTables in parallel. - See: https://circt.llvm.org/docs/RationaleSymbols/#innerrefnamespace - """ - - symbol_table: SymbolTable = field( - default_factory=lambda: SymbolTable(), compare=False - ) - inner_sym_tables: InnerSymbolTableCollection = field( - default_factory=lambda: InnerSymbolTableCollection(), compare=False - ) +class InnerRefNamespaceTrait(OpTrait): + """Trait for operations defining a new scope for InnerRef’s. Operations with this trait must be a SymbolTable.""" def verify(self, op: Operation): if not op.has_trait(trait := SymbolTable): @@ -237,18 +226,31 @@ def verify(self, op: Operation): if len(op.regions[0].blocks) != 1: raise VerifyException(f"Operation {op.name} must have a single block") - inner_sym_tables = InnerSymbolTableCollection(op=op) - symbol_table = SymbolTable(op) - namespace = InnerRefNamespace( - symbol_table=symbol_table, inner_sym_tables=inner_sym_tables - ) + namespace = InnerRefNamespace(op) for inner_op in op.walk(): - inner_ref_user_op_trait = inner_op.get_trait(InnerRefUserOpInterface) + inner_ref_user_op_trait = inner_op.get_trait(InnerRefUserOpInterfaceTrait) if inner_ref_user_op_trait is not None: inner_ref_user_op_trait.verify_inner_refs(inner_op, namespace) +@dataclass +class InnerRefNamespace: + """Class to perform symbol lookups within a InnerRef namespace, used during verification. + Combines InnerSymbolTableCollection with a SymbolTable for resolution of InnerRefAttrs. + + Inner symbols are more costly than normal symbols, with tricker verification. For this reason, + verification is driven as a trait verifier on InnerRefNamespace which constructs and verifies InnerSymbolTables in parallel. + See: https://circt.llvm.org/docs/RationaleSymbols/#innerrefnamespace + """ + + inner_sym_tables: InnerSymbolTableCollection = field(init=False) + inner_ref_ns_op: InitVar[Operation] + + def __init__(self, inner_ref_ns_op: Operation): + self.inner_sym_tables = InnerSymbolTableCollection(inner_ref_ns_op) + + HW = Dialect( "hw", [],