Skip to content

Commit baf2e43

Browse files
author
Vincent Moens
committed
[BugFix] Pass type directly during reduction
ghstack-source-id: 9ca5bf2 Pull Request resolved: #1223
1 parent 0b901a7 commit baf2e43

File tree

3 files changed

+24
-12
lines changed

3 files changed

+24
-12
lines changed

tensordict/_lazy.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1992,11 +1992,14 @@ def _apply_nest(
19921992
if all(r is None for r in results) and filter_empty in (None, True):
19931993
return
19941994
if not inplace:
1995-
out = type(self)(
1996-
*results,
1997-
stack_dim=self.stack_dim,
1998-
stack_dim_name=self._td_dim_name,
1999-
)
1995+
if results:
1996+
out = type(self)(
1997+
*results,
1998+
stack_dim=self.stack_dim,
1999+
stack_dim_name=self._td_dim_name,
2000+
)
2001+
else:
2002+
out = None
20002003
else:
20012004
out = self
20022005
if names is not NO_DEFAULT:

tensordict/_reductions.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@
1111
from tensordict._lazy import LazyStackedTensorDict
1212
from tensordict._td import TensorDict
1313

14-
from tensordict.tensorclass import NonTensorData
15-
from tensordict.utils import _STRDTYPE2DTYPE
14+
from tensordict.tensorclass import NonTensorData, NonTensorStack
15+
from tensordict.utils import _is_tensorclass, _STRDTYPE2DTYPE
1616

1717
CLS_MAP = {
1818
"TensorDict": TensorDict,
1919
"LazyStackedTensorDict": LazyStackedTensorDict,
20+
"NonTensorData": NonTensorData,
21+
"NonTensorStack": NonTensorStack,
2022
}
2123

2224

@@ -57,7 +59,9 @@ def from_metadata(metadata=metadata_dict, prefix=None):
5759
d[k] = from_metadata(
5860
v, prefix=prefix + (k,) if prefix is not None else (k,)
5961
)
60-
result = CLS_MAP[cls]._from_dict_validated(d, **cls_metadata)
62+
if isinstance(cls, str):
63+
cls = CLS_MAP[cls]
64+
result = cls._from_dict_validated(d, **cls_metadata)
6165
if is_locked:
6266
result.lock_()
6367
# if is_shared:
@@ -121,10 +125,15 @@ def from_metadata(metadata=metadata, prefix=None):
121125
d[k] = from_metadata(
122126
v, prefix=prefix + (k,) if prefix is not None else (k,)
123127
)
124-
result = CLS_MAP[cls]._from_dict_validated(d, **cls_metadata)
128+
if isinstance(cls, str):
129+
cls = CLS_MAP[cls]
130+
result = cls._from_dict_validated(d, **cls_metadata)
125131
if is_locked:
126132
result = result.lock_()
127-
result._consolidated = consolidated
133+
if _is_tensorclass(cls):
134+
result._tensordict._consolidated = consolidated
135+
else:
136+
result._consolidated = consolidated
128137
return result
129138

130139
return from_metadata()

tensordict/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4994,7 +4994,7 @@ def _reduce_vals_and_metadata(self, *, dtype=NO_DEFAULT, requires_metadata):
49944994
if requires_metadata:
49954995
# metadata is nested
49964996
metadata_dict = {
4997-
"cls": type(self).__name__,
4997+
"cls": type(self),
49984998
"non_tensors": {},
49994999
"leaves": {},
50005000
"cls_metadata": self._reduce_get_metadata(),
@@ -5055,7 +5055,7 @@ def assign(
50555055
metadata_dict_key = None
50565056
if requires_metadata:
50575057
metadata_dict_key = metadata_dict[key] = {
5058-
"cls": cls.__name__,
5058+
"cls": cls,
50595059
"non_tensors": {},
50605060
"leaves": {},
50615061
"cls_metadata": value._reduce_get_metadata(),

0 commit comments

Comments
 (0)