diff --git a/scripts/experiment.py b/scripts/experiment.py index 51c9f37..bde6a62 100644 --- a/scripts/experiment.py +++ b/scripts/experiment.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -import gym +import gym import safety_gym import safe_rl from safe_rl.utils.run_utils import setup_logger_kwargs @@ -21,8 +21,9 @@ def main(robot, task, algo, seed, exp_name, cpu): assert robot.lower() in robot_list, "Invalid robot" # Hyperparameters - exp_name = algo + '_' + robot + task - if robot=='Doggo': + if exp_name is None: + exp_name = (algo + '_' + robot.lower() + task.lower()) + if robot == 'Doggo': num_steps = 1e8 steps_per_epoch = 60000 else: @@ -37,7 +38,6 @@ def main(robot, task, algo, seed, exp_name, cpu): mpi_fork(cpu) # Prepare Logger - exp_name = exp_name or (algo + '_' + robot.lower() + task.lower()) logger_kwargs = setup_logger_kwargs(exp_name, seed) # Algo and Env @@ -58,7 +58,6 @@ def main(robot, task, algo, seed, exp_name, cpu): ) - if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() @@ -69,5 +68,5 @@ def main(robot, task, algo, seed, exp_name, cpu): parser.add_argument('--exp_name', type=str, default='') parser.add_argument('--cpu', type=int, default=1) args = parser.parse_args() - exp_name = args.exp_name if not(args.exp_name=='') else None - main(args.robot, args.task, args.algo, args.seed, exp_name, args.cpu) \ No newline at end of file + exp_name = args.exp_name if not(args.exp_name == '') else None + main(args.robot, args.task, args.algo, args.seed, exp_name, args.cpu)