-
Notifications
You must be signed in to change notification settings - Fork 1
/
statistics.py
110 lines (93 loc) · 3.1 KB
/
statistics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# ACKNOWLEDGEMENT - Modified from https://github.com/tambetm/simple_dqn/blob/master/src/statistics.py
import numpy as np
import tensorflow as tf
from tensorflow import contrib
import random
import scipy
from scipy import misc
import csv
import time
import sys
import constants as C
import pdb
class Stats:
"""
Keey running statistics on everything we want to plot.
These statistics will only get plotted every C.STEPS_PER_PLOT
"""
def __init__(self, network, game):
self.network = network
self.game = game
self.network.callback = self
self.csv_path = C.stats_csv_path
if self.csv_path != "":
self.csv_file = open(self.csv_path, "wb")
self.csv_writer = csv.writer(self.csv_file)
self.csv_writer.writerow((
"epoch",
"steps",
"average_reward_per_game",
"average_q",
"average_cost",
"num_games_per_epoch",
"epoch_max_reward",
"epoch_min_reward",
))
self.csv_file.flush()
self.num_games_total = 0
self.epoch = 0
self.num_steps = 0
self.game_rewards = 0 # running tally of rewards for the current game
self.average_reward_per_game = 0 # running average
self.average_cost = 0 # running average
self.average_q = 0 # running average
# these are on a per epoch basis
self.epoch_max_reward = 0
self.epoch_min_reward = 999999
self.num_games_per_epoch = 0
# call on step in game
def on_step(self, action, reward, terminal):
self.game_rewards += reward
self.num_steps += 1
if terminal:
self.num_games_total += 1
self.num_games_per_epoch += 1
self.epoch_max_reward = max(self.epoch_max_reward, self.game_rewards)
self.epoch_min_reward = min(self.epoch_min_reward, self.game_rewards)
self.average_reward_per_game += float(self.game_rewards - self.average_reward_per_game) / self.num_games_total
self.game_rewards = 0 # reset current running reward
# call every batch update
def on_train(self, loss, runs):
self.average_cost += float(loss - self.average_cost) / float(runs)
def write(self, epoch):
self.epoch = epoch
print "Plotted Statistics at Epoch: {}".format(self.epoch)
if self.network.validation:
# validation set was initialized
max_qvalues = self.network.predict(self.network.validation_set)
self.average_q = np.mean(max_qvalues)
else:
self.average_q = 0
if self.csv_path != "":
self.csv_writer.writerow((
self.epoch,
self.num_steps,
self.average_reward_per_game,
self.average_q,
self.average_cost,
self.num_games_per_epoch,
self.epoch_max_reward,
self.epoch_min_reward,
))
self.csv_file.flush()
# reset all stats that are per epoch only
self.epoch_max_reward = 0
self.epoch_min_reward = 999999
self.num_games_per_epoch = 0
self.average_cost = 0
self.network.trained_called = 0
self.average_reward_per_game = 0
self.num_games_total = 0
def close(self):
if self.csv_path:
self.csv_file.close()