Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 57 additions & 33 deletions docs/intermediate_representation/tensors.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,64 +192,88 @@ To fully support arrays from other frameworks, it is usually a good idea to crea
import ctypes
from typing import Any

import numpy.typing as npt
import torch

from onnxscript import ir

# Define utilities to convert PyTorch data types so users do not need to specify manually
_TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = {
torch.bfloat16: ir.DataType.BFLOAT16,
torch.bool: ir.DataType.BOOL,
torch.complex128: ir.DataType.COMPLEX128,
torch.complex64: ir.DataType.COMPLEX64,
torch.float16: ir.DataType.FLOAT16,
torch.float32: ir.DataType.FLOAT,
torch.float64: ir.DataType.DOUBLE,
torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN,
torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
torch.int16: ir.DataType.INT16,
torch.int32: ir.DataType.INT32,
torch.int64: ir.DataType.INT64,
torch.int8: ir.DataType.INT8,
torch.uint8: ir.DataType.UINT8,
}


def _torch_dtype_to_onnx_dtype(dtype: torch.dtype) -> ir.DataType:
return _TORCH_DTYPE_TO_ONNX[dtype]

class TorchTensor(ir.Tensor):
def __init__(self, tensor: torch.Tensor):
def __init__(
self, tensor: torch.Tensor, name: str | None = None, doc_string: str | None = None
):
# Pass the tensor as the raw data to ir.Tensor's constructor
super().__init__(tensor, dtype=_torch_dtype_to_onnx_dtype(tensor.dtype))

def __array__(self, dtype: Any = None) -> "np.ndarray":
# numpy() calls __array__ in ir.Tensor
_TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = {
torch.bfloat16: ir.DataType.BFLOAT16,
torch.bool: ir.DataType.BOOL,
torch.complex128: ir.DataType.COMPLEX128,
torch.complex64: ir.DataType.COMPLEX64,
torch.float16: ir.DataType.FLOAT16,
torch.float32: ir.DataType.FLOAT,
torch.float64: ir.DataType.DOUBLE,
torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN,
torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
torch.int16: ir.DataType.INT16,
torch.int32: ir.DataType.INT32,
torch.int64: ir.DataType.INT64,
torch.int8: ir.DataType.INT8,
torch.uint8: ir.DataType.UINT8,
torch.uint16: ir.DataType.UINT16,
torch.uint32: ir.DataType.UINT32,
torch.uint64: ir.DataType.UINT64,
}
super().__init__(
tensor, dtype=_TORCH_DTYPE_TO_ONNX[tensor.dtype], name=name, doc_string=doc_string
)

def numpy(self) -> npt.NDArray:
self.raw: torch.Tensor
if self.dtype == ir.DataType.BFLOAT16:
return self.raw.view(torch.uint16).__array__(dtype)
return self.raw.view(torch.uint16).numpy(force=True).view(self.dtype.numpy())
if self.dtype in {
ir.DataType.FLOAT8E4M3FN,
ir.DataType.FLOAT8E4M3FNUZ,
ir.DataType.FLOAT8E5M2,
ir.DataType.FLOAT8E5M2FNUZ
ir.DataType.FLOAT8E5M2FNUZ,
}:
return self.raw.view(torch.uint8).__array__(dtype)
return self.raw.__array__(dtype)
return self.raw.view(torch.uint8).numpy(force=True).view(self.dtype.numpy())

return self.raw.numpy(force=True)

def __array__(self, dtype: Any = None, copy: bool | None = None) -> npt.NDArray:
del copy # Unused, but needed for the signature
if dtype is None:
return self.numpy()
return self.numpy().__array__(dtype)

def tobytes(self) -> bytes:
# Implement tobytes to support native PyTorch types so we can use types like bloat16
# Reading from memory directly is also more efficient because
# it avoids copying to a NumPy array
tensor = self.raw.detach().cpu().contiguous()
import torch._subclasses.fake_tensor

with torch._subclasses.fake_tensor.unset_fake_temporarily(): # pylint: disable=protected-access
# Disable any fake mode so calling detach() etc. will return a real tensor
tensor = self.raw.detach().cpu().contiguous()

if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor): # pylint: disable=protected-access
raise TypeError(
f"Cannot take content out from the FakeTensor ('{self.name}'). Please replace the tensor "
"with a tensor backed by real data using ONNXProgram.apply_weights() "
"or save the model without initializers by setting include_initializers=False."
)

return bytes(
(ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address(
tensor.data_ptr()
)
)

# Test the implementation
torch_tensor = torch.tensor([1,2,3], dtype=torch.bfloat16)
torch_tensor = torch.tensor([1, 2, 3], dtype=torch.bfloat16)
tensor = TorchTensor(torch_tensor)
print("tensor: ", tensor)
print("numpy: ", tensor.numpy())
Expand Down
6 changes: 3 additions & 3 deletions onnxscript/ir/tensor_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,15 @@ def numpy(self) -> npt.NDArray:

self.raw: torch.Tensor
if self.dtype == ir.DataType.BFLOAT16:
return self.raw.view(torch.uint16).numpy(force=True)
return self.raw.view(torch.uint16).numpy(force=True).view(self.dtype.numpy())
if self.dtype in {
ir.DataType.FLOAT8E4M3FN,
ir.DataType.FLOAT8E4M3FNUZ,
ir.DataType.FLOAT8E5M2,
ir.DataType.FLOAT8E5M2FNUZ,
}:
# TODO: Use ml_dtypes
return self.raw.view(torch.uint8).numpy(force=True)
return self.raw.view(torch.uint8).numpy(force=True).view(self.dtype.numpy())

return self.raw.numpy(force=True)

def __array__(self, dtype: Any = None, copy: bool | None = None) -> npt.NDArray:
Expand Down
11 changes: 6 additions & 5 deletions onnxscript/ir/tensor_adapters_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import importlib.util
import unittest

import ml_dtypes
import numpy as np
import parameterized
import torch
Expand All @@ -25,17 +26,17 @@ def skip_if_no(module_name: str):
class TorchTensorTest(unittest.TestCase):
@parameterized.parameterized.expand(
[
(torch.bfloat16, np.uint16),
(torch.bfloat16, ml_dtypes.bfloat16),
(torch.bool, np.bool_),
(torch.complex128, np.complex128),
(torch.complex64, np.complex64),
(torch.float16, np.float16),
(torch.float32, np.float32),
(torch.float64, np.float64),
(torch.float8_e4m3fn, np.uint8),
(torch.float8_e4m3fnuz, np.uint8),
(torch.float8_e5m2, np.uint8),
(torch.float8_e5m2fnuz, np.uint8),
(torch.float8_e4m3fn, ml_dtypes.float8_e4m3fn),
(torch.float8_e4m3fnuz, ml_dtypes.float8_e4m3fnuz),
(torch.float8_e5m2, ml_dtypes.float8_e5m2),
(torch.float8_e5m2fnuz, ml_dtypes.float8_e5m2fnuz),
(torch.int16, np.int16),
(torch.int32, np.int32),
(torch.int64, np.int64),
Expand Down
Loading