-
Notifications
You must be signed in to change notification settings - Fork 131
/
Copy pathevaluate.py
82 lines (72 loc) · 3.15 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import argparse
import torch
import time
import imageio
import numpy as np
from pathlib import Path
from torch.autograd import Variable
from utils.make_env import make_env
from algorithms.maddpg import MADDPG
def run(config):
model_path = (Path('./models') / config.env_id / config.model_name /
('run%i' % config.run_num))
if config.incremental is not None:
model_path = model_path / 'incremental' / ('model_ep%i.pt' %
config.incremental)
else:
model_path = model_path / 'model.pt'
if config.save_gifs:
gif_path = model_path.parent / 'gifs'
gif_path.mkdir(exist_ok=True)
maddpg = MADDPG.init_from_save(model_path)
env = make_env(config.env_id, discrete_action=maddpg.discrete_action)
maddpg.prep_rollouts(device='cpu')
ifi = 1 / config.fps # inter-frame interval
for ep_i in range(config.n_episodes):
print("Episode %i of %i" % (ep_i + 1, config.n_episodes))
obs = env.reset()
if config.save_gifs:
frames = []
frames.append(env.render('rgb_array')[0])
env.render('human')
for t_i in range(config.episode_length):
calc_start = time.time()
# rearrange observations to be per agent, and convert to torch Variable
torch_obs = [Variable(torch.Tensor(obs[i]).view(1, -1),
requires_grad=False)
for i in range(maddpg.nagents)]
# get actions as torch Variables
torch_actions = maddpg.step(torch_obs, explore=False)
# convert actions to numpy arrays
actions = [ac.data.numpy().flatten() for ac in torch_actions]
obs, rewards, dones, infos = env.step(actions)
if config.save_gifs:
frames.append(env.render('rgb_array')[0])
calc_end = time.time()
elapsed = calc_end - calc_start
if elapsed < ifi:
time.sleep(ifi - elapsed)
env.render('human')
if config.save_gifs:
gif_num = 0
while (gif_path / ('%i_%i.gif' % (gif_num, ep_i))).exists():
gif_num += 1
imageio.mimsave(str(gif_path / ('%i_%i.gif' % (gif_num, ep_i))),
frames, duration=ifi)
env.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("env_id", help="Name of environment")
parser.add_argument("model_name",
help="Name of model")
parser.add_argument("run_num", default=1, type=int)
parser.add_argument("--save_gifs", action="store_true",
help="Saves gif of each episode into model directory")
parser.add_argument("--incremental", default=None, type=int,
help="Load incremental policy from given episode " +
"rather than final policy")
parser.add_argument("--n_episodes", default=10, type=int)
parser.add_argument("--episode_length", default=25, type=int)
parser.add_argument("--fps", default=30, type=int)
config = parser.parse_args()
run(config)