Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Train High-Level Policies in Hierarchical Approaches #1053

Merged
merged 55 commits into from
Feb 10, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
a787e34
Trainable HL policy
ASzot Dec 24, 2022
f6d1f11
Working on HRL trainer
ASzot Dec 26, 2022
daa84db
Fixed config setup
ASzot Dec 27, 2022
7ae2c6b
Train hl modif (#1057)
akshararai Jan 6, 2023
bb2e4da
Update README.md
xavierpuigf Jan 9, 2023
628063d
Match tensor device when checking if the skills is done
xavierpuigf Jan 9, 2023
d179ecf
Train hl modif2 (#1076)
xavierpuigf Jan 13, 2023
515b3cf
Merged with main
ASzot Jan 19, 2023
8eddcd1
Fixed RNN problem
ASzot Jan 20, 2023
fa94bfc
Fixed tests
ASzot Jan 20, 2023
11659b4
Fixed formatting
ASzot Jan 20, 2023
633ebf3
Fixed device issues. Cleaned up configs.
ASzot Jan 27, 2023
1c82390
More config cleanup
ASzot Jan 27, 2023
2e973f5
Addressing PR comments
ASzot Jan 27, 2023
3a805ea
Updated circular reference
ASzot Jan 27, 2023
7d92072
Addressing PR comments
ASzot Jan 27, 2023
4696551
Addressing PR comments
ASzot Jan 27, 2023
c11601f
Update habitat-baselines/habitat_baselines/rl/hrl/skills/skill.py
ASzot Jan 27, 2023
90155e1
Addressing PR comments
ASzot Jan 27, 2023
755acb4
Resolved storage problem
ASzot Jan 28, 2023
8d07335
merged
ASzot Jan 28, 2023
4754fd7
Update oracle_nav.py
xavierpuigf Jan 30, 2023
587672e
Fix for agent rotation
ASzot Jan 30, 2023
f16bc14
Missing key
ASzot Jan 30, 2023
22cb83f
More docs
ASzot Jan 31, 2023
df8b7a6
Update habitat-baselines/habitat_baselines/rl/hrl/hrl_rollout_storage.py
ASzot Jan 31, 2023
422c036
Update habitat-baselines/habitat_baselines/rl/hrl/utils.py
ASzot Jan 31, 2023
5643b0e
Updated name
ASzot Jan 31, 2023
ebe877e
fixes for training
ASzot Feb 1, 2023
e1a8727
Fixed env issue
ASzot Feb 2, 2023
2673f4f
Fixed deprecated configs
ASzot Feb 2, 2023
03faec7
Merge branch 'main' into train_hl
ASzot Feb 2, 2023
dd01d50
Speed fix
ASzot Feb 3, 2023
b20fcb1
Updated configs
ASzot Feb 3, 2023
7982f40
Pddl action fixes
ASzot Feb 3, 2023
902bfa1
Removed speed opts. Fixed some bugs
ASzot Feb 4, 2023
74278bd
Fixed rendering text to the frame
ASzot Feb 4, 2023
63f610c
Merged with main
ASzot Feb 4, 2023
49a71a4
Addressing Vince's PR comments
ASzot Feb 4, 2023
b41133a
Refactored navigation to be much clearer
ASzot Feb 4, 2023
e7a877b
Fixed some of the tests
ASzot Feb 5, 2023
5c213e3
Adddressed PR comments
ASzot Feb 6, 2023
d9721f1
Fixed rotation issue
ASzot Feb 6, 2023
f8387de
Fixed black
ASzot Feb 6, 2023
1c8f54c
Addressed PR comments
ASzot Feb 8, 2023
f2c6731
Addressed PR comments
ASzot Feb 8, 2023
4b65b3f
Merge branch 'main' into train_hl
ASzot Feb 8, 2023
11e77c3
Fixed config
ASzot Feb 8, 2023
6f5ea76
Fixed typo
ASzot Feb 8, 2023
cb6ce62
Fixed another typo
ASzot Feb 8, 2023
6d4b968
CI
ASzot Feb 8, 2023
d6c957e
Merge branch 'main' into train_hl
vincentpierre Feb 8, 2023
9d2c2f5
Updated to work with older pytorch version
ASzot Feb 9, 2023
a17fbfa
Merge branch 'main' into train_hl
ASzot Feb 9, 2023
48142fb
renaming --exp-config to --config-name again
vincentpierre Feb 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions habitat-baselines/habitat_baselines/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,18 @@ To run the following examples, you need the [ReplicaCAD dataset](https://github.
To train a high-level policy, while using pre-learned low-level skills (SRL baseline from [Habitat2.0](https://arxiv.org/abs/2106.14405)), you can run:

```bash
python -u habitat_baselines/run.py \
--exp-config habitat_baselines/config/rearrange/rl_hl_srl_onav.yaml \
python -u habitat-baselines/habitat_baselines/run.py \
--exp-config habitat-baselines/habitat_baselines/config/rearrange/rl_hierarchical.yaml \
--run-type train
```
To run a rearrangement episode with oracle in both low- and high-level, you can run:
To run a rearrangement episode with oracle low-level skills and a fixed task planner, run:

```bash
python -u habitat_baselines/run.py \
--exp-config habitat_baselines/config/rearrange/rl_hl_srl_onav.yaml \
python -u habitat-baselines/habitat_baselines/run.py \
--exp-config habitat-baselines/habitat_baselines/config/rearrange/rl_hierarchical.yaml \
--run-type eval \
habitat_baselines/rl/policy=hierarchical_tp_noop_onav
habitat_baselines/rl/policy=hl_fixed \
habitat_baselines/rl/policy/hierarchical_policy/defined_skills=oracle_skills
```

To change the task (like set table) that you train your skills on, you can change the line `/habitat/task/rearrange: rearrange_easy` to `/habitat/task/rearrange: set_table` in the defaults of your config.
Expand Down
12 changes: 4 additions & 8 deletions habitat-baselines/habitat_baselines/agents/ppo_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,23 +126,19 @@ def reset(self) -> None:
def act(self, observations: Observations) -> Dict[str, int]:
batch = batch_obs([observations], device=self.device)
with torch.no_grad():
(
_,
actions,
_,
self.test_recurrent_hidden_states,
) = self.actor_critic.act(
action_data = self.actor_critic.act(
batch,
self.test_recurrent_hidden_states,
self.prev_actions,
self.not_done_masks,
deterministic=False,
)
self.test_recurrent_hidden_states = action_data.rnn_hidden_states
# Make masks not done till reset (end of episode) will be called
self.not_done_masks.fill_(True)
self.prev_actions.copy_(actions) # type: ignore
self.prev_actions.copy_(action_data.actions) # type: ignore

return {"action": actions[0][0].item()}
return {"action": action_data.env_actions[0][0].item()}


def main():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def __init__(
num_recurrent_layers=1,
is_double_buffered: bool = False,
):

if is_continuous_action_space(action_space):
# Assume ALL actions are NOT discrete
action_shape = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,9 @@ class HrlDefinedSkillConfig(HabitatBaselinesBaseConfig):
stop_thresh: float = 0.001
ASzot marked this conversation as resolved.
Show resolved Hide resolved
# For the reset_arm_skill
reset_joint_state: List[float] = MISSING
# The set of PDDL action names (as defined in the PDDL domain file) that
# map to this skill. If not specified,the name of the skill must match the
# PDDL action name.
pddl_action_names: Optional[List[str]] = None


Expand Down
Original file line number Diff line number Diff line change
@@ -1,26 +1,38 @@
# Oracle skills that will teleport to the skill post-condition. When automatically setting predicates you may want to run the simulation in kinematic mode:
akshararai marked this conversation as resolved.
Show resolved Hide resolved
# To run in kinematic mode, add: `habitat.simulator.kinematic_mode=True habitat.simulator.ac_freq_ratio=1 habitat.task.measurements.force_terminate.max_accum_force=-1.0 habitat.task.measurements.force_terminate.max_instant_force=-1.0`

defaults:
- /habitat/task/actions:
- pddl_apply_action

open_cab:
skill_name: "NoopSkillPolicy"
max_skill_steps: 1
apply_postconds: True
force_end_on_timeout: False
pddl_action_names: ["open_cab_by_name"]

open_fridge:
skill_name: "NoopSkillPolicy"
max_skill_steps: 1
apply_postconds: True
force_end_on_timeout: False
pddl_action_names: ["open_fridge_by_name"]

close_cab:
skill_name: "NoopSkillPolicy"
obs_skill_inputs: ["obj_start_sensor"]
max_skill_steps: 1
force_end_on_timeout: False
pddl_action_names: ["close_cab_by_name"]

close_fridge:
skill_name: "NoopSkillPolicy"
obs_skill_inputs: ["obj_start_sensor"]
max_skill_steps: 1
apply_postconds: True
force_end_on_timeout: False
pddl_action_names: ["close_fridge_by_name"]

pick:
skill_name: "NoopSkillPolicy"
Expand All @@ -47,7 +59,7 @@ nav_to_obj:
apply_postconds: True
force_end_on_timeout: False
obs_skill_input_dim: 2
pddl_action_names: ["nav", "nav_to_receptacle"]
pddl_action_names: ["nav", "nav_to_receptacle_by_name"]

reset_arm:
skill_name: "ResetArmSkill"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@ obs_transforms:
hierarchical_policy:
high_level_policy:
name: "NeuralHighLevelPolicy"
allowed_actions:
- nav
- pick
- place
- nav_to_receptacle_by_name
- open_fridge_by_name
- close_fridge_by_name
- open_cab_by_name
- close_cab_by_name
allow_other_place: False
hidden_dim: 512
use_rnn: True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@akshararai check which of these fields are needed

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ habitat_baselines:
num_environments: 4
writer_type: 'tb'
num_updates: -1
total_num_steps: 1.0e8
total_num_steps: 5.0e7
log_interval: 10
num_checkpoints: 20
num_checkpoints: 10
force_torch_single_threaded: True
eval_keys_to_include_in_name: ['reward', 'force', 'composite_success']
load_resume_state_config: False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
defaults:
- rl_hierarchical
- /habitat/task/actions:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the oracle navigation action, and why is it different from the skill?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The oracle navigation action is a Habitat-Lab action that accesses the underlying simulator state to compute the oracle path. The oracle navigation skill does not have access to the simulator because it is on the Habitat-Baselines side.

- pddl_apply_action
- oracle_nav_action
- _self_

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

defaults:
- /benchmark/rearrange: rearrange
- /habitat_baselines/rl/policy: monolithic
- /habitat_baselines: habitat_baselines_rl_config_base
- /habitat_baselines/rl/policy: monolithic
- /habitat/simulator/sim_sensors@habitat_baselines.eval.extra_sim_sensors.third_rgb_sensor: third_rgb_sensor
- _self_

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def act(
(self._num_envs,), dtype=torch.bool
)

hl_policy_wants_termination = self._high_level_policy.get_termination(
hl_wants_skill_term = self._high_level_policy.get_termination(
observations,
rnn_hidden_states,
prev_actions,
Expand All @@ -249,7 +249,7 @@ def act(
"prev_actions": prev_actions,
"masks": masks,
"actions": actions,
"hl_policy_wants_termination": hl_policy_wants_termination,
"hl_wants_skill_term": hl_wants_skill_term,
},
# Only decide on skill termination if the episode is active.
should_adds=masks,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from itertools import chain
from typing import Any, List

Expand All @@ -7,6 +8,7 @@
import torch.nn as nn

from habitat.tasks.rearrange.multi_task.pddl_action import PddlAction
from habitat_baselines.common.logging import baselines_logger
from habitat_baselines.rl.ddppo.policy import resnet
from habitat_baselines.rl.ddppo.policy.resnet_policy import ResNetEncoder
from habitat_baselines.rl.hrl.hl.high_level_policy import HighLevelPolicy
Expand Down Expand Up @@ -81,6 +83,9 @@ def __init__(self, *args, **kwargs):

def _setup_actions(self) -> List[PddlAction]:
all_actions = self._pddl_prob.get_possible_actions()
all_actions = [
ac for ac in all_actions if ac.name in self._config.allowed_actions
]
if not self._config.allow_other_place:
all_actions = [
ac
Expand Down Expand Up @@ -191,6 +196,8 @@ def get_next_skill(
if should_plan != 1.0:
continue
use_ac = self._all_actions[skill_sel[batch_idx]]
if baselines_logger.level >= logging.DEBUG:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need this? I thought .debug already checks if the level is correct.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is needed because I want to avoid any cost of formatting the use_ac value.

baselines_logger.debug(f"HL Policy selected skill {use_ac}")
next_skill[batch_idx] = self._skill_name_to_idx[use_ac.name]
skill_args_data[batch_idx] = [
entity.name for entity in use_ac.param_values
Expand Down
7 changes: 6 additions & 1 deletion habitat-baselines/habitat_baselines/rl/hrl/hrl_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ def reduce_loss(loss):

self._set_grads_to_none()

(values, action_log_probs, dist_entropy, _,) = self._evaluate_actions(
(
values,
action_log_probs,
dist_entropy,
_,
) = self._evaluate_actions(
batch["observations"],
batch["recurrent_hidden_states"],
batch["prev_actions"],
Expand Down
54 changes: 14 additions & 40 deletions habitat-baselines/habitat_baselines/rl/hrl/skills/oracle_nav.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,7 @@
import torch

from habitat.core.spaces import ActionSpace
from habitat.tasks.rearrange.actions.oracle_nav_action import (
get_possible_nav_to_actions,
)
from habitat.tasks.rearrange.multi_task.pddl_domain import PddlProblem
from habitat.tasks.rearrange.multi_task.rearrange_pddl import RIGID_OBJ_TYPE
from habitat.tasks.rearrange.rearrange_sensors import LocalizationSensor
from habitat_baselines.common.logging import baselines_logger
from habitat_baselines.rl.hrl.skills.nn_skill import NnSkillPolicy
Expand All @@ -23,9 +19,11 @@
class OracleNavPolicy(NnSkillPolicy):
@dataclass
class OracleNavActionArgs:
"""
:property action_idx: The index of the oracle action we want to execute
"""

action_idx: int
is_target_obj: bool
target_idx: int

def __init__(
self,
Expand Down Expand Up @@ -53,7 +51,7 @@ def __init__(
pddl_task_path,
task_config,
)
self._poss_actions = get_possible_nav_to_actions(self._pddl_problem)
self._all_entities = self._pddl_problem.get_ordered_entities_list()
self._oracle_nav_ac_idx, _ = find_action_range(
action_space, "oracle_nav_action"
)
Expand Down Expand Up @@ -130,47 +128,23 @@ def _is_skill_done(
return ret

def _parse_skill_arg(self, skill_arg):
marker = None
if len(skill_arg) == 2:
targ_obj, _ = skill_arg
search_target, _ = skill_arg
elif len(skill_arg) == 3:
marker, targ_obj, _ = skill_arg
_, search_target, _ = skill_arg
else:
raise ValueError(
f"Unexpected number of skill arguments in {skill_arg}"
)

targ_obj_idx = int(targ_obj.split("|")[-1])

targ_obj = self._pddl_problem.get_entity(targ_obj)
if marker is not None:
marker = self._pddl_problem.get_entity(marker)

match_i = None
for i, action in enumerate(self._poss_actions):
match_obj = action.get_arg_value("obj")
if marker is not None:
match_marker = action.get_arg_value("marker")
if match_marker != marker:
continue
if match_obj != targ_obj:
continue
match_i = i
break
if match_i is None:
raise ValueError(f"Cannot find matching action for {skill_arg}")
is_target_obj = targ_obj.expr_type.is_subtype_of(
self._pddl_problem.expr_types[RIGID_OBJ_TYPE]
)
return OracleNavPolicy.OracleNavActionArgs(
match_i, is_target_obj, targ_obj_idx
)

def _get_multi_sensor_index(self, batch_idx):
return [self._cur_skill_args[i].target_idx for i in batch_idx]
target = self._pddl_problem.get_entity(search_target)
if target is None:
raise ValueError(
f"Cannot find matching entity for {search_target}"
)
match_i = self._all_entities.index(target)

def requires_rnn_state(self):
return False
return OracleNavPolicy.OracleNavActionArgs(match_i)

def _internal_act(
self,
Expand Down
1 change: 0 additions & 1 deletion habitat-baselines/habitat_baselines/rl/hrl/skills/place.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def _is_skill_done(
if is_done.sum() > 0:
self._internal_log(
f"Terminating with {rel_resting_pos} and {is_holding}",
observations,
)
return is_done

Expand Down
12 changes: 7 additions & 5 deletions habitat-baselines/habitat_baselines/rl/hrl/skills/reset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@ def __init__(
batch_size,
):
super().__init__(config, action_space, batch_size, True)
self._target = np.array([float(x) for x in config.reset_joint_state])
self._rest_state = np.array(
[float(x) for x in config.reset_joint_state]
)

self._arm_ac_range = find_action_range(action_space, "arm_action")
self._arm_ac_range = (self._arm_ac_range[0], self._target.shape[0])
self._arm_ac_range = (self._arm_ac_range[0], self._rest_state.shape[0])

def on_enter(
self,
Expand All @@ -43,7 +45,7 @@ def on_enter(
)

self._initial_delta = (
self._target - observations["joint"].cpu().numpy()
self._rest_state - observations["joint"].cpu().numpy()
)

return ret
Expand All @@ -58,7 +60,7 @@ def _is_skill_done(

return (
torch.as_tensor(
np.abs(current_joint_pos - self._target).max(-1),
np.abs(current_joint_pos - self._rest_state).max(-1),
dtype=torch.float32,
)
< 5e-2
Expand All @@ -74,7 +76,7 @@ def _internal_act(
deterministic=False,
):
current_joint_pos = observations["joint"].cpu().numpy()
delta = self._target - current_joint_pos
delta = self._rest_state - current_joint_pos

# Dividing by max initial delta means that the action will
# always in [-1,1] and has the benefit of reducing the delta
Expand Down
Loading