diff --git a/test/test_specs.py b/test/test_specs.py index 2c4f412c680..470aa3b4b0b 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -4301,6 +4301,291 @@ def test_composite(self): assert c_enum["b"].shape == torch.Size((20, 3)) +class TestCompositeNames: + """Test the names functionality of Composite specs.""" + + def test_names_property_basic(self): + """Test basic names property functionality.""" + # Test with names + spec = Composite( + {"obs": Bounded(low=-1, high=1, shape=(10, 5, 3, 4))}, + shape=(10, 5), + names=["batch", "time"], + ) + assert spec.names == ["batch", "time"] + assert spec._has_names() is True + + # Test without names + spec_no_names = Composite( + {"obs": Bounded(low=-1, high=1, shape=(10, 5, 3, 4))}, shape=(10, 5) + ) + assert spec_no_names.names == [None, None] + assert spec_no_names._has_names() is False + + def test_names_setter(self): + """Test setting names.""" + spec = Composite( + {"obs": Bounded(low=-1, high=1, shape=(10, 5, 3, 4))}, shape=(10, 5) + ) + + # Set names + spec.names = ["batch", "time"] + assert spec.names == ["batch", "time"] + assert spec._has_names() is True + + # Clear names + spec.names = None + assert spec.names == [None, None] + assert spec._has_names() is False + + def test_names_setter_validation(self): + """Test names setter validation.""" + spec = Composite( + {"obs": Bounded(low=-1, high=1, shape=(10, 5, 3, 4))}, shape=(10, 5) + ) + + # Test wrong number of names + with pytest.raises(ValueError, match="Expected 2 names, but got 3 names"): + spec.names = ["batch", "time", "extra"] + + def test_refine_names_basic(self): + """Test basic refine_names functionality.""" + spec = Composite( + {"obs": Bounded(low=-1, high=1, shape=(10, 5, 3, 4))}, shape=(10, 5, 3) + ) + + # Initially no names + assert spec.names == [None, None, None] + assert spec._has_names() is False + + # Refine names + spec_refined = spec.refine_names(None, None, "feature") + assert spec_refined.names == [None, None, "feature"] + assert spec_refined._has_names() is True + + def test_refine_names_ellipsis(self): + """Test refine_names with ellipsis.""" + spec = Composite( + {"obs": Bounded(low=-1, high=1, shape=(10, 5, 3, 4))}, + shape=(10, 5, 3), + names=["batch", None, None], + ) + + # Use ellipsis to fill remaining dimensions + spec_refined = spec.refine_names("batch", ...) + assert spec_refined.names == ["batch", None, None] + + def test_refine_names_validation(self): + """Test refine_names validation.""" + spec = Composite( + {"obs": Bounded(low=-1, high=1, shape=(10, 5, 3, 4))}, + shape=(10, 5), + names=["batch", "time"], + ) + + # Try to refine to different name + with pytest.raises(RuntimeError, match="cannot coerce Composite names"): + spec.refine_names("batch", "different") + + def test_expand_preserves_names(self): + """Test that expand preserves names.""" + spec = Composite( + {"obs": Bounded(low=-1, high=1, shape=(10, 3, 4))}, + shape=(10,), + names=["batch"], + ) + + expanded = spec.expand(5, 10) + assert expanded.names == [None, "batch"] + assert expanded.shape == torch.Size([5, 10]) + + def test_squeeze_preserves_names(self): + """Test that squeeze preserves names.""" + spec = Composite( + {"obs": Bounded(low=-1, high=1, shape=(10, 1, 5, 3, 4))}, + shape=(10, 1, 5), + names=["batch", "dummy", "time"], + ) + + squeezed = spec.squeeze(1) # Remove the dimension with size 1 + assert squeezed.names == ["batch", "time"] + assert squeezed.shape == torch.Size([10, 5]) + + def test_squeeze_all_ones_clears_names(self): + """Test that squeezing all dimensions clears names if all become None.""" + spec = Composite( + {"obs": Bounded(low=-1, high=1, shape=(1, 1, 3, 4))}, + shape=(1, 1), + names=["dummy1", "dummy2"], + ) + + squeezed = spec.squeeze() + assert squeezed.names == [] # All dimensions removed, so no names + assert squeezed.shape == torch.Size([]) + + def test_unsqueeze_preserves_names(self): + """Test that unsqueeze preserves names.""" + spec = Composite( + {"obs": Bounded(low=-1, high=1, shape=(10, 5, 3, 4))}, + shape=(10, 5), + names=["batch", "time"], + ) + + unsqueezed = spec.unsqueeze(1) + assert unsqueezed.names == ["batch", None, "time"] + assert unsqueezed.shape == torch.Size([10, 1, 5]) + + def test_unbind_preserves_names(self): + """Test that unbind preserves names.""" + spec = Composite( + {"obs": Bounded(low=-1, high=1, shape=(3, 5, 3, 4))}, + shape=(3, 5), + names=["batch", "time"], + ) + + unbound = spec.unbind(0) + assert len(unbound) == 3 + for spec_item in unbound: + assert spec_item.names == ["time"] + assert spec_item.shape == torch.Size([5]) + + def test_clone_preserves_names(self): + """Test that clone preserves names.""" + spec = Composite( + {"obs": Bounded(low=-1, high=1, shape=(10, 3, 4))}, + shape=(10,), + names=["batch"], + ) + + cloned = spec.clone() + assert cloned.names == ["batch"] + assert cloned.shape == spec.shape + assert cloned is not spec # Different objects + + def test_to_preserves_names(self): + """Test that to() preserves names.""" + spec = Composite( + {"obs": Bounded(low=-1, high=1, shape=(10, 3, 4))}, + shape=(10,), + names=["batch"], + ) + + moved = spec.to("cpu") + assert moved.names == ["batch"] + assert moved.device == torch.device("cpu") + + def test_indexing_preserves_names(self): + """Test that indexing preserves names.""" + spec = Composite( + {"obs": Bounded(low=-1, high=1, shape=(10, 5, 3, 4))}, + shape=(10, 5), + names=["batch", "time"], + ) + + # Test single dimension indexing + indexed = spec[0] + assert indexed.names == ["time"] + assert indexed.shape == torch.Size([5]) + + # Test slice indexing + sliced = spec[0:5] + assert sliced.names == ["batch", "time"] + assert sliced.shape == torch.Size([5, 5]) + + def test_nested_composite_names_propagation(self): + """Test that names are propagated to nested Composite specs.""" + nested_spec = Composite( + { + "outer": Composite( + {"inner": Bounded(low=-1, high=1, shape=(10, 3, 2))}, shape=(10, 3) + ) + }, + shape=(10,), + names=["batch"], + ) + + assert nested_spec.names == ["batch"] + assert nested_spec["outer"].names == ["batch", None] + + def test_erase_names(self): + """Test erasing names.""" + spec = Composite( + {"obs": Bounded(low=-1, high=1, shape=(10, 3, 4))}, + shape=(10,), + names=["batch"], + ) + + assert spec._has_names() is True + spec._erase_names() + assert spec._has_names() is False + assert spec.names == [None] + + def test_names_with_different_shapes(self): + """Test names with different spec shapes.""" + spec = Composite( + { + "obs": Bounded(low=-1, high=1, shape=(10, 5, 3, 4)), + "action": Bounded(low=0, high=1, shape=(10, 5, 2)), + }, + shape=(10, 5), + names=["batch", "time"], + ) + + assert spec.names == ["batch", "time"] + assert spec["obs"].shape == torch.Size([10, 5, 3, 4]) + assert spec["action"].shape == torch.Size([10, 5, 2]) + + def test_names_constructor_parameter(self): + """Test names parameter in constructor.""" + # Test with names + spec = Composite( + {"obs": Bounded(low=-1, high=1, shape=(10, 5, 3, 4))}, + shape=(10, 5), + names=["batch", "time"], + ) + assert spec.names == ["batch", "time"] + + # Test without names + spec_no_names = Composite( + {"obs": Bounded(low=-1, high=1, shape=(10, 5, 3, 4))}, shape=(10, 5) + ) + assert spec_no_names.names == [None, None] + + def test_names_with_empty_composite(self): + """Test names with empty Composite.""" + spec = Composite({}, shape=(10,), names=["batch"]) + assert spec.names == ["batch"] + assert spec._has_names() is True + + def test_names_equality(self): + """Test that names don't affect equality.""" + spec1 = Composite( + {"obs": Bounded(low=-1, high=1, shape=(10, 3, 4))}, + shape=(10,), + names=["batch"], + ) + + spec2 = Composite( + {"obs": Bounded(low=-1, high=1, shape=(10, 3, 4))}, shape=(10,) + ) + + # They should be equal despite different names + assert spec1 == spec2 + + def test_names_repr(self): + """Test that names don't break repr.""" + spec = Composite( + {"obs": Bounded(low=-1, high=1, shape=(10, 3, 4))}, + shape=(10,), + names=["batch"], + ) + + # Should not raise an error + repr_str = repr(spec) + assert "Composite" in repr_str + assert "obs" in repr_str + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 2b1dbd8e29f..4542d55cfd1 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -13,7 +13,7 @@ import warnings import weakref from collections.abc import Callable, Iterable, Mapping, Sequence -from copy import deepcopy +from copy import copy, deepcopy from dataclasses import dataclass, field from functools import wraps from textwrap import indent @@ -5095,6 +5095,7 @@ class Composite(TensorSpec): shape: torch.Size domain: str = "composite" + _td_dim_names: list[str] | None = None SPEC_HANDLED_FUNCTIONS = {} @@ -5111,6 +5112,7 @@ def __init__( device: torch.device | None = None, data_cls: type | None = None, step_mdp_static: bool = False, + names: Sequence[str] | None = None, **kwargs, ): # For compatibility with TensorDict @@ -5126,6 +5128,12 @@ def __init__( self._specs = {} self.step_mdp_static = step_mdp_static + # Initialize names + if names is not None: + self._td_dim_names = list(names) + else: + self._td_dim_names = None + _device = ( _make_ordinal_device(torch.device(device)) if device is not None else device ) @@ -5142,6 +5150,8 @@ def __init__( ) for k, item in argdict.items(): if isinstance(item, dict): + # Create nested Composite with appropriate names + # Note: nested specs will get their names propagated later in the names setter item = Composite(item, shape=shape, device=_device) self[k] = item for k, item in kwargs.items(): @@ -5150,6 +5160,10 @@ def __init__( self.encode = self._encode_eager self._encode_memo_dict = {} + # Propagate names to nested specs if names were provided + if names is not None: + self._propagate_names_to_nested() + def memoize_encode(self, mode: bool = True) -> None: super().memoize_encode(mode=mode) for spec in self._specs.values(): @@ -5354,6 +5368,127 @@ def clear_device_(self): spec.clear_device_() return self + def _has_names(self): + """Returns True if names are set for this Composite.""" + return self._td_dim_names is not None + + def _erase_names(self): + """Erases the names of this Composite.""" + self._td_dim_names = None + + def _propagate_names_to_nested(self): + """Propagates names to nested Composite specs.""" + if not self._has_names(): + return + for spec in self._specs.values(): + if isinstance(spec, Composite): + # For nested specs, we need to propagate the names + # The nested spec should have the same leading dimensions + if spec.ndim >= self.ndim: + nested_names = list(self.names) + [None] * (spec.ndim - self.ndim) + spec.names = nested_names + + @property + def names(self): + """Returns the names of the dimensions of this Composite.""" + names = self._td_dim_names + if names is None: + return [None for _ in range(self.ndim)] + # Return a copy but don't use copy to make dynamo happy + return list(names) + + @names.setter + def names(self, value): + """Sets the names of the dimensions of this Composite.""" + if value is None: + self._td_dim_names = None + return + if len(value) != self.ndim: + raise ValueError( + f"Expected {self.ndim} names, but got {len(value)} names: {value}" + ) + self._td_dim_names = list(value) + # Propagate names to nested Composite specs + for spec in self._specs.values(): + if isinstance(spec, Composite): + # For nested specs, we need to propagate the names + # The nested spec should have the same leading dimensions + if spec.ndim >= self.ndim: + nested_names = list(value) + [None] * (spec.ndim - self.ndim) + spec.names = nested_names + + def refine_names(self, *names): + """Refines the dimension names of self according to names. + + Refining is a special case of renaming that "lifts" unnamed dimensions. + A None dim can be refined to have any name; a named dim can only be + refined to have the same name. + + Because named specs can coexist with unnamed specs, refining names + gives a nice way to write named-spec-aware code that works with both + named and unnamed specs. + + names may contain up to one Ellipsis (...). The Ellipsis is expanded + greedily; it is expanded in-place to fill names to the same length as + self.ndim using names from the corresponding indices of self.names. + + Returns: the same composite spec with dimensions named according to the input. + + Examples: + >>> spec = Composite({}, shape=[3, 4, 5, 6]) + >>> spec_refined = spec.refine_names(None, None, None, "d") + >>> assert spec_refined.names == [None, None, None, "d"] + >>> spec_refined = spec.refine_names("a", None, None, "d") + >>> assert spec_refined.names == ["a", None, None, "d"] + + """ + # replace ellipsis if any + names_copy = copy(names) + if any(name is Ellipsis for name in names): + ellipsis_name = [None for _ in range(self.ndim - len(names) + 1)] + names = [] + for name in names_copy: + if name is Ellipsis: + names += ellipsis_name + else: + names.append(name) + + # check that the names that are set are either None or identical + curr_names = self.names + for i, name in enumerate(names): + if curr_names[i] is None: + continue + if curr_names[i] == name: + continue + else: + raise RuntimeError( + f"refine_names: cannot coerce Composite names {self.names} with {names_copy}." + ) + self.names = names + return self + + def _get_names_idx(self, idx): + """Helper method to get names after indexing.""" + if not self._has_names(): + return None + + names = copy(self.names) + if isinstance(idx, (int, slice)): + # Single dimension indexing + if isinstance(idx, int): + names.pop(idx) + else: + # For slice, we keep the names but adjust for the slice + pass + elif isinstance(idx, tuple): + # Multi-dimensional indexing + for i, sub_idx in enumerate(idx): + if isinstance(sub_idx, int): + # Remove the dimension + names.pop(i) + # For slices, we keep the name + return names + def __getitem__(self, idx): """Indexes the current Composite based on the provided index.""" if isinstance(idx, (str, tuple)): @@ -5393,10 +5528,15 @@ def __getitem__(self, idx): except RuntimeError: device = self._device + names = None + if self._has_names(): + names = self._get_names_idx(idx) + return self.__class__( indexed_specs, shape=indexed_shape, device=device, + names=names, ) def get(self, item, default=NO_DEFAULT): @@ -5760,6 +5900,7 @@ def to(self, dest: torch.dtype | DEVICE_TYPING) -> Composite: shape=self.shape, data_cls=self.data_cls, step_mdp_static=self.step_mdp_static, + names=self.names if self._has_names() else None, ) if not isinstance(dest, (str, int, torch.device)): raise ValueError( @@ -5782,6 +5923,7 @@ def to(self, dest: torch.dtype | DEVICE_TYPING) -> Composite: shape=self.shape, data_cls=self.data_cls, step_mdp_static=self.step_mdp_static, + names=self.names if self._has_names() else None, ) def clone(self) -> Composite: @@ -5802,6 +5944,7 @@ def clone(self) -> Composite: shape=self.shape, data_cls=self.data_cls, step_mdp_static=self.step_mdp_static, + names=self.names if self._has_names() else None, ) def cardinality(self) -> int: @@ -5942,12 +6085,17 @@ def expand(self, *shape: tuple[int, ...] | torch.Size) -> Composite: else None for key, value in tuple(self.items()) } + names = None + if self._has_names(): + names = [None] * (len(shape) - self.ndim) + self.names + out = Composite( specs, shape=shape, device=device, data_cls=self.data_cls, step_mdp_static=self.step_mdp_static, + names=names, ) return out @@ -5965,12 +6113,21 @@ def squeeze(self, dim: int | None = None) -> Composite: except RuntimeError: device = self._device + names = None + if self._has_names(): + names = copy(self.names) + names.pop(dim) + # If all names are None after popping, set to None + if all(name is None for name in names): + names = None + return self.__class__( {key: value.squeeze(dim) for key, value in self.items()}, shape=shape, device=device, data_cls=self.data_cls, step_mdp_static=self.step_mdp_static, + names=names, ) if self.shape.count(1) == 0: @@ -5993,6 +6150,11 @@ def unsqueeze(self, dim: int) -> Composite: except RuntimeError: device = self._device + names = None + if self._has_names(): + names = copy(self.names) + names.insert(dim, None) + return self.__class__( { key: value.unsqueeze(dim) if value is not None else None @@ -6002,6 +6164,7 @@ def unsqueeze(self, dim: int) -> Composite: device=device, data_cls=self.data_cls, step_mdp_static=self.step_mdp_static, + names=names, ) def unbind(self, dim: int = 0) -> tuple[Composite, ...]: @@ -6012,8 +6175,17 @@ def unbind(self, dim: int = 0) -> tuple[Composite, ...]: raise ValueError( f"Cannot unbind along dim {orig_dim} with shape {self.shape}." ) - shape = (s for i, s in enumerate(self.shape) if i != dim) + shape = tuple(s for i, s in enumerate(self.shape) if i != dim) unbound_vals = {key: val.unbind(dim) for key, val in self.items()} + + names = None + if self._has_names(): + names = copy(self.names) + names.pop(dim) + # If all names are None after popping, set to None + if all(name is None for name in names): + names = None + return tuple( self.__class__( {key: val[i] for key, val in unbound_vals.items()}, @@ -6021,6 +6193,7 @@ def unbind(self, dim: int = 0) -> tuple[Composite, ...]: device=self.device, data_cls=self.data_cls, step_mdp_static=self.step_mdp_static, + names=names, ) for i in range(self.shape[dim]) )