Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 42 additions & 31 deletions src/onnx_ir/_convenience/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,44 +58,50 @@ def _infer_attribute_type(attr: SupportedAttrTypes) -> _enums.AttributeType:
return _enums.AttributeType.STRING
if isinstance(attr, _core.Attr):
return attr.type
if isinstance(attr, Sequence) and all(isinstance(x, int) for x in attr):
return _enums.AttributeType.INTS
if isinstance(attr, Sequence) and all(isinstance(x, float) for x in attr):
return _enums.AttributeType.FLOATS
if isinstance(attr, Sequence) and all(isinstance(x, str) for x in attr):
return _enums.AttributeType.STRINGS
if isinstance(attr, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol)):
return _enums.AttributeType.GRAPH
if isinstance(attr, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol)):
# Be sure to check TensorProtocol last because isinstance checking on Protocols can be slower
return _enums.AttributeType.TENSOR
if isinstance(attr, Sequence) and all(
isinstance(x, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol))
for x in attr
):
return _enums.AttributeType.TENSORS
if isinstance(attr, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol)):
return _enums.AttributeType.GRAPH
if isinstance(attr, Sequence) and all(
isinstance(x, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol)) for x in attr
):
return _enums.AttributeType.GRAPHS
if isinstance(
attr,
(_core.TensorType, _core.SequenceType, _core.OptionalType, _protocols.TypeProtocol),
):
return _enums.AttributeType.TYPE_PROTO
if isinstance(attr, Sequence) and all(
isinstance(
x,
(
_core.TensorType,
_core.SequenceType,
_core.OptionalType,
_protocols.TypeProtocol,
),
)
for x in attr
):
return _enums.AttributeType.TYPE_PROTOS
if isinstance(attr, Sequence):
if not attr:
raise ValueError(
"Cannot infer type of empty sequence. Please create an Attr with an explicit type."
)
if all(isinstance(x, int) for x in attr):
return _enums.AttributeType.INTS
if all(isinstance(x, float) for x in attr):
return _enums.AttributeType.FLOATS
if all(isinstance(x, str) for x in attr):
return _enums.AttributeType.STRINGS
if all(
isinstance(x, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol))
for x in attr
):
return _enums.AttributeType.TENSORS
if all(
isinstance(x, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol))
for x in attr
):
return _enums.AttributeType.GRAPHS
if all(
isinstance(
x,
(
_core.TensorType,
_core.SequenceType,
_core.OptionalType,
_protocols.TypeProtocol,
),
)
for x in attr
):
return _enums.AttributeType.TYPE_PROTOS
raise TypeError(f"Unsupported attribute type: '{type(attr)}'")


Expand Down Expand Up @@ -218,7 +224,7 @@ def convert_attributes(
... "type_protos": [ir.TensorType(ir.DataType.FLOAT), ir.TensorType(ir.DataType.FLOAT)],
... }
>>> convert_attributes(attrs)
[Attr('int', INT, 1), Attr('float', FLOAT, 1.0), Attr('str', STRING, 'hello'), Attr('ints', INTS, [1, 2, 3]), Attr('floats', FLOATS, [1.0, 2.0, 3.0]), Attr('strings', STRINGS, ['hello', 'world']), Attr('tensor', TENSOR, Tensor<DOUBLE,[3]>(array([1., 2., 3.]), name=None)), Attr('tensor_proto', TENSOR, TensorProtoTensor<FLOAT,[3]>(array([1., 2., 3.], dtype=float32), name='proto')), Attr('graph', INTS, Graph(
[Attr('int', INT, 1), Attr('float', FLOAT, 1.0), Attr('str', STRING, 'hello'), Attr('ints', INTS, [1, 2, 3]), Attr('floats', FLOATS, [1.0, 2.0, 3.0]), Attr('strings', STRINGS, ['hello', 'world']), Attr('tensor', TENSOR, Tensor<DOUBLE,[3]>(array([1., 2., 3.]), name=None)), Attr('tensor_proto', TENSOR, TensorProtoTensor<FLOAT,[3]>(array([1., 2., 3.], dtype=float32), name='proto')), Attr('graph', GRAPH, Graph(
name='graph0',
inputs=(
<BLANKLINE>
Expand Down Expand Up @@ -252,6 +258,11 @@ def convert_attributes(

Returns:
A list of _core.Attr objects.

Raises:
ValueError: If an attribute is an empty sequence. It should be created with an
explicit type by initializing an Attr object with an attribute type.
TypeError: If an attribute type is not supported.
"""
attributes: list[_core.Attr] = []
for name, attr in attrs.items():
Expand Down
22 changes: 15 additions & 7 deletions src/onnx_ir/passes/common/constant_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def call(self, model: ir.Model) -> ir.passes.PassResult:
if graph is model.graph:
continue
for name in tuple(graph.initializers):
assert name is not None
initializer = graph.initializers[name]
if initializer.is_graph_input():
# Skip the ones that are also graph inputs
Expand All @@ -156,17 +157,24 @@ def call(self, model: ir.Model) -> ir.passes.PassResult:
initializer.name,
)
continue
if initializer.is_graph_output():
logger.debug(
"Initializer '%s' is used as output, so it can't be lifted",
initializer.name,
)
continue
# Remove the initializer from the subgraph
graph.initializers.pop(name)
# To avoid name conflicts, we need to rename the initializer
# to a unique name in the main graph
if name in registered_initializer_names:
name_count = registered_initializer_names[name]
initializer.name = f"{name}_{name_count}"
registered_initializer_names[name] = name_count + 1
else:
assert initializer.name is not None
registered_initializer_names[initializer.name] = 1
new_name = name
while new_name in model.graph.initializers:
if name in registered_initializer_names:
registered_initializer_names[name] += 1
else:
registered_initializer_names[name] = 1
new_name = f"{name}_{registered_initializer_names[name]}"
initializer.name = new_name
model.graph.register_initializer(initializer)
count += 1
logger.debug(
Expand Down
42 changes: 42 additions & 0 deletions src/onnx_ir/passes/common/constant_manipulation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,48 @@ def test_pass_does_not_lift_initialized_inputs_in_subgraph(
for value, tensor in zip(main_graph.initializers.values(), [else_initializer_tensor]):
self.assertIs(value.const_value, tensor)

def test_pass_does_not_lift_initialized_outputs_in_subgraph(self):
input = ir.Value(
name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2))
)
output = ir.Value(
name="output", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2))
)

subgraph_output = ir.Value(
name="subgraph_output",
type=ir.TensorType(ir.DataType.FLOAT),
shape=ir.Shape((1, 2)),
const_value=ir.tensor(np.random.rand(1, 2).astype(np.float32)),
)

graph_node = ir.node(
"OpWithSubgraph",
inputs=[input],
attributes={
"subgraph": ir.Graph(
inputs=[],
outputs=[subgraph_output],
nodes=[],
initializers=[subgraph_output],
)
},
outputs=[output],
)
model = ir.Model(
graph=ir.Graph(
inputs=[input],
outputs=[output],
nodes=[graph_node],
opset_imports={"": 20},
),
ir_version=10,
)

result = constant_manipulation.LiftSubgraphInitializersToMainGraphPass()(model)
self.assertFalse(result.modified)
self.assertIs(subgraph_output.graph, graph_node.attributes["subgraph"].as_graph())


class TestRemoveInitializersFromInputsPass(unittest.TestCase):
def test_remove_initializers_from_inputs(self):
Expand Down
Loading