Skip to content

Commit

Permalink
Add hp search space for PC/classic control.
Browse files Browse the repository at this point in the history
  • Loading branch information
ernestum committed Feb 27, 2024
1 parent 9bb89c9 commit 57018d5
Showing 1 changed file with 41 additions and 0 deletions.
41 changes: 41 additions & 0 deletions tuning/hp_search_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,47 @@ def __call__(
},
},
),
pc_classic_control=RunSacredAsTrial(
sacred_ex=imitation.scripts.train_preference_comparisons.train_preference_comparisons_ex,
suggest_named_configs=lambda _: ["reward.reward_ensemble"],
suggest_config_updates=lambda trial: {
"seed": trial.number,
"environment": {"num_vec": 8},
"total_timesteps": 1e6,
"total_comparisons": 1000,
"active_selection": True,
"active_selection_oversampling": trial.suggest_int("active_selection_oversampling", 1, 11),
"comparison_queue_size": trial.suggest_int("comparison_queue_size", 1, 1001), # upper bound determined by total_comparisons=1000
"exploration_frac": trial.suggest_float("exploration_frac", 0.0, 0.5),
"fragment_length": trial.suggest_int("fragment_length", 1, 1001), # trajectories are 1000 steps long
"gatherer_kwargs": {
"temperature": trial.suggest_float("gatherer_temperature", 0.0, 2.0),
"discount_factor": trial.suggest_float("gatherer_discount_factor", 0.95, 1.0),
"sample": trial.suggest_categorical("gatherer_sample", [True, False]),
},
"initial_epoch_multiplier": trial.suggest_float("initial_epoch_multiplier", 1, 200.0),
"initial_comparison_frac": trial.suggest_float("initial_comparison_frac", 0.01, 1.0),
"num_iterations": trial.suggest_int("num_iterations", 1, 51),
"preference_model_kwargs": {
"noise_prob": trial.suggest_float("preference_model_noise_prob", 0.0, 0.1),
"discount_factor": trial.suggest_float("preference_model_discount_factor", 0.95, 1.0),
},
"query_schedule": trial.suggest_categorical("query_schedule", ["hyperbolic", "constant", "inverse_quadratic"]),
"trajectory_generator_kwargs": {
"switch_prob": trial.suggest_float("tr_gen_switch_prob", 0.1, 1),
"random_prob": trial.suggest_float("tr_gen_random_prob", 0.1, 0.9),
},
"transition_oversampling": trial.suggest_float("transition_oversampling", 0.9, 2.0),
"reward_trainer_kwargs": {
"epochs": trial.suggest_int("reward_trainer_epochs", 1, 11),
},
"rl": {
"rl_kwargs": {
"ent_coef": trial.suggest_float("rl_ent_coef", 1e-7, 1e-3, log=True),
},
},
},
),
sqil=RunSacredAsTrial(
sacred_ex=imitation.scripts.train_imitation.train_imitation_ex,
command_name="sqil",
Expand Down

0 comments on commit 57018d5

Please sign in to comment.