Skip to content

Commit 8d2ad89

Browse files
louisfauryLouis Fauryvmoens
committed
[Feature] Composite specs can create named tensors with 'zero' and 'rand' (#3214)
Co-authored-by: Louis Faury <louis.faury@helsing.ai> Co-authored-by: vmoens <vincentmoens@gmail.com>
1 parent 3814305 commit 8d2ad89

File tree

2 files changed

+36
-6
lines changed

2 files changed

+36
-6
lines changed

test/test_specs.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4585,6 +4585,26 @@ def test_names_repr(self):
45854585
assert "Composite" in repr_str
45864586
assert "obs" in repr_str
45874587

4588+
def test_zero_create_names(self):
4589+
"""Test that creating tensors with 'zero' propagates names."""
4590+
spec = Composite(
4591+
{"obs": Bounded(low=-1, high=1, shape=(10, 3, 4))},
4592+
shape=(10,),
4593+
names=["batch"],
4594+
)
4595+
td = spec.zero()
4596+
td.names = ["batch"]
4597+
4598+
def test_rand_create_names(self):
4599+
"""Test that creating tensors with 'rand' propagates names."""
4600+
spec = Composite(
4601+
{"obs": Bounded(low=-1, high=1, shape=(10, 3, 4))},
4602+
shape=(10,),
4603+
names=["batch"],
4604+
)
4605+
td = spec.rand()
4606+
td.names = ["batch"]
4607+
45884608

45894609
if __name__ == "__main__":
45904610
args, unknown = argparse.ArgumentParser().parse_known_args()

torchrl/data/tensor_specs.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5740,16 +5740,22 @@ def rand(self, shape: torch.Size = None) -> TensorDictBase:
57405740
for key, item in self.items():
57415741
if item is not None:
57425742
_dict[key] = item.rand(shape)
5743-
if self.data_cls is None:
5744-
cls = TensorDict
5743+
5744+
cls = self.data_cls if self.data_cls is not None else TensorDict
5745+
if cls is not TensorDict:
5746+
kwargs = {}
5747+
if self._td_dim_names is not None:
5748+
warnings.warn(f"names for cls {cls} is not supported for rand.")
57455749
else:
5746-
cls = self.data_cls
5750+
kwargs = {"names": self._td_dim_names}
5751+
57475752
# No need to run checks since we know Composite is compliant with
57485753
# TensorDict requirements
57495754
return cls.from_dict(
57505755
_dict,
57515756
batch_size=_size([*shape, *_remove_neg_shapes(self.shape)]),
57525757
device=self.device,
5758+
**kwargs,
57535759
)
57545760

57555761
def keys(
@@ -6017,10 +6023,13 @@ def zero(self, shape: torch.Size = None) -> TensorDictBase:
60176023
except RuntimeError:
60186024
device = self._device
60196025

6020-
if self.data_cls is not None:
6021-
cls = self.data_cls
6026+
cls = self.data_cls if self.data_cls is not None else TensorDict
6027+
if cls is not TensorDict:
6028+
kwargs = {}
6029+
if self._td_dim_names is not None:
6030+
warnings.warn(f"names for cls {cls} is not supported for zero.")
60226031
else:
6023-
cls = TensorDict
6032+
kwargs = {"names": self._td_dim_names}
60246033

60256034
return cls.from_dict(
60266035
{
@@ -6030,6 +6039,7 @@ def zero(self, shape: torch.Size = None) -> TensorDictBase:
60306039
},
60316040
batch_size=_size([*shape, *self._safe_shape]),
60326041
device=device,
6042+
**kwargs,
60336043
)
60346044

60356045
def __eq__(self, other: object) -> bool:

0 commit comments

Comments
 (0)