@@ -98,7 +98,23 @@ def _compatible_with_dlpack(obj: Any) -> TypeGuard[_protocols.DLPackCompatible]:
98
98
class TensorBase (abc .ABC , _protocols .TensorProtocol , _display .PrettyPrintable ):
99
99
"""Convenience Shared methods for classes implementing TensorProtocol."""
100
100
101
- __slots__ = ()
101
+ __slots__ = (
102
+ "_doc_string" ,
103
+ "_metadata" ,
104
+ "_metadata_props" ,
105
+ "_name" ,
106
+ )
107
+
108
+ def __init__ (
109
+ self ,
110
+ name : str | None = None ,
111
+ doc_string : str | None = None ,
112
+ metadata_props : dict [str , str ] | None = None ,
113
+ ) -> None :
114
+ self ._metadata : _metadata .MetadataStore | None = None
115
+ self ._metadata_props : dict [str , str ] | None = metadata_props
116
+ self ._name : str | None = name
117
+ self ._doc_string : str | None = doc_string
102
118
103
119
def _printable_type_shape (self ) -> str :
104
120
"""Return a string representation of the shape and data type."""
@@ -111,6 +127,24 @@ def _repr_base(self) -> str:
111
127
"""
112
128
return f"{ self .__class__ .__name__ } <{ self ._printable_type_shape ()} >"
113
129
130
+ @property
131
+ def name (self ) -> str | None :
132
+ """The name of the tensor."""
133
+ return self ._name
134
+
135
+ @name .setter
136
+ def name (self , value : str | None ) -> None :
137
+ self ._name = value
138
+
139
+ @property
140
+ def doc_string (self ) -> str | None :
141
+ """The documentation string."""
142
+ return self ._doc_string
143
+
144
+ @doc_string .setter
145
+ def doc_string (self , value : str | None ) -> None :
146
+ self ._doc_string = value
147
+
114
148
@property
115
149
def size (self ) -> int :
116
150
"""The number of elements in the tensor."""
@@ -122,6 +156,23 @@ def nbytes(self) -> int:
122
156
# Use math.ceil because when dtype is INT4, the itemsize is 0.5
123
157
return math .ceil (self .dtype .itemsize * self .size )
124
158
159
+ @property
160
+ def metadata_props (self ) -> dict [str , str ]:
161
+ if self ._metadata_props is None :
162
+ self ._metadata_props = {}
163
+ return self ._metadata_props
164
+
165
+ @property
166
+ def meta (self ) -> _metadata .MetadataStore :
167
+ """The metadata store for intermediate analysis.
168
+
169
+ Write to the :attr:`metadata_props` if you would like the metadata to be serialized
170
+ to the ONNX proto.
171
+ """
172
+ if self ._metadata is None :
173
+ self ._metadata = _metadata .MetadataStore ()
174
+ return self ._metadata
175
+
125
176
def display (self , * , page : bool = False ) -> None :
126
177
rich = _display .require_rich ()
127
178
@@ -310,12 +361,8 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
310
361
311
362
__slots__ = (
312
363
"_dtype" ,
313
- "_metadata" ,
314
- "_metadata_props" ,
315
364
"_raw" ,
316
365
"_shape" ,
317
- "doc_string" ,
318
- "name" ,
319
366
)
320
367
321
368
def __init__ (
@@ -348,6 +395,7 @@ def __init__(
348
395
ValueError: If the shape is not specified and the value does not have a shape attribute.
349
396
ValueError: If the dtype is not specified and the value is not a numpy array.
350
397
"""
398
+ super ().__init__ (name = name , doc_string = doc_string , metadata_props = metadata_props )
351
399
# NOTE: We should not do any copying here for performance reasons
352
400
if not _compatible_with_numpy (value ) and not _compatible_with_dlpack (value ):
353
401
raise TypeError (f"Expected an array compatible object, got { type (value )} " )
@@ -382,10 +430,6 @@ def __init__(
382
430
value = _maybe_view_np_array_with_ml_dtypes (value , self ._dtype ) # type: ignore[assignment]
383
431
384
432
self ._raw = value
385
- self .name = name
386
- self .doc_string = doc_string
387
- self ._metadata : _metadata .MetadataStore | None = None
388
- self ._metadata_props = metadata_props
389
433
390
434
def __array__ (self , dtype : Any = None ) -> np .ndarray :
391
435
if isinstance (self ._raw , np .ndarray ) or _compatible_with_numpy (self ._raw ):
@@ -459,23 +503,6 @@ def tobytes(self) -> bytes:
459
503
array = array .view (array .dtype .newbyteorder ("<" ))
460
504
return array .tobytes ()
461
505
462
- @property
463
- def metadata_props (self ) -> dict [str , str ]:
464
- if self ._metadata_props is None :
465
- self ._metadata_props = {}
466
- return self ._metadata_props
467
-
468
- @property
469
- def meta (self ) -> _metadata .MetadataStore :
470
- """The metadata store for intermediate analysis.
471
-
472
- Write to the :attr:`metadata_props` if you would like the metadata to be serialized
473
- to the ONNX proto.
474
- """
475
- if self ._metadata is None :
476
- self ._metadata = _metadata .MetadataStore ()
477
- return self ._metadata
478
-
479
506
480
507
class ExternalTensor (TensorBase , _protocols .TensorProtocol ): # pylint: disable=too-many-ancestors
481
508
"""An immutable concrete tensor with its data store on disk.
@@ -516,13 +543,9 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
516
543
"_dtype" ,
517
544
"_length" ,
518
545
"_location" ,
519
- "_metadata" ,
520
- "_metadata_props" ,
521
546
"_offset" ,
522
547
"_shape" ,
523
548
"_valid" ,
524
- "doc_string" ,
525
- "name" ,
526
549
"raw" ,
527
550
)
528
551
@@ -552,6 +575,7 @@ def __init__(
552
575
metadata_props: The metadata properties.
553
576
base_dir: The base directory for the external data. It is used to resolve relative paths.
554
577
"""
578
+ super ().__init__ (name = name , doc_string = doc_string , metadata_props = metadata_props )
555
579
# NOTE: Do not verify the location by default. This is because the location field
556
580
# in the tensor proto can be anything and we would like deserialization from
557
581
# proto to IR to not fail.
@@ -729,34 +753,13 @@ def release(self) -> None:
729
753
self .raw .close ()
730
754
self .raw = None
731
755
732
- @property
733
- def metadata_props (self ) -> dict [str , str ]:
734
- if self ._metadata_props is None :
735
- self ._metadata_props = {}
736
- return self ._metadata_props
737
-
738
- @property
739
- def meta (self ) -> _metadata .MetadataStore :
740
- """The metadata store for intermediate analysis.
741
-
742
- Write to the :attr:`metadata_props` if you would like the metadata to be serialized
743
- to the ONNX proto.
744
- """
745
- if self ._metadata is None :
746
- self ._metadata = _metadata .MetadataStore ()
747
- return self ._metadata
748
-
749
756
750
757
class StringTensor (TensorBase , _protocols .TensorProtocol ): # pylint: disable=too-many-ancestors
751
758
"""Multidimensional array of strings (as binary data to match the string_data field in TensorProto)."""
752
759
753
760
__slots__ = (
754
- "_metadata" ,
755
- "_metadata_props" ,
756
761
"_raw" ,
757
762
"_shape" ,
758
- "doc_string" ,
759
- "name" ,
760
763
)
761
764
762
765
def __init__ (
@@ -777,6 +780,7 @@ def __init__(
777
780
doc_string: The documentation string.
778
781
metadata_props: The metadata properties.
779
782
"""
783
+ super ().__init__ (name = name , doc_string = doc_string , metadata_props = metadata_props )
780
784
if shape is None :
781
785
if not hasattr (value , "shape" ):
782
786
raise ValueError (
@@ -788,10 +792,6 @@ def __init__(
788
792
self ._shape = shape
789
793
self ._shape .freeze ()
790
794
self ._raw = value
791
- self .name = name
792
- self .doc_string = doc_string
793
- self ._metadata : _metadata .MetadataStore | None = None
794
- self ._metadata_props = metadata_props
795
795
796
796
def __array__ (self , dtype : Any = None ) -> np .ndarray :
797
797
if isinstance (self ._raw , np .ndarray ):
@@ -839,23 +839,6 @@ def string_data(self) -> Sequence[bytes]:
839
839
return self ._raw .flatten ().tolist ()
840
840
return self ._raw
841
841
842
- @property
843
- def metadata_props (self ) -> dict [str , str ]:
844
- if self ._metadata_props is None :
845
- self ._metadata_props = {}
846
- return self ._metadata_props
847
-
848
- @property
849
- def meta (self ) -> _metadata .MetadataStore :
850
- """The metadata store for intermediate analysis.
851
-
852
- Write to the :attr:`metadata_props` if you would like the metadata to be serialized
853
- to the ONNX proto.
854
- """
855
- if self ._metadata is None :
856
- self ._metadata = _metadata .MetadataStore ()
857
- return self ._metadata
858
-
859
842
860
843
class LazyTensor (TensorBase , _protocols .TensorProtocol ): # pylint: disable=too-many-ancestors
861
844
"""A tensor that lazily evaluates a function to get the actual tensor.
@@ -893,13 +876,9 @@ class LazyTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-
893
876
__slots__ = (
894
877
"_dtype" ,
895
878
"_func" ,
896
- "_metadata" ,
897
- "_metadata_props" ,
898
879
"_shape" ,
899
880
"_tensor" ,
900
881
"cache" ,
901
- "doc_string" ,
902
- "name" ,
903
882
)
904
883
905
884
def __init__ (
@@ -924,15 +903,12 @@ def __init__(
924
903
doc_string: The documentation string.
925
904
metadata_props: The metadata properties.
926
905
"""
906
+ super ().__init__ (name = name , doc_string = doc_string , metadata_props = metadata_props )
927
907
self ._func = func
928
908
self ._dtype = dtype
929
909
self ._shape = shape
930
910
self ._tensor : _protocols .TensorProtocol | None = None
931
911
self .cache = cache
932
- self .name = name
933
- self .doc_string = doc_string
934
- self ._metadata : _metadata .MetadataStore | None = None
935
- self ._metadata_props = metadata_props
936
912
937
913
def _evaluate (self ) -> _protocols .TensorProtocol :
938
914
"""Evaluate the function to get the actual tensor."""
@@ -978,23 +954,6 @@ def tobytes(self) -> bytes:
978
954
"""Return the bytes of the tensor."""
979
955
return self ._evaluate ().tobytes ()
980
956
981
- @property
982
- def metadata_props (self ) -> dict [str , str ]:
983
- if self ._metadata_props is None :
984
- self ._metadata_props = {}
985
- return self ._metadata_props
986
-
987
- @property
988
- def meta (self ) -> _metadata .MetadataStore :
989
- """The metadata store for intermediate analysis.
990
-
991
- Write to the :attr:`metadata_props` if you would like the metadata to be serialized
992
- to the ONNX proto.
993
- """
994
- if self ._metadata is None :
995
- self ._metadata = _metadata .MetadataStore ()
996
- return self ._metadata
997
-
998
957
999
958
class SymbolicDim (_protocols .SymbolicDimProtocol , _display .PrettyPrintable ):
1000
959
"""Immutable symbolic dimension that can be shared across multiple shapes."""
0 commit comments