diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index 32073c5b9..c4b159563 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -98,7 +98,23 @@ def _compatible_with_dlpack(obj: Any) -> TypeGuard[_protocols.DLPackCompatible]: class TensorBase(abc.ABC, _protocols.TensorProtocol, _display.PrettyPrintable): """Convenience Shared methods for classes implementing TensorProtocol.""" - __slots__ = () + __slots__ = ( + "_doc_string", + "_metadata", + "_metadata_props", + "_name", + ) + + def __init__( + self, + name: str | None = None, + doc_string: str | None = None, + metadata_props: dict[str, str] | None = None, + ) -> None: + self._metadata: _metadata.MetadataStore | None = None + self._metadata_props: dict[str, str] | None = metadata_props + self._name: str | None = name + self._doc_string: str | None = doc_string def _printable_type_shape(self) -> str: """Return a string representation of the shape and data type.""" @@ -111,6 +127,24 @@ def _repr_base(self) -> str: """ return f"{self.__class__.__name__}<{self._printable_type_shape()}>" + @property + def name(self) -> str | None: + """The name of the tensor.""" + return self._name + + @name.setter + def name(self, value: str | None) -> None: + self._name = value + + @property + def doc_string(self) -> str | None: + """The documentation string.""" + return self._doc_string + + @doc_string.setter + def doc_string(self, value: str | None) -> None: + self._doc_string = value + @property def size(self) -> int: """The number of elements in the tensor.""" @@ -122,6 +156,23 @@ def nbytes(self) -> int: # Use math.ceil because when dtype is INT4, the itemsize is 0.5 return math.ceil(self.dtype.itemsize * self.size) + @property + def metadata_props(self) -> dict[str, str]: + if self._metadata_props is None: + self._metadata_props = {} + return self._metadata_props + + @property + def meta(self) -> _metadata.MetadataStore: + """The metadata store for intermediate analysis. + + Write to the :attr:`metadata_props` if you would like the metadata to be serialized + to the ONNX proto. + """ + if self._metadata is None: + self._metadata = _metadata.MetadataStore() + return self._metadata + def display(self, *, page: bool = False) -> None: rich = _display.require_rich() @@ -310,12 +361,8 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): __slots__ = ( "_dtype", - "_metadata", - "_metadata_props", "_raw", "_shape", - "doc_string", - "name", ) def __init__( @@ -348,6 +395,7 @@ def __init__( ValueError: If the shape is not specified and the value does not have a shape attribute. ValueError: If the dtype is not specified and the value is not a numpy array. """ + super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props) # NOTE: We should not do any copying here for performance reasons if not _compatible_with_numpy(value) and not _compatible_with_dlpack(value): raise TypeError(f"Expected an array compatible object, got {type(value)}") @@ -382,10 +430,6 @@ def __init__( value = _maybe_view_np_array_with_ml_dtypes(value, self._dtype) # type: ignore[assignment] self._raw = value - self.name = name - self.doc_string = doc_string - self._metadata: _metadata.MetadataStore | None = None - self._metadata_props = metadata_props def __array__(self, dtype: Any = None) -> np.ndarray: if isinstance(self._raw, np.ndarray) or _compatible_with_numpy(self._raw): @@ -456,23 +500,6 @@ def tobytes(self) -> bytes: array = array.view(array.dtype.newbyteorder("<")) return array.tobytes() - @property - def metadata_props(self) -> dict[str, str]: - if self._metadata_props is None: - self._metadata_props = {} - return self._metadata_props - - @property - def meta(self) -> _metadata.MetadataStore: - """The metadata store for intermediate analysis. - - Write to the :attr:`metadata_props` if you would like the metadata to be serialized - to the ONNX proto. - """ - if self._metadata is None: - self._metadata = _metadata.MetadataStore() - return self._metadata - class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors """An immutable concrete tensor with its data store on disk. @@ -513,13 +540,9 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable= "_dtype", "_length", "_location", - "_metadata", - "_metadata_props", "_offset", "_shape", "_valid", - "doc_string", - "name", "raw", ) @@ -549,6 +572,7 @@ def __init__( metadata_props: The metadata properties. base_dir: The base directory for the external data. It is used to resolve relative paths. """ + super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props) # NOTE: Do not verify the location by default. This is because the location field # in the tensor proto can be anything and we would like deserialization from # proto to IR to not fail. @@ -726,34 +750,13 @@ def release(self) -> None: self.raw.close() self.raw = None - @property - def metadata_props(self) -> dict[str, str]: - if self._metadata_props is None: - self._metadata_props = {} - return self._metadata_props - - @property - def meta(self) -> _metadata.MetadataStore: - """The metadata store for intermediate analysis. - - Write to the :attr:`metadata_props` if you would like the metadata to be serialized - to the ONNX proto. - """ - if self._metadata is None: - self._metadata = _metadata.MetadataStore() - return self._metadata - class StringTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors """Multidimensional array of strings (as binary data to match the string_data field in TensorProto).""" __slots__ = ( - "_metadata", - "_metadata_props", "_raw", "_shape", - "doc_string", - "name", ) def __init__( @@ -774,6 +777,7 @@ def __init__( doc_string: The documentation string. metadata_props: The metadata properties. """ + super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props) if shape is None: if not hasattr(value, "shape"): raise ValueError( @@ -785,10 +789,6 @@ def __init__( self._shape = shape self._shape.freeze() self._raw = value - self.name = name - self.doc_string = doc_string - self._metadata: _metadata.MetadataStore | None = None - self._metadata_props = metadata_props def __array__(self, dtype: Any = None) -> np.ndarray: if isinstance(self._raw, np.ndarray): @@ -836,23 +836,6 @@ def string_data(self) -> Sequence[bytes]: return self._raw.flatten().tolist() return self._raw - @property - def metadata_props(self) -> dict[str, str]: - if self._metadata_props is None: - self._metadata_props = {} - return self._metadata_props - - @property - def meta(self) -> _metadata.MetadataStore: - """The metadata store for intermediate analysis. - - Write to the :attr:`metadata_props` if you would like the metadata to be serialized - to the ONNX proto. - """ - if self._metadata is None: - self._metadata = _metadata.MetadataStore() - return self._metadata - class LazyTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors """A tensor that lazily evaluates a function to get the actual tensor. @@ -890,13 +873,9 @@ class LazyTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too- __slots__ = ( "_dtype", "_func", - "_metadata", - "_metadata_props", "_shape", "_tensor", "cache", - "doc_string", - "name", ) def __init__( @@ -921,15 +900,12 @@ def __init__( doc_string: The documentation string. metadata_props: The metadata properties. """ + super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props) self._func = func self._dtype = dtype self._shape = shape self._tensor: _protocols.TensorProtocol | None = None self.cache = cache - self.name = name - self.doc_string = doc_string - self._metadata: _metadata.MetadataStore | None = None - self._metadata_props = metadata_props def _evaluate(self) -> _protocols.TensorProtocol: """Evaluate the function to get the actual tensor.""" @@ -975,23 +951,6 @@ def tobytes(self) -> bytes: """Return the bytes of the tensor.""" return self._evaluate().tobytes() - @property - def metadata_props(self) -> dict[str, str]: - if self._metadata_props is None: - self._metadata_props = {} - return self._metadata_props - - @property - def meta(self) -> _metadata.MetadataStore: - """The metadata store for intermediate analysis. - - Write to the :attr:`metadata_props` if you would like the metadata to be serialized - to the ONNX proto. - """ - if self._metadata is None: - self._metadata = _metadata.MetadataStore() - return self._metadata - class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable): """Immutable symbolic dimension that can be shared across multiple shapes.""" diff --git a/onnxscript/ir/serde.py b/onnxscript/ir/serde.py index 64703b2ba..6276121d0 100644 --- a/onnxscript/ir/serde.py +++ b/onnxscript/ir/serde.py @@ -67,7 +67,7 @@ import onnx import onnx.external_data_helper -from onnxscript.ir import _core, _enums, _metadata, _protocols, _type_casting +from onnxscript.ir import _core, _enums, _protocols, _type_casting if typing.TYPE_CHECKING: import google.protobuf.internal.containers as proto_containers @@ -243,12 +243,11 @@ def to_proto(ir_object: object) -> object: class TensorProtoTensor(_core.TensorBase): # pylint: disable=too-many-ancestors """A tensor initialized from a tensor proto.""" + __slots__ = ("_proto",) + def __init__(self, proto: onnx.TensorProto) -> None: + super().__init__(metadata_props=deserialize_metadata_props(proto.metadata_props)) self._proto = proto - self._metadata_props: dict[str, str] | None = deserialize_metadata_props( - proto.metadata_props - ) - self._metadata: _metadata.MetadataStore | None = None @property def name(self) -> str: @@ -269,7 +268,7 @@ def shape(self) -> _core.Shape: def dtype(self) -> _enums.DataType: return _enums.DataType(self._proto.data_type) - @property + @property # type: ignore[misc] def doc_string(self) -> str: return self._proto.doc_string @@ -439,23 +438,6 @@ def tobytes(self) -> bytes: # For example, int32_data can be empty and still be a valid tensor. return b"" - @property - def meta(self) -> _metadata.MetadataStore: - """The metadata store for intermediate analysis. - - Write to the :attr:`metadata_props` if you would like the metadata to be serialized - to the ONNX proto. - """ - if self._metadata is None: - self._metadata = _metadata.MetadataStore() - return self._metadata - - @property - def metadata_props(self) -> dict[str, str]: - if self._metadata_props is None: - self._metadata_props = {} - return self._metadata_props - def _get_field(proto: Any, field: str) -> Any: if proto.HasField(field):