diff --git a/docs/intermediate_representation/tensors.md b/docs/intermediate_representation/tensors.md index 67d9eee85a..5cd12a2eca 100644 --- a/docs/intermediate_representation/tensors.md +++ b/docs/intermediate_representation/tensors.md @@ -192,56 +192,80 @@ To fully support arrays from other frameworks, it is usually a good idea to crea import ctypes from typing import Any + import numpy.typing as npt import torch + from onnxscript import ir - # Define utilities to convert PyTorch data types so users do not need to specify manually - _TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = { - torch.bfloat16: ir.DataType.BFLOAT16, - torch.bool: ir.DataType.BOOL, - torch.complex128: ir.DataType.COMPLEX128, - torch.complex64: ir.DataType.COMPLEX64, - torch.float16: ir.DataType.FLOAT16, - torch.float32: ir.DataType.FLOAT, - torch.float64: ir.DataType.DOUBLE, - torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN, - torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ, - torch.float8_e5m2: ir.DataType.FLOAT8E5M2, - torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ, - torch.int16: ir.DataType.INT16, - torch.int32: ir.DataType.INT32, - torch.int64: ir.DataType.INT64, - torch.int8: ir.DataType.INT8, - torch.uint8: ir.DataType.UINT8, - } - - - def _torch_dtype_to_onnx_dtype(dtype: torch.dtype) -> ir.DataType: - return _TORCH_DTYPE_TO_ONNX[dtype] class TorchTensor(ir.Tensor): - def __init__(self, tensor: torch.Tensor): + def __init__( + self, tensor: torch.Tensor, name: str | None = None, doc_string: str | None = None + ): # Pass the tensor as the raw data to ir.Tensor's constructor - super().__init__(tensor, dtype=_torch_dtype_to_onnx_dtype(tensor.dtype)) - def __array__(self, dtype: Any = None) -> "np.ndarray": - # numpy() calls __array__ in ir.Tensor + _TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = { + torch.bfloat16: ir.DataType.BFLOAT16, + torch.bool: ir.DataType.BOOL, + torch.complex128: ir.DataType.COMPLEX128, + torch.complex64: ir.DataType.COMPLEX64, + torch.float16: ir.DataType.FLOAT16, + torch.float32: ir.DataType.FLOAT, + torch.float64: ir.DataType.DOUBLE, + torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN, + torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ, + torch.float8_e5m2: ir.DataType.FLOAT8E5M2, + torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ, + torch.int16: ir.DataType.INT16, + torch.int32: ir.DataType.INT32, + torch.int64: ir.DataType.INT64, + torch.int8: ir.DataType.INT8, + torch.uint8: ir.DataType.UINT8, + torch.uint16: ir.DataType.UINT16, + torch.uint32: ir.DataType.UINT32, + torch.uint64: ir.DataType.UINT64, + } + super().__init__( + tensor, dtype=_TORCH_DTYPE_TO_ONNX[tensor.dtype], name=name, doc_string=doc_string + ) + + def numpy(self) -> npt.NDArray: + self.raw: torch.Tensor if self.dtype == ir.DataType.BFLOAT16: - return self.raw.view(torch.uint16).__array__(dtype) + return self.raw.view(torch.uint16).numpy(force=True).view(self.dtype.numpy()) if self.dtype in { ir.DataType.FLOAT8E4M3FN, ir.DataType.FLOAT8E4M3FNUZ, ir.DataType.FLOAT8E5M2, - ir.DataType.FLOAT8E5M2FNUZ + ir.DataType.FLOAT8E5M2FNUZ, }: - return self.raw.view(torch.uint8).__array__(dtype) - return self.raw.__array__(dtype) + return self.raw.view(torch.uint8).numpy(force=True).view(self.dtype.numpy()) + + return self.raw.numpy(force=True) + + def __array__(self, dtype: Any = None, copy: bool | None = None) -> npt.NDArray: + del copy # Unused, but needed for the signature + if dtype is None: + return self.numpy() + return self.numpy().__array__(dtype) def tobytes(self) -> bytes: # Implement tobytes to support native PyTorch types so we can use types like bloat16 # Reading from memory directly is also more efficient because # it avoids copying to a NumPy array - tensor = self.raw.detach().cpu().contiguous() + import torch._subclasses.fake_tensor + + with torch._subclasses.fake_tensor.unset_fake_temporarily(): # pylint: disable=protected-access + # Disable any fake mode so calling detach() etc. will return a real tensor + tensor = self.raw.detach().cpu().contiguous() + + if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor): # pylint: disable=protected-access + raise TypeError( + f"Cannot take content out from the FakeTensor ('{self.name}'). Please replace the tensor " + "with a tensor backed by real data using ONNXProgram.apply_weights() " + "or save the model without initializers by setting include_initializers=False." + ) + return bytes( (ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address( tensor.data_ptr() @@ -249,7 +273,7 @@ To fully support arrays from other frameworks, it is usually a good idea to crea ) # Test the implementation - torch_tensor = torch.tensor([1,2,3], dtype=torch.bfloat16) + torch_tensor = torch.tensor([1, 2, 3], dtype=torch.bfloat16) tensor = TorchTensor(torch_tensor) print("tensor: ", tensor) print("numpy: ", tensor.numpy()) diff --git a/onnxscript/ir/tensor_adapters.py b/onnxscript/ir/tensor_adapters.py index e24bce026e..0a74e0a74c 100644 --- a/onnxscript/ir/tensor_adapters.py +++ b/onnxscript/ir/tensor_adapters.py @@ -81,15 +81,15 @@ def numpy(self) -> npt.NDArray: self.raw: torch.Tensor if self.dtype == ir.DataType.BFLOAT16: - return self.raw.view(torch.uint16).numpy(force=True) + return self.raw.view(torch.uint16).numpy(force=True).view(self.dtype.numpy()) if self.dtype in { ir.DataType.FLOAT8E4M3FN, ir.DataType.FLOAT8E4M3FNUZ, ir.DataType.FLOAT8E5M2, ir.DataType.FLOAT8E5M2FNUZ, }: - # TODO: Use ml_dtypes - return self.raw.view(torch.uint8).numpy(force=True) + return self.raw.view(torch.uint8).numpy(force=True).view(self.dtype.numpy()) + return self.raw.numpy(force=True) def __array__(self, dtype: Any = None, copy: bool | None = None) -> npt.NDArray: diff --git a/onnxscript/ir/tensor_adapters_test.py b/onnxscript/ir/tensor_adapters_test.py index 34034ac51f..8295bbe876 100644 --- a/onnxscript/ir/tensor_adapters_test.py +++ b/onnxscript/ir/tensor_adapters_test.py @@ -7,6 +7,7 @@ import importlib.util import unittest +import ml_dtypes import numpy as np import parameterized import torch @@ -25,17 +26,17 @@ def skip_if_no(module_name: str): class TorchTensorTest(unittest.TestCase): @parameterized.parameterized.expand( [ - (torch.bfloat16, np.uint16), + (torch.bfloat16, ml_dtypes.bfloat16), (torch.bool, np.bool_), (torch.complex128, np.complex128), (torch.complex64, np.complex64), (torch.float16, np.float16), (torch.float32, np.float32), (torch.float64, np.float64), - (torch.float8_e4m3fn, np.uint8), - (torch.float8_e4m3fnuz, np.uint8), - (torch.float8_e5m2, np.uint8), - (torch.float8_e5m2fnuz, np.uint8), + (torch.float8_e4m3fn, ml_dtypes.float8_e4m3fn), + (torch.float8_e4m3fnuz, ml_dtypes.float8_e4m3fnuz), + (torch.float8_e5m2, ml_dtypes.float8_e5m2), + (torch.float8_e5m2fnuz, ml_dtypes.float8_e5m2fnuz), (torch.int16, np.int16), (torch.int32, np.int32), (torch.int64, np.int64),