diff --git a/src/onnx_ir/_convenience/__init__.py b/src/onnx_ir/_convenience/__init__.py index 98ba1ead..5c27a2c5 100644 --- a/src/onnx_ir/_convenience/__init__.py +++ b/src/onnx_ir/_convenience/__init__.py @@ -58,44 +58,50 @@ def _infer_attribute_type(attr: SupportedAttrTypes) -> _enums.AttributeType: return _enums.AttributeType.STRING if isinstance(attr, _core.Attr): return attr.type - if isinstance(attr, Sequence) and all(isinstance(x, int) for x in attr): - return _enums.AttributeType.INTS - if isinstance(attr, Sequence) and all(isinstance(x, float) for x in attr): - return _enums.AttributeType.FLOATS - if isinstance(attr, Sequence) and all(isinstance(x, str) for x in attr): - return _enums.AttributeType.STRINGS + if isinstance(attr, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol)): + return _enums.AttributeType.GRAPH if isinstance(attr, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol)): # Be sure to check TensorProtocol last because isinstance checking on Protocols can be slower return _enums.AttributeType.TENSOR - if isinstance(attr, Sequence) and all( - isinstance(x, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol)) - for x in attr - ): - return _enums.AttributeType.TENSORS - if isinstance(attr, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol)): - return _enums.AttributeType.GRAPH - if isinstance(attr, Sequence) and all( - isinstance(x, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol)) for x in attr - ): - return _enums.AttributeType.GRAPHS if isinstance( attr, (_core.TensorType, _core.SequenceType, _core.OptionalType, _protocols.TypeProtocol), ): return _enums.AttributeType.TYPE_PROTO - if isinstance(attr, Sequence) and all( - isinstance( - x, - ( - _core.TensorType, - _core.SequenceType, - _core.OptionalType, - _protocols.TypeProtocol, - ), - ) - for x in attr - ): - return _enums.AttributeType.TYPE_PROTOS + if isinstance(attr, Sequence): + if not attr: + raise ValueError( + "Cannot infer type of empty sequence. Please create an Attr with an explicit type." + ) + if all(isinstance(x, int) for x in attr): + return _enums.AttributeType.INTS + if all(isinstance(x, float) for x in attr): + return _enums.AttributeType.FLOATS + if all(isinstance(x, str) for x in attr): + return _enums.AttributeType.STRINGS + if all( + isinstance(x, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol)) + for x in attr + ): + return _enums.AttributeType.TENSORS + if all( + isinstance(x, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol)) + for x in attr + ): + return _enums.AttributeType.GRAPHS + if all( + isinstance( + x, + ( + _core.TensorType, + _core.SequenceType, + _core.OptionalType, + _protocols.TypeProtocol, + ), + ) + for x in attr + ): + return _enums.AttributeType.TYPE_PROTOS raise TypeError(f"Unsupported attribute type: '{type(attr)}'") @@ -218,7 +224,7 @@ def convert_attributes( ... "type_protos": [ir.TensorType(ir.DataType.FLOAT), ir.TensorType(ir.DataType.FLOAT)], ... } >>> convert_attributes(attrs) - [Attr('int', INT, 1), Attr('float', FLOAT, 1.0), Attr('str', STRING, 'hello'), Attr('ints', INTS, [1, 2, 3]), Attr('floats', FLOATS, [1.0, 2.0, 3.0]), Attr('strings', STRINGS, ['hello', 'world']), Attr('tensor', TENSOR, Tensor(array([1., 2., 3.]), name=None)), Attr('tensor_proto', TENSOR, TensorProtoTensor(array([1., 2., 3.], dtype=float32), name='proto')), Attr('graph', INTS, Graph( + [Attr('int', INT, 1), Attr('float', FLOAT, 1.0), Attr('str', STRING, 'hello'), Attr('ints', INTS, [1, 2, 3]), Attr('floats', FLOATS, [1.0, 2.0, 3.0]), Attr('strings', STRINGS, ['hello', 'world']), Attr('tensor', TENSOR, Tensor(array([1., 2., 3.]), name=None)), Attr('tensor_proto', TENSOR, TensorProtoTensor(array([1., 2., 3.], dtype=float32), name='proto')), Attr('graph', GRAPH, Graph( name='graph0', inputs=( @@ -252,6 +258,11 @@ def convert_attributes( Returns: A list of _core.Attr objects. + + Raises: + ValueError: If an attribute is an empty sequence. It should be created with an + explicit type by initializing an Attr object with an attribute type. + TypeError: If an attribute type is not supported. """ attributes: list[_core.Attr] = [] for name, attr in attrs.items(): diff --git a/src/onnx_ir/passes/common/constant_manipulation.py b/src/onnx_ir/passes/common/constant_manipulation.py index da104118..369be3fd 100644 --- a/src/onnx_ir/passes/common/constant_manipulation.py +++ b/src/onnx_ir/passes/common/constant_manipulation.py @@ -148,6 +148,7 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: if graph is model.graph: continue for name in tuple(graph.initializers): + assert name is not None initializer = graph.initializers[name] if initializer.is_graph_input(): # Skip the ones that are also graph inputs @@ -156,17 +157,24 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: initializer.name, ) continue + if initializer.is_graph_output(): + logger.debug( + "Initializer '%s' is used as output, so it can't be lifted", + initializer.name, + ) + continue # Remove the initializer from the subgraph graph.initializers.pop(name) # To avoid name conflicts, we need to rename the initializer # to a unique name in the main graph - if name in registered_initializer_names: - name_count = registered_initializer_names[name] - initializer.name = f"{name}_{name_count}" - registered_initializer_names[name] = name_count + 1 - else: - assert initializer.name is not None - registered_initializer_names[initializer.name] = 1 + new_name = name + while new_name in model.graph.initializers: + if name in registered_initializer_names: + registered_initializer_names[name] += 1 + else: + registered_initializer_names[name] = 1 + new_name = f"{name}_{registered_initializer_names[name]}" + initializer.name = new_name model.graph.register_initializer(initializer) count += 1 logger.debug( diff --git a/src/onnx_ir/passes/common/constant_manipulation_test.py b/src/onnx_ir/passes/common/constant_manipulation_test.py index f3862917..f903e263 100644 --- a/src/onnx_ir/passes/common/constant_manipulation_test.py +++ b/src/onnx_ir/passes/common/constant_manipulation_test.py @@ -399,6 +399,48 @@ def test_pass_does_not_lift_initialized_inputs_in_subgraph( for value, tensor in zip(main_graph.initializers.values(), [else_initializer_tensor]): self.assertIs(value.const_value, tensor) + def test_pass_does_not_lift_initialized_outputs_in_subgraph(self): + input = ir.Value( + name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) + ) + output = ir.Value( + name="output", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) + ) + + subgraph_output = ir.Value( + name="subgraph_output", + type=ir.TensorType(ir.DataType.FLOAT), + shape=ir.Shape((1, 2)), + const_value=ir.tensor(np.random.rand(1, 2).astype(np.float32)), + ) + + graph_node = ir.node( + "OpWithSubgraph", + inputs=[input], + attributes={ + "subgraph": ir.Graph( + inputs=[], + outputs=[subgraph_output], + nodes=[], + initializers=[subgraph_output], + ) + }, + outputs=[output], + ) + model = ir.Model( + graph=ir.Graph( + inputs=[input], + outputs=[output], + nodes=[graph_node], + opset_imports={"": 20}, + ), + ir_version=10, + ) + + result = constant_manipulation.LiftSubgraphInitializersToMainGraphPass()(model) + self.assertFalse(result.modified) + self.assertIs(subgraph_output.graph, graph_node.attributes["subgraph"].as_graph()) + class TestRemoveInitializersFromInputsPass(unittest.TestCase): def test_remove_initializers_from_inputs(self):