Skip to content

Commit

Permalink
Adds new MDP observation, randomization and reward terms (#60)
Browse files Browse the repository at this point in the history
# Description

This MR adds the following:

1. Observations: Adds observations for root state (pos, quat, linear
vel, and angular vel) in the environment frame. Important for assets
such as objects during manipulation.

2. Randomizations: Adds random orientation randomization for assets
(such as objects) and joint position randomization for articulations.

3. Rewards: Adds a termination reward function for specific termination
terms. Needed if terminations are to be weighted individually, for eg,
if successful termination reward should have a different weighting
factor than illegal state termination reward.

Tested for functionality.

## Type of change

- New feature (non-breaking change which adds functionality)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [x] I have run all the tests with `./orbit.sh --test` and they pass
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
  • Loading branch information
arbhardwaj98 authored and Mayankm96 committed Mar 11, 2024
1 parent e7506fe commit 1a42eb9
Show file tree
Hide file tree
Showing 9 changed files with 332 additions and 64 deletions.
2 changes: 1 addition & 1 deletion source/extensions/omni.isaac.orbit/config/extension.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]

# Note: Semantic Versioning is used: https://semver.org/
version = "0.12.1"
version = "0.12.2"

# Description
title = "ORBIT framework for Robot Learning"
Expand Down
17 changes: 17 additions & 0 deletions source/extensions/omni.isaac.orbit/docs/CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,23 @@
Changelog
---------

0.12.2 (2024-03-10)
~~~~~~~~~~~~~~~~~~~

Added
^^^^^

* Added observation terms for states of a rigid object in world frame.
* Added randomization terms to set root state with randomized orientation and joint state within user-specified limits.
* Added reward term for penalizing specific termination terms.

Fixed
^^^^^

* Improved sampling of states inside randomization terms. Earlier, the code did multiple torch calls
for sampling different components of the vector. Now, it uses a single call to sample the entire vector.


0.12.1 (2024-03-09)
~~~~~~~~~~~~~~~~~~~

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class ArticulationData(RigidObjectData):
"""Joint positions limits for all joints. Shape is (count, num_joints, 2)."""

soft_joint_vel_limits: torch.Tensor = None
"""Joint velocity limits for all joints. Shape is (count, num_joints, 2)."""
"""Joint velocity limits for all joints. Shape is (count, num_joints)."""

gear_ratio: torch.Tensor = None
"""Gear ratio for relating motor torques to applied Joint torques. Shape is (count, num_joints)."""
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,34 @@ def projected_gravity(env: BaseEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("
return asset.data.projected_gravity_b


def root_pos_w(env: BaseEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Asset root position in the environment frame."""
# extract the used quantities (to enable type-hinting)
asset: RigidObject = env.scene[asset_cfg.name]
return asset.data.root_pos_w - env.scene.env_origins


def root_quat_w(env: BaseEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Asset root orientation in the environment frame."""
# extract the used quantities (to enable type-hinting)
asset: RigidObject = env.scene[asset_cfg.name]
return asset.data.root_quat_w


def root_lin_vel_w(env: BaseEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Asset root linear velocity in the environment frame."""
# extract the used quantities (to enable type-hinting)
asset: RigidObject = env.scene[asset_cfg.name]
return asset.data.root_lin_vel_w


def root_ang_vel_w(env: BaseEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Asset root angular velocity in the environment frame."""
# extract the used quantities (to enable type-hinting)
asset: RigidObject = env.scene[asset_cfg.name]
return asset.data.root_ang_vel_w


"""
Joint state.
"""
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

from omni.isaac.orbit.assets import Articulation, RigidObject
from omni.isaac.orbit.managers import SceneEntityCfg
from omni.isaac.orbit.managers.manager_base import ManagerTermBase
from omni.isaac.orbit.managers.manager_term_cfg import RewardTermCfg
from omni.isaac.orbit.sensors import ContactSensor

if TYPE_CHECKING:
Expand All @@ -36,6 +38,36 @@ def is_terminated(env: RLTaskEnv) -> torch.Tensor:
return env.termination_manager.terminated.float()


class is_terminated_term(ManagerTermBase):
"""Penalize termination for specific terms that don't correspond to episodic timeouts.
The parameters are as follows:
* attr:`term_keys`: The termination terms to penalize. This can be a string, a list of strings
or regular expressions. Default is ".*" which penalizes all terminations.
The reward is computed as the sum of the termination terms that are not episodic timeouts.
This means that the reward is 0 if the episode is terminated due to an episodic timeout. Otherwise,
if two termination terms are active, the reward is 2.
"""

def __init__(self, cfg: RewardTermCfg, env: RLTaskEnv):
# initialize the base class
super().__init__(cfg, env)
# find and store the termination terms
term_keys = cfg.params.get("term_keys", ".*")
self._term_names = env.termination_manager.find_terms(term_keys)

def __call__(self, env: RLTaskEnv, term_keys: str | list[str] = ".*") -> torch.Tensor:
# Return the unweighted reward for the termination terms
reset_buf = torch.zeros(env.num_envs, device=env.device)
for term in self._term_names:
# Sums over terminations term values to account for multiple terminations in the same step
reset_buf += env.termination_manager.get_term(term)

return (reset_buf * (~env.termination_manager.time_outs)).float()


"""
Root penalties.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,18 +120,19 @@ def max_episode_length(self) -> int:

def load_managers(self):
# note: this order is important since observation manager needs to know the command and action managers
# and the reward manager needs to know the termination manager
# -- command manager
self.command_manager: CommandManager = CommandManager(self.cfg.commands, self)
print("[INFO] Command Manager: ", self.command_manager)
# call the parent class to load the managers for observations and actions.
super().load_managers()
# prepare the managers
# -- reward manager
self.reward_manager = RewardManager(self.cfg.rewards, self)
print("[INFO] Reward Manager: ", self.reward_manager)
# -- termination manager
self.termination_manager = TerminationManager(self.cfg.terminations, self)
print("[INFO] Termination Manager: ", self.termination_manager)
# -- reward manager
self.reward_manager = RewardManager(self.cfg.rewards, self)
print("[INFO] Reward Manager: ", self.reward_manager)
# -- curriculum manager
self.curriculum_manager = CurriculumManager(self.cfg.curriculum, self)
print("[INFO] Curriculum Manager: ", self.curriculum_manager)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import carb

import omni.isaac.orbit.utils.string as string_utils
from omni.isaac.orbit.utils import string_to_callable

from .manager_term_cfg import ManagerTermBaseCfg
Expand Down Expand Up @@ -164,6 +165,33 @@ def reset(self, env_ids: Sequence[int] | None = None) -> dict[str, float]:
"""
return {}

def find_terms(self, name_keys: str | Sequence[str]) -> list[str]:
"""Find terms in the manager based on the names.
This function searches the manager for terms based on the names. The names can be
specified as regular expressions or a list of regular expressions. The search is
performed on the active terms in the manager.
Please check the :meth:`omni.isaac.orbit.utils.string_utils.resolve_matching_names` function for more
information on the name matching.
Args:
name_keys: A regular expression or a list of regular expressions to match the term names.
Returns:
A list of term names that match the input keys.
"""
# resolve search keys
if isinstance(self.active_terms, dict):
list_of_strings = []
for names in self.active_terms.values():
list_of_strings.extend(names)
else:
list_of_strings = self.active_terms

# return the matching names
return string_utils.resolve_matching_names(name_keys, list_of_strings)[1]

"""
Implementation specific.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ def __init__(self, cfg: object, env: RLTaskEnv):
"""
super().__init__(cfg, env)
# prepare extra info to store individual termination term information
self._episode_dones = dict()
self._term_dones = dict()
for term_name in self._term_names:
self._episode_dones[term_name] = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool)
self._term_dones[term_name] = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool)
# create buffer for managing termination per environment
self._truncated_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool)
self._terminated_buf = torch.zeros_like(self._truncated_buf)
Expand Down Expand Up @@ -133,11 +133,11 @@ def reset(self, env_ids: Sequence[int] | None = None) -> dict[str, torch.Tensor]
env_ids = slice(None)
# add to episode dict
extras = {}
for key in self._episode_dones.keys():
for key in self._term_dones.keys():
# store information
extras["Episode Termination/" + key] = torch.count_nonzero(self._episode_dones[key][env_ids]).item()
extras["Episode Termination/" + key] = torch.count_nonzero(self._term_dones[key][env_ids]).item()
# reset episode dones
self._episode_dones[key][env_ids] = False
self._term_dones[key][env_ids] = False
# reset all the reward terms
for term_cfg in self._class_term_cfgs:
term_cfg.func.reset(env_ids=env_ids)
Expand Down Expand Up @@ -165,10 +165,21 @@ def compute(self) -> torch.Tensor:
else:
self._terminated_buf |= value
# add to episode dones
self._episode_dones[name] |= value
self._term_dones[name] |= value
# return combined termination signal
return self._truncated_buf | self._terminated_buf

def get_term(self, name: str) -> torch.Tensor:
"""Returns the termination term with the specified name.
Args:
name: The name of the termination term.
Returns:
The corresponding termination term value. Shape is (num_envs,).
"""
return self._term_dones[name]

"""
Operations - Term settings.
"""
Expand Down

0 comments on commit 1a42eb9

Please sign in to comment.