Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 37 additions & 20 deletions onnxscript/ir/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,32 +627,43 @@

# Initialize the values dictionary for this graph scope with the inputs and initializers
values: dict[str, _core.Value] = {v.name: v for v in inputs} # type: ignore[misc]

# Enter the graph scope by pushing the values for this scope to the stack
scoped_values.append(values)

initializer_values = []
for tensor in initializer_tensors:
if tensor.name in values:
for i, tensor in enumerate(initializer_tensors):
initializer_name = tensor.name
if not initializer_name:
logger.warning(

Check warning on line 638 in onnxscript/ir/serde.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/serde.py#L638

Added line #L638 was not covered by tests
"Initializer tensor must have a name but the %s-th initializer does not. Skipping this initializer.",
i,
)
continue

Check warning on line 642 in onnxscript/ir/serde.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/serde.py#L642

Added line #L642 was not covered by tests
if initializer_name in values:
# The initializer is for an input
initializer_value = values[tensor.name]
initializer_value = values[initializer_name]
initializer_value.const_value = tensor
else:
# The initializer is for some other value. Create this value first
initializer_value = _core.Value(
None,
index=None,
name=tensor.name,
# TODO(justinchuby): Fix type hinting for shape and dtype
shape=tensor.shape, # type: ignore
name=initializer_name,
# Include shape and type even if the shape or type is not provided as ValueInfoProto.
# Users expect initialized values to have shape and type information.
type=_core.TensorType(tensor.dtype),
shape=tensor.shape, # type: ignore[arg-type]
const_value=tensor,
)
if initializer_value.name in quantization_annotations:
_deserialize_quantization_annotation(
quantization_annotations[initializer_value.name], initializer_value
)
values[tensor.name] = initializer_value # type: ignore[index]
values[initializer_name] = initializer_value
initializer_values.append(initializer_value)

# Add ValueInfos for this graph scope
# Build the value info dictionary to allow for quick lookup for this graph scope
value_info = {info.name: info for info in proto.value_info}

# Deserialize nodes with all known values
Expand All @@ -663,7 +674,10 @@

# Fill in values for graph outputs
outputs = [deserialize_value_info_proto(info, values[info.name]) for info in proto.output]

# Exit the graph scope by popping the values for this scope from the stack
scoped_values.pop()

return _core.Graph(
inputs,
outputs,
Expand Down Expand Up @@ -1204,24 +1218,24 @@
opset_ids.add(domain=domain, version=version)


def _serialize_metadata_props_into(
def _serialize_string_string_maps(
string_string_entries: proto_containers.RepeatedCompositeFieldContainer[
onnx.StringStringEntryProto
],
from_: Mapping[str, str],
) -> None:
"""Serialize metadata properties into a repeated field of string-string entries.
"""Serialize a <str, str> mapping into a repeated field of string-string entries.
Args:
string_string_entries: The repeated field to serialize into.
from_: The mapping of metadata properties to serialize.
from_: The mapping of a <str, str> mapping to serialize.
"""
# Sort names for deterministic serialization
for key in sorted(from_):
string_string_entries.add(key=key, value=from_[key])


_serialize_string_string_maps = _serialize_metadata_props_into
_serialize_metadata_props_into = _serialize_string_string_maps


def _maybe_add_quantization_annotation(
Expand Down Expand Up @@ -1284,18 +1298,21 @@
# TODO(justinchuby): We should add a method is_initializer() on Value when
# the initializer list is tracked
_maybe_add_quantization_annotation(graph_proto, input_)
input_names = {input_.name for input_ in from_.inputs}
# TODO(justinchuby): Support sparse_initializer
for initializer in from_.initializers.values():
_maybe_add_quantization_annotation(graph_proto, initializer)
if initializer.const_value is None:
for value in from_.initializers.values():
_maybe_add_quantization_annotation(graph_proto, value)
if _should_create_value_info_for_value(value) and value.name not in input_names:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the cases that initializers are model inputs? Does that mean the inputs are constants?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's like a parameter having a default value. Any input can be initialized if an initializer of the same name is in the graph. Users can choose to overwrite the initializer by providing their own input.

# Serialize information about all initializers into value_info,
# except for those that are also graph inputs
serialize_value_into(graph_proto.value_info.add(), value)
if value.const_value is None:
# Skip initializers without constant values
logger.warning(
"Initializer '%s' does not have a constant value set.", initializer.name
)
logger.warning("Initializer '%s' does not have a constant value set.", value.name)

Check warning on line 1311 in onnxscript/ir/serde.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/serde.py#L1311

Added line #L1311 was not covered by tests
continue
# Make sure the tensor's name is the same as the value's name
initializer.const_value.name = initializer.name
serialize_tensor_into(graph_proto.initializer.add(), from_=initializer.const_value)
value.const_value.name = value.name
serialize_tensor_into(graph_proto.initializer.add(), from_=value.const_value)
for node in from_:
serialize_node_into(graph_proto.node.add(), from_=node)
for node_output in node.outputs:
Expand Down
Git LFS file not shown
Git LFS file not shown
4 changes: 2 additions & 2 deletions testdata/e2e_models/resnet18/dynamo/resnet18_dynamo.onnx
Git LFS file not shown
4 changes: 2 additions & 2 deletions testdata/e2e_models/torchscript_model/torchscript_model.onnx
Git LFS file not shown
Loading