Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/onnx_ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,11 @@ def shape(self) -> Shape:
"""The shape of the tensor. Immutable."""
return self._shape

@property
def nbytes(self) -> int:
"""The number of bytes in the tensor."""
return sum(len(string) for string in self.string_data())

@property
def raw(self) -> Sequence[bytes] | npt.NDArray[np.bytes_]:
"""Backing data of the tensor. Immutable."""
Expand Down
22 changes: 22 additions & 0 deletions src/onnx_ir/_core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2485,5 +2485,27 @@ def test_integration_with_regular_tensor_operations(self):
self.assertEqual(result.sum(), 10) # 1+2+3+4 = 10


class StringTensorTest(unittest.TestCase):
def test_nbytes(self):
data = np.array([b"A", b"BC", b"D"])
tensor = _core.StringTensor(data)
self.assertEqual(tensor.nbytes, 4)

def test_nbytes_2d(self):
data = np.array([[b"A", b"BC", b"D"], [b"EFG", b"H", b"I"]])
tensor = _core.StringTensor(data)
self.assertEqual(tensor.nbytes, 9)

def test_nbytes_empty(self):
data = np.array([])
tensor = _core.StringTensor(data)
self.assertEqual(tensor.nbytes, 0)

def test_nbytes_single(self):
data = np.array([b"ABC"])
tensor = _core.StringTensor(data)
self.assertEqual(tensor.nbytes, 3)


if __name__ == "__main__":
unittest.main()
6 changes: 5 additions & 1 deletion src/onnx_ir/_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def from_numpy(cls, dtype: np.dtype) -> DataType:
if dtype in _NP_TYPE_TO_DATA_TYPE:
return cls(_NP_TYPE_TO_DATA_TYPE[dtype])

if np.issubdtype(dtype, np.str_):
if np.issubdtype(dtype, np.str_) or np.issubdtype(dtype, np.bytes_):
return DataType.STRING

# Special cases for handling custom dtypes defined in ONNX (as of onnx 1.18)
Expand Down Expand Up @@ -215,6 +215,10 @@ def is_signed(self) -> bool:
DataType.FLOAT8E8M0,
}

def is_string(self) -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs a versionadded line too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, #177

"""Returns True if the data type is a string type."""
return self == DataType.STRING

def __repr__(self) -> str:
return self.name

Expand Down
32 changes: 22 additions & 10 deletions src/onnx_ir/passes/common/initializer_deduplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import hashlib
import logging

import numpy as np

import onnx_ir as ir

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -42,17 +44,27 @@ def _should_skip_initializer(initializer: ir.Value, size_limit: int) -> bool:
size_limit,
)
return True

if const_val.dtype == ir.DataType.STRING:
# Skip string initializers as they don't have a bytes representation
logger.warning(
"Skipped deduplication of string initializer '%s' (unsupported yet)",
initializer.name,
)
return True
return False


def _tobytes(val):
"""StringTensor does not support tobytes. Use 'string_data' instead.

However, 'string_data' yields a list of bytes which cannot be hashed, i.e.,
cannot be used to index into a dict. To generate keys for identifying
tensors in initializer deduplication the following converts the list of
bytes to an array of fixed-length strings which can be flattened into a
bytes-string. This, together with the tensor shape, is sufficient for
identifying tensors for deduplication, but it differs from the
representation used for serializing tensors (that is string_data) by adding
padding bytes so that each string occupies the same number of consecutive
bytes in the flattened .tobytes representation.
"""
if val.dtype.is_string():
return np.array(val.string_data()).tobytes()
return val.tobytes()


class DeduplicateInitializersPass(ir.passes.InPlacePass):
"""Remove duplicated initializer tensors from the main graph and all subgraphs.

Expand Down Expand Up @@ -84,7 +96,7 @@ def call(self, model: ir.Model) -> ir.passes.PassResult:
const_val = initializer.const_value
assert const_val is not None

key = (const_val.dtype, tuple(const_val.shape), const_val.tobytes())
key = (const_val.dtype, tuple(const_val.shape), _tobytes(const_val))
if key in initializers:
modified = True
initializer_to_keep = initializers[key] # type: ignore[index]
Expand Down Expand Up @@ -143,7 +155,7 @@ def call(self, model: ir.Model) -> ir.passes.PassResult:
key = (const_val.dtype, tensor_dims, tensor_digest)

if key in initializers:
if initializers[key].const_value.tobytes() != const_val.tobytes():
if _tobytes(initializers[key].const_value) != _tobytes(const_val):
logger.warning(
"Initializer deduplication failed: "
"hashes match but values differ with values %s and %s",
Expand Down
37 changes: 37 additions & 0 deletions src/onnx_ir/passes/common/initializer_deduplication_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,19 @@ def test_deduplicates_identical_initializers(self):
add_node = new_model.graph[0]
self.assertEqual(add_node.inputs[0], add_node.inputs[1])

def test_deduplicates_identical_string_initializers(self):
model = ir.from_onnx_text(
"""
<ir_version: 10, opset_import: ["" : 17]>
agraph () => ()
<string[2] s1 = {"A", "B"}, string[2] s2 = {"A", "B"}> {
}
"""
)
self.assertEqual(len(model.graph.initializers), 2)
new_model = self.apply_pass(model)
self.assertEqual(len(new_model.graph.initializers), 1)

def test_initializers_with_different_shapes_not_deduplicated(self):
model = ir.from_onnx_text(
"""
Expand All @@ -60,6 +73,30 @@ def test_initializers_with_different_shapes_not_deduplicated(self):
new_model = self.apply_pass(model)
self.assertEqual(len(new_model.graph.initializers), 2)

def test_string_initializers_with_different_shapes_not_deduplicated(self):
model = ir.from_onnx_text(
"""
<ir_version: 10, opset_import: ["" : 17]>
agraph () => ()
<string[2] s1 = {"A", "B"}, string[1,2] s2 = {"A", "B"}> {
}
"""
)
new_model = self.apply_pass(model)
self.assertEqual(len(new_model.graph.initializers), 2)

def test_string_initializers_with_same_bytes_but_different_grouping_not_deduplicated(self):
model = ir.from_onnx_text(
"""
<ir_version: 10, opset_import: ["" : 17]>
agraph () => ()
<string[2] s1 = {"AB", "C"}, string[2] s2 = {"A", "BC"}> {
}
"""
)
new_model = self.apply_pass(model)
self.assertEqual(len(new_model.graph.initializers), 2)

def test_initializers_with_different_dtypes_not_deduplicated(self):
model = ir.from_onnx_text(
"""
Expand Down
Loading