diff --git a/tensordict/csrc/pybind.cpp b/tensordict/csrc/pybind.cpp index 4e31b629c..0a0fd6107 100644 --- a/tensordict/csrc/pybind.cpp +++ b/tensordict/csrc/pybind.cpp @@ -5,6 +5,8 @@ #include #include +#include +#include #include @@ -18,4 +20,6 @@ PYBIND11_MODULE(_tensordict, m) { m.def("_unravel_key_to_tuple", &_unravel_key_to_tuple, py::arg("key")); m.def("unravel_key_list", py::overload_cast(&unravel_key_list), py::arg("keys")); m.def("unravel_key_list", py::overload_cast(&unravel_key_list), py::arg("keys")); + m.def("_populate_index", &_populate_index, "populate index function"); + m.def("_as_shape", &_as_shape, "Converts a het shape to a shape with -1 for het dims."); } diff --git a/tensordict/csrc/utils.h b/tensordict/csrc/utils.h index c8cbfd861..4b9fe49df 100644 --- a/tensordict/csrc/utils.h +++ b/tensordict/csrc/utils.h @@ -5,9 +5,10 @@ #include #include +#include namespace py = pybind11; - +using namespace torch::indexing; py::tuple _unravel_key_to_tuple(const py::object& key) { bool is_tuple = py::isinstance(key); @@ -76,3 +77,63 @@ py::list unravel_key_list(const py::list& keys) { py::list unravel_key_list(const py::tuple& keys) { return unravel_key_list(py::list(keys)); } + +torch::Tensor _populate_index(torch::Tensor offsets, torch::Tensor offsets_cs) { + int64_t total = offsets.sum().item(); + torch::Tensor out = torch::empty({total}, torch::dtype(torch::kLong)); + + int64_t* out_data = out.data_ptr(); + int64_t cur_offset; + int64_t count = -1; + int64_t maxcount = -1; + int64_t cur = -1; + int64_t n = offsets.numel(); + for (int i = 0; i < total; ++i) { + if (cur < n && count == maxcount) { + cur++; + count = -1; + maxcount = offsets[cur].item() - 1; + cur_offset = offsets_cs[cur].item(); + } + count++; + out_data[i] = cur_offset + count; + } + return out; +} +py::list _as_shape(torch::Tensor shape_tensor) { +// torch::Tensor shape_tensor_view = shape_tensor.reshape({-1, shape_tensor.size(-1)}); + torch::Tensor out = shape_tensor; + for (int64_t i = 0; i < shape_tensor.ndimension() - 1; ++i) { + out = out[0]; + } + out = out.clone(); + torch::Tensor not_unique = shape_tensor != out; + for (int64_t i = 0; i < shape_tensor.ndimension() - 1; ++i) { + not_unique = not_unique.any(0); + } + out.masked_fill_(not_unique, -1); + std::vector shape_vector(shape_tensor.sizes().begin(), shape_tensor.sizes().end() - 1); + // Extend 'shape_vector' with the values from 'out'. + auto out_accessor = out.accessor(); + for (int64_t i = 0; i < out_accessor.size(0); ++i) { + shape_vector.push_back(out_accessor[i]); + } + + py::list shape = py::cast(shape_vector); + return shape; +} +//py::list _as_shape(torch::Tensor shape_tensor) { +// torch::Tensor shape_tensor_view = shape_tensor.reshape({-1, shape_tensor.size(-1)}); +// torch::Tensor out = shape_tensor_view[0]; +// auto not_unique = (shape_tensor_view != out).any(0); +// out.masked_fill_(not_unique, -1); +// std::vector shape_vector; +// shape_vector.reserve(shape_tensor.ndimension() + shape_tensor.size(-1) - 1); // Reserve capacity to avoid reallocations. +// shape_vector.insert(shape_vector.end(), shape_tensor.sizes().begin(), shape_tensor.sizes().end() - 1); +// auto out_accessor = out.accessor(); +// for (int64_t i = 0; i < out_accessor.size(0); ++i) { +// shape_vector.push_back(out_accessor[i]); +// } +// py::list shape = py::cast(shape_vector); +// return shape; +//} diff --git a/tensordict/tensorstack.py b/tensordict/tensorstack.py new file mode 100644 index 000000000..cc06b9876 --- /dev/null +++ b/tensordict/tensorstack.py @@ -0,0 +1,505 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import torch +from tensordict._tensordict import _as_shape, _populate_index +from torch import Tensor +from torch.utils._pytree import tree_flatten, tree_map + + +def _lazy_init(func): + """A caching helper.""" + name = "_" + func.__name__ + + def setter(self, value): + setattr(self, name, value) + + def new_func(self): + if not hasattr(self, name): + r = func(self) + setter(self, r) + return r + return getattr(self, name) + + return property(new_func, setter) + + +def _broadcast_shapes(*shapes): + """A modified version of torch.broadcast_shapes that accepts -1.""" + max_len = 0 + for shape in shapes: + if isinstance(shape, int): + if max_len < 1: + max_len = 1 + elif isinstance(shape, (tuple, list)): + s = len(shape) + if max_len < s: + max_len = s + result = [1] * max_len + for shape in shapes: + if isinstance(shape, int): + shape = (shape,) + if isinstance(shape, (tuple, list)): + for i in range(-1, -1 - len(shape), -1): + cur_shape = shape[i] + if cur_shape == -1: + cur_shape = None # in double we use None as placeholder, which equals nothing + if cur_shape == 1 or cur_shape == result[i]: + continue + if result[i] == -1: + # in this case, we consider this as het dim + continue + if result[i] != 1: + raise RuntimeError( + "Shape mismatch: objects cannot be broadcast to a single shape" + ) + result[i] = shape[i] + else: + raise RuntimeError( + "Input shapes should be of type ints, a tuple of ints, or a list of ints, got ", + shape, + ) + return torch.Size(result) + + +class _NestedShape: + def __new__(cls, shapes): + # TODO: if a tensor with nume() == tensor.shape[-1], then return a regular tensor + if isinstance(shapes, _NestedShape): + return shapes + return super().__new__(cls) + + def __init__(self, shapes): + if not isinstance(shapes, torch.Tensor): + shapes = torch.tensor(shapes) + self._shapes = shapes + + @_lazy_init + def _offsets(self): + common_shape = self.common_shape + shapes = self._shapes + if common_shape: + shapes = shapes[..., : -len(common_shape)] + return shapes.prod(-1) + + @_lazy_init + def _offsets_cs(self): + common_shape = self.common_shape + shapes = self._shapes + if common_shape: + shapes = shapes[..., : -len(common_shape)] + cs = shapes.prod(-1).reshape(-1).cumsum(0) + cs_pad = torch.nn.functional.pad(cs[:-1], [1, 0]) + return torch.stack( + [ + cs_pad.view(shapes.shape[:-1]), + cs.view(shapes.shape[:-1]), + ] + ) + + def unfold(self): + """Converts the shape to the maximum-indexable format. + + Examples: + >>> ns = _NestedShape(([11, 2, 3], [11, 5, 3])) + >>> print(ns.batch_dim) + torch.Size([2]) + >>> print(ns.unfold().batch_dim) + torch.Size([2, 11]) + """ + out = _NestedShape(self._shapes.clone()) + is_unique, val = out.is_unique(out._shapes.ndim - 1) + while is_unique: + out._shapes = ( + out._shapes[..., 1:] + .unsqueeze(-2) + .expand(*out._shapes.shape[:-1], val, -1) + ) + is_unique, val = out.is_unique(out._shapes.ndim - 1) + return out + + @_lazy_init + def ndim(self): + return self._shapes.ndim - 1 + self._shapes.shape[-1] + + def is_unique(self, dim): + if dim < 0: + dim = self.ndim + dim + if dim < 0 or dim >= self.ndim: + raise RuntimeError + if dim < self._shapes.ndim - 1: + return (True, self._shapes.shape[dim]) + v = self.as_shape[dim - self._shapes.ndim + 1] + return v != -1, v + + @_lazy_init + def het_dims(self): + as_shape = self.as_shape + return [dim for dim, s in enumerate(as_shape) if s == -1] + + def numel(self): + return ( + self._offsets_cs[(1,) + (-1,) * (self._shapes.ndim - 1)] + * self.common_shape.numel() + ) + + @property + def batch_dim(self): + return self._shapes.shape[:-1] + + @_lazy_init + def common_shape(self): + shape = [] + + for v in reversed(self.as_shape): + if v != -1: + shape.append(v) + else: + break + return torch.Size(reversed(shape)) + + @classmethod + def broadcast_shape(cls, shape: torch.Size, nested_shape: _NestedShape): + broadcast_shape = _broadcast_shapes(shape, nested_shape.as_shape) + return nested_shape.expand(broadcast_shape) + + def expand(self, *broadcast_shape): + if len(broadcast_shape) == 1 and isinstance(broadcast_shape[0], (tuple, list)): + broadcast_shape = broadcast_shape[0] + as_shape = self.as_shape + if len(broadcast_shape) == len(as_shape) and all( + s1 == s2 or s1 == -1 or s2 == -1 + for (s1, s2) in zip(broadcast_shape, as_shape) + ): + return self + + # trailing dims, ie dims that are registered + broadcast_shape_trailing = broadcast_shape[self._shapes.shape[-1] :] + broadcast_shape_trailing = _broadcast_shapes( + broadcast_shape_trailing, as_shape[len(self.batch_dim) :] + ) + # replace trailing dims + shapes = self._shapes.clone() + for i in range(-1, len(broadcast_shape_trailing) - 1, -1): + if as_shape[i] != -1: + shapes[..., i] = broadcast_shape_trailing[i] + + # leading dims, ie dims that are not explicitely registered + broadcast_shape_leading = broadcast_shape[: -self.ndim] + + # find first -1 in broadcast_shape + if not len(broadcast_shape_leading): + return _NestedShape(shapes) + + return _NestedShape( + shapes.expand(*broadcast_shape_leading, *self._shapes.shape) + ) + + @_lazy_init + def is_plain(self): + return self._shapes.ndim == 1 or not self.het_dims + + def __getitem__(self, item): + try: + return _NestedShape(self._shapes[item]) + except IndexError as err: + if "too many indices" in str(err): + raise IndexError( + "Cannot index along dimensions on the right of the heterogeneous dimension." + ) + + @_lazy_init + def as_shape(self): + shape_cpp = torch.Size(_as_shape(self._shapes)) + return shape_cpp + # first_shape = self._shapes[(0,) * (self._shapes.ndim - 1)].clone() + # unique = (self._shapes == first_shape).view(-1, self._shapes.shape[-1]).all(0) + # first_shape[~unique] = -1 + # shape = list(self._shapes.shape[:-1]) + list(first_shape) + # return torch.Size(shape) + + def __repr__(self): + return str(self.as_shape) + + def __eq__(self, other): + return (self._shapes == other).all() + + def __ne__(self, other): + return (self._shapes != other).any() + + +def get_parent_class(f): + return f.__globals__.get(f.__qualname__.split(".")[0], None) + + +def _copy_shapes(func): + def new_func(self, *args, **kwargs): + out = getattr(torch.Tensor, func.__name__)(self, *args, **kwargs) + _shapes = self._shapes + out = TensorStack(out, shapes=_shapes) + return out + + return new_func + + +def _broadcast(func): + def new_func(self, other): + other_shape = getattr(other, "shape", torch.Size([])) + shapes = _NestedShape.broadcast_shape(other_shape, self._shapes) + compact = self._compact + if isinstance(other, TensorStack): + other = other._compact + if shapes != self._shapes: + raise RuntimeError("broadcast between TensorStack not implemented yet.") + # other = other.unsqueeze(-2) + elif isinstance(other, Tensor) and shapes != self._shapes: + # we need to squash + other = other.reshape( + *other.shape[: -self.ndim], -1, *other.shape[-len(compact.shape[1:]) :] + ) + out = getattr(torch.Tensor, func.__name__)(compact, other) + return TensorStack(out, shapes=shapes) + + return new_func + + +class TensorStack(torch.Tensor): + def __new__(cls, tensor, *, shapes): + if shapes.is_plain: + return tensor.reshape(shapes.as_shape) + return super().__new__(cls, tensor) + + def __init__(self, tensor, *, shapes, unfold=False): + super(TensorStack, self).__init__() + if not isinstance(shapes, _NestedShape): + raise ValueError("shapes must be a _NestedShape instance") + if unfold: + shapes = shapes.unfold() + self._shapes = shapes + + @classmethod + def from_tensors(cls, tensors): + if not len(tensors): + raise RuntimeError + shapes = _NestedShape(tree_map(lambda x: x.shape, tensors)) + return TensorStack( + torch.cat([t.view(-1) for t in tree_flatten(tensors)[0]]), shapes=shapes + ) + + def numel(self): + return self._shapes.numel() + + @property + def shape(self): + return self._shapes.as_shape + + def unstack(self): + raise NotImplementedError + + @_lazy_init + def _flat(self): + # represents the tensor as a flat one + return super().view(-1) + + @_lazy_init + def _compact(self): + # represents the tensor with a compact structure (rightmost consistent dims-wise) + return torch.Tensor(super().view(-1, *self._shapes.common_shape)) + + def view(self, *shape): + if len(shape) == 1 and isinstance(shape[0], (tuple, list)): + shape = shape[0] + if isinstance(shape, _NestedShape): + if shape.numel() != self.numel(): + raise ValueError + out = TensorStack(self, shapes=shape) + return out + if len(shape) == 1 and shape[0] == -1: + return self._flat + n = self.numel() + shape = torch.Size(shape) + common_shape = self._shapes.common_shape + compact_shape = torch.Size([n // common_shape.numel(), *common_shape]) + if shape in (torch.Size([-1, *common_shape]), compact_shape): + return self._compact + raise RuntimeError(shape) + + @property + def ndim(self): + return len(self.shape) + + def ndimension(self): + return self.ndim + + def __getitem__(self, index): + if isinstance(index, (int,)): + idx_beg = self._shapes._offsets_cs[0, index] + shapes = self._shapes[index] + if idx_beg.numel() <= 1: + idx_end = self._shapes._offsets_cs[1, index] + out = self._compact.__getitem__(slice(idx_beg, idx_end)) + else: + elts = _populate_index( + self._shapes._offsets[index].view(-1), + idx_beg.view(-1), + ) + out = self._compact[elts] + out = TensorStack(out, shapes=shapes) + return out + shapes = self._shapes[index] + if not isinstance(index, tuple): + index = (index,) + # TODO: capture wrong indexing + elts = _populate_index( + self._shapes._offsets[index].view(-1), + self._shapes._offsets_cs[(0, *index)].view(-1), + ) + tensor = self._compact[elts] + return TensorStack(tensor, shapes=shapes) + + def __repr__(self): + return f"{self.__class__.__name__}(shape={self.shape}, dtype={self.dtype}, device={self.device})" + + def permute(self, *dims): + if len(dims) == 1 and isinstance(dims[0], (list, tuple)): + dims = tuple(dims[0]) + n_batch_dms = len(self._shapes.batch_dim) + last_dims = [d - n_batch_dms for d in dims[n_batch_dms:]] + if last_dims != list(range(self.ndim - n_batch_dms)): + raise RuntimeError + dims = dims[:n_batch_dms] + out = TensorStack(self, shapes=_NestedShape(self._shapes._shapes)) + out._shapes._shapes = self._shapes._shapes.permute(*dims, n_batch_dms)[ + ..., last_dims + ] + out._shapes._offsets_cs = self._shapes._offsets_cs.permute( + 0, *[dim + 1 for dim in dims] + ) + out._shapes._offsets = self._shapes._offsets.permute(*dims) + return out + + def transpose(self, dim0, dim1): + out = TensorStack(self, shapes=_NestedShape(self._shapes._shapes)) + out._shapes._shapes = self._shapes._shapes.transpose(dim0, dim1) + out._shapes._offsets_cs = self._shapes._offsets_cs.transpose(dim0 + 1, dim1 + 1) + out._shapes._offsets = self._shapes._offsets.transpose(dim0, dim1) + return out + + @_copy_shapes + def to(self, *args, **kwargs): + ... + + @_copy_shapes + def cpu(self): + ... + + @_copy_shapes + def bool(self): + ... + + @_copy_shapes + def float(self): + ... + + @_copy_shapes + def double(self): + ... + + @_copy_shapes + def int(self): + ... + + @_copy_shapes + def cuda(self): + ... + + @_copy_shapes + def __neg__(self): + ... + + @_copy_shapes + def __abs__(self): + ... + + @_copy_shapes + def __inv__(self): + ... + + @_copy_shapes + def __invert__(self): + ... + + @_broadcast + def add(self, other): + ... + + @_broadcast + def div(self, other): + ... + + @_broadcast + def rdiv(self, other): + ... + + @_broadcast + def __add__(self, other): + ... + + @_broadcast + def __mod__(self, other): + ... + + @_broadcast + def __pow__(self, other): + ... + + @_broadcast + def __sub__(self, other): + ... + + @_broadcast + def __truediv__(self, other): + ... + + @_broadcast + def __eq__(self, other): + ... + + @_broadcast + def __ne__(self, other): + ... + + @_broadcast + def __div__(self, other): + ... + + @_broadcast + def __floordiv__(self, other): + ... + + @_broadcast + def __lt__(self, other): + ... + + @_broadcast + def __le__(self, other): + ... + + @_broadcast + def __ge__(self, other): + ... + + @_broadcast + def __gt__(self, other): + ... + + @_broadcast + def __rdiv__(self, other): + ... + + @_broadcast + def __mul__(self, other): + ... diff --git a/test/test_tensorstack.py b/test/test_tensorstack.py new file mode 100644 index 000000000..a96324858 --- /dev/null +++ b/test/test_tensorstack.py @@ -0,0 +1,136 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import argparse + +import pytest +import torch + +from tensordict.tensorstack import TensorStack + + +@pytest.fixture +def _tensorstack(): + torch.manual_seed(0) + x = torch.randint(10, (3, 1, 5)) + y = torch.randint(10, (3, 2, 5)) + z = torch.randint(10, (3, 3, 5)) + t = TensorStack.from_tensors([x, y, z]) + return t, (x, y, z) + + +class TestTensorStack: + def test_indexing_int(self, _tensorstack): + t, (x, y, z) = _tensorstack + assert (t[0] == x).all() + assert (t[1] == y).all() + assert (t[2] == z).all() + + def test_indexing_slice(self, _tensorstack): + t, (x, y, z) = _tensorstack + + assert (t[:3][0] == x).all() + assert (t[:3][1] == y).all() + assert (t[:3][2] == z).all() + assert (t[-3:][0] == x).all() + assert (t[-3:][1] == y).all() + assert (t[-3:][2] == z).all() + # this breaks because the shape backend is a tensor, which cannot be indexed with neg steps + # assert (t[::-1][0] == z).all() + # assert (t[::-1][1] == y).all() + # assert (t[::-1][2] == x).all() + + def test_indexing_range(self, _tensorstack): + t, (x, y, z) = _tensorstack + assert (t[range(3)][0] == x).all() + assert (t[range(3)][1] == y).all() + assert (t[range(3)][2] == z).all() + assert (t[range(1, 3)][0] == y).all() + assert (t[range(1, 3)][1] == z).all() + + def test_indexing_tensor(self, _tensorstack): + t, (x, y, z) = _tensorstack + assert (t[torch.tensor([0, 2])][0] == x).all() + assert (t[torch.tensor([0, 2])][1] == z).all() + assert (t[torch.tensor([0, 2, 0, 2])][2] == x).all() + assert (t[torch.tensor([0, 2, 0, 2])][3] == z).all() + + assert (t[torch.tensor([[0, 2], [0, 2]])][0, 0] == x).all() + assert (t[torch.tensor([[0, 2], [0, 2]])][0, 1] == z).all() + assert (t[torch.tensor([[0, 2], [0, 2]])][1, 0] == x).all() + assert (t[torch.tensor([[0, 2], [0, 2]])][1, 1] == z).all() + + def test_indexing_composite(self, _tensorstack): + _, (x, y, z) = _tensorstack + t = TensorStack.from_tensors([[x, y, z], [x, y, z]]) + assert (t[0, 0] == x).all() + assert (t[torch.tensor([0]), torch.tensor([0])] == x).all() + assert (t[torch.tensor([0]), torch.tensor([1])] == y).all() + assert (t[torch.tensor([0]), torch.tensor([2])] == z).all() + assert (t[:, torch.tensor([0])] == x).all() + assert (t[:, torch.tensor([1])] == y).all() + assert (t[:, torch.tensor([2])] == z).all() + assert ( + t[torch.tensor([0]), torch.tensor([1, 2])] + == TensorStack.from_tensors([y, z]) + ).all() + with pytest.raises(IndexError, match="Cannot index along"): + assert ( + t[..., torch.tensor([1, 2]), :, :, :] + == TensorStack.from_tensors([y, z]) + ).all() + + @pytest.mark.parametrize( + "op", + ["__add__", "__truediv__", "__mul__", "__sub__", "__mod__", "__eq__", "__ne__"], + ) + def test_elementwise(self, _tensorstack, op): + t, (x, y, z) = _tensorstack + t2 = getattr(t, op)(2) + torch.testing.assert_close(t2[0], getattr(x, op)(2)) + torch.testing.assert_close(t2[1], getattr(y, op)(2)) + torch.testing.assert_close(t2[2], getattr(z, op)(2)) + t2 = getattr(t, op)(torch.ones(5) * 2) + torch.testing.assert_close(t2[0], getattr(x, op)(torch.ones(5) * 2)) + torch.testing.assert_close(t2[1], getattr(y, op)(torch.ones(5) * 2)) + torch.testing.assert_close(t2[2], getattr(z, op)(torch.ones(5) * 2)) + # check broadcasting + assert t2[0].shape == x.shape + v = torch.ones(2, 1, 1, 1, 5) * 2 + t2 = getattr(t, op)(v) + assert t2.shape == torch.Size([2, 3, 3, -1, 5]) + torch.testing.assert_close(t2[:, 0], getattr(x, op)(v[:, 0])) + torch.testing.assert_close(t2[:, 1], getattr(y, op)(v[:, 0])) + torch.testing.assert_close(t2[:, 2], getattr(z, op)(v[:, 0])) + # check broadcasting + assert t2[:, 0].shape == torch.Size((2, *x.shape)) + + def test_permute(self): + w = torch.randint(10, (3, 5, 5)) + x = torch.randint(10, (3, 4, 5)) + y = torch.randint(10, (3, 5, 5)) + z = torch.randint(10, (3, 4, 5)) + ts = TensorStack.from_tensors([[w, x], [y, z]]) + tst = ts.permute(1, 0, 2, 3, 4) + assert (tst[0, 1] == ts[1, 0]).all() + assert (tst[1, 0] == ts[0, 1]).all() + assert (tst[1, 1] == ts[1, 1]).all() + assert (tst[0, 0] == ts[0, 0]).all() + + def test_transpose(self): + w = torch.randint(10, (3, 5, 5)) + x = torch.randint(10, (3, 4, 5)) + y = torch.randint(10, (3, 5, 5)) + z = torch.randint(10, (3, 4, 5)) + ts = TensorStack.from_tensors([[w, x], [y, z]]) + tst = ts.transpose(1, 0) + assert (tst[0, 1] == ts[1, 0]).all() + assert (tst[1, 0] == ts[0, 1]).all() + assert (tst[1, 1] == ts[1, 1]).all() + assert (tst[0, 0] == ts[0, 0]).all() + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)