@@ -5740,17 +5740,22 @@ def rand(self, shape: torch.Size = None) -> TensorDictBase:
57405740 for key , item in self .items ():
57415741 if item is not None :
57425742 _dict [key ] = item .rand (shape )
5743- if self .data_cls is None :
5744- cls = TensorDict
5743+
5744+ cls = self .data_cls if self .data_cls is not None else TensorDict
5745+ if cls is not TensorDict :
5746+ kwargs = {}
5747+ if self ._td_dim_names is not None :
5748+ warnings .warn (f"names for cls { cls } is not supported for rand." )
57455749 else :
5746- cls = self .data_cls
5750+ kwargs = {"names" : self ._td_dim_names }
5751+
57475752 # No need to run checks since we know Composite is compliant with
57485753 # TensorDict requirements
57495754 return cls .from_dict (
57505755 _dict ,
57515756 batch_size = _size ([* shape , * _remove_neg_shapes (self .shape )]),
57525757 device = self .device ,
5753- names = self . _td_dim_names ,
5758+ ** kwargs ,
57545759 )
57555760
57565761 def keys (
@@ -6018,10 +6023,13 @@ def zero(self, shape: torch.Size = None) -> TensorDictBase:
60186023 except RuntimeError :
60196024 device = self ._device
60206025
6021- if self .data_cls is not None :
6022- cls = self .data_cls
6026+ cls = self .data_cls if self .data_cls is not None else TensorDict
6027+ if cls is not TensorDict :
6028+ kwargs = {}
6029+ if self ._td_dim_names is not None :
6030+ warnings .warn (f"names for cls { cls } is not supported for zero." )
60236031 else :
6024- cls = TensorDict
6032+ kwargs = { "names" : self . _td_dim_names }
60256033
60266034 return cls .from_dict (
60276035 {
@@ -6031,7 +6039,7 @@ def zero(self, shape: torch.Size = None) -> TensorDictBase:
60316039 },
60326040 batch_size = _size ([* shape , * self ._safe_shape ]),
60336041 device = device ,
6034- names = self . _td_dim_names ,
6042+ ** kwargs ,
60356043 )
60366044
60376045 def __eq__ (self , other : object ) -> bool :
0 commit comments