Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
285 changes: 285 additions & 0 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4301,6 +4301,291 @@ def test_composite(self):
assert c_enum["b"].shape == torch.Size((20, 3))


class TestCompositeNames:
"""Test the names functionality of Composite specs."""

def test_names_property_basic(self):
"""Test basic names property functionality."""
# Test with names
spec = Composite(
{"obs": Bounded(low=-1, high=1, shape=(10, 5, 3, 4))},
shape=(10, 5),
names=["batch", "time"],
)
assert spec.names == ["batch", "time"]
assert spec._has_names() is True

# Test without names
spec_no_names = Composite(
{"obs": Bounded(low=-1, high=1, shape=(10, 5, 3, 4))}, shape=(10, 5)
)
assert spec_no_names.names == [None, None]
assert spec_no_names._has_names() is False

def test_names_setter(self):
"""Test setting names."""
spec = Composite(
{"obs": Bounded(low=-1, high=1, shape=(10, 5, 3, 4))}, shape=(10, 5)
)

# Set names
spec.names = ["batch", "time"]
assert spec.names == ["batch", "time"]
assert spec._has_names() is True

# Clear names
spec.names = None
assert spec.names == [None, None]
assert spec._has_names() is False

def test_names_setter_validation(self):
"""Test names setter validation."""
spec = Composite(
{"obs": Bounded(low=-1, high=1, shape=(10, 5, 3, 4))}, shape=(10, 5)
)

# Test wrong number of names
with pytest.raises(ValueError, match="Expected 2 names, but got 3 names"):
spec.names = ["batch", "time", "extra"]

def test_refine_names_basic(self):
"""Test basic refine_names functionality."""
spec = Composite(
{"obs": Bounded(low=-1, high=1, shape=(10, 5, 3, 4))}, shape=(10, 5, 3)
)

# Initially no names
assert spec.names == [None, None, None]
assert spec._has_names() is False

# Refine names
spec_refined = spec.refine_names(None, None, "feature")
assert spec_refined.names == [None, None, "feature"]
assert spec_refined._has_names() is True

def test_refine_names_ellipsis(self):
"""Test refine_names with ellipsis."""
spec = Composite(
{"obs": Bounded(low=-1, high=1, shape=(10, 5, 3, 4))},
shape=(10, 5, 3),
names=["batch", None, None],
)

# Use ellipsis to fill remaining dimensions
spec_refined = spec.refine_names("batch", ...)
assert spec_refined.names == ["batch", None, None]

def test_refine_names_validation(self):
"""Test refine_names validation."""
spec = Composite(
{"obs": Bounded(low=-1, high=1, shape=(10, 5, 3, 4))},
shape=(10, 5),
names=["batch", "time"],
)

# Try to refine to different name
with pytest.raises(RuntimeError, match="cannot coerce Composite names"):
spec.refine_names("batch", "different")

def test_expand_preserves_names(self):
"""Test that expand preserves names."""
spec = Composite(
{"obs": Bounded(low=-1, high=1, shape=(10, 3, 4))},
shape=(10,),
names=["batch"],
)

expanded = spec.expand(5, 10)
assert expanded.names == [None, "batch"]
assert expanded.shape == torch.Size([5, 10])

def test_squeeze_preserves_names(self):
"""Test that squeeze preserves names."""
spec = Composite(
{"obs": Bounded(low=-1, high=1, shape=(10, 1, 5, 3, 4))},
shape=(10, 1, 5),
names=["batch", "dummy", "time"],
)

squeezed = spec.squeeze(1) # Remove the dimension with size 1
assert squeezed.names == ["batch", "time"]
assert squeezed.shape == torch.Size([10, 5])

def test_squeeze_all_ones_clears_names(self):
"""Test that squeezing all dimensions clears names if all become None."""
spec = Composite(
{"obs": Bounded(low=-1, high=1, shape=(1, 1, 3, 4))},
shape=(1, 1),
names=["dummy1", "dummy2"],
)

squeezed = spec.squeeze()
assert squeezed.names == [] # All dimensions removed, so no names
assert squeezed.shape == torch.Size([])

def test_unsqueeze_preserves_names(self):
"""Test that unsqueeze preserves names."""
spec = Composite(
{"obs": Bounded(low=-1, high=1, shape=(10, 5, 3, 4))},
shape=(10, 5),
names=["batch", "time"],
)

unsqueezed = spec.unsqueeze(1)
assert unsqueezed.names == ["batch", None, "time"]
assert unsqueezed.shape == torch.Size([10, 1, 5])

def test_unbind_preserves_names(self):
"""Test that unbind preserves names."""
spec = Composite(
{"obs": Bounded(low=-1, high=1, shape=(3, 5, 3, 4))},
shape=(3, 5),
names=["batch", "time"],
)

unbound = spec.unbind(0)
assert len(unbound) == 3
for spec_item in unbound:
assert spec_item.names == ["time"]
assert spec_item.shape == torch.Size([5])

def test_clone_preserves_names(self):
"""Test that clone preserves names."""
spec = Composite(
{"obs": Bounded(low=-1, high=1, shape=(10, 3, 4))},
shape=(10,),
names=["batch"],
)

cloned = spec.clone()
assert cloned.names == ["batch"]
assert cloned.shape == spec.shape
assert cloned is not spec # Different objects

def test_to_preserves_names(self):
"""Test that to() preserves names."""
spec = Composite(
{"obs": Bounded(low=-1, high=1, shape=(10, 3, 4))},
shape=(10,),
names=["batch"],
)

moved = spec.to("cpu")
assert moved.names == ["batch"]
assert moved.device == torch.device("cpu")

def test_indexing_preserves_names(self):
"""Test that indexing preserves names."""
spec = Composite(
{"obs": Bounded(low=-1, high=1, shape=(10, 5, 3, 4))},
shape=(10, 5),
names=["batch", "time"],
)

# Test single dimension indexing
indexed = spec[0]
assert indexed.names == ["time"]
assert indexed.shape == torch.Size([5])

# Test slice indexing
sliced = spec[0:5]
assert sliced.names == ["batch", "time"]
assert sliced.shape == torch.Size([5, 5])

def test_nested_composite_names_propagation(self):
"""Test that names are propagated to nested Composite specs."""
nested_spec = Composite(
{
"outer": Composite(
{"inner": Bounded(low=-1, high=1, shape=(10, 3, 2))}, shape=(10, 3)
)
},
shape=(10,),
names=["batch"],
)

assert nested_spec.names == ["batch"]
assert nested_spec["outer"].names == ["batch", None]

def test_erase_names(self):
"""Test erasing names."""
spec = Composite(
{"obs": Bounded(low=-1, high=1, shape=(10, 3, 4))},
shape=(10,),
names=["batch"],
)

assert spec._has_names() is True
spec._erase_names()
assert spec._has_names() is False
assert spec.names == [None]

def test_names_with_different_shapes(self):
"""Test names with different spec shapes."""
spec = Composite(
{
"obs": Bounded(low=-1, high=1, shape=(10, 5, 3, 4)),
"action": Bounded(low=0, high=1, shape=(10, 5, 2)),
},
shape=(10, 5),
names=["batch", "time"],
)

assert spec.names == ["batch", "time"]
assert spec["obs"].shape == torch.Size([10, 5, 3, 4])
assert spec["action"].shape == torch.Size([10, 5, 2])

def test_names_constructor_parameter(self):
"""Test names parameter in constructor."""
# Test with names
spec = Composite(
{"obs": Bounded(low=-1, high=1, shape=(10, 5, 3, 4))},
shape=(10, 5),
names=["batch", "time"],
)
assert spec.names == ["batch", "time"]

# Test without names
spec_no_names = Composite(
{"obs": Bounded(low=-1, high=1, shape=(10, 5, 3, 4))}, shape=(10, 5)
)
assert spec_no_names.names == [None, None]

def test_names_with_empty_composite(self):
"""Test names with empty Composite."""
spec = Composite({}, shape=(10,), names=["batch"])
assert spec.names == ["batch"]
assert spec._has_names() is True

def test_names_equality(self):
"""Test that names don't affect equality."""
spec1 = Composite(
{"obs": Bounded(low=-1, high=1, shape=(10, 3, 4))},
shape=(10,),
names=["batch"],
)

spec2 = Composite(
{"obs": Bounded(low=-1, high=1, shape=(10, 3, 4))}, shape=(10,)
)

# They should be equal despite different names
assert spec1 == spec2

def test_names_repr(self):
"""Test that names don't break repr."""
spec = Composite(
{"obs": Bounded(low=-1, high=1, shape=(10, 3, 4))},
shape=(10,),
names=["batch"],
)

# Should not raise an error
repr_str = repr(spec)
assert "Composite" in repr_str
assert "obs" in repr_str


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
Loading
Loading