Skip to content

Commit a8f56c2

Browse files
authored
[IR] Refactor TensorBase to simplify implementation (#2081)
Move name, doc_string, meta and metadata fields to the base class and simplify implementation.
1 parent db414d7 commit a8f56c2

File tree

2 files changed

+61
-120
lines changed

2 files changed

+61
-120
lines changed

onnxscript/ir/_core.py

Lines changed: 56 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,23 @@ def _compatible_with_dlpack(obj: Any) -> TypeGuard[_protocols.DLPackCompatible]:
9898
class TensorBase(abc.ABC, _protocols.TensorProtocol, _display.PrettyPrintable):
9999
"""Convenience Shared methods for classes implementing TensorProtocol."""
100100

101-
__slots__ = ()
101+
__slots__ = (
102+
"_doc_string",
103+
"_metadata",
104+
"_metadata_props",
105+
"_name",
106+
)
107+
108+
def __init__(
109+
self,
110+
name: str | None = None,
111+
doc_string: str | None = None,
112+
metadata_props: dict[str, str] | None = None,
113+
) -> None:
114+
self._metadata: _metadata.MetadataStore | None = None
115+
self._metadata_props: dict[str, str] | None = metadata_props
116+
self._name: str | None = name
117+
self._doc_string: str | None = doc_string
102118

103119
def _printable_type_shape(self) -> str:
104120
"""Return a string representation of the shape and data type."""
@@ -111,6 +127,24 @@ def _repr_base(self) -> str:
111127
"""
112128
return f"{self.__class__.__name__}<{self._printable_type_shape()}>"
113129

130+
@property
131+
def name(self) -> str | None:
132+
"""The name of the tensor."""
133+
return self._name
134+
135+
@name.setter
136+
def name(self, value: str | None) -> None:
137+
self._name = value
138+
139+
@property
140+
def doc_string(self) -> str | None:
141+
"""The documentation string."""
142+
return self._doc_string
143+
144+
@doc_string.setter
145+
def doc_string(self, value: str | None) -> None:
146+
self._doc_string = value
147+
114148
@property
115149
def size(self) -> int:
116150
"""The number of elements in the tensor."""
@@ -122,6 +156,23 @@ def nbytes(self) -> int:
122156
# Use math.ceil because when dtype is INT4, the itemsize is 0.5
123157
return math.ceil(self.dtype.itemsize * self.size)
124158

159+
@property
160+
def metadata_props(self) -> dict[str, str]:
161+
if self._metadata_props is None:
162+
self._metadata_props = {}
163+
return self._metadata_props
164+
165+
@property
166+
def meta(self) -> _metadata.MetadataStore:
167+
"""The metadata store for intermediate analysis.
168+
169+
Write to the :attr:`metadata_props` if you would like the metadata to be serialized
170+
to the ONNX proto.
171+
"""
172+
if self._metadata is None:
173+
self._metadata = _metadata.MetadataStore()
174+
return self._metadata
175+
125176
def display(self, *, page: bool = False) -> None:
126177
rich = _display.require_rich()
127178

@@ -310,12 +361,8 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
310361

311362
__slots__ = (
312363
"_dtype",
313-
"_metadata",
314-
"_metadata_props",
315364
"_raw",
316365
"_shape",
317-
"doc_string",
318-
"name",
319366
)
320367

321368
def __init__(
@@ -348,6 +395,7 @@ def __init__(
348395
ValueError: If the shape is not specified and the value does not have a shape attribute.
349396
ValueError: If the dtype is not specified and the value is not a numpy array.
350397
"""
398+
super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props)
351399
# NOTE: We should not do any copying here for performance reasons
352400
if not _compatible_with_numpy(value) and not _compatible_with_dlpack(value):
353401
raise TypeError(f"Expected an array compatible object, got {type(value)}")
@@ -382,10 +430,6 @@ def __init__(
382430
value = _maybe_view_np_array_with_ml_dtypes(value, self._dtype) # type: ignore[assignment]
383431

384432
self._raw = value
385-
self.name = name
386-
self.doc_string = doc_string
387-
self._metadata: _metadata.MetadataStore | None = None
388-
self._metadata_props = metadata_props
389433

390434
def __array__(self, dtype: Any = None) -> np.ndarray:
391435
if isinstance(self._raw, np.ndarray) or _compatible_with_numpy(self._raw):
@@ -459,23 +503,6 @@ def tobytes(self) -> bytes:
459503
array = array.view(array.dtype.newbyteorder("<"))
460504
return array.tobytes()
461505

462-
@property
463-
def metadata_props(self) -> dict[str, str]:
464-
if self._metadata_props is None:
465-
self._metadata_props = {}
466-
return self._metadata_props
467-
468-
@property
469-
def meta(self) -> _metadata.MetadataStore:
470-
"""The metadata store for intermediate analysis.
471-
472-
Write to the :attr:`metadata_props` if you would like the metadata to be serialized
473-
to the ONNX proto.
474-
"""
475-
if self._metadata is None:
476-
self._metadata = _metadata.MetadataStore()
477-
return self._metadata
478-
479506

480507
class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors
481508
"""An immutable concrete tensor with its data store on disk.
@@ -516,13 +543,9 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
516543
"_dtype",
517544
"_length",
518545
"_location",
519-
"_metadata",
520-
"_metadata_props",
521546
"_offset",
522547
"_shape",
523548
"_valid",
524-
"doc_string",
525-
"name",
526549
"raw",
527550
)
528551

@@ -552,6 +575,7 @@ def __init__(
552575
metadata_props: The metadata properties.
553576
base_dir: The base directory for the external data. It is used to resolve relative paths.
554577
"""
578+
super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props)
555579
# NOTE: Do not verify the location by default. This is because the location field
556580
# in the tensor proto can be anything and we would like deserialization from
557581
# proto to IR to not fail.
@@ -729,34 +753,13 @@ def release(self) -> None:
729753
self.raw.close()
730754
self.raw = None
731755

732-
@property
733-
def metadata_props(self) -> dict[str, str]:
734-
if self._metadata_props is None:
735-
self._metadata_props = {}
736-
return self._metadata_props
737-
738-
@property
739-
def meta(self) -> _metadata.MetadataStore:
740-
"""The metadata store for intermediate analysis.
741-
742-
Write to the :attr:`metadata_props` if you would like the metadata to be serialized
743-
to the ONNX proto.
744-
"""
745-
if self._metadata is None:
746-
self._metadata = _metadata.MetadataStore()
747-
return self._metadata
748-
749756

750757
class StringTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors
751758
"""Multidimensional array of strings (as binary data to match the string_data field in TensorProto)."""
752759

753760
__slots__ = (
754-
"_metadata",
755-
"_metadata_props",
756761
"_raw",
757762
"_shape",
758-
"doc_string",
759-
"name",
760763
)
761764

762765
def __init__(
@@ -777,6 +780,7 @@ def __init__(
777780
doc_string: The documentation string.
778781
metadata_props: The metadata properties.
779782
"""
783+
super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props)
780784
if shape is None:
781785
if not hasattr(value, "shape"):
782786
raise ValueError(
@@ -788,10 +792,6 @@ def __init__(
788792
self._shape = shape
789793
self._shape.freeze()
790794
self._raw = value
791-
self.name = name
792-
self.doc_string = doc_string
793-
self._metadata: _metadata.MetadataStore | None = None
794-
self._metadata_props = metadata_props
795795

796796
def __array__(self, dtype: Any = None) -> np.ndarray:
797797
if isinstance(self._raw, np.ndarray):
@@ -839,23 +839,6 @@ def string_data(self) -> Sequence[bytes]:
839839
return self._raw.flatten().tolist()
840840
return self._raw
841841

842-
@property
843-
def metadata_props(self) -> dict[str, str]:
844-
if self._metadata_props is None:
845-
self._metadata_props = {}
846-
return self._metadata_props
847-
848-
@property
849-
def meta(self) -> _metadata.MetadataStore:
850-
"""The metadata store for intermediate analysis.
851-
852-
Write to the :attr:`metadata_props` if you would like the metadata to be serialized
853-
to the ONNX proto.
854-
"""
855-
if self._metadata is None:
856-
self._metadata = _metadata.MetadataStore()
857-
return self._metadata
858-
859842

860843
class LazyTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors
861844
"""A tensor that lazily evaluates a function to get the actual tensor.
@@ -893,13 +876,9 @@ class LazyTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-
893876
__slots__ = (
894877
"_dtype",
895878
"_func",
896-
"_metadata",
897-
"_metadata_props",
898879
"_shape",
899880
"_tensor",
900881
"cache",
901-
"doc_string",
902-
"name",
903882
)
904883

905884
def __init__(
@@ -924,15 +903,12 @@ def __init__(
924903
doc_string: The documentation string.
925904
metadata_props: The metadata properties.
926905
"""
906+
super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props)
927907
self._func = func
928908
self._dtype = dtype
929909
self._shape = shape
930910
self._tensor: _protocols.TensorProtocol | None = None
931911
self.cache = cache
932-
self.name = name
933-
self.doc_string = doc_string
934-
self._metadata: _metadata.MetadataStore | None = None
935-
self._metadata_props = metadata_props
936912

937913
def _evaluate(self) -> _protocols.TensorProtocol:
938914
"""Evaluate the function to get the actual tensor."""
@@ -978,23 +954,6 @@ def tobytes(self) -> bytes:
978954
"""Return the bytes of the tensor."""
979955
return self._evaluate().tobytes()
980956

981-
@property
982-
def metadata_props(self) -> dict[str, str]:
983-
if self._metadata_props is None:
984-
self._metadata_props = {}
985-
return self._metadata_props
986-
987-
@property
988-
def meta(self) -> _metadata.MetadataStore:
989-
"""The metadata store for intermediate analysis.
990-
991-
Write to the :attr:`metadata_props` if you would like the metadata to be serialized
992-
to the ONNX proto.
993-
"""
994-
if self._metadata is None:
995-
self._metadata = _metadata.MetadataStore()
996-
return self._metadata
997-
998957

999958
class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
1000959
"""Immutable symbolic dimension that can be shared across multiple shapes."""

onnxscript/ir/serde.py

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
import onnx
6868
import onnx.external_data_helper
6969

70-
from onnxscript.ir import _core, _enums, _metadata, _protocols, _type_casting
70+
from onnxscript.ir import _core, _enums, _protocols, _type_casting
7171

7272
if typing.TYPE_CHECKING:
7373
import google.protobuf.internal.containers as proto_containers
@@ -243,12 +243,11 @@ def to_proto(ir_object: object) -> object:
243243
class TensorProtoTensor(_core.TensorBase): # pylint: disable=too-many-ancestors
244244
"""A tensor initialized from a tensor proto."""
245245

246+
__slots__ = ("_proto",)
247+
246248
def __init__(self, proto: onnx.TensorProto) -> None:
249+
super().__init__(metadata_props=deserialize_metadata_props(proto.metadata_props))
247250
self._proto = proto
248-
self._metadata_props: dict[str, str] | None = deserialize_metadata_props(
249-
proto.metadata_props
250-
)
251-
self._metadata: _metadata.MetadataStore | None = None
252251

253252
@property
254253
def name(self) -> str:
@@ -269,7 +268,7 @@ def shape(self) -> _core.Shape:
269268
def dtype(self) -> _enums.DataType:
270269
return _enums.DataType(self._proto.data_type)
271270

272-
@property
271+
@property # type: ignore[misc]
273272
def doc_string(self) -> str:
274273
return self._proto.doc_string
275274

@@ -440,23 +439,6 @@ def tobytes(self) -> bytes:
440439
# For example, int32_data can be empty and still be a valid tensor.
441440
return b""
442441

443-
@property
444-
def meta(self) -> _metadata.MetadataStore:
445-
"""The metadata store for intermediate analysis.
446-
447-
Write to the :attr:`metadata_props` if you would like the metadata to be serialized
448-
to the ONNX proto.
449-
"""
450-
if self._metadata is None:
451-
self._metadata = _metadata.MetadataStore()
452-
return self._metadata
453-
454-
@property
455-
def metadata_props(self) -> dict[str, str]:
456-
if self._metadata_props is None:
457-
self._metadata_props = {}
458-
return self._metadata_props
459-
460442

461443
def _get_field(proto: Any, field: str) -> Any:
462444
if proto.HasField(field):

0 commit comments

Comments
 (0)