Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decision transformer #52

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions habitat_extensions/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@
# ----------------------------------------------------------------------------
_C.TASK.VLN_ORACLE_PROGRESS_SENSOR = CN()
_C.TASK.VLN_ORACLE_PROGRESS_SENSOR.TYPE = "VLNOracleProgressSensor"

# -----------------------------------------------------------------------------
# VLN ORACLE DISTANCE LEFT SENSOR
# ----------------------------------------------------------------------------
_C.TASK.VLN_ORACLE_DISTANCE_LEFT_SENSOR = CN()
_C.TASK.VLN_ORACLE_DISTANCE_LEFT_SENSOR.TYPE = "VLNOracleDistanceLeftSensor"

# ----------------------------------------------------------------------------
# PANO ANGLE FEATURE SENSOR
# ----------------------------------------------------------------------------
Expand Down
3 changes: 2 additions & 1 deletion habitat_extensions/config/vlnce_task.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ TASK:
SENSORS: [
INSTRUCTION_SENSOR,
SHORTEST_PATH_SENSOR,
VLN_ORACLE_PROGRESS_SENSOR
VLN_ORACLE_PROGRESS_SENSOR,
VLN_ORACLE_DISTANCE_LEFT_SENSOR
]
INSTRUCTION_SENSOR_UUID: instruction
POSSIBLE_ACTIONS: [STOP, MOVE_FORWARD, TURN_LEFT, TURN_RIGHT]
Expand Down
3 changes: 2 additions & 1 deletion habitat_extensions/config/vlnce_task_aug.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ TASK:
SENSORS: [
INSTRUCTION_SENSOR,
SHORTEST_PATH_SENSOR,
VLN_ORACLE_PROGRESS_SENSOR
VLN_ORACLE_PROGRESS_SENSOR,
VLN_ORACLE_DISTANCE_LEFT_SENSOR
]
INSTRUCTION_SENSOR_UUID: instruction
POSSIBLE_ACTIONS: [STOP, MOVE_FORWARD, TURN_LEFT, TURN_RIGHT]
Expand Down
35 changes: 35 additions & 0 deletions habitat_extensions/sensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,41 @@ def get_observation(self, *args: Any, episode, **kwargs: Any) -> float:
[(distance_from_start - distance_to_target) / distance_from_start]
)

@registry.register_sensor
class VLNOracleDistanceLeftSensor(Sensor):
"""Distance left towards goal"""

cls_uuid: str = "distance_left"

def __init__(
self, sim: Simulator, config: Config, *args: Any, **kwargs: Any
) -> None:
self._sim = sim
super().__init__(config=config)

def _get_uuid(self, *args: Any, **kwargs: Any) -> str:
return self.cls_uuid

def _get_sensor_type(self, *args: Any, **kwargs: Any) -> SensorTypes:
return SensorTypes.MEASUREMENT

def _get_observation_space(self, *args: Any, **kwargs: Any) -> Space:
return spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float)

def get_observation(self, *args: Any, episode, **kwargs: Any) -> float:
distance_to_target = self._sim.geodesic_distance(
self._sim.get_agent_state().position.tolist(),
episode.goals[0].position,
)

# just in case the agent ends up somewhere it shouldn't
if not np.isfinite(distance_to_target):
distance_to_target = 0.0


return np.array(
[distance_to_target]
)

@registry.register_sensor
class AngleFeaturesSensor(Sensor):
Expand Down
31 changes: 26 additions & 5 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
from habitat import logger
from habitat_baselines.common.baseline_registry import baseline_registry

import gc
import habitat_extensions # noqa: F401
import vlnce_baselines # noqa: F401
from vlnce_baselines.config.default import get_config
Expand All @@ -30,7 +30,8 @@ def main():
"--exp-config",
type=str,
required=True,
help="path to config yaml containing info about experiment",
help="path to config yaml containing info about experiment. "
"If this is a directory, run all yaml file contained in the dir.",
)
parser.add_argument(
"opts",
Expand All @@ -40,7 +41,19 @@ def main():
)

args = parser.parse_args()
run_exp(**vars(args))
if os.path.isdir(args.exp_config):
conf_parameter = args.exp_config
if os.path.isdir(conf_parameter):
print("Running several config files from:", conf_parameter)
for file in sorted(os.listdir(conf_parameter)):
if file.endswith(".yaml") or file.endswith(".yml"):
file_path = os.path.join(conf_parameter, file)
print("exp_config", file_path)
run_exp(exp_config=file_path, run_type=args.run_type, opts=args.opts)
else:
print("Not a valid config file:", file)
else:
run_exp(**vars(args))


def run_exp(exp_config: str, run_type: str, opts=None) -> None:
Expand All @@ -52,11 +65,16 @@ def run_exp(exp_config: str, run_type: str, opts=None) -> None:
opts: list of strings of additional config options.
"""
config = get_config(exp_config, opts)
logger.info(f"config: {config}")
logdir = "/".join(config.LOG_FILE.split("/")[:-1])
if not logdir:
logdir = "logs"
os.makedirs(logdir, exist_ok=True)
config_file_root__name = logdir+"/"+exp_config.split("/")[-1].split(".")[0]
if logdir:
os.makedirs(logdir, exist_ok=True)
logger.add_filehandler(config.LOG_FILE)
log_file = config_file_root__name + "_" + config.LOG_FILE
logger.add_filehandler(log_file)
logger.info(f"config: {config}")

random.seed(config.TASK_CONFIG.SEED)
np.random.seed(config.TASK_CONFIG.SEED)
Expand Down Expand Up @@ -87,6 +105,9 @@ def run_exp(exp_config: str, run_type: str, opts=None) -> None:
elif run_type == "inference":
trainer.inference()

# avoids to write to all previous files if running in a loop
logger.removeHandler(logger.handlers[-1])
gc.collect()

if __name__ == "__main__":
main()
5 changes: 4 additions & 1 deletion vlnce_baselines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
dagger_trainer,
ddppo_waypoint_trainer,
recollect_trainer,
decision_transformer_trainer
)
from vlnce_baselines.common import environments
from vlnce_baselines.models import cma_policy, seq2seq_policy, waypoint_policy
from vlnce_baselines.models import (cma_policy, seq2seq_policy,
waypoint_policy,
decision_transformer_policy)
19 changes: 12 additions & 7 deletions vlnce_baselines/common/base_il_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def _initialize_policy(
action_space=action_space,
)
self.policy.to(self.device)

self.optimizer = torch.optim.Adam(
# torch.optim.RAdam or torch.optim.Adam for example
self.optimizer = eval(config.IL.optimizer)(
self.policy.parameters(), lr=self.config.IL.lr
)
if load_from_ckpt:
Expand Down Expand Up @@ -191,16 +191,21 @@ def _pause_envs(
):
# pausing envs with no new episode
if len(envs_to_pause) > 0:
# That can avoid nasty bugs when creating new Trainers...
envs_to_pause = sorted(envs_to_pause)
state_index = list(range(envs.num_envs))
for idx in reversed(envs_to_pause):
state_index.pop(idx)
envs.pause_at(idx)

# indexing along the batch dimensions
recurrent_hidden_states = recurrent_hidden_states[state_index]
not_done_masks = not_done_masks[state_index]
prev_actions = prev_actions[state_index]

# indexing along the batch dimensions => because we removed the environement to pause in
# the previous step from the state_index list, we just keep everything related to the active environments
if recurrent_hidden_states is not None:
recurrent_hidden_states = recurrent_hidden_states[state_index]
if not_done_masks is not None:
not_done_masks = not_done_masks[state_index]
if prev_actions is not None:
prev_actions = prev_actions[state_index]
for k, v in batch.items():
batch[k] = v[state_index]

Expand Down
1 change: 1 addition & 0 deletions vlnce_baselines/common/env_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def construct_envs(
env_fn_args=tuple(zip(configs, env_classes)),
auto_reset_done=auto_reset_done,
workers_ignore_signals=workers_ignore_signals,
multiprocessing_start_method=config.MULTIPROCESSING
)
return envs

Expand Down
20 changes: 20 additions & 0 deletions vlnce_baselines/common/environments.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,26 @@
from habitat_extensions.utils import generate_video, navigator_video_frame


@baseline_registry.register_env(name="VLNCEDecisionTransformerEnv")
class VLNCEDecisionTransformerEnv(habitat.RLEnv):
def __init__(self, config: Config, dataset: Optional[Dataset] = None):
super().__init__(config.TASK_CONFIG, dataset)

def get_reward_range(self) -> Tuple[float, float]:
# We don't use the Habitat Framework to create rewards, they are
# created with the trajectories.
return (0.0, 0.0)

def get_reward(self, observations: Observations) -> float:
return 0.0

def get_done(self, observations: Observations) -> bool:
return self._env.episode_over

def get_info(self, observations: Observations) -> Dict[Any, Any]:
return self.habitat_env.get_metrics()


@baseline_registry.register_env(name="VLNCEDaggerEnv")
class VLNCEDaggerEnv(habitat.RLEnv):
def __init__(self, config: Config, dataset: Optional[Dataset] = None):
Expand Down
91 changes: 88 additions & 3 deletions vlnce_baselines/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
_C.VIDEO_DIR = "data/videos/debug"
_C.TENSORBOARD_DIR = "data/tensorboard_dirs/debug"
_C.RESULTS_DIR = "data/checkpoints/pretrained/evals"

# Enables debugging for Pycharm. Default value is "forkserver".
# https://youtrack.jetbrains.com/issue/PY-52273/Debugger-multiprocessing-hangs-pycharm-20213
_C.MULTIPROCESSING = "forkserver" # Set to 'spawn' when debugging with Pycharm,
_C.use_pbar = True
# ----------------------------------------------------------------------------
# EVAL CONFIG
# ----------------------------------------------------------------------------
Expand All @@ -37,7 +40,9 @@
_C.EVAL.EVAL_NONLEARNING = False
_C.EVAL.NONLEARNING = CN()
_C.EVAL.NONLEARNING.AGENT = "RandomAgent"

_C.EVAL.VAL_SEEN_SMALL = "val_seen_80_ep" # only used when ran in train_complete mode
_C.EVAL.VAL_SEEN = "val_seen" # only used when ran in train_complete mode
_C.EVAL.VAL_UNSEEN = "val_unseen" # only used when ran in train_complete mode
# ----------------------------------------------------------------------------
# INFERENCE CONFIG
# ----------------------------------------------------------------------------
Expand All @@ -57,10 +62,13 @@
# IMITATION LEARNING CONFIG
# ----------------------------------------------------------------------------
_C.IL = CN()
_C.IL.optimizer = "torch.optim.Adam"
_C.IL.dataload_workers = 1
_C.IL.lr = 2.5e-4
_C.IL.batch_size = 5
# number of network update rounds per iteration
_C.IL.epochs = 4
_C.IL.preload_dataloader_size = 100
# if true, uses class-based inflection weighting
_C.IL.use_iw = True
# inflection coefficient for RxR training set GT trajectories (guide): 1.9
Expand All @@ -69,9 +77,12 @@
# load an already trained model for fine tuning
_C.IL.load_from_ckpt = False
_C.IL.ckpt_to_load = "data/checkpoints/ckpt.0.pth"
_C.IL.continue_ckpt_naming = True
# if True, loads the optimizer state, epoch, and step_id from the ckpt dict.
_C.IL.is_requeue = False

_C.IL.checkpoint_frequency = 1 # regulates the frequency (epochs % checkpoint_frequency == 0) to save the model.
_C.IL.mean_loss_to_save_checkpoint = 0.40
_C.IL.mean_loss_to_stop_training = 0.06
# ----------------------------------------------------------------------------
# IL: RECOLLECT TRAINER CONFIG
# ----------------------------------------------------------------------------
Expand Down Expand Up @@ -116,6 +127,28 @@
"data/trajectories_dirs/debug/trajectories.lmdb"
)
_C.IL.DAGGER.drop_existing_lmdb_features = True

# ----------------------------------------------------------------------------
# IL: DAGGER / DECISION TRANSFORMER CONFIG
# ----------------------------------------------------------------------------

_C.IL.DECISION_TRANSFORMER = CN()
_C.IL.DECISION_TRANSFORMER.episode_horizon = 183
_C.IL.DECISION_TRANSFORMER.use_perfect_episode_only_for_dagger = True
_C.IL.DECISION_TRANSFORMER.use_oracle_actions = False
_C.IL.DECISION_TRANSFORMER.reward_type = "POINT_GOAL_NAV_REWARD" # POINT_GOAL_NAV_REWARD or SPARSE_REWARD
_C.IL.DECISION_TRANSFORMER.sensor_uuid = "distance_left" # USed to calculate the Return To Go
_C.IL.DECISION_TRANSFORMER.recompute_reward = True
_C.IL.DECISION_TRANSFORMER.POINT_GOAL_NAV_REWARD = CN()
_C.IL.DECISION_TRANSFORMER.POINT_GOAL_NAV_REWARD.step_penalty = -0.01
_C.IL.DECISION_TRANSFORMER.POINT_GOAL_NAV_REWARD.success = 1.0
_C.IL.DECISION_TRANSFORMER.SPARSE_REWARD = CN()
_C.IL.DECISION_TRANSFORMER.POINT_GOAL_NAV_REWARD.step_penalty = -0.01
_C.IL.DECISION_TRANSFORMER.POINT_GOAL_NAV_REWARD.success = 1.0
_C.IL.DECISION_TRANSFORMER.NDTW_REWARD = CN()
_C.IL.DECISION_TRANSFORMER.NDTW_REWARD.step_penalty = -0.01
_C.IL.DECISION_TRANSFORMER.NDTW_REWARD.success = 1.0

# ----------------------------------------------------------------------------
# RL CONFIG
# ----------------------------------------------------------------------------
Expand Down Expand Up @@ -284,6 +317,58 @@
_C.MODEL.WAYPOINT.discrete_offsets = 7
_C.MODEL.WAYPOINT.offset_temperature = 1.0

# ----------------------------------------------------------------------------
# DECISION TRANSFORMER CONFIG
# ----------------------------------------------------------------------------
_C.MODEL.DECISION_TRANSFORMER = CN()
_C.MODEL.DECISION_TRANSFORMER.use_re_zero = False # https://arxiv.org/abs/2003.04887
_C.MODEL.DECISION_TRANSFORMER.hidden_dim = 128
# the max in the training split.
_C.MODEL.DECISION_TRANSFORMER.episode_horizon = _C.IL.DECISION_TRANSFORMER.episode_horizon
_C.MODEL.DECISION_TRANSFORMER.reward_type = "POINT_GOAL_NAV_REWARD" # POINT_GOAL_NAV_REWARD or SPARSE_REWARD
_C.MODEL.DECISION_TRANSFORMER.return_to_go_inference = 1.0
_C.MODEL.DECISION_TRANSFORMER.spatial_output = False # If set to false, depth and rgb feature are averaged
_C.MODEL.DECISION_TRANSFORMER.model_type = None
_C.MODEL.DECISION_TRANSFORMER.n_layer = 2
_C.MODEL.DECISION_TRANSFORMER.n_head = 1
_C.MODEL.DECISION_TRANSFORMER.n_embd = _C.MODEL.DECISION_TRANSFORMER.hidden_dim
_C.MODEL.DECISION_TRANSFORMER.use_transformer_encoded_instruction = False
# these options must be filled in externally
_C.MODEL.DECISION_TRANSFORMER.vocab_size = 4
_C.MODEL.DECISION_TRANSFORMER.step_size = 3 #We multiply by three because at each time step, we use [reward, action, state].
_C.MODEL.DECISION_TRANSFORMER.block_size = _C.MODEL.DECISION_TRANSFORMER.episode_horizon *_C.MODEL.DECISION_TRANSFORMER.step_size
_C.MODEL.DECISION_TRANSFORMER.allowed_models = ["DecisionTransformerNet",
"DecisionTransformerEnhancedNet",
"FullDecisionTransformerNet",
"FullDecisionTransformerSingleVisionStateNet"]
_C.MODEL.DECISION_TRANSFORMER.allowed_rewards = ["point_nav_reward_to_go", "sparse_reward_to_go",
"point_nav_reward", "sparse_reward", "ndtw_reward",
"ndtw_reward_to_go"]
_C.MODEL.DECISION_TRANSFORMER.exclude_past_action_for_prediction = True
_C.MODEL.DECISION_TRANSFORMER.normalize_depth = False # Needs to be done during dataset creation
_C.MODEL.DECISION_TRANSFORMER.normalize_rgb = False
# dropout hyperparameters
_C.MODEL.DECISION_TRANSFORMER.embd_pdrop = 0.1
_C.MODEL.DECISION_TRANSFORMER.resid_pdrop = 0.1
_C.MODEL.DECISION_TRANSFORMER.attn_pdrop = 0.1
_C.MODEL.DECISION_TRANSFORMER.activation_action_drop = 0.3
_C.MODEL.DECISION_TRANSFORMER.activation_instruction_drop = 0.0
_C.MODEL.DECISION_TRANSFORMER.activation_rgb_drop = 0.0
_C.MODEL.DECISION_TRANSFORMER.activation_depth_drop = 0.0
_C.MODEL.DECISION_TRANSFORMER.ENCODER = CN()
_C.MODEL.DECISION_TRANSFORMER.ENCODER.n_layer = 2
_C.MODEL.DECISION_TRANSFORMER.ENCODER.n_head = 1
_C.MODEL.DECISION_TRANSFORMER.ENCODER.use_sentence_encoding = True
# Only for FullDecisionTransformerNet
_C.MODEL.DECISION_TRANSFORMER.ENCODER.use_rgb_state_embeddings = True
_C.MODEL.DECISION_TRANSFORMER.ENCODER.use_depth_state_embeddings = True
_C.MODEL.DECISION_TRANSFORMER.ENCODER.use_output_rgb_instructions = True
_C.MODEL.DECISION_TRANSFORMER.ENCODER.use_output_depth_instructions = True
_C.MODEL.DECISION_TRANSFORMER.ENCODER.use_output_rgb = True
_C.MODEL.DECISION_TRANSFORMER.ENCODER.use_output_depth = True
# Only for FullDecisionTransformerSingleVisionStateNet
_C.MODEL.DECISION_TRANSFORMER.ENCODER.use_output_state_instructions = True
_C.MODEL.DECISION_TRANSFORMER.ENCODER.use_output_state = True

def purge_keys(config: CN, keys: List[str]) -> None:
for k in keys:
Expand Down
18 changes: 18 additions & 0 deletions vlnce_baselines/config/r2r_baselines/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,21 @@ gdown https://drive.google.com/uc?id=1xIxh5eUkjGzSL_3AwBqDQlNXjkFrpcg4
# Seq2Seq_DA (135MB)
gdown https://drive.google.com/uc?id=14y7dXkAEwB_q81cDCow8JNKD2aAAPxbW
```


## Experimental

A trainer based on the [Decision Transformer](https://arxiv.org/abs/2106.01345) has been added.

The results are underwhelming (below the best Seq2Seq) but it constitutes a good starting point for
anybody wanting to test Transformer in VLN-CE.

Pretrained models with corresponding training file under :


[Decision Transformer Agent](https://drive.google.com/file/d/1-E1l5g7DM36m3HYx8b4b4CNBC8d-OS83/view?usp=sharing)

[Enhanced Decision Transformer Agent](https://drive.google.com/file/d/1b2hpkHpiZIc2CBsaLzZWCa7qurfKDDsu/view?usp=sharing)

[Full Decision Transformer Agent](https://drive.google.com/file/d/1rS2_yo9_z35zzaHW4CtByorZ-jpDpht_/view?usp=sharing)

Loading