Skip to content

Commit

Permalink
Merge pull request #126 from FragileTech/rgb
Browse files Browse the repository at this point in the history
Add rgb_shape. Return RGB array at reset
  • Loading branch information
Guillemdb authored Feb 1, 2025
2 parents 3b7930f + cc7b3e9 commit e60709e
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 8 deletions.
14 changes: 14 additions & 0 deletions src/plangym/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ class TestPlanEnv:
"autoreset",
"delay_setup",
"return_image",
"img_shape",
)

def test_init(self, env):
Expand Down Expand Up @@ -141,6 +142,15 @@ def test_obs_shape(self, env):
obs, *_ = env.step(env.sample_action())
assert obs.shape == env.obs_shape, (obs.shape, env.obs_shape)

def test_img_shape(self, env):
assert hasattr(env, "img_shape")
if env.img_shape is None:
return
assert isinstance(env.img_shape, tuple)
if env.img_shape:
for val in env.img_shape:
assert isinstance(val, int)

def test_action_shape(self, env):
assert hasattr(env, "action_shape")
assert isinstance(env.action_shape, tuple)
Expand Down Expand Up @@ -193,6 +203,10 @@ def test_set_state(self, env):
def test_reset(self, env):
_ = env.reset(return_state=False)
state, obs, info = env.reset(return_state=True)
if env.return_image:
assert "rgb" in info
assert isinstance(info["rgb"], numpy.ndarray)
assert info["rgb"].shape == env.img_shape
state_is_array = isinstance(state, numpy.ndarray)
obs_is_array = isinstance(obs, numpy.ndarray)
assert isinstance(info, dict), info
Expand Down
30 changes: 22 additions & 8 deletions src/plangym/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Plangym API implementation."""

from abc import ABC
from functools import cached_property
from typing import Any, Callable, Iterable

import gymnasium as gym
Expand Down Expand Up @@ -106,6 +107,20 @@ def return_image(self) -> bool:
"""
return self._return_image

@cached_property
def img_shape(self) -> tuple[int, ...] | None:
"""Return the shape of the image returned by the environment.
If the environment does not return an image, it will return None. This also applies
to environments that throw an error when trying to get the image
(like when running in headless machines without a virtual display).
"""
try:
img = self.get_image()
return img.shape
except Exception:
return None

def get_image(self) -> None | numpy.ndarray:
"""Return a numpy array containing the rendered view of the environment.
Expand Down Expand Up @@ -563,7 +578,7 @@ def gym_env(self):
return self._gym_env

@property
def obs_shape(self) -> tuple[int, ...]:
def obs_shape(self) -> tuple[int, ...] | None:
"""Tuple containing the shape of the *observations* returned by the Environment."""
if self.observation_space is None:
return None
Expand Down Expand Up @@ -701,15 +716,12 @@ def get_image(self) -> numpy.ndarray:

def apply_reset(
self,
return_state: bool = True,
) -> numpy.ndarray | tuple[numpy.ndarray, numpy.ndarray]:
) -> tuple[numpy.ndarray, dict[str, Any]]:
"""Restart the environment.
Args:
return_state: If ``True`` it will return the state of the environment.
Returns:
``(state, obs)`` if ```return_state`` is ``True`` else return ``obs``.
Returns
``(obs, info)``. If ```return_image`` is ``True``, the info dictionary contains an
``'rgb'`` key with the corresponding image.
"""
# FIXME: WTF this return_state thing?
Expand All @@ -720,6 +732,8 @@ def apply_reset(
obs, info = data
else:
obs, info = data, {}
if self.return_image:
info["rgb"] = self.get_image()
return obs, info

def apply_action(self, action):
Expand Down

0 comments on commit e60709e

Please sign in to comment.