-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Train High-Level Policies in Hierarchical Approaches #1053
Changes from 3 commits
a787e34
f6d1f11
daa84db
7ae2c6b
bb2e4da
628063d
d179ecf
515b3cf
8eddcd1
fa94bfc
11659b4
633ebf3
1c82390
2e973f5
3a805ea
7d92072
4696551
c11601f
90155e1
755acb4
8d07335
4754fd7
587672e
f16bc14
22cb83f
df8b7a6
422c036
5643b0e
ebe877e
e1a8727
2673f4f
03faec7
dd01d50
b20fcb1
7982f40
902bfa1
74278bd
63f610c
49a71a4
b41133a
e7a877b
5c213e3
d9721f1
f8387de
1c8f54c
f2c6731
4b65b3f
11e77c3
6f5ea76
cb6ce62
6d4b968
d6c957e
9d2c2f5
a17fbfa
48142fb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -219,10 +219,31 @@ class Eq2CubeConfig(ObsTransformConfig): | |
) | ||
|
||
|
||
@dataclass | ||
class HrlDefinedSkill(HabitatBaselinesBaseConfig): | ||
ASzot marked this conversation as resolved.
Show resolved
Hide resolved
|
||
skill_name: str = MISSING | ||
name: str = "PointNavResNetPolicy" | ||
akshararai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
action_distribution_type: str = "gaussian" | ||
load_ckpt_file: str = "" | ||
max_skill_steps: int = 200 | ||
force_end_on_timeout: bool = True | ||
force_config_file: str = "" | ||
at_resting_threshold: float = 0.15 | ||
apply_postconds: bool = False | ||
ASzot marked this conversation as resolved.
Show resolved
Hide resolved
|
||
obs_skill_inputs: List[str] = field(default_factory=list) | ||
obs_skill_input_dim: int = 3 | ||
start_zone_radius: float = 0.3 | ||
# For the oracle navigation skill | ||
nav_action_name: str = "base_velocity" | ||
ASzot marked this conversation as resolved.
Show resolved
Hide resolved
|
||
stop_thresh: float = 0.001 | ||
ASzot marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# For the reset_arm_skill | ||
reset_joint_state: List[float] = MISSING | ||
|
||
|
||
@dataclass | ||
class HierarchicalPolicy(HabitatBaselinesBaseConfig): | ||
ASzot marked this conversation as resolved.
Show resolved
Hide resolved
|
||
high_level_policy: Dict[str, Any] = MISSING | ||
defined_skills: Dict[str, Any] = field(default_factory=dict) | ||
defined_skills: Dict[str, HrlDefinedSkill] = field(default_factory=dict) | ||
use_skills: Dict[str, str] = field(default_factory=dict) | ||
|
||
|
||
|
@@ -383,6 +404,8 @@ class HabitatBaselinesConfig(HabitatBaselinesBaseConfig): | |
# ) | ||
# cmd_trailing_opts: List[str] = field(default_factory=list) | ||
trainer_name: str = "ppo" | ||
updater_name: str = "PPO" | ||
distrib_updater_name: str = "DDPPO" | ||
Comment on lines
371
to
+373
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks super redundant. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, I think they are best here because they are needed to instantiate the updater. But this was a mistake, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't these just be inferred from other configurations ? |
||
torch_gpu_id: int = 0 | ||
video_render_views: List[str] = field(default_factory=list) | ||
tensorboard_dir: str = "tb" | ||
|
@@ -394,6 +417,7 @@ class HabitatBaselinesConfig(HabitatBaselinesBaseConfig): | |
eval_ckpt_path_dir: str = "data/checkpoints" | ||
num_environments: int = 16 | ||
num_processes: int = -1 # deprecated | ||
rollout_storage: str = "RolloutStorage" | ||
ASzot marked this conversation as resolved.
Show resolved
Hide resolved
|
||
checkpoint_folder: str = "data/checkpoints" | ||
num_updates: int = 10000 | ||
num_checkpoints: int = 10 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
# @package _global_ | ||
ASzot marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Pick and place are kinematically simulated. | ||
akshararai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Navigation is fully simulated. | ||
akshararai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
defaults: | ||
- /benchmark/rearrange: rearrange_easy | ||
- /habitat_baselines: habitat_baselines_rl_config_base | ||
- /habitat_baselines/rl/policy/obs_transforms: | ||
- add_virtual_keys_base | ||
- /habitat/task/actions: | ||
- pddl_apply_action | ||
- oracle_nav_action | ||
- arm_action | ||
- base_velocity | ||
- rearrange_stop | ||
- _self_ | ||
|
||
habitat: | ||
gym: | ||
auto_name: RearrangeEasy | ||
obs_keys: | ||
- robot_head_depth | ||
- relative_resting_position | ||
- obj_start_sensor | ||
- obj_goal_sensor | ||
- obj_start_gps_compass | ||
- obj_goal_gps_compass | ||
- joint | ||
- is_holding | ||
- localization_sensor | ||
|
||
|
||
habitat_baselines: | ||
akshararai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
verbose: False | ||
trainer_name: "ddppo" | ||
torch_gpu_id: 0 | ||
tensorboard_dir: "tb" | ||
rollout_storage: "HrlRolloutStorage" | ||
updater_name: "HrlPPO" | ||
distrib_updater_name: "HrlDDPPO" | ||
video_dir: "video_dir" | ||
video_fps: 30 | ||
video_render_views: | ||
- "third_rgb_sensor" | ||
test_episode_count: -1 | ||
eval_ckpt_path_dir: "" | ||
num_environments: 3 | ||
writer_type: 'tb' | ||
checkpoint_folder: "data/new_checkpoints" | ||
num_updates: -1 | ||
total_num_steps: 1.0e8 | ||
log_interval: 10 | ||
num_checkpoints: 20 | ||
force_torch_single_threaded: True | ||
eval_keys_to_include_in_name: ['reward', 'force', 'composite_success'] | ||
load_resume_state_config: False | ||
|
||
eval: | ||
use_ckpt_config: False | ||
should_load_ckpt: False | ||
video_option: ["disk"] | ||
|
||
rl: | ||
policy: | ||
akshararai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
name: "HierarchicalPolicy" | ||
obs_transforms: | ||
add_virtual_keys: | ||
virtual_keys: | ||
"goal_to_agent_gps_compass": 2 | ||
hierarchical_policy: | ||
high_level_policy: | ||
name: "NeuralHighLevelPolicy" | ||
allow_other_place: False | ||
hidden_dim: 512 | ||
use_rnn: True | ||
rnn_type: 'LSTM' | ||
backbone: resnet18 | ||
normalize_visual_inputs: False | ||
num_rnn_layers: 2 | ||
policy_input_keys: | ||
- "robot_head_depth" | ||
defined_skills: | ||
open_cab: | ||
skill_name: "NoopSkillPolicy" | ||
max_skill_steps: 1 | ||
apply_postconds: True | ||
|
||
open_fridge: | ||
skill_name: "NoopSkillPolicy" | ||
max_skill_steps: 1 | ||
apply_postconds: True | ||
|
||
close_cab: | ||
skill_name: "NoopSkillPolicy" | ||
obs_skill_inputs: ["obj_start_sensor"] | ||
max_skill_steps: 1 | ||
|
||
close_fridge: | ||
skill_name: "NoopSkillPolicy" | ||
obs_skill_inputs: ["obj_start_sensor"] | ||
max_skill_steps: 1 | ||
apply_postconds: True | ||
|
||
pick: | ||
skill_name: "NoopSkillPolicy" | ||
obs_skill_inputs: ["obj_start_sensor"] | ||
max_skill_steps: 1 | ||
apply_postconds: True | ||
|
||
place: | ||
skill_name: "NoopSkillPolicy" | ||
obs_skill_inputs: ["obj_goal_sensor"] | ||
max_skill_steps: 1 | ||
apply_postconds: True | ||
|
||
nav_to_obj: | ||
skill_name: "OracleNavPolicy" | ||
obs_skill_inputs: ["obj_start_sensor", "abs_obj_start_sensor", "obj_goal_sensor", "abs_obj_goal_sensor"] | ||
max_skill_steps: 300 | ||
|
||
wait_skill: | ||
skill_name: "WaitSkillPolicy" | ||
max_skill_steps: -1 | ||
force_end_on_timeout: False | ||
|
||
reset_arm_skill: | ||
skill_name: "ResetArmSkill" | ||
max_skill_steps: 50 | ||
reset_joint_state: [-4.5003259e-01, -1.0799699e00, 9.9526465e-02, 9.3869519e-01, -7.8854430e-04, 1.5702540e00, 4.6168058e-03] | ||
force_end_on_timeout: False | ||
|
||
use_skills: | ||
open_cab: "open_cab" | ||
open_fridge: "open_fridge" | ||
close_cab: "close_cab" | ||
close_fridge: "close_fridge" | ||
pick: "pick" | ||
place: "place" | ||
nav: "nav_to_obj" | ||
nav_to_receptacle: "nav_to_obj" | ||
wait: "wait_skill" | ||
reset_arm: "reset_arm_skill" | ||
|
||
ppo: | ||
# ppo params | ||
clip_param: 0.2 | ||
ppo_epoch: 2 | ||
num_mini_batch: 2 | ||
value_loss_coef: 0.5 | ||
entropy_coef: 0.0001 | ||
lr: 2.5e-4 | ||
eps: 1e-5 | ||
max_grad_norm: 0.2 | ||
num_steps: 128 | ||
use_gae: True | ||
gamma: 0.99 | ||
tau: 0.95 | ||
use_linear_clip_decay: False | ||
use_linear_lr_decay: False | ||
reward_window_size: 50 | ||
|
||
use_normalized_advantage: False | ||
|
||
hidden_size: 512 | ||
|
||
# Use double buffered sampling, typically helps | ||
# when environment time is similar or larger than | ||
# policy inference time during rollout generation | ||
use_double_buffered_sampler: False | ||
|
||
ddppo: | ||
sync_frac: 0.6 | ||
# The PyTorch distributed backend to use | ||
distrib_backend: NCCL | ||
# Visual encoder backbone | ||
pretrained_weights: data/ddppo-models/gibson-2plus-resnet50.pth | ||
# Initialize with pretrained weights | ||
pretrained: False | ||
# Initialize just the visual encoder backbone with pretrained weights | ||
pretrained_encoder: False | ||
# Whether the visual encoder backbone will be trained. | ||
train_encoder: True | ||
# Whether to reset the critic linear layer | ||
reset_critic: False | ||
|
||
# Model parameters | ||
backbone: resnet18 | ||
rnn_type: LSTM | ||
num_recurrent_layers: 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
storage and updater are not very descriptive. Can you add a little text here to clarify what is the base classes that are being registered?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe even add a
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is the assert ?