Skip to content

Commit

Permalink
Update Omniverse Isaac Gym wrapper to use space utils in jax
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Oct 13, 2024
1 parent 80615af commit ead61e8
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions skrl/envs/wrappers/jax/omniverse_isaacgym_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from skrl import logger
from skrl.envs.wrappers.jax.base import Wrapper
from skrl.utils.spaces.torch import flatten_tensorized_space, tensorize_space, unflatten_tensorized_space


# ML frameworks conversion utilities
Expand Down Expand Up @@ -69,12 +70,14 @@ def step(self, actions: Union[np.ndarray, jax.Array]) -> \
actions = _jax2torch(actions, self._env_device, self._jax)

with torch.no_grad():
self._observations, reward, terminated, info = self._env.step(actions)
observations, reward, terminated, info = self._env.step(unflatten_tensorized_space(self.action_space, actions))

observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations["obs"]))
terminated = terminated.to(dtype=torch.int8)
truncated = info["time_outs"].to(dtype=torch.int8) if "time_outs" in info else torch.zeros_like(terminated)

return _torch2jax(self._observations["obs"], self._jax), \
self._observations = _torch2jax(observations, self._jax)
return self._observations, \
_torch2jax(reward.view(-1, 1), self._jax), \
_torch2jax(terminated.view(-1, 1), self._jax), \
_torch2jax(truncated.view(-1, 1), self._jax), \
Expand All @@ -87,9 +90,11 @@ def reset(self) -> Tuple[Union[np.ndarray, jax.Array], Any]:
:rtype: np.ndarray or jax.Array and any other info
"""
if self._reset_once:
self._observations = self._env.reset()
observations = self._env.reset()
observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations["obs"]))
self._observations = _torch2jax(observations, self._jax)
self._reset_once = False
return _torch2jax(self._observations["obs"], self._jax), {}
return self._observations, {}

def render(self, *args, **kwargs) -> None:
"""Render the environment
Expand Down

0 comments on commit ead61e8

Please sign in to comment.