|
11 | 11 | from tensordict._lazy import LazyStackedTensorDict |
12 | 12 | from tensordict._td import TensorDict |
13 | 13 |
|
14 | | -from tensordict.tensorclass import NonTensorData |
15 | | -from tensordict.utils import _STRDTYPE2DTYPE |
| 14 | +from tensordict.tensorclass import NonTensorData, NonTensorStack |
| 15 | +from tensordict.utils import _is_tensorclass, _STRDTYPE2DTYPE |
16 | 16 |
|
17 | 17 | CLS_MAP = { |
18 | 18 | "TensorDict": TensorDict, |
19 | 19 | "LazyStackedTensorDict": LazyStackedTensorDict, |
| 20 | + "NonTensorData": NonTensorData, |
| 21 | + "NonTensorStack": NonTensorStack, |
20 | 22 | } |
21 | 23 |
|
22 | 24 |
|
@@ -57,7 +59,9 @@ def from_metadata(metadata=metadata_dict, prefix=None): |
57 | 59 | d[k] = from_metadata( |
58 | 60 | v, prefix=prefix + (k,) if prefix is not None else (k,) |
59 | 61 | ) |
60 | | - result = CLS_MAP[cls]._from_dict_validated(d, **cls_metadata) |
| 62 | + if isinstance(cls, str): |
| 63 | + cls = CLS_MAP[cls] |
| 64 | + result = cls._from_dict_validated(d, **cls_metadata) |
61 | 65 | if is_locked: |
62 | 66 | result.lock_() |
63 | 67 | # if is_shared: |
@@ -121,10 +125,15 @@ def from_metadata(metadata=metadata, prefix=None): |
121 | 125 | d[k] = from_metadata( |
122 | 126 | v, prefix=prefix + (k,) if prefix is not None else (k,) |
123 | 127 | ) |
124 | | - result = CLS_MAP[cls]._from_dict_validated(d, **cls_metadata) |
| 128 | + if isinstance(cls, str): |
| 129 | + cls = CLS_MAP[cls] |
| 130 | + result = cls._from_dict_validated(d, **cls_metadata) |
125 | 131 | if is_locked: |
126 | 132 | result = result.lock_() |
127 | | - result._consolidated = consolidated |
| 133 | + if _is_tensorclass(cls): |
| 134 | + result._tensordict._consolidated = consolidated |
| 135 | + else: |
| 136 | + result._consolidated = consolidated |
128 | 137 | return result |
129 | 138 |
|
130 | 139 | return from_metadata() |
|
0 commit comments