@@ -5740,16 +5740,22 @@ def rand(self, shape: torch.Size = None) -> TensorDictBase:
57405740 for key , item in self .items ():
57415741 if item is not None :
57425742 _dict [key ] = item .rand (shape )
5743- if self .data_cls is None :
5744- cls = TensorDict
5743+
5744+ cls = self .data_cls if self .data_cls is not None else TensorDict
5745+ if cls is not TensorDict :
5746+ kwargs = {}
5747+ if self ._td_dim_names is not None :
5748+ warnings .warn (f"names for cls { cls } is not supported for rand." )
57455749 else :
5746- cls = self .data_cls
5750+ kwargs = {"names" : self ._td_dim_names }
5751+
57475752 # No need to run checks since we know Composite is compliant with
57485753 # TensorDict requirements
57495754 return cls .from_dict (
57505755 _dict ,
57515756 batch_size = _size ([* shape , * _remove_neg_shapes (self .shape )]),
57525757 device = self .device ,
5758+ ** kwargs ,
57535759 )
57545760
57555761 def keys (
@@ -6017,10 +6023,13 @@ def zero(self, shape: torch.Size = None) -> TensorDictBase:
60176023 except RuntimeError :
60186024 device = self ._device
60196025
6020- if self .data_cls is not None :
6021- cls = self .data_cls
6026+ cls = self .data_cls if self .data_cls is not None else TensorDict
6027+ if cls is not TensorDict :
6028+ kwargs = {}
6029+ if self ._td_dim_names is not None :
6030+ warnings .warn (f"names for cls { cls } is not supported for zero." )
60226031 else :
6023- cls = TensorDict
6032+ kwargs = { "names" : self . _td_dim_names }
60246033
60256034 return cls .from_dict (
60266035 {
@@ -6030,6 +6039,7 @@ def zero(self, shape: torch.Size = None) -> TensorDictBase:
60306039 },
60316040 batch_size = _size ([* shape , * self ._safe_shape ]),
60326041 device = device ,
6042+ ** kwargs ,
60336043 )
60346044
60356045 def __eq__ (self , other : object ) -> bool :
0 commit comments