Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a "duration" mode for event term #1192

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ def step(self, action: torch.Tensor) -> VecEnvStepReturn:
# -- step interval events
if "interval" in self.event_manager.available_modes:
self.event_manager.apply(mode="interval", dt=self.step_dt)
if "duration" in self.event_manager.available_modes:
self.event_manager.apply(mode="duration", dt=self.step_dt)
# -- compute observations
# note: done after reset to get the correct observations for reset envs
self.obs_buf = self.observation_manager.compute()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,40 @@ def apply_external_force_torque(
# note: these are only applied when you call: `asset.write_data_to_sim()`
asset.set_external_force_and_torque(forces, torques, env_ids=env_ids, body_ids=asset_cfg.body_ids)

def apply_external_force_torque_duration(
env: ManagerBasedEnv,
env_ids: torch.Tensor,
open: bool,
force_range: tuple[float, float],
torque_range: tuple[float, float],
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
):
"""
apply the force/torque for a time, then cancel it
changes the above function apply_external_force_torque that apply the force/torque in the whole episode
"""
# extract the used quantities (to enable type-hinting)
asset: RigidObject | Articulation = env.scene[asset_cfg.name]

# resolve environment ids
if env_ids is None:
env_ids = torch.arange(env.scene.num_envs, device=asset.device)
# resolve number of bodies
num_bodies = len(asset_cfg.body_ids) if isinstance(asset_cfg.body_ids, list) else asset.num_bodies

# sample random forces and torques
if open:
size = (len(env_ids), num_bodies, 3)
forces = math_utils.sample_uniform(*force_range, size, asset.device).clone()
torques = math_utils.sample_uniform(*torque_range, size, asset.device).clone()
else:
size = (len(env_ids), num_bodies, 3)
forces = torch.zeros(size=size, device=asset.device)
torques = torch.zeros(size=size, device=asset.device)

# set the forces and torques into the buffers
# note: these are only applied when you call: `asset.write_data_to_sim()`
asset.set_external_force_and_torque(forces, torques, env_ids=env_ids, body_ids=asset_cfg.body_ids)

def push_by_setting_velocity(
env: ManagerBasedEnv,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class EventManager(ManagerBase):
- "startup": Event is applied once at the beginning of the training.
- "reset": Event is applied at every reset.
- "interval": Event is applied at pre-specified intervals of time.
- "duration": Event is applied at pre-specified intervals of time and for a duration (e.g. applied force).

However, you can also define your own modes and use them in the training process as you see fit.
For this you will need to add the triggering of that mode in the environment implementation as well.
Expand Down Expand Up @@ -79,6 +80,11 @@ def __str__(self) -> str:
table.align["Name"] = "l"
for index, (name, cfg) in enumerate(zip(self._mode_term_names[mode], self._mode_term_cfgs[mode])):
table.add_row([index, name, cfg.interval_range_s])
elif mode == "duration":
table.field_names = ["Index", "Name", "Interval time range (s)", "Duration time range (s)"]
table.align["Name"] = "l"
for index, (name, cfg) in enumerate(zip(self._mode_term_names[mode], self._mode_term_cfgs[mode])):
table.add_row([index, name, cfg.interval_range_s, cfg.duration_range_s])
else:
table.field_names = ["Index", "Name"]
table.align["Name"] = "l"
Expand Down Expand Up @@ -198,6 +204,64 @@ def apply(

# call the event term
term_cfg.func(self._env, valid_env_ids, **term_cfg.params)
elif mode == "duration":
# extract time left for this term
time_left = self._duration_term_interval_time_left[index]
# update the time left for each environment
time_left -= dt

# update duration time
started_env_ids=self._duration_term_started
duration_left = self._duration_term_duration_time_left[index]
duration_left[started_env_ids] -= dt

# check if the interval has passed and sample a new interval
# note: we compare with a small value to handle floating point errors
if term_cfg.is_global_time:
# interval check
if time_left < 1e-6:
lower, upper = term_cfg.interval_range_s
sampled_interval = torch.rand(1) * (upper - lower) + lower
self._duration_term_interval_time_left[index][:] = sampled_interval

# call the event term (with None for env_ids)
term_cfg.params["open"]=True
term_cfg.func(self._env, None, **term_cfg.params)
self._duration_term_started[index][:] = True

# duration check
if duration_left < 1e-6:
term_cfg.params["open"]=False
term_cfg.func(self._env, None, **term_cfg.params)
self._duration_term_started[index][:] = False

lower, upper = term_cfg.duration_range_s
duration_left = torch.rand(1) * (upper - lower) + lower
self._duration_term_duration_time_left[index][:] = duration_left
else:
# interval check
valid_env_ids = (time_left < 1e-6).nonzero().flatten()
if len(valid_env_ids) > 0:
lower, upper = term_cfg.interval_range_s
sampled_time = torch.rand(len(valid_env_ids), device=self.device) * (upper - lower) + lower
self._duration_term_interval_time_left[index][valid_env_ids] = sampled_time

# call the event term
term_cfg.params["open"]=True
term_cfg.func(self._env, valid_env_ids, **term_cfg.params)
self._duration_term_started[index][valid_env_ids] = True

# duration check
valid_env_ids_duration = (duration_left < 1e-6).nonzero().flatten()
if len(valid_env_ids_duration) > 0:
term_cfg.params["open"]=False
term_cfg.func(self._env, valid_env_ids_duration, **term_cfg.params)
self._duration_term_started[index][valid_env_ids_duration] = False

lower, upper = term_cfg.duration_range_s
duration_left = torch.rand(len(valid_env_ids_duration), device=self.device) * (upper - lower) + lower
self._duration_term_duration_time_left[index][valid_env_ids_duration] = duration_left

elif mode == "reset":
# obtain the minimum step count between resets
min_step_count = term_cfg.min_step_count_between_reset
Expand Down Expand Up @@ -302,6 +366,10 @@ def _prepare_terms(self):
# buffer to store the time left for "interval" mode
# if interval is global, then it is a single value, otherwise it is per environment
self._interval_term_time_left: list[torch.Tensor] = list()
# buffer to store the time left for "duration" mode
self._duration_term_interval_time_left: list[torch.Tensor] = list()
self._duration_term_started: list[torch.Tensor] = list()
self._duration_term_duration_time_left: list[torch.Tensor] = list()
# buffer to store the step count when the term was last triggered for each environment for "reset" mode
self._reset_term_last_triggered_step_id: list[torch.Tensor] = list()
self._reset_term_last_triggered_once: list[torch.Tensor] = list()
Expand Down Expand Up @@ -363,6 +431,42 @@ def _prepare_terms(self):
lower, upper = term_cfg.interval_range_s
time_left = torch.rand(self.num_envs, device=self.device) * (upper - lower) + lower
self._interval_term_time_left.append(time_left)
elif term_cfg.mode == "duration":
if term_cfg.interval_range_s is None:
raise ValueError(
f"Event term '{term_name}' has mode 'duration' but 'duration_range_s' is not specified."
)
if term_cfg.duration_range_s is None:
raise ValueError(
f"Event term '{term_name}' has mode 'duration' but 'duration_range_s' is not specified."
)


# sample the time left for global
if term_cfg.is_global_time:
lower, upper = term_cfg.interval_range_s
time_left = torch.rand(1) * (upper - lower) + lower
self._duration_term_interval_time_left.append(time_left)

lower, upper = term_cfg.duration_range_s
duration_left = torch.rand(1) * (upper - lower) + lower
self._duration_term_duration_time_left.append(duration_left)

started = torch.zeros(1, dtype=torch.bool, device=self.device)
self._duration_term_started.append(started)
else:
# sample the time left for each environment
lower, upper = term_cfg.interval_range_s
time_left = torch.rand(self.num_envs, device=self.device) * (upper - lower) + lower
self._duration_term_interval_time_left.append(time_left)

lower, upper = term_cfg.duration_range_s
duration_left = torch.rand(self.num_envs, device=self.device) * (upper - lower) + lower
self._duration_term_duration_time_left.append(duration_left)

started = torch.zeros(self.num_envs, dtype=torch.bool, device=self.device)
self._duration_term_started.append(started)

# -- reset mode
elif term_cfg.mode == "reset":
if term_cfg.min_step_count_between_reset < 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,12 @@ class EventTermCfg(ManagerTermBaseCfg):
This is only used if the mode is ``"interval"``.
"""

duration_range_s: tuple[float, float] | None = None
"""
range to be sampled of the time duration of the event
only used if the mode is ``"duration"``.
"""

is_global_time: bool = False
"""Whether randomization should be tracked on a per-environment basis. Defaults to False.

Expand Down