Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Train High-Level Policies in Hierarchical Approaches #1053

Merged
merged 55 commits into from
Feb 10, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
a787e34
Trainable HL policy
ASzot Dec 24, 2022
f6d1f11
Working on HRL trainer
ASzot Dec 26, 2022
daa84db
Fixed config setup
ASzot Dec 27, 2022
7ae2c6b
Train hl modif (#1057)
akshararai Jan 6, 2023
bb2e4da
Update README.md
xavierpuigf Jan 9, 2023
628063d
Match tensor device when checking if the skills is done
xavierpuigf Jan 9, 2023
d179ecf
Train hl modif2 (#1076)
xavierpuigf Jan 13, 2023
515b3cf
Merged with main
ASzot Jan 19, 2023
8eddcd1
Fixed RNN problem
ASzot Jan 20, 2023
fa94bfc
Fixed tests
ASzot Jan 20, 2023
11659b4
Fixed formatting
ASzot Jan 20, 2023
633ebf3
Fixed device issues. Cleaned up configs.
ASzot Jan 27, 2023
1c82390
More config cleanup
ASzot Jan 27, 2023
2e973f5
Addressing PR comments
ASzot Jan 27, 2023
3a805ea
Updated circular reference
ASzot Jan 27, 2023
7d92072
Addressing PR comments
ASzot Jan 27, 2023
4696551
Addressing PR comments
ASzot Jan 27, 2023
c11601f
Update habitat-baselines/habitat_baselines/rl/hrl/skills/skill.py
ASzot Jan 27, 2023
90155e1
Addressing PR comments
ASzot Jan 27, 2023
755acb4
Resolved storage problem
ASzot Jan 28, 2023
8d07335
merged
ASzot Jan 28, 2023
4754fd7
Update oracle_nav.py
xavierpuigf Jan 30, 2023
587672e
Fix for agent rotation
ASzot Jan 30, 2023
f16bc14
Missing key
ASzot Jan 30, 2023
22cb83f
More docs
ASzot Jan 31, 2023
df8b7a6
Update habitat-baselines/habitat_baselines/rl/hrl/hrl_rollout_storage.py
ASzot Jan 31, 2023
422c036
Update habitat-baselines/habitat_baselines/rl/hrl/utils.py
ASzot Jan 31, 2023
5643b0e
Updated name
ASzot Jan 31, 2023
ebe877e
fixes for training
ASzot Feb 1, 2023
e1a8727
Fixed env issue
ASzot Feb 2, 2023
2673f4f
Fixed deprecated configs
ASzot Feb 2, 2023
03faec7
Merge branch 'main' into train_hl
ASzot Feb 2, 2023
dd01d50
Speed fix
ASzot Feb 3, 2023
b20fcb1
Updated configs
ASzot Feb 3, 2023
7982f40
Pddl action fixes
ASzot Feb 3, 2023
902bfa1
Removed speed opts. Fixed some bugs
ASzot Feb 4, 2023
74278bd
Fixed rendering text to the frame
ASzot Feb 4, 2023
63f610c
Merged with main
ASzot Feb 4, 2023
49a71a4
Addressing Vince's PR comments
ASzot Feb 4, 2023
b41133a
Refactored navigation to be much clearer
ASzot Feb 4, 2023
e7a877b
Fixed some of the tests
ASzot Feb 5, 2023
5c213e3
Adddressed PR comments
ASzot Feb 6, 2023
d9721f1
Fixed rotation issue
ASzot Feb 6, 2023
f8387de
Fixed black
ASzot Feb 6, 2023
1c8f54c
Addressed PR comments
ASzot Feb 8, 2023
f2c6731
Addressed PR comments
ASzot Feb 8, 2023
4b65b3f
Merge branch 'main' into train_hl
ASzot Feb 8, 2023
11e77c3
Fixed config
ASzot Feb 8, 2023
6f5ea76
Fixed typo
ASzot Feb 8, 2023
cb6ce62
Fixed another typo
ASzot Feb 8, 2023
6d4b968
CI
ASzot Feb 8, 2023
d6c957e
Merge branch 'main' into train_hl
vincentpierre Feb 8, 2023
9d2c2f5
Updated to work with older pytorch version
ASzot Feb 9, 2023
a17fbfa
Merge branch 'main' into train_hl
ASzot Feb 9, 2023
48142fb
renaming --exp-config to --config-name again
vincentpierre Feb 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions habitat-baselines/habitat_baselines/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,31 @@ To use them download pre-trained pytorch models from [link](https://dl.fbaipubli
The `habitat_baselines/config/pointnav/ppo_pointnav.yaml` config has better hyperparameters for large scale training and loads the [Gibson PointGoal Navigation Dataset](/README.md#datasets) instead of the test scenes.
Change the `/benchmark/nav/pointnav: pointnav_gibson` in `habitat_baselines/config/pointnav/ppo_pointnav.yaml` to `/benchmark/nav/pointnav: pointnav_mp3d` in the defaults list for training on [MatterPort3D PointGoal Navigation Dataset](/README.md#datasets).

### Hierarchical Reinforcement Learning (HRL)

We provide a two-layer hierarchical policy class, consisting of a low-level skill that moves the robot, and a high-level policy that reasons about which low-level skill to use in the current state. This can be especially powerful in long-horizon mobile manipulation tasks, like those introduced in [Habitat2.0](https://arxiv.org/abs/2106.14405). Both the low- and high- level can be either learned or an oracle. For oracle high-level we use [PDDL](https://planning.wiki/guide/whatis/pddl), and for oracle low-level we use instantaneous transitions, with the environment set to the final desired state. Additionally, for navigation, we provide an oracle navigation skill that uses A-star and the map of the environment to move the robot to its goal.

To run the following examples, you need the [ReplicaCAD dataset](https://github.com/facebookresearch/habitat-sim/blob/main/DATASETS.md#replicacad).

To train a high-level policy, while using pre-learned low-level skills (SRL baseline from [Habitat2.0](https://arxiv.org/abs/2106.14405)), you can run:

```bash
python -u habitat_baselines/run.py \
--exp-config habitat_baselines/config/rearrange/rl_hl_srl_onav.yaml \
--run-type train
ASzot marked this conversation as resolved.
Show resolved Hide resolved
```
To run a rearrangement episode with oracle in both low- and high-level, you can run:

```bash
python -u habitat_baselines/run.py \
--exp-config habitat_baselines/config/rearrange/rl_hl_srl_onav.yaml \
--run-type eval \
habitat_baselines/rl/policy=hierarchical_tp_noop_onav
ASzot marked this conversation as resolved.
Show resolved Hide resolved
```

To change the task (like set table) that you train your skills on, you can change the line `/habitat/task/rearrange: rearrange_easy` to `/habitat/task/rearrange: set_table` in the defaults of your config.


### Classic

**SLAM based**
Expand Down
16 changes: 16 additions & 0 deletions habitat-baselines/habitat_baselines/common/baseline_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,5 +136,21 @@ def register_auxiliary_loss(
def get_auxiliary_loss(cls, name: str):
return cls._get_impl("aux_loss", name)

@classmethod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

storage and updater are not very descriptive. Can you add a little text here to clarify what is the base classes that are being registered?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe even add a

assert isinstance(to_register, RolloutStorage)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is the assert ?

def register_storage(cls, to_register=None, *, name: Optional[str] = None):
return cls._register_impl("storage", to_register, name)

@classmethod
def get_storage(cls, name: str):
return cls._get_impl("storage", name)

@classmethod
def register_updater(cls, to_register=None, *, name: Optional[str] = None):
return cls._register_impl("updater", to_register, name)

@classmethod
def get_updater(cls, name: str):
return cls._get_impl("updater", name)


baseline_registry = BaselineRegistry()
3 changes: 3 additions & 0 deletions habitat-baselines/habitat_baselines/common/rollout_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
import numpy as np
import torch

from habitat_baselines.common.baseline_registry import baseline_registry
from habitat_baselines.common.tensor_dict import DictTree, TensorDict
from habitat_baselines.rl.models.rnn_state_encoder import (
build_pack_info_from_dones,
build_rnn_build_seq_info,
)


@baseline_registry.register_storage
class RolloutStorage:
r"""Class for storing rollout information for RL trainers."""

Expand Down Expand Up @@ -116,6 +118,7 @@ def insert(
rewards=None,
next_masks=None,
buffer_index: int = 0,
**kwargs,
):
if not self.is_double_buffered:
assert buffer_index == 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,31 @@ class Eq2CubeConfig(ObsTransformConfig):
)


@dataclass
class HrlDefinedSkill(HabitatBaselinesBaseConfig):
ASzot marked this conversation as resolved.
Show resolved Hide resolved
skill_name: str = MISSING
name: str = "PointNavResNetPolicy"
akshararai marked this conversation as resolved.
Show resolved Hide resolved
action_distribution_type: str = "gaussian"
load_ckpt_file: str = ""
max_skill_steps: int = 200
force_end_on_timeout: bool = True
force_config_file: str = ""
at_resting_threshold: float = 0.15
apply_postconds: bool = False
ASzot marked this conversation as resolved.
Show resolved Hide resolved
obs_skill_inputs: List[str] = field(default_factory=list)
obs_skill_input_dim: int = 3
start_zone_radius: float = 0.3
# For the oracle navigation skill
nav_action_name: str = "base_velocity"
ASzot marked this conversation as resolved.
Show resolved Hide resolved
stop_thresh: float = 0.001
ASzot marked this conversation as resolved.
Show resolved Hide resolved
# For the reset_arm_skill
reset_joint_state: List[float] = MISSING


@dataclass
class HierarchicalPolicy(HabitatBaselinesBaseConfig):
ASzot marked this conversation as resolved.
Show resolved Hide resolved
high_level_policy: Dict[str, Any] = MISSING
defined_skills: Dict[str, Any] = field(default_factory=dict)
defined_skills: Dict[str, HrlDefinedSkill] = field(default_factory=dict)
use_skills: Dict[str, str] = field(default_factory=dict)


Expand Down Expand Up @@ -383,6 +404,8 @@ class HabitatBaselinesConfig(HabitatBaselinesBaseConfig):
# )
# cmd_trailing_opts: List[str] = field(default_factory=list)
trainer_name: str = "ppo"
updater_name: str = "PPO"
distrib_updater_name: str = "DDPPO"
Comment on lines 371 to +373
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks super redundant.
Some of these are capitalized.
Neither distrib_updater_name nor updater_name are ever changed in any of the configurations.
Can distrib_updater_name and updater_name be properties of the trainer ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I think they are best here because they are needed to instantiate the updater. But this was a mistake, rl_hierarchical was supposed to change these properties.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't these just be inferred from other configurations ?

torch_gpu_id: int = 0
video_render_views: List[str] = field(default_factory=list)
tensorboard_dir: str = "tb"
Expand All @@ -394,6 +417,7 @@ class HabitatBaselinesConfig(HabitatBaselinesBaseConfig):
eval_ckpt_path_dir: str = "data/checkpoints"
num_environments: int = 16
num_processes: int = -1 # deprecated
rollout_storage: str = "RolloutStorage"
ASzot marked this conversation as resolved.
Show resolved Hide resolved
checkpoint_folder: str = "data/checkpoints"
num_updates: int = 10000
num_checkpoints: int = 10
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
name: "HierarchicalPolicy"
obs_transforms:
add_virtual_keys:
virtual_keys:
"goal_to_agent_gps_compass": 2
hierarchical_policy:
high_level_policy:
name: "NeuralHighLevelPolicy"
allow_other_place: False
hidden_dim: 512
use_rnn: True
rnn_type: 'LSTM'
backbone: resnet18
normalize_visual_inputs: False
num_rnn_layers: 2
policy_input_keys:
- "robot_head_depth"
defined_skills:
open_cab:
skill_name: "ArtObjSkillPolicy"
name: "PointNavResNetPolicy"
load_ckpt_file: "data/models/open_cab.pth"

open_fridge:
skill_name: "ArtObjSkillPolicy"
load_ckpt_file: "data/models/open_fridge.pth"

close_cab:
skill_name: "ArtObjSkillPolicy"
obs_skill_inputs: ["obj_start_sensor"]
load_ckpt_file: "data/models/close_cab.pth"

close_fridge:
skill_name: "ArtObjSkillPolicy"
obs_skill_inputs: ["obj_start_sensor"]
load_ckpt_file: "data/models/close_fridge.pth"

pick:
skill_name: "PickSkillPolicy"
name: "PointNavResNetPolicy"
obs_skill_inputs: ["obj_start_sensor"]
load_ckpt_file: "data/models/pick.pth"

place:
skill_name: "PlaceSkillPolicy"
name: "PointNavResNetPolicy"
obs_skill_inputs: ["obj_goal_sensor"]
load_ckpt_file: "data/models/place.pth"

wait_skill:
skill_name: "WaitSkillPolicy"
max_skill_steps: -1
force_end_on_timeout: False

oracle_nav:
skill_name: "OracleNavPolicy"
obs_skill_inputs: ["obj_start_sensor", "abs_obj_start_sensor", "obj_goal_sensor", "abs_obj_goal_sensor"]
max_skill_steps: 300

reset_arm_skill:
skill_name: "ResetArmSkill"
max_skill_steps: 50
reset_joint_state: [-4.50e-01, -1.08e00, 9.95e-02, 9.38e-01, -7.88e-04, 1.57e00, 4.62e-03]
force_end_on_timeout: False

use_skills:
# Uncomment if you are also using these skills
# open_cab: "open_cab"
# open_fridge: "open_fridge"
# close_cab: "open_cab"
# close_fridge: "open_fridge"
pick: "pick"
place: "place"
nav: "oracle_nav"
nav_to_receptacle: "oracle_nav"
wait: "wait_skill"
reset_arm: "reset_arm_skill"
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@

name: "HierarchicalPolicy"
obs_transforms:
add_virtual_keys:
virtual_keys:
"goal_to_agent_gps_compass": 2
hierarchical_policy:
high_level_policy:
name: "FixedHighLevelPolicy"
add_arm_rest: True

defined_skills:
open_cab:
skill_name: "NoopSkillPolicy"
max_skill_steps: 1
apply_postconds: True

open_fridge:
skill_name: "NoopSkillPolicy"
max_skill_steps: 1
apply_postconds: True

close_cab:
skill_name: "NoopSkillPolicy"
obs_skill_inputs: ["obj_start_sensor"]
max_skill_steps: 1

close_fridge:
skill_name: "NoopSkillPolicy"
obs_skill_inputs: ["obj_start_sensor"]
max_skill_steps: 1
apply_postconds: True

pick:
skill_name: "NoopSkillPolicy"
obs_skill_inputs: ["obj_start_sensor"]
max_skill_steps: 1
apply_postconds: True

place:
skill_name: "NoopSkillPolicy"
obs_skill_inputs: ["obj_goal_sensor"]
max_skill_steps: 1
apply_postconds: True

nav_to_obj:
skill_name: "OracleNavPolicy"
obs_skill_inputs: ["obj_start_sensor", "abs_obj_start_sensor", "obj_goal_sensor", "abs_obj_goal_sensor"]
max_skill_steps: 300

wait_skill:
skill_name: "WaitSkillPolicy"
max_skill_steps: -1
force_end_on_timeout: False

reset_arm_skill:
skill_name: "ResetArmSkill"
max_skill_steps: 50
reset_joint_state: [-4.50e-01, -1.07e00, 9.95e-02, 9.38e-01, -7.88e-04, 1.57e00, 4.62e-03]
force_end_on_timeout: False

use_skills:
open_cab: "open_cab"
open_fridge: "open_fridge"
close_cab: "close_cab"
close_fridge: "close_fridge"
pick: "pick"
place: "place"
nav: "nav_to_obj"
nav_to_receptacle: "nav_to_obj"
wait: "wait_skill"
reset_arm: "reset_arm_skill"
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
name: "HierarchicalPolicy"
obs_transforms:
add_virtual_keys:
virtual_keys:
"goal_to_agent_gps_compass": 2
hierarchical_policy:
high_level_policy:
name: "FixedHighLevelPolicy"
add_arm_rest: True
defined_skills:
open_cab:
skill_name: "ArtObjSkillPolicy"
name: "PointNavResNetPolicy"
load_ckpt_file: "data/models/open_cab.pth"

open_fridge:
skill_name: "ArtObjSkillPolicy"
skill_name: "ArtObjSkillPolicy"
load_ckpt_file: "data/models/open_fridge.pth"

close_cab:
skill_name: "ArtObjSkillPolicy"
obs_skill_inputs: ["obj_start_sensor"]
load_ckpt_file: "data/models/close_cab.pth"

close_fridge:
skill_name: "ArtObjSkillPolicy"
obs_skill_inputs: ["obj_start_sensor"]
load_ckpt_file: "data/models/close_fridge.pth"

pick:
skill_name: "PickSkillPolicy"
obs_skill_inputs: ["obj_start_sensor"]
load_ckpt_file: "data/models/pick.pth"

place:
skill_name: "PlaceSkillPolicy"
obs_skill_inputs: ["obj_goal_sensor"]
load_ckpt_file: "data/models/place.pth"

wait_skill:
skill_name: "WaitSkillPolicy"
max_skill_steps: -1
force_end_on_timeout: False

nav_to_obj:
skill_name: "NavSkillPolicy"
obs_skill_inputs: ["goal_to_agent_gps_compass"]
obs_skill_input_dim: 2
load_ckpt_file: "data/models/nav.pth"
max_skill_steps: 300
force_end_on_timeout: False

reset_arm_skill:
skill_name: "ResetArmSkill"
max_skill_steps: 50
reset_joint_state: [-4.50e-01, -1.08e00, 9.95e-02, 9.38e-01, -7.88e-04, 1.57e00, 4.62e-03]
force_end_on_timeout: False

use_skills:
# Uncomment if you are also using these skills
# open_cab: "open_cab"
# open_fridge: "open_fridge"
# close_cab: "open_cab"
# close_fridge: "open_fridge"
pick: "pick"
place: "place"
nav: "nav_to_obj"
nav_to_receptacle: "nav_to_obj"
wait: "wait_skill"
reset_arm: "reset_arm_skill"
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
name: "PointNavResNetPolicy"
action_distribution_type: "gaussian"
action_dist:
use_log_std: True
Loading