diff --git a/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/manager_based_rl_env.py b/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/manager_based_rl_env.py index 0b899c0c33..af059866ac 100644 --- a/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/manager_based_rl_env.py +++ b/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/manager_based_rl_env.py @@ -203,6 +203,8 @@ def step(self, action: torch.Tensor) -> VecEnvStepReturn: # -- step interval events if "interval" in self.event_manager.available_modes: self.event_manager.apply(mode="interval", dt=self.step_dt) + if "duration" in self.event_manager.available_modes: + self.event_manager.apply(mode="duration", dt=self.step_dt) # -- compute observations # note: done after reset to get the correct observations for reset envs self.obs_buf = self.observation_manager.compute() diff --git a/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/events.py b/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/events.py index 8f2d737eb7..cb55b7a779 100644 --- a/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/events.py +++ b/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/events.py @@ -564,6 +564,40 @@ def apply_external_force_torque( # note: these are only applied when you call: `asset.write_data_to_sim()` asset.set_external_force_and_torque(forces, torques, env_ids=env_ids, body_ids=asset_cfg.body_ids) +def apply_external_force_torque_duration( + env: ManagerBasedEnv, + env_ids: torch.Tensor, + open: bool, + force_range: tuple[float, float], + torque_range: tuple[float, float], + asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), +): + """ + apply the force/torque for a time, then cancel it + changes the above function apply_external_force_torque that apply the force/torque in the whole episode + """ + # extract the used quantities (to enable type-hinting) + asset: RigidObject | Articulation = env.scene[asset_cfg.name] + + # resolve environment ids + if env_ids is None: + env_ids = torch.arange(env.scene.num_envs, device=asset.device) + # resolve number of bodies + num_bodies = len(asset_cfg.body_ids) if isinstance(asset_cfg.body_ids, list) else asset.num_bodies + + # sample random forces and torques + if open: + size = (len(env_ids), num_bodies, 3) + forces = math_utils.sample_uniform(*force_range, size, asset.device).clone() + torques = math_utils.sample_uniform(*torque_range, size, asset.device).clone() + else: + size = (len(env_ids), num_bodies, 3) + forces = torch.zeros(size=size, device=asset.device) + torques = torch.zeros(size=size, device=asset.device) + + # set the forces and torques into the buffers + # note: these are only applied when you call: `asset.write_data_to_sim()` + asset.set_external_force_and_torque(forces, torques, env_ids=env_ids, body_ids=asset_cfg.body_ids) def push_by_setting_velocity( env: ManagerBasedEnv, diff --git a/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/event_manager.py b/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/event_manager.py index e8aa7407ed..ed49189d6a 100644 --- a/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/event_manager.py +++ b/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/event_manager.py @@ -41,6 +41,7 @@ class EventManager(ManagerBase): - "startup": Event is applied once at the beginning of the training. - "reset": Event is applied at every reset. - "interval": Event is applied at pre-specified intervals of time. + - "duration": Event is applied at pre-specified intervals of time and for a duration (e.g. applied force). However, you can also define your own modes and use them in the training process as you see fit. For this you will need to add the triggering of that mode in the environment implementation as well. @@ -79,6 +80,11 @@ def __str__(self) -> str: table.align["Name"] = "l" for index, (name, cfg) in enumerate(zip(self._mode_term_names[mode], self._mode_term_cfgs[mode])): table.add_row([index, name, cfg.interval_range_s]) + elif mode == "duration": + table.field_names = ["Index", "Name", "Interval time range (s)", "Duration time range (s)"] + table.align["Name"] = "l" + for index, (name, cfg) in enumerate(zip(self._mode_term_names[mode], self._mode_term_cfgs[mode])): + table.add_row([index, name, cfg.interval_range_s, cfg.duration_range_s]) else: table.field_names = ["Index", "Name"] table.align["Name"] = "l" @@ -198,6 +204,64 @@ def apply( # call the event term term_cfg.func(self._env, valid_env_ids, **term_cfg.params) + elif mode == "duration": + # extract time left for this term + time_left = self._duration_term_interval_time_left[index] + # update the time left for each environment + time_left -= dt + + # update duration time + started_env_ids=self._duration_term_started + duration_left = self._duration_term_duration_time_left[index] + duration_left[started_env_ids] -= dt + + # check if the interval has passed and sample a new interval + # note: we compare with a small value to handle floating point errors + if term_cfg.is_global_time: + # interval check + if time_left < 1e-6: + lower, upper = term_cfg.interval_range_s + sampled_interval = torch.rand(1) * (upper - lower) + lower + self._duration_term_interval_time_left[index][:] = sampled_interval + + # call the event term (with None for env_ids) + term_cfg.params["open"]=True + term_cfg.func(self._env, None, **term_cfg.params) + self._duration_term_started[index][:] = True + + # duration check + if duration_left < 1e-6: + term_cfg.params["open"]=False + term_cfg.func(self._env, None, **term_cfg.params) + self._duration_term_started[index][:] = False + + lower, upper = term_cfg.duration_range_s + duration_left = torch.rand(1) * (upper - lower) + lower + self._duration_term_duration_time_left[index][:] = duration_left + else: + # interval check + valid_env_ids = (time_left < 1e-6).nonzero().flatten() + if len(valid_env_ids) > 0: + lower, upper = term_cfg.interval_range_s + sampled_time = torch.rand(len(valid_env_ids), device=self.device) * (upper - lower) + lower + self._duration_term_interval_time_left[index][valid_env_ids] = sampled_time + + # call the event term + term_cfg.params["open"]=True + term_cfg.func(self._env, valid_env_ids, **term_cfg.params) + self._duration_term_started[index][valid_env_ids] = True + + # duration check + valid_env_ids_duration = (duration_left < 1e-6).nonzero().flatten() + if len(valid_env_ids_duration) > 0: + term_cfg.params["open"]=False + term_cfg.func(self._env, valid_env_ids_duration, **term_cfg.params) + self._duration_term_started[index][valid_env_ids_duration] = False + + lower, upper = term_cfg.duration_range_s + duration_left = torch.rand(len(valid_env_ids_duration), device=self.device) * (upper - lower) + lower + self._duration_term_duration_time_left[index][valid_env_ids_duration] = duration_left + elif mode == "reset": # obtain the minimum step count between resets min_step_count = term_cfg.min_step_count_between_reset @@ -302,6 +366,10 @@ def _prepare_terms(self): # buffer to store the time left for "interval" mode # if interval is global, then it is a single value, otherwise it is per environment self._interval_term_time_left: list[torch.Tensor] = list() + # buffer to store the time left for "duration" mode + self._duration_term_interval_time_left: list[torch.Tensor] = list() + self._duration_term_started: list[torch.Tensor] = list() + self._duration_term_duration_time_left: list[torch.Tensor] = list() # buffer to store the step count when the term was last triggered for each environment for "reset" mode self._reset_term_last_triggered_step_id: list[torch.Tensor] = list() self._reset_term_last_triggered_once: list[torch.Tensor] = list() @@ -363,6 +431,42 @@ def _prepare_terms(self): lower, upper = term_cfg.interval_range_s time_left = torch.rand(self.num_envs, device=self.device) * (upper - lower) + lower self._interval_term_time_left.append(time_left) + elif term_cfg.mode == "duration": + if term_cfg.interval_range_s is None: + raise ValueError( + f"Event term '{term_name}' has mode 'duration' but 'duration_range_s' is not specified." + ) + if term_cfg.duration_range_s is None: + raise ValueError( + f"Event term '{term_name}' has mode 'duration' but 'duration_range_s' is not specified." + ) + + + # sample the time left for global + if term_cfg.is_global_time: + lower, upper = term_cfg.interval_range_s + time_left = torch.rand(1) * (upper - lower) + lower + self._duration_term_interval_time_left.append(time_left) + + lower, upper = term_cfg.duration_range_s + duration_left = torch.rand(1) * (upper - lower) + lower + self._duration_term_duration_time_left.append(duration_left) + + started = torch.zeros(1, dtype=torch.bool, device=self.device) + self._duration_term_started.append(started) + else: + # sample the time left for each environment + lower, upper = term_cfg.interval_range_s + time_left = torch.rand(self.num_envs, device=self.device) * (upper - lower) + lower + self._duration_term_interval_time_left.append(time_left) + + lower, upper = term_cfg.duration_range_s + duration_left = torch.rand(self.num_envs, device=self.device) * (upper - lower) + lower + self._duration_term_duration_time_left.append(duration_left) + + started = torch.zeros(self.num_envs, dtype=torch.bool, device=self.device) + self._duration_term_started.append(started) + # -- reset mode elif term_cfg.mode == "reset": if term_cfg.min_step_count_between_reset < 0: diff --git a/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/manager_term_cfg.py b/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/manager_term_cfg.py index 9a2250e48b..f06daf1565 100644 --- a/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/manager_term_cfg.py +++ b/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/manager_term_cfg.py @@ -213,6 +213,12 @@ class EventTermCfg(ManagerTermBaseCfg): This is only used if the mode is ``"interval"``. """ + duration_range_s: tuple[float, float] | None = None + """ + range to be sampled of the time duration of the event + only used if the mode is ``"duration"``. + """ + is_global_time: bool = False """Whether randomization should be tracked on a per-environment basis. Defaults to False.