Skip to content

Commit

Permalink
Split traits into trait and state-holding classes
Browse files Browse the repository at this point in the history
- Classes are named the same upstream (but in different namespaces?).
  However state-holding classes that do lookups and traits/interfaces do
  not seem to really by the same objects. Split accordingly.
- Rename HW traits with -Trait suffix (also InnerRefUserOpInterface for consistency).
  • Loading branch information
lucjaulmes committed Feb 5, 2024
1 parent 025eb87 commit e3eec8b
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 62 deletions.
44 changes: 25 additions & 19 deletions tests/dialects/test_hw.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from xdsl.dialects.builtin import StringAttr
from xdsl.dialects.hw import (
InnerRefAttr,
InnerRefNamespace,
InnerRefUserOpInterface,
InnerSymbolTable,
InnerRefNamespaceTrait,
InnerRefUserOpInterfaceTrait,
InnerSymbolTableCollection,
InnerSymbolTableTrait,
InnerSymTarget,
)
from xdsl.dialects.test import TestOp
Expand Down Expand Up @@ -58,7 +58,7 @@ class ModuleOp(IRDLOperation):
name = "module"
region = region_def()
sym_name = attr_def(StringAttr)
traits = frozenset({InnerSymbolTable(), SymbolOpInterface()})
traits = frozenset({InnerSymbolTableTrait(), SymbolOpInterface()})


@irdl_op_definition
Expand All @@ -73,7 +73,11 @@ class CircuitOp(IRDLOperation):
region: Region | None = opt_region_def()
sym_name = attr_def(StringAttr)
traits = frozenset(
{InnerRefNamespace(), SymbolTable(), SingleBlockImplicitTerminator(OutputOp)}
{
InnerRefNamespaceTrait(),
SymbolTable(),
SingleBlockImplicitTerminator(OutputOp),
}
)

def __post_init__(self):
Expand All @@ -85,12 +89,12 @@ def __post_init__(self):
class WireOp(IRDLOperation):
name = "wire"
sym_name = attr_def(StringAttr)
traits = frozenset({InnerRefUserOpInterface()})
traits = frozenset({InnerRefUserOpInterfaceTrait()})


def test_inner_symbol_table_interface():
"""
Test operations that conform to InnerSymbolTable
Test operations that conform to InnerSymbolTableTrait
"""
mod = ModuleOp(
attributes={"sym_name": StringAttr("symbol_name")}, regions=[[OutputOp()]]
Expand All @@ -104,7 +108,7 @@ def test_inner_symbol_table_interface():
)
with pytest.raises(
VerifyException,
match="Operation module with trait InnerSymbolTable must have a parent with trait SymbolOpInterface",
match="Operation module with trait InnerSymbolTableTrait must have a parent with trait SymbolOpInterface",
):
mod_no_parent.verify()

Expand All @@ -114,12 +118,12 @@ def test_inner_symbol_table_interface():
no_trait_circ = TestOp(regions=[[mod_no_trait_circ, OutputOp()]])
with pytest.raises(
VerifyException,
match="Operation module with trait InnerSymbolTable must have a parent with trait InnerRefNamespace",
match="Operation module with trait InnerSymbolTableTrait must have a parent with trait InnerRefNamespaceTrait",
):
mod_no_trait_circ.verify()
with pytest.raises(
VerifyException,
match="Operation module with trait InnerSymbolTable must have a parent with trait InnerRefNamespace",
match="Operation module with trait InnerSymbolTableTrait must have a parent with trait InnerRefNamespaceTrait",
):
no_trait_circ.verify()

Expand All @@ -128,7 +132,7 @@ class MissingTraitModuleOp(IRDLOperation):
name = "module"
region = region_def()
sym_name = attr_def(StringAttr)
traits = frozenset({InnerSymbolTable()})
traits = frozenset({InnerSymbolTableTrait()})

mod_missing_trait = MissingTraitModuleOp(
attributes={"sym_name": StringAttr("symbol_name")}, regions=[[OutputOp()]]
Expand All @@ -147,7 +151,7 @@ class MissingTraitModuleOp(IRDLOperation):
class MissingAttrModuleOp(IRDLOperation):
name = "module"
region = region_def()
traits = frozenset({InnerSymbolTable(), SymbolOpInterface()})
traits = frozenset({InnerSymbolTableTrait(), SymbolOpInterface()})

mod_missing_trait_parent = ModuleOp(regions=[[OutputOp()]])
MissingAttrModuleOp(regions=[[mod_missing_trait_parent, OutputOp()]])
Expand All @@ -160,7 +164,7 @@ class MissingAttrModuleOp(IRDLOperation):

def test_inner_ref_namespace_interface():
"""
Test operations that conform to InnerRefNamespace
Test operations that conform to InnerRefNamespaceTrait
"""

@irdl_op_definition
Expand All @@ -169,7 +173,7 @@ class MissingTraitCircuitOp(IRDLOperation):
region: Region | None = opt_region_def()
sym_name = attr_def(StringAttr)
traits = frozenset(
{InnerRefNamespace(), SingleBlockImplicitTerminator(OutputOp)}
{InnerRefNamespaceTrait(), SingleBlockImplicitTerminator(OutputOp)}
)

wire0 = WireOp(attributes={"sym_name": StringAttr("wire0")})
Expand Down Expand Up @@ -197,9 +201,11 @@ class MissingTraitCircuitOp(IRDLOperation):
attributes={"sym_name": StringAttr("circuit")}, regions=[[mod1, mod2]]
)

# InnerRefUserOpInterface.verify_inner_refs() does not do anything, so just mock
# InnerRefUserOpInterfaceTrait.verify_inner_refs() does not do anything, so just mock
# to check it is called
with patch.object(InnerRefUserOpInterface, "verify_inner_refs") as inner_ref_verif:
with patch.object(
InnerRefUserOpInterfaceTrait, "verify_inner_refs"
) as inner_ref_verif:
circuit.verify()

inner_ref_verif.assert_any_call(wire1, ANY)
Expand All @@ -226,22 +232,22 @@ def test_inner_symbol_table_collection():
inner_sym_tables = InnerSymbolTableCollection(op=circuit)

with pytest.raises(
VerifyException, match="Operation wire should have InnerSymbolTable trait"
VerifyException, match="Operation wire should have InnerSymbolTableTrait trait"
):
inner_sym_tables.get_inner_symbol_table(wire1)

sym_table1 = inner_sym_tables.get_inner_symbol_table(mod1)
sym_table2 = inner_sym_tables.get_inner_symbol_table(mod2)
assert (
sym_table1 is not sym_table2
), "Different InnerSymbolTable objects must return different instances of inner symbol tables"
), "Different InnerSymbolTableTrait objects must return different instances of inner symbol tables"

unpopulated_inner_sym_tables = InnerSymbolTableCollection()
sym_table3 = unpopulated_inner_sym_tables.get_inner_symbol_table(mod1)
sym_table4 = unpopulated_inner_sym_tables.get_inner_symbol_table(mod2)
assert (
sym_table3 is not sym_table4
), "InnerSymbolTable still behave as expected when created on the fly"
), "InnerSymbolTableTrait still behave as expected when created on the fly"


def test_inner_ref_attr():
Expand Down
88 changes: 45 additions & 43 deletions xdsl/dialects/hw.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,20 +120,8 @@ def print_parameters(self, printer: Printer) -> None:


@dataclass(frozen=True)
class InnerSymbolTable(OpTrait):
"""A trait for inner symbol table functionality on an operation.
Merges the upstream table of inner symbols and their resolutions and the op trait.
"""

symbol_table: dict[StringAttr, InnerSymTarget] = field(
default_factory=dict, compare=False
)
op: InitVar[Operation | None] = None

def __post_init__(self, op: Operation | None = None) -> None:
if op is None:
return
# Here will populate self.symbol_table
class InnerSymbolTableTrait(OpTrait):
"""A trait for inner symbol table functionality on an operation."""

def verify(self, op: Operation):
# Insist that ops with InnerSymbolTable's provide a Symbol, this is
Expand All @@ -143,37 +131,51 @@ def verify(self, op: Operation):
f"Operation {op.name} must have trait {trait.__name__}"
)

# InnerSymbolTable's must be directly nested within an InnerRefNamespace,
# InnerSymbolTable's must be directly nested within an InnerRefNamespaceTrait,
# however don’t test InnerRefNamespace’s symbol lookups
parent = op.parent_op()
if (
parent is None
or len(parent.get_traits_of_type(trait := InnerRefNamespace)) != 1
or len(parent.get_traits_of_type(trait := InnerRefNamespaceTrait)) != 1
):
raise VerifyException(
f"Operation {op.name} with trait {type(self).__name__} must have a parent with trait {trait.__name__}"
)


@dataclass
class InnerSymbolTable:
"""A class for lookups in inner symbol tables. Called InnerSymbolTable in upstream (name clash with trait)."""

op: InitVar[Operation | None] = None
symbol_table: dict[StringAttr, InnerSymTarget] = field(default_factory=dict)

def __post_init__(self, op: Operation | None = None) -> None:
pass
# Here will populate self.symbol_table


@dataclass
class InnerSymbolTableCollection:
"""This class represents an InnerSymbolTable collection."""
"""This class represents a collection of InnerSymbolTable."""

symbol_tables: dict[Operation, InnerSymbolTable] = field(default_factory=dict)
symbol_tables: dict[Operation, InnerSymbolTable] = field(
default_factory=dict, init=False
)
op: InitVar[Operation | None] = None

def __post_init__(self, op: Operation | None = None) -> None:
if op is None:
return
if not op.has_trait(trait := InnerRefNamespace):
if not op.has_trait(trait := InnerRefNamespaceTrait):
raise VerifyException(
f"Operation {op.name} should have {trait.__name__} trait"
)
self.populate_and_verify_tables(op)

def get_inner_symbol_table(self, op: Operation) -> InnerSymbolTable:
"""Returns the InnerSymolTable trait, ensuring `op` is in the collection"""
if not op.has_trait(trait := InnerSymbolTable):
if not op.has_trait(trait := InnerSymbolTableTrait):
raise VerifyException(
f"Operation {op.name} should have {trait.__name__} trait"
)
Expand All @@ -185,7 +187,7 @@ def populate_and_verify_tables(self, inner_ref_ns_op: Operation):
"""Populate tables for all InnerSymbolTable operations in the given InnerRefNamespace operation, verifying each."""
# Gather top-level operations that have the InnerSymbolTable trait.
inner_sym_table_ops = (
op for op in inner_ref_ns_op.walk() if op.has_trait(InnerSymbolTable)
op for op in inner_ref_ns_op.walk() if op.has_trait(InnerSymbolTableTrait)
)

# Construct the tables
Expand All @@ -197,7 +199,7 @@ def populate_and_verify_tables(self, inner_ref_ns_op: Operation):
self.symbol_tables[op] = InnerSymbolTable(op)


class InnerRefUserOpInterface(OpTrait):
class InnerRefUserOpInterfaceTrait(OpTrait):
"""This interface describes an operation that may use a `InnerRef`. This
interface allows for users of inner symbols to hook into verification and
other inner symbol related utilities that are either costly or otherwise
Expand All @@ -209,21 +211,8 @@ def verify_inner_refs(self, op: Operation, namespace: "InnerRefNamespace"):


@dataclass(frozen=True)
class InnerRefNamespace(OpTrait):
"""Defines a new scope for InnerRef’s. Operations with this trait myst be a SymbolTable.
Combines InnerSymbolTableCollection with a SymbolTable for resolution of InnerRefAttrs, used during verification.
Inner symbols are more costly than normal symbols, with tricker verification. For this reason,
verification is driven as a trait verifier on InnerRefNamespace which constructs and verifies InnerSymbolTables in parallel.
See: https://circt.llvm.org/docs/RationaleSymbols/#innerrefnamespace
"""

symbol_table: SymbolTable = field(
default_factory=lambda: SymbolTable(), compare=False
)
inner_sym_tables: InnerSymbolTableCollection = field(
default_factory=lambda: InnerSymbolTableCollection(), compare=False
)
class InnerRefNamespaceTrait(OpTrait):
"""Trait for operations defining a new scope for InnerRef’s. Operations with this trait must be a SymbolTable."""

def verify(self, op: Operation):
if not op.has_trait(trait := SymbolTable):
Expand All @@ -237,18 +226,31 @@ def verify(self, op: Operation):
if len(op.regions[0].blocks) != 1:
raise VerifyException(f"Operation {op.name} must have a single block")

inner_sym_tables = InnerSymbolTableCollection(op=op)
symbol_table = SymbolTable(op)
namespace = InnerRefNamespace(
symbol_table=symbol_table, inner_sym_tables=inner_sym_tables
)
namespace = InnerRefNamespace(op)

for inner_op in op.walk():
inner_ref_user_op_trait = inner_op.get_trait(InnerRefUserOpInterface)
inner_ref_user_op_trait = inner_op.get_trait(InnerRefUserOpInterfaceTrait)
if inner_ref_user_op_trait is not None:
inner_ref_user_op_trait.verify_inner_refs(inner_op, namespace)


@dataclass
class InnerRefNamespace:
"""Class to perform symbol lookups within a InnerRef namespace, used during verification.
Combines InnerSymbolTableCollection with a SymbolTable for resolution of InnerRefAttrs.
Inner symbols are more costly than normal symbols, with tricker verification. For this reason,
verification is driven as a trait verifier on InnerRefNamespace which constructs and verifies InnerSymbolTables in parallel.
See: https://circt.llvm.org/docs/RationaleSymbols/#innerrefnamespace
"""

inner_sym_tables: InnerSymbolTableCollection = field(init=False)
inner_ref_ns_op: InitVar[Operation]

def __init__(self, inner_ref_ns_op: Operation):
self.inner_sym_tables = InnerSymbolTableCollection(inner_ref_ns_op)


HW = Dialect(
"hw",
[],
Expand Down

0 comments on commit e3eec8b

Please sign in to comment.