From d2edd94bd7b03b508d096fa15d961064d704f638 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 2 Jun 2023 14:40:08 +0100 Subject: [PATCH] [Feature] Unravel nested keys (#403) --- tensordict/persistent.py | 2 ++ tensordict/tensordict.py | 16 +++++++++++---- tensordict/utils.py | 42 +++++++++++++++++++++++++++++++++++----- test/test_tensordict.py | 14 ++++++++++++++ 4 files changed, 65 insertions(+), 9 deletions(-) diff --git a/tensordict/persistent.py b/tensordict/persistent.py index 7e9a28a9a..d19b71ed5 100644 --- a/tensordict/persistent.py +++ b/tensordict/persistent.py @@ -33,6 +33,7 @@ TensorDictBase, ) from tensordict.utils import ( + _maybe_unravel_keys_silent, _shape, DeviceType, expand_right, @@ -363,6 +364,7 @@ def __getitem__(self, item): __getitems__ = __getitem__ def __setitem__(self, index, value): + index = _maybe_unravel_keys_silent(index) if isinstance(index, str) or ( isinstance(index, tuple) and all(isinstance(val, str) for val in index) ): diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index 2b4f95946..a2c1962f7 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -40,7 +40,7 @@ _get_item, _getitem_batch_size, _is_shared, - _nested_key_check, + _maybe_unravel_keys_silent, _set_item, _shape, _sub_index, @@ -52,6 +52,7 @@ int_generator, NestedKey, prod, + unravel_keys, ) from torch import distributed as dist, Tensor from torch.utils._pytree import tree_map @@ -1233,7 +1234,7 @@ def get_item_shape(self, key: NestedKey): def pop( self, key: NestedKey, default: str | CompatibleType = NO_DEFAULT ) -> CompatibleType: - _nested_key_check(key) + key = unravel_keys(key) try: # using try/except for get/del is suboptimal, but # this is faster that checkink if key in self keys @@ -1525,7 +1526,7 @@ def _convert_to_tensordict(self, dict_value: dict[str, Any]) -> TensorDictBase: ) def _validate_key(self, key: NestedKey) -> NestedKey: - _nested_key_check(key) + key = unravel_keys(key) if isinstance(key, tuple) and len(key) == 1: key = key[0] @@ -3023,6 +3024,8 @@ def __getitem__(self, idx: IndexType) -> TensorDictBase: """ if isinstance(idx, tuple) and len(idx) == 1: idx = idx[0] + if isinstance(idx, tuple): + idx = _maybe_unravel_keys_silent(idx) if isinstance(idx, str) or ( isinstance(idx, tuple) and all(isinstance(sub_idx, str) for sub_idx in idx) ): @@ -3089,6 +3092,9 @@ def __setitem__( elif isinstance(index, (list, range)): index = torch.tensor(index, device=self.device) elif isinstance(index, tuple): + if isinstance(index, tuple): + index = _maybe_unravel_keys_silent(index) + if any(isinstance(sub_index, (list, range)) for sub_index in index): index = tuple( torch.tensor(sub_index, device=self.device) @@ -3863,7 +3869,7 @@ def set_at_( def get( self, key: NestedKey, default: str | CompatibleType = NO_DEFAULT ) -> CompatibleType: - _nested_key_check(key) + key = unravel_keys(key) try: if isinstance(key, tuple): @@ -5905,6 +5911,8 @@ def __contains__(self, item: IndexType) -> bool: def __getitem__(self, index: IndexType) -> TensorDictBase: if isinstance(index, tuple) and len(index) == 1: index = index[0] + if isinstance(index, tuple): + index = _maybe_unravel_keys_silent(index) if index is Ellipsis or (isinstance(index, tuple) and Ellipsis in index): index = convert_ellipsis_to_idx(index, self.batch_size) if index is None: diff --git a/tensordict/utils.py b/tensordict/utils.py index 7e431a51c..4e377492f 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -449,11 +449,6 @@ def _seq_of_nested_key_check(seq: Sequence[NestedKey]) -> None: raise ValueError(f"seq should be a Sequence[NestedKey]. Got {seq}") -def _nested_key_check(key: NestedKey) -> None: - if not is_nested_key(key): - raise ValueError(f"key should be a Sequence[NestedKey]. Got {key}") - - def _normalize_key(key: NestedKey) -> NestedKey: # normalises tuples of length one to their string contents return key if not isinstance(key, tuple) or len(key) > 1 else key[0] @@ -834,3 +829,40 @@ def _is_lis_of_list_of_bools(index, first_level=True): if isinstance(index[0], list): return _is_lis_of_list_of_bools(index[0], False) return False + + +def unravel_keys(key): + """Unravels keys when one can be sure that they are keys.""" + if isinstance(key, tuple): + newkey = [] + for subkey in key: + if isinstance(subkey, str): + newkey.append(subkey) + else: + _key = unravel_keys(subkey) + newkey += _key + key = tuple(newkey) + elif not isinstance(key, str): + raise ValueError(f"key should be a Sequence[NestedKey]. Got {key}") + return key + + +def _maybe_unravel_keys_silent(index): + """Attemps to unravel keys. + + If not possible (not keys) return the original index. + """ + if isinstance(index, tuple): + newkey = [] + for key in index: + if isinstance(key, str): + newkey.append(key) + else: + _key = _maybe_unravel_keys_silent(key) + if _key is key: + return index + newkey += _key + newkey = tuple(newkey) + else: + return index + return newkey diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 11b407eee..92b8fe384 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -1542,6 +1542,20 @@ def test_getitem_string(self, td_name, device): td = getattr(self, td_name)(device) assert isinstance(td["a"], (MemmapTensor, torch.Tensor)) + def test_getitem_nestedtuple(self, td_name, device): + torch.manual_seed(1) + td = getattr(self, td_name)(device) + assert isinstance(td[(("a",))], (MemmapTensor, torch.Tensor)) + assert isinstance(td.get((("a",))), (MemmapTensor, torch.Tensor)) + + def test_setitem_nestedtuple(self, td_name, device): + torch.manual_seed(1) + td = getattr(self, td_name)(device) + if td.is_locked: + td.unlock_() + td[" a ", (("little", "story")), "about", ("myself",)] = torch.zeros(td.shape) + assert (td[" a ", "little", "story", "about", "myself"] == 0).all() + def test_getitem_range(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device)