Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Better dtype coverage #834

Merged
merged 2 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 77 additions & 4 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
_is_non_tensor,
_is_tensorclass,
_KEY_ERROR,
_make_dtype_promotion,
_proc_init,
_prune_selected_keys,
_set_max_batch_size,
Expand Down Expand Up @@ -7758,10 +7759,6 @@ def half(self):
r"""Casts all tensors to ``torch.half``."""
return self._fast_apply(lambda x: x.half(), propagate_lock=True)

def bfloat16(self):
r"""Casts all tensors to ``torch.bfloat16``."""
return self._fast_apply(lambda x: x.bfloat16(), propagate_lock=True)

def type(self, dst_type):
r"""Casts all tensors to :attr:`dst_type`.

Expand Down Expand Up @@ -7799,6 +7796,82 @@ def detach(self) -> T:
propagate_lock=True,
)

@_make_dtype_promotion
def bfloat16(self):
...

@_make_dtype_promotion
def complex128(self):
...

@_make_dtype_promotion
def complex32(self):
...

@_make_dtype_promotion
def complex64(self):
...

@_make_dtype_promotion
def float16(self):
...

@_make_dtype_promotion
def float32(self):
...

@_make_dtype_promotion
def float64(self):
...

@_make_dtype_promotion
def int16(self):
...

@_make_dtype_promotion
def int32(self):
...

@_make_dtype_promotion
def int64(self):
...

@_make_dtype_promotion
def int8(self):
...

@_make_dtype_promotion
def qint32(self):
...

@_make_dtype_promotion
def qint8(self):
...

@_make_dtype_promotion
def quint4x2(self):
...

@_make_dtype_promotion
def quint8(self):
...

@_make_dtype_promotion
def uint16(self):
...

@_make_dtype_promotion
def uint32(self):
...

@_make_dtype_promotion
def uint64(self):
...

@_make_dtype_promotion
def uint8(self):
...


_ACCEPTED_CLASSES = (
Tensor,
Expand Down
82 changes: 47 additions & 35 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,28 +105,29 @@ def dims(self, *args, **kwargs):

T = TypeVar("T", bound="TensorDictBase")

_STRDTYPE2DTYPE = {
str(dtype): dtype
for dtype in (
torch.float32,
torch.float64,
torch.float16,
torch.bfloat16,
torch.complex32,
torch.complex64,
torch.complex128,
torch.uint8,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.bool,
torch.quint8,
torch.qint8,
torch.qint32,
torch.quint4x2,
)
}
_TORCH_DTYPES = (
torch.bfloat16,
torch.bool,
torch.complex128,
torch.complex32,
torch.complex64,
torch.float16,
torch.float32,
torch.float64,
torch.int16,
torch.int32,
torch.int64,
torch.int8,
torch.qint32,
torch.qint8,
torch.quint4x2,
torch.quint8,
torch.uint16,
torch.uint32,
torch.uint64,
torch.uint8,
)
_STRDTYPE2DTYPE = {str(dtype): dtype for dtype in _TORCH_DTYPES}

IndexType = Union[None, int, slice, str, Tensor, List[Any], Tuple[Any, ...]]
DeviceType = Union[torch.device, str, int]
Expand Down Expand Up @@ -414,19 +415,19 @@ def expand_right(tensor: Tensor, shape: Sequence[int]) -> Tensor:
return tensor_expand


NUMPY_TO_TORCH_DTYPE_DICT = {
np.dtype("bool"): torch.bool,
np.dtype("uint8"): torch.uint8,
np.dtype("int8"): torch.int8,
np.dtype("int16"): torch.int16,
np.dtype("int32"): torch.int32,
np.dtype("int64"): torch.int64,
np.dtype("float16"): torch.float16,
np.dtype("float32"): torch.float32,
np.dtype("float64"): torch.float64,
np.dtype("complex64"): torch.complex64,
np.dtype("complex128"): torch.complex128,
}
def _populate_np_dtypes():
d = {}
for dtype in _TORCH_DTYPES:
dtype_str = str(dtype).split(".")[-1]
try:
d[np.dtype(dtype_str)] = dtype
except TypeError:
continue
return d


NUMPY_TO_TORCH_DTYPE_DICT = _populate_np_dtypes()

TORCH_TO_NUMPY_DTYPE_DICT = {
value: key for key, value in NUMPY_TO_TORCH_DTYPE_DICT.items()
}
Expand Down Expand Up @@ -2309,3 +2310,14 @@ def is_namedtuple_class(cls):
"""Check if a class is a namedtuple class."""
base_attrs = {"_fields", "_replace", "_asdict"}
return all(hasattr(cls, attr) for attr in base_attrs)


def _make_dtype_promotion(func):
dtype = getattr(torch, func.__name__)

@wraps(func)
def new_func(self):
return self._fast_apply(lambda x: x.to(dtype), propagate_lock=True)

new_func.__doc__ = rf"""Casts all tensors to ``{str(dtype)}``."""
return new_func
Loading