Skip to content

Commit 154fdf5

Browse files
author
Vincent Moens
committed
[BugFix] Pass type directly during reduction
ghstack-source-id: 2a0f011 Pull Request resolved: #1223
1 parent 0b901a7 commit 154fdf5

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

tensordict/_reductions.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@
1111
from tensordict._lazy import LazyStackedTensorDict
1212
from tensordict._td import TensorDict
1313

14-
from tensordict.tensorclass import NonTensorData
15-
from tensordict.utils import _STRDTYPE2DTYPE
14+
from tensordict.tensorclass import NonTensorData, NonTensorStack
15+
from tensordict.utils import _is_tensorclass, _STRDTYPE2DTYPE
1616

1717
CLS_MAP = {
1818
"TensorDict": TensorDict,
1919
"LazyStackedTensorDict": LazyStackedTensorDict,
20+
"NonTensorData": NonTensorData,
21+
"NonTensorStack": NonTensorStack,
2022
}
2123

2224

@@ -57,7 +59,9 @@ def from_metadata(metadata=metadata_dict, prefix=None):
5759
d[k] = from_metadata(
5860
v, prefix=prefix + (k,) if prefix is not None else (k,)
5961
)
60-
result = CLS_MAP[cls]._from_dict_validated(d, **cls_metadata)
62+
if isinstance(cls, str):
63+
cls = CLS_MAP[cls]
64+
result = cls._from_dict_validated(d, **cls_metadata)
6165
if is_locked:
6266
result.lock_()
6367
# if is_shared:
@@ -121,10 +125,15 @@ def from_metadata(metadata=metadata, prefix=None):
121125
d[k] = from_metadata(
122126
v, prefix=prefix + (k,) if prefix is not None else (k,)
123127
)
124-
result = CLS_MAP[cls]._from_dict_validated(d, **cls_metadata)
128+
if isinstance(cls, str):
129+
cls = CLS_MAP[cls]
130+
result = cls._from_dict_validated(d, **cls_metadata)
125131
if is_locked:
126132
result = result.lock_()
127-
result._consolidated = consolidated
133+
if _is_tensorclass(cls):
134+
result._tensordict._consolidated = consolidated
135+
else:
136+
result._consolidated = consolidated
128137
return result
129138

130139
return from_metadata()

tensordict/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4994,7 +4994,7 @@ def _reduce_vals_and_metadata(self, *, dtype=NO_DEFAULT, requires_metadata):
49944994
if requires_metadata:
49954995
# metadata is nested
49964996
metadata_dict = {
4997-
"cls": type(self).__name__,
4997+
"cls": type(self),
49984998
"non_tensors": {},
49994999
"leaves": {},
50005000
"cls_metadata": self._reduce_get_metadata(),
@@ -5055,7 +5055,7 @@ def assign(
50555055
metadata_dict_key = None
50565056
if requires_metadata:
50575057
metadata_dict_key = metadata_dict[key] = {
5058-
"cls": cls.__name__,
5058+
"cls": cls,
50595059
"non_tensors": {},
50605060
"leaves": {},
50615061
"cls_metadata": value._reduce_get_metadata(),

0 commit comments

Comments
 (0)