@@ -98,7 +98,23 @@ def _compatible_with_dlpack(obj: Any) -> TypeGuard[_protocols.DLPackCompatible]:
9898class TensorBase (abc .ABC , _protocols .TensorProtocol , _display .PrettyPrintable ):
9999 """Convenience Shared methods for classes implementing TensorProtocol."""
100100
101- __slots__ = ()
101+ __slots__ = (
102+ "_doc_string" ,
103+ "_metadata" ,
104+ "_metadata_props" ,
105+ "_name" ,
106+ )
107+
108+ def __init__ (
109+ self ,
110+ name : str | None = None ,
111+ doc_string : str | None = None ,
112+ metadata_props : dict [str , str ] | None = None ,
113+ ) -> None :
114+ self ._metadata : _metadata .MetadataStore | None = None
115+ self ._metadata_props : dict [str , str ] | None = metadata_props
116+ self ._name : str | None = name
117+ self ._doc_string : str | None = doc_string
102118
103119 def _printable_type_shape (self ) -> str :
104120 """Return a string representation of the shape and data type."""
@@ -111,6 +127,24 @@ def _repr_base(self) -> str:
111127 """
112128 return f"{ self .__class__ .__name__ } <{ self ._printable_type_shape ()} >"
113129
130+ @property
131+ def name (self ) -> str | None :
132+ """The name of the tensor."""
133+ return self ._name
134+
135+ @name .setter
136+ def name (self , value : str | None ) -> None :
137+ self ._name = value
138+
139+ @property
140+ def doc_string (self ) -> str | None :
141+ """The documentation string."""
142+ return self ._doc_string
143+
144+ @doc_string .setter
145+ def doc_string (self , value : str | None ) -> None :
146+ self ._doc_string = value
147+
114148 @property
115149 def size (self ) -> int :
116150 """The number of elements in the tensor."""
@@ -122,6 +156,23 @@ def nbytes(self) -> int:
122156 # Use math.ceil because when dtype is INT4, the itemsize is 0.5
123157 return math .ceil (self .dtype .itemsize * self .size )
124158
159+ @property
160+ def metadata_props (self ) -> dict [str , str ]:
161+ if self ._metadata_props is None :
162+ self ._metadata_props = {}
163+ return self ._metadata_props
164+
165+ @property
166+ def meta (self ) -> _metadata .MetadataStore :
167+ """The metadata store for intermediate analysis.
168+
169+ Write to the :attr:`metadata_props` if you would like the metadata to be serialized
170+ to the ONNX proto.
171+ """
172+ if self ._metadata is None :
173+ self ._metadata = _metadata .MetadataStore ()
174+ return self ._metadata
175+
125176 def display (self , * , page : bool = False ) -> None :
126177 rich = _display .require_rich ()
127178
@@ -310,12 +361,8 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
310361
311362 __slots__ = (
312363 "_dtype" ,
313- "_metadata" ,
314- "_metadata_props" ,
315364 "_raw" ,
316365 "_shape" ,
317- "doc_string" ,
318- "name" ,
319366 )
320367
321368 def __init__ (
@@ -348,6 +395,7 @@ def __init__(
348395 ValueError: If the shape is not specified and the value does not have a shape attribute.
349396 ValueError: If the dtype is not specified and the value is not a numpy array.
350397 """
398+ super ().__init__ (name = name , doc_string = doc_string , metadata_props = metadata_props )
351399 # NOTE: We should not do any copying here for performance reasons
352400 if not _compatible_with_numpy (value ) and not _compatible_with_dlpack (value ):
353401 raise TypeError (f"Expected an array compatible object, got { type (value )} " )
@@ -382,10 +430,6 @@ def __init__(
382430 value = _maybe_view_np_array_with_ml_dtypes (value , self ._dtype ) # type: ignore[assignment]
383431
384432 self ._raw = value
385- self .name = name
386- self .doc_string = doc_string
387- self ._metadata : _metadata .MetadataStore | None = None
388- self ._metadata_props = metadata_props
389433
390434 def __array__ (self , dtype : Any = None ) -> np .ndarray :
391435 if isinstance (self ._raw , np .ndarray ) or _compatible_with_numpy (self ._raw ):
@@ -459,23 +503,6 @@ def tobytes(self) -> bytes:
459503 array = array .view (array .dtype .newbyteorder ("<" ))
460504 return array .tobytes ()
461505
462- @property
463- def metadata_props (self ) -> dict [str , str ]:
464- if self ._metadata_props is None :
465- self ._metadata_props = {}
466- return self ._metadata_props
467-
468- @property
469- def meta (self ) -> _metadata .MetadataStore :
470- """The metadata store for intermediate analysis.
471-
472- Write to the :attr:`metadata_props` if you would like the metadata to be serialized
473- to the ONNX proto.
474- """
475- if self ._metadata is None :
476- self ._metadata = _metadata .MetadataStore ()
477- return self ._metadata
478-
479506
480507class ExternalTensor (TensorBase , _protocols .TensorProtocol ): # pylint: disable=too-many-ancestors
481508 """An immutable concrete tensor with its data store on disk.
@@ -516,13 +543,9 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
516543 "_dtype" ,
517544 "_length" ,
518545 "_location" ,
519- "_metadata" ,
520- "_metadata_props" ,
521546 "_offset" ,
522547 "_shape" ,
523548 "_valid" ,
524- "doc_string" ,
525- "name" ,
526549 "raw" ,
527550 )
528551
@@ -552,6 +575,7 @@ def __init__(
552575 metadata_props: The metadata properties.
553576 base_dir: The base directory for the external data. It is used to resolve relative paths.
554577 """
578+ super ().__init__ (name = name , doc_string = doc_string , metadata_props = metadata_props )
555579 # NOTE: Do not verify the location by default. This is because the location field
556580 # in the tensor proto can be anything and we would like deserialization from
557581 # proto to IR to not fail.
@@ -729,34 +753,13 @@ def release(self) -> None:
729753 self .raw .close ()
730754 self .raw = None
731755
732- @property
733- def metadata_props (self ) -> dict [str , str ]:
734- if self ._metadata_props is None :
735- self ._metadata_props = {}
736- return self ._metadata_props
737-
738- @property
739- def meta (self ) -> _metadata .MetadataStore :
740- """The metadata store for intermediate analysis.
741-
742- Write to the :attr:`metadata_props` if you would like the metadata to be serialized
743- to the ONNX proto.
744- """
745- if self ._metadata is None :
746- self ._metadata = _metadata .MetadataStore ()
747- return self ._metadata
748-
749756
750757class StringTensor (TensorBase , _protocols .TensorProtocol ): # pylint: disable=too-many-ancestors
751758 """Multidimensional array of strings (as binary data to match the string_data field in TensorProto)."""
752759
753760 __slots__ = (
754- "_metadata" ,
755- "_metadata_props" ,
756761 "_raw" ,
757762 "_shape" ,
758- "doc_string" ,
759- "name" ,
760763 )
761764
762765 def __init__ (
@@ -777,6 +780,7 @@ def __init__(
777780 doc_string: The documentation string.
778781 metadata_props: The metadata properties.
779782 """
783+ super ().__init__ (name = name , doc_string = doc_string , metadata_props = metadata_props )
780784 if shape is None :
781785 if not hasattr (value , "shape" ):
782786 raise ValueError (
@@ -788,10 +792,6 @@ def __init__(
788792 self ._shape = shape
789793 self ._shape .freeze ()
790794 self ._raw = value
791- self .name = name
792- self .doc_string = doc_string
793- self ._metadata : _metadata .MetadataStore | None = None
794- self ._metadata_props = metadata_props
795795
796796 def __array__ (self , dtype : Any = None ) -> np .ndarray :
797797 if isinstance (self ._raw , np .ndarray ):
@@ -839,23 +839,6 @@ def string_data(self) -> Sequence[bytes]:
839839 return self ._raw .flatten ().tolist ()
840840 return self ._raw
841841
842- @property
843- def metadata_props (self ) -> dict [str , str ]:
844- if self ._metadata_props is None :
845- self ._metadata_props = {}
846- return self ._metadata_props
847-
848- @property
849- def meta (self ) -> _metadata .MetadataStore :
850- """The metadata store for intermediate analysis.
851-
852- Write to the :attr:`metadata_props` if you would like the metadata to be serialized
853- to the ONNX proto.
854- """
855- if self ._metadata is None :
856- self ._metadata = _metadata .MetadataStore ()
857- return self ._metadata
858-
859842
860843class LazyTensor (TensorBase , _protocols .TensorProtocol ): # pylint: disable=too-many-ancestors
861844 """A tensor that lazily evaluates a function to get the actual tensor.
@@ -893,13 +876,9 @@ class LazyTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-
893876 __slots__ = (
894877 "_dtype" ,
895878 "_func" ,
896- "_metadata" ,
897- "_metadata_props" ,
898879 "_shape" ,
899880 "_tensor" ,
900881 "cache" ,
901- "doc_string" ,
902- "name" ,
903882 )
904883
905884 def __init__ (
@@ -924,15 +903,12 @@ def __init__(
924903 doc_string: The documentation string.
925904 metadata_props: The metadata properties.
926905 """
906+ super ().__init__ (name = name , doc_string = doc_string , metadata_props = metadata_props )
927907 self ._func = func
928908 self ._dtype = dtype
929909 self ._shape = shape
930910 self ._tensor : _protocols .TensorProtocol | None = None
931911 self .cache = cache
932- self .name = name
933- self .doc_string = doc_string
934- self ._metadata : _metadata .MetadataStore | None = None
935- self ._metadata_props = metadata_props
936912
937913 def _evaluate (self ) -> _protocols .TensorProtocol :
938914 """Evaluate the function to get the actual tensor."""
@@ -978,23 +954,6 @@ def tobytes(self) -> bytes:
978954 """Return the bytes of the tensor."""
979955 return self ._evaluate ().tobytes ()
980956
981- @property
982- def metadata_props (self ) -> dict [str , str ]:
983- if self ._metadata_props is None :
984- self ._metadata_props = {}
985- return self ._metadata_props
986-
987- @property
988- def meta (self ) -> _metadata .MetadataStore :
989- """The metadata store for intermediate analysis.
990-
991- Write to the :attr:`metadata_props` if you would like the metadata to be serialized
992- to the ONNX proto.
993- """
994- if self ._metadata is None :
995- self ._metadata = _metadata .MetadataStore ()
996- return self ._metadata
997-
998957
999958class SymbolicDim (_protocols .SymbolicDimProtocol , _display .PrettyPrintable ):
1000959 """Immutable symbolic dimension that can be shared across multiple shapes."""
0 commit comments