Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 31 additions & 38 deletions src/onnx_ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,12 @@
import sys
import textwrap
import typing
from collections import OrderedDict
from collections.abc import (
Collection,
Hashable,
Iterable,
Iterator,
MutableMapping,
Mapping,
MutableSequence,
Sequence,
)
Expand Down Expand Up @@ -1325,7 +1324,7 @@
domain: str,
op_type: str,
inputs: Iterable[Value | None],
attributes: Iterable[Attr] = (),
attributes: Iterable[Attr] | Mapping[str, Attr] = (),
*,
overload: str = "",
num_outputs: int | None = None,
Expand Down Expand Up @@ -1371,15 +1370,10 @@
self._inputs: tuple[Value | None, ...] = tuple(inputs)
# Values belong to their defining nodes. The values list is immutable
self._outputs: tuple[Value, ...] = self._create_outputs(num_outputs, outputs)
attributes = tuple(attributes)
if attributes and not isinstance(attributes[0], Attr):
raise TypeError(
f"Expected the attributes to be Attr, got {type(attributes[0])}. "
"If you are copying the attributes from another node, make sure you call "
"node.attributes.values() because it is a dictionary."
)
self._attributes: OrderedDict[str, Attr] = OrderedDict(
(attr.name, attr) for attr in attributes
if isinstance(attributes, Mapping):
attributes = tuple(attributes.values())
self._attributes: _graph_containers.Attributes = _graph_containers.Attributes(
attributes
)
self._overload: str = overload
# TODO(justinchuby): Potentially support a version range
Expand Down Expand Up @@ -1637,7 +1631,7 @@
raise AttributeError("outputs is immutable. Please create a new node instead.")

@property
def attributes(self) -> OrderedDict[str, Attr]:
def attributes(self) -> _graph_containers.Attributes:
"""The attributes of the node."""
return self._attributes

Expand Down Expand Up @@ -2201,17 +2195,9 @@
# Private fields that are not to be accessed by any other classes
self._inputs = _graph_containers.GraphInputs(self, inputs)
self._outputs = _graph_containers.GraphOutputs(self, outputs)
self._initializers = _graph_containers.GraphInitializers(self)
for initializer in initializers:
if isinstance(initializer, str):
raise TypeError(
"Initializer must be a Value, not a string. "
"If you are copying the initializers from another graph, "
"make sure you call graph.initializers.values() because it is a dictionary."
)
if initializer.name is None:
raise ValueError(f"Initializer must have a name: {initializer}")
self._initializers[initializer.name] = initializer
self._initializers = _graph_containers.GraphInitializers(
self, {initializer.name: initializer for initializer in initializers}
)
self._doc_string = doc_string
self._opset_imports = opset_imports or {}
self._metadata: _metadata.MetadataStore | None = None
Expand All @@ -2234,7 +2220,19 @@
return self._outputs

@property
def initializers(self) -> MutableMapping[str, Value]:
def initializers(self) -> _graph_containers.GraphInitializers:
"""The initializers of the graph as a ``MutableMapping[str, Value]``.
The keys are the names of the initializers. The values are the :class:`Value` objects.
This property additionally supports the ``add`` method, which takes a :class:`Value`
and adds it to the initializers if it is not already present.
.. note::
When setting an initializer with ``graph.initializers[key] = value``,
if the value does not have a name, it will be assigned ``key`` as its name.
"""
return self._initializers

def register_initializer(self, value: Value) -> None:
Expand Down Expand Up @@ -2263,15 +2261,11 @@
" it is not the same object: existing={self._initializers[value.name]!r},"
f" new={value!r}"
)
if value.producer() is not None:
raise ValueError(
f"Value '{value!r}' is produced by a node and cannot be an initializer."
)
if value.const_value is None:
raise ValueError(
f"Value '{value!r}' must have its const_value set to be an initializer."
)
self._initializers[value.name] = value
self._initializers.add(value)

@property
def doc_string(self) -> str | None:
Expand Down Expand Up @@ -2701,7 +2695,7 @@
outputs: Sequence[Value],
*,
nodes: Iterable[Node],
initializers: Sequence[_protocols.ValueProtocol] = (),
initializers: Sequence[Value] = (),
doc_string: str | None = None,
opset_imports: dict[str, int] | None = None,
name: str | None = None,
Expand All @@ -2710,10 +2704,7 @@
self.name = name
self.inputs = tuple(inputs)
self.outputs = tuple(outputs)
for initializer in initializers:
if initializer.name is None:
raise ValueError(f"Initializer must have a name: {initializer}")
self.initializers = {tensor.name: tensor for tensor in initializers}
self.initializers = {initializer.name: initializer for initializer in initializers}
self.doc_string = doc_string
self.opset_imports = opset_imports or {}
self._metadata: _metadata.MetadataStore | None = None
Expand Down Expand Up @@ -2927,13 +2918,15 @@
# Ensure the inputs and outputs of the function belong to a graph
# and not from an outer scope
graph: Graph,
attributes: Sequence[Attr],
attributes: Iterable[Attr] | Mapping[str, Attr],
) -> None:
self._domain = domain
self._name = name
self._overload = overload
self._graph = graph
self._attributes = OrderedDict((attr.name, attr) for attr in attributes)
if isinstance(attributes, Mapping):
attributes = tuple(attributes.values())

Check warning on line 2928 in src/onnx_ir/_core.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_core.py#L2928

Added line #L2928 was not covered by tests
self._attributes = _graph_containers.Attributes(attributes)

def identifier(self) -> _protocols.OperatorIdentifier:
return self.domain, self.name, self.overload
Expand Down Expand Up @@ -2971,7 +2964,7 @@
return self._graph.outputs

@property
def attributes(self) -> OrderedDict[str, Attr]:
def attributes(self) -> _graph_containers.Attributes:
return self._attributes

@typing.overload
Expand Down
159 changes: 155 additions & 4 deletions src/onnx_ir/_core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,6 +861,117 @@ def test_domain_normalizes_ai_onnx(self):
node.domain = "ai.onnx"
self.assertEqual(node.domain, "")

def test_attributes_add(self):
node = _core.Node("ai.onnx", "TestOp", inputs=())
node.attributes.add(_core.AttrInt64("test_attr", 1))
self.assertIn("test_attr", node.attributes)
self.assertEqual(node.attributes["test_attr"].value, 1)

def test_attributes_set_raise_with_type_error(self):
node = _core.Node("ai.onnx", "TestOp", inputs=())
with self.assertRaises(TypeError):
node.attributes["test_attr"] = 1
with self.assertRaises(TypeError):
node.attributes[1] = _core.AttrInt64("test_attr", 1)

def test_init_accepts_attribute_mapping(self):
node = _core.Node(
"ai.onnx", "TestOp", inputs=(), attributes=[_core.AttrInt64("test_attr", 1)]
)
new_node = _core.Node("", "OtherOp", inputs=(), attributes=node.attributes)
self.assertEqual(new_node.attributes, node.attributes)

def test_attributes_get_int(self):
node = _core.Node(
"ai.onnx", "TestOp", inputs=(), attributes=[_core.AttrInt64("test_attr", 1)]
)
self.assertEqual(node.attributes.get_int("test_attr"), 1)
self.assertIsNone(node.attributes.get_int("non_existent_attr"))
self.assertEqual(node.attributes.get_int("non_existent_attr", 42), 42)

def test_attributes_get_float(self):
node = _core.Node(
"ai.onnx", "TestOp", inputs=(), attributes=[_core.AttrFloat32("test_attr", 1.0)]
)
self.assertEqual(node.attributes.get_float("test_attr"), 1.0)
self.assertIsNone(node.attributes.get_float("non_existent_attr"))
self.assertEqual(node.attributes.get_float("non_existent_attr", 42.0), 42.0)

def test_attributes_get_string(self):
node = _core.Node(
"ai.onnx", "TestOp", inputs=(), attributes=[_core.AttrString("test_attr", "value")]
)
self.assertEqual(node.attributes.get_string("test_attr"), "value")
self.assertIsNone(node.attributes.get_string("non_existent_attr"))
self.assertEqual(node.attributes.get_string("non_existent_attr", "default"), "default")

def test_attributes_get_tensor(self):
tensor = ir.Tensor(np.array([1.0, 2.0, 3.0], dtype=np.float32))
node = _core.Node(
"ai.onnx", "TestOp", inputs=(), attributes=[_core.AttrTensor("test_attr", tensor)]
)
np.testing.assert_equal(
node.attributes.get_tensor("test_attr").numpy(), tensor.numpy()
)
self.assertIsNone(node.attributes.get_tensor("non_existent_attr"))
np.testing.assert_equal(
node.attributes.get_tensor("non_existent_attr", tensor).numpy(), tensor.numpy()
)

def test_attributes_get_ints(self):
node = _core.Node(
"ai.onnx",
"TestOp",
inputs=(),
attributes=[_core.AttrInt64s("test_attr", [1, 2, 3])],
)
self.assertEqual(node.attributes.get_ints("test_attr"), [1, 2, 3])
self.assertIsNone(node.attributes.get_ints("non_existent_attr"))
self.assertEqual(node.attributes.get_ints("non_existent_attr", [42]), [42])

def test_attributes_get_floats(self):
node = _core.Node(
"ai.onnx",
"TestOp",
inputs=(),
attributes=[_core.AttrFloat32s("test_attr", [1.0, 2.0, 3.0])],
)
self.assertEqual(node.attributes.get_floats("test_attr"), [1.0, 2.0, 3.0])
self.assertIsNone(node.attributes.get_floats("non_existent_attr"))
self.assertEqual(node.attributes.get_floats("non_existent_attr", [42.0]), [42.0])

def test_attributes_get_strings(self):
node = _core.Node(
"ai.onnx",
"TestOp",
inputs=(),
attributes=[_core.AttrStrings("test_attr", ["a", "b", "c"])],
)
self.assertEqual(node.attributes.get_strings("test_attr"), ["a", "b", "c"])
self.assertIsNone(node.attributes.get_strings("non_existent_attr"))
self.assertEqual(
node.attributes.get_strings("non_existent_attr", ["default"]), ["default"]
)

def test_attributes_get_tensors(self):
tensor1 = ir.Tensor(np.array([1.0, 2.0], dtype=np.float32))
tensor2 = ir.Tensor(np.array([3.0, 4.0], dtype=np.float32))
node = _core.Node(
"ai.onnx",
"TestOp",
inputs=(),
attributes=[_core.AttrTensors("test_attr", [tensor1, tensor2])],
)
tensors = node.attributes.get_tensors("test_attr")
self.assertIsNotNone(tensors)
self.assertEqual(len(tensors), 2)
np.testing.assert_equal(tensors[0].numpy(), tensor1.numpy())
np.testing.assert_equal(tensors[1].numpy(), tensor2.numpy())
self.assertIsNone(node.attributes.get_tensors("non_existent_attr"))
np.testing.assert_equal(
node.attributes.get_tensors("non_existent_attr", [tensor1]), [tensor1]
)

# TODO(justinchuby): Test all methods


Expand Down Expand Up @@ -1453,7 +1564,7 @@ def test_outputs_copy(self):
self.assertNotIn(self.value3, self.graph.outputs)
self.assertIn(self.value3, outputs_copy)

def test_set_initializers(self):
def test_initializers_setitem(self):
self.graph.initializers["initializer1"] = self.value3
self.assertIn("initializer1", self.graph.initializers)
self.assertTrue(self.value3.is_initializer())
Expand All @@ -1467,11 +1578,11 @@ def test_set_initializers(self):
self.assertFalse(self.value3.is_initializer())
self.assertIsNone(self.value3.graph)

def test_set_initializers_raises_when_key_does_not_match(self):
def test_initializers_setitem_raises_when_key_does_not_match(self):
with self.assertRaisesRegex(ValueError, "does not match the name of the value"):
self.graph.initializers["some_key"] = self.value3

def test_set_initializers_raises_when_it_belongs_to_another_graph(self):
def test_initializers_setitem_raises_when_it_belongs_to_another_graph(self):
other_graph = _core.Graph(inputs=(), outputs=(), nodes=())
other_graph.initializers["initializer1"] = self.value3
with self.assertRaisesRegex(
Expand All @@ -1485,11 +1596,51 @@ def test_set_initializers_raises_when_it_belongs_to_another_graph(self):
self.assertTrue(self.value3.is_initializer())
self.assertIs(self.value3.graph, self.graph)

def test_set_initializers_raises_when_value_does_not_have_a_name(self):
def test_initializers_setitem_raises_when_value_does_not_have_a_name(self):
self.value3.name = None
with self.assertRaises(TypeError):
self.graph.initializers[None] = self.value3

with self.assertRaisesRegex(ValueError, "cannot be an empty string"):
self.graph.initializers[""] = _core.Value(name="")

def test_initializers_setitem_checks_value_name_match(self):
with self.assertRaisesRegex(ValueError, "does not match"):
self.graph.initializers["some_name"] = _core.Value(name="some_other_name")

def test_initializers_setitem_assigns_key_to_value_name_if_not_set(self):
value = _core.Value(name=None)
self.graph.initializers["some_name"] = value
self.assertEqual(value.name, "some_name")
self.assertIs(value, self.graph.initializers["some_name"])

value = _core.Value(name="")
self.graph.initializers["some_other_name"] = value
self.assertEqual(value.name, "some_other_name")
self.assertIs(value, self.graph.initializers["some_other_name"])

def test_initializers_setitem_checks_value_type(self):
with self.assertRaisesRegex(TypeError, "must be a Value object"):
self.graph.initializers["some_name"] = ir.tensor([1, 2, 3], name="some_tensor")

def test_initializers_setitem_raises_when_value_is_node_output(self):
node = ir.node("SomeOp", inputs=[])
with self.assertRaisesRegex(ValueError, "produced by a node"):
self.graph.initializers["some_name"] = node.outputs[0]

def test_initializers_add_checks_value_name(self):
# Initializers should always have a name
with self.assertRaisesRegex(ValueError, "cannot be an empty string"):
self.graph.initializers.add(_core.Value(name=""))

with self.assertRaisesRegex(TypeError, "must be a string"):
self.graph.initializers.add(_core.Value(name=None))

def test_initializers_add_checks_value_type(self):
# Initializers should be of type Value
with self.assertRaisesRegex(TypeError, "must be a Value object"):
self.graph.initializers.add(ir.tensor([1, 2, 3], name="some_tensor"))

def test_delete_initializer(self):
self.graph.initializers["initializer1"] = self.value3
del self.graph.initializers["initializer1"]
Expand Down
Loading