Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Train High-Level Policies in Hierarchical Approaches #1053

Merged
merged 55 commits into from
Feb 10, 2023
Merged
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
a787e34
Trainable HL policy
ASzot Dec 24, 2022
f6d1f11
Working on HRL trainer
ASzot Dec 26, 2022
daa84db
Fixed config setup
ASzot Dec 27, 2022
7ae2c6b
Train hl modif (#1057)
akshararai Jan 6, 2023
bb2e4da
Update README.md
xavierpuigf Jan 9, 2023
628063d
Match tensor device when checking if the skills is done
xavierpuigf Jan 9, 2023
d179ecf
Train hl modif2 (#1076)
xavierpuigf Jan 13, 2023
515b3cf
Merged with main
ASzot Jan 19, 2023
8eddcd1
Fixed RNN problem
ASzot Jan 20, 2023
fa94bfc
Fixed tests
ASzot Jan 20, 2023
11659b4
Fixed formatting
ASzot Jan 20, 2023
633ebf3
Fixed device issues. Cleaned up configs.
ASzot Jan 27, 2023
1c82390
More config cleanup
ASzot Jan 27, 2023
2e973f5
Addressing PR comments
ASzot Jan 27, 2023
3a805ea
Updated circular reference
ASzot Jan 27, 2023
7d92072
Addressing PR comments
ASzot Jan 27, 2023
4696551
Addressing PR comments
ASzot Jan 27, 2023
c11601f
Update habitat-baselines/habitat_baselines/rl/hrl/skills/skill.py
ASzot Jan 27, 2023
90155e1
Addressing PR comments
ASzot Jan 27, 2023
755acb4
Resolved storage problem
ASzot Jan 28, 2023
8d07335
merged
ASzot Jan 28, 2023
4754fd7
Update oracle_nav.py
xavierpuigf Jan 30, 2023
587672e
Fix for agent rotation
ASzot Jan 30, 2023
f16bc14
Missing key
ASzot Jan 30, 2023
22cb83f
More docs
ASzot Jan 31, 2023
df8b7a6
Update habitat-baselines/habitat_baselines/rl/hrl/hrl_rollout_storage.py
ASzot Jan 31, 2023
422c036
Update habitat-baselines/habitat_baselines/rl/hrl/utils.py
ASzot Jan 31, 2023
5643b0e
Updated name
ASzot Jan 31, 2023
ebe877e
fixes for training
ASzot Feb 1, 2023
e1a8727
Fixed env issue
ASzot Feb 2, 2023
2673f4f
Fixed deprecated configs
ASzot Feb 2, 2023
03faec7
Merge branch 'main' into train_hl
ASzot Feb 2, 2023
dd01d50
Speed fix
ASzot Feb 3, 2023
b20fcb1
Updated configs
ASzot Feb 3, 2023
7982f40
Pddl action fixes
ASzot Feb 3, 2023
902bfa1
Removed speed opts. Fixed some bugs
ASzot Feb 4, 2023
74278bd
Fixed rendering text to the frame
ASzot Feb 4, 2023
63f610c
Merged with main
ASzot Feb 4, 2023
49a71a4
Addressing Vince's PR comments
ASzot Feb 4, 2023
b41133a
Refactored navigation to be much clearer
ASzot Feb 4, 2023
e7a877b
Fixed some of the tests
ASzot Feb 5, 2023
5c213e3
Adddressed PR comments
ASzot Feb 6, 2023
d9721f1
Fixed rotation issue
ASzot Feb 6, 2023
f8387de
Fixed black
ASzot Feb 6, 2023
1c8f54c
Addressed PR comments
ASzot Feb 8, 2023
f2c6731
Addressed PR comments
ASzot Feb 8, 2023
4b65b3f
Merge branch 'main' into train_hl
ASzot Feb 8, 2023
11e77c3
Fixed config
ASzot Feb 8, 2023
6f5ea76
Fixed typo
ASzot Feb 8, 2023
cb6ce62
Fixed another typo
ASzot Feb 8, 2023
6d4b968
CI
ASzot Feb 8, 2023
d6c957e
Merge branch 'main' into train_hl
vincentpierre Feb 8, 2023
9d2c2f5
Updated to work with older pytorch version
ASzot Feb 9, 2023
a17fbfa
Merge branch 'main' into train_hl
ASzot Feb 9, 2023
48142fb
renaming --exp-config to --config-name again
vincentpierre Feb 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions habitat-baselines/habitat_baselines/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,32 @@ To use them download pre-trained pytorch models from [link](https://dl.fbaipubli
The `habitat_baselines/config/pointnav/ppo_pointnav.yaml` config has better hyperparameters for large scale training and loads the [Gibson PointGoal Navigation Dataset](/README.md#datasets) instead of the test scenes.
Change the `/benchmark/nav/pointnav: pointnav_gibson` in `habitat_baselines/config/pointnav/ppo_pointnav.yaml` to `/benchmark/nav/pointnav: pointnav_mp3d` in the defaults list for training on [MatterPort3D PointGoal Navigation Dataset](/README.md#datasets).

### Hierarchical Reinforcement Learning (HRL)

We provide a two-layer hierarchical policy class, consisting of a low-level skill that moves the robot, and a high-level policy that reasons about which low-level skill to use in the current state. This can be especially powerful in long-horizon mobile manipulation tasks, like those introduced in [Habitat2.0](https://arxiv.org/abs/2106.14405). Both the low- and high- level can be either learned or an oracle. For oracle high-level we use [PDDL](https://planning.wiki/guide/whatis/pddl), and for oracle low-level we use instantaneous transitions, with the environment set to the final desired state. Additionally, for navigation, we provide an oracle navigation skill that uses A-star and the map of the environment to move the robot to its goal.

To run the following examples, you need the [ReplicaCAD dataset](https://github.com/facebookresearch/habitat-sim/blob/main/DATASETS.md#replicacad).

To train a high-level policy, while using pre-learned low-level skills (SRL baseline from [Habitat2.0](https://arxiv.org/abs/2106.14405)), you can run:

```bash
python -u habitat-baselines/habitat_baselines/run.py \
--exp-config habitat-baselines/habitat_baselines/config/rearrange/rl_hierarchical.yaml \
--run-type train
```
To run a rearrangement episode with oracle low-level skills and a fixed task planner, run:

```bash
python -u habitat-baselines/habitat_baselines/run.py \
--exp-config habitat-baselines/habitat_baselines/config/rearrange/rl_hierarchical.yaml \
--run-type eval \
habitat_baselines/rl/policy=hl_fixed \
habitat_baselines/rl/policy/hierarchical_policy/defined_skills=oracle_skills
```

To change the task (like set table) that you train your skills on, you can change the line `/habitat/task/rearrange: rearrange_easy` to `/habitat/task/rearrange: set_table` in the defaults of your config.


### Classic

**SLAM based**
Expand Down
12 changes: 4 additions & 8 deletions habitat-baselines/habitat_baselines/agents/ppo_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,23 +126,19 @@ def reset(self) -> None:
def act(self, observations: Observations) -> Dict[str, int]:
batch = batch_obs([observations], device=self.device)
with torch.no_grad():
(
_,
actions,
_,
self.test_recurrent_hidden_states,
) = self.actor_critic.act(
action_data = self.actor_critic.act(
batch,
self.test_recurrent_hidden_states,
self.prev_actions,
self.not_done_masks,
deterministic=False,
)
self.test_recurrent_hidden_states = action_data.rnn_hidden_states
# Make masks not done till reset (end of episode) will be called
self.not_done_masks.fill_(True)
self.prev_actions.copy_(actions) # type: ignore
self.prev_actions.copy_(action_data.actions) # type: ignore

return {"action": actions[0][0].item()}
return {"action": action_data.env_actions[0][0].item()}


def main():
Expand Down
25 changes: 25 additions & 0 deletions habitat-baselines/habitat_baselines/common/baseline_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,5 +136,30 @@ def register_auxiliary_loss(
def get_auxiliary_loss(cls, name: str):
return cls._get_impl("aux_loss", name)

@classmethod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

storage and updater are not very descriptive. Can you add a little text here to clarify what is the base classes that are being registered?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe even add a

assert isinstance(to_register, RolloutStorage)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is the assert ?

def register_storage(cls, to_register=None, *, name: Optional[str] = None):
"""
Registers data storage for storing data in the policy rollout in the
trainer and then for fetching data batches for the updater.
"""

return cls._register_impl("storage", to_register, name)

@classmethod
def get_storage(cls, name: str):
return cls._get_impl("storage", name)

@classmethod
def register_updater(cls, to_register=None, *, name: Optional[str] = None):
"""
Registers a policy updater.
"""

return cls._register_impl("updater", to_register, name)

@classmethod
def get_updater(cls, name: str):
return cls._get_impl("updater", name)


baseline_registry = BaselineRegistry()
24 changes: 21 additions & 3 deletions habitat-baselines/habitat_baselines/common/rollout_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,24 @@
# LICENSE file in the root directory of this source tree.

import warnings
from typing import Any, Dict, Iterator, Optional, Tuple
from typing import Any, Dict, Iterator, Optional

import numpy as np
import torch

from habitat_baselines.common.baseline_registry import baseline_registry
from habitat_baselines.common.tensor_dict import DictTree, TensorDict
from habitat_baselines.rl.models.rnn_state_encoder import (
build_pack_info_from_dones,
build_rnn_build_seq_info,
)
from habitat_baselines.utils.common import (
get_num_actions,
is_continuous_action_space,
)


@baseline_registry.register_storage
class RolloutStorage:
r"""Class for storing rollout information for RL trainers."""

Expand All @@ -28,10 +34,21 @@ def __init__(
action_space,
recurrent_hidden_state_size,
num_recurrent_layers=1,
action_shape: Optional[Tuple[int]] = None,
is_double_buffered: bool = False,
discrete_actions: bool = True,
):
if is_continuous_action_space(action_space):
# Assume ALL actions are NOT discrete
action_shape = (
get_num_actions(
action_space,
),
)
discrete_actions = False
else:
# For discrete pointnav
action_shape = (1,)
discrete_actions = True
ASzot marked this conversation as resolved.
Show resolved Hide resolved

self.buffers = TensorDict()
self.buffers["observations"] = TensorDict()

Expand Down Expand Up @@ -115,6 +132,7 @@ def insert(
rewards=None,
next_masks=None,
buffer_index: int = 0,
**kwargs,
):
if not self.is_double_buffered:
assert buffer_index == 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass, field
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple

from hydra.core.config_store import ConfigStore
from omegaconf import MISSING
Expand Down Expand Up @@ -224,10 +224,45 @@ class Eq2CubeConfig(ObsTransformConfig):


@dataclass
class HierarchicalPolicy(HabitatBaselinesBaseConfig):
class HrlDefinedSkillConfig(HabitatBaselinesBaseConfig):
"""
Defines a low-level skill to be used in the hierarchical policy.
"""

skill_name: str = MISSING
name: str = "PointNavResNetPolicy"
akshararai marked this conversation as resolved.
Show resolved Hide resolved
action_distribution_type: str = "gaussian"
load_ckpt_file: str = ""
max_skill_steps: int = 200
# If true, the stop action will be called if the skill times out.
force_end_on_timeout: bool = True
# Overrides the config file of a neural network skill rather than loading
# the config file from the checkpoint file.
force_config_file: str = ""
at_resting_threshold: float = 0.15
# If true, this willapply the post-conditions of the skill after it
# terminates.
apply_postconds: bool = False
ASzot marked this conversation as resolved.
Show resolved Hide resolved
obs_skill_inputs: List[str] = field(default_factory=list)
obs_skill_input_dim: int = 3
start_zone_radius: float = 0.3
# For the oracle navigation skill
action_name: str = "base_velocity"
stop_thresh: float = 0.001
ASzot marked this conversation as resolved.
Show resolved Hide resolved
# For the reset_arm_skill
reset_joint_state: List[float] = MISSING
# The set of PDDL action names (as defined in the PDDL domain file) that
# map to this skill. If not specified,the name of the skill must match the
# PDDL action name.
pddl_action_names: Optional[List[str]] = None


@dataclass
class HierarchicalPolicyConfig(HabitatBaselinesBaseConfig):
high_level_policy: Dict[str, Any] = MISSING
defined_skills: Dict[str, Any] = field(default_factory=dict)
use_skills: Dict[str, str] = field(default_factory=dict)
defined_skills: Dict[str, HrlDefinedSkillConfig] = field(
default_factory=dict
)


@dataclass
Expand All @@ -238,7 +273,7 @@ class PolicyConfig(HabitatBaselinesBaseConfig):
# For gaussian action distribution:
action_dist: ActionDistributionConfig = ActionDistributionConfig()
obs_transforms: Dict[str, ObsTransformConfig] = field(default_factory=dict)
hierarchical_policy: HierarchicalPolicy = MISSING
hierarchical_policy: HierarchicalPolicyConfig = MISSING


@dataclass
Expand Down Expand Up @@ -345,6 +380,8 @@ class HabitatBaselinesConfig(HabitatBaselinesBaseConfig):
# )
# cmd_trailing_opts: List[str] = field(default_factory=list)
trainer_name: str = "ppo"
updater_name: str = "PPO"
distrib_updater_name: str = "DDPPO"
Comment on lines 371 to +373
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks super redundant.
Some of these are capitalized.
Neither distrib_updater_name nor updater_name are ever changed in any of the configurations.
Can distrib_updater_name and updater_name be properties of the trainer ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I think they are best here because they are needed to instantiate the updater. But this was a mistake, rl_hierarchical was supposed to change these properties.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't these just be inferred from other configurations ?

torch_gpu_id: int = 0
tensorboard_dir: str = "tb"
writer_type: str = "tb"
Expand All @@ -355,6 +392,7 @@ class HabitatBaselinesConfig(HabitatBaselinesBaseConfig):
eval_ckpt_path_dir: str = "data/checkpoints"
num_environments: int = 16
num_processes: int = -1 # deprecated
rollout_storage: str = "RolloutStorage"
ASzot marked this conversation as resolved.
Show resolved Hide resolved
checkpoint_folder: str = "data/checkpoints"
num_updates: int = 10000
num_checkpoints: int = 10
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
open_cab:
skill_name: "ArtObjSkillPolicy"
load_ckpt_file: "data/models/open_cab.pth"

open_fridge:
skill_name: "ArtObjSkillPolicy"
load_ckpt_file: "data/models/open_fridge.pth"

close_cab:
skill_name: "ArtObjSkillPolicy"
load_ckpt_file: "data/models/close_cab.pth"

close_fridge:
skill_name: "ArtObjSkillPolicy"
load_ckpt_file: "data/models/close_fridge.pth"

pick:
skill_name: "PickSkillPolicy"
obs_skill_inputs: ["obj_start_sensor"]
load_ckpt_file: "data/models/pick.pth"

place:
skill_name: "PlaceSkillPolicy"
obs_skill_inputs: ["obj_goal_sensor"]
load_ckpt_file: "data/models/place.pth"

wait:
skill_name: "WaitSkillPolicy"
max_skill_steps: -1
force_end_on_timeout: False

nav_to_obj:
skill_name: "NavSkillPolicy"
obs_skill_inputs: ["goal_to_agent_gps_compass"]
load_ckpt_file: "data/models/nav.pth"
max_skill_steps: 300
obs_skill_input_dim: 2
pddl_action_names: ["nav", "nav_to_receptacle"]

reset_arm:
skill_name: "ResetArmSkill"
max_skill_steps: 50
reset_joint_state: [-4.50e-01, -1.08e00, 9.95e-02, 9.38e-01, -7.88e-04, 1.57e00, 4.62e-03]
force_end_on_timeout: False
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Oracle skills that will teleport to the skill post-condition. When automatically setting predicates you may want to run the simulation in kinematic mode:
akshararai marked this conversation as resolved.
Show resolved Hide resolved
# To run in kinematic mode, add: `habitat.simulator.kinematic_mode=True habitat.simulator.ac_freq_ratio=1 habitat.task.measurements.force_terminate.max_accum_force=-1.0 habitat.task.measurements.force_terminate.max_instant_force=-1.0`

defaults:
- /habitat/task/actions:
- pddl_apply_action

open_cab:
skill_name: "NoopSkillPolicy"
max_skill_steps: 1
apply_postconds: True
force_end_on_timeout: False
pddl_action_names: ["open_cab_by_name"]

open_fridge:
skill_name: "NoopSkillPolicy"
max_skill_steps: 1
apply_postconds: True
force_end_on_timeout: False
pddl_action_names: ["open_fridge_by_name"]

close_cab:
skill_name: "NoopSkillPolicy"
obs_skill_inputs: ["obj_start_sensor"]
max_skill_steps: 1
force_end_on_timeout: False
pddl_action_names: ["close_cab_by_name"]

close_fridge:
skill_name: "NoopSkillPolicy"
obs_skill_inputs: ["obj_start_sensor"]
max_skill_steps: 1
apply_postconds: True
force_end_on_timeout: False
pddl_action_names: ["close_fridge_by_name"]

pick:
skill_name: "NoopSkillPolicy"
obs_skill_inputs: ["obj_start_sensor"]
max_skill_steps: 1
apply_postconds: True
force_end_on_timeout: False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@akshararai check if all these parameters are needed, and add a sample documentation for them somewhere.


place:
skill_name: "NoopSkillPolicy"
obs_skill_inputs: ["obj_goal_sensor"]
max_skill_steps: 1
apply_postconds: True
force_end_on_timeout: False

wait:
skill_name: "WaitSkillPolicy"
max_skill_steps: -1

nav_to_obj:
skill_name: "NoopSkillPolicy"
obs_skill_inputs: ["goal_to_agent_gps_compass"]
max_skill_steps: 1
apply_postconds: True
force_end_on_timeout: False
obs_skill_input_dim: 2
pddl_action_names: ["nav", "nav_to_receptacle_by_name"]

reset_arm:
skill_name: "ResetArmSkill"
max_skill_steps: 50
reset_joint_state: [-4.50e-01, -1.07e00, 9.95e-02, 9.38e-01, -7.88e-04, 1.57e00, 4.62e-03]
force_end_on_timeout: False
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
name: "HierarchicalPolicy"
obs_transforms:
add_virtual_keys:
virtual_keys:
"goal_to_agent_gps_compass": 2
hierarchical_policy:
high_level_policy:
name: "FixedHighLevelPolicy"
add_arm_rest: True
defined_skills: {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: "HierarchicalPolicy"
obs_transforms:
add_virtual_keys:
virtual_keys:
"goal_to_agent_gps_compass": 2
hierarchical_policy:
high_level_policy:
name: "NeuralHighLevelPolicy"
allowed_actions:
- nav
- pick
- place
- nav_to_receptacle_by_name
- open_fridge_by_name
- close_fridge_by_name
- open_cab_by_name
- close_cab_by_name
allow_other_place: False
hidden_dim: 512
use_rnn: True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@akshararai check which of these fields are needed

rnn_type: 'LSTM'
backbone: resnet18
normalize_visual_inputs: False
num_rnn_layers: 2
policy_input_keys:
- "robot_head_depth"
defined_skills: {}
ASzot marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
name: "PointNavResNetPolicy"
action_distribution_type: "gaussian"
action_dist:
use_log_std: True
Loading