Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 4 additions & 33 deletions examples/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,45 +213,17 @@ def train_rllib(submodule, flags):
run_experiments({flow_params["exp_tag"]: exp_config})


def train_h_baselines(flow_params, args, multiagent):
def train_h_baselines(env_name, args, multiagent):
"""Train policies using SAC and TD3 with h-baselines."""
from hbaselines.algorithms import OffPolicyRLAlgorithm
from hbaselines.utils.train import parse_options, get_hyperparameters
from hbaselines.envs.mixed_autonomy import FlowEnv

flow_params = deepcopy(flow_params)

# Get the command-line arguments that are relevant here
args = parse_options(description="", example_usage="", args=args)

# the base directory that the logged data will be stored in
base_dir = "training_data"

# Create the training environment.
env = FlowEnv(
flow_params,
multiagent=multiagent,
shared=args.shared,
maddpg=args.maddpg,
render=args.render,
version=0
)

# Create the evaluation environment.
if args.evaluate:
eval_flow_params = deepcopy(flow_params)
eval_flow_params['env'].evaluate = True
eval_env = FlowEnv(
eval_flow_params,
multiagent=multiagent,
shared=args.shared,
maddpg=args.maddpg,
render=args.render_eval,
version=1
)
else:
eval_env = None

for i in range(args.n_training):
# value of the next seed
seed = args.seed + i
Expand Down Expand Up @@ -299,8 +271,8 @@ def train_h_baselines(flow_params, args, multiagent):
# Create the algorithm object.
alg = OffPolicyRLAlgorithm(
policy=policy,
env=env,
eval_env=eval_env,
env="flow:{}".format(env_name),
eval_env="flow:{}".format(env_name) if args.evaluate else None,
**hp
)

Expand Down Expand Up @@ -393,8 +365,7 @@ def main(args):
elif flags.rl_trainer.lower() == "stable-baselines":
train_stable_baselines(submodule, flags)
elif flags.rl_trainer.lower() == "h-baselines":
flow_params = submodule.flow_params
train_h_baselines(flow_params, args, multiagent)
train_h_baselines(flags.exp_config, args, multiagent)
else:
raise ValueError("rl_trainer should be either 'rllib', 'h-baselines', "
"or 'stable-baselines'.")
Expand Down
10 changes: 5 additions & 5 deletions tests/fast_tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,22 +229,22 @@ class TestHBaselineExamples(unittest.TestCase):
confirming that it runs.
"""
@staticmethod
def run_exp(flow_params, multiagent):
def run_exp(env_name, multiagent):
train_h_baselines(
flow_params=flow_params,
env_name=env_name,
args=[
flow_params["env_name"].__name__,
env_name,
"--initial_exploration_steps", "1",
"--total_steps", "10"
],
multiagent=multiagent,
)

def test_singleagent_ring(self):
self.run_exp(singleagent_ring.copy(), multiagent=False)
self.run_exp("singleagent_ring", multiagent=False)

def test_multiagent_ring(self):
self.run_exp(multiagent_ring.copy(), multiagent=True)
self.run_exp("multiagent_ring", multiagent=True)


class TestRllibExamples(unittest.TestCase):
Expand Down