From 20f4fbc2f5fd11a85988eb3bf85efaa4ba9c4550 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 3 Feb 2023 10:31:35 +0000 Subject: [PATCH 1/6] init --- torchrl/data/tensor_specs.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index fe8e14ddd8f..9c09049b499 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1959,6 +1959,11 @@ def expand(self, *shape): return out +class LazyStackedCompositeSpec(CompositeSpec): + def __init__(self, *composite_specs, dim): + self.composite_spec = composite_specs + self.dim = dim + def _keys_to_empty_composite_spec(keys): if not len(keys): return From 3d3725583c3ea0f14844c52272ec8ec970450a91 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 3 Feb 2023 12:00:57 +0000 Subject: [PATCH 2/6] amend --- torchrl/data/tensor_specs.py | 167 ++++++++++++++++++++++++++++++++++- 1 file changed, 165 insertions(+), 2 deletions(-) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 9c09049b499..de4c76bb444 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1960,9 +1960,172 @@ def expand(self, *shape): class LazyStackedCompositeSpec(CompositeSpec): - def __init__(self, *composite_specs, dim): - self.composite_spec = composite_specs + """A lazy representation of a stack of composite specs. + + Stacks composite specs together along one dimension. + When random samples are drawn, a LazyStackedTensorDict is returned. + + Indexing is allowed but only along the stack dimension. + + This class is aimed to be used in multi-task and multi-agent settings, where + heterogeneous specs may occur (same semantic but different shape). + + """ + + def __init__(self, *composite_specs: CompositeSpec, dim): + self.composite_specs = composite_specs self.dim = dim + if self.dim < 0: + self.dim = len(self.shape) + self.dim + + def __getitem__(self, item): + is_key = isinstance(item, str) or ( + isinstance(item, tuple) and all(isinstance(_item, str) for _item in item) + ) + if is_key: + return torch.stack( + [composite_spec[item] for composite_spec in self.composite_specs] + ) + elif isinstance(item, tuple): + # quick check that the index is along the stacked dim + # case 1: index is a tuple, and the first arg is an ellipsis. Then dim must be the last dim of all composite_specs + if item[0] is Ellipsis: + if len(item) == 0: + return self + elif self.dim == len(self.shape) - 1: + # we can return + return self.composite_specs[item[1]] + else: + raise IndexError( + "Indexing a LazyStackedCompositeSpec with [..., idx] is only permitted if the stack dimension is the last dimension. " + f"Got self.dim={self.dim} and self.shape={self.shape}." + ) + elif len(item) == 2 and item[1] is Ellipsis: + return self[item[0]] + elif any(_item is Ellipsis for _item in item): + raise IndexError("Cannot index along multiple dimensions.") + # Ellipsis is now ruled out + elif any(_item is None for _item in item): + raise IndexError( + "Cannot index a LazyStackedCompositeSpec with None values" + ) + # Must be an index with slices then + else: + for i, _item in enumerate(item): + if i == self.dim: + return torch.stack(list(self.composite_specs)[_item], self.dim) + elif isinstance(_item, slice): + # then the slice must be trivial + if not (_item.step is _item.start is _item.stop is None): + raise IndexError( + f"Got a non-trivial index at dim {i} when only the dim {self.dim} could be indexed." + ) + else: + return self + else: + if not self.dim == 0: + raise IndexError( + f"Trying to index a LazyStackedCompositeSpec along dimension 0 when the stack dimension is {self.dim}." + ) + return torch.stack(list(self.composite_specs)[item], 0) + + @property + def shape(self): + shape = list(self.composite_specs[0].shape) + dim = self.dim + if dim < 0: + dim = len(shape) + dim + 1 + shape.insert(dim, len(self.composite_specs)) + return torch.Size(shape) + + def clone(self) -> CompositeSpec: + pass + + def expand(self, *shape): + pass + + def update(self, dict_or_spec: Union[CompositeSpec, Dict[str, TensorSpec]]) -> None: + pass + + def __eq__(self, other): + pass + + def zero(self, shape=None) -> TensorDictBase: + pass + + def rand(self, shape=None) -> TensorDictBase: + pass + + def to_numpy(self, val: TensorDict, safe: bool = True) -> dict: + pass + + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: + pass + + def __len__(self): + pass + + def values(self) -> ValuesView: + pass + + def items(self) -> ItemsView: + pass + + def keys( + self, yield_nesting_keys: bool = False, nested_keys: bool = True + ) -> KeysView: + pass + + def project(self, val: TensorDictBase) -> TensorDictBase: + pass + + def is_in(self, val: Union[dict, TensorDictBase]) -> bool: + pass + + def type_check( + self, + value: Union[torch.Tensor, TensorDictBase], + selected_keys: Union[str, Optional[Sequence[str]]] = None, + ): + pass + + def __repr__(self): + pass + + def encode(self, vals: Dict[str, Any]) -> Dict[str, torch.Tensor]: + pass + + def __delitem__(self, key): + pass + + def __iter__(self): + pass + + def __setitem__(self, key, value): + pass + + @property + def device(self) -> DEVICE_TYPING: + pass + + @property + def ndim(self): + return self.ndimension() + + def ndimension(self): + return len(self.shape) + + def set(self, name, spec): + if spec is not None: + shape = spec.shape + if shape[: self.ndim] != self.shape: + raise ValueError( + "The shape of the spec and the CompositeSpec mismatch: the first " + f"{self.ndim} dimensions should match but got spec.shape={spec.shape} and " + f"CompositeSpec.shape={self.shape}." + ) + self._specs[name] = spec + def _keys_to_empty_composite_spec(keys): if not len(keys): From f503249b84f2e64f071fa866ca39f3b7ad9ce1bf Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 6 Feb 2023 21:07:55 +0000 Subject: [PATCH 3/6] amend --- test/test_specs.py | 601 ++++++++++++++++++++++++++++++++++- torchrl/data/__init__.py | 2 + torchrl/data/tensor_specs.py | 566 ++++++++++++++++++++++++++++----- 3 files changed, 1076 insertions(+), 93 deletions(-) diff --git a/test/test_specs.py b/test/test_specs.py index 6e77096fbea..3cf210006b5 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -10,13 +10,15 @@ import torchrl.data.tensor_specs from _utils_internal import get_available_devices, set_global_var from scipy.stats import chisquare -from tensordict.tensordict import TensorDict, TensorDictBase +from tensordict.tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase from torchrl.data.tensor_specs import ( _keys_to_empty_composite_spec, BinaryDiscreteTensorSpec, BoundedTensorSpec, CompositeSpec, DiscreteTensorSpec, + LazyStackedCompositeSpec, + LazyStackedTensorSpec, MultiDiscreteTensorSpec, MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, @@ -994,11 +996,6 @@ def test_equality_composite(self): assert ts != ts_other -if __name__ == "__main__": - args, unknown = argparse.ArgumentParser().parse_known_args() - pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) - - class TestSpec: @pytest.mark.parametrize( "action_spec_cls", [OneHotDiscreteTensorSpec, DiscreteTensorSpec] @@ -1672,3 +1669,595 @@ def test_unboundeddiscrete( spec = UnboundedDiscreteTensorSpec(shape=shape1, device="cpu", dtype=torch.long) assert spec == spec.clone() assert spec is not spec.clone() + + +@pytest.mark.parametrize( + "shape,stack_dim", + [[(), 0], [(2,), 0], [(2,), 1], [(2, 3), 0], [(2, 3), 1], [(2, 3), 2]], +) +class TestStack: + def test_stack_binarydiscrete(self, shape, stack_dim): + n = 5 + shape = (*shape, n) + c1 = BinaryDiscreteTensorSpec(n=n, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + assert isinstance(c, LazyStackedTensorSpec) + shape = list(shape) + if stack_dim < 0: + stack_dim = len(shape) + stack_dim + 1 + shape.insert(stack_dim, 2) + assert c.shape == torch.Size(shape) + + def test_stack_binarydiscrete_expand(self, shape, stack_dim): + n = 5 + shape = (*shape, n) + c1 = BinaryDiscreteTensorSpec(n=n, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + shape = list(shape) + if stack_dim < 0: + stack_dim = len(shape) + stack_dim + 1 + shape.insert(stack_dim, 2) + cexpand = c.expand(3, 2, *shape) + assert cexpand.shape == torch.Size([3, 2, *shape]) + + def test_stack_binarydiscrete_rand(self, shape, stack_dim): + n = 5 + shape = (*shape, n) + c1 = BinaryDiscreteTensorSpec(n=n, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], 0) + r = c.rand() + assert r.shape == c.shape + + def test_stack_binarydiscrete_zero(self, shape, stack_dim): + n = 5 + shape = (*shape, n) + c1 = BinaryDiscreteTensorSpec(n=n, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], 0) + r = c.zero() + assert r.shape == c.shape + + def test_stack_bounded(self, shape, stack_dim): + mini = -1 + maxi = 1 + shape = (*shape,) + c1 = BoundedTensorSpec(mini, maxi, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + assert isinstance(c, LazyStackedTensorSpec) + shape = list(shape) + if stack_dim < 0: + stack_dim = len(shape) + stack_dim + 1 + shape.insert(stack_dim, 2) + assert c.shape == torch.Size(shape) + + def test_stack_bounded_expand(self, shape, stack_dim): + mini = -1 + maxi = 1 + shape = (*shape,) + c1 = BoundedTensorSpec(mini, maxi, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + shape = list(shape) + if stack_dim < 0: + stack_dim = len(shape) + stack_dim + 1 + shape.insert(stack_dim, 2) + cexpand = c.expand(3, 2, *shape) + assert cexpand.shape == torch.Size([3, 2, *shape]) + + def test_stack_bounded_rand(self, shape, stack_dim): + mini = -1 + maxi = 1 + shape = (*shape,) + c1 = BoundedTensorSpec(mini, maxi, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], 0) + r = c.rand() + assert r.shape == c.shape + + def test_stack_bounded_zero(self, shape, stack_dim): + mini = -1 + maxi = 1 + shape = (*shape,) + c1 = BoundedTensorSpec(mini, maxi, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], 0) + r = c.zero() + assert r.shape == c.shape + + def test_stack_discrete(self, shape, stack_dim): + n = 4 + shape = (*shape,) + c1 = DiscreteTensorSpec(n, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + assert isinstance(c, LazyStackedTensorSpec) + shape = list(shape) + if stack_dim < 0: + stack_dim = len(shape) + stack_dim + 1 + shape.insert(stack_dim, 2) + assert c.shape == torch.Size(shape) + + def test_stack_discrete_expand(self, shape, stack_dim): + n = 4 + shape = (*shape,) + c1 = DiscreteTensorSpec(n, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + shape = list(shape) + if stack_dim < 0: + stack_dim = len(shape) + stack_dim + 1 + shape.insert(stack_dim, 2) + cexpand = c.expand(3, 2, *shape) + assert cexpand.shape == torch.Size([3, 2, *shape]) + + def test_stack_discrete_rand(self, shape, stack_dim): + n = 4 + shape = (*shape,) + c1 = DiscreteTensorSpec(n, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], 0) + r = c.rand() + assert r.shape == c.shape + + def test_stack_discrete_zero(self, shape, stack_dim): + n = 4 + shape = (*shape,) + c1 = DiscreteTensorSpec(n, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], 0) + r = c.zero() + assert r.shape == c.shape + + def test_stack_multidiscrete(self, shape, stack_dim): + nvec = [4, 5] + shape = (*shape, 2) + c1 = MultiDiscreteTensorSpec(nvec, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + assert isinstance(c, LazyStackedTensorSpec) + shape = list(shape) + if stack_dim < 0: + stack_dim = len(shape) + stack_dim + 1 + shape.insert(stack_dim, 2) + assert c.shape == torch.Size(shape) + + def test_stack_multidiscrete_expand(self, shape, stack_dim): + nvec = [4, 5] + shape = (*shape, 2) + c1 = MultiDiscreteTensorSpec(nvec, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + shape = list(shape) + if stack_dim < 0: + stack_dim = len(shape) + stack_dim + 1 + shape.insert(stack_dim, 2) + cexpand = c.expand(3, 2, *shape) + assert cexpand.shape == torch.Size([3, 2, *shape]) + + def test_stack_multidiscrete_rand(self, shape, stack_dim): + nvec = [4, 5] + shape = (*shape, 2) + c1 = MultiDiscreteTensorSpec(nvec, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], 0) + r = c.rand() + assert r.shape == c.shape + + def test_stack_multidiscrete_zero(self, shape, stack_dim): + nvec = [4, 5] + shape = (*shape, 2) + c1 = MultiDiscreteTensorSpec(nvec, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], 0) + r = c.zero() + assert r.shape == c.shape + + def test_stack_multionehot(self, shape, stack_dim): + nvec = [4, 5] + shape = (*shape, 9) + c1 = MultiOneHotDiscreteTensorSpec(nvec, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + assert isinstance(c, LazyStackedTensorSpec) + shape = list(shape) + if stack_dim < 0: + stack_dim = len(shape) + stack_dim + 1 + shape.insert(stack_dim, 2) + assert c.shape == torch.Size(shape) + + def test_stack_multionehot_expand(self, shape, stack_dim): + nvec = [4, 5] + shape = (*shape, 9) + c1 = MultiOneHotDiscreteTensorSpec(nvec, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + shape = list(shape) + if stack_dim < 0: + stack_dim = len(shape) + stack_dim + 1 + shape.insert(stack_dim, 2) + cexpand = c.expand(3, 2, *shape) + assert cexpand.shape == torch.Size([3, 2, *shape]) + + def test_stack_multionehot_rand(self, shape, stack_dim): + nvec = [4, 5] + shape = (*shape, 9) + c1 = MultiOneHotDiscreteTensorSpec(nvec, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], 0) + r = c.rand() + assert r.shape == c.shape + + def test_stack_multionehot_zero(self, shape, stack_dim): + nvec = [4, 5] + shape = (*shape, 9) + c1 = MultiOneHotDiscreteTensorSpec(nvec, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], 0) + r = c.zero() + assert r.shape == c.shape + + def test_stack_onehot(self, shape, stack_dim): + n = 5 + shape = (*shape, 5) + c1 = OneHotDiscreteTensorSpec(n, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + assert isinstance(c, LazyStackedTensorSpec) + shape = list(shape) + if stack_dim < 0: + stack_dim = len(shape) + stack_dim + 1 + shape.insert(stack_dim, 2) + assert c.shape == torch.Size(shape) + + def test_stack_onehot_expand(self, shape, stack_dim): + n = 5 + shape = (*shape, 5) + c1 = OneHotDiscreteTensorSpec(n, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + shape = list(shape) + if stack_dim < 0: + stack_dim = len(shape) + stack_dim + 1 + shape.insert(stack_dim, 2) + cexpand = c.expand(3, 2, *shape) + assert cexpand.shape == torch.Size([3, 2, *shape]) + + def test_stack_onehot_rand(self, shape, stack_dim): + n = 5 + shape = (*shape, 5) + c1 = OneHotDiscreteTensorSpec(n, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], 0) + r = c.rand() + assert r.shape == c.shape + + def test_stack_onehot_zero(self, shape, stack_dim): + n = 5 + shape = (*shape, 5) + c1 = OneHotDiscreteTensorSpec(n, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], 0) + r = c.zero() + assert r.shape == c.shape + + def test_stack_unboundedcont(self, shape, stack_dim): + shape = (*shape,) + c1 = UnboundedContinuousTensorSpec(shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + assert isinstance(c, LazyStackedTensorSpec) + shape = list(shape) + if stack_dim < 0: + stack_dim = len(shape) + stack_dim + 1 + shape.insert(stack_dim, 2) + assert c.shape == torch.Size(shape) + + def test_stack_unboundedcont_expand(self, shape, stack_dim): + shape = (*shape,) + c1 = UnboundedContinuousTensorSpec(shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + shape = list(shape) + if stack_dim < 0: + stack_dim = len(shape) + stack_dim + 1 + shape.insert(stack_dim, 2) + cexpand = c.expand(3, 2, *shape) + assert cexpand.shape == torch.Size([3, 2, *shape]) + + def test_stack_unboundedcont_rand(self, shape, stack_dim): + shape = (*shape,) + c1 = UnboundedContinuousTensorSpec(shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], 0) + r = c.rand() + assert r.shape == c.shape + + def test_stack_unboundedcont_zero(self, shape, stack_dim): + shape = (*shape,) + c1 = UnboundedDiscreteTensorSpec(shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], 0) + r = c.zero() + assert r.shape == c.shape + + def test_stack_unboundeddiscrete(self, shape, stack_dim): + shape = (*shape,) + c1 = UnboundedDiscreteTensorSpec(shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + assert isinstance(c, LazyStackedTensorSpec) + shape = list(shape) + if stack_dim < 0: + stack_dim = len(shape) + stack_dim + 1 + shape.insert(stack_dim, 2) + assert c.shape == torch.Size(shape) + + def test_stack_unboundeddiscrete_expand(self, shape, stack_dim): + shape = (*shape,) + c1 = UnboundedDiscreteTensorSpec(shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + shape = list(shape) + if stack_dim < 0: + stack_dim = len(shape) + stack_dim + 1 + shape.insert(stack_dim, 2) + cexpand = c.expand(3, 2, *shape) + assert cexpand.shape == torch.Size([3, 2, *shape]) + + def test_stack_unboundeddiscrete_rand(self, shape, stack_dim): + shape = (*shape,) + c1 = UnboundedDiscreteTensorSpec(shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], 0) + r = c.rand() + assert r.shape == c.shape + + def test_stack_unboundeddiscrete_zero(self, shape, stack_dim): + shape = (*shape,) + c1 = UnboundedDiscreteTensorSpec(shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], 0) + r = c.zero() + assert r.shape == c.shape + + +class TestStackComposite: + def test_stack(self): + c1 = CompositeSpec(a=UnboundedContinuousTensorSpec()) + c2 = c1.clone() + c = torch.stack([c1, c2], 0) + assert isinstance(c, LazyStackedCompositeSpec) + + def test_stack_index(self): + c1 = CompositeSpec(a=UnboundedContinuousTensorSpec()) + c2 = c1.clone() + c = torch.stack([c1, c2], 0) + assert c.shape == torch.Size([2]) + assert c[0] is c1 + assert c[1] is c2 + assert c[..., 0] is c1 + assert c[..., 1] is c2 + assert c[0, ...] is c1 + assert c[1, ...] is c2 + assert isinstance(c[:], LazyStackedCompositeSpec) + + @pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1]) + def test_stack_index_multdim(self, stack_dim): + c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + if stack_dim in (0, -3): + assert isinstance(c[:], LazyStackedCompositeSpec) + assert c.shape == torch.Size([2, 1, 3]) + assert c[0] is c1 + assert c[1] is c2 + with pytest.raises( + IndexError, + match="only permitted if the stack dimension is the last dimension", + ): + assert c[..., 0] is c1 + with pytest.raises( + IndexError, + match="only permitted if the stack dimension is the last dimension", + ): + assert c[..., 1] is c2 + assert c[0, ...] is c1 + assert c[1, ...] is c2 + elif stack_dim == (1, -2): + assert isinstance(c[:, :], LazyStackedCompositeSpec) + assert c.shape == torch.Size([1, 2, 3]) + assert c[:, 0] is c1 + assert c[:, 1] is c2 + with pytest.raises( + IndexError, match="along dimension 0 when the stack dimension is 1." + ): + assert c[0] is c1 + with pytest.raises( + IndexError, match="along dimension 0 when the stack dimension is 1." + ): + assert c[1] is c1 + with pytest.raises( + IndexError, + match="only permitted if the stack dimension is the last dimension", + ): + assert c[..., 0] is c1 + with pytest.raises( + IndexError, + match="only permitted if the stack dimension is the last dimension", + ): + assert c[..., 1] is c2 + assert c[..., 0, :] is c1 + assert c[..., 1, :] is c2 + assert c[:, 0, ...] is c1 + assert c[:, 1, ...] is c2 + elif stack_dim == (2, -1): + assert isinstance(c[:, :, :], LazyStackedCompositeSpec) + with pytest.raises( + IndexError, match="along dimension 0 when the stack dimension is 2." + ): + assert c[0] is c1 + with pytest.raises( + IndexError, match="along dimension 0 when the stack dimension is 2." + ): + assert c[1] is c1 + assert c.shape == torch.Size([1, 3, 2]) + assert c[:, :, 0] is c1 + assert c[:, :, 1] is c2 + assert c[..., 0] is c1 + assert c[..., 1] is c2 + assert c[:, :, 0, ...] is c1 + assert c[:, :, 1, ...] is c2 + + @pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1]) + def test_stack_expand_one(self, stack_dim): + c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) + c = torch.stack([c1], stack_dim) + if stack_dim in (0, -3): + c_expand = c.expand([4, 2, 1, 3]) + assert c_expand.shape == torch.Size([4, 2, 1, 3]) + assert c_expand.dim == 1 + elif stack_dim in (1, -2): + c_expand = c.expand([4, 1, 2, 3]) + assert c_expand.shape == torch.Size([4, 1, 2, 3]) + assert c_expand.dim == 2 + elif stack_dim in (2, -1): + c_expand = c.expand( + [ + 4, + 1, + 3, + 2, + ] + ) + assert c_expand.shape == torch.Size([4, 1, 3, 2]) + assert c_expand.dim == 3 + else: + raise NotImplementedError + + @pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1]) + def test_stack_expand_multi(self, stack_dim): + c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + if stack_dim in (0, -3): + c_expand = c.expand([4, 2, 1, 3]) + assert c_expand.shape == torch.Size([4, 2, 1, 3]) + assert c_expand.dim == 1 + elif stack_dim in (1, -2): + c_expand = c.expand([4, 1, 2, 3]) + assert c_expand.shape == torch.Size([4, 1, 2, 3]) + assert c_expand.dim == 2 + elif stack_dim in (2, -1): + c_expand = c.expand( + [ + 4, + 1, + 3, + 2, + ] + ) + assert c_expand.shape == torch.Size([4, 1, 3, 2]) + assert c_expand.dim == 3 + else: + raise NotImplementedError + + @pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1]) + def test_stack_rand(self, stack_dim): + c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + r = c.rand() + assert isinstance(r, LazyStackedTensorDict) + if stack_dim in (0, -3): + assert r.shape == torch.Size([2, 1, 3]) + assert r["a"].shape == torch.Size([2, 1, 3]) # access tensor + elif stack_dim in (1, -2): + assert r.shape == torch.Size([1, 2, 3]) + assert r["a"].shape == torch.Size([1, 2, 3]) # access tensor + elif stack_dim in (2, -1): + assert r.shape == torch.Size([1, 3, 2]) + assert r["a"].shape == torch.Size([1, 3, 2]) # access tensor + assert (r["a"] != 0).all() + + @pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1]) + def test_stack_rand_shape(self, stack_dim): + c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + shape = [5, 6] + r = c.rand(shape) + assert isinstance(r, LazyStackedTensorDict) + if stack_dim in (0, -3): + assert r.shape == torch.Size([*shape, 2, 1, 3]) + assert r["a"].shape == torch.Size([*shape, 2, 1, 3]) # access tensor + elif stack_dim in (1, -2): + assert r.shape == torch.Size([*shape, 1, 2, 3]) + assert r["a"].shape == torch.Size([*shape, 1, 2, 3]) # access tensor + elif stack_dim in (2, -1): + assert r.shape == torch.Size([*shape, 1, 3, 2]) + assert r["a"].shape == torch.Size([*shape, 1, 3, 2]) # access tensor + assert (r["a"] != 0).all() + + @pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1]) + def test_stack_zero(self, stack_dim): + c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + r = c.zero() + assert isinstance(r, LazyStackedTensorDict) + if stack_dim in (0, -3): + assert r.shape == torch.Size([2, 1, 3]) + assert r["a"].shape == torch.Size([2, 1, 3]) # access tensor + elif stack_dim in (1, -2): + assert r.shape == torch.Size([1, 2, 3]) + assert r["a"].shape == torch.Size([1, 2, 3]) # access tensor + elif stack_dim in (2, -1): + assert r.shape == torch.Size([1, 3, 2]) + assert r["a"].shape == torch.Size([1, 3, 2]) # access tensor + assert (r["a"] == 0).all() + + @pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1]) + def test_stack_zero_shape(self, stack_dim): + c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + shape = [5, 6] + r = c.zero(shape) + assert isinstance(r, LazyStackedTensorDict) + if stack_dim in (0, -3): + assert r.shape == torch.Size([*shape, 2, 1, 3]) + assert r["a"].shape == torch.Size([*shape, 2, 1, 3]) # access tensor + elif stack_dim in (1, -2): + assert r.shape == torch.Size([*shape, 1, 2, 3]) + assert r["a"].shape == torch.Size([*shape, 1, 2, 3]) # access tensor + elif stack_dim in (2, -1): + assert r.shape == torch.Size([*shape, 1, 3, 2]) + assert r["a"].shape == torch.Size([*shape, 1, 3, 2]) # access tensor + assert (r["a"] == 0).all() + + @pytest.mark.skipif(not torch.cuda.device_count(), reason="no cuda") + def test_to(self): + c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + cdevice = c.to("cuda:0") + assert cdevice.device != c.device + assert cdevice.device == torch.device("cuda:0") + assert cdevice[0].device == torch.device("cuda:0") + + def test_clone(self): + c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) + c2 = c1.clone() + c = torch.stack([c1, c2], 0) + cclone = c.clone() + assert cclone[0] is not c[0] + assert cclone[0] == c[0] + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index 4a0ac554218..788a2cce27d 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -21,6 +21,8 @@ CompositeSpec, DEVICE_TYPING, DiscreteTensorSpec, + LazyStackedCompositeSpec, + LazyStackedTensorSpec, MultiDiscreteTensorSpec, MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 51c1b5d6ed3..3bbe6128813 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -9,6 +9,7 @@ import warnings from copy import deepcopy from dataclasses import dataclass +from functools import wraps from textwrap import indent from typing import ( Any, @@ -233,6 +234,19 @@ class TensorSpec: dtype: torch.dtype = torch.float domain: str = "" + SPEC_HANDLED_FUNCTIONS = {} + + @classmethod + def implements_for_spec(cls, torch_function: Callable) -> Callable: + """Register a torch function override for TensorSpec.""" + + @wraps(torch_function) + def decorator(func): + cls.SPEC_HANDLED_FUNCTIONS[torch_function] = func + return func + + return decorator + def encode(self, val: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: """Encodes a value given the specified spec, and return the corresponding tensor. @@ -436,6 +450,259 @@ def __repr__(self): string = f"{self.__class__.__name__}(\n {sub_string})" return string + @classmethod + def __torch_function__( + cls, + func: Callable, + types, + args: Tuple = (), + kwargs: Optional[dict] = None, + ) -> Callable: + if kwargs is None: + kwargs = {} + if func not in cls.SPEC_HANDLED_FUNCTIONS or not all( + issubclass(t, (TensorSpec,)) for t in types + ): + return NotImplemented( + f"func {func} for spec {cls} with handles {cls.SPEC_HANDLED_FUNCTIONS}" + ) + return cls.SPEC_HANDLED_FUNCTIONS[func](*args, **kwargs) + + +class LazyStackedTensorSpec(TensorSpec): + """A lazy representation of a stack of tensor specs. + + Stacks tensor-specs together along one dimension. + When random samples are drawn, a stack of samples is returned if possible. + If not, an error is thrown. + + Indexing is allowed but only along the stack dimension. + + This class is aimed to be used in multi-task and multi-agent settings, where + heterogeneous specs may occur (same semantic but different shape). + + """ + + def __init__(self, *specs: TensorSpec, dim): + self._specs = specs + self.dim = dim + if self.dim < 0: + self.dim = len(self.shape) + self.dim + + def __getitem__(self, item): + is_key = isinstance(item, str) or ( + isinstance(item, tuple) and all(isinstance(_item, str) for _item in item) + ) + if is_key: + return torch.stack([spec[item] for spec in self._specs]) + elif isinstance(item, tuple): + # quick check that the index is along the stacked dim + # case 1: index is a tuple, and the first arg is an ellipsis. Then dim must be the last dim of all composite_specs + if item[0] is Ellipsis: + if len(item) == 1: + return self + elif self.dim == len(self.shape) - 1 and len(item) == 2: + # we can return + return self._specs[item[1]] + elif len(item) > 2: + # check that there is only one non-slice index + assigned = False + dim_idx = self.dim + for i, _item in enumerate(item[1:]): + if ( + isinstance(_item, slice) + and not ( + _item.start is None + and _item.stop is None + and _item.step is None + ) + ) or not isinstance(_item, slice): + if assigned: + raise RuntimeError( + "Found more than one meaningful index in a stacked composite spec." + ) + item = _item + dim_idx = i + 1 + assigned = True + if not assigned: + return self + if dim_idx != self.dim: + raise RuntimeError( + f"Indexing occured along dimension {dim_idx} but stacking was done along dim {self.dim}." + ) + out = self._specs[item] + if isinstance(out, TensorSpec): + return out + return torch.stack(list(out), 0) + else: + raise IndexError( + f"Indexing a {self.__class__.__name__} with [..., idx] is only permitted if the stack dimension is the last dimension. " + f"Got self.dim={self.dim} and self.shape={self.shape}." + ) + elif len(item) >= 2 and item[-1] is Ellipsis: + return self[item[:-1]] + elif any(_item is Ellipsis for _item in item): + raise IndexError("Cannot index along multiple dimensions.") + # Ellipsis is now ruled out + elif any(_item is None for _item in item): + raise IndexError( + f"Cannot index a {self.__class__.__name__} with None values" + ) + # Must be an index with slices then + else: + for i, _item in enumerate(item): + if i == self.dim: + out = self._specs[_item] + if isinstance(out, TensorSpec): + return out + return torch.stack(list(out), 0) + elif isinstance(_item, slice): + # then the slice must be trivial + if not (_item.step is _item.start is _item.stop is None): + raise IndexError( + f"Got a non-trivial index at dim {i} when only the dim {self.dim} could be indexed." + ) + else: + return self + else: + if not self.dim == 0: + raise IndexError( + f"Trying to index a {self.__class__.__name__} along dimension 0 when the stack dimension is {self.dim}." + ) + out = self._specs[item] + if isinstance(out, TensorSpec): + return out + return torch.stack(list(out), 0) + + @property + def shape(self): + shape = list(self._specs[0].shape) + dim = self.dim + if dim < 0: + dim = len(shape) + dim + 1 + shape.insert(dim, len(self._specs)) + return torch.Size(shape) + + def clone(self) -> CompositeSpec: + return torch.stack([spec.clone() for spec in self._specs], 0) + + def expand(self, *shape): + if len(shape) == 1 and not isinstance(shape[0], (int,)): + return self.expand(*shape[0]) + expand_shape = shape[: -len(self.shape)] + existing_shape = self.shape + shape_check = shape[-len(self.shape) :] + for _i, (size1, size2) in enumerate(zip(existing_shape, shape_check)): + if size1 != size2 and size1 != 1: + raise RuntimeError( + f"Expanding a non-singletom dimension: existing shape={size1} vs expand={size2}" + ) + elif size1 != size2 and size1 == 1 and _i == self.dim: + # if we're expanding along the stack dim we just need to clone the existing spec + return torch.stack( + [self._specs[0].clone() for _ in range(size2)], self.dim + ).expand(*shape) + if _i != len(self.shape) - 1: + raise RuntimeError( + f"Trying to expand non-congruent shapes: received {shape} when the shape is {self.shape}." + ) + # remove the stack dim from the expanded shape, which we know to match + unstack_shape = list(expand_shape) + [ + s for i, s in enumerate(shape_check) if i != self.dim + ] + return torch.stack( + [spec.expand(unstack_shape) for spec in self._specs], + self.dim + len(expand_shape), + ) + + def zero(self, shape=None) -> TensorDictBase: + if shape is not None: + dim = self.dim + len(shape) + else: + dim = self.dim + return torch.stack([spec.zero(shape) for spec in self._specs], dim) + + def rand(self, shape=None) -> TensorDictBase: + if shape is not None: + dim = self.dim + len(shape) + else: + dim = self.dim + return torch.stack([spec.rand(shape) for spec in self._specs], dim) + + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: + return torch.stack([spec.to(dest) for spec in self._specs], self.dim) + + def __eq__(self, other): + # requires unbind to be implemented + pass + + def to_numpy(self, val: TensorDict, safe: bool = True) -> dict: + pass + + def __len__(self): + pass + + def values(self) -> ValuesView: + pass + + def items(self) -> ItemsView: + pass + + def keys( + self, yield_nesting_keys: bool = False, nested_keys: bool = True + ) -> KeysView: + pass + + def project(self, val: TensorDictBase) -> TensorDictBase: + pass + + def is_in(self, val: Union[dict, TensorDictBase]) -> bool: + pass + + def type_check( + self, + value: Union[torch.Tensor, TensorDictBase], + selected_keys: Union[str, Optional[Sequence[str]]] = None, + ): + pass + + def __repr__(self): + pass + + def encode(self, vals: Dict[str, Any]) -> Dict[str, torch.Tensor]: + pass + + def __delitem__(self, key): + pass + + def __iter__(self): + pass + + def __setitem__(self, key, value): + pass + + @property + def device(self) -> DEVICE_TYPING: + pass + + @property + def ndim(self): + return self.ndimension() + + def ndimension(self): + return len(self.shape) + + def set(self, name, spec): + if spec is not None: + shape = spec.shape + if shape[: self.ndim] != self.shape: + raise ValueError( + "The shape of the spec and the CompositeSpec mismatch: the first " + f"{self.ndim} dimensions should match but got spec.shape={spec.shape} and " + f"CompositeSpec.shape={self.shape}." + ) + self._specs[name] = spec + @dataclass(repr=False) class OneHotDiscreteTensorSpec(TensorSpec): @@ -476,6 +743,8 @@ class OneHotDiscreteTensorSpec(TensorSpec): dtype: torch.dtype = torch.float domain: str = "" + # SPEC_HANDLED_FUNCTIONS = {} + def __init__( self, n: int, @@ -636,6 +905,8 @@ class BoundedTensorSpec(TensorSpec): """ + # SPEC_HANDLED_FUNCTIONS = {} + def __init__( self, minimum: Union[float, torch.Tensor, np.ndarray], @@ -679,17 +950,17 @@ def __init__( if shape is not None and shape != maximum.shape: raise RuntimeError(err_msg) shape = maximum.shape - minimum = minimum.expand(*shape).clone() + minimum = minimum.expand(shape).clone() elif minimum.ndimension(): if shape is not None and shape != minimum.shape: raise RuntimeError(err_msg) shape = minimum.shape - maximum = maximum.expand(*shape).clone() + maximum = maximum.expand(shape).clone() elif shape is None: raise RuntimeError(err_msg) else: - minimum = minimum.expand(*shape).clone() - maximum = maximum.expand(*shape).clone() + minimum = minimum.expand(shape).clone() + maximum = maximum.expand(shape).clone() if minimum.numel() > maximum.numel(): maximum = maximum.expand_as(minimum).clone() @@ -830,6 +1101,8 @@ class UnboundedContinuousTensorSpec(TensorSpec): (should be an floating point dtype such as float, double etc.) """ + # SPEC_HANDLED_FUNCTIONS = {} + def __init__( self, shape: Union[torch.Size, int] = _DEFAULT_SHAPE, @@ -899,6 +1172,8 @@ class UnboundedDiscreteTensorSpec(TensorSpec): (should be an integer dtype such as long, uint8 etc.) """ + # SPEC_HANDLED_FUNCTIONS = {} + def __init__( self, shape: Union[torch.Size, int] = _DEFAULT_SHAPE, @@ -994,6 +1269,8 @@ class BinaryDiscreteTensorSpec(TensorSpec): dtype: torch.dtype = torch.float domain: str = "" + # SPEC_HANDLED_FUNCTIONS = {} + def __init__( self, n: int, @@ -1003,7 +1280,7 @@ def __init__( ): dtype, device = _default_dtype_and_device(dtype, device) box = BinaryBox(n) - if shape is None: + if shape is None or not len(shape): shape = torch.Size((n,)) else: shape = torch.Size(shape) @@ -1093,6 +1370,8 @@ class MultiOneHotDiscreteTensorSpec(OneHotDiscreteTensorSpec): """ + # SPEC_HANDLED_FUNCTIONS = {} + def __init__( self, nvec: Sequence[int], @@ -1276,6 +1555,8 @@ class DiscreteTensorSpec(TensorSpec): dtype: torch.dtype = torch.float domain: str = "" + # SPEC_HANDLED_FUNCTIONS = {} + def __init__( self, n: int, @@ -1389,6 +1670,8 @@ class MultiDiscreteTensorSpec(DiscreteTensorSpec): False """ + # SPEC_HANDLED_FUNCTIONS = {} + def __init__( self, nvec: Union[Sequence[int], torch.Tensor, int], @@ -1591,6 +1874,8 @@ class CompositeSpec(TensorSpec): shape: torch.Size domain: str = "composite" + SPEC_HANDLED_FUNCTIONS = {} + @classmethod def __new__(cls, *args, **kwargs): cls._device = torch.device("cpu") @@ -1826,7 +2111,7 @@ def rand(self, shape=None) -> TensorDictBase: } return TensorDict( _dict, - batch_size=shape, + batch_size=[*shape, *self.shape], device=self._device, ) @@ -1950,7 +2235,7 @@ def expand(self, *shape): return out -class LazyStackedCompositeSpec(CompositeSpec): +class LazyStackedCompositeSpec(LazyStackedTensorSpec): """A lazy representation of a stack of composite specs. Stacks composite specs together along one dimension. @@ -1963,77 +2248,154 @@ class LazyStackedCompositeSpec(CompositeSpec): """ - def __init__(self, *composite_specs: CompositeSpec, dim): - self.composite_specs = composite_specs - self.dim = dim - if self.dim < 0: - self.dim = len(self.shape) + self.dim - - def __getitem__(self, item): - is_key = isinstance(item, str) or ( - isinstance(item, tuple) and all(isinstance(_item, str) for _item in item) - ) - if is_key: - return torch.stack( - [composite_spec[item] for composite_spec in self.composite_specs] - ) - elif isinstance(item, tuple): - # quick check that the index is along the stacked dim - # case 1: index is a tuple, and the first arg is an ellipsis. Then dim must be the last dim of all composite_specs - if item[0] is Ellipsis: - if len(item) == 0: - return self - elif self.dim == len(self.shape) - 1: - # we can return - return self.composite_specs[item[1]] - else: - raise IndexError( - "Indexing a LazyStackedCompositeSpec with [..., idx] is only permitted if the stack dimension is the last dimension. " - f"Got self.dim={self.dim} and self.shape={self.shape}." - ) - elif len(item) == 2 and item[1] is Ellipsis: - return self[item[0]] - elif any(_item is Ellipsis for _item in item): - raise IndexError("Cannot index along multiple dimensions.") - # Ellipsis is now ruled out - elif any(_item is None for _item in item): - raise IndexError( - "Cannot index a LazyStackedCompositeSpec with None values" - ) - # Must be an index with slices then - else: - for i, _item in enumerate(item): - if i == self.dim: - return torch.stack(list(self.composite_specs)[_item], self.dim) - elif isinstance(_item, slice): - # then the slice must be trivial - if not (_item.step is _item.start is _item.stop is None): - raise IndexError( - f"Got a non-trivial index at dim {i} when only the dim {self.dim} could be indexed." - ) - else: - return self - else: - if not self.dim == 0: - raise IndexError( - f"Trying to index a LazyStackedCompositeSpec along dimension 0 when the stack dimension is {self.dim}." - ) - return torch.stack(list(self.composite_specs)[item], 0) - - @property - def shape(self): - shape = list(self.composite_specs[0].shape) - dim = self.dim - if dim < 0: - dim = len(shape) + dim + 1 - shape.insert(dim, len(self.composite_specs)) - return torch.Size(shape) - - def clone(self) -> CompositeSpec: - pass - - def expand(self, *shape): - pass + # def __init__(self, *composite_specs: CompositeSpec, dim): + # self._specs = composite_specs + # self.dim = dim + # if self.dim < 0: + # self.dim = len(self.shape) + self.dim + # + # def __getitem__(self, item): + # is_key = isinstance(item, str) or ( + # isinstance(item, tuple) and all(isinstance(_item, str) for _item in item) + # ) + # if is_key: + # return torch.stack([composite_spec[item] for composite_spec in self._specs]) + # elif isinstance(item, tuple): + # # quick check that the index is along the stacked dim + # # case 1: index is a tuple, and the first arg is an ellipsis. Then dim must be the last dim of all composite_specs + # if item[0] is Ellipsis: + # if len(item) == 1: + # return self + # elif self.dim == len(self.shape) - 1 and len(item) == 2: + # # we can return + # return self._specs[item[1]] + # elif len(item) > 2: + # # check that there is only one non-slice index + # assigned = False + # dim_idx = self.dim + # for i, _item in enumerate(item[1:]): + # if ( + # isinstance(_item, slice) + # and not ( + # _item.start is None + # and _item.stop is None + # and _item.step is None + # ) + # ) or not isinstance(_item, slice): + # if assigned: + # raise RuntimeError( + # "Found more than one meaningful index in a stacked composite spec." + # ) + # item = _item + # dim_idx = i + 1 + # assigned = True + # if not assigned: + # return self + # if dim_idx != self.dim: + # raise RuntimeError( + # f"Indexing occured along dimension {dim_idx} but stacking was done along dim {self.dim}." + # ) + # out = self._specs[item] + # if isinstance(out, TensorSpec): + # return out + # return torch.stack(list(out), 0) + # else: + # raise IndexError( + # f"Indexing a {self.__class__.__name__} with [..., idx] is only permitted if the stack dimension is the last dimension. " + # f"Got self.dim={self.dim} and self.shape={self.shape}." + # ) + # elif len(item) >= 2 and item[-1] is Ellipsis: + # return self[item[:-1]] + # elif any(_item is Ellipsis for _item in item): + # raise IndexError("Cannot index along multiple dimensions.") + # # Ellipsis is now ruled out + # elif any(_item is None for _item in item): + # raise IndexError( + # f"Cannot index a {self.__class__.__name__} with None values" + # ) + # # Must be an index with slices then + # else: + # for i, _item in enumerate(item): + # if i == self.dim: + # out = self._specs[_item] + # if isinstance(out, TensorSpec): + # return out + # return torch.stack(list(out), 0) + # elif isinstance(_item, slice): + # # then the slice must be trivial + # if not (_item.step is _item.start is _item.stop is None): + # raise IndexError( + # f"Got a non-trivial index at dim {i} when only the dim {self.dim} could be indexed." + # ) + # else: + # return self + # else: + # if not self.dim == 0: + # raise IndexError( + # f"Trying to index a {self.__class__.__name__} along dimension 0 when the stack dimension is {self.dim}." + # ) + # out = self._specs[item] + # if isinstance(out, TensorSpec): + # return out + # return torch.stack(list(out), 0) + # + # @property + # def shape(self): + # shape = list(self._specs[0].shape) + # dim = self.dim + # if dim < 0: + # dim = len(shape) + dim + 1 + # shape.insert(dim, len(self._specs)) + # return torch.Size(shape) + # + # def clone(self) -> CompositeSpec: + # return torch.stack([spec.clone() for spec in self._specs], 0) + # + # def expand(self, *shape): + # if len(shape) == 1 and not isinstance(shape[0], (int,)): + # return self.expand(*shape[0]) + # expand_shape = shape[: -len(self.shape)] + # existing_shape = self.shape + # shape_check = shape[-len(self.shape) :] + # for _i, (size1, size2) in enumerate(zip(existing_shape, shape_check)): + # if size1 != size2 and size1 != 1: + # raise RuntimeError( + # f"Expanding a non-singletom dimension: existing shape={size1} vs expand={size2}" + # ) + # elif size1 != size2 and size1 == 1 and _i == self.dim: + # # if we're expanding along the stack dim we just need to clone the existing spec + # return torch.stack( + # [self._specs[0].clone() for _ in range(size2)], self.dim + # ).expand(*shape) + # if _i != len(self.shape) - 1: + # raise RuntimeError( + # f"Trying to expand non-congruent shapes: received {shape} when the shape is {self.shape}." + # ) + # # remove the stack dim from the expanded shape, which we know to match + # unstack_shape = list(expand_shape) + [ + # s for i, s in enumerate(shape_check) if i != self.dim + # ] + # return torch.stack( + # [spec.expand(unstack_shape) for spec in self._specs], + # self.dim + len(expand_shape), + # ) + # + # def zero(self, shape=None) -> TensorDictBase: + # if shape is not None: + # dim = self.dim + len(shape) + # else: + # dim = self.dim + # return torch.stack([spec.zero(shape) for spec in self._specs], dim) + # + # def rand(self, shape=None) -> TensorDictBase: + # if shape is not None: + # dim = self.dim + len(shape) + # else: + # dim = self.dim + # return torch.stack([spec.rand(shape) for spec in self._specs], dim) + # + # def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: + # return torch.stack([spec.to(dest) for spec in self._specs], self.dim) def update(self, dict_or_spec: Union[CompositeSpec, Dict[str, TensorSpec]]) -> None: pass @@ -2041,18 +2403,9 @@ def update(self, dict_or_spec: Union[CompositeSpec, Dict[str, TensorSpec]]) -> N def __eq__(self, other): pass - def zero(self, shape=None) -> TensorDictBase: - pass - - def rand(self, shape=None) -> TensorDictBase: - pass - def to_numpy(self, val: TensorDict, safe: bool = True) -> dict: pass - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: - pass - def __len__(self): pass @@ -2118,6 +2471,45 @@ def set(self, name, spec): self._specs[name] = spec +# for SPEC_CLASS in [BinaryDiscreteTensorSpec, BoundedTensorSpec, DiscreteTensorSpec, MultiDiscreteTensorSpec, MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec]: +@TensorSpec.implements_for_spec(torch.stack) +def _stack_specs(list_of_spec, dim, out=None): + if out is not None: + raise NotImplementedError( + "In-place spec modification is not a feature of torchrl, hence " + "torch.stack(list_of_specs, dim, out=spec) is not implemented." + ) + if not len(list_of_spec): + raise ValueError("Cannot stack an empty list of specs.") + if isinstance(list_of_spec[0], TensorSpec): + if not all(isinstance(spec, TensorSpec) for spec in list_of_spec): + raise RuntimeError( + "Stacking specs cannot occur: Found more than one type of specs in the list." + ) + return LazyStackedTensorSpec(*list_of_spec, dim=dim) + else: + raise NotImplementedError + + +@CompositeSpec.implements_for_spec(torch.stack) +def _stack_composite_specs(list_of_spec, dim, out=None): + if out is not None: + raise NotImplementedError( + "In-place spec modification is not a feature of torchrl, hence " + "torch.stack(list_of_specs, dim, out=spec) is not implemented." + ) + if not len(list_of_spec): + raise ValueError("Cannot stack an empty list of specs.") + if isinstance(list_of_spec[0], CompositeSpec): + if not all(isinstance(spec, CompositeSpec) for spec in list_of_spec): + raise RuntimeError( + "Stacking specs cannot occur: Found more than one type of specs in the list." + ) + return LazyStackedCompositeSpec(*list_of_spec, dim=dim) + else: + raise NotImplementedError + + def _keys_to_empty_composite_spec(keys): if not len(keys): return From a912a2e7119fa23af3c5aa69ec91d55d2c5c6161 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 9 Mar 2023 13:57:18 +0000 Subject: [PATCH 4/6] [Upgrading] Stack specs (#959) Co-authored-by: Tom Begley --- torchrl/data/tensor_specs.py | 307 +++++++++++------------------------ 1 file changed, 94 insertions(+), 213 deletions(-) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 15eb96ab7a7..32cdfc3ab9f 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -14,12 +14,14 @@ from typing import ( Any, Dict, + Generic, ItemsView, KeysView, List, Optional, Sequence, Tuple, + TypeVar, Union, ValuesView, ) @@ -469,21 +471,11 @@ def __torch_function__( return cls.SPEC_HANDLED_FUNCTIONS[func](*args, **kwargs) -class LazyStackedTensorSpec(TensorSpec): - """A lazy representation of a stack of tensor specs. - - Stacks tensor-specs together along one dimension. - When random samples are drawn, a stack of samples is returned if possible. - If not, an error is thrown. +T = TypeVar("T") - Indexing is allowed but only along the stack dimension. - This class is aimed to be used in multi-task and multi-agent settings, where - heterogeneous specs may occur (same semantic but different shape). - - """ - - def __init__(self, *specs: TensorSpec, dim): +class _LazyStackedMixin(Generic[T]): + def __init__(self, *specs: tuple[T, ...], dim: int) -> None: self._specs = specs self.dim = dim if self.dim < 0: @@ -494,7 +486,9 @@ def __getitem__(self, item): isinstance(item, tuple) and all(isinstance(_item, str) for _item in item) ) if is_key: - return torch.stack([spec[item] for spec in self._specs]) + return torch.stack( + [composite_spec[item] for composite_spec in self._specs], dim=self.dim + ) elif isinstance(item, tuple): # quick check that the index is along the stacked dim # case 1: index is a tuple, and the first arg is an ellipsis. Then dim must be the last dim of all composite_specs @@ -583,7 +577,7 @@ def shape(self): shape.insert(dim, len(self._specs)) return torch.Size(shape) - def clone(self) -> CompositeSpec: + def clone(self) -> T: return torch.stack([spec.clone() for spec in self._specs], 0) def expand(self, *shape): @@ -629,9 +623,33 @@ def rand(self, shape=None) -> TensorDictBase: dim = self.dim return torch.stack([spec.rand(shape) for spec in self._specs], dim) - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> T: return torch.stack([spec.to(dest) for spec in self._specs], self.dim) + +class LazyStackedTensorSpec(_LazyStackedMixin[TensorSpec], TensorSpec): + """A lazy representation of a stack of tensor specs. + + Stacks tensor-specs together along one dimension. + When random samples are drawn, a stack of samples is returned if possible. + If not, an error is thrown. + + Indexing is allowed but only along the stack dimension. + + This class is aimed to be used in multi-task and multi-agent settings, where + heterogeneous specs may occur (same semantic but different shape). + + """ + + @property + def space(self): + return self._specs[0].space + if shape is not None: + dim = self.dim + len(shape) + else: + dim = self.dim + return torch.stack([spec.rand(shape) for spec in self._specs], dim) + def __eq__(self, other): # requires unbind to be implemented pass @@ -667,7 +685,16 @@ def type_check( pass def __repr__(self): - pass + shape_str = "shape=" + str(self.shape) + space_str = "space=" + str(self._specs[0].space) + device_str = "device=" + str(self.device) + dtype_str = "dtype=" + str(self.dtype) + domain_str = "domain=" + str(self._specs[0].domain) + sub_string = ", ".join( + [shape_str, space_str, device_str, dtype_str, domain_str] + ) + string = f"{self.__class__.__name__}(\n {sub_string})" + return string def encode(self, vals: Dict[str, Any]) -> Dict[str, torch.Tensor]: pass @@ -756,9 +783,7 @@ def __init__( dtype, device = _default_dtype_and_device(dtype, device) self.use_register = use_register - space = DiscreteBox( - n, - ) + space = DiscreteBox(n) if shape is None: shape = torch.Size((space.n,)) else: @@ -2005,7 +2030,7 @@ def __init__(self, *args, shape=None, device=None, **kwargs): for key, value in kwargs.items(): self.set(key, value) - _device = device + _device = torch.device(device) if device is not None else device if len(kwargs): for key, item in self.items(): if item is None: @@ -2020,12 +2045,12 @@ def __init__(self, *args, shape=None, device=None, **kwargs): else: raise err - if _device is None: - _device = item_device - elif item_device != _device: + if _device is not None and item_device != _device: raise RuntimeError( - f"Setting a new attribute ({key}) on another device ({item.device} against {_device}). " - f"All devices of CompositeSpec must match." + f"Setting a new attribute ({key}) on another device (" + f"{item.device} against {_device}). If the device of " + "CompositeSpec has been defined, then all devices of its " + "entries must match that device." ) self._device = _device if len(args): @@ -2039,27 +2064,11 @@ def __init__(self, *args, shape=None, device=None, **kwargs): f"Expected a dictionary of specs, but got an argument of type {type(argdict)}." ) for k, item in argdict.items(): - if item is None: - continue - if self._device is None: - self._device = item.device - self[k] = item + if item is not None: + self[k] = item @property def device(self) -> DEVICE_TYPING: - if self._device is None: - # try to replace device by the true device - _device = None - for value in self.values(): - if value is not None: - _device = value.device - if _device is None: - raise RuntimeError( - "device of empty CompositeSpec is not defined. " - "You can set it directly by calling " - "`spec.device = device`." - ) - self._device = _device return self._device @device.setter @@ -2091,7 +2100,11 @@ def __setitem__(self, key, value): if key in {"shape", "device", "dtype", "space"}: raise AttributeError(f"CompositeSpec[{key}] cannot be set") try: - if value is not None and value.device != self.device: + if ( + value is not None + and self.device is not None + and value.device != self.device + ): raise RuntimeError( f"Setting a new attribute ({key}) on another device ({value.device} against {self.device}). " f"All devices of CompositeSpec must match." @@ -2332,7 +2345,7 @@ def expand(self, *shape): return out -class LazyStackedCompositeSpec(LazyStackedTensorSpec): +class LazyStackedCompositeSpec(_LazyStackedMixin[CompositeSpec], CompositeSpec): """A lazy representation of a stack of composite specs. Stacks composite specs together along one dimension. @@ -2345,155 +2358,6 @@ class LazyStackedCompositeSpec(LazyStackedTensorSpec): """ - # def __init__(self, *composite_specs: CompositeSpec, dim): - # self._specs = composite_specs - # self.dim = dim - # if self.dim < 0: - # self.dim = len(self.shape) + self.dim - # - # def __getitem__(self, item): - # is_key = isinstance(item, str) or ( - # isinstance(item, tuple) and all(isinstance(_item, str) for _item in item) - # ) - # if is_key: - # return torch.stack([composite_spec[item] for composite_spec in self._specs]) - # elif isinstance(item, tuple): - # # quick check that the index is along the stacked dim - # # case 1: index is a tuple, and the first arg is an ellipsis. Then dim must be the last dim of all composite_specs - # if item[0] is Ellipsis: - # if len(item) == 1: - # return self - # elif self.dim == len(self.shape) - 1 and len(item) == 2: - # # we can return - # return self._specs[item[1]] - # elif len(item) > 2: - # # check that there is only one non-slice index - # assigned = False - # dim_idx = self.dim - # for i, _item in enumerate(item[1:]): - # if ( - # isinstance(_item, slice) - # and not ( - # _item.start is None - # and _item.stop is None - # and _item.step is None - # ) - # ) or not isinstance(_item, slice): - # if assigned: - # raise RuntimeError( - # "Found more than one meaningful index in a stacked composite spec." - # ) - # item = _item - # dim_idx = i + 1 - # assigned = True - # if not assigned: - # return self - # if dim_idx != self.dim: - # raise RuntimeError( - # f"Indexing occured along dimension {dim_idx} but stacking was done along dim {self.dim}." - # ) - # out = self._specs[item] - # if isinstance(out, TensorSpec): - # return out - # return torch.stack(list(out), 0) - # else: - # raise IndexError( - # f"Indexing a {self.__class__.__name__} with [..., idx] is only permitted if the stack dimension is the last dimension. " - # f"Got self.dim={self.dim} and self.shape={self.shape}." - # ) - # elif len(item) >= 2 and item[-1] is Ellipsis: - # return self[item[:-1]] - # elif any(_item is Ellipsis for _item in item): - # raise IndexError("Cannot index along multiple dimensions.") - # # Ellipsis is now ruled out - # elif any(_item is None for _item in item): - # raise IndexError( - # f"Cannot index a {self.__class__.__name__} with None values" - # ) - # # Must be an index with slices then - # else: - # for i, _item in enumerate(item): - # if i == self.dim: - # out = self._specs[_item] - # if isinstance(out, TensorSpec): - # return out - # return torch.stack(list(out), 0) - # elif isinstance(_item, slice): - # # then the slice must be trivial - # if not (_item.step is _item.start is _item.stop is None): - # raise IndexError( - # f"Got a non-trivial index at dim {i} when only the dim {self.dim} could be indexed." - # ) - # else: - # return self - # else: - # if not self.dim == 0: - # raise IndexError( - # f"Trying to index a {self.__class__.__name__} along dimension 0 when the stack dimension is {self.dim}." - # ) - # out = self._specs[item] - # if isinstance(out, TensorSpec): - # return out - # return torch.stack(list(out), 0) - # - # @property - # def shape(self): - # shape = list(self._specs[0].shape) - # dim = self.dim - # if dim < 0: - # dim = len(shape) + dim + 1 - # shape.insert(dim, len(self._specs)) - # return torch.Size(shape) - # - # def clone(self) -> CompositeSpec: - # return torch.stack([spec.clone() for spec in self._specs], 0) - # - # def expand(self, *shape): - # if len(shape) == 1 and not isinstance(shape[0], (int,)): - # return self.expand(*shape[0]) - # expand_shape = shape[: -len(self.shape)] - # existing_shape = self.shape - # shape_check = shape[-len(self.shape) :] - # for _i, (size1, size2) in enumerate(zip(existing_shape, shape_check)): - # if size1 != size2 and size1 != 1: - # raise RuntimeError( - # f"Expanding a non-singletom dimension: existing shape={size1} vs expand={size2}" - # ) - # elif size1 != size2 and size1 == 1 and _i == self.dim: - # # if we're expanding along the stack dim we just need to clone the existing spec - # return torch.stack( - # [self._specs[0].clone() for _ in range(size2)], self.dim - # ).expand(*shape) - # if _i != len(self.shape) - 1: - # raise RuntimeError( - # f"Trying to expand non-congruent shapes: received {shape} when the shape is {self.shape}." - # ) - # # remove the stack dim from the expanded shape, which we know to match - # unstack_shape = list(expand_shape) + [ - # s for i, s in enumerate(shape_check) if i != self.dim - # ] - # return torch.stack( - # [spec.expand(unstack_shape) for spec in self._specs], - # self.dim + len(expand_shape), - # ) - # - # def zero(self, shape=None) -> TensorDictBase: - # if shape is not None: - # dim = self.dim + len(shape) - # else: - # dim = self.dim - # return torch.stack([spec.zero(shape) for spec in self._specs], dim) - # - # def rand(self, shape=None) -> TensorDictBase: - # if shape is not None: - # dim = self.dim + len(shape) - # else: - # dim = self.dim - # return torch.stack([spec.rand(shape) for spec in self._specs], dim) - # - # def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: - # return torch.stack([spec.to(dest) for spec in self._specs], self.dim) - def update(self, dict_or_spec: Union[CompositeSpec, Dict[str, TensorSpec]]) -> None: pass @@ -2506,16 +2370,20 @@ def to_numpy(self, val: TensorDict, safe: bool = True) -> dict: def __len__(self): pass - def values(self) -> ValuesView: - pass + def values(self): + for key in self.keys(): + yield self[key] - def items(self) -> ItemsView: - pass + def items(self): + for key in self.keys(): + yield key, self[key] def keys( self, yield_nesting_keys: bool = False, nested_keys: bool = True ) -> KeysView: - pass + return self._specs[0].keys( + yield_nesting_keys=yield_nesting_keys, nested_keys=nested_keys + ) def project(self, val: TensorDictBase) -> TensorDictBase: pass @@ -2530,8 +2398,14 @@ def type_check( ): pass - def __repr__(self): - pass + def __repr__(self) -> str: + sub_str = ",\n".join( + [indent(f"{k}: {repr(item)}", 4 * " ") for k, item in self.items()] + ) + device_str = f"device={self._specs[0].device}" + shape_str = f"shape={self.shape}" + sub_str = ", ".join([sub_str, device_str, shape_str]) + return f"CompositeSpec(\n{', '.join([sub_str, device_str, shape_str])})" def encode(self, vals: Dict[str, Any]) -> Dict[str, torch.Tensor]: pass @@ -2579,10 +2453,14 @@ def _stack_specs(list_of_spec, dim, out=None): if not len(list_of_spec): raise ValueError("Cannot stack an empty list of specs.") if isinstance(list_of_spec[0], TensorSpec): - if not all(isinstance(spec, TensorSpec) for spec in list_of_spec): - raise RuntimeError( - "Stacking specs cannot occur: Found more than one type of specs in the list." - ) + device = list_of_spec[0].device + for spec in list_of_spec: + if not isinstance(spec, TensorSpec): + raise RuntimeError( + "Stacking specs cannot occur: Found more than one type of specs in the list." + ) + if device != spec.device: + raise RuntimeError(f"Devices differ, got {device} and {spec.device}") return LazyStackedTensorSpec(*list_of_spec, dim=dim) else: raise NotImplementedError @@ -2598,10 +2476,15 @@ def _stack_composite_specs(list_of_spec, dim, out=None): if not len(list_of_spec): raise ValueError("Cannot stack an empty list of specs.") if isinstance(list_of_spec[0], CompositeSpec): - if not all(isinstance(spec, CompositeSpec) for spec in list_of_spec): - raise RuntimeError( - "Stacking specs cannot occur: Found more than one type of specs in the list." - ) + device = list_of_spec[0].device + for spec in list_of_spec: + if not isinstance(spec, CompositeSpec): + raise RuntimeError( + "Stacking specs cannot occur: Found more than one type of spec in " + "the list." + ) + if device != spec.device: + raise RuntimeError(f"Devices differ, got {device} and {spec.device}") return LazyStackedCompositeSpec(*list_of_spec, dim=dim) else: raise NotImplementedError @@ -2644,9 +2527,7 @@ def __init__( self.leaves_only = leaves_only self.include_nested = include_nested - def __iter__( - self, - ): + def __iter__(self): for key, item in self.composite.items(): if self.include_nested and isinstance(item, CompositeSpec): for subkey in item.keys( From c1acefd4657e9fe979589e9ff49648ece570e133 Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Mon, 13 Mar 2023 15:50:26 +0000 Subject: [PATCH 5/6] [Feature] Contiguous stacking of matching specs (#960) --- test/test_specs.py | 106 ++++++----- torchrl/data/tensor_specs.py | 359 ++++++++++++++++++++++++++++++++--- 2 files changed, 388 insertions(+), 77 deletions(-) diff --git a/test/test_specs.py b/test/test_specs.py index 46fe2645181..7472a2b8e08 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -18,7 +18,6 @@ CompositeSpec, DiscreteTensorSpec, LazyStackedCompositeSpec, - LazyStackedTensorSpec, MultiDiscreteTensorSpec, MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, @@ -1716,7 +1715,7 @@ def test_stack_binarydiscrete(self, shape, stack_dim): c1 = BinaryDiscreteTensorSpec(n=n, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], stack_dim) - assert isinstance(c, LazyStackedTensorSpec) + assert isinstance(c, BinaryDiscreteTensorSpec) shape = list(shape) if stack_dim < 0: stack_dim = len(shape) + stack_dim + 1 @@ -1761,7 +1760,7 @@ def test_stack_bounded(self, shape, stack_dim): c1 = BoundedTensorSpec(mini, maxi, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], stack_dim) - assert isinstance(c, LazyStackedTensorSpec) + assert isinstance(c, BoundedTensorSpec) shape = list(shape) if stack_dim < 0: stack_dim = len(shape) + stack_dim + 1 @@ -1808,7 +1807,7 @@ def test_stack_discrete(self, shape, stack_dim): c1 = DiscreteTensorSpec(n, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], stack_dim) - assert isinstance(c, LazyStackedTensorSpec) + assert isinstance(c, DiscreteTensorSpec) shape = list(shape) if stack_dim < 0: stack_dim = len(shape) + stack_dim + 1 @@ -1852,7 +1851,7 @@ def test_stack_multidiscrete(self, shape, stack_dim): c1 = MultiDiscreteTensorSpec(nvec, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], stack_dim) - assert isinstance(c, LazyStackedTensorSpec) + assert isinstance(c, MultiDiscreteTensorSpec) shape = list(shape) if stack_dim < 0: stack_dim = len(shape) + stack_dim + 1 @@ -1896,7 +1895,7 @@ def test_stack_multionehot(self, shape, stack_dim): c1 = MultiOneHotDiscreteTensorSpec(nvec, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], stack_dim) - assert isinstance(c, LazyStackedTensorSpec) + assert isinstance(c, MultiOneHotDiscreteTensorSpec) shape = list(shape) if stack_dim < 0: stack_dim = len(shape) + stack_dim + 1 @@ -1940,7 +1939,7 @@ def test_stack_onehot(self, shape, stack_dim): c1 = OneHotDiscreteTensorSpec(n, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], stack_dim) - assert isinstance(c, LazyStackedTensorSpec) + assert isinstance(c, OneHotDiscreteTensorSpec) shape = list(shape) if stack_dim < 0: stack_dim = len(shape) + stack_dim + 1 @@ -1983,7 +1982,7 @@ def test_stack_unboundedcont(self, shape, stack_dim): c1 = UnboundedContinuousTensorSpec(shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], stack_dim) - assert isinstance(c, LazyStackedTensorSpec) + assert isinstance(c, UnboundedContinuousTensorSpec) shape = list(shape) if stack_dim < 0: stack_dim = len(shape) + stack_dim + 1 @@ -2023,7 +2022,7 @@ def test_stack_unboundeddiscrete(self, shape, stack_dim): c1 = UnboundedDiscreteTensorSpec(shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], stack_dim) - assert isinstance(c, LazyStackedTensorSpec) + assert isinstance(c, UnboundedDiscreteTensorSpec) shape = list(shape) if stack_dim < 0: stack_dim = len(shape) + stack_dim + 1 @@ -2064,11 +2063,13 @@ def test_stack(self): c1 = CompositeSpec(a=UnboundedContinuousTensorSpec()) c2 = c1.clone() c = torch.stack([c1, c2], 0) - assert isinstance(c, LazyStackedCompositeSpec) + assert isinstance(c, CompositeSpec) def test_stack_index(self): c1 = CompositeSpec(a=UnboundedContinuousTensorSpec()) - c2 = c1.clone() + c2 = CompositeSpec( + a=UnboundedContinuousTensorSpec(), b=UnboundedDiscreteTensorSpec() + ) c = torch.stack([c1, c2], 0) assert c.shape == torch.Size([2]) assert c[0] is c1 @@ -2082,7 +2083,11 @@ def test_stack_index(self): @pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1]) def test_stack_index_multdim(self, stack_dim): c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) - c2 = c1.clone() + c2 = CompositeSpec( + a=UnboundedContinuousTensorSpec(shape=(1, 3)), + b=UnboundedDiscreteTensorSpec(shape=(1, 3)), + shape=(1, 3), + ) c = torch.stack([c1, c2], stack_dim) if stack_dim in (0, -3): assert isinstance(c[:], LazyStackedCompositeSpec) @@ -2146,36 +2151,14 @@ def test_stack_index_multdim(self, stack_dim): assert c[:, :, 0, ...] is c1 assert c[:, :, 1, ...] is c2 - @pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1]) - def test_stack_expand_one(self, stack_dim): - c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) - c = torch.stack([c1], stack_dim) - if stack_dim in (0, -3): - c_expand = c.expand([4, 2, 1, 3]) - assert c_expand.shape == torch.Size([4, 2, 1, 3]) - assert c_expand.dim == 1 - elif stack_dim in (1, -2): - c_expand = c.expand([4, 1, 2, 3]) - assert c_expand.shape == torch.Size([4, 1, 2, 3]) - assert c_expand.dim == 2 - elif stack_dim in (2, -1): - c_expand = c.expand( - [ - 4, - 1, - 3, - 2, - ] - ) - assert c_expand.shape == torch.Size([4, 1, 3, 2]) - assert c_expand.dim == 3 - else: - raise NotImplementedError - @pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1]) def test_stack_expand_multi(self, stack_dim): c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) - c2 = c1.clone() + c2 = CompositeSpec( + a=UnboundedContinuousTensorSpec(shape=(1, 3)), + b=UnboundedDiscreteTensorSpec(shape=(1, 3)), + shape=(1, 3), + ) c = torch.stack([c1, c2], stack_dim) if stack_dim in (0, -3): c_expand = c.expand([4, 2, 1, 3]) @@ -2202,7 +2185,11 @@ def test_stack_expand_multi(self, stack_dim): @pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1]) def test_stack_rand(self, stack_dim): c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) - c2 = c1.clone() + c2 = CompositeSpec( + a=UnboundedContinuousTensorSpec(shape=(1, 3)), + b=UnboundedDiscreteTensorSpec(shape=(1, 3)), + shape=(1, 3), + ) c = torch.stack([c1, c2], stack_dim) r = c.rand() assert isinstance(r, LazyStackedTensorDict) @@ -2220,7 +2207,11 @@ def test_stack_rand(self, stack_dim): @pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1]) def test_stack_rand_shape(self, stack_dim): c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) - c2 = c1.clone() + c2 = CompositeSpec( + a=UnboundedContinuousTensorSpec(shape=(1, 3)), + b=UnboundedDiscreteTensorSpec(shape=(1, 3)), + shape=(1, 3), + ) c = torch.stack([c1, c2], stack_dim) shape = [5, 6] r = c.rand(shape) @@ -2239,7 +2230,11 @@ def test_stack_rand_shape(self, stack_dim): @pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1]) def test_stack_zero(self, stack_dim): c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) - c2 = c1.clone() + c2 = CompositeSpec( + a=UnboundedContinuousTensorSpec(shape=(1, 3)), + b=UnboundedDiscreteTensorSpec(shape=(1, 3)), + shape=(1, 3), + ) c = torch.stack([c1, c2], stack_dim) r = c.zero() assert isinstance(r, LazyStackedTensorDict) @@ -2257,7 +2252,11 @@ def test_stack_zero(self, stack_dim): @pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1]) def test_stack_zero_shape(self, stack_dim): c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) - c2 = c1.clone() + c2 = CompositeSpec( + a=UnboundedContinuousTensorSpec(shape=(1, 3)), + b=UnboundedDiscreteTensorSpec(shape=(1, 3)), + shape=(1, 3), + ) c = torch.stack([c1, c2], stack_dim) shape = [5, 6] r = c.zero(shape) @@ -2274,18 +2273,31 @@ def test_stack_zero_shape(self, stack_dim): assert (r["a"] == 0).all() @pytest.mark.skipif(not torch.cuda.device_count(), reason="no cuda") - def test_to(self): + @pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1]) + def test_to(self, stack_dim): c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) - c2 = c1.clone() + c2 = CompositeSpec( + a=UnboundedContinuousTensorSpec(shape=(1, 3)), + b=UnboundedDiscreteTensorSpec(shape=(1, 3)), + shape=(1, 3), + ) c = torch.stack([c1, c2], stack_dim) + assert isinstance(c, LazyStackedCompositeSpec) cdevice = c.to("cuda:0") assert cdevice.device != c.device assert cdevice.device == torch.device("cuda:0") - assert cdevice[0].device == torch.device("cuda:0") + if stack_dim < 0: + stack_dim += 3 + index = (slice(None),) * stack_dim + (0,) + assert cdevice[index].device == torch.device("cuda:0") def test_clone(self): c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) - c2 = c1.clone() + c2 = CompositeSpec( + a=UnboundedContinuousTensorSpec(shape=(1, 3)), + b=UnboundedDiscreteTensorSpec(shape=(1, 3)), + shape=(1, 3), + ) c = torch.stack([c1, c2], 0) cclone = c.clone() assert cclone[0] is not c[0] diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 32cdfc3ab9f..0f36db422be 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -13,6 +13,7 @@ from textwrap import indent from typing import ( Any, + Callable, Dict, Generic, ItemsView, @@ -345,6 +346,24 @@ def expand(self, *shape): """ raise NotImplementedError + def squeeze(self, dim: int | None = None): + """Returns a new Spec with all the dimensions of size ``1`` removed. + + When ``dim`` is given, a squeeze operation is done only in that dimension. + + Args: + dim (int or None): the dimension to apply the squeeze operation to + + """ + shape = _squeezed_shape(self.shape, dim) + if shape is None: + return self + return self.__class__(shape=shape, device=self.device, dtype=self.dtype) + + def unsqueeze(self, dim: int): + shape = _unsqueezed_shape(self.shape, dim) + return self.__class__(shape=shape, device=self.device, dtype=self.dtype) + def _project(self, val: torch.Tensor) -> torch.Tensor: raise NotImplementedError @@ -644,11 +663,6 @@ class LazyStackedTensorSpec(_LazyStackedMixin[TensorSpec], TensorSpec): @property def space(self): return self._specs[0].space - if shape is not None: - dim = self.dim + len(shape) - else: - dim = self.dim - return torch.stack([spec.rand(shape) for spec in self._specs], dim) def __eq__(self, other): # requires unbind to be implemented @@ -667,7 +681,9 @@ def items(self) -> ItemsView: pass def keys( - self, yield_nesting_keys: bool = False, nested_keys: bool = True + self, + include_nested: bool = False, + leaves_only: bool = False, ) -> KeysView: pass @@ -837,6 +853,39 @@ def expand(self, *shape): n=shape[-1], shape=shape, device=self.device, dtype=self.dtype ) + def squeeze(self, dim=None): + if self.shape[-1] == 1 and dim in (len(self.shape), -1, None): + raise ValueError( + "Final dimension of OneHotDiscreteTensorSpec must remain unchanged" + ) + + shape = _squeezed_shape(self.shape, dim) + if shape is None: + return self + + return self.__class__( + n=shape[-1], + shape=shape, + device=self.device, + dtype=self.dtype, + use_register=self.use_register, + ) + + def unsqueeze(self, dim: int): + if dim in (len(self.shape), -1): + raise ValueError( + "Final dimension of OneHotDiscreteTensorSpec must remain unchanged" + ) + + shape = _unsqueezed_shape(self.shape, dim) + return self.__class__( + n=shape[-1], + shape=shape, + device=self.device, + dtype=self.dtype, + use_register=self.use_register, + ) + def rand(self, shape=None) -> torch.Tensor: if shape is None: shape = self.shape[:-1] @@ -1052,6 +1101,36 @@ def expand(self, *shape): dtype=self.dtype, ) + def squeeze(self, dim: int | None = None): + shape = _squeezed_shape(self.shape, dim) + if shape is None: + return self + + if dim is None: + minimum = self.space.minimum.squeeze().clone() + maximum = self.space.maximum.squeeze().clone() + else: + minimum = self.space.minimum.squeeze(dim).clone() + maximum = self.space.maximum.squeeze(dim).clone() + + return self.__class__( + minimum=minimum, + maximum=maximum, + shape=shape, + device=self.device, + dtype=self.dtype, + ) + + def unsqueeze(self, dim: int): + shape = _unsqueezed_shape(self.shape, dim) + return self.__class__( + minimum=self.space.minimum.unsqueeze(dim).clone(), + maximum=self.space.maximum.unsqueeze(dim).clone(), + shape=shape, + device=self.device, + dtype=self.dtype, + ) + def rand(self, shape=None) -> torch.Tensor: if shape is None: shape = torch.Size([]) @@ -1379,6 +1458,28 @@ def expand(self, *shape): n=shape[-1], shape=shape, device=self.device, dtype=self.dtype ) + def squeeze(self, dim: int | None = None): + if self.shape[-1] == 1 and dim in (len(self.shape), -1, None): + raise ValueError( + "Final dimension of BinaryDiscreteTensorSpec must remain unchanged" + ) + shape = _squeezed_shape(self.shape, dim) + if shape is None: + return self + return self.__class__( + n=shape[-1], shape=shape, device=self.device, dtype=self.dtype + ) + + def unsqueeze(self, dim: int): + if dim in (len(self.shape), -1): + raise ValueError( + "Final dimension of BinaryDiscreteTensorSpec must remain unchanged" + ) + shape = _unsqueezed_shape(self.shape, dim) + return self.__class__( + n=shape[-1], shape=shape, device=self.device, dtype=self.dtype + ) + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: if isinstance(dest, torch.dtype): dest_dtype = dest @@ -1589,6 +1690,29 @@ def expand(self, *shape): nvec=nvecs, shape=shape, device=self.device, dtype=self.dtype ) + def squeeze(self, dim=None): + if self.shape[-1] == 1 and dim in (len(self.shape), -1, None): + raise ValueError( + "Final dimension of MultiOneHotDiscreteTensorSpec must remain unchanged" + ) + + shape = _squeezed_shape(self.shape, dim) + if shape is None: + return self + return self.__class__( + nvec=self.nvec, shape=shape, device=self.device, dtype=self.dtype + ) + + def unsqueeze(self, dim: int): + if dim in (len(self.shape), -1): + raise ValueError( + "Final dimension of MultiOneHotDiscreteTensorSpec must remain unchanged" + ) + shape = _unsqueezed_shape(self.shape, dim) + return self.__class__( + nvec=self.nvec, shape=shape, device=self.device, dtype=self.dtype + ) + class DiscreteTensorSpec(TensorSpec): """A discrete tensor spec. @@ -1705,6 +1829,20 @@ def expand(self, *shape): n=self.space.n, shape=shape, device=self.device, dtype=self.dtype ) + def squeeze(self, dim=None): + shape = _squeezed_shape(self.shape, dim) + if shape is None: + return self + return self.__class__( + n=self.space.n, shape=shape, device=self.device, dtype=self.dtype + ) + + def unsqueeze(self, dim: int): + shape = _unsqueezed_shape(self.shape, dim) + return self.__class__( + n=self.space.n, shape=shape, device=self.device, dtype=self.dtype + ) + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: if isinstance(dest, torch.dtype): dest_dtype = dest @@ -1910,6 +2048,36 @@ def expand(self, *shape): nvec=self.nvec, shape=shape, device=self.device, dtype=self.dtype ) + def squeeze(self, dim: int | None = None): + if self.shape[-1] == 1 and dim in (len(self.shape), -1, None): + raise ValueError( + "Final dimension of MultiDiscreteTensorSpec must remain unchanged" + ) + + shape = _squeezed_shape(self.shape, dim) + if shape is None: + return self + + if dim is None: + nvec = self.nvec.squeeze() + else: + nvec = self.nvec.squeeze(dim) + + return self.__class__( + nvec=nvec, shape=shape, device=self.device, dtype=self.dtype + ) + + def unsqueeze(self, dim: int): + if dim in (len(self.shape), -1): + raise ValueError( + "Final dimension of MultiDiscreteTensorSpec must remain unchanged" + ) + shape = _unsqueezed_shape(self.shape, dim) + nvec = self.nvec.unsqueeze(dim) + return self.__class__( + nvec=nvec, shape=shape, device=self.device, dtype=self.dtype + ) + class CompositeSpec(TensorSpec): """A composition of TensorSpecs. @@ -2045,12 +2213,13 @@ def __init__(self, *args, shape=None, device=None, **kwargs): else: raise err - if _device is not None and item_device != _device: + if _device is None: + _device = item_device + elif item_device != _device: raise RuntimeError( - f"Setting a new attribute ({key}) on another device (" - f"{item.device} against {_device}). If the device of " - "CompositeSpec has been defined, then all devices of its " - "entries must match that device." + f"Setting a new attribute ({key}) on another device " + f"({item.device} against {_device}). All devices of " + "CompositeSpec must match." ) self._device = _device if len(args): @@ -2065,10 +2234,25 @@ def __init__(self, *args, shape=None, device=None, **kwargs): ) for k, item in argdict.items(): if item is not None: + if self._device is None: + self._device = item.device self[k] = item @property def device(self) -> DEVICE_TYPING: + if self._device is None: + # try to replace device by the true device + _device = None + for value in self.values(): + if value is not None: + _device = value.device + if _device is None: + raise RuntimeError( + "device of empty CompositeSpec is not defined. " + "You can set it directly by calling " + "`spec.device = device`." + ) + self._device = _device return self._device @device.setter @@ -2100,11 +2284,7 @@ def __setitem__(self, key, value): if key in {"shape", "device", "dtype", "space"}: raise AttributeError(f"CompositeSpec[{key}] cannot be set") try: - if ( - value is not None - and self.device is not None - and value.device != self.device - ): + if value is not None and value.device != self.device: raise RuntimeError( f"Setting a new attribute ({key}) on another device ({value.device} against {self.device}). " f"All devices of CompositeSpec must match." @@ -2227,9 +2407,7 @@ def keys( Default is ``False``. """ return _CompositeSpecKeysView( - self, - include_nested=include_nested, - leaves_only=leaves_only, + self, include_nested=include_nested, leaves_only=leaves_only ) def items(self) -> ItemsView: @@ -2304,7 +2482,9 @@ def update(self, dict_or_spec: Union[CompositeSpec, Dict[str, TensorSpec]]) -> N continue try: if isinstance(item, TensorSpec) and item.device != self.device: - item = deepcopy(item).to(self.device) + item = deepcopy(item) + if self.device is not None: + item = item.to(self.device) except RuntimeError as err: if DEVICE_ERR_MSG in str(err): try: @@ -2344,6 +2524,52 @@ def expand(self, *shape): ) return out + def squeeze(self, dim: int | None = None): + if dim is not None: + if dim < 0: + dim += len(self.shape) + + shape = _squeezed_shape(self.shape, dim) + if shape is None: + return self + + try: + device = self.device + except RuntimeError: + device = self._device + + return CompositeSpec( + {key: value.squeeze(dim) for key, value in self.items()}, + shape=shape, + device=device, + ) + + if self.shape.count(1) == 0: + return self + + # we can't just recursively apply squeeze with dim=None because we don't want + # to squeeze non-batch dims of the values. Instead we find the first dim in + # the batch dims with size 1, squeeze that, then recurse on the root spec + out = self.squeeze(self.shape.index(1)) + return out.squeeze() + + def unsqueeze(self, dim: int): + if dim < 0: + dim += len(self.shape) + + shape = _unsqueezed_shape(self.shape, dim) + + try: + device = self.device + except RuntimeError: + device = self._device + + return CompositeSpec( + {key: value.unsqueeze(dim) for key, value in self.items()}, + shape=shape, + device=device, + ) + class LazyStackedCompositeSpec(_LazyStackedMixin[CompositeSpec], CompositeSpec): """A lazy representation of a stack of composite specs. @@ -2379,10 +2605,12 @@ def items(self): yield key, self[key] def keys( - self, yield_nesting_keys: bool = False, nested_keys: bool = True + self, + include_nested: bool = False, + leaves_only: bool = False, ) -> KeysView: return self._specs[0].keys( - yield_nesting_keys=yield_nesting_keys, nested_keys=nested_keys + include_nested=include_nested, leaves_only=leaves_only ) def project(self, val: TensorDictBase) -> TensorDictBase: @@ -2405,7 +2633,9 @@ def __repr__(self) -> str: device_str = f"device={self._specs[0].device}" shape_str = f"shape={self.shape}" sub_str = ", ".join([sub_str, device_str, shape_str]) - return f"CompositeSpec(\n{', '.join([sub_str, device_str, shape_str])})" + return ( + f"LazyStackedCompositeSpec(\n{', '.join([sub_str, device_str, shape_str])})" + ) def encode(self, vals: Dict[str, Any]) -> Dict[str, torch.Tensor]: pass @@ -2421,7 +2651,7 @@ def __setitem__(self, key, value): @property def device(self) -> DEVICE_TYPING: - pass + return self._specs[0].device @property def ndim(self): @@ -2452,15 +2682,24 @@ def _stack_specs(list_of_spec, dim, out=None): ) if not len(list_of_spec): raise ValueError("Cannot stack an empty list of specs.") - if isinstance(list_of_spec[0], TensorSpec): - device = list_of_spec[0].device - for spec in list_of_spec: + spec0 = list_of_spec[0] + if isinstance(spec0, TensorSpec): + device = spec0.device + all_equal = True + for spec in list_of_spec[1:]: if not isinstance(spec, TensorSpec): raise RuntimeError( "Stacking specs cannot occur: Found more than one type of specs in the list." ) if device != spec.device: raise RuntimeError(f"Devices differ, got {device} and {spec.device}") + all_equal = all_equal and spec == spec0 + if all_equal: + shape = list(spec0.shape) + if dim < 0: + dim += len(shape) + 1 + shape.insert(dim, len(list_of_spec)) + return spec0.clone().unsqueeze(dim).expand(shape) return LazyStackedTensorSpec(*list_of_spec, dim=dim) else: raise NotImplementedError @@ -2475,9 +2714,11 @@ def _stack_composite_specs(list_of_spec, dim, out=None): ) if not len(list_of_spec): raise ValueError("Cannot stack an empty list of specs.") - if isinstance(list_of_spec[0], CompositeSpec): - device = list_of_spec[0].device - for spec in list_of_spec: + spec0 = list_of_spec[0] + if isinstance(spec0, CompositeSpec): + device = spec0.device + all_equal = True + for spec in list_of_spec[1:]: if not isinstance(spec, CompositeSpec): raise RuntimeError( "Stacking specs cannot occur: Found more than one type of spec in " @@ -2485,11 +2726,38 @@ def _stack_composite_specs(list_of_spec, dim, out=None): ) if device != spec.device: raise RuntimeError(f"Devices differ, got {device} and {spec.device}") + all_equal = all_equal and spec == spec0 + if all_equal: + shape = list(spec0.shape) + if dim < 0: + dim += len(shape) + 1 + shape.insert(dim, len(list_of_spec)) + return spec0.clone().unsqueeze(dim).expand(shape) return LazyStackedCompositeSpec(*list_of_spec, dim=dim) else: raise NotImplementedError +@TensorSpec.implements_for_spec(torch.squeeze) +def _squeeze_spec(spec: TensorSpec, *args, **kwargs) -> TensorSpec: + return spec.squeeze(*args, **kwargs) + + +@CompositeSpec.implements_for_spec(torch.squeeze) +def _squeeze_composite_spec(spec: CompositeSpec, *args, **kwargs) -> CompositeSpec: + return spec.squeeze(*args, **kwargs) + + +@TensorSpec.implements_for_spec(torch.unsqueeze) +def _unsqueeze_spec(spec: TensorSpec, *args, **kwargs) -> TensorSpec: + return spec.unsqueeze(*args, **kwargs) + + +@CompositeSpec.implements_for_spec(torch.unsqueeze) +def _unsqueeze_composite_spec(spec: CompositeSpec, *args, **kwargs) -> CompositeSpec: + return spec.unsqueeze(*args, **kwargs) + + def _keys_to_empty_composite_spec(keys): """Given a list of keys, creates a CompositeSpec tree where each leaf is assigned a None value.""" if not len(keys): @@ -2514,6 +2782,37 @@ def _keys_to_empty_composite_spec(keys): return c +def _squeezed_shape(shape: torch.Size, dim: int | None) -> torch.Size | None: + if dim is None: + if len(shape) == 1 or shape.count(1) == 0: + return None + new_shape = torch.Size([s for s in shape if s != 1]) + else: + if dim < 0: + dim += len(shape) + + if shape[dim] != 1: + return None + + new_shape = torch.Size([s for i, s in enumerate(shape) if i != dim]) + return new_shape + + +def _unsqueezed_shape(shape: torch.Size, dim: int) -> torch.Size: + n = len(shape) + if dim < -(n + 1) or dim > n: + raise ValueError( + f"Dimension out of range, expected value in the range [{-(n+1)}, {n}], but " + f"got {dim}" + ) + if dim < 0: + dim += n + 1 + + new_shape = list(shape) + new_shape.insert(dim, 1) + return torch.Size(new_shape) + + class _CompositeSpecKeysView: """Wrapper class that enables richer behaviour of `key in tensordict.keys()`.""" From 4bf6c5f950f61deb513a8fe29d22d2cd2c329127 Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Tue, 14 Mar 2023 17:30:09 +0000 Subject: [PATCH 6/6] [Feature] Various improvements to LazyStacked specs (#965) --- test/test_specs.py | 40 +++++++++++++++++++++++++-- torchrl/data/tensor_specs.py | 53 ++++++++++++++---------------------- 2 files changed, 58 insertions(+), 35 deletions(-) diff --git a/test/test_specs.py b/test/test_specs.py index 7472a2b8e08..aaf30a0906e 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -2045,7 +2045,7 @@ def test_stack_unboundeddiscrete_rand(self, shape, stack_dim): shape = (*shape,) c1 = UnboundedDiscreteTensorSpec(shape=shape) c2 = c1.clone() - c = torch.stack([c1, c2], 0) + c = torch.stack([c1, c2], stack_dim) r = c.rand() assert r.shape == c.shape @@ -2053,10 +2053,28 @@ def test_stack_unboundeddiscrete_zero(self, shape, stack_dim): shape = (*shape,) c1 = UnboundedDiscreteTensorSpec(shape=shape) c2 = c1.clone() - c = torch.stack([c1, c2], 0) + c = torch.stack([c1, c2], stack_dim) r = c.zero() assert r.shape == c.shape + def test_to_numpy(self, shape, stack_dim): + c1 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float64) + c2 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float32) + c = torch.stack([c1, c2], stack_dim) + + shape = list(shape) + shape.insert(stack_dim, 2) + shape = tuple(shape) + + val = 2 * torch.rand(torch.Size(shape)) - 1 + + val_np = c.to_numpy(val) + assert isinstance(val_np, np.ndarray) + assert (val.numpy() == val_np).all() + + with pytest.raises(AssertionError): + c.to_numpy(val + 1) + class TestStackComposite: def test_stack(self): @@ -2303,6 +2321,24 @@ def test_clone(self): assert cclone[0] is not c[0] assert cclone[0] == c[0] + def test_to_numpy(self): + c1 = CompositeSpec(a=BoundedTensorSpec(-1, 1, shape=(1, 3)), shape=(1, 3)) + c2 = CompositeSpec( + a=BoundedTensorSpec(-1, 1, shape=(1, 3)), + b=UnboundedDiscreteTensorSpec(shape=(1, 3)), + shape=(1, 3), + ) + c = torch.stack([c1, c2], 0) + for _ in range(100): + r = c.rand() + for key, value in c.to_numpy(r).items(): + spec = c[key] + assert (spec.to_numpy(r[key]) == value).all() + + td_fail = TensorDict({"a": torch.rand((2, 1, 3)) + 1}, [2, 1, 3]) + with pytest.raises(AssertionError): + c.to_numpy(td_fail) + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 0f36db422be..54ea17fb4e2 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -668,38 +668,23 @@ def __eq__(self, other): # requires unbind to be implemented pass - def to_numpy(self, val: TensorDict, safe: bool = True) -> dict: - pass + def to_numpy(self, val: torch.Tensor, safe: bool = True) -> dict: + if safe: + if val.shape[self.dim] != len(self._specs): + raise ValueError( + "Size of LazyStackedTensorSpec and val differ along the stacking " + "dimension" + ) + for spec, v in zip(self._specs, torch.unbind(val, dim=self.dim)): + spec.assert_is_in(v) + return val.detach().cpu().numpy() def __len__(self): pass - def values(self) -> ValuesView: - pass - - def items(self) -> ItemsView: - pass - - def keys( - self, - include_nested: bool = False, - leaves_only: bool = False, - ) -> KeysView: - pass - def project(self, val: TensorDictBase) -> TensorDictBase: pass - def is_in(self, val: Union[dict, TensorDictBase]) -> bool: - pass - - def type_check( - self, - value: Union[torch.Tensor, TensorDictBase], - selected_keys: Union[str, Optional[Sequence[str]]] = None, - ): - pass - def __repr__(self): shape_str = "shape=" + str(self.shape) space_str = "space=" + str(self._specs[0].space) @@ -712,12 +697,6 @@ def __repr__(self): string = f"{self.__class__.__name__}(\n {sub_string})" return string - def encode(self, vals: Dict[str, Any]) -> Dict[str, torch.Tensor]: - pass - - def __delitem__(self, key): - pass - def __iter__(self): pass @@ -726,7 +705,7 @@ def __setitem__(self, key, value): @property def device(self) -> DEVICE_TYPING: - pass + return self._specs[0].device @property def ndim(self): @@ -2591,7 +2570,15 @@ def __eq__(self, other): pass def to_numpy(self, val: TensorDict, safe: bool = True) -> dict: - pass + if safe: + if val.shape[self.dim] != len(self._specs): + raise ValueError( + "Size of LazyStackedCompositeSpec and val differ along the " + "stacking dimension" + ) + for spec, v in zip(self._specs, torch.unbind(val, dim=self.dim)): + spec.assert_is_in(v) + return {key: self[key].to_numpy(val) for key, val in val.items()} def __len__(self): pass