Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[drake_gym] Add info_handler callback #21900

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
9 changes: 8 additions & 1 deletion bindings/pydrake/examples/gym/envs/cart_pole.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def make_sim(meshcat=None,
time_step=sim_time_step,
contact_model=contact_model,
discrete_contact_approximation=contact_approximation,
)
)

plant, scene_graph = AddMultibodyPlant(multibody_plant_config, builder)

Expand Down Expand Up @@ -371,6 +371,12 @@ def reset_handler(simulator, diagram_context, seed):
body.SetMass(plant_context, mass+pair[1])


def info_handler(simulator: Simulator) -> dict:
info = dict()
info["timestamp"] = simulator.get_context().get_time()
return info


def DrakeCartPoleEnv(
meshcat=None,
time_limit=gym_time_limit,
Expand Down Expand Up @@ -415,6 +421,7 @@ def DrakeCartPoleEnv(
action_port_id="actions",
observation_port_id="observations",
reset_handler=reset_handler,
info_handler=info_handler,
render_rgb_port_id="color_image" if monitoring_camera else None)

# Expose parameters that could be useful for learning.
Expand Down
22 changes: 20 additions & 2 deletions bindings/pydrake/gym/_drake_gym_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(self,
render_rgb_port_id: Union[OutputPortIndex, str] = None,
render_mode: str = 'human',
reset_handler: Callable[[Simulator, Context], None] = None,
info_handler: Callable[[Simulator], dict] = None,
hardware: bool = False):
"""
Args:
Expand Down Expand Up @@ -87,6 +88,12 @@ def __init__(self,
(e.g. ``joint.set_random_pose_distribution()``
using the ``reset()`` seed), (otherwise) using
``reset_handler()``.
info_handler: A function that returns a ``dict[str, Any]``
containing auxiliary diagnostic information (helpful for
debugging, learning, and logging). Note: if ``step()``
terminates with a ``RuntimeError``, then, to avoid
unexpected behavior, `info_handler()`` will not be called
and an empty info will be returned instead.
hardware: If True, it prevents from setting random context at
``reset()`` when using ``random_generator``, but it does
execute ``reset_handler()`` if given.
Expand Down Expand Up @@ -158,6 +165,12 @@ def __init__(self,
else:
raise ValueError("reset_handler is not callable.")

# Default return value of `info_handler()` is an empty `dict`.
if info_handler is callable(info_handler):
self.info_handler = info_handler
else:
self.info_handler = lambda _: dict()

self.hardware = hardware

if self.simulator:
Expand Down Expand Up @@ -223,7 +236,6 @@ def step(self, action):
truncated = False
# Observation prior to advancing the simulation.
prev_observation = self.observation_port.Eval(context)
info = dict()
try:
status = self.simulator.AdvanceTo(time + self.time_step)
except RuntimeError as e:
Expand All @@ -245,6 +257,9 @@ def step(self, action):
truncated = True
terminated = False
reward = 0
# Do not call info handler, as the simulator has faulted.
info = dict()

return prev_observation, reward, terminated, truncated, info

observation = self.observation_port.Eval(context)
Expand All @@ -253,6 +268,7 @@ def step(self, action):
not truncated
and (status.reason()
== SimulatorStatus.ReturnReason.kReachedTerminationCondition))
info = self.info_handler(self.simulator)

return observation, reward, terminated, truncated, info

Expand Down Expand Up @@ -293,7 +309,9 @@ def reset(self, *,
# Note: The output port will be evaluated without fixing the input
# port.
observations = self.observation_port.Eval(context)
return observations, dict()
info = self.info_handler(self.simulator)

return observations, info

def render(self):
"""
Expand Down