-
Notifications
You must be signed in to change notification settings - Fork 26
/
test.py
100 lines (83 loc) · 3.75 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# -*- coding: utf-8 -*-
import time
from datetime import datetime
import gym
import torch
from torch.autograd import Variable
from model import ActorCritic
from utils import state_to_tensor, plot_line
def test(rank, args, T, shared_model):
torch.manual_seed(args.seed + rank)
env = gym.make(args.env)
env.seed(args.seed + rank)
model = ActorCritic(env.observation_space, env.action_space, args.hidden_size, args.sigma_init, args.no_noise)
model.eval()
can_test = True # Test flag
t_start = 1 # Test step counter to check against global counter
rewards, steps = [], [] # Rewards and steps for plotting
l = str(len(str(args.T_max))) # Max num. of digits for logging steps
done = True # Start new episode
while T.value() <= args.T_max:
if can_test:
t_start = T.value() # Reset counter
# Evaluate over several episodes and average results
avg_rewards, avg_episode_lengths = [], []
for _ in range(args.evaluation_episodes):
while True:
# Reset or pass on hidden state
if done:
# Sync with shared model every episode
model.load_state_dict(shared_model.state_dict())
hx = Variable(torch.zeros(1, args.hidden_size), volatile=True)
cx = Variable(torch.zeros(1, args.hidden_size), volatile=True)
# Reset environment and done flag
state = state_to_tensor(env.reset())
done, episode_length = False, 0
reward_sum = 0
model.remove_noise() # Run without noise
# Optionally render validation states
if args.render:
env.render()
# Calculate policy
policy, _, (hx, cx) = model(Variable(state, volatile=True), (hx.detach(), cx.detach())) # Break graph for memory efficiency
# Choose action greedily
action = policy.max(1)[1].data[0]
# Step
state, reward, done, _ = env.step(action)
state = state_to_tensor(state)
reward_sum += reward
done = done or episode_length >= args.max_episode_length # Stop episodes at a max length
episode_length += 1 # Increase episode counter
# Log and reset statistics at the end of every episode
if done:
avg_rewards.append(reward_sum)
avg_episode_lengths.append(episode_length)
break
print(('[{}] Step: {:<' + l + '} Avg. Reward: {:<8} Avg. Episode Length: {:<8}').format(
datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3],
t_start,
sum(avg_rewards) / args.evaluation_episodes,
sum(avg_episode_lengths) / args.evaluation_episodes))
if not args.no_noise:
print(('[{}] Step: {:<' + l + '} Avg. σ^w (π): {:<8} Avg. Sigma σ^b (π): {:<8}').format(
datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3],
t_start,
model.fc_actor.sigma_weight.abs().mean().data[0],
model.fc_actor.sigma_bias.abs().mean().data[0]))
print(('[{}] Step: {:<' + l + '} Avg. σ^w (V): {:<8} Avg. Sigma σ^b (V): {:<8}').format(
datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3],
t_start,
model.fc_critic.sigma_weight.abs().mean().data[0],
model.fc_critic.sigma_bias.abs().mean().data[0]))
if args.evaluate:
return
rewards.append(avg_rewards) # Keep all evaluations
steps.append(t_start)
plot_line(steps, rewards) # Plot rewards
torch.save(model.state_dict(), 'model.pth') # Save model params
can_test = False # Finish testing
else:
if T.value() - t_start >= args.evaluation_interval:
can_test = True
time.sleep(0.001) # Check if available to test every millisecond
env.close()