diff --git a/habitat-baselines/habitat_baselines/README.md b/habitat-baselines/habitat_baselines/README.md index bd891b424e..8c04a8ef2a 100644 --- a/habitat-baselines/habitat_baselines/README.md +++ b/habitat-baselines/habitat_baselines/README.md @@ -51,6 +51,30 @@ To use them download pre-trained pytorch models from [link](https://dl.fbaipubli The `habitat_baselines/config/pointnav/ppo_pointnav.yaml` config has better hyperparameters for large scale training and loads the [Gibson PointGoal Navigation Dataset](/README.md#datasets) instead of the test scenes. Change the `/benchmark/nav/pointnav: pointnav_gibson` in `habitat_baselines/config/pointnav/ppo_pointnav.yaml` to `/benchmark/nav/pointnav: pointnav_mp3d` in the defaults list for training on [MatterPort3D PointGoal Navigation Dataset](/README.md#datasets). +### Hierarchical Reinforcement Learning (HRL) + +We provide a two-layer hierarchical policy class, consisting of a low-level skill that moves the robot, and a high-level policy that reasons about which low-level skill to use in the current state. This can be especially powerful in long-horizon mobile manipulation tasks, like those introduced in [Habitat2.0](https://arxiv.org/abs/2106.14405). Both the low- and high- level can be either learned or an oracle. For oracle high-level we use [PDDL](https://planning.wiki/guide/whatis/pddl), and for oracle low-level we use instantaneous transitions, with the environment set to the final desired state. Additionally, for navigation, we provide an oracle navigation skill that uses A-star and the map of the environment to move the robot to its goal. + +To run the following examples, you need the [ReplicaCAD dataset](https://github.com/facebookresearch/habitat-sim/blob/main/DATASETS.md#replicacad). + +To train a high-level policy, while using pre-learned low-level skills (SRL baseline from [Habitat2.0](https://arxiv.org/abs/2106.14405)), you can run: + +```bash +python -u habitat-baselines/habitat_baselines/run.py \ + --config-name=rearrange/rl_hierarchical.yaml +``` +To run a rearrangement episode with oracle low-level skills and a fixed task planner, run: + +```bash +python -u habitat-baselines/habitat_baselines/run.py \ + --config-name=rearrange/rl_hierarchical.yaml \ + habitat_baselines.evaluate=True \ + habitat_baselines/rl/policy=hl_fixed \ + habitat_baselines/rl/policy/hierarchical_policy/defined_skills=oracle_skills +``` + +To change the task (like set table) that you train your skills on, you can change the line `/habitat/task/rearrange: rearrange_easy` to `/habitat/task/rearrange: set_table` in the defaults of your config. + ### Additional Utilities **Episode iterator options**: diff --git a/habitat-baselines/habitat_baselines/agents/ppo_agents.py b/habitat-baselines/habitat_baselines/agents/ppo_agents.py index 7b25704cf2..f87a85ef3e 100644 --- a/habitat-baselines/habitat_baselines/agents/ppo_agents.py +++ b/habitat-baselines/habitat_baselines/agents/ppo_agents.py @@ -126,23 +126,19 @@ def reset(self) -> None: def act(self, observations: Observations) -> Dict[str, int]: batch = batch_obs([observations], device=self.device) with torch.no_grad(): - ( - _, - actions, - _, - self.test_recurrent_hidden_states, - ) = self.actor_critic.act( + action_data = self.actor_critic.act( batch, self.test_recurrent_hidden_states, self.prev_actions, self.not_done_masks, deterministic=False, ) + self.test_recurrent_hidden_states = action_data.rnn_hidden_states # Make masks not done till reset (end of episode) will be called self.not_done_masks.fill_(True) - self.prev_actions.copy_(actions) # type: ignore + self.prev_actions.copy_(action_data.actions) # type: ignore - return {"action": actions[0][0].item()} + return {"action": action_data.env_actions[0][0].item()} def main(): diff --git a/habitat-baselines/habitat_baselines/common/baseline_registry.py b/habitat-baselines/habitat_baselines/common/baseline_registry.py index 5038cf91a5..26554f6054 100644 --- a/habitat-baselines/habitat_baselines/common/baseline_registry.py +++ b/habitat-baselines/habitat_baselines/common/baseline_registry.py @@ -136,5 +136,30 @@ def register_auxiliary_loss( def get_auxiliary_loss(cls, name: str): return cls._get_impl("aux_loss", name) + @classmethod + def register_storage(cls, to_register=None, *, name: Optional[str] = None): + """ + Registers data storage for storing data in the policy rollout in the + trainer and then for fetching data batches for the updater. + """ + + return cls._register_impl("storage", to_register, name) + + @classmethod + def get_storage(cls, name: str): + return cls._get_impl("storage", name) + + @classmethod + def register_updater(cls, to_register=None, *, name: Optional[str] = None): + """ + Registers a policy updater. + """ + + return cls._register_impl("updater", to_register, name) + + @classmethod + def get_updater(cls, name: str): + return cls._get_impl("updater", name) + baseline_registry = BaselineRegistry() diff --git a/habitat-baselines/habitat_baselines/common/rollout_storage.py b/habitat-baselines/habitat_baselines/common/rollout_storage.py index 0222ff44f1..6fe9a03fc3 100644 --- a/habitat-baselines/habitat_baselines/common/rollout_storage.py +++ b/habitat-baselines/habitat_baselines/common/rollout_storage.py @@ -5,18 +5,21 @@ # LICENSE file in the root directory of this source tree. import warnings -from typing import Any, Dict, Iterator, Optional, Tuple +from typing import Any, Dict, Iterator, Optional import numpy as np import torch +from habitat_baselines.common.baseline_registry import baseline_registry from habitat_baselines.common.tensor_dict import DictTree, TensorDict from habitat_baselines.rl.models.rnn_state_encoder import ( build_pack_info_from_dones, build_rnn_build_seq_info, ) +from habitat_baselines.utils.common import get_action_space_info +@baseline_registry.register_storage class RolloutStorage: r"""Class for storing rollout information for RL trainers.""" @@ -28,10 +31,10 @@ def __init__( action_space, recurrent_hidden_state_size, num_recurrent_layers=1, - action_shape: Optional[Tuple[int]] = None, is_double_buffered: bool = False, - discrete_actions: bool = True, ): + action_shape, discrete_actions = get_action_space_info(action_space) + self.buffers = TensorDict() self.buffers["observations"] = TensorDict() @@ -115,6 +118,7 @@ def insert( rewards=None, next_masks=None, buffer_index: int = 0, + **kwargs, ): if not self.is_double_buffered: assert buffer_index == 0 diff --git a/habitat-baselines/habitat_baselines/config/default_structured_configs.py b/habitat-baselines/habitat_baselines/config/default_structured_configs.py index 88fed951e0..f998435a7d 100644 --- a/habitat-baselines/habitat_baselines/config/default_structured_configs.py +++ b/habitat-baselines/habitat_baselines/config/default_structured_configs.py @@ -4,7 +4,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple import attr from hydra.core.config_store import ConfigStore @@ -215,9 +215,43 @@ class Eq2CubeConfig(ObsTransformConfig): @attr.s(auto_attribs=True, slots=True) -class HierarchicalPolicy(HabitatBaselinesBaseConfig): +class HrlDefinedSkillConfig(HabitatBaselinesBaseConfig): + """ + Defines a low-level skill to be used in the hierarchical policy. + """ + + skill_name: str = MISSING + name: str = "PointNavResNetPolicy" + action_distribution_type: str = "gaussian" + load_ckpt_file: str = "" + max_skill_steps: int = 200 + # If true, the stop action will be called if the skill times out. + force_end_on_timeout: bool = True + # Overrides the config file of a neural network skill rather than loading + # the config file from the checkpoint file. + force_config_file: str = "" + at_resting_threshold: float = 0.15 + # If true, this willapply the post-conditions of the skill after it + # terminates. + apply_postconds: bool = False + obs_skill_inputs: List[str] = list() + obs_skill_input_dim: int = 3 + start_zone_radius: float = 0.3 + # For the oracle navigation skill + action_name: str = "base_velocity" + stop_thresh: float = 0.001 + # For the reset_arm_skill + reset_joint_state: List[float] = MISSING + # The set of PDDL action names (as defined in the PDDL domain file) that + # map to this skill. If not specified,the name of the skill must match the + # PDDL action name. + pddl_action_names: Optional[List[str]] = None + + +@attr.s(auto_attribs=True, slots=True) +class HierarchicalPolicyConfig(HabitatBaselinesBaseConfig): high_level_policy: Dict[str, Any] = MISSING - defined_skills: Dict[str, Any] = dict() + defined_skills: Dict[str, HrlDefinedSkillConfig] = dict() use_skills: Dict[str, str] = dict() @@ -229,7 +263,7 @@ class PolicyConfig(HabitatBaselinesBaseConfig): # For gaussian action distribution: action_dist: ActionDistributionConfig = ActionDistributionConfig() obs_transforms: Dict[str, ObsTransformConfig] = dict() - hierarchical_policy: HierarchicalPolicy = MISSING + hierarchical_policy: HierarchicalPolicyConfig = MISSING @attr.s(auto_attribs=True, slots=True) @@ -335,6 +369,8 @@ class HabitatBaselinesConfig(HabitatBaselinesBaseConfig): # replaces --run-type eval when true evaluate: bool = False trainer_name: str = "ppo" + updater_name: str = "PPO" + distrib_updater_name: str = "DDPPO" torch_gpu_id: int = 0 tensorboard_dir: str = "tb" writer_type: str = "tb" @@ -345,6 +381,7 @@ class HabitatBaselinesConfig(HabitatBaselinesBaseConfig): eval_ckpt_path_dir: str = "data/checkpoints" num_environments: int = 16 num_processes: int = -1 # deprecated + rollout_storage_name: str = "RolloutStorage" checkpoint_folder: str = "data/checkpoints" num_updates: int = 10000 num_checkpoints: int = 10 diff --git a/habitat-baselines/habitat_baselines/config/habitat_baselines/rl/policy/hierarchical_policy/defined_skills/nn_skills.yaml b/habitat-baselines/habitat_baselines/config/habitat_baselines/rl/policy/hierarchical_policy/defined_skills/nn_skills.yaml new file mode 100644 index 0000000000..a0370b63cf --- /dev/null +++ b/habitat-baselines/habitat_baselines/config/habitat_baselines/rl/policy/hierarchical_policy/defined_skills/nn_skills.yaml @@ -0,0 +1,44 @@ +open_cab: + skill_name: "ArtObjSkillPolicy" + load_ckpt_file: "data/models/open_cab.pth" + +open_fridge: + skill_name: "ArtObjSkillPolicy" + load_ckpt_file: "data/models/open_fridge.pth" + +close_cab: + skill_name: "ArtObjSkillPolicy" + load_ckpt_file: "data/models/close_cab.pth" + +close_fridge: + skill_name: "ArtObjSkillPolicy" + load_ckpt_file: "data/models/close_fridge.pth" + +pick: + skill_name: "PickSkillPolicy" + obs_skill_inputs: ["obj_start_sensor"] + load_ckpt_file: "data/models/pick.pth" + +place: + skill_name: "PlaceSkillPolicy" + obs_skill_inputs: ["obj_goal_sensor"] + load_ckpt_file: "data/models/place.pth" + +wait: + skill_name: "WaitSkillPolicy" + max_skill_steps: -1 + force_end_on_timeout: False + +nav_to_obj: + skill_name: "NavSkillPolicy" + obs_skill_inputs: ["goal_to_agent_gps_compass"] + load_ckpt_file: "data/models/nav.pth" + max_skill_steps: 300 + obs_skill_input_dim: 2 + pddl_action_names: ["nav", "nav_to_receptacle"] + +reset_arm: + skill_name: "ResetArmSkill" + max_skill_steps: 50 + reset_joint_state: [-4.50e-01, -1.08e00, 9.95e-02, 9.38e-01, -7.88e-04, 1.57e00, 4.62e-03] + force_end_on_timeout: False diff --git a/habitat-baselines/habitat_baselines/config/habitat_baselines/rl/policy/hierarchical_policy/defined_skills/oracle_skills.yaml b/habitat-baselines/habitat_baselines/config/habitat_baselines/rl/policy/hierarchical_policy/defined_skills/oracle_skills.yaml new file mode 100644 index 0000000000..ab4bae46ef --- /dev/null +++ b/habitat-baselines/habitat_baselines/config/habitat_baselines/rl/policy/hierarchical_policy/defined_skills/oracle_skills.yaml @@ -0,0 +1,68 @@ +# Oracle skills that will teleport to the skill post-condition. When automatically setting predicates you may want to run the simulation in kinematic mode: +# To run in kinematic mode, add: `habitat.simulator.kinematic_mode=True habitat.simulator.ac_freq_ratio=1 habitat.task.measurements.force_terminate.max_accum_force=-1.0 habitat.task.measurements.force_terminate.max_instant_force=-1.0` + +defaults: + - /habitat/task/actions: + - pddl_apply_action + +open_cab: + skill_name: "NoopSkillPolicy" + max_skill_steps: 1 + apply_postconds: True + force_end_on_timeout: False + pddl_action_names: ["open_cab_by_name"] + +open_fridge: + skill_name: "NoopSkillPolicy" + max_skill_steps: 1 + apply_postconds: True + force_end_on_timeout: False + pddl_action_names: ["open_fridge_by_name"] + +close_cab: + skill_name: "NoopSkillPolicy" + obs_skill_inputs: ["obj_start_sensor"] + max_skill_steps: 1 + force_end_on_timeout: False + pddl_action_names: ["close_cab_by_name"] + +close_fridge: + skill_name: "NoopSkillPolicy" + obs_skill_inputs: ["obj_start_sensor"] + max_skill_steps: 1 + apply_postconds: True + force_end_on_timeout: False + pddl_action_names: ["close_fridge_by_name"] + +pick: + skill_name: "NoopSkillPolicy" + obs_skill_inputs: ["obj_start_sensor"] + max_skill_steps: 1 + apply_postconds: True + force_end_on_timeout: False + +place: + skill_name: "NoopSkillPolicy" + obs_skill_inputs: ["obj_goal_sensor"] + max_skill_steps: 1 + apply_postconds: True + force_end_on_timeout: False + +wait: + skill_name: "WaitSkillPolicy" + max_skill_steps: -1 + +nav_to_obj: + skill_name: "NoopSkillPolicy" + obs_skill_inputs: ["goal_to_agent_gps_compass"] + max_skill_steps: 1 + apply_postconds: True + force_end_on_timeout: False + obs_skill_input_dim: 2 + pddl_action_names: ["nav", "nav_to_receptacle_by_name"] + +reset_arm: + skill_name: "ResetArmSkill" + max_skill_steps: 50 + reset_joint_state: [-4.50e-01, -1.07e00, 9.95e-02, 9.38e-01, -7.88e-04, 1.57e00, 4.62e-03] + force_end_on_timeout: False diff --git a/habitat-baselines/habitat_baselines/config/habitat_baselines/rl/policy/hl_fixed.yaml b/habitat-baselines/habitat_baselines/config/habitat_baselines/rl/policy/hl_fixed.yaml new file mode 100644 index 0000000000..018fd32519 --- /dev/null +++ b/habitat-baselines/habitat_baselines/config/habitat_baselines/rl/policy/hl_fixed.yaml @@ -0,0 +1,10 @@ +name: "HierarchicalPolicy" +obs_transforms: + add_virtual_keys: + virtual_keys: + "goal_to_agent_gps_compass": 2 +hierarchical_policy: + high_level_policy: + name: "FixedHighLevelPolicy" + add_arm_rest: True + defined_skills: {} diff --git a/habitat-baselines/habitat_baselines/config/habitat_baselines/rl/policy/hl_neural.yaml b/habitat-baselines/habitat_baselines/config/habitat_baselines/rl/policy/hl_neural.yaml new file mode 100644 index 0000000000..1ede79538b --- /dev/null +++ b/habitat-baselines/habitat_baselines/config/habitat_baselines/rl/policy/hl_neural.yaml @@ -0,0 +1,27 @@ +name: "HierarchicalPolicy" +obs_transforms: + add_virtual_keys: + virtual_keys: + "goal_to_agent_gps_compass": 2 +hierarchical_policy: + high_level_policy: + name: "NeuralHighLevelPolicy" + allowed_actions: + - nav + - pick + - place + - nav_to_receptacle_by_name + - open_fridge_by_name + - close_fridge_by_name + - open_cab_by_name + - close_cab_by_name + allow_other_place: False + hidden_dim: 512 + use_rnn: True + rnn_type: 'LSTM' + backbone: resnet18 + normalize_visual_inputs: False + num_rnn_layers: 2 + policy_input_keys: + - "robot_head_depth" + defined_skills: {} diff --git a/habitat-baselines/habitat_baselines/config/habitat_baselines/rl/policy/monolithic.yaml b/habitat-baselines/habitat_baselines/config/habitat_baselines/rl/policy/monolithic.yaml new file mode 100644 index 0000000000..83f65a6e5f --- /dev/null +++ b/habitat-baselines/habitat_baselines/config/habitat_baselines/rl/policy/monolithic.yaml @@ -0,0 +1,4 @@ +name: "PointNavResNetPolicy" +action_distribution_type: "gaussian" +action_dist: + use_log_std: True diff --git a/habitat-baselines/habitat_baselines/config/rearrange/rl_hierarchical.yaml b/habitat-baselines/habitat_baselines/config/rearrange/rl_hierarchical.yaml new file mode 100644 index 0000000000..f5971142d4 --- /dev/null +++ b/habitat-baselines/habitat_baselines/config/rearrange/rl_hierarchical.yaml @@ -0,0 +1,73 @@ +# @package _global_ + +# Config for running hierarchical policies where a high-level (HL) policy selects from a set of low-level (LL) policies. +# Supports different HL policy configurations and using a variety of LL policies. + +defaults: + - /benchmark/rearrange: rearrange_easy + - /habitat_baselines: habitat_baselines_rl_config_base + - /habitat/simulator/sim_sensors@habitat_baselines.eval.extra_sim_sensors.third_rgb_sensor: third_rgb_sensor + - /habitat_baselines/rl/policy/obs_transforms: + - add_virtual_keys_base + - /habitat_baselines/rl/policy: hl_fixed + - /habitat_baselines/rl/policy/hierarchical_policy/defined_skills: nn_skills + - _self_ + +habitat_baselines: + verbose: False + trainer_name: "ddppo" + updater_name: "HRLPPO" + distrib_updater_name: "HRLDDPPO" + torch_gpu_id: 0 + video_fps: 30 + eval_ckpt_path_dir: "" + num_environments: 4 + writer_type: 'tb' + num_updates: -1 + total_num_steps: 5.0e7 + log_interval: 10 + num_checkpoints: 10 + force_torch_single_threaded: True + eval_keys_to_include_in_name: ['reward', 'force', 'composite_success'] + load_resume_state_config: False + rollout_storage_name: "HrlRolloutStorage" + + eval: + use_ckpt_config: False + should_load_ckpt: False + video_option: ["disk"] + + rl: + ppo: + # ppo params + clip_param: 0.2 + ppo_epoch: 2 + num_mini_batch: 2 + value_loss_coef: 0.5 + entropy_coef: 0.0001 + lr: 2.5e-4 + eps: 1e-5 + max_grad_norm: 0.2 + num_steps: 128 + use_gae: True + gamma: 0.99 + tau: 0.95 + + ddppo: + sync_frac: 0.6 + # The PyTorch distributed backend to use + distrib_backend: NCCL + # Visual encoder backbone + pretrained_weights: data/ddppo-models/gibson-2plus-resnet50.pth + # Initialize with pretrained weights + pretrained: False + # Initialize just the visual encoder backbone with pretrained weights + pretrained_encoder: False + # Whether the visual encoder backbone will be trained. + train_encoder: True + # Whether to reset the critic linear layer + reset_critic: False + # Model parameters + backbone: resnet18 + rnn_type: LSTM + num_recurrent_layers: 2 diff --git a/habitat-baselines/habitat_baselines/config/rearrange/rl_hierarchical_oracle_nav.yaml b/habitat-baselines/habitat_baselines/config/rearrange/rl_hierarchical_oracle_nav.yaml new file mode 100644 index 0000000000..0f6ddbdbfb --- /dev/null +++ b/habitat-baselines/habitat_baselines/config/rearrange/rl_hierarchical_oracle_nav.yaml @@ -0,0 +1,41 @@ +# @package _global_ + +# Extends the `rl_hierarchical` config to use an oracle navigation action. +# Several things change to support the oracle navigation: +# - Adding the oracle navigation action +# - Add sliding. +# - Use the oracle navigation skill. + +defaults: + - rl_hierarchical + - /habitat/task/actions: + - oracle_nav_action + - _self_ + +habitat: + gym: + obs_keys: + - robot_head_depth + - relative_resting_position + - obj_start_sensor + - obj_goal_sensor + - obj_start_gps_compass + - obj_goal_gps_compass + - joint + - is_holding + - ee_pos + - localization_sensor + simulator: + habitat_sim_v0: + allow_sliding: True + +habitat_baselines: + rl: + policy: + hierarchical_policy: + # Override to use the oracle navigation skill. + defined_skills: + nav_to_obj: + skill_name: "OracleNavPolicy" + obs_skill_inputs: ["obj_start_sensor", "abs_obj_start_sensor", "obj_goal_sensor", "abs_obj_goal_sensor"] + max_skill_steps: 300 diff --git a/habitat-baselines/habitat_baselines/config/rearrange/rl_rearrange.yaml b/habitat-baselines/habitat_baselines/config/rearrange/rl_rearrange.yaml index b39dcdeb51..30dfdcd50e 100644 --- a/habitat-baselines/habitat_baselines/config/rearrange/rl_rearrange.yaml +++ b/habitat-baselines/habitat_baselines/config/rearrange/rl_rearrange.yaml @@ -3,6 +3,7 @@ defaults: - /benchmark/rearrange: rearrange - /habitat_baselines: habitat_baselines_rl_config_base + - /habitat_baselines/rl/policy: monolithic - /habitat/simulator/sim_sensors@habitat_baselines.eval.extra_sim_sensors.third_rgb_sensor: third_rgb_sensor - _self_ @@ -31,11 +32,6 @@ habitat_baselines: video_option: ["disk"] rl: - policy: - name: "PointNavResNetPolicy" - action_distribution_type: "gaussian" - action_dist: - use_log_std: True ppo: # ppo params clip_param: 0.2 diff --git a/habitat-baselines/habitat_baselines/config/rearrange/rl_rearrange_easy.yaml b/habitat-baselines/habitat_baselines/config/rearrange/rl_rearrange_easy.yaml index ef37aa71eb..dcdd6bd99a 100644 --- a/habitat-baselines/habitat_baselines/config/rearrange/rl_rearrange_easy.yaml +++ b/habitat-baselines/habitat_baselines/config/rearrange/rl_rearrange_easy.yaml @@ -3,6 +3,7 @@ defaults: - /benchmark/rearrange: rearrange_easy - /habitat_baselines: habitat_baselines_rl_config_base + - /habitat_baselines/rl/policy: monolithic - /habitat/simulator/sim_sensors@habitat_baselines.eval.extra_sim_sensors.third_rgb_sensor: third_rgb_sensor - _self_ @@ -31,11 +32,6 @@ habitat_baselines: video_option: ["disk"] rl: - policy: - name: "PointNavResNetPolicy" - action_distribution_type: "gaussian" - action_dist: - use_log_std: True ppo: # ppo params clip_param: 0.2 diff --git a/habitat-baselines/habitat_baselines/config/rearrange/rl_skill.yaml b/habitat-baselines/habitat_baselines/config/rearrange/rl_skill.yaml index 7a49be1fc4..08d7a89c73 100644 --- a/habitat-baselines/habitat_baselines/config/rearrange/rl_skill.yaml +++ b/habitat-baselines/habitat_baselines/config/rearrange/rl_skill.yaml @@ -3,6 +3,7 @@ defaults: - /benchmark/rearrange: pick - /habitat_baselines: habitat_baselines_rl_config_base + - /habitat_baselines/rl/policy: monolithic - /habitat/simulator/sim_sensors@habitat_baselines.eval.extra_sim_sensors.third_rgb_sensor: third_rgb_sensor - _self_ @@ -16,9 +17,9 @@ habitat_baselines: test_episode_count: -1 eval_ckpt_path_dir: "data/new_checkpoints" # 26 environments will just barely be below 16gb. - # 20 environments will just barely be below 11gb. - num_environments: 20 - checkpoint_folder: "data/new_checkpoints" + # num_environments: 26 + # 18 environments will just barely be below 11gb. + num_environments: 18 num_updates: -1 total_num_steps: 1.0e8 log_interval: 10 @@ -33,10 +34,7 @@ habitat_baselines: rl: policy: - name: "PointNavResNetPolicy" - action_distribution_type: "gaussian" action_dist: - use_log_std: True clamp_std: True std_init: -1.0 use_std_param: True diff --git a/habitat-baselines/habitat_baselines/config/rearrange/spap_pick.yaml b/habitat-baselines/habitat_baselines/config/rearrange/spap_pick.yaml deleted file mode 100644 index 5ba89283e5..0000000000 --- a/habitat-baselines/habitat_baselines/config/rearrange/spap_pick.yaml +++ /dev/null @@ -1,39 +0,0 @@ -# @package _global_ - -defaults: - - /benchmark/rearrange: pick - - /habitat_baselines: habitat_baselines_spa_config_base - - _self_ - -habitat_baselines: - video_dir: "data/vids/" - - sense_plan_act: - verbose: True - run_freq: 2 - n_grasps: 100 - mp_obj: True - mp_margin: null - mp_render: True - timeout: 3 - exec_ee_thresh: 0.1 - # "Priv" or "Reg" - mp_sim_type: "Priv" - video_dir: 'data/vids' - debug_dir: "data/mp_test" - count_obj_collisions: True - grasp_gen_is_verbose: True - ik_dist_thresh: 0.1 - - eval: - video_option: ["disk"] - -habitat: - gym: - obs_keys: ['joint', 'ee_pos'] - desired_goal_keys: ['obj_goal_pos_sensor'] - achieved_goal_keys: [] - action_keys: ['arm_action'] - task: - success_reward: 2000.0 - end_on_success: False diff --git a/habitat-baselines/habitat_baselines/config/rearrange/spap_reach_state.yaml b/habitat-baselines/habitat_baselines/config/rearrange/spap_reach_state.yaml deleted file mode 100644 index d62f630093..0000000000 --- a/habitat-baselines/habitat_baselines/config/rearrange/spap_reach_state.yaml +++ /dev/null @@ -1,42 +0,0 @@ -# @package _global_ - -defaults: - - /benchmark/rearrange: reach_state - - /habitat_baselines: habitat_baselines_spa_config_base - - _self_ - -habitat_baselines: - video_dir: "data/vids/" - - sense_plan_act: - verbose: True - run_freq: 4 - n_grasps: 100 - mp_obj: True - mp_margin: null - mp_render: True - timeout: 3 - exec_ee_thresh: 0.1 - # "Priv" or "Reg" - mp_sim_type: "Priv" - video_dir: 'data/vids' - debug_dir: "data/mp_test" - count_obj_collisions: True - grasp_gen_is_verbose: True - ik_dist_thresh: 0.1 - ik_speed_factor: 1.0 - - eval: - video_option: ["disk"] - -habitat: - gym: - obs_keys: [ 'joint', 'ee_pos' ] - desired_goal_keys: [ 'resting_position' ] - action_keys: [ 'arm_action' ] - task: - reward_measure: "rearrange_reach_reward" - success_measure: "rearrange_reach_success" - success_reward: 10.0 - slack_reward: -0.01 - end_on_success: False diff --git a/habitat-baselines/habitat_baselines/config/rearrange/tp_srl.yaml b/habitat-baselines/habitat_baselines/config/rearrange/tp_srl.yaml deleted file mode 100644 index 467492e1d8..0000000000 --- a/habitat-baselines/habitat_baselines/config/rearrange/tp_srl.yaml +++ /dev/null @@ -1,193 +0,0 @@ -# @package _global_ - -defaults: - - /benchmark/rearrange: rearrange_easy - - /habitat_baselines: habitat_baselines_rl_config_base - - /habitat_baselines/rl/policy/obs_transforms: - - add_virtual_keys_base - - /habitat/simulator/sim_sensors@habitat_baselines.eval.extra_sim_sensors.third_rgb_sensor: third_rgb_sensor - - _self_ - - -habitat_baselines: - verbose: False - trainer_name: "ddppo" - torch_gpu_id: 0 - tensorboard_dir: "tb" - video_dir: "video_dir" - video_fps: 30 - test_episode_count: -1 - eval_ckpt_path_dir: "" - num_environments: 1 - writer_type: 'tb' - checkpoint_folder: "data/new_checkpoints" - num_updates: -1 - total_num_steps: 1.0e8 - log_interval: 10 - num_checkpoints: 20 - force_torch_single_threaded: True - eval_keys_to_include_in_name: ['reward', 'force', 'composite_success'] - load_resume_state_config: False - eval: - use_ckpt_config: False - should_load_ckpt: False - video_option: ["disk"] - - rl: - policy: - name: "HierarchicalPolicy" - obs_transforms: - add_virtual_keys: - virtual_keys: - "goal_to_agent_gps_compass": 2 - hierarchical_policy: - high_level_policy: - name: "FixedHighLevelPolicy" - add_arm_rest: True - defined_skills: - nn_open_cab: - skill_name: "ArtObjSkillPolicy" - name: "PointNavResNetPolicy" - action_distribution_type: "gaussian" - at_resting_threshold: 0.15 - obs_skill_inputs: [] - load_ckpt_file: "data/models/open_cab.pth" - max_skill_steps: 200 - start_zone_radius: 0.3 - force_end_on_timeout: True - - nn_open_fridge: - skill_name: "ArtObjSkillPolicy" - name: "PointNavResNetPolicy" - action_distribution_type: "gaussian" - at_resting_threshold: 0.15 - obs_skill_inputs: [] - load_ckpt_file: "data/models/open_fridge.pth" - max_skill_steps: 200 - start_zone_radius: 0.3 - force_end_on_timeout: True - - nn_close_cab: - skill_name: "ArtObjSkillPolicy" - name: "PointNavResNetPolicy" - action_distribution_type: "gaussian" - at_resting_threshold: 0.2 - obs_skill_inputs: ["obj_start_sensor"] - load_ckpt_file: "data/models/close_cab.pth" - max_skill_steps: 200 - start_zone_radius: 0.3 - force_end_on_timeout: True - - nn_close_fridge: - skill_name: "ArtObjSkillPolicy" - name: "PointNavResNetPolicy" - action_distribution_type: "gaussian" - at_resting_threshold: 0.2 - obs_skill_inputs: ["obj_start_sensor"] - load_ckpt_file: "data/models/close_fridge.pth" - max_skill_steps: 200 - start_zone_radius: 0.3 - force_end_on_timeout: True - - nn_pick: - skill_name: "PickSkillPolicy" - name: "PointNavResNetPolicy" - action_distribution_type: "gaussian" - at_resting_threshold: 0.15 - obs_skill_inputs: ["obj_start_sensor"] - load_ckpt_file: "data/models/pick.pth" - max_skill_steps: 200 - force_end_on_timeout: True - - nn_place: - skill_name: "PlaceSkillPolicy" - name: "PointNavResNetPolicy" - action_distribution_type: "gaussian" - at_resting_threshold: 0.15 - obs_skill_inputs: ["obj_goal_sensor"] - load_ckpt_file: "data/models/place.pth" - max_skill_steps: 200 - force_end_on_timeout: True - - nn_nav: - skill_name: "NavSkillPolicy" - name: "PointNavResNetPolicy" - action_distribution_type: "gaussian" - obs_skill_inputs: ["goal_to_agent_gps_compass"] - obs_skill_input_dim: 2 - lin_speed_stop: 0.067 - ang_speed_stop: 0.067 - load_ckpt_file: "data/models/nav.pth" - max_skill_steps: 300 - force_end_on_timeout: False - - wait_skill: - skill_name: "WaitSkillPolicy" - max_skill_steps: -1.0 - force_end_on_timeout: False - - reset_arm_skill: - skill_name: "ResetArmSkill" - max_skill_steps: 50 - reset_joint_state: [-4.5003259e-01, -1.0799699e00, 9.9526465e-02, 9.3869519e-01, -7.8854430e-04, 1.5702540e00, 4.6168058e-03] - force_end_on_timeout: False - - use_skills: - # Uncomment if you are also using these skills - # open_cab: "NN_OPEN_CAB" - # open_fridge: "NN_OPEN_FRIDGE" - # close_cab: "NN_OPEN_CAB" - # close_fridge: "NN_OPEN_FRIDGE" - pick: "nn_pick" - place: "nn_place" - nav: "nn_nav" - nav_to_receptacle: "nn_nav" - wait: "wait_skill" - reset_arm: "reset_arm_skill" - - ppo: - # ppo params - clip_param: 0.2 - ppo_epoch: 2 - num_mini_batch: 2 - value_loss_coef: 0.5 - entropy_coef: 0.0001 - lr: 2.5e-4 - eps: 1e-5 - max_grad_norm: 0.2 - num_steps: 128 - use_gae: True - gamma: 0.99 - tau: 0.95 - use_linear_clip_decay: False - use_linear_lr_decay: False - reward_window_size: 50 - - use_normalized_advantage: False - - hidden_size: 512 - - # Use double buffered sampling, typically helps - # when environment time is similar or larger than - # policy inference time during rollout generation - use_double_buffered_sampler: False - - ddppo: - sync_frac: 0.6 - # The PyTorch distributed backend to use - distrib_backend: NCCL - # Visual encoder backbone - pretrained_weights: data/ddppo-models/gibson-2plus-resnet50.pth - # Initialize with pretrained weights - pretrained: False - # Initialize just the visual encoder backbone with pretrained weights - pretrained_encoder: False - # Whether the visual encoder backbone will be trained. - train_encoder: True - # Whether to reset the critic linear layer - reset_critic: False - - # Model parameters - backbone: resnet18 - rnn_type: LSTM - num_recurrent_layers: 2 diff --git a/habitat-baselines/habitat_baselines/config/rearrange/tp_srl_oracle_nav.yaml b/habitat-baselines/habitat_baselines/config/rearrange/tp_srl_oracle_nav.yaml deleted file mode 100644 index 8d8c6ca12f..0000000000 --- a/habitat-baselines/habitat_baselines/config/rearrange/tp_srl_oracle_nav.yaml +++ /dev/null @@ -1,194 +0,0 @@ -# @package _global_ - -defaults: - - /benchmark/rearrange: rearrange_easy - - /habitat_baselines: habitat_baselines_rl_config_base - - /habitat_baselines/rl/policy/obs_transforms: - - add_virtual_keys_base - - /habitat/task/actions: - - arm_action - - base_velocity - - rearrange_stop - - oracle_nav_action - - /habitat/simulator/sim_sensors@habitat_baselines.eval.extra_sim_sensors.third_rgb_sensor: third_rgb_sensor - - _self_ - -habitat_baselines: - verbose: False - trainer_name: "ddppo" - torch_gpu_id: 0 - tensorboard_dir: "tb" - video_dir: "video_dir" - video_fps: 30 - test_episode_count: -1 - eval_ckpt_path_dir: "" - num_environments: 1 - writer_type: 'tb' - checkpoint_folder: "data/new_checkpoints" - num_updates: -1 - total_num_steps: 1.0e8 - log_interval: 10 - num_checkpoints: 20 - force_torch_single_threaded: True - eval_keys_to_include_in_name: ['reward', 'force', 'composite_success'] - eval: - video_option: ["disk"] - use_ckpt_config: False - should_load_ckpt: False - - rl: - policy: - name: "HierarchicalPolicy" - obs_transforms: - add_virtual_keys: - virtual_keys: - "goal_to_agent_gps_compass": 2 - hierarchical_policy: - high_level_policy: - name: "FixedHighLevelPolicy" - add_arm_rest: True - defined_skills: - NN_OPEN_CAB: - skill_name: "ArtObjSkillPolicy" - name: "PointNavResNetPolicy" - action_distribution_type: "gaussian" - at_resting_threshold: 0.15 - obs_skill_inputs: [] - load_ckpt_file: "data/models/open_cab.pth" - max_skill_steps: 200 - start_zone_radius: 0.3 - force_end_on_timeout: True - - NN_OPEN_FRIDGE: - skill_name: "ArtObjSkillPolicy" - name: "PointNavResNetPolicy" - action_distribution_type: "gaussian" - at_resting_threshold: 0.15 - obs_skill_inputs: [] - load_ckpt_file: "data/models/open_fridge.pth" - max_skill_steps: 200 - start_zone_radius: 0.3 - force_end_on_timeout: True - - NN_CLOSE_CAB: - skill_name: "ArtObjSkillPolicy" - name: "PointNavResNetPolicy" - action_distribution_type: "gaussian" - at_resting_threshold: 0.2 - obs_skill_inputs: ["obj_start_sensor"] - load_ckpt_file: "data/models/close_cab.pth" - max_skill_steps: 200 - start_zone_radius: 0.3 - force_end_on_timeout: True - - NN_CLOSE_FRIDGE: - skill_name: "ArtObjSkillPolicy" - name: "PointNavResNetPolicy" - action_distribution_type: "gaussian" - at_resting_threshold: 0.2 - obs_skill_inputs: ["obj_start_sensor"] - load_ckpt_file: "data/models/close_fridge.pth" - max_skill_steps: 200 - start_zone_radius: 0.3 - force_end_on_timeout: True - - nn_pick: - skill_name: "PickSkillPolicy" - name: "PointNavResNetPolicy" - action_distribution_type: "gaussian" - at_resting_threshold: 0.15 - obs_skill_inputs: ["obj_start_sensor"] - load_ckpt_file: "data/models/pick.pth" - max_skill_steps: 200 - force_end_on_timeout: True - - GT_NAV: - skill_name: "OracleNavPolicy" - obs_skill_inputs: ["obj_start_sensor", "abs_obj_start_sensor", "obj_goal_sensor", "abs_obj_goal_sensor"] - goal_sensors: ["obj_goal_sensor", "abs_obj_goal_sensor"] - NAV_ACTION_NAME: "base_velocity" - max_skill_steps: 300 - force_end_on_timeout: True - stop_angle_thresh: 0.2 - stop_dist_thresh: 1.0 - - nn_place: - skill_name: "PlaceSkillPolicy" - name: "PointNavResNetPolicy" - action_distribution_type: "gaussian" - at_resting_threshold: 0.15 - obs_skill_inputs: ["obj_goal_sensor"] - load_ckpt_file: "data/models/place.pth" - max_skill_steps: 200 - force_end_on_timeout: True - - wait_skill: - skill_name: "WaitSkillPolicy" - max_skill_steps: -1.0 - force_end_on_timeout: False - - reset_arm_skill: - skill_name: "ResetArmSkill" - max_skill_steps: 50 - reset_joint_state: [-4.5003259e-01, -1.0799699e00, 9.9526465e-02, 9.3869519e-01, -7.8854430e-04, 1.5702540e00, 4.6168058e-03] - force_end_on_timeout: False - - use_skills: - # Uncomment if you are also using these skills - # open_cab: "NN_OPEN_CAB" - # open_fridge: "NN_OPEN_FRIDGE" - # close_cab: "NN_OPEN_CAB" - # close_fridge: "NN_OPEN_FRIDGE" - pick: "nn_pick" - place: "nn_place" - nav: "GT_NAV" - nav_to_receptacle: "GT_NAV" - wait: "wait_skill" - reset_arm: "reset_arm_skill" - - ppo: - # ppo params - clip_param: 0.2 - ppo_epoch: 2 - num_mini_batch: 2 - value_loss_coef: 0.5 - entropy_coef: 0.0001 - lr: 2.5e-4 - eps: 1e-5 - max_grad_norm: 0.2 - num_steps: 128 - use_gae: True - gamma: 0.99 - tau: 0.95 - use_linear_clip_decay: False - use_linear_lr_decay: False - reward_window_size: 50 - - use_normalized_advantage: False - - hidden_size: 512 - - # Use double buffered sampling, typically helps - # when environment time is similar or larger than - # policy inference time during rollout generation - use_double_buffered_sampler: False - - ddppo: - sync_frac: 0.6 - # The PyTorch distributed backend to use - distrib_backend: NCCL - # Visual encoder backbone - pretrained_weights: data/ddppo-models/gibson-2plus-resnet50.pth - # Initialize with pretrained weights - pretrained: False - # Initialize just the visual encoder backbone with pretrained weights - pretrained_encoder: False - # Whether the visual encoder backbone will be trained. - train_encoder: True - # Whether to reset the critic linear layer - reset_critic: False - - # Model parameters - backbone: resnet18 - rnn_type: LSTM - num_recurrent_layers: 2 diff --git a/habitat-baselines/habitat_baselines/config/tp_srl_test/tp_srl_test.yaml b/habitat-baselines/habitat_baselines/config/tp_srl_test/tp_srl_test.yaml deleted file mode 100644 index fd413e2749..0000000000 --- a/habitat-baselines/habitat_baselines/config/tp_srl_test/tp_srl_test.yaml +++ /dev/null @@ -1,99 +0,0 @@ -# @package _global_ - -defaults: - - /benchmark/rearrange: rearrange_easy - - /habitat_baselines: habitat_baselines_rl_config_base - - /habitat_baselines/rl/policy/obs_transforms: - - add_virtual_keys_base - -habitat_baselines: - verbose: False - trainer_name: "ppo" - torch_gpu_id: 0 - tensorboard_dir: "" - video_dir: "data/test_checkpoints/ppo/pointnav/video" - test_episode_count: 2 - eval_ckpt_path_dir: "" - num_environments: 1 - checkpoint_folder: "data/test_checkpoints/ppo/pointnav/" - num_updates: 2 - log_interval: 100 - num_checkpoints: 2 - force_torch_single_threaded: True - load_resume_state_config: False - eval: - use_ckpt_config: False - should_load_ckpt: False - - rl: - policy: - name: "HierarchicalPolicy" - - obs_transforms: - add_virtual_keys: - virtual_keys: - "goal_to_agent_gps_compass": 2 - hierarchical_policy: - high_level_policy: - name: "FixedHighLevelPolicy" - add_arm_rest: True - defined_skills: - nn_pick: - skill_name: "PickSkillPolicy" - name: "PointNavResNetPolicy" - action_distribution_type: "gaussian" - at_resting_threshold: 0.15 - obs_skill_inputs: ["obj_start_sensor"] - load_ckpt_file: "" - max_skill_steps: 200 - force_end_on_timeout: True - force_config_file: "benchmark/rearrange=pick" - - nn_place: - skill_name: "PlaceSkillPolicy" - name: "PointNavResNetPolicy" - action_distribution_type: "gaussian" - at_resting_threshold: 0.15 - obs_skill_inputs: ["obj_goal_sensor"] - load_ckpt_file: "" - max_skill_steps: 200 - force_end_on_timeout: True - force_config_file: "benchmark/rearrange=place" - - nn_nav: - skill_name: "NavSkillPolicy" - name: "PointNavResNetPolicy" - action_distribution_type: "gaussian" - obs_skill_inputs: ["goal_to_agent_gps_compass"] - obs_skill_input_dim: 2 - lin_speed_stop: 0.067 - ang_speed_stop: 0.067 - load_ckpt_file: "" - max_skill_steps: 300 - force_end_on_timeout: False - force_config_file: "benchmark/rearrange=nav_to_obj" - - wait_skill: - skill_name: "WaitSkillPolicy" - max_skill_steps: -1.0 - force_end_on_timeout: False - - reset_arm_skill: - skill_name: "ResetArmSkill" - max_skill_steps: 50 - reset_joint_state: [-4.5003259e-01, -1.0799699e00, 9.9526465e-02, 9.3869519e-01, -7.8854430e-04, 1.5702540e00, 4.6168058e-03] - force_end_on_timeout: False - - use_skills: - # Uncomment if you are also using these skills - pick: "nn_pick" - place: "nn_place" - nav: "nn_nav" - nav_to_receptacle: "nn_nav" - wait: "wait_skill" - reset_arm: "reset_arm_skill" - ddppo: - pretrained: False - pretrained_encoder: False - train_encoder: True - reset_critic: False diff --git a/habitat-baselines/habitat_baselines/rl/ddppo/algo/ddppo.py b/habitat-baselines/habitat_baselines/rl/ddppo/algo/ddppo.py index 0bf1a9aff9..25f19ea7d8 100644 --- a/habitat-baselines/habitat_baselines/rl/ddppo/algo/ddppo.py +++ b/habitat-baselines/habitat_baselines/rl/ddppo/algo/ddppo.py @@ -10,6 +10,7 @@ import torch from torch import distributed as distrib +from habitat_baselines.common.baseline_registry import baseline_registry from habitat_baselines.rl.ppo import PPO @@ -131,5 +132,6 @@ def _evaluate_actions(self, *args, **kwargs): ) +@baseline_registry.register_updater class DDPPO(DecentralizedDistributedMixin, PPO): pass diff --git a/habitat-baselines/habitat_baselines/rl/hrl/__init__.py b/habitat-baselines/habitat_baselines/rl/hrl/__init__.py index 0f0db8cad5..39a3f0c90d 100644 --- a/habitat-baselines/habitat_baselines/rl/hrl/__init__.py +++ b/habitat-baselines/habitat_baselines/rl/hrl/__init__.py @@ -3,3 +3,11 @@ # Copyright (c) Meta Platforms, Inc. and its affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + +from habitat_baselines.rl.hrl.hrl_ppo import HRLPPO +from habitat_baselines.rl.hrl.hrl_rollout_storage import HrlRolloutStorage + +__all__ = [ + "HRLPPO", + "HrlRolloutStorage", +] diff --git a/habitat-baselines/habitat_baselines/rl/hrl/hierarchical_policy.py b/habitat-baselines/habitat_baselines/rl/hrl/hierarchical_policy.py index b1437a011a..bcad61b862 100644 --- a/habitat-baselines/habitat_baselines/rl/hrl/hierarchical_policy.py +++ b/habitat-baselines/habitat_baselines/rl/hrl/hierarchical_policy.py @@ -8,20 +8,24 @@ import gym.spaces as spaces import torch +import torch.nn as nn from habitat.core.spaces import ActionSpace from habitat.tasks.rearrange.multi_task.composite_sensors import ( CompositeSuccess, ) +from habitat.tasks.rearrange.multi_task.pddl_domain import PddlProblem from habitat_baselines.common.baseline_registry import baseline_registry from habitat_baselines.common.logging import baselines_logger from habitat_baselines.rl.hrl.hl import ( # noqa: F401. FixedHighLevelPolicy, HighLevelPolicy, + NeuralHighLevelPolicy, ) from habitat_baselines.rl.hrl.skills import ( # noqa: F401. ArtObjSkillPolicy, NavSkillPolicy, + NoopSkillPolicy, OracleNavPolicy, PickSkillPolicy, PlaceSkillPolicy, @@ -30,12 +34,21 @@ WaitSkillPolicy, ) from habitat_baselines.rl.hrl.utils import find_action_range -from habitat_baselines.rl.ppo.policy import Policy +from habitat_baselines.rl.ppo.policy import Policy, PolicyActionData from habitat_baselines.utils.common import get_num_actions @baseline_registry.register_policy -class HierarchicalPolicy(Policy): +class HierarchicalPolicy(nn.Module, Policy): + """ + :property _pddl_problem: Stores the PDDL domain information. This allows + accessing all the possible entities, actions, and predicates. Note that + this is not the grounded PDDL problem with truth values assigned to the + predicates basedon the current simulator state. + """ + + _pddl_problem: PddlProblem + def __init__( self, config, @@ -54,16 +67,23 @@ def __init__( self._name_to_idx: Dict[str, int] = {} self._idx_to_name: Dict[int, str] = {} - for i, (skill_id, use_skill_name) in enumerate( - config.hierarchical_policy.use_skills.items() - ): - if use_skill_name == "": - # Skip loading this skill if no name is provided - continue - skill_config = config.hierarchical_policy.defined_skills[ - use_skill_name - ] + task_spec_file = osp.join( + full_config.habitat.task.task_spec_base_path, + full_config.habitat.task.task_spec + ".yaml", + ) + domain_file = full_config.habitat.task.pddl_domain_def + self._pddl_problem = PddlProblem( + domain_file, + task_spec_file, + config, + ) + + skill_i = 0 + for ( + skill_name, + skill_config, + ) in config.hierarchical_policy.defined_skills.items(): cls = eval(skill_config.skill_name) skill_policy = cls.from_config( skill_config, @@ -72,13 +92,17 @@ def __init__( self._num_envs, full_config, ) - self._skills[i] = skill_policy - self._name_to_idx[skill_id] = i - self._idx_to_name[i] = skill_id + skill_policy.set_pddl_problem(self._pddl_problem) + if skill_config.pddl_action_names is None: + action_names = [skill_name] + else: + action_names = skill_config.pddl_action_names + for skill_id in action_names: + self._name_to_idx[skill_id] = skill_i + self._idx_to_name[skill_i] = skill_id + self._skills[skill_i] = skill_policy + skill_i += 1 - self._call_high_level: torch.Tensor = torch.ones( - self._num_envs, dtype=torch.bool - ) self._cur_skills: torch.Tensor = torch.full( (self._num_envs,), -1, dtype=torch.long ) @@ -88,12 +112,11 @@ def __init__( ) self._high_level_policy: HighLevelPolicy = high_level_cls( config.hierarchical_policy.high_level_policy, - osp.join( - full_config.habitat.task.task_spec_base_path, - full_config.habitat.task.task_spec + ".yaml", - ), + self._pddl_problem, num_envs, self._name_to_idx, + observation_space, + action_space, ) self._stop_action_idx, _ = find_action_range( action_space, "rearrange_stop" @@ -102,35 +125,56 @@ def __init__( def eval(self): pass - def get_policy_info(self, infos, dones): - policy_infos = [] - for i, info in enumerate(infos): + def get_policy_action_space( + self, env_action_space: spaces.Space + ) -> spaces.Space: + """ + Fetches the policy action space for learning. If we are learning the HL + policy, it will return its custom action space for learning. + """ + + return self._high_level_policy.get_policy_action_space( + env_action_space + ) + + def extract_policy_info( + self, action_data, infos, dones + ) -> List[Dict[str, float]]: + ret_policy_infos = [] + for i, (info, policy_info) in enumerate( + zip(infos, action_data.policy_info) + ): cur_skill_idx = self._cur_skills[i].item() - policy_info: Dict[str, Any] = { - "cur_skill": self._idx_to_name[cur_skill_idx] + ret_policy_info: Dict[str, Any] = { + "cur_skill": self._idx_to_name[cur_skill_idx], + **policy_info, } did_skill_fail = dones[i] and not info[CompositeSuccess.cls_uuid] for skill_name, idx in self._name_to_idx.items(): - policy_info[f"failed_skill_{skill_name}"] = ( + ret_policy_info[f"failed_skill_{skill_name}"] = ( did_skill_fail if idx == cur_skill_idx else 0.0 ) - policy_infos.append(policy_info) + ret_policy_infos.append(ret_policy_info) - return policy_infos + return ret_policy_infos @property def num_recurrent_layers(self): - return self._skills[0].num_recurrent_layers + if self._high_level_policy.num_recurrent_layers != 0: + return self._high_level_policy.num_recurrent_layers + else: + return self._skills[0].num_recurrent_layers @property def should_load_agent_state(self): return False def parameters(self): - return self._skills[0].parameters() # type: ignore[attr-defined] + return self._high_level_policy.parameters() def to(self, device): + self._high_level_policy.to(device) for skill in self._skills.values(): skill.to(device) @@ -172,15 +216,31 @@ def act( masks, deterministic=False, ): - self._high_level_policy.apply_mask(masks) # type: ignore[attr-defined] + masks_cpu = masks.cpu() + log_info: List[Dict[str, Any]] = [{} for _ in range(self._num_envs)] + self._high_level_policy.apply_mask(masks_cpu) # type: ignore[attr-defined] - should_terminate: torch.BoolTensor = torch.zeros( + call_high_level: torch.BoolTensor = torch.zeros( (self._num_envs,), dtype=torch.bool ) bad_should_terminate: torch.BoolTensor = torch.zeros( (self._num_envs,), dtype=torch.bool ) + hl_wants_skill_term = self._high_level_policy.get_termination( + observations, + rnn_hidden_states, + prev_actions, + masks, + self._cur_skills, + log_info, + ) + # Initialize empty action set based on the overall action space. + actions = torch.zeros( + (self._num_envs, get_num_actions(self._action_space)), + device=masks.device, + ) + grouped_skills = self._broadcast_skill_ids( self._cur_skills, sel_dat={ @@ -188,6 +248,8 @@ def act( "rnn_hidden_states": rnn_hidden_states, "prev_actions": prev_actions, "masks": masks, + "actions": actions, + "hl_wants_skill_term": hl_wants_skill_term, }, # Only decide on skill termination if the episode is active. should_adds=masks, @@ -197,40 +259,52 @@ def act( for skill_id, (batch_ids, dat) in grouped_skills.items(): if skill_id == -1: # Policy has not prediced a skill yet. - should_terminate[batch_ids] = 1.0 + call_high_level[batch_ids] = 1.0 continue + # TODO: either change name of the function or assign actions somewhere + # else. Updating actions in should_terminate is counterintuitive + ( - should_terminate[batch_ids], + call_high_level[batch_ids], bad_should_terminate[batch_ids], + actions[batch_ids], ) = self._skills[skill_id].should_terminate( **dat, batch_idx=batch_ids, + log_info=log_info, + skill_name=[ + self._idx_to_name[self._cur_skills[i].item()] + for i in batch_ids + ], ) - self._call_high_level = should_terminate # Always call high-level if the episode is over. - self._call_high_level = self._call_high_level | (~masks).view(-1).cpu() + call_high_level = call_high_level | (~masks_cpu).view(-1) # If any skills want to terminate invoke the high-level policy to get # the next skill. hl_terminate = torch.zeros(self._num_envs, dtype=torch.bool) - if self._call_high_level.sum() > 0: + hl_info: Dict[str, Any] = {} + if call_high_level.sum() > 0: ( new_skills, new_skill_args, hl_terminate, + hl_info, ) = self._high_level_policy.get_next_skill( observations, rnn_hidden_states, prev_actions, masks, - self._call_high_level, + call_high_level, + deterministic, + log_info, ) sel_grouped_skills = self._broadcast_skill_ids( new_skills, sel_dat={}, - should_adds=self._call_high_level, + should_adds=call_high_level, ) for skill_id, (batch_ids, _) in sel_grouped_skills.items(): @@ -241,17 +315,16 @@ def act( rnn_hidden_states, prev_actions, ) - rnn_hidden_states[batch_ids] *= 0.0 - prev_actions[batch_ids] *= 0 - self._cur_skills = ( - (~self._call_high_level) * self._cur_skills - ) + (self._call_high_level * new_skills) - - # Compute the actions from the current skills - actions = torch.zeros( - (self._num_envs, get_num_actions(self._action_space)), - device=masks.device, - ) + if "rnn_hidden_states" not in hl_info: + rnn_hidden_states[batch_ids] *= 0.0 + prev_actions[batch_ids] *= 0 + elif self._skills[skill_id].has_hidden_state: + raise ValueError( + f"The code does not currently support neural LL and neural HL skills. Skill={self._skills[skill_id]}, HL={self._high_level_policy}" + ) + self._cur_skills = ((~call_high_level) * self._cur_skills) + ( + call_high_level * new_skills + ) grouped_skills = self._broadcast_skill_ids( self._cur_skills, @@ -263,7 +336,7 @@ def act( }, ) for skill_id, (batch_ids, batch_dat) in grouped_skills.items(): - tmp_actions, tmp_rnn = self._skills[skill_id].act( + action_data = self._skills[skill_id].act( observations=batch_dat["observations"], rnn_hidden_states=batch_dat["rnn_hidden_states"], prev_actions=batch_dat["prev_actions"], @@ -272,8 +345,9 @@ def act( ) # LL skills are not allowed to terminate the overall episode. - actions[batch_ids] = tmp_actions - rnn_hidden_states[batch_ids] = tmp_rnn + actions[batch_ids] += action_data.actions + # Add actions from apply_postcond + rnn_hidden_states[batch_ids] = action_data.rnn_hidden_states actions[:, self._stop_action_idx] = 0.0 should_terminate = bad_should_terminate | hl_terminate @@ -285,7 +359,44 @@ def act( ) actions[batch_idx, self._stop_action_idx] = 1.0 - return (None, actions, None, rnn_hidden_states) + action_kwargs = { + "rnn_hidden_states": rnn_hidden_states, + "actions": actions, + } + action_kwargs.update(hl_info) + + return PolicyActionData( + take_actions=actions, + policy_info=log_info, + should_inserts=call_high_level, + **action_kwargs, + ) + + def get_value(self, observations, rnn_hidden_states, prev_actions, masks): + return self._high_level_policy.get_value( + observations, rnn_hidden_states, prev_actions, masks + ) + + def _get_policy_components(self) -> List[nn.Module]: + return self._high_level_policy.get_policy_components() + + def evaluate_actions( + self, + observations, + rnn_hidden_states, + prev_actions, + masks, + action, + rnn_build_seq_info: Dict[str, torch.Tensor], + ): + return self._high_level_policy.evaluate_actions( + observations, + rnn_hidden_states, + prev_actions, + masks, + action, + rnn_build_seq_info, + ) @classmethod def from_config( diff --git a/habitat-baselines/habitat_baselines/rl/hrl/high_level_policy.py b/habitat-baselines/habitat_baselines/rl/hrl/high_level_policy.py deleted file mode 100644 index 082dc15434..0000000000 --- a/habitat-baselines/habitat_baselines/rl/hrl/high_level_policy.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and its affiliates. -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -from typing import List, Tuple - -import torch -import yaml - -from habitat.config.default import get_full_habitat_config_path -from habitat.tasks.rearrange.multi_task.rearrange_pddl import parse_func -from habitat_baselines.common.logging import baselines_logger - - -class FixedHighLevelPolicy: - """ - :property _solution_actions: List of tuples were first tuple element is the - action name and the second is the action arguments. - """ - - _solution_actions: List[Tuple[str, List[str]]] - - def __init__(self, config, task_spec_file, num_envs, skill_name_to_idx): - with open(get_full_habitat_config_path(task_spec_file), "r") as f: - task_spec = yaml.safe_load(f) - - self._solution_actions = [] - if "solution" not in task_spec: - raise ValueError( - f"The ground truth task planner only works when the task solution is hard-coded in the PDDL problem file at {task_spec_file}" - ) - for i, sol_step in enumerate(task_spec["solution"]): - sol_action = parse_func(sol_step) - self._solution_actions.append(sol_action) - if config.add_arm_rest and i < (len(task_spec["solution"]) - 1): - self._solution_actions.append(parse_func("reset_arm(0)")) - - # Add a wait action at the end. - self._solution_actions.append(parse_func("wait(30)")) - - self._next_sol_idxs = torch.zeros(num_envs, dtype=torch.int32) - self._num_envs = num_envs - self._skill_name_to_idx = skill_name_to_idx - - def apply_mask(self, mask): - self._next_sol_idxs *= mask.cpu().view(-1) - - def get_next_skill( - self, observations, rnn_hidden_states, prev_actions, masks, plan_masks - ): - next_skill = torch.zeros(self._num_envs) - skill_args_data = [None for _ in range(self._num_envs)] - immediate_end = torch.zeros(self._num_envs, dtype=torch.bool) - for batch_idx, should_plan in enumerate(plan_masks): - if should_plan == 1.0: - if self._next_sol_idxs[batch_idx] >= len( - self._solution_actions - ): - baselines_logger.info( - f"Calling for immediate end with {self._next_sol_idxs[batch_idx]}" - ) - immediate_end[batch_idx] = True - use_idx = len(self._solution_actions) - 1 - else: - use_idx = self._next_sol_idxs[batch_idx].item() - - skill_name, skill_args = self._solution_actions[use_idx] - baselines_logger.info( - f"Got next element of the plan with {skill_name}, {skill_args}" - ) - if skill_name not in self._skill_name_to_idx: - raise ValueError( - f"Could not find skill named {skill_name} in {self._skill_name_to_idx}" - ) - next_skill[batch_idx] = self._skill_name_to_idx[skill_name] - - skill_args_data[batch_idx] = skill_args # type: ignore[call-overload] - - self._next_sol_idxs[batch_idx] += 1 - - return next_skill, skill_args_data, immediate_end diff --git a/habitat-baselines/habitat_baselines/rl/hrl/hl/__init__.py b/habitat-baselines/habitat_baselines/rl/hrl/hl/__init__.py index 09023c316f..31bc4beeb1 100644 --- a/habitat-baselines/habitat_baselines/rl/hrl/hl/__init__.py +++ b/habitat-baselines/habitat_baselines/rl/hrl/hl/__init__.py @@ -1,4 +1,5 @@ from habitat_baselines.rl.hrl.hl.fixed_policy import FixedHighLevelPolicy from habitat_baselines.rl.hrl.hl.high_level_policy import HighLevelPolicy +from habitat_baselines.rl.hrl.hl.neural_policy import NeuralHighLevelPolicy -__all__ = ["HighLevelPolicy", "FixedHighLevelPolicy"] +__all__ = ["HighLevelPolicy", "FixedHighLevelPolicy", "NeuralHighLevelPolicy"] diff --git a/habitat-baselines/habitat_baselines/rl/hrl/hl/fixed_policy.py b/habitat-baselines/habitat_baselines/rl/hrl/hl/fixed_policy.py index acbed7d418..fa6b09f542 100644 --- a/habitat-baselines/habitat_baselines/rl/hrl/hl/fixed_policy.py +++ b/habitat-baselines/habitat_baselines/rl/hrl/hl/fixed_policy.py @@ -5,9 +5,7 @@ from typing import List, Tuple import torch -import yaml -from habitat.config.default import get_full_habitat_config_path from habitat.tasks.rearrange.multi_task.rearrange_pddl import parse_func from habitat_baselines.common.logging import baselines_logger from habitat_baselines.rl.hrl.hl.high_level_policy import HighLevelPolicy @@ -23,39 +21,25 @@ class FixedHighLevelPolicy(HighLevelPolicy): _solution_actions: List[Tuple[str, List[str]]] - def __init__(self, config, task_spec_file, num_envs, skill_name_to_idx): - """ - Initialize the `FixedHighLevelPolicy` object. - - Args: - config: Config object containing the configurations for the agent. - task_spec_file: Path to the task specification file. - num_envs: Number of parallel environments. - skill_name_to_idx: Dictionary mapping skill names to skill indices. - """ - with open(get_full_habitat_config_path(task_spec_file), "r") as f: - task_spec = yaml.safe_load(f) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) - self._num_envs = num_envs - self._skill_name_to_idx = skill_name_to_idx self._solution_actions = self._parse_solution_actions( - config, task_spec, task_spec_file + self._pddl_prob.solution ) - self._next_sol_idxs = torch.zeros(num_envs, dtype=torch.int32) - - def _parse_solution_actions(self, config, task_spec, task_spec_file): - if "solution" not in task_spec: - raise ValueError( - f"The ground truth task planner only works when the task solution is hard-coded in the PDDL problem file at {task_spec_file}." - ) + self._next_sol_idxs = torch.zeros(self._num_envs, dtype=torch.int32) + def _parse_solution_actions(self, solution): solution_actions = [] - for i, sol_step in enumerate(task_spec["solution"]): - sol_action = parse_func(sol_step) + for i, hl_action in enumerate(solution): + sol_action = ( + hl_action.name, + [x.name for x in hl_action.param_values], + ) solution_actions.append(sol_action) - if config.add_arm_rest and i < (len(task_spec["solution"]) - 1): + if self._config.add_arm_rest and i < (len(solution) - 1): solution_actions.append(parse_func("reset_arm(0)")) # Add a wait action at the end. @@ -93,26 +77,15 @@ def _get_next_sol_idx(self, batch_idx, immediate_end): return self._next_sol_idxs[batch_idx].item() def get_next_skill( - self, observations, rnn_hidden_states, prev_actions, masks, plan_masks + self, + observations, + rnn_hidden_states, + prev_actions, + masks, + plan_masks, + deterministic, + log_info, ): - """ - Get the next skill to be executed. - - Args: - observations: Current observations. - rnn_hidden_states: Current hidden states of the RNN. - prev_actions: Previous actions taken. - masks: Binary masks indicating which environment(s) are active. - plan_masks: Binary masks indicating which environment(s) should - plan the next skill. - - Returns: - A tuple containing: - - next_skill: Next skill to be executed. - - skill_args_data: Arguments for the next skill. - - immediate_end: Binary masks indicating which environment(s) should - end immediately. - """ next_skill = torch.zeros(self._num_envs) skill_args_data = [None for _ in range(self._num_envs)] immediate_end = torch.zeros(self._num_envs, dtype=torch.bool) @@ -134,4 +107,4 @@ def get_next_skill( self._next_sol_idxs[batch_idx] += 1 - return next_skill, skill_args_data, immediate_end + return next_skill, skill_args_data, immediate_end, {} diff --git a/habitat-baselines/habitat_baselines/rl/hrl/hl/high_level_policy.py b/habitat-baselines/habitat_baselines/rl/hrl/hl/high_level_policy.py index 848ebd92e1..ec12f0ebc1 100644 --- a/habitat-baselines/habitat_baselines/rl/hrl/hl/high_level_policy.py +++ b/habitat-baselines/habitat_baselines/rl/hrl/hl/high_level_policy.py @@ -1,15 +1,122 @@ -from typing import Any, List, Tuple +from typing import Any, Dict, List, Tuple +import gym.spaces as spaces import torch +import torch.nn as nn +from habitat.tasks.rearrange.multi_task.pddl_domain import PddlProblem + + +class HighLevelPolicy(nn.Module): + """ + High level policy that selects from low-level skills. + """ + + def __init__( + self, + config, + pddl_problem: PddlProblem, + num_envs: int, + skill_name_to_idx: Dict[int, str], + observation_space: spaces.Space, + action_space: spaces.Space, + ): + super().__init__() + self._config = config + self._pddl_prob = pddl_problem + self._num_envs = num_envs + self._skill_name_to_idx = skill_name_to_idx + self._obs_space = observation_space + self._device = None + + def to(self, device): + self._device = device + return super().to(device) + + def get_value(self, observations, rnn_hidden_states, prev_actions, masks): + raise NotImplementedError() + + def evaluate_actions( + self, + observations, + rnn_hidden_states, + prev_actions, + masks, + action, + rnn_build_seq_info, + ): + raise NotImplementedError() + + @property + def num_recurrent_layers(self): + return 0 + + def parameters(self): + return iter([nn.Parameter(torch.zeros((1,), device=self._device))]) + + def get_policy_action_space( + self, env_action_space: spaces.Space + ) -> spaces.Space: + return env_action_space -class HighLevelPolicy: def get_next_skill( - self, observations, rnn_hidden_states, prev_actions, masks, plan_masks - ) -> Tuple[torch.Tensor, List[Any], torch.BoolTensor]: + self, + observations, + rnn_hidden_states: torch.Tensor, + prev_actions: torch.Tensor, + masks: torch.Tensor, + plan_masks: torch.Tensor, + deterministic: bool, + log_info: List[Dict[str, Any]], + ) -> Tuple[torch.Tensor, List[Any], torch.BoolTensor, Dict[str, Any]]: """ - :returns: A tuple containing the next skill index, a list of arguments - for the skill, and if the high-level policy requests immediate - termination. + Get the next skill to be executed. + + Args: + observations: Current observations. + rnn_hidden_states: Current hidden states of the RNN. + prev_actions: Previous actions taken. + masks: Binary masks indicating which environment(s) are active. + plan_masks: Binary masks indicating which environment(s) should + plan the next skill. + + Returns: + A tuple containing: + - next_skill: Next skill to be executed. + - skill_args_data: Arguments for the next skill. + - immediate_end: Binary masks indicating which environment(s) should + end immediately. + - Information for PolicyActionData """ raise NotImplementedError() + + def apply_mask(self, mask: torch.Tensor) -> None: + """ + Called before every step with the mask information at the current step. + """ + + def get_policy_components(self) -> List[nn.Module]: + """ + Gets the torch modules that are in the HL policy architecture. + """ + + return [] + + def get_termination( + self, + observations, + rnn_hidden_states, + prev_actions, + masks, + cur_skills, + log_info, + ) -> torch.BoolTensor: + """ + Can force the currently executing skill to terminate. + In the base HighLevelPolicy, the skill always continues. + + Returns: A binary tensor where 1 indicates the current skill should + terminate and 0 indicates the skill can continue. + """ + + return torch.zeros(self._num_envs, dtype=torch.bool) diff --git a/habitat-baselines/habitat_baselines/rl/hrl/hl/neural_policy.py b/habitat-baselines/habitat_baselines/rl/hrl/hl/neural_policy.py new file mode 100644 index 0000000000..d06f2d7d90 --- /dev/null +++ b/habitat-baselines/habitat_baselines/rl/hrl/hl/neural_policy.py @@ -0,0 +1,217 @@ +import logging +from itertools import chain +from typing import Any, List + +import gym.spaces as spaces +import numpy as np +import torch +import torch.nn as nn + +from habitat.tasks.rearrange.multi_task.pddl_action import PddlAction +from habitat_baselines.common.logging import baselines_logger +from habitat_baselines.rl.ddppo.policy import resnet +from habitat_baselines.rl.ddppo.policy.resnet_policy import ResNetEncoder +from habitat_baselines.rl.hrl.hl.high_level_policy import HighLevelPolicy +from habitat_baselines.rl.models.rnn_state_encoder import ( + build_rnn_state_encoder, +) +from habitat_baselines.rl.ppo.policy import CriticHead +from habitat_baselines.utils.common import CategoricalNet + + +class NeuralHighLevelPolicy(HighLevelPolicy): + """ + A trained high-level policy that selects low-level skills and their skill + inputs. Is limited to discrete skills and discrete skill inputs. The policy + detects the available skills and their possible arguments via the PDDL + problem. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._all_actions = self._setup_actions() + self._n_actions = len(self._all_actions) + + use_obs_space = spaces.Dict( + { + k: self._obs_space.spaces[k] + for k in self._config.policy_input_keys + } + ) + self._im_obs_space = spaces.Dict( + {k: v for k, v in use_obs_space.items() if len(v.shape) == 3} + ) + + state_obs_space = { + k: v for k, v in use_obs_space.items() if len(v.shape) == 1 + } + self._state_obs_space = spaces.Dict(state_obs_space) + + rnn_input_size = sum( + v.shape[0] for v in self._state_obs_space.values() + ) + self._hidden_size = self._config.hidden_dim + if len(self._im_obs_space) > 0 and self._config.backbone != "NONE": + resnet_baseplanes = 32 + self._visual_encoder = ResNetEncoder( + self._im_obs_space, + baseplanes=resnet_baseplanes, + ngroups=resnet_baseplanes // 2, + make_backbone=getattr(resnet, self._config.backbone), + ) + self._visual_fc = nn.Sequential( + nn.Flatten(), + nn.Linear( + np.prod(self._visual_encoder.output_shape), + self._hidden_size, + ), + nn.ReLU(True), + ) + rnn_input_size += self._hidden_size + else: + self._visual_encoder = nn.Sequential() + self._visual_fc = nn.Sequential() + + self._state_encoder = build_rnn_state_encoder( + rnn_input_size, + self._hidden_size, + rnn_type=self._config.rnn_type, + num_layers=self._config.num_rnn_layers, + ) + self._policy = CategoricalNet(self._hidden_size, self._n_actions) + self._critic = CriticHead(self._hidden_size) + + def _setup_actions(self) -> List[PddlAction]: + all_actions = self._pddl_prob.get_possible_actions() + all_actions = [ + ac for ac in all_actions if ac.name in self._config.allowed_actions + ] + if not self._config.allow_other_place: + all_actions = [ + ac + for ac in all_actions + if ( + ac.name != "place" + or ac.param_values[0].name in ac.param_values[1].name + ) + ] + return all_actions + + def get_policy_action_space( + self, env_action_space: spaces.Space + ) -> spaces.Space: + return spaces.Discrete(self._n_actions) + + @property + def num_recurrent_layers(self): + return self._state_encoder.num_recurrent_layers + + def parameters(self): + return chain( + self._visual_encoder.parameters(), + self._visual_fc.parameters(), + self._policy.parameters(), + self._state_encoder.parameters(), + self._critic.parameters(), + ) + + def get_policy_components(self) -> List[nn.Module]: + return [self] + + def forward(self, obs, rnn_hidden_states, masks, rnn_build_seq_info=None): + hidden = [] + if len(self._im_obs_space) > 0: + im_obs = {k: obs[k] for k in self._im_obs_space.keys()} + visual_features = self._visual_encoder(im_obs) + visual_features = self._visual_fc(visual_features) + hidden.append(visual_features) + + if len(self._state_obs_space) > 0: + hidden.extend([obs[k] for k in self._state_obs_space.keys()]) + hidden = torch.cat(hidden, -1) + + return self._state_encoder( + hidden, rnn_hidden_states, masks, rnn_build_seq_info + ) + + def to(self, device): + self._device = device + return super().to(device) + + def get_value(self, observations, rnn_hidden_states, prev_actions, masks): + state, _ = self.forward(observations, rnn_hidden_states, masks) + return self._critic(state) + + def evaluate_actions( + self, + observations, + rnn_hidden_states, + prev_actions, + masks, + action, + rnn_build_seq_info, + ): + features, _ = self.forward( + observations, rnn_hidden_states, masks, rnn_build_seq_info + ) + distribution = self._policy(features) + value = self._critic(features) + action_log_probs = distribution.log_probs(action) + distribution_entropy = distribution.entropy() + + return ( + value, + action_log_probs, + distribution_entropy, + rnn_hidden_states, + {}, + ) + + def get_next_skill( + self, + observations, + rnn_hidden_states, + prev_actions, + masks, + plan_masks, + deterministic, + log_info, + ): + next_skill = torch.zeros(self._num_envs, dtype=torch.long) + skill_args_data: List[Any] = [None for _ in range(self._num_envs)] + immediate_end = torch.zeros(self._num_envs, dtype=torch.bool) + + state, rnn_hidden_states = self.forward( + observations, rnn_hidden_states, masks + ) + distrib = self._policy(state) + values = self._critic(state) + if deterministic: + skill_sel = distrib.mode() + else: + skill_sel = distrib.sample() + action_log_probs = distrib.log_probs(skill_sel) + + for batch_idx, should_plan in enumerate(plan_masks): + if should_plan != 1.0: + continue + use_ac = self._all_actions[skill_sel[batch_idx]] + if baselines_logger.level >= logging.DEBUG: + baselines_logger.debug(f"HL Policy selected skill {use_ac}") + next_skill[batch_idx] = self._skill_name_to_idx[use_ac.name] + skill_args_data[batch_idx] = [ + entity.name for entity in use_ac.param_values + ] + log_info[batch_idx]["nn_action"] = use_ac.compact_str + + return ( + next_skill, + skill_args_data, + immediate_end, + { + "action_log_probs": action_log_probs, + "values": values, + "actions": skill_sel, + "rnn_hidden_states": rnn_hidden_states, + }, + ) diff --git a/habitat-baselines/habitat_baselines/rl/hrl/hrl_ppo.py b/habitat-baselines/habitat_baselines/rl/hrl/hrl_ppo.py new file mode 100644 index 0000000000..b7bbd38bc2 --- /dev/null +++ b/habitat-baselines/habitat_baselines/rl/hrl/hrl_ppo.py @@ -0,0 +1,122 @@ +import torch +import torch.nn.functional as F + +from habitat_baselines.common.baseline_registry import baseline_registry +from habitat_baselines.rl.ddppo.algo.ddppo import DecentralizedDistributedMixin +from habitat_baselines.rl.ppo import PPO +from habitat_baselines.utils.common import ( + LagrangeInequalityCoefficient, + inference_mode, +) + + +@baseline_registry.register_updater +class HRLPPO(PPO): + def _update_from_batch(self, batch, epoch, rollouts, learner_metrics): + n_samples = max(batch["loss_mask"].sum(), 1) + + def record_min_mean_max(t: torch.Tensor, prefix: str): + for name, op in ( + ("min", torch.min), + ("mean", torch.mean), + ("max", torch.max), + ): + learner_metrics[f"{prefix}_{name}"].append(op(t)) + + def reduce_loss(loss): + return (loss * batch["loss_mask"]).sum() / n_samples + + self._set_grads_to_none() + + ( + values, + action_log_probs, + dist_entropy, + _, + _, + ) = self._evaluate_actions( + batch["observations"], + batch["recurrent_hidden_states"], + batch["prev_actions"], + batch["masks"], + batch["actions"], + batch["rnn_build_seq_info"], + ) + + ratio = torch.exp(action_log_probs - batch["action_log_probs"]) + + surr1 = batch["advantages"] * ratio + surr2 = batch["advantages"] * ( + torch.clamp( + ratio, + 1.0 - self.clip_param, + 1.0 + self.clip_param, + ) + ) + action_loss = -torch.min(surr1, surr2) + action_loss = reduce_loss(action_loss) + + values = values.float() + orig_values = values + + if self.use_clipped_value_loss: + delta = values.detach() - batch["value_preds"] + value_pred_clipped = batch["value_preds"] + delta.clamp( + -self.clip_param, self.clip_param + ) + + values = torch.where( + delta.abs() < self.clip_param, + values, + value_pred_clipped, + ) + + value_loss = 0.5 * F.mse_loss( + values, batch["returns"], reduction="none" + ) + value_loss = reduce_loss(value_loss) + + all_losses = [ + self.value_loss_coef * value_loss, + action_loss, + ] + + dist_entropy = reduce_loss(dist_entropy) + if isinstance(self.entropy_coef, float): + all_losses.append(-self.entropy_coef * dist_entropy) + else: + all_losses.append(self.entropy_coef.lagrangian_loss(dist_entropy)) + + total_loss = torch.stack(all_losses).sum() + + total_loss = self.before_backward(total_loss) + total_loss.backward() + self.after_backward(total_loss) + + grad_norm = self.before_step() + self.optimizer.step() + self.after_step() + + with inference_mode(): + record_min_mean_max(orig_values, "value_pred") + record_min_mean_max(ratio, "prob_ratio") + + learner_metrics["value_loss"].append(value_loss) + learner_metrics["action_loss"].append(action_loss) + learner_metrics["dist_entropy"].append(dist_entropy) + if epoch == (self.ppo_epoch - 1): + learner_metrics["ppo_fraction_clipped"].append( + (ratio > (1.0 + self.clip_param)).float().mean() + + (ratio < (1.0 - self.clip_param)).float().mean() + ) + + learner_metrics["grad_norm"].append(grad_norm) + if isinstance(self.entropy_coef, LagrangeInequalityCoefficient): + learner_metrics["entropy_coef"].append( + self.entropy_coef().detach() + ) + + +@baseline_registry.register_updater +class HRLDDPPO(DecentralizedDistributedMixin, HRLPPO): + pass diff --git a/habitat-baselines/habitat_baselines/rl/hrl/hrl_rollout_storage.py b/habitat-baselines/habitat_baselines/rl/hrl/hrl_rollout_storage.py new file mode 100644 index 0000000000..9acf0854ea --- /dev/null +++ b/habitat-baselines/habitat_baselines/rl/hrl/hrl_rollout_storage.py @@ -0,0 +1,211 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Iterator, Optional + +import torch + +from habitat_baselines.common.baseline_registry import baseline_registry +from habitat_baselines.common.rollout_storage import RolloutStorage +from habitat_baselines.common.tensor_dict import DictTree, TensorDict +from habitat_baselines.rl.models.rnn_state_encoder import ( + build_pack_info_from_dones, + build_rnn_build_seq_info, +) + +EPS_PPO = 1e-5 + + +@baseline_registry.register_storage +class HrlRolloutStorage(RolloutStorage): + """ + Supports variable writes to the rollout buffer where data is not inserted + into the buffer on every step. When getting batches from the storage, these + batches will only contain samples that were written. This means that the + batches could be variable size and less than the maximum size of the + rollout buffer. + """ + + def __init__(self, numsteps, num_envs, *args, **kwargs): + super().__init__(numsteps, num_envs, *args, **kwargs) + self._num_envs = num_envs + self._cur_step_idxs = torch.zeros(self._num_envs, dtype=torch.long) + self._last_should_inserts = None + assert ( + not self.is_double_buffered + ), "HRL storage does not support double buffered sampling" + + def insert( + self, + next_observations=None, + next_recurrent_hidden_states=None, + actions=None, + action_log_probs=None, + value_preds=None, + rewards=None, + next_masks=None, + buffer_index: int = 0, + should_inserts: Optional[torch.BoolTensor] = None, + ): + """ + The only key different from the base `RolloutStorage` is + `should_inserts`. This is a bool tensor of shape [# environments,]. If + `should_insert[i] == True`, then this will the sample at enviroment + index `i` into the rollout buffer at environment index `i`, if not, it + will ignore the sample. If None, this defaults to the last insert + state. + + Rewards acquired of steps where `should_insert[i] == False` will be summed up and added to the next step where `should_insert[i] == True` + """ + + if next_masks is not None: + next_masks = next_masks.to(self.device) + if rewards is not None: + rewards = rewards.to(self.device) + next_step = dict( + observations=next_observations, + recurrent_hidden_states=next_recurrent_hidden_states, + prev_actions=actions, + masks=next_masks, + ) + + current_step = dict( + actions=actions, + action_log_probs=action_log_probs, + value_preds=value_preds, + ) + + next_step = TensorDict( + {k: v for k, v in next_step.items() if v is not None} + ) + current_step = TensorDict( + {k: v for k, v in current_step.items() if v is not None} + ) + + if should_inserts is None: + should_inserts = self._last_should_inserts + assert should_inserts is not None + + if should_inserts.sum() == 0: + return + + env_idxs = torch.arange(self._num_envs) + if rewards is not None: + # Accumulate rewards between updates. + reward_write_idxs = torch.clamp(self._cur_step_idxs - 1, min=0) + self.buffers["rewards"][reward_write_idxs, env_idxs] += rewards + + if len(next_step) > 0: + self.buffers.set( + ( + self._cur_step_idxs[should_inserts] + 1, + env_idxs[should_inserts], + ), + next_step[should_inserts], + strict=False, + ) + + if len(current_step) > 0: + self.buffers.set( + ( + self._cur_step_idxs[should_inserts], + env_idxs[should_inserts], + ), + current_step[should_inserts], + strict=False, + ) + self._last_should_inserts = should_inserts + + def advance_rollout(self, buffer_index: int = 0): + """ + This will advance to writing at the next step in the data buffer ONLY + if an element was written to that environment index in the previous + step. + """ + + self._cur_step_idxs += self._last_should_inserts.long() + + is_past_buffer = self._cur_step_idxs >= self.num_steps + if is_past_buffer.sum() > 0: + self._cur_step_idxs[is_past_buffer] = self.num_steps - 1 + env_idxs = torch.arange(self._num_envs) + self.buffers["rewards"][ + self._cur_step_idxs[is_past_buffer], env_idxs[is_past_buffer] + ] = 0.0 + + def after_update(self): + env_idxs = torch.arange(self._num_envs) + self.buffers[0] = self.buffers[self._cur_step_idxs, env_idxs] + self.buffers["masks"][1:] = False + self.buffers["rewards"][1:] = 0.0 + + self.current_rollout_step_idxs = [ + 0 for _ in self.current_rollout_step_idxs + ] + self._cur_step_idxs[:] = 0 + + def compute_returns(self, next_value, use_gae, gamma, tau): + if not use_gae: + raise ValueError("Only GAE is supported with HRL trainer") + + assert isinstance(self.buffers["value_preds"], torch.Tensor) + gae = 0.0 + for step in reversed(range(self._cur_step_idxs.max() - 1)): + delta = ( + self.buffers["rewards"][step] + + gamma + * self.buffers["value_preds"][step + 1] + * self.buffers["masks"][step + 1] + - self.buffers["value_preds"][step] + ) + gae = delta + gamma * tau * gae * self.buffers["masks"][step + 1] + self.buffers["returns"][step] = ( # type: ignore + gae + self.buffers["value_preds"][step] # type: ignore + ) + + def recurrent_generator( + self, advantages, num_batches + ) -> Iterator[DictTree]: + """ + Generates data batches based on the data that has been written to the + rollout buffer. + """ + + num_environments = advantages.size(1) + dones_cpu = ( + torch.logical_not(self.buffers["masks"]) + .cpu() + .view(-1, self._num_envs) + .numpy() + ) + for inds in torch.randperm(num_environments).chunk(num_batches): + batch = self.buffers[0 : self.num_steps, inds] + batch["advantages"] = advantages[: self.num_steps, inds] + batch["recurrent_hidden_states"] = batch[ + "recurrent_hidden_states" + ][0:1] + batch["loss_mask"] = ( + torch.arange(self.num_steps, device=advantages.device) + .view(-1, 1, 1) + .repeat(1, len(inds), 1) + ) + for i, env_i in enumerate(inds): + # The -1 is to throw out the last transition. + batch["loss_mask"][:, i] = ( + batch["loss_mask"][:, i] < self._cur_step_idxs[env_i] - 1 + ) + + batch.map_in_place(lambda v: v.flatten(0, 1)) + batch["rnn_build_seq_info"] = build_rnn_build_seq_info( + device=self.device, + build_fn_result=build_pack_info_from_dones( + dones_cpu[0 : self.num_steps, inds.numpy()].reshape( + -1, len(inds) + ), + ), + ) + + yield batch.to_tree() diff --git a/habitat-baselines/habitat_baselines/rl/hrl/skills/__init__.py b/habitat-baselines/habitat_baselines/rl/hrl/skills/__init__.py index 1b1536f999..1187ac3cd9 100644 --- a/habitat-baselines/habitat_baselines/rl/hrl/skills/__init__.py +++ b/habitat-baselines/habitat_baselines/rl/hrl/skills/__init__.py @@ -5,6 +5,7 @@ from habitat_baselines.rl.hrl.skills.art_obj import ArtObjSkillPolicy from habitat_baselines.rl.hrl.skills.nav import NavSkillPolicy from habitat_baselines.rl.hrl.skills.nn_skill import NnSkillPolicy +from habitat_baselines.rl.hrl.skills.noop import NoopSkillPolicy from habitat_baselines.rl.hrl.skills.oracle_nav import OracleNavPolicy from habitat_baselines.rl.hrl.skills.pick import PickSkillPolicy from habitat_baselines.rl.hrl.skills.place import PlaceSkillPolicy @@ -22,4 +23,5 @@ "ResetArmSkill", "SkillPolicy", "WaitSkillPolicy", + "NoopSkillPolicy", ] diff --git a/habitat-baselines/habitat_baselines/rl/hrl/skills/art_obj.py b/habitat-baselines/habitat_baselines/rl/hrl/skills/art_obj.py index 3f381a3c11..da7ff0d405 100644 --- a/habitat-baselines/habitat_baselines/rl/hrl/skills/art_obj.py +++ b/habitat-baselines/habitat_baselines/rl/hrl/skills/art_obj.py @@ -65,4 +65,4 @@ def _is_skill_done( def _parse_skill_arg(self, skill_arg): self._internal_log(f"Parsing skill argument {skill_arg}") - return int(skill_arg[-1].split("|")[1]) + return int(skill_arg[1].split("|")[1]) diff --git a/habitat-baselines/habitat_baselines/rl/hrl/skills/nav.py b/habitat-baselines/habitat_baselines/rl/hrl/skills/nav.py index 834ea862f6..642aaf9a69 100644 --- a/habitat-baselines/habitat_baselines/rl/hrl/skills/nav.py +++ b/habitat-baselines/habitat_baselines/rl/hrl/skills/nav.py @@ -47,8 +47,8 @@ def _get_filtered_obs(self, observations, cur_batch_idx) -> TensorDict: ret_obs = super()._get_filtered_obs(observations, cur_batch_idx) if NavGoalPointGoalSensor.cls_uuid in ret_obs: - for i in cur_batch_idx: - if self._cur_skill_args[cur_batch_idx[i]].is_target: + for i, batch_i in enumerate(cur_batch_idx): + if self._cur_skill_args[batch_i].is_target: replace_sensor = TargetGoalGpsCompassSensor.cls_uuid else: replace_sensor = TargetStartGpsCompassSensor.cls_uuid diff --git a/habitat-baselines/habitat_baselines/rl/hrl/skills/nn_skill.py b/habitat-baselines/habitat_baselines/rl/hrl/skills/nn_skill.py index be9ebb2258..12e172e114 100644 --- a/habitat-baselines/habitat_baselines/rl/hrl/skills/nn_skill.py +++ b/habitat-baselines/habitat_baselines/rl/hrl/skills/nn_skill.py @@ -14,6 +14,7 @@ from habitat_baselines.common.tensor_dict import TensorDict from habitat_baselines.config.default import get_config from habitat_baselines.rl.hrl.skills.skill import SkillPolicy +from habitat_baselines.rl.ppo.policy import PolicyActionData from habitat_baselines.utils.common import get_num_actions @@ -72,6 +73,10 @@ def parameters(self): else: return [] + @property + def has_hidden_state(self): + return self.num_recurrent_layers != 0 + @property def num_recurrent_layers(self): if self._wrap_policy is not None: @@ -118,7 +123,7 @@ def _internal_act( masks, cur_batch_idx, deterministic=False, - ): + ) -> PolicyActionData: filtered_obs = self._get_filtered_obs(observations, cur_batch_idx) filtered_prev_actions = prev_actions[ @@ -126,19 +131,25 @@ def _internal_act( ] filtered_obs = self._select_obs(filtered_obs, cur_batch_idx) - _, action, _, rnn_hidden_states = self._wrap_policy.act( + action_data = self._wrap_policy.act( filtered_obs, rnn_hidden_states, filtered_prev_actions, masks, deterministic, ) - full_action = torch.zeros(prev_actions.shape, device=masks.device) - full_action[:, self._ac_start : self._ac_start + self._ac_len] = action + full_action = torch.zeros( + (masks.shape[0], self._full_ac_size), device=masks.device + ) + full_action[ + :, self._ac_start : self._ac_start + self._ac_len + ] = action_data.actions + action_data.actions = full_action + self._did_want_done[cur_batch_idx] = full_action[ - cur_batch_idx, self._stop_action_idx + :, self._stop_action_idx ] - return full_action, rnn_hidden_states + return action_data @classmethod def from_config( @@ -171,6 +182,10 @@ def from_config( ) for k in config.obs_skill_inputs: + if k not in filtered_obs_space.spaces: + raise ValueError( + f"Could not find {k} for skill {policy_cfg.habitat.gym.auto_name}" + ) space = filtered_obs_space.spaces[k] # There is always a 3D position filtered_obs_space.spaces[k] = truncate_obs_space(space, 3) diff --git a/habitat-baselines/habitat_baselines/rl/hrl/skills/noop.py b/habitat-baselines/habitat_baselines/rl/hrl/skills/noop.py new file mode 100644 index 0000000000..271afeaa0f --- /dev/null +++ b/habitat-baselines/habitat_baselines/rl/hrl/skills/noop.py @@ -0,0 +1,39 @@ +from typing import Any + +import gym.spaces as spaces +import torch + +from habitat_baselines.rl.hrl.skills.skill import SkillPolicy +from habitat_baselines.rl.ppo.policy import PolicyActionData + + +class NoopSkillPolicy(SkillPolicy): + def __init__( + self, + config, + action_space: spaces.Space, + batch_size, + ): + super().__init__(config, action_space, batch_size, False) + + def _parse_skill_arg(self, *args, **kwargs) -> Any: + pass + + def _is_skill_done( + self, observations, rnn_hidden_states, prev_actions, masks, batch_idx + ) -> torch.BoolTensor: + return torch.zeros(masks.size(0), dtype=torch.bool) + + def _internal_act( + self, + observations, + rnn_hidden_states, + prev_actions, + masks, + cur_batch_idx, + full_action, + deterministic=False, + ): + return PolicyActionData( + actions=full_action, rnn_hidden_states=rnn_hidden_states + ) diff --git a/habitat-baselines/habitat_baselines/rl/hrl/skills/oracle_nav.py b/habitat-baselines/habitat_baselines/rl/hrl/skills/oracle_nav.py index 7ba03a0bd7..369bdf4e05 100644 --- a/habitat-baselines/habitat_baselines/rl/hrl/skills/oracle_nav.py +++ b/habitat-baselines/habitat_baselines/rl/hrl/skills/oracle_nav.py @@ -8,23 +8,22 @@ import torch from habitat.core.spaces import ActionSpace -from habitat.tasks.rearrange.actions.oracle_nav_action import ( - get_possible_nav_to_actions, -) from habitat.tasks.rearrange.multi_task.pddl_domain import PddlProblem -from habitat.tasks.rearrange.multi_task.rearrange_pddl import RIGID_OBJ_TYPE from habitat.tasks.rearrange.rearrange_sensors import LocalizationSensor from habitat_baselines.common.logging import baselines_logger from habitat_baselines.rl.hrl.skills.nn_skill import NnSkillPolicy from habitat_baselines.rl.hrl.utils import find_action_range +from habitat_baselines.rl.ppo.policy import PolicyActionData class OracleNavPolicy(NnSkillPolicy): @dataclass class OracleNavActionArgs: + """ + :property action_idx: The index of the oracle action we want to execute + """ + action_idx: int - is_target_obj: bool - target_idx: int def __init__( self, @@ -52,7 +51,7 @@ def __init__( pddl_task_path, task_config, ) - self._poss_actions = get_possible_nav_to_actions(self._pddl_problem) + self._all_entities = self._pddl_problem.get_ordered_entities_list() self._oracle_nav_ac_idx, _ = find_action_range( action_space, "oracle_nav_action" ) @@ -87,7 +86,7 @@ def from_config( cls, config, observation_space, action_space, batch_size, full_config ): filtered_action_space = ActionSpace( - {config.NAV_ACTION_NAME: action_space[config.NAV_ACTION_NAME]} + {config.action_name: action_space[config.action_name]} ) baselines_logger.debug( f"Loaded action space {filtered_action_space} for skill {config.skill_name}" @@ -115,58 +114,37 @@ def _is_skill_done( masks, batch_idx, ) -> torch.BoolTensor: - ret = torch.zeros(masks.shape[0], dtype=torch.bool).to(masks.device) + ret = torch.zeros(masks.shape[0], dtype=torch.bool) cur_pos = observations[LocalizationSensor.cls_uuid].cpu() for i, batch_i in enumerate(batch_idx): prev_pos = self._prev_pos[batch_i] if prev_pos is not None: - movement = torch.linalg.norm(prev_pos - cur_pos[i]) - ret[i] = movement < self._config.STOP_THRESH + movement = (prev_pos - cur_pos[i]).pow(2).sum().sqrt() + ret[i] = movement < self._config.stop_thresh self._prev_pos[batch_i] = cur_pos[i] return ret def _parse_skill_arg(self, skill_arg): - marker = None if len(skill_arg) == 2: - targ_obj, _ = skill_arg + search_target, _ = skill_arg elif len(skill_arg) == 3: - marker, targ_obj, _ = skill_arg + _, search_target, _ = skill_arg else: raise ValueError( f"Unexpected number of skill arguments in {skill_arg}" ) - targ_obj_idx = int(targ_obj.split("|")[-1]) - - targ_obj = self._pddl_problem.get_entity(targ_obj) - if marker is not None: - marker = self._pddl_problem.get_entity(marker) - - match_i = None - for i, action in enumerate(self._poss_actions): - match_obj = action.get_arg_value("obj") - if marker is not None: - match_marker = action.get_arg_value("marker") - if match_marker != marker: - continue - if match_obj != targ_obj: - continue - match_i = i - break - if match_i is None: - raise ValueError(f"Cannot find matching action for {skill_arg}") - is_target_obj = targ_obj.expr_type.is_subtype_of( - self._pddl_problem.expr_types[RIGID_OBJ_TYPE] - ) - return OracleNavPolicy.OracleNavActionArgs( - match_i, is_target_obj, targ_obj_idx - ) + target = self._pddl_problem.get_entity(search_target) + if target is None: + raise ValueError( + f"Cannot find matching entity for {search_target}" + ) + match_i = self._all_entities.index(target) - def _get_multi_sensor_index(self, batch_idx): - return [self._cur_skill_args[i].target_idx for i in batch_idx] + return OracleNavPolicy.OracleNavActionArgs(match_i) def _internal_act( self, @@ -177,11 +155,15 @@ def _internal_act( cur_batch_idx, deterministic=False, ): - full_action = torch.zeros(prev_actions.shape, device=masks.device) + full_action = torch.zeros( + (masks.shape[0], self._full_ac_size), device=masks.device + ) action_idxs = torch.FloatTensor( [self._cur_skill_args[i].action_idx + 1 for i in cur_batch_idx] ) full_action[:, self._oracle_nav_ac_idx] = action_idxs - return full_action, rnn_hidden_states + return PolicyActionData( + actions=full_action, rnn_hidden_states=rnn_hidden_states + ) diff --git a/habitat-baselines/habitat_baselines/rl/hrl/skills/pick.py b/habitat-baselines/habitat_baselines/rl/hrl/skills/pick.py index 3f88f1b1c4..4cc58e520f 100644 --- a/habitat-baselines/habitat_baselines/rl/hrl/skills/pick.py +++ b/habitat-baselines/habitat_baselines/rl/hrl/skills/pick.py @@ -9,6 +9,7 @@ RelativeRestingPositionSensor, ) from habitat_baselines.rl.hrl.skills.nn_skill import NnSkillPolicy +from habitat_baselines.rl.ppo.policy import PolicyActionData class PickSkillPolicy(NnSkillPolicy): @@ -33,12 +34,14 @@ def _parse_skill_arg(self, skill_arg): self._internal_log(f"Parsing skill argument {skill_arg}") return int(skill_arg[0].split("|")[1]) - def _mask_pick(self, action, observations): + def _mask_pick( + self, action: PolicyActionData, observations + ) -> PolicyActionData: # Mask out the release if the object is already held. is_holding = observations[IsHoldingSensor.cls_uuid].view(-1) for i in torch.nonzero(is_holding): # Do not release the object once it is held - action[i, self._grip_ac_idx] = 1.0 + action.actions[i, self._grip_ac_idx] = 1.0 return action def _internal_act( @@ -50,7 +53,7 @@ def _internal_act( cur_batch_idx, deterministic=False, ): - action, hxs = super()._internal_act( + action = super()._internal_act( observations, rnn_hidden_states, prev_actions, @@ -59,4 +62,4 @@ def _internal_act( deterministic, ) action = self._mask_pick(action, observations) - return action, hxs + return action diff --git a/habitat-baselines/habitat_baselines/rl/hrl/skills/place.py b/habitat-baselines/habitat_baselines/rl/hrl/skills/place.py index a8d069910c..a99ac1a8e1 100644 --- a/habitat-baselines/habitat_baselines/rl/hrl/skills/place.py +++ b/habitat-baselines/habitat_baselines/rl/hrl/skills/place.py @@ -46,7 +46,6 @@ def _is_skill_done( if is_done.sum() > 0: self._internal_log( f"Terminating with {rel_resting_pos} and {is_holding}", - observations, ) return is_done diff --git a/habitat-baselines/habitat_baselines/rl/hrl/skills/reset.py b/habitat-baselines/habitat_baselines/rl/hrl/skills/reset.py index 9e33f713c0..937f0097c1 100644 --- a/habitat-baselines/habitat_baselines/rl/hrl/skills/reset.py +++ b/habitat-baselines/habitat_baselines/rl/hrl/skills/reset.py @@ -9,7 +9,8 @@ import torch from habitat_baselines.rl.hrl.skills.skill import SkillPolicy -from habitat_baselines.utils.common import get_num_actions +from habitat_baselines.rl.hrl.utils import find_action_range +from habitat_baselines.rl.ppo.policy import PolicyActionData class ResetArmSkill(SkillPolicy): @@ -20,14 +21,12 @@ def __init__( batch_size, ): super().__init__(config, action_space, batch_size, True) - self._target = np.array([float(x) for x in config.reset_joint_state]) + self._rest_state = np.array( + [float(x) for x in config.reset_joint_state] + ) - self._ac_start = 0 - for k, space in action_space.items(): - if k != "arm_action": - self._ac_start += get_num_actions(space) - else: - break + self._arm_ac_range = find_action_range(action_space, "arm_action") + self._arm_ac_range = (self._arm_ac_range[0], self._rest_state.shape[0]) def on_enter( self, @@ -46,7 +45,7 @@ def on_enter( ) self._initial_delta = ( - self._target - observations["joint"].cpu().numpy() + self._rest_state - observations["joint"].cpu().numpy() ) return ret @@ -61,8 +60,7 @@ def _is_skill_done( return ( torch.as_tensor( - np.abs(current_joint_pos - self._target).max(-1), - device=rnn_hidden_states.device, + np.abs(current_joint_pos - self._rest_state).max(-1), dtype=torch.float32, ) < 5e-2 @@ -78,19 +76,23 @@ def _internal_act( deterministic=False, ): current_joint_pos = observations["joint"].cpu().numpy() - delta = self._target - current_joint_pos + delta = self._rest_state - current_joint_pos # Dividing by max initial delta means that the action will # always in [-1,1] and has the benefit of reducing the delta # amount was we converge to the target. delta = delta / np.maximum( - self._initial_delta.max(-1, keepdims=True), 1e-5 + self._initial_delta[cur_batch_idx].max(-1, keepdims=True), 1e-5 ) action = torch.zeros_like(prev_actions) + # There is an extra grab action that we don't want to set. + action[ + ..., self._arm_ac_range[0] : self._arm_ac_range[1] + ] = torch.from_numpy(delta).to( + device=action.device, dtype=action.dtype + ) - action[..., self._ac_start : self._ac_start + 7] = torch.from_numpy( - delta - ).to(device=action.device, dtype=action.dtype) - - return action, rnn_hidden_states + return PolicyActionData( + actions=action, rnn_hidden_states=rnn_hidden_states + ) diff --git a/habitat-baselines/habitat_baselines/rl/hrl/skills/skill.py b/habitat-baselines/habitat_baselines/rl/hrl/skills/skill.py index 7b857288c5..e373f83378 100644 --- a/habitat-baselines/habitat_baselines/rl/hrl/skills/skill.py +++ b/habitat-baselines/habitat_baselines/rl/hrl/skills/skill.py @@ -2,15 +2,17 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, List, Optional, Tuple +import logging +from typing import Any, Dict, List, Optional, Tuple import gym.spaces as spaces import torch +from habitat.core.simulator import Observations from habitat.tasks.rearrange.rearrange_sensors import IsHoldingSensor from habitat_baselines.common.logging import baselines_logger from habitat_baselines.rl.hrl.utils import find_action_range -from habitat_baselines.rl.ppo.policy import Policy +from habitat_baselines.rl.ppo.policy import Policy, PolicyActionData from habitat_baselines.utils.common import get_num_actions @@ -27,6 +29,9 @@ def __init__( """ self._config = config self._batch_size = batch_size + self._apply_postconds = self._config.apply_postconds + self._force_end_on_timeout = self._config.force_end_on_timeout + self._max_skill_steps = self._config.max_skill_steps self._cur_skill_step = torch.zeros(self._batch_size) self._should_keep_hold_state = should_keep_hold_state @@ -37,6 +42,22 @@ def __init__( self._raw_skill_args: List[Optional[str]] = [ None for _ in range(self._batch_size) ] + self._full_ac_size = get_num_actions(action_space) + + # TODO: for some reason this doesnt work with "pddl_apply_action" in action_space + # and needs to go through the keys argument + if "pddl_apply_action" in list(action_space.keys()): + self._pddl_ac_start, _ = find_action_range( + action_space, "pddl_apply_action" + ) + else: + self._pddl_ac_start = None + if self._apply_postconds and self._pddl_ac_start is None: + raise ValueError(f"Could not find PDDL action in skill {self}") + + self._delay_term: List[Optional[bool]] = [ + None for _ in range(self._batch_size) + ] self._grip_ac_idx = 0 found_grip = False @@ -54,7 +75,7 @@ def __init__( action_space, "rearrange_stop" ) - def _internal_log(self, s, observations=None): + def _internal_log(self, s): baselines_logger.debug( f"Skill {self._config.skill_name} @ step {self._cur_skill_step}: {s}" ) @@ -67,9 +88,17 @@ def _get_multi_sensor_index(self, batch_idx: List[int]) -> List[int]: """ return [self._cur_skill_args[i] for i in batch_idx] + @property + def has_hidden_state(self): + """ + Returns if the skill requires a hidden state. + """ + + return False + def _keep_holding_state( - self, full_action: torch.Tensor, observations - ) -> torch.Tensor: + self, action_data: PolicyActionData, observations + ) -> PolicyActionData: """ Makes the action so it does not result in dropping or picking up an object. Used in navigation and other skills which are not supposed to @@ -79,48 +108,124 @@ def _keep_holding_state( is_holding = observations[IsHoldingSensor.cls_uuid].view(-1) # If it is not holding (0) want to keep releasing -> output -1. # If it is holding (1) want to keep grasping -> output +1. - full_action[:, self._grip_ac_idx] = is_holding + (is_holding - 1.0) - return full_action + action_data.write_action( + self._grip_ac_idx, is_holding + (is_holding - 1.0) + ) + return action_data + + def _apply_postcond( + self, + actions, + log_info, + skill_name, + env_i, + idx, + ): + """ + Modifies the actions according to the postconditions set in self._pddl_problem.actions[skill_name] + """ + skill_args = self._raw_skill_args[env_i] + action = self._pddl_problem.actions[skill_name] + + entities = [self._pddl_problem.get_entity(x) for x in skill_args] + assert ( + self._pddl_ac_start is not None + ), "Apply post cond not supported when pddl action not in action space" + + ac_idx = self._pddl_ac_start + found = False + for other_action in self._action_ordering: + if other_action.name != action.name: + ac_idx += other_action.n_args + else: + found = True + break + if not found: + raise ValueError(f"Could not find action {action}") + + entity_idxs = [ + self._entities_list.index(entity) + 1 for entity in entities + ] + if len(entity_idxs) != action.n_args: + raise ValueError( + f"The skill was called with the wrong # of args {action.n_args} versus {entity_idxs} for {action} with {skill_args} and {entities}. Make sure the skill and PDDL definition match." + ) + + actions[idx, ac_idx : ac_idx + action.n_args] = torch.tensor( + entity_idxs, dtype=actions.dtype, device=actions.device + ) + apply_action = action.clone() + apply_action.set_param_values(entities) + + log_info[env_i]["pddl_action"] = apply_action.compact_str + return actions[idx] def should_terminate( self, - observations, - rnn_hidden_states, - prev_actions, - masks, - batch_idx, - ) -> Tuple[torch.BoolTensor, torch.BoolTensor]: + observations: Observations, + rnn_hidden_states: torch.Tensor, + prev_actions: torch.Tensor, + masks: torch.Tensor, + actions: torch.Tensor, + hl_wants_skill_term: torch.BoolTensor, + batch_idx: List[int], + skill_name: List[str], + log_info: List[Dict[str, Any]], + ) -> Tuple[torch.BoolTensor, torch.BoolTensor, torch.Tensor]: """ - :returns: A (batch_size,) size tensor where 1 indicates the skill wants to end and 0 if not. + :returns: Both of the BoolTensor's will be on the CPU. + - `is_skill_done`: Shape (batch_size,) size tensor where 1 + indicates the skill to return control to HL policy. + - `bad_terminate`: Shape (batch_size,) size tensor where 1 + indicates the skill should immediately end the episode. """ is_skill_done = self._is_skill_done( observations, rnn_hidden_states, prev_actions, masks, batch_idx - ) + ).cpu() if is_skill_done.sum() > 0: self._internal_log( f"Requested skill termination {is_skill_done}", - observations, ) + cur_skill_step = self._cur_skill_step[batch_idx] + bad_terminate = torch.zeros( - self._cur_skill_step.shape, - device=self._cur_skill_step.device, + cur_skill_step.shape, + device=cur_skill_step.device, dtype=torch.bool, ) - if self._config.max_skill_steps > 0: - over_max_len = self._cur_skill_step > self._config.max_skill_steps - if self._config.force_end_on_timeout: + if self._max_skill_steps > 0: + over_max_len = cur_skill_step >= self._max_skill_steps + if self._force_end_on_timeout: bad_terminate = over_max_len else: is_skill_done = is_skill_done | over_max_len + is_skill_done |= hl_wants_skill_term + + new_actions = torch.zeros_like(actions) + for i, env_i in enumerate(batch_idx): + if self._delay_term[env_i]: + self._internal_log( + "Terminating skill due to delayed termination." + ) + self._delay_term[env_i] = False + is_skill_done[i] = True + elif self._apply_postconds and is_skill_done[i]: + new_actions[i] = self._apply_postcond( + actions, log_info, skill_name[i], env_i, i + ) + self._delay_term[env_i] = True + is_skill_done[i] = False + self._internal_log( + "Applying PDDL action and terminating on the next step." + ) + if bad_terminate.sum() > 0: self._internal_log( - f"Bad terminating due to timeout {self._cur_skill_step}, {bad_terminate}", - observations, + f"Bad terminating due to timeout {cur_skill_step}, {bad_terminate}", ) - - return is_skill_done, bad_terminate + return is_skill_done, bad_terminate, new_actions def on_enter( self, @@ -138,6 +243,10 @@ def on_enter( self._cur_skill_step[batch_idxs] = 0 for i, batch_idx in enumerate(batch_idxs): self._raw_skill_args[batch_idx] = skill_arg[i] + if baselines_logger.level >= logging.DEBUG: + baselines_logger.debug( + f"Entering skill {self} with arguments {skill_arg[i]}" + ) self._cur_skill_args[batch_idx] = self._parse_skill_arg( skill_arg[i] ) @@ -147,6 +256,11 @@ def on_enter( prev_actions[batch_idxs] * 0.0, ) + def set_pddl_problem(self, pddl_prob): + self._pddl_problem = pddl_prob + self._entities_list = self._pddl_problem.get_ordered_entities_list() + self._action_ordering = self._pddl_problem.get_ordered_actions() + @classmethod def from_config( cls, config, observation_space, action_space, batch_size, full_config @@ -166,7 +280,7 @@ def act( :returns: Predicted action and next rnn hidden state. """ self._cur_skill_step[cur_batch_idx] += 1 - action, hxs = self._internal_act( + action_data = self._internal_act( observations, rnn_hidden_states, prev_actions, @@ -176,11 +290,11 @@ def act( ) if self._should_keep_hold_state: - action = self._keep_holding_state(action, observations) - return action, hxs + action_data = self._keep_holding_state(action_data, observations) + return action_data def to(self, device): - self._cur_skill_step = self._cur_skill_step.to(device) + pass def _select_obs(self, obs, cur_batch_idx): """ @@ -196,9 +310,7 @@ def _select_obs(self, obs, cur_batch_idx): ) entity_positions = obs[k].view( - len(cur_batch_idx), - -1, - self._config.get("obs_skill_input_dim", 3), + len(cur_batch_idx), -1, self._config.obs_skill_input_dim ) obs[k] = entity_positions[ torch.arange(len(cur_batch_idx)), cur_multi_sensor_index @@ -209,7 +321,9 @@ def _is_skill_done( self, observations, rnn_hidden_states, prev_actions, masks, batch_idx ) -> torch.BoolTensor: """ - :returns: A (batch_size,) size tensor where 1 indicates the skill wants to end and 0 if not. + :returns: A (batch_size,) size tensor where 1 indicates the skill wants + to end and 0 if not where batch_size is potentially a subset of the + overall num_environments as specified by `batch_idx`. """ return torch.zeros(observations.shape[0], dtype=torch.bool).to( masks.device @@ -229,5 +343,5 @@ def _internal_act( masks, cur_batch_idx, deterministic=False, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> PolicyActionData: raise NotImplementedError() diff --git a/habitat-baselines/habitat_baselines/rl/hrl/skills/wait.py b/habitat-baselines/habitat_baselines/rl/hrl/skills/wait.py index 6fb9a1b552..6a0887ba96 100644 --- a/habitat-baselines/habitat_baselines/rl/hrl/skills/wait.py +++ b/habitat-baselines/habitat_baselines/rl/hrl/skills/wait.py @@ -8,6 +8,7 @@ import torch from habitat_baselines.rl.hrl.skills.skill import SkillPolicy +from habitat_baselines.rl.ppo.policy import PolicyActionData class WaitSkillPolicy(SkillPolicy): @@ -28,7 +29,7 @@ def _is_skill_done( self, observations, rnn_hidden_states, prev_actions, masks, batch_idx ) -> torch.BoolTensor: assert self._wait_time > 0 - return self._cur_skill_step >= self._wait_time + return (self._cur_skill_step >= self._wait_time)[batch_idx] def _internal_act( self, @@ -39,5 +40,9 @@ def _internal_act( cur_batch_idx, deterministic=False, ): - action = torch.zeros(prev_actions.shape, device=prev_actions.device) - return action, rnn_hidden_states + action = torch.zeros( + (masks.shape[0], self._full_ac_size), device=prev_actions.device + ) + return PolicyActionData( + actions=action, rnn_hidden_states=rnn_hidden_states + ) diff --git a/habitat-baselines/habitat_baselines/rl/hrl/utils.py b/habitat-baselines/habitat_baselines/rl/hrl/utils.py index 7cebe7db67..2c1b64bfe2 100644 --- a/habitat-baselines/habitat_baselines/rl/hrl/utils.py +++ b/habitat-baselines/habitat_baselines/rl/hrl/utils.py @@ -12,7 +12,8 @@ def find_action_range( action_space: ActionSpace, search_key: str ) -> Tuple[int, int]: """ - Returns the start and end indices of an action key in the action tensor. + Returns the start and end indices of an action key in the action tensor. If + the key is not found, a Value error will be thrown. """ start_idx = 0 diff --git a/habitat-baselines/habitat_baselines/rl/ppo/policy.py b/habitat-baselines/habitat_baselines/rl/ppo/policy.py index 71a2dc170b..fb83a36478 100644 --- a/habitat-baselines/habitat_baselines/rl/ppo/policy.py +++ b/habitat-baselines/habitat_baselines/rl/ppo/policy.py @@ -4,7 +4,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import abc -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Union +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union import torch from gym import spaces @@ -30,6 +31,58 @@ from omegaconf import DictConfig +@dataclass +class PolicyActionData: + """ + Information returned from the `Policy.act` method representing the + information from an agent's action. + + :property should_inserts: Of shape [# envs, 1]. If False at environment + index `i`, then don't write this transition to the rollout buffer. If + `None`, then write all data. + :property policy_info`: Optional logging information about the policy per + environment. For example, you could log the policy entropy. + :property take_actions`: If specified, these actions will be executed in + the environment, but not stored in the storage buffer. This allows + exectuing and learning from different actions. If not specified, the + agent will execute `self.actions`. + :property values: The actor value predictions. None if the actor does not predict value. + :property actions: The actions to store in the storage buffer. if + `take_actions` is None, then this is also the action executed in the + environment. + :property rnn_hidden_states: Actor hidden states. + :property action_log_probs: The log probabilities of the actions under the + current policy. + """ + + rnn_hidden_states: torch.Tensor + actions: Optional[torch.Tensor] = None + values: Optional[torch.Tensor] = None + action_log_probs: Optional[torch.Tensor] = None + take_actions: Optional[torch.Tensor] = None + policy_info: Optional[List[Dict[str, Any]]] = None + should_inserts: Optional[torch.BoolTensor] = None + + def write_action(self, write_idx: int, write_action: torch.Tensor) -> None: + """ + Used to override an action across all environments. + :param write_idx: The index in the action dimension to write the new action. + :param write_action: The action to write at `write_idx`. + """ + self.actions[:, write_idx] = write_action + + @property + def env_actions(self) -> torch.Tensor: + """ + The actions to execute in the environment. + """ + + if self.take_actions is None: + return self.actions + else: + return self.take_actions + + class Policy(abc.ABC): action_distribution: nn.Module @@ -47,7 +100,29 @@ def num_recurrent_layers(self) -> int: def forward(self, *x): raise NotImplementedError - def get_policy_info(self, infos, dones) -> List[Dict[str, float]]: + def get_policy_action_space( + self, env_action_space: spaces.Space + ) -> spaces.Space: + return env_action_space + + def _get_policy_components(self) -> List[nn.Module]: + return [] + + def aux_loss_parameters(self) -> Dict[str, Iterable[torch.Tensor]]: + return {} + + def policy_parameters(self) -> Iterable[torch.Tensor]: + for c in self._get_policy_components(): + yield from c.parameters() + + def all_policy_tensors(self) -> Iterable[torch.Tensor]: + yield from self.policy_parameters() + for c in self._get_policy_components(): + yield from c.buffers() + + def extract_policy_info( + self, action_data: PolicyActionData, infos, dones + ) -> List[Dict[str, float]]: """ Gets the log information from the policy at the current time step. Currently only called during evaluation. The return list should be @@ -64,7 +139,7 @@ def act( prev_actions, masks, deterministic=False, - ): + ) -> PolicyActionData: raise NotImplementedError @classmethod @@ -155,8 +230,12 @@ def act( action = distribution.sample() action_log_probs = distribution.log_probs(action) - - return value, action, action_log_probs, rnn_hidden_states + return PolicyActionData( + values=value, + actions=action, + action_log_probs=action_log_probs, + rnn_hidden_states=rnn_hidden_states, + ) def get_value(self, observations, rnn_hidden_states, prev_actions, masks): features, _, _ = self.net( @@ -207,18 +286,8 @@ def evaluate_actions( aux_loss_res, ) - @property - def policy_components(self): - return (self.net, self.critic, self.action_distribution) - - def policy_parameters(self) -> Iterable[torch.Tensor]: - for c in self.policy_components: - yield from c.parameters() - - def all_policy_tensors(self) -> Iterable[torch.Tensor]: - yield from self.policy_parameters() - for c in self.policy_components: - yield from c.buffers() + def _get_policy_components(self) -> List[nn.Module]: + return [self.net, self.critic, self.action_distribution] def aux_loss_parameters(self) -> Dict[str, Iterable[torch.Tensor]]: return {k: v.parameters() for k, v in self.aux_loss_modules.items()} diff --git a/habitat-baselines/habitat_baselines/rl/ppo/ppo.py b/habitat-baselines/habitat_baselines/rl/ppo/ppo.py index 7b2ca19e2b..7affdc63ba 100644 --- a/habitat-baselines/habitat_baselines/rl/ppo/ppo.py +++ b/habitat-baselines/habitat_baselines/rl/ppo/ppo.py @@ -6,7 +6,7 @@ import collections import inspect -from typing import Dict, Optional, Union +from typing import Any, Dict, List, Optional, Union import torch import torch.nn as nn @@ -15,6 +15,7 @@ from torch import Tensor from habitat.utils import profiling_wrapper +from habitat_baselines.common.baseline_registry import baseline_registry from habitat_baselines.common.rollout_storage import RolloutStorage from habitat_baselines.rl.ppo.policy import NetPolicy from habitat_baselines.rl.ver.ver_rollout_storage import VERRolloutStorage @@ -26,6 +27,7 @@ EPS_PPO = 1e-5 +@baseline_registry.register_updater class PPO(nn.Module): entropy_coef: Union[float, LagrangeInequalityCoefficient] @@ -154,13 +156,10 @@ def _set_grads_to_none(self): for p in pg["params"]: p.grad = None - def update( - self, - rollouts: RolloutStorage, - ) -> Dict[str, float]: - advantages = self.get_advantages(rollouts) - - learner_metrics = collections.defaultdict(list) + def _update_from_batch(self, batch, epoch, rollouts, learner_metrics): + """ + Performs a gradient update from the minibatch. + """ def record_min_mean_max(t: torch.Tensor, prefix: str): for name, op in ( @@ -170,145 +169,147 @@ def record_min_mean_max(t: torch.Tensor, prefix: str): ): learner_metrics[f"{prefix}_{name}"].append(op(t)) - for epoch in range(self.ppo_epoch): - profiling_wrapper.range_push("PPO.update epoch") - data_generator = rollouts.recurrent_generator( - advantages, self.num_mini_batch + self._set_grads_to_none() + + ( + values, + action_log_probs, + dist_entropy, + _, + aux_loss_res, + ) = self._evaluate_actions( + batch["observations"], + batch["recurrent_hidden_states"], + batch["prev_actions"], + batch["masks"], + batch["actions"], + batch["rnn_build_seq_info"], + ) + + ratio = torch.exp(action_log_probs - batch["action_log_probs"]) + + surr1 = batch["advantages"] * ratio + surr2 = batch["advantages"] * ( + torch.clamp( + ratio, + 1.0 - self.clip_param, + 1.0 + self.clip_param, ) + ) + action_loss = -torch.min(surr1, surr2) - for _bid, batch in enumerate(data_generator): - self._set_grads_to_none() - - ( - values, - action_log_probs, - dist_entropy, - _, - aux_loss_res, - ) = self._evaluate_actions( - batch["observations"], - batch["recurrent_hidden_states"], - batch["prev_actions"], - batch["masks"], - batch["actions"], - batch["rnn_build_seq_info"], - ) + values = values.float() + orig_values = values - ratio = torch.exp(action_log_probs - batch["action_log_probs"]) + if self.use_clipped_value_loss: + delta = values.detach() - batch["value_preds"] + value_pred_clipped = batch["value_preds"] + delta.clamp( + -self.clip_param, self.clip_param + ) - surr1 = batch["advantages"] * ratio - surr2 = batch["advantages"] * ( - torch.clamp( - ratio, - 1.0 - self.clip_param, - 1.0 + self.clip_param, - ) - ) - action_loss = -torch.min(surr1, surr2) + values = torch.where( + delta.abs() < self.clip_param, + values, + value_pred_clipped, + ) - values = values.float() - orig_values = values + value_loss = 0.5 * F.mse_loss( + values, batch["returns"], reduction="none" + ) - if self.use_clipped_value_loss: - delta = values.detach() - batch["value_preds"] - value_pred_clipped = batch["value_preds"] + delta.clamp( - -self.clip_param, self.clip_param - ) + if "is_coeffs" in batch: + assert isinstance(batch["is_coeffs"], torch.Tensor) + ver_is_coeffs = batch["is_coeffs"].clamp(max=1.0) + mean_fn = lambda t: torch.mean(ver_is_coeffs * t) + else: + mean_fn = torch.mean - values = torch.where( - delta.abs() < self.clip_param, - values, - value_pred_clipped, - ) + action_loss, value_loss, dist_entropy = map( + mean_fn, + (action_loss, value_loss, dist_entropy), + ) - value_loss = 0.5 * F.mse_loss( - values, batch["returns"], reduction="none" - ) + all_losses = [ + self.value_loss_coef * value_loss, + action_loss, + ] - if "is_coeffs" in batch: - assert isinstance(batch["is_coeffs"], torch.Tensor) - ver_is_coeffs = batch["is_coeffs"].clamp(max=1.0) - mean_fn = lambda t: torch.mean(ver_is_coeffs * t) - else: - mean_fn = torch.mean + if isinstance(self.entropy_coef, float): + all_losses.append(-self.entropy_coef * dist_entropy) + else: + all_losses.append(self.entropy_coef.lagrangian_loss(dist_entropy)) - action_loss, value_loss, dist_entropy = map( - mean_fn, - (action_loss, value_loss, dist_entropy), - ) + all_losses.extend(v["loss"] for v in aux_loss_res.values()) - all_losses = [ - self.value_loss_coef * value_loss, - action_loss, - ] + total_loss = torch.stack(all_losses).sum() - if isinstance(self.entropy_coef, float): - all_losses.append(-self.entropy_coef * dist_entropy) - else: - all_losses.append( - self.entropy_coef.lagrangian_loss(dist_entropy) - ) + total_loss = self.before_backward(total_loss) + total_loss.backward() + self.after_backward(total_loss) - all_losses.extend(v["loss"] for v in aux_loss_res.values()) + grad_norm = self.before_step() + self.optimizer.step() + self.after_step() - total_loss = torch.stack(all_losses).sum() + with inference_mode(): + if "is_coeffs" in batch: + record_min_mean_max(batch["is_coeffs"], "ver_is_coeffs") + record_min_mean_max(orig_values, "value_pred") + record_min_mean_max(ratio, "prob_ratio") + + learner_metrics["value_loss"].append(value_loss) + learner_metrics["action_loss"].append(action_loss) + learner_metrics["dist_entropy"].append(dist_entropy) + if epoch == (self.ppo_epoch - 1): + learner_metrics["ppo_fraction_clipped"].append( + (ratio > (1.0 + self.clip_param)).float().mean() + + (ratio < (1.0 - self.clip_param)).float().mean() + ) - total_loss = self.before_backward(total_loss) - total_loss.backward() - self.after_backward(total_loss) + learner_metrics["grad_norm"].append(grad_norm) + if isinstance(self.entropy_coef, LagrangeInequalityCoefficient): + learner_metrics["entropy_coef"].append( + self.entropy_coef().detach() + ) - grad_norm = self.before_step() - self.optimizer.step() - self.after_step() + for name, res in aux_loss_res.items(): + for k, v in res.items(): + learner_metrics[f"aux_{name}_{k}"].append(v.detach()) - with inference_mode(): - if "is_coeffs" in batch: - record_min_mean_max( - batch["is_coeffs"], "ver_is_coeffs" - ) - record_min_mean_max(orig_values, "value_pred") - record_min_mean_max(ratio, "prob_ratio") - - learner_metrics["value_loss"].append(value_loss) - learner_metrics["action_loss"].append(action_loss) - learner_metrics["dist_entropy"].append(dist_entropy) - if epoch == (self.ppo_epoch - 1): - learner_metrics["ppo_fraction_clipped"].append( - (ratio > (1.0 + self.clip_param)).float().mean() - + (ratio < (1.0 - self.clip_param)).float().mean() - ) + if "is_stale" in batch: + assert isinstance(batch["is_stale"], torch.Tensor) + learner_metrics["fraction_stale"].append( + batch["is_stale"].float().mean() + ) - learner_metrics["grad_norm"].append(grad_norm) - if isinstance( - self.entropy_coef, LagrangeInequalityCoefficient - ): - learner_metrics["entropy_coef"].append( - self.entropy_coef().detach() - ) + if isinstance(rollouts, VERRolloutStorage): + assert isinstance(batch["policy_version"], torch.Tensor) + record_min_mean_max( + ( + rollouts.current_policy_version + - batch["policy_version"] + ).float(), + "policy_version_difference", + ) + + def update( + self, + rollouts: RolloutStorage, + ) -> Dict[str, float]: + advantages = self.get_advantages(rollouts) - for name, res in aux_loss_res.items(): - for k, v in res.items(): - learner_metrics[f"aux_{name}_{k}"].append( - v.detach() - ) + learner_metrics: Dict[str, List[Any]] = collections.defaultdict(list) - if "is_stale" in batch: - assert isinstance(batch["is_stale"], torch.Tensor) - learner_metrics["fraction_stale"].append( - batch["is_stale"].float().mean() - ) + for epoch in range(self.ppo_epoch): + profiling_wrapper.range_push("PPO.update epoch") + data_generator = rollouts.recurrent_generator( + advantages, self.num_mini_batch + ) - if isinstance(rollouts, VERRolloutStorage): - assert isinstance( - batch["policy_version"], torch.Tensor - ) - record_min_mean_max( - ( - rollouts.current_policy_version - - batch["policy_version"] - ).float(), - "policy_version_difference", - ) + for _bid, batch in enumerate(data_generator): + self._update_from_batch( + batch, epoch, rollouts, learner_metrics + ) profiling_wrapper.range_pop() # PPO.update epoch diff --git a/habitat-baselines/habitat_baselines/rl/ppo/ppo_trainer.py b/habitat-baselines/habitat_baselines/rl/ppo/ppo_trainer.py index 60c1fff608..dcc4289cb6 100644 --- a/habitat-baselines/habitat_baselines/rl/ppo/ppo_trainer.py +++ b/habitat-baselines/habitat_baselines/rl/ppo/ppo_trainer.py @@ -35,12 +35,14 @@ apply_obs_transforms_obs_space, get_active_obs_transforms, ) -from habitat_baselines.common.rollout_storage import RolloutStorage +from habitat_baselines.common.rollout_storage import ( # noqa: F401. + RolloutStorage, +) from habitat_baselines.common.tensorboard_utils import ( TensorboardWriter, get_writer, ) -from habitat_baselines.rl.ddppo.algo import DDPPO +from habitat_baselines.rl.ddppo.algo import DDPPO # noqa: F401. from habitat_baselines.rl.ddppo.ddp_utils import ( EXIT, get_distrib_size, @@ -58,12 +60,12 @@ from habitat_baselines.rl.hrl.hierarchical_policy import ( # noqa: F401. HierarchicalPolicy, ) -from habitat_baselines.rl.ppo import PPO +from habitat_baselines.rl.ppo import PPO # noqa: F401. from habitat_baselines.rl.ppo.policy import NetPolicy from habitat_baselines.utils.common import ( batch_obs, generate_video, - get_num_actions, + get_action_space_info, inference_mode, is_continuous_action_space, ) @@ -82,6 +84,11 @@ class PPOTrainer(BaseRLTrainer): r"""Trainer class for PPO algorithm Paper: https://arxiv.org/abs/1707.06347. + + :property env_action_space: The action space required by the environment. + :property policy_action_space: The action space the policy acts in. This + can be different from the environment action space for hierarchical + policies. """ supported_tasks = ["Nav-v0"] @@ -152,8 +159,8 @@ def _setup_actor_critic_agent(self, ppo_cfg: "DictConfig") -> None: self.actor_critic = policy.from_config( self.config, observation_space, - self.policy_action_space, - orig_action_space=self.orig_policy_action_space, + self.env_action_space, + orig_action_space=self.orig_env_action_space, ) self.obs_space = observation_space self.actor_critic.to(self.device) @@ -193,8 +200,18 @@ def _setup_actor_critic_agent(self, ppo_cfg: "DictConfig") -> None: nn.init.orthogonal_(self.actor_critic.critic.fc.weight) nn.init.constant_(self.actor_critic.critic.fc.bias, 0) - self.agent = (DDPPO if self._is_distributed else PPO).from_config( - self.actor_critic, ppo_cfg + if self._is_distributed: + agent_cls = baseline_registry.get_updater( + self.config.habitat_baselines.distrib_updater_name + ) + else: + agent_cls = baseline_registry.get_updater( + self.config.habitat_baselines.updater_name + ) + + self.agent = agent_cls.from_config(self.actor_critic, ppo_cfg) + self.policy_action_space = self.actor_critic.get_policy_action_space( + self.env_action_space ) def _init_envs(self, config=None, is_eval: bool = False): @@ -206,6 +223,8 @@ def _init_envs(self, config=None, is_eval: bool = False): workers_ignore_signals=is_slurm_batch_job(), enforce_scenes_greater_eq_environments=is_eval, ) + self.env_action_space = self.envs.action_spaces[0] + self.orig_env_action_space = self.envs.orig_action_spaces[0] def _init_train(self, resume_state=None): if resume_state is None: @@ -282,18 +301,6 @@ def _init_train(self, resume_state=None): self._init_envs() - action_space = self.envs.action_spaces[0] - self.policy_action_space = action_space - self.orig_policy_action_space = self.envs.orig_action_spaces[0] - if is_continuous_action_space(action_space): - # Assume ALL actions are NOT discrete - action_shape = (get_num_actions(action_space),) - discrete_actions = False - else: - # For discrete pointnav - action_shape = (1,) - discrete_actions = True - ppo_cfg = self.config.habitat_baselines.rl.ppo if torch.cuda.is_available(): self.device = torch.device( @@ -338,16 +345,17 @@ def _init_train(self, resume_state=None): self._nbuffers = 2 if ppo_cfg.use_double_buffered_sampler else 1 - self.rollouts = RolloutStorage( + rollouts_cls = baseline_registry.get_storage( + self.config.habitat_baselines.rollout_storage_name + ) + self.rollouts = rollouts_cls( ppo_cfg.num_steps, self.envs.num_envs, obs_space, self.policy_action_space, ppo_cfg.hidden_size, - num_recurrent_layers=self.actor_critic.net.num_recurrent_layers, + num_recurrent_layers=self.actor_critic.num_recurrent_layers, is_double_buffered=ppo_cfg.use_double_buffered_sampler, - action_shape=action_shape, - discrete_actions=discrete_actions, ) self.rollouts.to(self.device) @@ -439,12 +447,7 @@ def _compute_actions_and_step_envs(self, buffer_index: int = 0): ] profiling_wrapper.range_push("compute actions") - ( - values, - actions, - actions_log_probs, - recurrent_hidden_states, - ) = self.actor_critic.act( + action_data = self.actor_critic.act( step_batch["observations"], step_batch["recurrent_hidden_states"], step_batch["prev_actions"], @@ -458,14 +461,15 @@ def _compute_actions_and_step_envs(self, buffer_index: int = 0): t_step_env = time.time() for index_env, act in zip( - range(env_slice.start, env_slice.stop), actions.cpu().unbind(0) + range(env_slice.start, env_slice.stop), + action_data.env_actions.cpu().unbind(0), ): - if is_continuous_action_space(self.policy_action_space): + if is_continuous_action_space(self.env_action_space): # Clipping actions to the specified limits act = np.clip( act.numpy(), - self.policy_action_space.low, - self.policy_action_space.high, + self.env_action_space.low, + self.env_action_space.high, ) else: act = act.item() @@ -474,11 +478,12 @@ def _compute_actions_and_step_envs(self, buffer_index: int = 0): self.env_time += time.time() - t_step_env self.rollouts.insert( - next_recurrent_hidden_states=recurrent_hidden_states, - actions=actions, - action_log_probs=actions_log_probs, - value_preds=values, + next_recurrent_hidden_states=action_data.rnn_hidden_states, + actions=action_data.actions, + action_log_probs=action_data.action_log_probs, + value_preds=action_data.values, buffer_index=buffer_index, + should_inserts=action_data.should_inserts, ) def _collect_environment_result(self, buffer_index: int = 0): @@ -932,19 +937,10 @@ def _eval_checkpoint( self._init_envs(config, is_eval=True) - action_space = self.envs.action_spaces[0] - self.policy_action_space = action_space - self.orig_policy_action_space = self.envs.orig_action_spaces[0] - if is_continuous_action_space(action_space): - # Assume NONE of the actions are discrete - action_shape = (get_num_actions(action_space),) - discrete_actions = False - else: - # For discrete pointnav - action_shape = (1,) - discrete_actions = True - self._setup_actor_critic_agent(ppo_cfg) + action_shape, discrete_actions = get_action_space_info( + self.policy_action_space + ) if self.agent.actor_critic.should_load_agent_state: self.agent.load_state_dict(ckpt_dict["state_dict"]) @@ -1018,44 +1014,53 @@ def _eval_checkpoint( current_episodes_info = self.envs.current_episodes() with inference_mode(): - ( - _, - actions, - _, - test_recurrent_hidden_states, - ) = self.actor_critic.act( + action_data = self.actor_critic.act( batch, test_recurrent_hidden_states, prev_actions, not_done_masks, deterministic=False, ) - - prev_actions.copy_(actions) # type: ignore + if action_data.should_inserts is None: + test_recurrent_hidden_states = ( + action_data.rnn_hidden_states + ) + prev_actions.copy_(action_data.actions) # type: ignore + else: + for i, should_insert in enumerate( + action_data.should_inserts + ): + if should_insert.item(): + test_recurrent_hidden_states[ + i + ] = action_data.rnn_hidden_states[i] + prev_actions[i].copy_(action_data.actions[i]) # type: ignore # NB: Move actions to CPU. If CUDA tensors are # sent in to env.step(), that will create CUDA contexts # in the subprocesses. - if is_continuous_action_space(self.policy_action_space): + if is_continuous_action_space(self.env_action_space): # Clipping actions to the specified limits step_data = [ np.clip( a.numpy(), - self.policy_action_space.low, - self.policy_action_space.high, + self.env_action_space.low, + self.env_action_space.high, ) - for a in actions.cpu() + for a in action_data.env_actions.cpu() ] else: - step_data = [a.item() for a in actions.cpu()] + step_data = [a.item() for a in action_data.env_actions.cpu()] outputs = self.envs.step(step_data) observations, rewards_l, dones, infos = [ list(x) for x in zip(*outputs) ] - policy_info = self.actor_critic.get_policy_info(infos, dones) - for i in range(len(policy_info)): - infos[i].update(policy_info[i]) + policy_infos = self.actor_critic.extract_policy_info( + action_data, infos, dones + ) + for i in range(len(policy_infos)): + infos[i].update(policy_infos[i]) batch = batch_obs( # type: ignore observations, device=self.device, @@ -1098,9 +1103,7 @@ def _eval_checkpoint( frame = observations_to_image( {k: v[i] * 0.0 for k, v in batch.items()}, infos[i] ) - frame = overlay_frame( - frame, extract_scalars_from_info(infos[i]) - ) + frame = overlay_frame(frame, infos[i]) rgb_frames[i].append(frame) # episode ended diff --git a/habitat-baselines/habitat_baselines/rl/ver/inference_worker.py b/habitat-baselines/habitat_baselines/rl/ver/inference_worker.py index bd674a00a9..d8646af772 100644 --- a/habitat-baselines/habitat_baselines/rl/ver/inference_worker.py +++ b/habitat-baselines/habitat_baselines/rl/ver/inference_worker.py @@ -318,12 +318,7 @@ def step(self) -> Tuple[bool, List[Tuple[int, int]]]: PointNavResNetNet.PRETRAINED_VISUAL_FEATURES_KEY ] = self.visual_encoder(obs) - ( - values, - actions, - actions_log_probs, - next_recurrent_hidden_states, - ) = self.actor_critic.act( + action_data = self.actor_critic.act( obs, recurrent_hidden_states, prev_actions, @@ -332,10 +327,12 @@ def step(self) -> Tuple[bool, List[Tuple[int, int]]]: if not final_batch: self.rollouts.next_hidden_states.index_copy_( - 0, environment_ids, next_recurrent_hidden_states + 0, + environment_ids, + action_data.rnn_hidden_states, ) self.rollouts.next_prev_actions.index_copy_( - 0, environment_ids, actions + 0, environment_ids, action_data.actions ) if self._variable_experience: @@ -347,7 +344,7 @@ def step(self) -> Tuple[bool, List[Tuple[int, int]]]: dtype=np.int64, ) - cpu_actions = actions.to(device="cpu") + cpu_actions = action_data.env_actions.to(device="cpu") self.transfer_buffers["actions"][ self.new_reqs ] = cpu_actions.numpy() @@ -387,8 +384,8 @@ def step(self) -> Tuple[bool, List[Tuple[int, int]]]: dict( masks=to_batch["masks"], observations=obs, - actions=actions, - action_log_probs=actions_log_probs, + actions=action_data.actions, + action_log_probs=action_data.action_log_probs, recurrent_hidden_states=recurrent_hidden_states, prev_actions=prev_actions, policy_version=self.rollouts.current_policy_version.expand( @@ -397,7 +394,7 @@ def step(self) -> Tuple[bool, List[Tuple[int, int]]]: episode_ids=to_batch["episode_ids"], environment_ids=to_batch["environment_ids"], step_ids=to_batch["step_ids"], - value_preds=values, + value_preds=action_data.values, returns=torch.full( (), float("nan"), diff --git a/habitat-baselines/habitat_baselines/rl/ver/ver_rollout_storage.py b/habitat-baselines/habitat_baselines/rl/ver/ver_rollout_storage.py index 361ee96ad4..d9f6c70af0 100644 --- a/habitat-baselines/habitat_baselines/rl/ver/ver_rollout_storage.py +++ b/habitat-baselines/habitat_baselines/rl/ver/ver_rollout_storage.py @@ -144,9 +144,7 @@ def __init__( action_space, recurrent_hidden_state_size, num_recurrent_layers=1, - action_shape: Optional[Tuple[int]] = None, is_double_buffered: bool = False, - discrete_actions: bool = True, ): super().__init__( numsteps, @@ -155,9 +153,7 @@ def __init__( action_space, recurrent_hidden_state_size, num_recurrent_layers, - action_shape, is_double_buffered, - discrete_actions, ) self.use_is_coeffs = variable_experience diff --git a/habitat-baselines/habitat_baselines/rl/ver/ver_trainer.py b/habitat-baselines/habitat_baselines/rl/ver/ver_trainer.py index 52bc4932c7..fc8f75c674 100644 --- a/habitat-baselines/habitat_baselines/rl/ver/ver_trainer.py +++ b/habitat-baselines/habitat_baselines/rl/ver/ver_trainer.py @@ -52,12 +52,7 @@ WorkerBase, WorkerQueues, ) -from habitat_baselines.utils.common import ( - cosine_decay, - get_num_actions, - inference_mode, - is_continuous_action_space, -) +from habitat_baselines.utils.common import cosine_decay, inference_mode try: torch.backends.cudnn.allow_tf32 = True @@ -186,7 +181,8 @@ def _init_train(self, resume_state): action_space = init_reports[0]["act_space"] self.policy_action_space = action_space - self.orig_policy_action_space = None + self.env_action_space = action_space + self.orig_env_action_space = None [ ew.set_action_plugin( @@ -196,14 +192,6 @@ def _init_train(self, resume_state): ) for ew in self.environment_workers ] - if is_continuous_action_space(action_space): - # Assume ALL actions are NOT discrete - action_shape = (get_num_actions(action_space),) - discrete_actions = False - else: - # For discrete pointnav - action_shape = (1,) - discrete_actions = True ppo_cfg = self.config.habitat_baselines.rl.ppo if torch.cuda.is_available(): @@ -253,8 +241,6 @@ def _init_train(self, resume_state): action_space=self.policy_action_space, recurrent_hidden_state_size=ppo_cfg.hidden_size, num_recurrent_layers=self.actor_critic.net.num_recurrent_layers, - action_shape=action_shape, - discrete_actions=discrete_actions, observation_space=rollouts_obs_space, ) self.rollouts = VERRolloutStorage(**storage_kwargs) diff --git a/habitat-baselines/habitat_baselines/utils/common.py b/habitat-baselines/habitat_baselines/utils/common.py index 76e52423d6..2e8570de95 100644 --- a/habitat-baselines/habitat_baselines/utils/common.py +++ b/habitat-baselines/habitat_baselines/utils/common.py @@ -694,6 +694,25 @@ def is_continuous_action_space(action_space) -> bool: ) +def get_action_space_info(ac_space: spaces.Space) -> Tuple[Tuple[int], bool]: + """ + :returns: The shape of the action space and if the action space is discrete. If the action space is discrete, the shape will be `(1,)`. + """ + if is_continuous_action_space(ac_space): + # Assume NONE of the actions are discrete + return ( + ( + get_num_actions( + ac_space, + ), + ), + False, + ) + else: + # For discrete pointnav + return (1,), True + + def get_num_actions(action_space) -> int: num_actions = 0 for v in iterate_action_space_recursively(action_space): diff --git a/habitat-lab/habitat/config/README.md b/habitat-lab/habitat/config/README.md index 3bfbc53a6a..8916c951a2 100644 --- a/habitat-lab/habitat/config/README.md +++ b/habitat-lab/habitat/config/README.md @@ -241,7 +241,6 @@ habitat: constraint_violation_drops_object: false force_regenerate: false should_save_to_cache: true - must_look_at_targ: true object_in_hand_sample_prob: 0.167 render_target: true ee_sample_factor: 0.2 diff --git a/habitat-lab/habitat/config/benchmark/rearrange/rearrange_easy.yaml b/habitat-lab/habitat/config/benchmark/rearrange/rearrange_easy.yaml index a9f5b8b68d..7b5995e431 100644 --- a/habitat-lab/habitat/config/benchmark/rearrange/rearrange_easy.yaml +++ b/habitat-lab/habitat/config/benchmark/rearrange/rearrange_easy.yaml @@ -19,6 +19,7 @@ habitat: - obj_goal_gps_compass - joint - is_holding + - ee_pos environment: max_episode_steps: 1500 simulator: diff --git a/habitat-lab/habitat/config/default_structured_configs.py b/habitat-lab/habitat/config/default_structured_configs.py index dcd70fcaff..c0638a849c 100644 --- a/habitat-lab/habitat/config/default_structured_configs.py +++ b/habitat-lab/habitat/config/default_structured_configs.py @@ -196,6 +196,11 @@ class RearrangeStopActionConfig(ActionConfig): type: str = "RearrangeStopAction" +@attr.s(auto_attribs=True, slots=True) +class PddlApplyActionConfig(ActionConfig): + type: str = "PddlApplyAction" + + @attr.s(auto_attribs=True, slots=True) class OracleNavActionConfig(ActionConfig): """ @@ -214,6 +219,8 @@ class OracleNavActionConfig(ActionConfig): ang_speed: float = 10.0 allow_dyn_slide: bool = True allow_back: bool = True + spawn_max_dist_to_obj: float = 2.0 + num_spawn_attempts: int = 200 # ----------------------------------------------------------------------------- @@ -850,7 +857,6 @@ class TaskConfig(HabitatBaseConfig): force_regenerate: bool = False # Saves the generated starts to a cache if they are not already generated should_save_to_cache: bool = False - must_look_at_targ: bool = True object_in_hand_sample_prob: float = 0.167 min_start_distance: float = 3.0 gfx_replay_dir = "data/replays" @@ -858,7 +864,7 @@ class TaskConfig(HabitatBaseConfig): # Spawn parameters physics_stability_steps: int = 1 num_spawn_attempts: int = 200 - spawn_max_dists_to_obj: float = 2.0 + spawn_max_dist_to_obj: float = 2.0 base_angle_noise: float = 0.523599 # EE sample parameters ee_sample_factor: float = 0.2 @@ -871,9 +877,6 @@ class TaskConfig(HabitatBaseConfig): cache_robot_init: bool = False success_state: float = 0.0 # Measurements for composite tasks. - # If true, does not care about navigability or collisions - # with objects when spawning robot - easy_init: bool = False should_enforce_target_within_reach: bool = False # COMPOSITE task CONFIG task_spec_base_path: str = "habitat/task/rearrange/pddl/" @@ -1330,6 +1333,12 @@ class HabitatConfig(HabitatBaseConfig): name="oracle_nav_action", node=OracleNavActionConfig, ) +cs.store( + package="habitat.task.actions.pddl_apply_action", + group="habitat/task/actions", + name="pddl_apply_action", + node=PddlApplyActionConfig, +) # Dataset Config Schema cs.store( @@ -1457,6 +1466,12 @@ class HabitatConfig(HabitatBaseConfig): name="instance_imagegoal_hfov_sensor", node=InstanceImageGoalHFOVSensorConfig, ) +cs.store( + package="habitat.task.lab_sensors.localization_sensor", + group="habitat/task/lab_sensors", + name="localization_sensor", + node=LocalizationSensorConfig, +) cs.store( package="habitat.task.lab_sensors.target_start_sensor", group="habitat/task/lab_sensors", diff --git a/habitat-lab/habitat/config/habitat/simulator/agents/rgbd_head_agent.yaml b/habitat-lab/habitat/config/habitat/simulator/agents/rgbd_head_agent.yaml new file mode 100644 index 0000000000..338bbd2dd3 --- /dev/null +++ b/habitat-lab/habitat/config/habitat/simulator/agents/rgbd_head_agent.yaml @@ -0,0 +1,6 @@ +# @package habitat.simulator.agents.rgb_head_agent + +defaults: + - agent_base + - /habitat/simulator/sim_sensors@sim_sensors.head_rgb_sensor: head_rgb_sensor + - /habitat/simulator/sim_sensors@sim_sensors.head_depth_sensor: head_depth_sensor diff --git a/habitat-lab/habitat/config/habitat/task/rearrange/pick_spa.yaml b/habitat-lab/habitat/config/habitat/task/rearrange/pick_spa.yaml index f720718123..a9879d7c44 100644 --- a/habitat-lab/habitat/config/habitat/task/rearrange/pick_spa.yaml +++ b/habitat-lab/habitat/config/habitat/task/rearrange/pick_spa.yaml @@ -36,10 +36,6 @@ base_angle_noise: 0.15 base_noise: 0.05 force_regenerate: False -# If true, does not care about navigability or collisions with objects when spawning -# robot -easy_init: False - actions: arm_action: arm_controller: "ArmAbsPosKinematicAction" diff --git a/habitat-lab/habitat/config/habitat/task/rearrange/play.yaml b/habitat-lab/habitat/config/habitat/task/rearrange/play.yaml index 639e87ecd1..74d4304263 100644 --- a/habitat-lab/habitat/config/habitat/task/rearrange/play.yaml +++ b/habitat-lab/habitat/config/habitat/task/rearrange/play.yaml @@ -26,9 +26,6 @@ base_angle_noise: 0.0 base_noise: 0.0 constraint_violation_ends_episode: False -# If true, does not care about navigability or collisions -# with objects when spawning robot -easy_init: False force_regenerate: True actions: diff --git a/habitat-lab/habitat/config/habitat/task/rearrange/prepare_groceries.yaml b/habitat-lab/habitat/config/habitat/task/rearrange/prepare_groceries.yaml index d404648720..bb54918c66 100644 --- a/habitat-lab/habitat/config/habitat/task/rearrange/prepare_groceries.yaml +++ b/habitat-lab/habitat/config/habitat/task/rearrange/prepare_groceries.yaml @@ -28,6 +28,7 @@ defaults: - end_effector_sensor - target_start_gps_compass_sensor - target_goal_gps_compass_sensor + - localization_sensor - _self_ type: RearrangeCompositeTask-v0 diff --git a/habitat-lab/habitat/config/habitat/task/rearrange/rearrange.yaml b/habitat-lab/habitat/config/habitat/task/rearrange/rearrange.yaml index 13dcf40da7..9f3b3f3608 100644 --- a/habitat-lab/habitat/config/habitat/task/rearrange/rearrange.yaml +++ b/habitat-lab/habitat/config/habitat/task/rearrange/rearrange.yaml @@ -27,6 +27,7 @@ defaults: - relative_resting_pos_sensor - target_start_gps_compass_sensor - target_goal_gps_compass_sensor + - localization_sensor - _self_ type: RearrangeCompositeTask-v0 diff --git a/habitat-lab/habitat/config/habitat/task/rearrange/rearrange_easy.yaml b/habitat-lab/habitat/config/habitat/task/rearrange/rearrange_easy.yaml index 793fbc1847..e7dff55b48 100644 --- a/habitat-lab/habitat/config/habitat/task/rearrange/rearrange_easy.yaml +++ b/habitat-lab/habitat/config/habitat/task/rearrange/rearrange_easy.yaml @@ -19,6 +19,7 @@ defaults: - did_violate_hold_constraint - move_objects_reward - gfx_replay_measure + - composite_stage_goals - /habitat/task/lab_sensors: - relative_resting_pos_sensor - target_start_sensor @@ -28,6 +29,7 @@ defaults: - end_effector_sensor - target_start_gps_compass_sensor - target_goal_gps_compass_sensor + - localization_sensor - _self_ type: RearrangeCompositeTask-v0 diff --git a/habitat-lab/habitat/config/habitat/task/rearrange/rearrange_easy_multi_agent.yaml b/habitat-lab/habitat/config/habitat/task/rearrange/rearrange_easy_multi_agent.yaml index 47220cf3cb..afac052bf5 100644 --- a/habitat-lab/habitat/config/habitat/task/rearrange/rearrange_easy_multi_agent.yaml +++ b/habitat-lab/habitat/config/habitat/task/rearrange/rearrange_easy_multi_agent.yaml @@ -24,6 +24,7 @@ defaults: - end_effector_sensor - target_start_gps_compass_sensor - target_goal_gps_compass_sensor + - localization_sensor - _self_ type: RearrangeCompositeTask-v0 diff --git a/habitat-lab/habitat/config/habitat/task/rearrange/set_table.yaml b/habitat-lab/habitat/config/habitat/task/rearrange/set_table.yaml index 662f8da3bb..f11cc889d9 100644 --- a/habitat-lab/habitat/config/habitat/task/rearrange/set_table.yaml +++ b/habitat-lab/habitat/config/habitat/task/rearrange/set_table.yaml @@ -28,6 +28,7 @@ defaults: - end_effector_sensor - target_start_gps_compass_sensor - target_goal_gps_compass_sensor + - localization_sensor - _self_ type: RearrangeCompositeTask-v0 diff --git a/habitat-lab/habitat/config/habitat/task/rearrange/tidy_house.yaml b/habitat-lab/habitat/config/habitat/task/rearrange/tidy_house.yaml index f1452911a4..6d176c934a 100644 --- a/habitat-lab/habitat/config/habitat/task/rearrange/tidy_house.yaml +++ b/habitat-lab/habitat/config/habitat/task/rearrange/tidy_house.yaml @@ -28,6 +28,7 @@ defaults: - end_effector_sensor - target_start_gps_compass_sensor - target_goal_gps_compass_sensor + - localization_sensor - _self_ type: RearrangeCompositeTask-v0 diff --git a/habitat-lab/habitat/sims/habitat_simulator/habitat_simulator.py b/habitat-lab/habitat/sims/habitat_simulator/habitat_simulator.py index 9b187b1519..ddd8e80a6b 100644 --- a/habitat-lab/habitat/sims/habitat_simulator/habitat_simulator.py +++ b/habitat-lab/habitat/sims/habitat_simulator/habitat_simulator.py @@ -391,8 +391,8 @@ def _update_agents_state(self) -> bool: agent_cfg = self.habitat_config.agents[agent_name] if agent_cfg.is_set_start_state: self.set_agent_state( - agent_cfg.start_position, - agent_cfg.start_rotation, + [float(k) for k in agent_cfg.start_position], + [float(k) for k in agent_cfg.start_rotation], agent_id, ) is_updated = True diff --git a/habitat-lab/habitat/tasks/rearrange/actions/oracle_nav_action.py b/habitat-lab/habitat/tasks/rearrange/actions/oracle_nav_action.py index 19523e04c1..521aa5a5b7 100644 --- a/habitat-lab/habitat/tasks/rearrange/actions/oracle_nav_action.py +++ b/habitat-lab/habitat/tasks/rearrange/actions/oracle_nav_action.py @@ -14,32 +14,34 @@ from habitat.tasks.utils import get_angle -def compute_turn(rel, turn_vel, robot_forward): - is_left = np.cross(robot_forward, rel) > 0 - if is_left: - vel = [0, -turn_vel] - else: - vel = [0, turn_vel] - return vel - - -def get_possible_nav_to_actions(pddl_problem): - return pddl_problem.get_possible_actions( - allowed_action_names=["nav", "nav_to_receptacle"], - true_preds=None, - ) - - @registry.register_task_action class OracleNavAction(BaseVelAction): + """ + An action that will convert the index of an entity (in the sense of + `PddlEntity`) to navigate to and convert this to base control to move the + robot to the closest navigable position to that entity. The entity index is + the index into the list of all available entities in the current scene. + """ + def __init__(self, *args, task, **kwargs): super().__init__(*args, **kwargs) self._task = task - self._poss_actions = get_possible_nav_to_actions(task.pddl_problem) + self._poss_entities = ( + self._task.pddl_problem.get_ordered_entities_list() + ) self._prev_ep_id = None self._targets = {} + @staticmethod + def _compute_turn(rel, turn_vel, robot_forward): + is_left = np.cross(robot_forward, rel) > 0 + if is_left: + vel = [0, -turn_vel] + else: + vel = [0, turn_vel] + return vel + @property def action_space(self): return spaces.Dict( @@ -62,8 +64,7 @@ def reset(self, *args, **kwargs): def _get_target_for_idx(self, nav_to_target_idx: int): if nav_to_target_idx not in self._targets: - action = self._poss_actions[nav_to_target_idx] - nav_to_obj = action.get_arg_value("obj") + nav_to_obj = self._poss_entities[nav_to_target_idx] obj_pos = self._task.pddl_problem.sim_info.get_entity_pos( nav_to_obj ) @@ -95,7 +96,7 @@ def step(self, *args, is_last_action, **kwargs): self._action_arg_prefix + "oracle_nav_action" ] if nav_to_target_idx <= 0 or nav_to_target_idx > len( - self._poss_actions + self._poss_entities ): if is_last_action: return self._sim.step(HabitatSimActions.base_velocity) @@ -134,7 +135,7 @@ def step(self, *args, is_last_action, **kwargs): if not at_goal: if dist_to_final_nav_targ < self._config.dist_thresh: # Look at the object - vel = compute_turn( + vel = OracleNavAction._compute_turn( rel_pos, self._config.turn_velocity, robot_forward ) elif angle_to_target < self._config.turn_thresh: @@ -142,7 +143,7 @@ def step(self, *args, is_last_action, **kwargs): vel = [self._config.forward_velocity, 0] else: # Look at the target waypoint. - vel = compute_turn( + vel = OracleNavAction._compute_turn( rel_targ, self._config.turn_velocity, robot_forward ) else: diff --git a/habitat-lab/habitat/tasks/rearrange/actions/pddl_actions.py b/habitat-lab/habitat/tasks/rearrange/actions/pddl_actions.py index 279bc079f7..df54c59593 100644 --- a/habitat-lab/habitat/tasks/rearrange/actions/pddl_actions.py +++ b/habitat-lab/habitat/tasks/rearrange/actions/pddl_actions.py @@ -8,7 +8,6 @@ from habitat.core.registry import registry from habitat.sims.habitat_simulator.actions import HabitatSimActions from habitat.tasks.rearrange.actions.grip_actions import RobotAction -from habitat.tasks.rearrange.utils import rearrange_logger @registry.register_task_action @@ -38,7 +37,10 @@ def action_space(self): { self._action_arg_prefix + "pddl_action": spaces.Box( - shape=(action_n_args,), low=-1, high=1, dtype=np.float32 + shape=(action_n_args,), + low=np.finfo(np.float32).min, + high=np.finfo(np.float32).max, + dtype=np.float32, ) } ) @@ -49,6 +51,7 @@ def was_prev_action_invalid(self): def reset(self, *args, **kwargs): self._was_prev_action_invalid = False + self._prev_action = None def get_pddl_action_start(self, action_id: int) -> int: start_idx = 0 @@ -56,10 +59,8 @@ def get_pddl_action_start(self, action_id: int) -> int: start_idx += action.n_args return start_idx - def step(self, *args, is_last_action, **kwargs): - apply_pddl_action = kwargs[self._action_arg_prefix + "pddl_action"] + def _apply_action(self, apply_pddl_action): cur_i = 0 - self._was_prev_action_invalid = False for action in self._action_ordering: action_part = apply_pddl_action[cur_i : cur_i + action.n_args][:] if sum(action_part) > 0: @@ -71,26 +72,31 @@ def step(self, *args, is_last_action, **kwargs): raise ValueError( f"Got invalid action value < 0 in {action_part} with action {action}" ) - rearrange_logger.debug(f"Got action part {real_action_idxs}") param_values = [ self._entities_list[i] for i in real_action_idxs ] - apply_action = action.copy() + apply_action = action.clone() apply_action.set_param_values(param_values) + self._prev_action = apply_action if self._task.pddl_problem.is_expr_true(apply_action.precond): - rearrange_logger.debug( - f"Applying action {action} with obj args {param_values}" - ) self._task.pddl_problem.apply_action(apply_action) else: - rearrange_logger.debug( - f"Preconds not satisfied for: action {action} with obj args {param_values}" - ) self._was_prev_action_invalid = True cur_i += action.n_args + + def step(self, *args, is_last_action, **kwargs): + self._prev_action = None + apply_pddl_action = kwargs[self._action_arg_prefix + "pddl_action"] + self._was_prev_action_invalid = False + inputs_outside = any( + a < 0 or a > len(self._entities_list) for a in apply_pddl_action + ) + if not inputs_outside: + self._apply_action(apply_pddl_action) + if is_last_action: return self._sim.step(HabitatSimActions.arm_action) else: diff --git a/habitat-lab/habitat/tasks/rearrange/multi_task/domain_configs/replica_cad.yaml b/habitat-lab/habitat/tasks/rearrange/multi_task/domain_configs/replica_cad.yaml index 2de18d0c9b..20a15150fa 100644 --- a/habitat-lab/habitat/tasks/rearrange/multi_task/domain_configs/replica_cad.yaml +++ b/habitat-lab/habitat/tasks/rearrange/multi_task/domain_configs/replica_cad.yaml @@ -310,3 +310,107 @@ actions: config_args: task.base_angle_noise: 0.0 task.spawn_region_scale: 0.0 + + ######################################################### + # Receptacle name only based variants of the receptacle skills. This does not + # require any information about knowing which objects the receptacle + # contains. + - name: nav_to_receptacle_by_name + parameters: + - name: marker + expr_type: art_obj_type + - name: robot + expr_type: robot_type + precondition: null + postcondition: + - robot_at(marker, robot) + task_info: + task: NavToObjTask-v0 + task_def: "nav_to_obj" + config_args: + task.force_regenerate: True + task.should_save_to_cache: False + + - name: open_fridge_by_name + parameters: + - name: fridge_id + expr_type: fridge_type + - name: robot + expr_type: robot_type + precondition: + expr_type: AND + sub_exprs: + - robot_at(fridge_id, robot) + - closed_fridge(fridge_id) + postcondition: + - opened_fridge(fridge_id) + task_info: + task_def: "open_fridge" + task: RearrangeOpenFridgeTask-v0 + add_task_args: + marker: fridge_id + config_args: + task.base_angle_noise: 0.0 + task.spawn_region_scale: 0.0 + + - name: close_fridge_by_name + parameters: + - name: fridge_id + expr_type: fridge_type + - name: robot + expr_type: robot_type + precondition: + expr_type: AND + sub_exprs: + - robot_at(fridge_id, robot) + - opened_fridge(fridge_id) + postcondition: + - closed_fridge(fridge_id) + task_info: + task_def: "close_fridge" + task: RearrangeCloseFridgeTask-v0 + add_task_args: + marker: fridge_id + config_args: + task.base_angle_noise: 0.0 + task.spawn_region_scale: 0.0 + + - name: open_cab_by_name + parameters: + - name: marker + expr_type: cab_type + - name: robot + expr_type: robot_type + precondition: + expr_type: AND + sub_exprs: + - robot_at(marker, robot) + - closed_cab(marker) + postcondition: + - opened_cab(marker) + task_info: + task_def: "open_cab" + task: RearrangeOpenDrawerTask-v0 + config_args: + task.base_angle_noise: 0.0 + task.spawn_region_scale: 0.0 + + - name: close_cab_by_name + parameters: + - name: marker + expr_type: cab_type + - name: robot + expr_type: robot_type + precondition: + expr_type: AND + sub_exprs: + - robot_at(marker, robot) + - opened_cab(marker) + postcondition: + - closed_cab(marker) + task_info: + task_def: "close_cab" + task: RearrangeCloseDrawerTask-v0 + config_args: + task.base_angle_noise: 0.0 + task.spawn_region_scale: 0.0 diff --git a/habitat-lab/habitat/tasks/rearrange/multi_task/pddl_action.py b/habitat-lab/habitat/tasks/rearrange/multi_task/pddl_action.py index 2c0966f11c..4e72132791 100644 --- a/habitat-lab/habitat/tasks/rearrange/multi_task/pddl_action.py +++ b/habitat-lab/habitat/tasks/rearrange/multi_task/pddl_action.py @@ -162,7 +162,7 @@ def apply(self, sim_info: PddlSimInfo) -> None: p.set_state(sim_info) @property - def param_values(self): + def param_values(self) -> Optional[List[PddlEntity]]: if self._param_values is None: raise ValueError( "Accessing action param values before they are set." diff --git a/habitat-lab/habitat/tasks/rearrange/rearrange_sim.py b/habitat-lab/habitat/tasks/rearrange/rearrange_sim.py index af50754257..7983f39700 100644 --- a/habitat-lab/habitat/tasks/rearrange/rearrange_sim.py +++ b/habitat-lab/habitat/tasks/rearrange/rearrange_sim.py @@ -161,6 +161,7 @@ def sleep_all_objects(self): rom = self.get_rigid_object_manager() for _, ro in rom.get_objects_by_handle_substring().items(): ro.awake = False + aom = self.get_articulated_object_manager() for _, ao in aom.get_objects_by_handle_substring().items(): ao.awake = False @@ -253,6 +254,12 @@ def reconfigure(self, config: "DictConfig", ep_info: RearrangeEpisode): if self.habitat_config.auto_sleep: self.sleep_all_objects() + rom = self.get_rigid_object_manager() + self._obj_orig_motion_types = { + handle: ro.motion_type + for handle, ro in rom.get_objects_by_handle_substring().items() + } + if new_scene: self._load_navmesh(ep_info) @@ -762,7 +769,6 @@ def internal_step( Never call sim.step_world directly or miss updating the robot. """ - # optionally step physics and update the robot for benchmarking purposes if self._step_physics: self.step_world(dt) diff --git a/habitat-lab/habitat/tasks/rearrange/rearrange_task.py b/habitat-lab/habitat/tasks/rearrange/rearrange_task.py index 1c9d38f2a4..4563bc490f 100644 --- a/habitat-lab/habitat/tasks/rearrange/rearrange_task.py +++ b/habitat-lab/habitat/tasks/rearrange/rearrange_task.py @@ -96,6 +96,9 @@ def __init__( # Duplicate sensors that handle robots. One for each robot. self._duplicate_sensor_suite(self.sensor_suite) + def overwrite_sim_config(self, config: Any, episode: Episode) -> Any: + return config + @property def targ_idx(self): return self._targ_idx diff --git a/habitat-lab/habitat/tasks/rearrange/sub_tasks/pick_task.py b/habitat-lab/habitat/tasks/rearrange/sub_tasks/pick_task.py index 23eebb33a3..f1f396d169 100644 --- a/habitat-lab/habitat/tasks/rearrange/sub_tasks/pick_task.py +++ b/habitat-lab/habitat/tasks/rearrange/sub_tasks/pick_task.py @@ -58,7 +58,7 @@ def _gen_start_pos(self, sim, episode, sel_idx): start_pos, angle_to_obj, was_succ = get_robot_spawns( snap_pos, self._config.base_angle_noise, - self._config.spawn_max_dists_to_obj, + self._config.spawn_max_dist_to_obj, sim, self._config.num_spawn_attempts, self._config.physics_stability_steps, diff --git a/habitat-lab/habitat/utils/render_wrapper.py b/habitat-lab/habitat/utils/render_wrapper.py index 8000093d99..63ad257b6c 100644 --- a/habitat-lab/habitat/utils/render_wrapper.py +++ b/habitat-lab/habitat/utils/render_wrapper.py @@ -66,6 +66,10 @@ def append_text_to_image( def overlay_frame(frame, info, additional=None): + """ + Renders text from the `info` dictionary to the `frame` image. + """ + lines = [] flattened_info = flatten_dict(info) for k, v in flattened_info.items(): diff --git a/scripts/generate_profile_shell_scripts.py b/scripts/generate_profile_shell_scripts.py index 030bcf29ec..b61728e29a 100644 --- a/scripts/generate_profile_shell_scripts.py +++ b/scripts/generate_profile_shell_scripts.py @@ -22,7 +22,7 @@ if __name__ == "__main__": # The Habitat-lab program to be profiled (the command you usually use to # invoke it). - program_str = "python -u -m habitat_baselines.run --exp-config habitat-baselines/habitat_baselines/config/pointnav/ddppo_pointnav.yaml --run-type train" + program_str = "python -u -m habitat_baselines.run --config-name=pointnav/ddppo_pointnav.yaml" # Path to Nsight Systems nsys command-line tool. This hard-coded path is # for the FAIR cluster. diff --git a/test/test_baseline_config.py b/test/test_baseline_config.py index 533944bd78..15b3fb524b 100644 --- a/test/test_baseline_config.py +++ b/test/test_baseline_config.py @@ -25,4 +25,9 @@ def test_baselines_configs(test_cfg_path): cleaned_path = test_cfg_path.replace( "habitat-baselines/habitat_baselines/config/", "" ) + if "habitat_baselines" in cleaned_path: + # Do not test non-standalone config options that are + # supposed to be used with "main" configs. + return + get_config(cleaned_path) diff --git a/test/test_baseline_training.py b/test/test_baseline_training.py index b129bbff75..aa1fdeccd9 100644 --- a/test/test_baseline_training.py +++ b/test/test_baseline_training.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import glob +import itertools import os import random @@ -13,6 +14,7 @@ from habitat.config import read_write from habitat.config.default import get_agent_config +from habitat_baselines.run import execute_exp try: import torch @@ -148,6 +150,79 @@ def test_trainers(config_path, num_updates, overrides, trainer_name): # Training should complete without raising an error. +@pytest.mark.parametrize( + "config_path,policy_type,skill_type,mode", + list( + itertools.product( + [ + "habitat-baselines/habitat_baselines/config/rearrange/rl_hierarchical_oracle_nav.yaml", + "habitat-baselines/habitat_baselines/config/rearrange/rl_hierarchical.yaml", + ], + [ + "hl_neural", + "hl_fixed", + ], + [ + "nn_skills", + "oracle_skills", + ], + [ + "eval", + "train", + ], + ) + ), +) +def test_hrl(config_path, policy_type, skill_type, mode): + TRAIN_LOG_FILE = "data/test_train.log" + + if policy_type == "hl_neural" and skill_type == "nn_skills": + return + if policy_type == "hl_fixed" and mode == "train": + return + if skill_type == "oracle_skills" and "oracle" not in config_path: + return + # Remove the checkpoints from previous tests + for f in glob.glob("data/test_checkpoints/test_training/*"): + os.remove(f) + if os.path.exists(TRAIN_LOG_FILE): + os.remove(TRAIN_LOG_FILE) + + # Setup the training + config = get_config( + config_path, + [ + "habitat_baselines.num_updates=1", + "habitat_baselines.eval.split=minival", + "habitat.dataset.split=minival", + "habitat_baselines.total_num_steps=-1.0", + "habitat_baselines.test_episode_count=1", + "habitat_baselines.checkpoint_folder=data/test_checkpoints/test_training", + f"habitat_baselines.log_file={TRAIN_LOG_FILE}", + f"habitat_baselines/rl/policy={policy_type}", + f"habitat_baselines/rl/policy/hierarchical_policy/defined_skills={skill_type}", + ], + ) + with read_write(config): + config.habitat_baselines.eval.update({"video_option": []}) + for ( + skill_name, + skill, + ) in ( + config.habitat_baselines.rl.policy.hierarchical_policy.defined_skills.items() + ): + if skill.load_ckpt_file == "": + continue + skill.update( + { + "force_config_file": f"benchmark/rearrange={skill_name}", + "max_skill_steps": 1, + "load_ckpt_file": "", + } + ) + execute_exp(config, mode) + + @pytest.mark.skipif( int(os.environ.get("TEST_BASELINE_SMALL", 0)) == 0, reason="Full training tests did not run. Need `export TEST_BASELINE_SMALL=1", diff --git a/test/test_rearrange_task.py b/test/test_rearrange_task.py index ecd353a02d..99d012871c 100644 --- a/test/test_rearrange_task.py +++ b/test/test_rearrange_task.py @@ -4,16 +4,12 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import gc -import itertools import json -import os import os.path as osp import time from glob import glob import pytest -import torch import yaml from omegaconf import DictConfig, OmegaConf @@ -29,8 +25,6 @@ from habitat.datasets.rearrange.rearrange_dataset import RearrangeDatasetV0 from habitat.tasks.rearrange.multi_task.composite_task import CompositeTask from habitat_baselines.config.default import get_config as baselines_get_config -from habitat_baselines.rl.ddppo.ddp_utils import find_free_port -from habitat_baselines.run import execute_exp CFG_TEST = "benchmark/rearrange/pick.yaml" GEN_TEST_CFG = ( @@ -109,33 +103,6 @@ def test_rearrange_baseline_envs(test_cfg_path): ) -@pytest.mark.parametrize( - "test_cfg_path", - list( - glob("habitat-lab/habitat/config/benchmark/rearrange/*"), - ), -) -def test_rearrange_tasks(test_cfg_path): - """ - Test the underlying Habitat Tasks - """ - if not osp.isfile(test_cfg_path): - return - - config = get_config(test_cfg_path) - if ( - config.habitat.dataset.data_path - == "data/ep_datasets/bench_scene.json.gz" - ): - pytest.skip( - "This config is only useful for examples and does not have the generated dataset" - ) - - with habitat.Env(config=config) as env: - for _ in range(5): - env.reset() - - @pytest.mark.parametrize( "test_cfg_path", list( @@ -209,33 +176,3 @@ def test_rearrange_episode_generator( logger.info( f"successful_ep = {len(dataset.episodes)} generated in {time.time()-start_time} seconds." ) - - -@pytest.mark.parametrize( - "test_cfg_path,mode", - list( - itertools.product( - glob("habitat-baselines/habitat_baselines/config/tp_srl_test/*"), - ["eval"], - ) - ), -) -def test_tp_srl(test_cfg_path, mode): - # For testing with world_size=1 - os.environ["MAIN_PORT"] = str(find_free_port()) - - baseline_config = baselines_get_config( - test_cfg_path.replace( - "habitat-baselines/habitat_baselines/config/", "" - ), - ["habitat_baselines.eval.split=train"], - ) - - execute_exp(baseline_config, mode) - - # Needed to destroy the trainer - gc.collect() - - # Deinit processes group - if torch.distributed.is_initialized(): - torch.distributed.destroy_process_group()