-
Notifications
You must be signed in to change notification settings - Fork 508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Train High-Level Policies in Hierarchical Approaches #1053
Changes from 44 commits
a787e34
f6d1f11
daa84db
7ae2c6b
bb2e4da
628063d
d179ecf
515b3cf
8eddcd1
fa94bfc
11659b4
633ebf3
1c82390
2e973f5
3a805ea
7d92072
4696551
c11601f
90155e1
755acb4
8d07335
4754fd7
587672e
f16bc14
22cb83f
df8b7a6
422c036
5643b0e
ebe877e
e1a8727
2673f4f
03faec7
dd01d50
b20fcb1
7982f40
902bfa1
74278bd
63f610c
49a71a4
b41133a
e7a877b
5c213e3
d9721f1
f8387de
1c8f54c
f2c6731
4b65b3f
11e77c3
6f5ea76
cb6ce62
6d4b968
d6c957e
9d2c2f5
a17fbfa
48142fb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
# LICENSE file in the root directory of this source tree. | ||
|
||
from dataclasses import dataclass, field | ||
from typing import Any, Dict, List, Tuple | ||
from typing import Any, Dict, List, Optional, Tuple | ||
|
||
from hydra.core.config_store import ConfigStore | ||
from omegaconf import MISSING | ||
|
@@ -224,10 +224,45 @@ class Eq2CubeConfig(ObsTransformConfig): | |
|
||
|
||
@dataclass | ||
class HierarchicalPolicy(HabitatBaselinesBaseConfig): | ||
class HrlDefinedSkillConfig(HabitatBaselinesBaseConfig): | ||
""" | ||
Defines a low-level skill to be used in the hierarchical policy. | ||
""" | ||
|
||
skill_name: str = MISSING | ||
name: str = "PointNavResNetPolicy" | ||
akshararai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
action_distribution_type: str = "gaussian" | ||
load_ckpt_file: str = "" | ||
max_skill_steps: int = 200 | ||
# If true, the stop action will be called if the skill times out. | ||
force_end_on_timeout: bool = True | ||
# Overrides the config file of a neural network skill rather than loading | ||
# the config file from the checkpoint file. | ||
force_config_file: str = "" | ||
at_resting_threshold: float = 0.15 | ||
# If true, this willapply the post-conditions of the skill after it | ||
# terminates. | ||
apply_postconds: bool = False | ||
ASzot marked this conversation as resolved.
Show resolved
Hide resolved
|
||
obs_skill_inputs: List[str] = field(default_factory=list) | ||
obs_skill_input_dim: int = 3 | ||
start_zone_radius: float = 0.3 | ||
# For the oracle navigation skill | ||
action_name: str = "base_velocity" | ||
stop_thresh: float = 0.001 | ||
ASzot marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# For the reset_arm_skill | ||
reset_joint_state: List[float] = MISSING | ||
# The set of PDDL action names (as defined in the PDDL domain file) that | ||
# map to this skill. If not specified,the name of the skill must match the | ||
# PDDL action name. | ||
pddl_action_names: Optional[List[str]] = None | ||
|
||
|
||
@dataclass | ||
class HierarchicalPolicyConfig(HabitatBaselinesBaseConfig): | ||
high_level_policy: Dict[str, Any] = MISSING | ||
defined_skills: Dict[str, Any] = field(default_factory=dict) | ||
use_skills: Dict[str, str] = field(default_factory=dict) | ||
defined_skills: Dict[str, HrlDefinedSkillConfig] = field( | ||
default_factory=dict | ||
) | ||
|
||
|
||
@dataclass | ||
|
@@ -238,7 +273,7 @@ class PolicyConfig(HabitatBaselinesBaseConfig): | |
# For gaussian action distribution: | ||
action_dist: ActionDistributionConfig = ActionDistributionConfig() | ||
obs_transforms: Dict[str, ObsTransformConfig] = field(default_factory=dict) | ||
hierarchical_policy: HierarchicalPolicy = MISSING | ||
hierarchical_policy: HierarchicalPolicyConfig = MISSING | ||
|
||
|
||
@dataclass | ||
|
@@ -345,6 +380,8 @@ class HabitatBaselinesConfig(HabitatBaselinesBaseConfig): | |
# ) | ||
# cmd_trailing_opts: List[str] = field(default_factory=list) | ||
trainer_name: str = "ppo" | ||
updater_name: str = "PPO" | ||
distrib_updater_name: str = "DDPPO" | ||
Comment on lines
371
to
+373
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks super redundant. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, I think they are best here because they are needed to instantiate the updater. But this was a mistake, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't these just be inferred from other configurations ? |
||
torch_gpu_id: int = 0 | ||
tensorboard_dir: str = "tb" | ||
writer_type: str = "tb" | ||
|
@@ -355,6 +392,7 @@ class HabitatBaselinesConfig(HabitatBaselinesBaseConfig): | |
eval_ckpt_path_dir: str = "data/checkpoints" | ||
num_environments: int = 16 | ||
num_processes: int = -1 # deprecated | ||
rollout_storage: str = "RolloutStorage" | ||
ASzot marked this conversation as resolved.
Show resolved
Hide resolved
|
||
checkpoint_folder: str = "data/checkpoints" | ||
num_updates: int = 10000 | ||
num_checkpoints: int = 10 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
open_cab: | ||
skill_name: "ArtObjSkillPolicy" | ||
load_ckpt_file: "data/models/open_cab.pth" | ||
|
||
open_fridge: | ||
skill_name: "ArtObjSkillPolicy" | ||
load_ckpt_file: "data/models/open_fridge.pth" | ||
|
||
close_cab: | ||
skill_name: "ArtObjSkillPolicy" | ||
load_ckpt_file: "data/models/close_cab.pth" | ||
|
||
close_fridge: | ||
skill_name: "ArtObjSkillPolicy" | ||
load_ckpt_file: "data/models/close_fridge.pth" | ||
|
||
pick: | ||
skill_name: "PickSkillPolicy" | ||
obs_skill_inputs: ["obj_start_sensor"] | ||
load_ckpt_file: "data/models/pick.pth" | ||
|
||
place: | ||
skill_name: "PlaceSkillPolicy" | ||
obs_skill_inputs: ["obj_goal_sensor"] | ||
load_ckpt_file: "data/models/place.pth" | ||
|
||
wait: | ||
skill_name: "WaitSkillPolicy" | ||
max_skill_steps: -1 | ||
force_end_on_timeout: False | ||
|
||
nav_to_obj: | ||
skill_name: "NavSkillPolicy" | ||
obs_skill_inputs: ["goal_to_agent_gps_compass"] | ||
load_ckpt_file: "data/models/nav.pth" | ||
max_skill_steps: 300 | ||
obs_skill_input_dim: 2 | ||
pddl_action_names: ["nav", "nav_to_receptacle"] | ||
|
||
reset_arm: | ||
skill_name: "ResetArmSkill" | ||
max_skill_steps: 50 | ||
reset_joint_state: [-4.50e-01, -1.08e00, 9.95e-02, 9.38e-01, -7.88e-04, 1.57e00, 4.62e-03] | ||
force_end_on_timeout: False |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Oracle skills that will teleport to the skill post-condition. When automatically setting predicates you may want to run the simulation in kinematic mode: | ||
akshararai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# To run in kinematic mode, add: `habitat.simulator.kinematic_mode=True habitat.simulator.ac_freq_ratio=1 habitat.task.measurements.force_terminate.max_accum_force=-1.0 habitat.task.measurements.force_terminate.max_instant_force=-1.0` | ||
|
||
defaults: | ||
- /habitat/task/actions: | ||
- pddl_apply_action | ||
|
||
open_cab: | ||
skill_name: "NoopSkillPolicy" | ||
max_skill_steps: 1 | ||
apply_postconds: True | ||
force_end_on_timeout: False | ||
pddl_action_names: ["open_cab_by_name"] | ||
|
||
open_fridge: | ||
skill_name: "NoopSkillPolicy" | ||
max_skill_steps: 1 | ||
apply_postconds: True | ||
force_end_on_timeout: False | ||
pddl_action_names: ["open_fridge_by_name"] | ||
|
||
close_cab: | ||
skill_name: "NoopSkillPolicy" | ||
obs_skill_inputs: ["obj_start_sensor"] | ||
max_skill_steps: 1 | ||
force_end_on_timeout: False | ||
pddl_action_names: ["close_cab_by_name"] | ||
|
||
close_fridge: | ||
skill_name: "NoopSkillPolicy" | ||
obs_skill_inputs: ["obj_start_sensor"] | ||
max_skill_steps: 1 | ||
apply_postconds: True | ||
force_end_on_timeout: False | ||
pddl_action_names: ["close_fridge_by_name"] | ||
|
||
pick: | ||
skill_name: "NoopSkillPolicy" | ||
obs_skill_inputs: ["obj_start_sensor"] | ||
max_skill_steps: 1 | ||
apply_postconds: True | ||
force_end_on_timeout: False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @akshararai check if all these parameters are needed, and add a sample documentation for them somewhere. |
||
|
||
place: | ||
skill_name: "NoopSkillPolicy" | ||
obs_skill_inputs: ["obj_goal_sensor"] | ||
max_skill_steps: 1 | ||
apply_postconds: True | ||
force_end_on_timeout: False | ||
|
||
wait: | ||
skill_name: "WaitSkillPolicy" | ||
max_skill_steps: -1 | ||
|
||
nav_to_obj: | ||
skill_name: "NoopSkillPolicy" | ||
obs_skill_inputs: ["goal_to_agent_gps_compass"] | ||
max_skill_steps: 1 | ||
apply_postconds: True | ||
force_end_on_timeout: False | ||
obs_skill_input_dim: 2 | ||
pddl_action_names: ["nav", "nav_to_receptacle_by_name"] | ||
|
||
reset_arm: | ||
skill_name: "ResetArmSkill" | ||
max_skill_steps: 50 | ||
reset_joint_state: [-4.50e-01, -1.07e00, 9.95e-02, 9.38e-01, -7.88e-04, 1.57e00, 4.62e-03] | ||
force_end_on_timeout: False |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
name: "HierarchicalPolicy" | ||
obs_transforms: | ||
add_virtual_keys: | ||
virtual_keys: | ||
"goal_to_agent_gps_compass": 2 | ||
hierarchical_policy: | ||
high_level_policy: | ||
name: "FixedHighLevelPolicy" | ||
add_arm_rest: True | ||
defined_skills: {} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
name: "HierarchicalPolicy" | ||
obs_transforms: | ||
add_virtual_keys: | ||
virtual_keys: | ||
"goal_to_agent_gps_compass": 2 | ||
hierarchical_policy: | ||
high_level_policy: | ||
name: "NeuralHighLevelPolicy" | ||
allowed_actions: | ||
- nav | ||
- pick | ||
- place | ||
- nav_to_receptacle_by_name | ||
- open_fridge_by_name | ||
- close_fridge_by_name | ||
- open_cab_by_name | ||
- close_cab_by_name | ||
allow_other_place: False | ||
hidden_dim: 512 | ||
use_rnn: True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @akshararai check which of these fields are needed |
||
rnn_type: 'LSTM' | ||
backbone: resnet18 | ||
normalize_visual_inputs: False | ||
num_rnn_layers: 2 | ||
policy_input_keys: | ||
- "robot_head_depth" | ||
defined_skills: {} | ||
ASzot marked this conversation as resolved.
Show resolved
Hide resolved
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
name: "PointNavResNetPolicy" | ||
action_distribution_type: "gaussian" | ||
action_dist: | ||
use_log_std: True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
storage and updater are not very descriptive. Can you add a little text here to clarify what is the base classes that are being registered?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe even add a
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is the assert ?