Skip to content

Commit 968b463

Browse files
committed
fix-names-from_dict
1 parent 628b884 commit 968b463

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

torchrl/data/tensor_specs.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5740,17 +5740,22 @@ def rand(self, shape: torch.Size = None) -> TensorDictBase:
57405740
for key, item in self.items():
57415741
if item is not None:
57425742
_dict[key] = item.rand(shape)
5743-
if self.data_cls is None:
5744-
cls = TensorDict
5743+
5744+
cls = self.data_cls if self.data_cls is not None else TensorDict
5745+
if cls is not TensorDict:
5746+
kwargs = {}
5747+
if self._td_dim_names is not None:
5748+
warnings.warn(f"names for cls {cls} is not supported for rand.")
57455749
else:
5746-
cls = self.data_cls
5750+
kwargs = {"names": self._td_dim_names}
5751+
57475752
# No need to run checks since we know Composite is compliant with
57485753
# TensorDict requirements
57495754
return cls.from_dict(
57505755
_dict,
57515756
batch_size=_size([*shape, *_remove_neg_shapes(self.shape)]),
57525757
device=self.device,
5753-
names=self._td_dim_names,
5758+
**kwargs,
57545759
)
57555760

57565761
def keys(
@@ -6018,10 +6023,13 @@ def zero(self, shape: torch.Size = None) -> TensorDictBase:
60186023
except RuntimeError:
60196024
device = self._device
60206025

6021-
if self.data_cls is not None:
6022-
cls = self.data_cls
6026+
cls = self.data_cls if self.data_cls is not None else TensorDict
6027+
if cls is not TensorDict:
6028+
kwargs = {}
6029+
if self._td_dim_names is not None:
6030+
warnings.warn(f"names for cls {cls} is not supported for zero.")
60236031
else:
6024-
cls = TensorDict
6032+
kwargs = {"names": self._td_dim_names}
60256033

60266034
return cls.from_dict(
60276035
{
@@ -6031,7 +6039,7 @@ def zero(self, shape: torch.Size = None) -> TensorDictBase:
60316039
},
60326040
batch_size=_size([*shape, *self._safe_shape]),
60336041
device=device,
6034-
names=self._td_dim_names,
6042+
**kwargs,
60356043
)
60366044

60376045
def __eq__(self, other: object) -> bool:

0 commit comments

Comments
 (0)