diff --git a/src/onnx_ir/_core.py b/src/onnx_ir/_core.py index 339d7dd4..eefb4480 100644 --- a/src/onnx_ir/_core.py +++ b/src/onnx_ir/_core.py @@ -22,13 +22,12 @@ import sys import textwrap import typing -from collections import OrderedDict from collections.abc import ( Collection, Hashable, Iterable, Iterator, - MutableMapping, + Mapping, MutableSequence, Sequence, ) @@ -1325,7 +1324,7 @@ def __init__( domain: str, op_type: str, inputs: Iterable[Value | None], - attributes: Iterable[Attr] = (), + attributes: Iterable[Attr] | Mapping[str, Attr] = (), *, overload: str = "", num_outputs: int | None = None, @@ -1371,15 +1370,10 @@ def __init__( self._inputs: tuple[Value | None, ...] = tuple(inputs) # Values belong to their defining nodes. The values list is immutable self._outputs: tuple[Value, ...] = self._create_outputs(num_outputs, outputs) - attributes = tuple(attributes) - if attributes and not isinstance(attributes[0], Attr): - raise TypeError( - f"Expected the attributes to be Attr, got {type(attributes[0])}. " - "If you are copying the attributes from another node, make sure you call " - "node.attributes.values() because it is a dictionary." - ) - self._attributes: OrderedDict[str, Attr] = OrderedDict( - (attr.name, attr) for attr in attributes + if isinstance(attributes, Mapping): + attributes = tuple(attributes.values()) + self._attributes: _graph_containers.Attributes = _graph_containers.Attributes( + attributes ) self._overload: str = overload # TODO(justinchuby): Potentially support a version range @@ -1637,7 +1631,7 @@ def outputs(self, _: Sequence[Value]) -> None: raise AttributeError("outputs is immutable. Please create a new node instead.") @property - def attributes(self) -> OrderedDict[str, Attr]: + def attributes(self) -> _graph_containers.Attributes: """The attributes of the node.""" return self._attributes @@ -2201,17 +2195,9 @@ def __init__( # Private fields that are not to be accessed by any other classes self._inputs = _graph_containers.GraphInputs(self, inputs) self._outputs = _graph_containers.GraphOutputs(self, outputs) - self._initializers = _graph_containers.GraphInitializers(self) - for initializer in initializers: - if isinstance(initializer, str): - raise TypeError( - "Initializer must be a Value, not a string. " - "If you are copying the initializers from another graph, " - "make sure you call graph.initializers.values() because it is a dictionary." - ) - if initializer.name is None: - raise ValueError(f"Initializer must have a name: {initializer}") - self._initializers[initializer.name] = initializer + self._initializers = _graph_containers.GraphInitializers( + self, {initializer.name: initializer for initializer in initializers} + ) self._doc_string = doc_string self._opset_imports = opset_imports or {} self._metadata: _metadata.MetadataStore | None = None @@ -2234,7 +2220,19 @@ def outputs(self) -> MutableSequence[Value]: return self._outputs @property - def initializers(self) -> MutableMapping[str, Value]: + def initializers(self) -> _graph_containers.GraphInitializers: + """The initializers of the graph as a ``MutableMapping[str, Value]``. + + The keys are the names of the initializers. The values are the :class:`Value` objects. + + This property additionally supports the ``add`` method, which takes a :class:`Value` + and adds it to the initializers if it is not already present. + + .. note:: + When setting an initializer with ``graph.initializers[key] = value``, + if the value does not have a name, it will be assigned ``key`` as its name. + + """ return self._initializers def register_initializer(self, value: Value) -> None: @@ -2263,15 +2261,11 @@ def register_initializer(self, value: Value) -> None: " it is not the same object: existing={self._initializers[value.name]!r}," f" new={value!r}" ) - if value.producer() is not None: - raise ValueError( - f"Value '{value!r}' is produced by a node and cannot be an initializer." - ) if value.const_value is None: raise ValueError( f"Value '{value!r}' must have its const_value set to be an initializer." ) - self._initializers[value.name] = value + self._initializers.add(value) @property def doc_string(self) -> str | None: @@ -2701,7 +2695,7 @@ def __init__( outputs: Sequence[Value], *, nodes: Iterable[Node], - initializers: Sequence[_protocols.ValueProtocol] = (), + initializers: Sequence[Value] = (), doc_string: str | None = None, opset_imports: dict[str, int] | None = None, name: str | None = None, @@ -2710,10 +2704,7 @@ def __init__( self.name = name self.inputs = tuple(inputs) self.outputs = tuple(outputs) - for initializer in initializers: - if initializer.name is None: - raise ValueError(f"Initializer must have a name: {initializer}") - self.initializers = {tensor.name: tensor for tensor in initializers} + self.initializers = {initializer.name: initializer for initializer in initializers} self.doc_string = doc_string self.opset_imports = opset_imports or {} self._metadata: _metadata.MetadataStore | None = None @@ -2927,13 +2918,15 @@ def __init__( # Ensure the inputs and outputs of the function belong to a graph # and not from an outer scope graph: Graph, - attributes: Sequence[Attr], + attributes: Iterable[Attr] | Mapping[str, Attr], ) -> None: self._domain = domain self._name = name self._overload = overload self._graph = graph - self._attributes = OrderedDict((attr.name, attr) for attr in attributes) + if isinstance(attributes, Mapping): + attributes = tuple(attributes.values()) + self._attributes = _graph_containers.Attributes(attributes) def identifier(self) -> _protocols.OperatorIdentifier: return self.domain, self.name, self.overload @@ -2971,7 +2964,7 @@ def outputs(self) -> MutableSequence[Value]: return self._graph.outputs @property - def attributes(self) -> OrderedDict[str, Attr]: + def attributes(self) -> _graph_containers.Attributes: return self._attributes @typing.overload diff --git a/src/onnx_ir/_core_test.py b/src/onnx_ir/_core_test.py index 1a956dfe..b77efaae 100644 --- a/src/onnx_ir/_core_test.py +++ b/src/onnx_ir/_core_test.py @@ -861,6 +861,117 @@ def test_domain_normalizes_ai_onnx(self): node.domain = "ai.onnx" self.assertEqual(node.domain, "") + def test_attributes_add(self): + node = _core.Node("ai.onnx", "TestOp", inputs=()) + node.attributes.add(_core.AttrInt64("test_attr", 1)) + self.assertIn("test_attr", node.attributes) + self.assertEqual(node.attributes["test_attr"].value, 1) + + def test_attributes_set_raise_with_type_error(self): + node = _core.Node("ai.onnx", "TestOp", inputs=()) + with self.assertRaises(TypeError): + node.attributes["test_attr"] = 1 + with self.assertRaises(TypeError): + node.attributes[1] = _core.AttrInt64("test_attr", 1) + + def test_init_accepts_attribute_mapping(self): + node = _core.Node( + "ai.onnx", "TestOp", inputs=(), attributes=[_core.AttrInt64("test_attr", 1)] + ) + new_node = _core.Node("", "OtherOp", inputs=(), attributes=node.attributes) + self.assertEqual(new_node.attributes, node.attributes) + + def test_attributes_get_int(self): + node = _core.Node( + "ai.onnx", "TestOp", inputs=(), attributes=[_core.AttrInt64("test_attr", 1)] + ) + self.assertEqual(node.attributes.get_int("test_attr"), 1) + self.assertIsNone(node.attributes.get_int("non_existent_attr")) + self.assertEqual(node.attributes.get_int("non_existent_attr", 42), 42) + + def test_attributes_get_float(self): + node = _core.Node( + "ai.onnx", "TestOp", inputs=(), attributes=[_core.AttrFloat32("test_attr", 1.0)] + ) + self.assertEqual(node.attributes.get_float("test_attr"), 1.0) + self.assertIsNone(node.attributes.get_float("non_existent_attr")) + self.assertEqual(node.attributes.get_float("non_existent_attr", 42.0), 42.0) + + def test_attributes_get_string(self): + node = _core.Node( + "ai.onnx", "TestOp", inputs=(), attributes=[_core.AttrString("test_attr", "value")] + ) + self.assertEqual(node.attributes.get_string("test_attr"), "value") + self.assertIsNone(node.attributes.get_string("non_existent_attr")) + self.assertEqual(node.attributes.get_string("non_existent_attr", "default"), "default") + + def test_attributes_get_tensor(self): + tensor = ir.Tensor(np.array([1.0, 2.0, 3.0], dtype=np.float32)) + node = _core.Node( + "ai.onnx", "TestOp", inputs=(), attributes=[_core.AttrTensor("test_attr", tensor)] + ) + np.testing.assert_equal( + node.attributes.get_tensor("test_attr").numpy(), tensor.numpy() + ) + self.assertIsNone(node.attributes.get_tensor("non_existent_attr")) + np.testing.assert_equal( + node.attributes.get_tensor("non_existent_attr", tensor).numpy(), tensor.numpy() + ) + + def test_attributes_get_ints(self): + node = _core.Node( + "ai.onnx", + "TestOp", + inputs=(), + attributes=[_core.AttrInt64s("test_attr", [1, 2, 3])], + ) + self.assertEqual(node.attributes.get_ints("test_attr"), [1, 2, 3]) + self.assertIsNone(node.attributes.get_ints("non_existent_attr")) + self.assertEqual(node.attributes.get_ints("non_existent_attr", [42]), [42]) + + def test_attributes_get_floats(self): + node = _core.Node( + "ai.onnx", + "TestOp", + inputs=(), + attributes=[_core.AttrFloat32s("test_attr", [1.0, 2.0, 3.0])], + ) + self.assertEqual(node.attributes.get_floats("test_attr"), [1.0, 2.0, 3.0]) + self.assertIsNone(node.attributes.get_floats("non_existent_attr")) + self.assertEqual(node.attributes.get_floats("non_existent_attr", [42.0]), [42.0]) + + def test_attributes_get_strings(self): + node = _core.Node( + "ai.onnx", + "TestOp", + inputs=(), + attributes=[_core.AttrStrings("test_attr", ["a", "b", "c"])], + ) + self.assertEqual(node.attributes.get_strings("test_attr"), ["a", "b", "c"]) + self.assertIsNone(node.attributes.get_strings("non_existent_attr")) + self.assertEqual( + node.attributes.get_strings("non_existent_attr", ["default"]), ["default"] + ) + + def test_attributes_get_tensors(self): + tensor1 = ir.Tensor(np.array([1.0, 2.0], dtype=np.float32)) + tensor2 = ir.Tensor(np.array([3.0, 4.0], dtype=np.float32)) + node = _core.Node( + "ai.onnx", + "TestOp", + inputs=(), + attributes=[_core.AttrTensors("test_attr", [tensor1, tensor2])], + ) + tensors = node.attributes.get_tensors("test_attr") + self.assertIsNotNone(tensors) + self.assertEqual(len(tensors), 2) + np.testing.assert_equal(tensors[0].numpy(), tensor1.numpy()) + np.testing.assert_equal(tensors[1].numpy(), tensor2.numpy()) + self.assertIsNone(node.attributes.get_tensors("non_existent_attr")) + np.testing.assert_equal( + node.attributes.get_tensors("non_existent_attr", [tensor1]), [tensor1] + ) + # TODO(justinchuby): Test all methods @@ -1453,7 +1564,7 @@ def test_outputs_copy(self): self.assertNotIn(self.value3, self.graph.outputs) self.assertIn(self.value3, outputs_copy) - def test_set_initializers(self): + def test_initializers_setitem(self): self.graph.initializers["initializer1"] = self.value3 self.assertIn("initializer1", self.graph.initializers) self.assertTrue(self.value3.is_initializer()) @@ -1467,11 +1578,11 @@ def test_set_initializers(self): self.assertFalse(self.value3.is_initializer()) self.assertIsNone(self.value3.graph) - def test_set_initializers_raises_when_key_does_not_match(self): + def test_initializers_setitem_raises_when_key_does_not_match(self): with self.assertRaisesRegex(ValueError, "does not match the name of the value"): self.graph.initializers["some_key"] = self.value3 - def test_set_initializers_raises_when_it_belongs_to_another_graph(self): + def test_initializers_setitem_raises_when_it_belongs_to_another_graph(self): other_graph = _core.Graph(inputs=(), outputs=(), nodes=()) other_graph.initializers["initializer1"] = self.value3 with self.assertRaisesRegex( @@ -1485,11 +1596,51 @@ def test_set_initializers_raises_when_it_belongs_to_another_graph(self): self.assertTrue(self.value3.is_initializer()) self.assertIs(self.value3.graph, self.graph) - def test_set_initializers_raises_when_value_does_not_have_a_name(self): + def test_initializers_setitem_raises_when_value_does_not_have_a_name(self): self.value3.name = None with self.assertRaises(TypeError): self.graph.initializers[None] = self.value3 + with self.assertRaisesRegex(ValueError, "cannot be an empty string"): + self.graph.initializers[""] = _core.Value(name="") + + def test_initializers_setitem_checks_value_name_match(self): + with self.assertRaisesRegex(ValueError, "does not match"): + self.graph.initializers["some_name"] = _core.Value(name="some_other_name") + + def test_initializers_setitem_assigns_key_to_value_name_if_not_set(self): + value = _core.Value(name=None) + self.graph.initializers["some_name"] = value + self.assertEqual(value.name, "some_name") + self.assertIs(value, self.graph.initializers["some_name"]) + + value = _core.Value(name="") + self.graph.initializers["some_other_name"] = value + self.assertEqual(value.name, "some_other_name") + self.assertIs(value, self.graph.initializers["some_other_name"]) + + def test_initializers_setitem_checks_value_type(self): + with self.assertRaisesRegex(TypeError, "must be a Value object"): + self.graph.initializers["some_name"] = ir.tensor([1, 2, 3], name="some_tensor") + + def test_initializers_setitem_raises_when_value_is_node_output(self): + node = ir.node("SomeOp", inputs=[]) + with self.assertRaisesRegex(ValueError, "produced by a node"): + self.graph.initializers["some_name"] = node.outputs[0] + + def test_initializers_add_checks_value_name(self): + # Initializers should always have a name + with self.assertRaisesRegex(ValueError, "cannot be an empty string"): + self.graph.initializers.add(_core.Value(name="")) + + with self.assertRaisesRegex(TypeError, "must be a string"): + self.graph.initializers.add(_core.Value(name=None)) + + def test_initializers_add_checks_value_type(self): + # Initializers should be of type Value + with self.assertRaisesRegex(TypeError, "must be a Value object"): + self.graph.initializers.add(ir.tensor([1, 2, 3], name="some_tensor")) + def test_delete_initializer(self): self.graph.initializers["initializer1"] = self.value3 del self.graph.initializers["initializer1"] diff --git a/src/onnx_ir/_graph_containers.py b/src/onnx_ir/_graph_containers.py index 16ba97bd..76dfd6c2 100644 --- a/src/onnx_ir/_graph_containers.py +++ b/src/onnx_ir/_graph_containers.py @@ -12,13 +12,16 @@ ] import collections -from collections.abc import Iterable -from typing import TYPE_CHECKING, SupportsIndex +import logging +from collections.abc import Iterable, Sequence +from typing import SupportsIndex, TypeVar import onnx_ir +from onnx_ir import _core, _protocols -if TYPE_CHECKING: - from onnx_ir import _core +T = TypeVar("T") + +logger = logging.getLogger(__name__) class _GraphIO(collections.UserList["_core.Value"]): @@ -248,12 +251,23 @@ def _maybe_unset_graph(self, value: _core.Value) -> None: def __setitem__(self, key: str, value: _core.Value) -> None: """Set an initializer for the graph.""" - if key != value.name: + if not isinstance(value, _core.Value): + raise TypeError(f"value must be a Value object, not {type(value)}") + if not isinstance(key, str): + raise TypeError(f"Value name must be a string, not {type(key)}") + if key == "": + raise ValueError("Value name cannot be an empty string") + if not value.name: + logger.info("Value %s does not have a name, setting it to '%s'", value, key) + value.name = key + elif key != value.name: raise ValueError( - f"Key '{key}' does not match the name of the value '{value.name}'" + f"Key '{key}' does not match the name of the value '{value.name}'. Please use the value.name as the key." + ) + if value.producer() is not None: + raise ValueError( + f"Value '{value}' is produced by a node and cannot be a graph initializer" ) - if not isinstance(key, str): - raise TypeError(f"Key must be a string, not {type(key)}") if key in self.data: # If the key already exists, unset the old value old_value = self.data[key] @@ -270,3 +284,90 @@ def __delitem__(self, key: str) -> None: # the dictionary is not modified self._maybe_unset_graph(value) super().__delitem__(key) + + def add(self, value: _core.Value) -> None: + """Add an initializer to the graph.""" + self[value.name] = value # type: ignore[index] + + +class Attributes(collections.UserDict[str, "_core.Attr"]): + """The attributes of a Node.""" + + def __init__(self, attrs: Iterable[_core.Attr]): + super().__init__({attr.name: attr for attr in attrs}) + + def __setitem__(self, key: str, value: _core.Attr) -> None: + """Set an attribute for the node.""" + if type(key) is not str: + raise TypeError(f"Key must be a string, not {type(key)}") + if not isinstance(value, _core.Attr): + raise TypeError(f"Value must be an Attr, not {type(value)}") + super().__setitem__(key, value) + + def add(self, value: _core.Attr) -> None: + """Add an attribute to the node.""" + self[value.name] = value + + def get_int(self, key: str, default: T = None) -> int | T: # type: ignore[assignment] + """Get the integer value of the attribute.""" + if key in self: + return self[key].as_int() + return default + + def get_float(self, key: str, default: T = None) -> float | T: # type: ignore[assignment] + """Get the float value of the attribute.""" + if key in self: + return self[key].as_float() + return default + + def get_string(self, key: str, default: T = None) -> str | T: # type: ignore[assignment] + """Get the string value of the attribute.""" + if key in self: + return self[key].as_string() + return default + + def get_tensor(self, key: str, default: T = None) -> _protocols.TensorProtocol | T: # type: ignore[assignment] + """Get the tensor value of the attribute.""" + if key in self: + return self[key].as_tensor() + return default + + def get_graph(self, key: str, default: T = None) -> _core.Graph | T: # type: ignore[assignment] + """Get the graph value of the attribute.""" + if key in self: + return self[key].as_graph() + return default + + def get_ints(self, key: str, default: T = None) -> Sequence[int] | T: # type: ignore[assignment] + """Get the Sequence of integers from the attribute.""" + if key in self: + return self[key].as_ints() + return default + + def get_floats(self, key: str, default: T = None) -> Sequence[float] | T: # type: ignore[assignment] + """Get the Sequence of floats from the attribute.""" + if key in self: + return self[key].as_floats() + return default + + def get_strings(self, key: str, default: T = None) -> Sequence[str] | T: # type: ignore[assignment] + """Get the Sequence of strings from the attribute.""" + if key in self: + return self[key].as_strings() + return default + + def get_tensors( + self, + key: str, + default: T = None, # type: ignore[assignment] + ) -> Sequence[_protocols.TensorProtocol] | T: + """Get the Sequence of tensors from the attribute.""" + if key in self: + return self[key].as_tensors() + return default + + def get_graphs(self, key: str, default: T = None) -> Sequence[_core.Graph] | T: # type: ignore[assignment] + """Get the Sequence of graphs from the attribute.""" + if key in self: + return self[key].as_graphs() + return default diff --git a/src/onnx_ir/passes/common/inliner.py b/src/onnx_ir/passes/common/inliner.py index c74317a6..e2c834f7 100644 --- a/src/onnx_ir/passes/common/inliner.py +++ b/src/onnx_ir/passes/common/inliner.py @@ -9,7 +9,7 @@ __all__ = ["InlinePass", "InlinePassResult"] from collections import defaultdict -from collections.abc import Iterable, Sequence +from collections.abc import Iterable, Mapping, Sequence import onnx_ir as ir import onnx_ir.convenience as _ir_convenience @@ -52,7 +52,7 @@ class _CopyReplace: def __init__( self, inliner: InlinePass, - attr_map: dict[str, ir.Attr], + attr_map: Mapping[str, ir.Attr], value_map: dict[ir.Value, ir.Value | None], metadata_props: dict[str, str], call_stack: CallStack, @@ -96,6 +96,7 @@ def clone_attr(self, key: str, attr: ir.Attr) -> ir.Attr | None: return attr assert attr.is_ref() ref_attr_name = attr.ref_attr_name + assert ref_attr_name is not None, "Reference attribute must have a name" if ref_attr_name in self._attr_map: ref_attr = self._attr_map[ref_attr_name] if not ref_attr.is_ref(): @@ -237,7 +238,7 @@ def _instantiate_call(self, node: ir.Node, call_site_id: CallSiteId) -> NodeRepl ) # Identify substitutions for both inputs and attributes of the function: - attributes: dict[str, ir.Attr] = node.attributes + attributes: Mapping[str, ir.Attr] = node.attributes default_attr_values = { attr.name: attr for attr in function.attributes.values()