Skip to content

Commit

Permalink
Use gymnsasium batch utility to sample fundamental spaces
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Oct 18, 2024
1 parent c41ad0f commit 94083a9
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
15 changes: 9 additions & 6 deletions skrl/utils/spaces/jax/spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,26 +280,29 @@ def sample_space(space: spaces.Space, batch_size: int = 1, backend: str = Litera
# fundamental spaces
# Box
if isinstance(space, spaces.Box):
sample = gymnasium.vector.utils.batch_space(space, batch_size).sample()
if backend == "numpy":
return np.stack([space.sample() for _ in range(batch_size)])
return np.array(sample).reshape(batch_size, *space.shape)
elif backend == "jax":
return jnp.array(np.stack([space.sample() for _ in range(batch_size)]), device=device)
return jnp.array(sample, device=device).reshape(batch_size, *space.shape)
else:
raise ValueError(f"Unsupported backend type ({backend})")
# Discrete
elif isinstance(space, spaces.Discrete):
sample = gymnasium.vector.utils.batch_space(space, batch_size).sample()
if backend == "numpy":
return np.stack([[space.sample()] for _ in range(batch_size)])
return np.array(sample).reshape(batch_size, -1)
elif backend == "jax":
return jnp.array(np.stack([[space.sample()] for _ in range(batch_size)]), device=device)
return jnp.array(sample, device=device).reshape(batch_size, -1)
else:
raise ValueError(f"Unsupported backend type ({backend})")
# MultiDiscrete
elif isinstance(space, spaces.MultiDiscrete):
sample = gymnasium.vector.utils.batch_space(space, batch_size).sample()
if backend == "numpy":
return np.stack([space.sample() for _ in range(batch_size)])
return np.array(sample).reshape(batch_size, *space.nvec.shape)
elif backend == "jax":
return jnp.array(np.stack([space.sample() for _ in range(batch_size)]), device=device)
return jnp.array(sample, device=device).reshape(batch_size, *space.nvec.shape)
else:
raise ValueError(f"Unsupported backend type ({backend})")
# composite spaces
Expand Down
15 changes: 9 additions & 6 deletions skrl/utils/spaces/torch/spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,26 +248,29 @@ def sample_space(space: spaces.Space, batch_size: int = 1, backend: str = Litera
# fundamental spaces
# Box
if isinstance(space, spaces.Box):
sample = gymnasium.vector.utils.batch_space(space, batch_size).sample()
if backend == "numpy":
return np.stack([space.sample() for _ in range(batch_size)])
return np.array(sample).reshape(batch_size, *space.shape)
elif backend == "torch":
return torch.tensor(np.stack([space.sample() for _ in range(batch_size)]), device=device)
return torch.tensor(sample, device=device).reshape(batch_size, *space.shape)
else:
raise ValueError(f"Unsupported backend type ({backend})")
# Discrete
elif isinstance(space, spaces.Discrete):
sample = gymnasium.vector.utils.batch_space(space, batch_size).sample()
if backend == "numpy":
return np.stack([[space.sample()] for _ in range(batch_size)])
return np.array(sample).reshape(batch_size, -1)
elif backend == "torch":
return torch.tensor(np.stack([[space.sample()] for _ in range(batch_size)]), device=device)
return torch.tensor(sample, device=device).reshape(batch_size, -1)
else:
raise ValueError(f"Unsupported backend type ({backend})")
# MultiDiscrete
elif isinstance(space, spaces.MultiDiscrete):
sample = gymnasium.vector.utils.batch_space(space, batch_size).sample()
if backend == "numpy":
return np.stack([space.sample() for _ in range(batch_size)])
return np.array(sample).reshape(batch_size, *space.nvec.shape)
elif backend == "torch":
return torch.tensor(np.stack([space.sample() for _ in range(batch_size)]), device=device)
return torch.tensor(sample, device=device).reshape(batch_size, *space.nvec.shape)
else:
raise ValueError(f"Unsupported backend type ({backend})")
# composite spaces
Expand Down

0 comments on commit 94083a9

Please sign in to comment.