-
Notifications
You must be signed in to change notification settings - Fork 102
/
run_d2rl_sac.py
32 lines (30 loc) · 1.09 KB
/
run_d2rl_sac.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from tf2rl.algos.d2rl_sac import D2RLSAC
from tf2rl.experiments.trainer import Trainer
from tf2rl.envs.utils import make
if __name__ == '__main__':
parser = Trainer.get_argument()
parser = D2RLSAC.get_argument(parser)
parser.add_argument('--env-name', type=str, default="Pendulum-v0")
parser.set_defaults(batch_size=256)
parser.set_defaults(n_warmup=10000)
parser.set_defaults(max_steps=3e6)
args = parser.parse_args()
env = make(args.env_name)
test_env = make(args.env_name)
policy = D2RLSAC(
state_shape=env.observation_space.shape,
action_dim=env.action_space.high.size,
actor_units=(256, 256, 256, 256),
critic_units=(256, 256, 256, 256),
gpu=args.gpu,
memory_capacity=args.memory_capacity,
max_action=env.action_space.high[0],
batch_size=args.batch_size,
n_warmup=args.n_warmup,
alpha=args.alpha,
auto_alpha=args.auto_alpha)
trainer = Trainer(policy, env, args, test_env=test_env)
if args.evaluate:
trainer.evaluate_policy_continuously()
else:
trainer()