Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jun 6, 2024
1 parent 9c7cc36 commit 252d64a
Showing 1 changed file with 147 additions and 0 deletions.
147 changes: 147 additions & 0 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1908,6 +1908,153 @@ def test_unboundeddiscrete(
assert spec == torch.stack(spec.unbind(-1), -1)


@pytest.mark.parametrize(
"device",
[torch.device("cpu")]
+ [torch.device(f"cuda:{i}" for i in range(torch.cuda.device_count()))],
)
class TestTo:
@pytest.mark.parametrize("shape1", [(5, 4)])
def test_binary(self, shape1, device):
spec = BinaryDiscreteTensorSpec(
n=4, shape=shape1, device="cpu", dtype=torch.bool
)
assert spec.to(device).device == device

@pytest.mark.parametrize(
"shape1,mini,maxi",
[
[(10,), -torch.ones([]), torch.ones([])],
[None, -torch.ones([10]), torch.ones([])],
[None, -torch.ones([]), torch.ones([10])],
[(10,), -torch.ones([]), torch.ones([10])],
[(10,), -torch.ones([10]), torch.ones([])],
[(10,), -torch.ones([10]), torch.ones([10])],
],
)
def test_bounded(self, shape1, mini, maxi, device):
spec = BoundedTensorSpec(
mini, maxi, shape=shape1, device="cpu", dtype=torch.bool
)
assert spec.to(device).device == device

def test_composite(self, device):
batch_size = (5,)
spec1 = BoundedTensorSpec(
-torch.ones([*batch_size, 10]),
torch.ones([*batch_size, 10]),
shape=(
*batch_size,
10,
),
device="cpu",
dtype=torch.bool,
)
spec2 = BinaryDiscreteTensorSpec(
n=4, shape=(*batch_size, 4), device="cpu", dtype=torch.bool
)
spec3 = DiscreteTensorSpec(
n=4, shape=batch_size, device="cpu", dtype=torch.long
)
spec4 = MultiDiscreteTensorSpec(
nvec=(4, 5, 6), shape=(*batch_size, 3), device="cpu", dtype=torch.long
)
spec5 = MultiOneHotDiscreteTensorSpec(
nvec=(4, 5, 6), shape=(*batch_size, 15), device="cpu", dtype=torch.long
)
spec6 = OneHotDiscreteTensorSpec(
n=15, shape=(*batch_size, 15), device="cpu", dtype=torch.long
)
spec7 = UnboundedContinuousTensorSpec(
shape=(*batch_size, 9),
device="cpu",
dtype=torch.float64,
)
spec8 = UnboundedDiscreteTensorSpec(
shape=(*batch_size, 9),
device="cpu",
dtype=torch.long,
)
spec = CompositeSpec(
spec1=spec1,
spec2=spec2,
spec3=spec3,
spec4=spec4,
spec5=spec5,
spec6=spec6,
spec7=spec7,
spec8=spec8,
shape=batch_size,
)
assert spec.to(device).device == device

@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
def test_discrete(
self,
shape1,
device,
):
spec = DiscreteTensorSpec(n=4, shape=shape1, device="cpu", dtype=torch.long)
assert spec.to(device).device == device

@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
def test_multidiscrete(self, shape1, device):
if shape1 is None:
shape1 = (3,)
else:
shape1 = (*shape1, 3)
spec = MultiDiscreteTensorSpec(
nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long
)
assert spec.to(device).device == device

@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
def test_multionehot(self, shape1, device):
if shape1 is None:
shape1 = (15,)
else:
shape1 = (*shape1, 15)
spec = MultiOneHotDiscreteTensorSpec(
nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long
)
assert spec.to(device).device == device

def test_non_tensor(self, device):
spec = NonTensorSpec(shape=(3, 4), device="cpu")
assert spec.to(device).device == device

@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
def test_onehot(self, shape1, device):
if shape1 is None:
shape1 = (15,)
else:
shape1 = (*shape1, 15)
spec = OneHotDiscreteTensorSpec(
n=15, shape=shape1, device="cpu", dtype=torch.long
)
assert spec.to(device).device == device

@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
def test_unbounded(self, shape1, device):
if shape1 is None:
shape1 = (15,)
else:
shape1 = (*shape1, 15)
spec = UnboundedContinuousTensorSpec(
shape=shape1, device="cpu", dtype=torch.float64
)
assert spec.to(device).device == device

@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
def test_unboundeddiscrete(self, shape1, device):
if shape1 is None:
shape1 = (15,)
else:
shape1 = (*shape1, 15)
spec = UnboundedDiscreteTensorSpec(shape=shape1, device="cpu", dtype=torch.long)
assert spec.to(device).device == device


@pytest.mark.parametrize(
"shape,stack_dim",
[[(), 0], [(2,), 0], [(2,), 1], [(2, 3), 0], [(2, 3), 1], [(2, 3), 2]],
Expand Down

0 comments on commit 252d64a

Please sign in to comment.