diff --git a/examples/train.py b/examples/train.py index 1b2f22476..5f8edbb22 100644 --- a/examples/train.py +++ b/examples/train.py @@ -213,13 +213,10 @@ def train_rllib(submodule, flags): run_experiments({flow_params["exp_tag"]: exp_config}) -def train_h_baselines(flow_params, args, multiagent): +def train_h_baselines(env_name, args, multiagent): """Train policies using SAC and TD3 with h-baselines.""" from hbaselines.algorithms import OffPolicyRLAlgorithm from hbaselines.utils.train import parse_options, get_hyperparameters - from hbaselines.envs.mixed_autonomy import FlowEnv - - flow_params = deepcopy(flow_params) # Get the command-line arguments that are relevant here args = parse_options(description="", example_usage="", args=args) @@ -227,31 +224,6 @@ def train_h_baselines(flow_params, args, multiagent): # the base directory that the logged data will be stored in base_dir = "training_data" - # Create the training environment. - env = FlowEnv( - flow_params, - multiagent=multiagent, - shared=args.shared, - maddpg=args.maddpg, - render=args.render, - version=0 - ) - - # Create the evaluation environment. - if args.evaluate: - eval_flow_params = deepcopy(flow_params) - eval_flow_params['env'].evaluate = True - eval_env = FlowEnv( - eval_flow_params, - multiagent=multiagent, - shared=args.shared, - maddpg=args.maddpg, - render=args.render_eval, - version=1 - ) - else: - eval_env = None - for i in range(args.n_training): # value of the next seed seed = args.seed + i @@ -299,8 +271,8 @@ def train_h_baselines(flow_params, args, multiagent): # Create the algorithm object. alg = OffPolicyRLAlgorithm( policy=policy, - env=env, - eval_env=eval_env, + env="flow:{}".format(env_name), + eval_env="flow:{}".format(env_name) if args.evaluate else None, **hp ) @@ -393,8 +365,7 @@ def main(args): elif flags.rl_trainer.lower() == "stable-baselines": train_stable_baselines(submodule, flags) elif flags.rl_trainer.lower() == "h-baselines": - flow_params = submodule.flow_params - train_h_baselines(flow_params, args, multiagent) + train_h_baselines(flags.exp_config, args, multiagent) else: raise ValueError("rl_trainer should be either 'rllib', 'h-baselines', " "or 'stable-baselines'.") diff --git a/tests/fast_tests/test_examples.py b/tests/fast_tests/test_examples.py index 0b385f28a..b5faf6517 100644 --- a/tests/fast_tests/test_examples.py +++ b/tests/fast_tests/test_examples.py @@ -229,11 +229,11 @@ class TestHBaselineExamples(unittest.TestCase): confirming that it runs. """ @staticmethod - def run_exp(flow_params, multiagent): + def run_exp(env_name, multiagent): train_h_baselines( - flow_params=flow_params, + env_name=env_name, args=[ - flow_params["env_name"].__name__, + env_name, "--initial_exploration_steps", "1", "--total_steps", "10" ], @@ -241,10 +241,10 @@ def run_exp(flow_params, multiagent): ) def test_singleagent_ring(self): - self.run_exp(singleagent_ring.copy(), multiagent=False) + self.run_exp("singleagent_ring", multiagent=False) def test_multiagent_ring(self): - self.run_exp(multiagent_ring.copy(), multiagent=True) + self.run_exp("multiagent_ring", multiagent=True) class TestRllibExamples(unittest.TestCase):