Skip to content

Commit

Permalink
[Feature] Unravel nested keys (#403)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 2, 2023
1 parent 6f7b3df commit d2edd94
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 9 deletions.
2 changes: 2 additions & 0 deletions tensordict/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
TensorDictBase,
)
from tensordict.utils import (
_maybe_unravel_keys_silent,
_shape,
DeviceType,
expand_right,
Expand Down Expand Up @@ -363,6 +364,7 @@ def __getitem__(self, item):
__getitems__ = __getitem__

def __setitem__(self, index, value):
index = _maybe_unravel_keys_silent(index)
if isinstance(index, str) or (
isinstance(index, tuple) and all(isinstance(val, str) for val in index)
):
Expand Down
16 changes: 12 additions & 4 deletions tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
_get_item,
_getitem_batch_size,
_is_shared,
_nested_key_check,
_maybe_unravel_keys_silent,
_set_item,
_shape,
_sub_index,
Expand All @@ -52,6 +52,7 @@
int_generator,
NestedKey,
prod,
unravel_keys,
)
from torch import distributed as dist, Tensor
from torch.utils._pytree import tree_map
Expand Down Expand Up @@ -1233,7 +1234,7 @@ def get_item_shape(self, key: NestedKey):
def pop(
self, key: NestedKey, default: str | CompatibleType = NO_DEFAULT
) -> CompatibleType:
_nested_key_check(key)
key = unravel_keys(key)
try:
# using try/except for get/del is suboptimal, but
# this is faster that checkink if key in self keys
Expand Down Expand Up @@ -1525,7 +1526,7 @@ def _convert_to_tensordict(self, dict_value: dict[str, Any]) -> TensorDictBase:
)

def _validate_key(self, key: NestedKey) -> NestedKey:
_nested_key_check(key)
key = unravel_keys(key)

if isinstance(key, tuple) and len(key) == 1:
key = key[0]
Expand Down Expand Up @@ -3023,6 +3024,8 @@ def __getitem__(self, idx: IndexType) -> TensorDictBase:
"""
if isinstance(idx, tuple) and len(idx) == 1:
idx = idx[0]
if isinstance(idx, tuple):
idx = _maybe_unravel_keys_silent(idx)
if isinstance(idx, str) or (
isinstance(idx, tuple) and all(isinstance(sub_idx, str) for sub_idx in idx)
):
Expand Down Expand Up @@ -3089,6 +3092,9 @@ def __setitem__(
elif isinstance(index, (list, range)):
index = torch.tensor(index, device=self.device)
elif isinstance(index, tuple):
if isinstance(index, tuple):
index = _maybe_unravel_keys_silent(index)

if any(isinstance(sub_index, (list, range)) for sub_index in index):
index = tuple(
torch.tensor(sub_index, device=self.device)
Expand Down Expand Up @@ -3863,7 +3869,7 @@ def set_at_(
def get(
self, key: NestedKey, default: str | CompatibleType = NO_DEFAULT
) -> CompatibleType:
_nested_key_check(key)
key = unravel_keys(key)

try:
if isinstance(key, tuple):
Expand Down Expand Up @@ -5905,6 +5911,8 @@ def __contains__(self, item: IndexType) -> bool:
def __getitem__(self, index: IndexType) -> TensorDictBase:
if isinstance(index, tuple) and len(index) == 1:
index = index[0]
if isinstance(index, tuple):
index = _maybe_unravel_keys_silent(index)
if index is Ellipsis or (isinstance(index, tuple) and Ellipsis in index):
index = convert_ellipsis_to_idx(index, self.batch_size)
if index is None:
Expand Down
42 changes: 37 additions & 5 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,11 +449,6 @@ def _seq_of_nested_key_check(seq: Sequence[NestedKey]) -> None:
raise ValueError(f"seq should be a Sequence[NestedKey]. Got {seq}")


def _nested_key_check(key: NestedKey) -> None:
if not is_nested_key(key):
raise ValueError(f"key should be a Sequence[NestedKey]. Got {key}")


def _normalize_key(key: NestedKey) -> NestedKey:
# normalises tuples of length one to their string contents
return key if not isinstance(key, tuple) or len(key) > 1 else key[0]
Expand Down Expand Up @@ -834,3 +829,40 @@ def _is_lis_of_list_of_bools(index, first_level=True):
if isinstance(index[0], list):
return _is_lis_of_list_of_bools(index[0], False)
return False


def unravel_keys(key):
"""Unravels keys when one can be sure that they are keys."""
if isinstance(key, tuple):
newkey = []
for subkey in key:
if isinstance(subkey, str):
newkey.append(subkey)
else:
_key = unravel_keys(subkey)
newkey += _key
key = tuple(newkey)
elif not isinstance(key, str):
raise ValueError(f"key should be a Sequence[NestedKey]. Got {key}")
return key


def _maybe_unravel_keys_silent(index):
"""Attemps to unravel keys.
If not possible (not keys) return the original index.
"""
if isinstance(index, tuple):
newkey = []
for key in index:
if isinstance(key, str):
newkey.append(key)
else:
_key = _maybe_unravel_keys_silent(key)
if _key is key:
return index
newkey += _key
newkey = tuple(newkey)
else:
return index
return newkey
14 changes: 14 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1542,6 +1542,20 @@ def test_getitem_string(self, td_name, device):
td = getattr(self, td_name)(device)
assert isinstance(td["a"], (MemmapTensor, torch.Tensor))

def test_getitem_nestedtuple(self, td_name, device):
torch.manual_seed(1)
td = getattr(self, td_name)(device)
assert isinstance(td[(("a",))], (MemmapTensor, torch.Tensor))
assert isinstance(td.get((("a",))), (MemmapTensor, torch.Tensor))

def test_setitem_nestedtuple(self, td_name, device):
torch.manual_seed(1)
td = getattr(self, td_name)(device)
if td.is_locked:
td.unlock_()
td[" a ", (("little", "story")), "about", ("myself",)] = torch.zeros(td.shape)
assert (td[" a ", "little", "story", "about", "myself"] == 0).all()

def test_getitem_range(self, td_name, device):
torch.manual_seed(1)
td = getattr(self, td_name)(device)
Expand Down

0 comments on commit d2edd94

Please sign in to comment.