Skip to content

Commit

Permalink
[gym_jiminy/common] Faster env pipeline.
Browse files Browse the repository at this point in the history
  • Loading branch information
duburcqa committed Dec 1, 2024
1 parent 7913241 commit 7cca584
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 44 deletions.
13 changes: 6 additions & 7 deletions python/gym_jiminy/common/gym_jiminy/common/bases/compositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

import numpy as np

from ..utils.spaces import _array_contains

from .interfaces import InfoType, InterfaceJiminyEnv
from .quantities import QuantityCreator

Expand Down Expand Up @@ -580,8 +582,8 @@ def __init__(self,
Optional: False by default.
"""
# Backup user argument(s)
self.low = low
self.high = high
self.low = np.asarray(low) if isinstance(low, Sequence) else low
self.high = np.asarray(high) if isinstance(high, Sequence) else high

# Call base implementation
super().__init__(
Expand Down Expand Up @@ -617,13 +619,10 @@ def compute(self, info: InfoType) -> bool:
# Evaluate the quantity
value = self.data.get()

# Check if the quantity is out-of-bounds bound.
# Check if the quantity is out-of-bounds.
# Note that it may be `None` if the quantity is ill-defined for the
# current simulation state, which triggers termination unconditionally.
is_done = value is None
is_done |= self.low is not None and bool(np.any(self.low > value))
is_done |= self.high is not None and bool(np.any(value > self.high))
return is_done
return value is None or not _array_contains(value, self.low, self.high)


QuantityTermination.name.__doc__ = \
Expand Down
107 changes: 70 additions & 37 deletions python/gym_jiminy/common/gym_jiminy/common/utils/spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,36 +83,26 @@ def _array_clip(value: np.ndarray,
@nb.jit(nopython=True, cache=True, fastmath=True)
def _array_contains(value: np.ndarray,
low: Optional[ArrayOrScalar],
high: Optional[ArrayOrScalar],
tol_abs: float,
tol_rel: float) -> bool:
high: Optional[ArrayOrScalar]) -> bool:
"""Check that all array elements are withing bounds, up to some tolerance
threshold. If both absolute and relative tolerances are provided, then
satisfying only one of the two criteria is considered sufficient.
threshold.
:param value: Array holding values to check.
:param low: Optional lower bound.
:param high: Optional upper bound.
:param tol_abs: Absolute tolerance.
:param tol_rel: Relative tolerance. It will be ignored if either the lower
or upper is not specified.
"""
if value.ndim:
tol_nd = np.full_like(value, tol_abs)
if low is not None and high is not None and tol_rel > 0.0:
tol_nd = np.maximum((high - low) * tol_rel, tol_nd)
value_ = np.asarray(value)
if value_.ndim:
value_1d = np.atleast_1d(value_)
# Reversed bound check because 'all' is always true for empty arrays
if low is not None and not (low - tol_nd <= value).all():
if low is not None and not (low <= value_1d).all():
return False
if high is not None and not (value <= high + tol_nd).all():
if high is not None and not (value_1d <= high).all():
return False
return True
tol_0d = tol_abs
if low is not None and high is not None and tol_rel > 0.0:
tol_0d = max((high.item() - low.item()) * tol_rel, tol_0d)
if low is not None and (low.item() - tol_0d > value.item()):
if low is not None and (low.item() > value_.item()):
return False
if high is not None and (value.item() > high.item() + tol_0d):
if high is not None and (value_.item() > high.item()):
return False
return True

Expand Down Expand Up @@ -247,21 +237,44 @@ def get_robot_measurements_space(robot: jiminy.Robot) -> gym.spaces.Dict:
sensor_space_lower.items(), sensor_space_upper.values())))


def get_bounds(space: gym.Space
def get_bounds(space: gym.Space,
tol_abs: float = 0.0,
tol_rel: float = 0.0,
) -> Tuple[Optional[ArrayOrScalar], Optional[ArrayOrScalar]]:
"""Get the lower and upper bounds of a given 'gym.Space' if any.
:param space: `gym.Space` on which to operate.
:param tol_abs: Absolute tolerance.
Optional: 0.0 by default
:param tol_rel: Relative tolerance. It will be ignored if either the lower
or upper is not specified.
Optional: 0.0 by default.
:returns: Lower and upper bounds as a tuple.
"""
# Extract lower and upper bounds depending on the gym space
low, high = None, None
if isinstance(space, gym.spaces.Box):
return space.low, space.high
low, high = space.low, space.high
if isinstance(space, gym.spaces.Discrete):
return space.start, space.n
low, high = space.start, space.n
if isinstance(space, gym.spaces.MultiDiscrete):
return 0, space.nvec
return None, None
low, high = 0, space.nvec

# Take into account the absolute and relative tolerances
# assert tol_abs >= 0.0 and tol_rel >= 0.0
if tol_abs or tol_rel:
dtype = low.dtype
tol_nd = np.full_like(low, tol_abs)
if low is not None and high is not None and tol_rel:
tol_nd = np.maximum((high - low) * tol_rel, tol_nd)
if low is not None:
low = (low - tol_nd).astype(dtype)
if low is not None:
high = (high + tol_nd).astype(dtype)

return low, high



@no_type_check
Expand Down Expand Up @@ -410,7 +423,7 @@ def contains(data: DataNested,
if tree.issubclass_sequence(data_type):
return all(contains(data[i], subspace, tol_abs, tol_rel)
for i, subspace in enumerate(space))
return _array_contains(data, *get_bounds(space), tol_abs, tol_rel)
return _array_contains(data, *get_bounds(space, tol_abs, tol_rel))


@no_type_check
Expand All @@ -421,7 +434,9 @@ def build_reduce(fn: Callable[..., ValueInT],
arity: Optional[Literal[0, 1]],
*args: Any,
initializer: Optional[Callable[[], ValueOutT]] = None,
forward_bounds: bool = True) -> Callable[..., ValueOutT]:
forward_bounds: bool = True,
tol_abs: float = 0.0,
tol_rel: float = 0.0) -> Callable[..., ValueOutT]:
"""Generate specialized callable applying transform and reduction on all
leaves of given nested space.
Expand Down Expand Up @@ -484,6 +499,12 @@ def build_reduce(fn: Callable[..., ValueInT],
sure all leaves have bounds, otherwise it will raise
an exception at generation-time. This argument is
ignored if not space is specified.
:param tol_abs: Absolute tolerance added to the lower and upper bounds of
the `gym.Space` associated with each leaf.
Optional: 0.0 by default.
:param tol_rel: Relative tolerance added to the lower and upper bounds of
the `gym.Space` associated with each leaf.
Optional: 0.0 by default.
:returns: Fully-specialized reduction callable.
"""
Expand Down Expand Up @@ -803,7 +824,8 @@ def _build_transform_and_reduce(
post_fn = fn if not dataset else partial(fn, *dataset)
post_args = args
if forward_bounds and space is not None:
post_args = (*get_bounds(space), *post_args)
post_args = (
*get_bounds(space, tol_abs, tol_rel), *post_args)
post_fn = partial(post_fn, post_args)
if parent is None:
post_fn = _build_forward(
Expand Down Expand Up @@ -906,8 +928,9 @@ def build_map(fn: Callable[..., ValueT],
space: Optional[gym.Space[DataNested]],
arity: Optional[Literal[0, 1]],
*args: Any,
forward_bounds: bool = True
) -> Callable[[], StructNested[ValueT]]:
forward_bounds: bool = True,
tol_abs: float = 0.0,
tol_rel: float = 0.0) -> Callable[[], StructNested[ValueT]]:
"""Generate specialized callable returning applying out-of-place transform
to all leaves of given nested space.
Expand Down Expand Up @@ -950,6 +973,12 @@ def build_map(fn: Callable[..., ValueT],
an exception at generation-time. This argument is
ignored if not space is specified.
Optional: `True` by default.
:param tol_abs: Absolute tolerance added to the lower and upper bounds of
the `gym.Space` associated with each leaf.
Optional: 0.0 by default.
:param tol_rel: Relative tolerance added to the lower and upper bounds of
the `gym.Space` associated with each leaf.
Optional: 0.0 by default.
:returns: Fully-specialized mapping callable.
"""
Expand Down Expand Up @@ -1087,7 +1116,8 @@ def _build_map(
post_fn = fn if data is None else partial(fn, data)
post_args = args
if forward_bounds and space is not None:
post_args = (*get_bounds(space), *post_args)
post_args = (
*get_bounds(space, tol_abs, tol_rel), *post_args)
post_fn = partial(post_fn, post_args)
if parent is None:
post_fn = _build_setitem(arity, None, post_fn, None)
Expand Down Expand Up @@ -1232,9 +1262,7 @@ class ShortCircuitContains(Exception):
@nb.jit(nopython=True, cache=True)
def _contains_or_raises(value: np.ndarray,
low: Optional[ArrayOrScalar],
high: Optional[ArrayOrScalar],
tol_abs: float,
tol_rel: float) -> bool:
high: Optional[ArrayOrScalar]) -> bool:
"""Thin wrapper around original `_array_contains` method to raise
an exception if the test fails. It enables short-circuit mechanism
to abort checking remaining leaves if any.
Expand All @@ -1247,10 +1275,8 @@ def _contains_or_raises(value: np.ndarray,
:param value: Array holding values to check.
:param low: Lower bound.
:param high: Upper bound.
:param tol_abs: Absolute tolerance.
:param tol_rel: Relative tolerance.
"""
if not _array_contains(value, low, high, tol_abs, tol_rel):
if not _array_contains(value, low, high):
raise ShortCircuitContains("Short-circuit exception.")
return True

Expand All @@ -1270,7 +1296,14 @@ def _exception_handling(out_fn: Callable[[], bool]) -> bool:
return True

return partial(_exception_handling, build_reduce(
_contains_or_raises, None, (data,), space, 0, tol_abs, tol_rel))
_contains_or_raises,
None,
(data,),
space,
arity=0,
forward_bounds=True,
tol_abs=tol_abs,
tol_rel=tol_rel))


def build_normalize(space: gym.Space[DataNested],
Expand Down

0 comments on commit 7cca584

Please sign in to comment.