From 87bc004ff914d41417999dd4a5b6366997e4648b Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 26 Jul 2023 17:30:25 +0100 Subject: [PATCH 1/4] amend --- tensordict/__init__.py | 2 + tensordict/nn/__init__.py | 1 + tensordict/nn/params.py | 539 ++++++++++++++++++++++++++++++++++++++ tensordict/persistent.py | 2 + tensordict/tensordict.py | 53 ++-- test/_utils_internal.py | 4 + test/test_tensordict.py | 326 ++++++++++++++++++----- 7 files changed, 850 insertions(+), 77 deletions(-) create mode 100644 tensordict/nn/params.py diff --git a/tensordict/__init__.py b/tensordict/__init__.py index 54b8d1dce..784f05f1b 100644 --- a/tensordict/__init__.py +++ b/tensordict/__init__.py @@ -45,3 +45,5 @@ "PersistentTensorDict", "tensorclass", ] + +# from tensordict._pytree import * diff --git a/tensordict/nn/__init__.py b/tensordict/nn/__init__.py index d21e7e4a3..575d14567 100644 --- a/tensordict/nn/__init__.py +++ b/tensordict/nn/__init__.py @@ -18,6 +18,7 @@ make_functional, repopulate_module, ) +from tensordict.nn.params import TensorDictParams from tensordict.nn.probabilistic import ( InteractionType, ProbabilisticTensorDictModule, diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py new file mode 100644 index 000000000..068f97a5e --- /dev/null +++ b/tensordict/nn/params.py @@ -0,0 +1,539 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import functools +import inspect +import numbers +import re +from copy import copy +from functools import wraps +from typing import Any, Callable, Sequence + +import torch + +from tensordict import TensorDictBase +from tensordict.tensordict import ( + CompatibleType, + lock_blocked, + NO_DEFAULT, + TD_HANDLED_FUNCTIONS, + TensorDict, +) +from tensordict.utils import DeviceType, erase_cache, IndexType, NestedKey +from torch import nn, Tensor +from torch.utils._pytree import tree_map + + +def _get_args_dict(func, args, kwargs): + signature = inspect.signature(func) + bound_arguments = signature.bind(*args, **kwargs) + bound_arguments.apply_defaults() + + args_dict = dict(bound_arguments.arguments) + return args_dict + + +def _maybe_make_param(tensor): + if ( + isinstance(tensor, Tensor) + and not isinstance(tensor, nn.Parameter) + and tensor.dtype in (torch.float, torch.double, torch.half) + ): + tensor = nn.Parameter(tensor) + return tensor + + +class _unlock_and_set: + def __new__(cls, *args, **kwargs): + if len(args) and callable(args[0]): + return cls(**kwargs)(args[0]) + return super().__new__(cls) + + def __init__(self, **only_for_kwargs): + self.only_for_kwargs = only_for_kwargs + + def __call__(self, func): + name = func.__name__ + + @wraps(func) + def new_func(_self, *args, **kwargs): + if self.only_for_kwargs: + arg_dict = _get_args_dict(func, (_self, *args), kwargs) + for kwarg, exp_value in self.only_for_kwargs.items(): + cur_val = arg_dict.get(kwarg, NO_DEFAULT) + if cur_val != exp_value: + # escape + meth = getattr(_self._param_td, name) + out = meth(*args, **kwargs) + return out + args = tree_map(_maybe_make_param, args) + kwargs = tree_map(_maybe_make_param, kwargs) + with _self._param_td.unlock_(): + meth = getattr(_self._param_td, name) + out = meth(*args, **kwargs) + _self.__dict__["_parameters"] = _self._param_td.flatten_keys("_").to_dict() + if out is _self._param_td: + return _self + return out + + return new_func + + +def _fallback(func): + name = func.__name__ + + @wraps(func) + def new_func(self, *args, **kwargs): + out = getattr(self._param_td, name)(*args, **kwargs) + if out is self._param_td: + return self + return out + + return new_func + + +def _fallback_property(func): + name = func.__name__ + + @wraps(func) + def new_func(self): + out = getattr(self._param_td, name) + if out is self._param_td: + return self + return out + + return property(new_func) + + +def _replace(func): + name = func.__name__ + + @wraps(func) + def new_func(self, *args, **kwargs): + out = getattr(self._param_td, name)(*args, **kwargs) + if out is self._param_td: + return self + self._param_td = out + return self + + return new_func + + +def _carry_over(func): + name = func.__name__ + + @wraps(func) + def new_func(self, *args, **kwargs): + out = getattr(self._param_td, name)(*args, **kwargs) + return TensorDictParams(out, no_convert=True) + + return new_func + + +class TensorDictParams(TensorDictBase, nn.Module): + r"""Holds a TensorDictBase instance full of parameters. + + This class exposes the contained parameters to a parent nn.Module + such that iterating over the parameters of the module also iterates over + the leaves of the tensordict. + + Indexing works exactly as the indexing of the wrapped tensordict. + TODO: Parameter names + + Any operation that sets a tensor in the tensordict will be augmented by + a :class:`torch.nn.Parameter` conversion. + """ + + def __init__(self, parameters: TensorDictBase, no_convert=False): + super().__init__() + self._param_td = parameters + if not no_convert: + self._param_td = self._param_td.apply( + lambda x: _maybe_make_param(x) + ).lock_() + self._parameters = parameters.flatten_keys("_").to_dict() + self._is_locked = False + self._locked_tensordicts = [] + self.__last_op_queue = None + + @classmethod + def __torch_function__( + cls, + func: Callable, + types: tuple[type, ...], + args: tuple[Any, ...] = (), + kwargs: dict[str, Any] | None = None, + ) -> Callable: + if kwargs is None: + kwargs = {} + if func not in TDPARAM_HANDLED_FUNCTIONS or not all( + issubclass(t, (Tensor, TensorDictBase)) for t in types + ): + return NotImplemented + return TDPARAM_HANDLED_FUNCTIONS[func](*args, **kwargs) + + @classmethod + def _flatten_key(cls, key): + def make_valid_identifier(s): + # Replace invalid characters with underscores + s = re.sub(r"\W|^(?=\d)", "_", s) + + # Ensure the string starts with a letter or underscore + if not s[0].isalpha() and s[0] != "_": + s = "_" + s + + return s + + key_flat = "_".join(key) + if not key_flat.isidentifier(): + key_flat = make_valid_identifier(key_flat) + return key_flat + + @lock_blocked + @_unlock_and_set + def __setitem__( + self, + index: IndexType, + value: TensorDictBase | dict | numbers.Number | CompatibleType, + ) -> None: + ... + + @lock_blocked + @_unlock_and_set + def set( + self, key: NestedKey, item: CompatibleType, inplace: bool = False, **kwargs: Any + ) -> TensorDictBase: + ... + + def update( + self, + input_dict_or_td: dict[str, CompatibleType] | TensorDictBase, + clone: bool = False, + inplace: bool = False, + ) -> TensorDictBase: + if isinstance(input_dict_or_td, TensorDictBase): + input_dict_or_td = input_dict_or_td.apply(_maybe_make_param) + else: + input_dict_or_td = tree_map(_maybe_make_param, input_dict_or_td) + with self._param_td.unlock_(): + TensorDictBase.update(self, input_dict_or_td, clone=clone, inplace=inplace) + return self + + @lock_blocked + @_unlock_and_set + def pop( + self, key: NestedKey, default: str | CompatibleType = NO_DEFAULT + ) -> CompatibleType: + ... + + @lock_blocked + @_unlock_and_set + def rename_key_( + self, old_key: str, new_key: str, safe: bool = False + ) -> TensorDictBase: + ... + + @_unlock_and_set + def apply_(self, fn: Callable, *others) -> TensorDictBase: + ... + + @_unlock_and_set(inplace=True) + def apply( + self, + fn: Callable, + *others: TensorDictBase, + batch_size: Sequence[int] | None = None, + device: torch.device | None = None, + names: Sequence[str] | None = None, + inplace: bool = False, + **constructor_kwargs, + ) -> TensorDictBase: + ... + + @_fallback + def get( + self, key: NestedKey, default: str | CompatibleType = NO_DEFAULT + ) -> CompatibleType: + ... + + @_fallback + def __getitem__(self, index: IndexType) -> TensorDictBase: + ... + + @_replace + def to(self, dest: DeviceType | type | torch.Size, **kwargs) -> TensorDictBase: + ... + + @_replace + def cpu(self): + ... + + @_replace + def cuda(self): + ... + + def clone(self, recurse: bool = True) -> TensorDictBase: + return TensorDictParams(self._param_td.clone(recurse=recurse)) + + @_fallback + def chunk(self, chunks: int, dim: int = 0) -> tuple[TensorDictBase, ...]: + ... + + @_fallback + def unbind(self, dim: int) -> tuple[TensorDictBase, ...]: + ... + + @_fallback + def to_tensordict(self): + ... + + @_fallback + def to_h5( + self, + filename, + **kwargs, + ): + ... + + def __hash__(self): + return hash((id(self), id(self._param_td))) + + @_fallback + def __eq__(self, other: object) -> TensorDictBase: + ... + + @_fallback + def __ne__(self, other: object) -> TensorDictBase: + ... + + def __getattr__(self, item: str) -> Any: + try: + return getattr(self._param_td, item) + except AttributeError: + return super().__getattr__(item) + + @_fallback + def _change_batch_size(self, *args, **kwargs): + ... + + @_fallback + def _erase_names(self, *args, **kwargs): + ... + + # @_unlock_and_set # we need this as one sub-module could call _get_str, get a td and want to modify it + @_fallback + def _get_str(self, *args, **kwargs): + ... + + # @_unlock_and_set + @_fallback + def _get_tuple(self, *args, **kwargs): + ... + + @_fallback + def _has_names(self, *args, **kwargs): + ... + + @_unlock_and_set + def _rename_subtds(self, *args, **kwargs): + ... + + @_unlock_and_set + def _set_at_str(self, *args, **kwargs): + ... + + @_fallback + def _set_at_tuple(self, *args, **kwargs): + ... + + @_unlock_and_set + def _set_str(self, *args, **kwargs): + ... + + @_unlock_and_set + def _set_tuple(self, *args, **kwargs): + ... + + @_unlock_and_set + def _create_nested_str(self, *args, **kwargs): + ... + + @_fallback + def _stack_onto_(self, *args, **kwargs): + ... + + @_fallback_property + def batch_size(self) -> torch.Size: + ... + + @_fallback + def contiguous(self, *args, **kwargs): + ... + + @lock_blocked + @_unlock_and_set + def del_(self, *args, **kwargs): + ... + + @_fallback + def detach_(self, *args, **kwargs): + ... + + @_fallback_property + def device(self): + ... + + @_fallback + def entry_class(self, *args, **kwargs): + ... + + @_fallback + def is_contiguous(self, *args, **kwargs): + ... + + @_fallback + def keys(self, *args, **kwargs): + ... + + @_fallback + def masked_fill(self, *args, **kwargs): + ... + + @_fallback + def masked_fill_(self, *args, **kwargs): + ... + + def memmap_( + self, prefix: str | None = None, copy_existing: bool = False + ) -> TensorDictBase: + raise RuntimeError("Cannot build a memmap TensorDict in-place.") + + @_fallback_property + def names(self): + ... + + @_fallback + def pin_memory(self, *args, **kwargs): + ... + + @_unlock_and_set + def select(self, *args, **kwargs): + ... + + @_fallback + def share_memory_(self, *args, **kwargs): + ... + + @property + def is_locked(self) -> bool: + # Cannot be locked + return self._is_locked + + @is_locked.setter + def is_locked(self, value): + self._is_locked = bool(value) + + @_fallback_property + def is_shared(self) -> bool: + ... + + @_fallback_property + def is_memmap(self) -> bool: + ... + + @_fallback_property + def shape(self) -> torch.Size: + ... + + @erase_cache + def _propagate_unlock(self, lock_ids=None): + if lock_ids is not None: + self._lock_id.difference_update(lock_ids) + else: + lock_ids = set() + self._is_locked = False + + unlocked_tds = [self] + lock_ids.add(id(self)) + self._locked_tensordicts = [] + + self._is_shared = False + self._is_memmap = False + return unlocked_tds + + unlock_ = TensorDict.unlock_ + lock_ = TensorDict.lock_ + + @property + def data(self): + return self.apply(lambda x: x.data) + + @_unlock_and_set(inplace=True) + def flatten_keys( + self, separator: str = ".", inplace: bool = False + ) -> TensorDictBase: + ... + + @_unlock_and_set(inplace=True) + def unflatten_keys( + self, separator: str = ".", inplace: bool = False + ) -> TensorDictBase: + ... + + @_unlock_and_set(inplace=True) + def exclude(self, *keys: str, inplace: bool = False) -> TensorDictBase: + ... + + @_carry_over + def transpose(self, dim0, dim1): + ... + + @_carry_over + def permute( + self, + *dims_list: int, + dims: list[int] | None = None, + ) -> TensorDictBase: + ... + + @_carry_over + def squeeze(self, dim: int | None = None) -> TensorDictBase: + ... + + @_carry_over + def unsqueeze(self, dim: int) -> TensorDictBase: + ... + + @_unlock_and_set + def create_nested(self, key): + ... + + +TDPARAM_HANDLED_FUNCTIONS = copy(TD_HANDLED_FUNCTIONS) + + +def implements_for_tdparam(torch_function: Callable) -> Callable[[Callable], Callable]: + """Register a torch function override for TensorDictParams.""" + + @functools.wraps(torch_function) + def decorator(func: Callable) -> Callable: + TDPARAM_HANDLED_FUNCTIONS[torch_function] = func + return func + + return decorator + + +@implements_for_tdparam(torch.empty_like) +def _empty_like(td: TensorDictBase, *args, **kwargs) -> TensorDictBase: + try: + tdclone = td.clone() + except Exception as err: + raise RuntimeError( + "The tensordict passed to torch.empty_like cannot be " + "cloned, preventing empty_like to be called. " + "Consider calling tensordict.to_tensordict() first." + ) from err + return tdclone.data.apply_(lambda x: torch.empty_like(x, *args, **kwargs)) diff --git a/tensordict/persistent.py b/tensordict/persistent.py index 445eb1251..84d361760 100644 --- a/tensordict/persistent.py +++ b/tensordict/persistent.py @@ -39,6 +39,7 @@ DeviceType, expand_right, IndexType, + lock_blocked, NestedKey, NUMPY_TO_TORCH_DTYPE_DICT, ) @@ -475,6 +476,7 @@ def contiguous(self): """Materializes a PersistentTensorDict on a regular TensorDict.""" return self.to_tensordict() + @lock_blocked def del_(self, key): key = self._process_key(key) del self.file[key] diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index bf520ca04..898b70b10 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -72,8 +72,6 @@ except ImportError: from tensordict.utils import infer_size_impl -# from torch.utils._pytree import _register_pytree_node - _has_functorch = False try: @@ -782,21 +780,21 @@ def _set_tuple(self, key, value, *, inplace, validated): ... def set_at_( - self, key: NestedKey, value: CompatibleType, idx: IndexType + self, key: NestedKey, value: CompatibleType, index: IndexType ) -> TensorDictBase: """Sets the values in-place at the index indicated by :obj:`idx`. Args: key (str, tuple of str): key to be modified. value (torch.Tensor): value to be set at the index `idx` - idx (int, tensor or tuple): index where to write the values. + index (int, tensor or tuple): index where to write the values. Returns: self """ key = _unravel_key_to_tuple(key) - return self._set_at_tuple(key, value, idx, validated=False) + return self._set_at_tuple(key, value, index, validated=False) @abc.abstractmethod def _set_at_str(self, key, value, idx, *, validated): @@ -2003,6 +2001,8 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is not None and issubclass(exc_type, Exception): + return False _last_op = self._last_op_queue.pop() if _last_op is not None: last_op, (args, kwargs) = _last_op @@ -2051,6 +2051,10 @@ def __ne__(self, other: object) -> TensorDictBase: ) return True + # @abc.abstractmethod + # def __hash__(self): + # ... + def __eq__(self, other: object) -> TensorDictBase: """Compares two tensordicts against each other, for every key. The two tensordicts must have the same key set. @@ -2079,11 +2083,11 @@ def __eq__(self, other: object) -> TensorDictBase: return False @abc.abstractmethod - def del_(self, key: str) -> TensorDictBase: + def del_(self, key: NestedKey) -> TensorDictBase: """Deletes a key of the tensordict. Args: - key (str): key to be deleted + key (NestedKey): key to be deleted Returns: self @@ -3769,6 +3773,9 @@ def __init__( for key, value in source.items(): self.set(key, value) + # def __hash__(self): + # return hash((self._tensordict, self._batch_size, self._device)) + @classmethod def from_dict(cls, input_dict, batch_size=None, device=None, batch_dims=None): """Returns a TensorDict created from a dictionary or another :class:`TensorDict`. @@ -4146,7 +4153,8 @@ def _set_at_tuple(self, key, value, idx, *, validated): td._set_at_tuple(key[1:], value, idx, validated=validated) return self - def del_(self, key: str) -> TensorDictBase: + @lock_blocked + def del_(self, key: NestedKey) -> TensorDictBase: key = _unravel_key_to_tuple(key) if len(key) > 1: td, subkey = _get_leaf_tensordict(self, key) @@ -4526,7 +4534,8 @@ def __setstate__(self, state): for slot, value in state.items(): setattr(self, slot, value) self._cache = None - self._last_op = collections.deque() + self.__last_op_queue = None + self._last_op = None # some custom methods for efficiency def items( @@ -4805,7 +4814,6 @@ def _empty_like(td: TensorDictBase, *args, **kwargs) -> TensorDictBase: "cloned, preventing empty_like to be called. " "Consider calling tensordict.to_tensordict() first." ) from err - return tdclone.apply_(lambda x: torch.empty_like(x, *args, **kwargs)) @@ -5260,9 +5268,8 @@ def __init__( if batch_size is not None and batch_size != self.batch_size: raise RuntimeError("batch_size does not match self.batch_size.") - # @staticmethod - # def _convert_range(idx): - # return tuple(list(_idx) if isinstance(_idx, range) else _idx for _idx in idx) + # def __hash__(self): + # return hash((self._source, self.idx)) @staticmethod def _convert_ellipsis(idx, shape): @@ -5605,7 +5612,8 @@ def get_parent_tensordict(self) -> TensorDictBase: ) return self._source - def del_(self, key: str) -> TensorDictBase: + @lock_blocked + def del_(self, key: NestedKey) -> TensorDictBase: self._source = self._source.del_(key) return self @@ -7025,6 +7033,9 @@ def __getitem__(self, index: IndexType) -> TensorDictBase: # return out[td_index] # return out + # def __hash__(self): + # return hash(self.tensordicts) + def __eq__(self, other): if is_tensorclass(other): return other == self @@ -7155,7 +7166,8 @@ def _irecv( future.wait() return - def del_(self, key: str, **kwargs: Any) -> TensorDictBase: + @lock_blocked + def del_(self, key: NestedKey, **kwargs: Any) -> TensorDictBase: ids = set() cur_len = len(ids) is_deleted = False @@ -7605,6 +7617,9 @@ def __init__( if batch_size is not None and batch_size != self.batch_size: raise RuntimeError("batch_size does not match self.batch_size.") + # def __hash__(self): + # return hash((self._source, self.custom_op, self.inv_op, self.custom_op_kwargs, self.inv_op_kwargs)) + def _update_custom_op_kwargs(self, source_tensor: Tensor) -> dict[str, Any]: """Allows for a transformation to be customized for a certain shape, device or dtype. @@ -7713,13 +7728,14 @@ def _set_tuple(self, key, value, *, inplace: bool, validated: bool): if source is None: self._source._create_nested_str(key[0]) source = self._source._get_str(key[0], NO_DEFAULT) - type(self)( + nested = type(self)( source, custom_op=self.custom_op, inv_op=self.inv_op, custom_op_kwargs=self._update_custom_op_kwargs(source), inv_op_kwargs=self._update_inv_op_kwargs(source), - )._set_tuple(key[1:], value, inplace=inplace, validated=validated) + ) + nested._set_tuple(key[1:], value, inplace=inplace, validated=validated) return self def _set_at_str(self, key, value, idx, *, validated): @@ -7837,7 +7853,8 @@ def rename_key_( rename_key = _renamed_inplace_method(rename_key_) - def del_(self, key: str) -> _CustomOpTensorDict: + @lock_blocked + def del_(self, key: NestedKey) -> _CustomOpTensorDict: self._source = self._source.del_(key) return self diff --git a/test/_utils_internal.py b/test/_utils_internal.py index ed64284e9..8f266a119 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -9,6 +9,7 @@ import torch from tensordict import PersistentTensorDict, tensorclass, TensorDict +from tensordict.nn.params import TensorDictParams from tensordict.tensordict import _stack as stack_td @@ -224,6 +225,9 @@ def td_h5( ) return td_h5 + def td_params(self, device): + return TensorDictParams(self.td(device)) + def expand_list(list_of_tensors, *dims): n = len(list_of_tensors) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 57f9d4d26..19fcf6ca0 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -523,6 +523,7 @@ def test_convert_ellipsis_to_idx_invalid(ellipsis_index, expectation): "nested_tensorclass", "permute_td", "nested_stacked_td", + "td_params", pytest.param( "td_h5", marks=pytest.mark.skipif(not _has_h5py, reason="h5py not found.") ), @@ -539,8 +540,29 @@ def test_permute_applied_twice(self, td_name, device): other_p = inv_p while (other_p == inv_p).all(): other_p = torch.randperm(4) - assert tensordict.permute(*p).permute(*inv_p) is tensordict - assert tensordict.permute(*p).permute(*other_p) is not tensordict + other_p = tuple(other_p.tolist()) + p = tuple(p.tolist()) + inv_p = tuple(inv_p.tolist()) + if td_name in ("td_params",): + assert ( + tensordict.permute(*p).permute(*inv_p)._param_td + is tensordict._param_td + ) + assert ( + tensordict.permute(*p).permute(*other_p)._param_td + is not tensordict._param_td + ) + assert torch.permute(tensordict, p).permute( + inv_p + )._param_td is tensordict._param_td + assert torch.permute(tensordict, p).permute( + other_p + )._param_td is not tensordict._param_td + else: + assert tensordict.permute(*p).permute(*inv_p) is tensordict + assert tensordict.permute(*p).permute(*other_p) is not tensordict + assert torch.permute(tensordict, p).permute(inv_p) is tensordict + assert torch.permute(tensordict, p).permute(other_p) is not tensordict def test_to_tensordict(self, td_name, device): torch.manual_seed(1) @@ -614,7 +636,8 @@ def test_exclude(self, td_name, device): and "a" not in td2.clone().keys() ) - td2 = td.exclude("a", inplace=True) + with td.unlock_(): + td2 = td.exclude("a", inplace=True) assert td2 is td def test_assert(self, td_name, device): @@ -656,19 +679,29 @@ def test_broadcast(self, td_name, device): sub_td = td[:, :2].to_tensordict() sub_td.zero_() sub_dict = sub_td.to_dict() - td[:, :2] = sub_dict + if td_name == "td_params": + td_set = td.data + else: + td_set = td + td_set[:, :2] = sub_dict assert (td[:, :2] == 0).all() @pytest.mark.parametrize("call_del", [True, False]) def test_remove(self, td_name, device, call_del): torch.manual_seed(1) td = getattr(self, td_name)(device) - if call_del: - del td["a"] - else: - td = td.del_("a") + with td.unlock_(): + if call_del: + del td["a"] + else: + td = td.del_("a") assert td is not None assert "a" not in td.keys() + if td_name in ("sub_td", "sub_td2"): + return + td.lock_() + with pytest.raises(RuntimeError, match="locked"): + del td["b"] def test_set_unexisting(self, td_name, device): torch.manual_seed(1) @@ -686,9 +719,17 @@ def test_set_unexisting(self, td_name, device): def test_fill_(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device) - new_td = td.fill_("a", 0.1) + if td_name == "td_params": + td_set = td.data + else: + td_set = td + new_td = td_set.fill_("a", 0.1) assert (td.get("a") == 0.1).all() - assert new_td is td + assert new_td is td_set + + def test_shape(self, td_name, device): + td = getattr(self, td_name)(device) + assert td.shape == td.batch_size def test_flatten_unflatten(self, td_name, device): td = getattr(self, td_name)(device) @@ -710,8 +751,12 @@ def test_masked_fill_(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device) mask = torch.zeros(td.shape, dtype=torch.bool, device=device).bernoulli_() - new_td = td.masked_fill_(mask, -10.0) - assert new_td is td + if td_name == "td_params": + td_set = td.data + else: + td_set = td + new_td = td_set.masked_fill_(mask, -10.0) + assert new_td is td_set for item in td.values(): assert (item[mask] == -10).all(), item[mask] @@ -778,7 +823,11 @@ def test_lock_write(self, td_name, device): for key, item in td_clone.items(True): with pytest.raises(RuntimeError, match="Cannot modify locked TensorDict"): td.set(key, item) - td.set_(key, item) + if td_name == "td_params": + td_set = td.data + else: + td_set = td + td_set.set_(key, item) def test_unlock(self, td_name, device): torch.manual_seed(1) @@ -846,7 +895,8 @@ def test_cache(self, td_name, device, op): b = td.unflatten_keys() assert a is b - assert len(td._cache) + if td_name != "td_params": + assert len(td._cache) td.unlock_() assert td._cache is None for val in td.values(True): @@ -920,6 +970,12 @@ def test_masked_fill(self, td_name, device): def test_zero_(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device) + if td_name == "td_params": + with pytest.raises( + RuntimeError, match="a leaf Variable that requires grad" + ): + new_td = td.zero_() + return new_td = td.zero_() assert new_td is td for k in td.keys(): @@ -929,6 +985,10 @@ def test_zero_(self, td_name, device): def test_apply(self, td_name, device, inplace): td = getattr(self, td_name)(device) td_c = td.to_tensordict() + if inplace and td_name == "td_params": + with pytest.raises(ValueError, match="Failed to update"): + td.apply(lambda x: x + 1, inplace=inplace) + return td_1 = td.apply(lambda x: x + 1, inplace=inplace) if inplace: for key in td.keys(True, True): @@ -943,7 +1003,11 @@ def test_apply(self, td_name, device, inplace): def test_apply_other(self, td_name, device, inplace): td = getattr(self, td_name)(device) td_c = td.to_tensordict() - td_1 = td.apply(lambda x, y: x + y, td_c, inplace=inplace) + if inplace and td_name == "td_params": + td_set = td.data + else: + td_set = td + td_1 = td_set.apply(lambda x, y: x + y, td_c, inplace=inplace) if inplace: for key in td.keys(True, True): assert (td_c[key] * 2 == td[key]).all() @@ -999,9 +1063,18 @@ def test_equal(self, td_name, device): def test_equal_float(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device) - td.zero_() + if td_name == "td_params": + td_set = td.data + else: + td_set = td + td_set.zero_() assert (td == 0.0).all() - td0 = td.clone().zero_() + td0 = td.clone() + if td_name == "td_params": + td_set = td0.data + else: + td_set = td0 + td_set.zero_() assert (td0 != 1.0).all() def test_equal_other(self, td_name, device): @@ -1012,7 +1085,11 @@ def test_equal_other(self, td_name, device): def test_equal_int(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device) - td.zero_() + if td_name == "td_params": + td_set = td.data + else: + td_set = td + td_set.zero_() assert (td == 0).all() td0 = td.to_tensordict().zero_() assert (td0 != 1).all() @@ -1020,7 +1097,11 @@ def test_equal_int(self, td_name, device): def test_equal_tensor(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device) - td.zero_() + if td_name == "td_params": + td_set = td.data + else: + td_set = td + td_set.zero_() assert (td == torch.zeros([], dtype=torch.int, device=device)).all() td0 = td.to_tensordict().zero_() assert (td0 != torch.ones([], dtype=torch.int, device=device)).all() @@ -1046,6 +1127,12 @@ def test_gather(self, td_name, device, dim): # gather with out td_gather.zero_() out = td_gather.clone() + if td_name == "td_params": + with pytest.raises( + RuntimeError, match="don't support automatic differentiation" + ): + torch.gather(td, dim=dim, index=index, out=out) + return td_gather2 = torch.gather(td, dim=dim, index=index, out=out) assert (td_gather2 != 0).any() @@ -1063,7 +1150,13 @@ def test_masking_set(self, td_name, device): ), batch_size=[n, *td.batch_size[d:]], ) - td[mask] = pseudo_td + + if td_name == "td_params": + td_set = td.data + else: + td_set = td + + td_set[mask] = pseudo_td for item in td.values(): assert (item[mask] == 0).all() @@ -1218,6 +1311,10 @@ def test_squeeze_with_none(self, td_name, device, squeeze_dim=None): td = getattr(self, td_name)(device) td_squeeze = torch.squeeze(td, dim=None) tensor = torch.ones_like(td.get("a").squeeze()) + if td_name == "td_params": + with pytest.raises(ValueError, match="Failed to update"): + td_squeeze.set_("a", tensor) + return td_squeeze.set_("a", tensor) assert (td_squeeze.get("a") == tensor).all() if td_name == "unsqueezed_td": @@ -1267,6 +1364,7 @@ def test_exclude_nested(self, td_name, device, nested): "unsqueezed_td", "squeezed_td", "permute_td", + "td_params", ): # TODO: document this as an edge-case: with a sub-tensordict, exclude acts on the parent tensordict # perhaps exclude should return an error in these cases? @@ -1320,12 +1418,21 @@ def test_update(self, td_name, device, clone): def test_update_at_(self, td_name, device): td = getattr(self, td_name)(device) td0 = td[1].clone().zero_() + if td_name == "td_params": + with pytest.raises(RuntimeError, match="a view of a leaf Variable"): + td.update_at_(td0, 0) + return td.update_at_(td0, 0) assert (td[0] == 0).all() def test_write_on_subtd(self, td_name, device): td = getattr(self, td_name)(device) sub_td = td.get_sub_tensordict(0) + # should not work with td_params + if td_name == "td_params": + with pytest.raises(RuntimeError, match="a view of a leaf"): + sub_td["a"] = torch.full((3, 2, 1, 5), 1.0, device=device) + return sub_td["a"] = torch.full((3, 2, 1, 5), 1.0, device=device) assert (td["a"][0] == 1).all() @@ -1541,7 +1648,10 @@ def test_rename_key(self, td_name, device) -> None: torch.testing.assert_close(new_z, td.get("z")) new_z = torch.randn_like(z) - td.set_("z", new_z) + if td_name == "td_params": + td.data.set_("z", new_z) + else: + td.set_("z", new_z) torch.testing.assert_close(new_z, td.get("z")) def test_rename_key_nested(self, td_name, device) -> None: @@ -1598,11 +1708,22 @@ def test_setitem_ellipsis(self, td_name, device, actual_index): idx = actual_index td_clone = td.clone() - actual_td = td_clone[idx].clone().zero_() + actual_td = td_clone[idx].clone() + if td_name in ("td_params",): + td_set = actual_td.apply(lambda x: x.data) + else: + td_set = actual_td + td_set.zero_() for key in actual_td.keys(): assert (actual_td.get(key) == 0).all() - td_clone[idx] = actual_td + + if td_name in ("td_params",): + td_set = td_clone.data + else: + td_set = td_clone + + td_set[idx] = actual_td for key in td_clone.keys(): assert (td_clone[idx].get(key) == 0).all() @@ -1617,7 +1738,10 @@ def test_setitem(self, td_name, device, idx): return td_clone = td[idx].to_tensordict().zero_() - td[idx] = td_clone + if td_name == "td_params": + td.data[idx] = td_clone + else: + td[idx] = td_clone assert (td[idx].get("a") == 0).all() td_clone = torch.cat([td_clone, td_clone], 0) @@ -1698,11 +1822,17 @@ def test_transpose(self, td_name, device): tdt = td.transpose(-1, -2) for key, value in tdt.items(True): assert value.shape == td.get(key).transpose(2, 3).shape - assert tdt.transpose(-1, -2) is td + if td_name in ("td_params",): + assert tdt.transpose(-1, -2)._param_td is td._param_td + else: + assert tdt.transpose(-1, -2) is td with td.unlock_(): tdt.set(("some", "transposed", "tensor"), torch.zeros(tdt.shape)) assert td.get(("some", "transposed", "tensor")).shape == td.shape - assert td.transpose(0, 0) is td + if td_name in ("td_params",): + assert td.transpose(0, 0)._param_td is td._param_td + else: + assert td.transpose(0, 0) is td with pytest.raises( ValueError, match="The provided dimensions are incompatible" ): @@ -1774,7 +1904,10 @@ def test_tensordict_set_dict_value(self, td_name, device): # test set_ val2 = {"subkey1": torch.zeros(4, 3, 2, 1, 10)} - td.set_("key1", val2) + if td_name in ("td_params",): + td.data.set_("key1", val2) + else: + td.set_("key1", val2) assert (td.get("key1").get("subkey1") == 0).all() if td_name not in ("stacked_td", "nested_stacked_td"): @@ -1790,6 +1923,10 @@ def test_tensordict_set_dict_value(self, td_name, device): def test_delitem(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device) + if td_name in ("memmap_td",): + with pytest.raises(RuntimeError, match="Cannot modify"): + del td["a"] + return del td["a"] assert "a" not in td.keys() @@ -1844,6 +1981,10 @@ def test_update_subtensordict(self, td_name, device, index): td0 = td0.to_tensordict() td0 = td0.apply(lambda x: x * 0 + 2) assert sub_td.shape == td0.shape + if td_name == "td_params": + with pytest.raises(RuntimeError, match="a leaf Variable"): + sub_td.update(td0) + return sub_td.update(td0) assert (sub_td == 2).all() assert (td[index] == 2).all() @@ -1856,10 +1997,23 @@ def test_stack_onto(self, td_name, device, tmpdir): td0 = td.clone(newfile=tmpdir / "file0.h5").apply_(lambda x: x.zero_()) td1 = td.clone(newfile=tmpdir / "file1.h5").apply_(lambda x: x.zero_() + 1) else: - td0 = td.clone().apply_(lambda x: x.zero_()) - td1 = td.clone().apply_(lambda x: x.zero_() + 1) + td0 = td.clone() + if td_name in ("td_params",): + td0.data.apply_(lambda x: x.zero_()) + else: + td0.apply_(lambda x: x.zero_()) + td1 = td.clone() + if td_name in ("td_params",): + td1.data.apply_(lambda x: x.zero_() + 1) + else: + td1.apply_(lambda x: x.zero_() + 1) + td_out = td.unsqueeze(1).expand(td.shape[0], 2, *td.shape[1:]).clone() td_stack = torch.stack([td0, td1], 1) + if td_name == "td_params": + with pytest.raises(RuntimeError, match="out.batch_size and stacked"): + torch.stack([td0, td1], 0, out=td_out) + return torch.stack([td0, td1], 1, out=td_out) assert (td_stack == td_out).all() @@ -1885,6 +2039,10 @@ def test_stack_tds_on_subclass(self, td_name, device): with pytest.raises(IndexError, match="storages of the indexed tensors"): torch.stack(tds_list, 0, out=td) return + if td_name == "td_params": + with pytest.raises(RuntimeError, match="arguments don't support automatic"): + torch.stack(tds_list, 0, out=td) + return stacked_td = torch.stack(tds_list, 0, out=td) assert stacked_td.batch_size == td.batch_size assert stacked_td is td @@ -1897,6 +2055,10 @@ def test_stack_subclasses_on_td(self, td_name, device): td = getattr(self, td_name)(device) td = td.expand(3, *td.batch_size).clone().zero_() tds_list = [getattr(self, td_name)(device) for _ in range(3)] + if td_name == "td_params": + with pytest.raises(RuntimeError, match="arguments don't support automatic"): + torch.stack(tds_list, 0, out=td) + return stacked_td = stack_td(tds_list, 0, out=td) assert stacked_td.batch_size == td.batch_size for key in ("a", "b", "c"): @@ -1991,6 +2153,8 @@ def test_items_values_keys(self, td_name, device): def test_set_requires_grad(self, td_name, device): td = getattr(self, td_name)(device) + if td_name in ("td_params",): + td.apply(lambda x: x.requires_grad_(False)) td.unlock_() assert not td.get("a").requires_grad if td_name in ("td_h5",): @@ -2167,7 +2331,7 @@ def test_memmap_(self, td_name, device): match="Converting a sub-tensordict values to memmap cannot be done", ): td.memmap_() - elif td_name in ("td_h5",): + elif td_name in ("td_h5", "td_params"): with pytest.raises( RuntimeError, match="Cannot build a memmap TensorDict in-place", @@ -2199,7 +2363,7 @@ def test_memmap_prefix(self, td_name, device, tmp_path): ): td.memmap_(tmp_path / "tensordict") return - elif td_name in ("td_h5",): + elif td_name in ("td_h5", "td_params"): with pytest.raises( RuntimeError, match="Cannot build a memmap TensorDict in-place", @@ -2229,7 +2393,7 @@ def test_memmap_existing(self, td_name, device, copy_existing, tmp_path): pytest.skip( "Memmap case is redundant, functionality checked by other cases" ) - elif td_name in ("sub_td", "sub_td2", "td_h5"): + elif td_name in ("sub_td", "sub_td2", "td_h5", "td_params"): pytest.skip( "SubTensorDict/H5 and memmap_ incompatibility is checked elsewhere" ) @@ -2395,58 +2559,91 @@ def test_pop(self, td_name, device): td = getattr(self, td_name)(device) assert "a" in td.keys() a = td["a"].clone() - out = td.pop("a") - assert (out == a).all() - assert "a" not in td.keys() + with td.unlock_(): + out = td.pop("a") + assert (out == a).all() + assert "a" not in td.keys() - assert "b" in td.keys() - b = td["b"].clone() - default = torch.zeros_like(b).to(device) - assert (default != b).all() - out = td.pop("b", default) + assert "b" in td.keys() + b = td["b"].clone() + default = torch.zeros_like(b).to(device) + assert (default != b).all() + out = td.pop("b", default) - assert torch.ne(out, default).all() - assert (out == b).all() + assert torch.ne(out, default).all() + assert (out == b).all() - assert "z" not in td.keys() - out = td.pop("z", default) - assert (out == default).all() + assert "z" not in td.keys() + out = td.pop("z", default) + assert (out == default).all() - with pytest.raises( - KeyError, - match=re.escape(r"You are trying to pop key"), - ): - td.pop("z") + with pytest.raises( + KeyError, + match=re.escape(r"You are trying to pop key"), + ): + td.pop("z") def test_setitem_slice(self, td_name, device): td = getattr(self, td_name)(device) - td[:] = td.clone() - td[:1] = td[:1].clone().zero_() + if td_name == "td_params": + td_set = td.data + else: + td_set = td + td_set[:] = td.clone() + td_set[:1] = td[:1].clone().zero_() assert (td[:1] == 0).all() td = getattr(self, td_name)(device) - td[:1] = td[:1].to_tensordict().zero_() + if td_name == "td_params": + td_set = td.data + else: + td_set = td + td_set[:1] = td[:1].to_tensordict().zero_() assert (td[:1] == 0).all() # with broadcast td = getattr(self, td_name)(device) - td[:1] = td[0].clone().zero_() + if td_name == "td_params": + td_set = td.data + else: + td_set = td + td_set[:1] = td[0].clone().zero_() assert (td[:1] == 0).all() td = getattr(self, td_name)(device) - td[:1] = td[0].to_tensordict().zero_() + if td_name == "td_params": + td_set = td.data + else: + td_set = td + td_set[:1] = td[0].to_tensordict().zero_() assert (td[:1] == 0).all() td = getattr(self, td_name)(device) - td[:1, 0] = td[0, 0].clone().zero_() + if td_name == "td_params": + td_set = td.data + else: + td_set = td + td_set[:1, 0] = td[0, 0].clone().zero_() assert (td[:1, 0] == 0).all() td = getattr(self, td_name)(device) - td[:1, 0] = td[0, 0].to_tensordict().zero_() + if td_name == "td_params": + td_set = td.data + else: + td_set = td + td_set[:1, 0] = td[0, 0].to_tensordict().zero_() assert (td[:1, 0] == 0).all() td = getattr(self, td_name)(device) - td[:1, :, 0] = td[0, :, 0].clone().zero_() + if td_name == "td_params": + td_set = td.data + else: + td_set = td + td_set[:1, :, 0] = td[0, :, 0].clone().zero_() assert (td[:1, :, 0] == 0).all() td = getattr(self, td_name)(device) - td[:1, :, 0] = td[0, :, 0].to_tensordict().zero_() + if td_name == "td_params": + td_set = td.data + else: + td_set = td + td_set[:1, :, 0] = td[0, :, 0].to_tensordict().zero_() assert (td[:1, :, 0] == 0).all() def test_casts(self, td_name, device): @@ -2475,6 +2672,11 @@ def test_empty_like(self, td_name, device): # we do not call skip to avoid systematic skips in internal code base return td_empty = torch.empty_like(td) + if td_name == "td_params": + with pytest.raises(ValueError, match="Failed to update"): + td.apply_(lambda x: x + 1.0) + return + td.apply_(lambda x: x + 1.0) assert type(td) is type(td_empty) assert all(val.any() for val in (td != td_empty).values(True, True)) @@ -2495,6 +2697,12 @@ def test_add_batch_dim_cache(self, td_name, device, nested): fun(td) return fun(td) + + if td_name == "td_params": + with pytest.raises(RuntimeError, match="leaf Variable that requires grad"): + td.zero_() + return + td.zero_() # this value should be cached std = fun(td) From bf42255b51420f925da6bcce78a355d89356fd60 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 26 Jul 2023 20:51:27 +0100 Subject: [PATCH 2/4] amend --- tensordict/nn/params.py | 6 ++-- test/test_nn.py | 68 ++++++++++++++++++++--------------------- test/test_tensordict.py | 14 +++++---- 3 files changed, 45 insertions(+), 43 deletions(-) diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index 068f97a5e..64868ce1c 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -263,15 +263,15 @@ def get( def __getitem__(self, index: IndexType) -> TensorDictBase: ... - @_replace + @_carry_over def to(self, dest: DeviceType | type | torch.Size, **kwargs) -> TensorDictBase: ... - @_replace + @_carry_over def cpu(self): ... - @_replace + @_carry_over def cuda(self): ... diff --git a/test/test_nn.py b/test/test_nn.py index 067fde1f4..20dde26e7 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -1078,15 +1078,15 @@ def test_functional(self): assert hasattr(tdmodule, "__setitem__") assert len(tdmodule) == 3 tdmodule[1] = tdmodule2 - params.unlock_() - params["module", "1"] = params["module", "2"] - params.lock_() + with params.unlock_(): + params["module", "1"] = params["module", "2"] assert len(tdmodule) == 3 assert hasattr(tdmodule, "__delitem__") assert len(tdmodule) == 3 del tdmodule[2] - del params["module", "2"] + with params.unlock_(): + del params["module", "2"] assert len(tdmodule) == 2 assert hasattr(tdmodule, "__getitem__") @@ -1159,16 +1159,16 @@ def test_functional_probabilistic_deprec(self): assert len(tdmodule) == 4 tdmodule[1] = tdmodule2 tdmodule[2] = prob_module - params.unlock_() - params["module", "1"] = params["module", "2"] - params["module", "2"] = params["module", "3"] - params.lock_() + with params.unlock_(): + params["module", "1"] = params["module", "2"] + params["module", "2"] = params["module", "3"] assert len(tdmodule) == 4 assert hasattr(tdmodule, "__delitem__") assert len(tdmodule) == 4 del tdmodule[3] - del params["module", "3"] + with params.unlock_(): + del params["module", "3"] assert len(tdmodule) == 3 assert hasattr(tdmodule.module, "__getitem__") @@ -1221,17 +1221,17 @@ def test_functional_probabilistic(self): tdmodule[1] = tdmodule2 tdmodule[2] = normal_params tdmodule[3] = prob_module - params.unlock_() - params["module", "1"] = params["module", "2"] - params["module", "2"] = params["module", "3"] - params["module", "3"] = params["module", "4"] - params.lock_() + with params.unlock_(): + params["module", "1"] = params["module", "2"] + params["module", "2"] = params["module", "3"] + params["module", "3"] = params["module", "4"] assert len(tdmodule) == 5 assert hasattr(tdmodule, "__delitem__") assert len(tdmodule) == 5 del tdmodule[4] - del params["module", "4"] + with params.unlock_(): + del params["module", "4"] assert len(tdmodule) == 4 assert hasattr(tdmodule.module, "__getitem__") @@ -1273,15 +1273,15 @@ def test_functional_with_buffer(self): assert hasattr(tdmodule, "__setitem__") assert len(tdmodule) == 3 tdmodule[1] = tdmodule2 - params.unlock_() - params["module", "1"] = params["module", "2"] - params.lock_() + with params.unlock_(): + params["module", "1"] = params["module", "2"] assert len(tdmodule) == 3 assert hasattr(tdmodule, "__delitem__") assert len(tdmodule) == 3 del tdmodule[2] - del params["module", "2"] + with params.unlock_(): + del params["module", "2"] assert len(tdmodule) == 2 assert hasattr(tdmodule, "__getitem__") @@ -1333,16 +1333,16 @@ def test_functional_with_buffer_probabilistic_deprec(self): assert len(tdmodule.module) == 4 tdmodule[1] = tdmodule2 tdmodule[2] = prob_module - params.unlock_() - params["module", "1"] = params["module", "2"] - params["module", "2"] = params["module", "3"] - params.lock_() + with params.unlock_(): + params["module", "1"] = params["module", "2"] + params["module", "2"] = params["module", "3"] assert len(tdmodule) == 4 assert hasattr(tdmodule.module, "__delitem__") assert len(tdmodule.module) == 4 del tdmodule.module[3] - del params["module", "3"] + with params.unlock_(): + del params["module", "3"] assert len(tdmodule.module) == 3 assert hasattr(tdmodule.module, "__getitem__") @@ -1399,17 +1399,17 @@ def test_functional_with_buffer_probabilistic(self): tdmodule[1] = tdmodule2 tdmodule[2] = normal_params tdmodule[3] = prob_module - params.unlock_() - params["module", "1"] = params["module", "2"] - params["module", "2"] = params["module", "3"] - params["module", "3"] = params["module", "4"] - params.lock_() + with params.unlock_(): + params["module", "1"] = params["module", "2"] + params["module", "2"] = params["module", "3"] + params["module", "3"] = params["module", "4"] assert len(tdmodule) == 5 assert hasattr(tdmodule.module, "__delitem__") assert len(tdmodule.module) == 5 del tdmodule.module[4] - del params["module", "4"] + with params.unlock_(): + del params["module", "4"] assert len(tdmodule.module) == 4 assert hasattr(tdmodule.module, "__getitem__") @@ -1450,15 +1450,15 @@ def test_vmap(self): assert hasattr(tdmodule, "__setitem__") assert len(tdmodule) == 3 tdmodule[1] = tdmodule2 - params.unlock_() - params["module", "1"] = params["module", "2"] - params.lock_() + with params.unlock_(): + params["module", "1"] = params["module", "2"] assert len(tdmodule) == 3 assert hasattr(tdmodule, "__delitem__") assert len(tdmodule) == 3 del tdmodule[2] - del params["module", "2"] + with params.unlock_(): + del params["module", "2"] assert len(tdmodule) == 2 assert hasattr(tdmodule, "__getitem__") diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 19fcf6ca0..f9dec0f59 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -552,12 +552,14 @@ def test_permute_applied_twice(self, td_name, device): tensordict.permute(*p).permute(*other_p)._param_td is not tensordict._param_td ) - assert torch.permute(tensordict, p).permute( - inv_p - )._param_td is tensordict._param_td - assert torch.permute(tensordict, p).permute( - other_p - )._param_td is not tensordict._param_td + assert ( + torch.permute(tensordict, p).permute(inv_p)._param_td + is tensordict._param_td + ) + assert ( + torch.permute(tensordict, p).permute(other_p)._param_td + is not tensordict._param_td + ) else: assert tensordict.permute(*p).permute(*inv_p) is tensordict assert tensordict.permute(*p).permute(*other_p) is not tensordict From 73234d7163aaccfd55276ebe0027301ddb18f837 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 26 Jul 2023 21:14:20 +0100 Subject: [PATCH 3/4] amend --- docs/source/reference/nn.rst | 1 + tensordict/nn/__init__.py | 1 + tensordict/nn/params.py | 54 ++++++++++++++++++++++++++++++++++-- test/test_nn.py | 46 ++++++++++++++++++++++++++++++ 4 files changed, 100 insertions(+), 2 deletions(-) diff --git a/docs/source/reference/nn.rst b/docs/source/reference/nn.rst index 27dc10a35..be0e6560d 100644 --- a/docs/source/reference/nn.rst +++ b/docs/source/reference/nn.rst @@ -334,3 +334,4 @@ Utils biased_softplus set_skip_existing skip_existing + TensorDictParams diff --git a/tensordict/nn/__init__.py b/tensordict/nn/__init__.py index 575d14567..66d18eff3 100644 --- a/tensordict/nn/__init__.py +++ b/tensordict/nn/__init__.py @@ -50,5 +50,6 @@ "make_tensordict", "biased_softplus", "inv_softplus", + "TensorDictParams", "is_functional", ] diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index 64868ce1c..296e69fb9 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -141,13 +141,60 @@ class TensorDictParams(TensorDictBase, nn.Module): the leaves of the tensordict. Indexing works exactly as the indexing of the wrapped tensordict. - TODO: Parameter names + The parameter names will be registered within this module using :meth:`~.TensorDict.flatten_keys("_")`. + Therefore, the result of :meth:`~.named_parameters()` and the content of the + tensordict will differ slightly in term of key names. Any operation that sets a tensor in the tensordict will be augmented by a :class:`torch.nn.Parameter` conversion. + + Args: + parameters (TensorDictBase): a tensordict to represent as parameters. + Values will be converted to parameters unless ``no_convert=True``. + + Keyword Args: + no_convert (bool): if ``True``, no conversion to ``nn.Parameter`` will occur. + Defaults to ``False``. + + Examples: + >>> from torch import nn + >>> from tensordict import TensorDict + >>> module = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 4)) + >>> params = TensorDict.from_module(module) + >>> params.lock_() + >>> p = TensorDictParams(params) + >>> print(p) + TensorDictParams(params=TensorDict( + fields={ + 0: TensorDict( + fields={ + bias: Parameter(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), + weight: Parameter(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + 1: TensorDict( + fields={ + bias: Parameter(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), + weight: Parameter(shape=torch.Size([4, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False)) + >>> class CustomModule(nn.Module): + ... def __init__(self, params): + ... super().__init__() + ... self.params = params + >>> m = CustomModule(p) + >>> # the wrapper supports assignment and values are turned in Parameter + >>> m.params['other'] = torch.randn(3) + >>> assert isinstance(m.params['other'], nn.Parameter) + """ - def __init__(self, parameters: TensorDictBase, no_convert=False): + def __init__(self, parameters: TensorDictBase, *, no_convert=False): super().__init__() self._param_td = parameters if not no_convert: @@ -511,6 +558,9 @@ def unsqueeze(self, dim: int) -> TensorDictBase: def create_nested(self, key): ... + def __repr__(self): + return f"TensorDictParams(params={self._param_td})" + TDPARAM_HANDLED_FUNCTIONS = copy(TD_HANDLED_FUNCTIONS) diff --git a/test/test_nn.py b/test/test_nn.py index 20dde26e7..efae6ce83 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -20,6 +20,7 @@ ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential, TensorDictModuleBase, + TensorDictParams, TensorDictSequential, ) from tensordict.nn.common import TensorDictModule, TensorDictModuleWrapper @@ -2783,6 +2784,51 @@ def test_reset_once(self): ), f"Reset parameters called {lin.reset_parameters.call_count} times should be 2" +class TestTensorDictParams: + def _get_params(self): + module = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 4)) + params = TensorDict.from_module(module) + params.lock_() + return params + + class CustomModule(nn.Module): + def __init__(self, params): + super().__init__() + self.params = params + + def test_td_params(self): + params = self._get_params() + p = TensorDictParams(params) + m = self.CustomModule(p) + assert ( + TensorDict(dict(m.named_parameters()), []) + == TensorDict({"params": params.flatten_keys("_")}, []).flatten_keys(".") + ).all() + + assert not m.params.is_locked + assert m.params._param_td.is_locked + + assert ( + m.params["0", "weight"] is not None + ) # assess that param can be accessed via nested indexing + + # assert assignment + m.params["other"] = torch.randn(3) + assert isinstance(m.params["other"], nn.Parameter) + assert m.params["other"].requires_grad + + # change that locking is unchanged + assert not m.params.is_locked + assert m.params._param_td.is_locked + + assert m.params.other.requires_grad + del m.params["other"] + + assert m.params["0", "weight"].requires_grad + assert (m.params == params).all() + assert (params == m.params).all() + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) From 07284d7149d69b7ece7c9526082c1070606a4b61 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 27 Jul 2023 08:23:50 +0100 Subject: [PATCH 4/4] amend --- tensordict/nn/params.py | 20 +++++++++++++------- test/test_nn.py | 9 +++++++++ 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index 296e69fb9..f77ae6c50 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -310,17 +310,23 @@ def get( def __getitem__(self, index: IndexType) -> TensorDictBase: ... - @_carry_over def to(self, dest: DeviceType | type | torch.Size, **kwargs) -> TensorDictBase: - ... + params = self._param_td.to(dest) + if params is self._param_td: + return self + return TensorDictParams(params) - @_carry_over def cpu(self): - ... + params = self._param_td.cpu() + if params is self._param_td: + return self + return TensorDictParams(params) - @_carry_over - def cuda(self): - ... + def cuda(self, device=None): + params = self._param_td.cuda(device=device) + if params is self._param_td: + return self + return TensorDictParams(params) def clone(self, recurse: bool = True) -> TensorDictBase: return TensorDictParams(self._param_td.clone(recurse=recurse)) diff --git a/test/test_nn.py b/test/test_nn.py index efae6ce83..06e721c25 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2828,6 +2828,15 @@ def test_td_params(self): assert (m.params == params).all() assert (params == m.params).all() + def test_td_params_cast(self): + params = self._get_params() + p = TensorDictParams(params) + m = self.CustomModule(p) + for dtype in ("half", "double", "float"): + getattr(m, dtype)() + for p in params.values(True, True): + assert p.dtype == getattr(torch, dtype) + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args()