-
Notifications
You must be signed in to change notification settings - Fork 12
/
eval.py
134 lines (110 loc) · 4.05 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# ============================================================== #
# imports {{{
# ============================================================== #
import gym
import sys
from gym import spaces
import numpy as np
import scipy.io as sio
import json # json.loads
import os # os.listdir
import re # re.findall
import ray
import ray.rllib.agents.dqn as dqn
import argparse as ap
from CodeEnv import *
# }}}
# ============================================================== #
# parameters {{{
# ============================================================== #
tmpdir = "/tmp/fc94"
parser = ap.ArgumentParser("python eval.py")
parser.add_argument("dB_range", help="set of SNRs, e.g., [5] or [2,3,4]")
parser.add_argument("minCwErr", help="get at least this many codeword errors")
parser.add_argument("maxCw", help="but stop if this many codewords have been simulated")
parser.add_argument("NUM", help="number appended to the name of the result file")
parser.add_argument("--path", dest="path_to_results", default="./latest", help="path to the result directory")
parser.add_argument('--save', dest='SAVE', action='store_true')
parser.add_argument('--no-save', dest='SAVE', action='store_false')
parser.set_defaults(SAVE=True)
args = parser.parse_args()
dB_range = np.asarray(eval(args.dB_range))
minCwErr = int(args.minCwErr)
maxCw = int(args.maxCw)
NUM = int(args.NUM)
path_to_results = args.path_to_results
SAVE = args.SAVE
if SAVE:
save_path = path_to_results+"/res_{}.mat".format(NUM)
save_path_txt = path_to_results+"/res_{}.txt".format(NUM)
# }}}
# ============================================================== #
with open(path_to_results + "/params.json") as h:
config = json.loads(h.read())
env_config = config["env_config"]
# find all checkpoint and load the latest
filenames = os.listdir(path_to_results)
checkpoint_numbers = []
for fn in filenames:
m = re.findall('checkpoint_(\d+)', fn)
if not m: continue
print(m[0])
checkpoint_numbers.append(int(m[0]))
mc = max(checkpoint_numbers)
checkpoint_path = path_to_results+"/"+"checkpoint_{}/checkpoint-{}".format(mc,mc)
print("found {} checkpoints".format(len(checkpoint_numbers)))
print("restoring "+checkpoint_path)
# ============================================================== #
# evaluation {{{
# ============================================================== #
#ray.init()
ray.init(temp_dir=tmpdir+"/ray") # you may need to change the temp directory in case it runs on a cluster or shared machine
if config["optimizer_class"] == "AsyncReplayOptimizer":
trainer = dqn.ApexTrainer(config=config, env=CodeEnv)
else:
trainer = dqn.DQNTrainer(config=config, env=CodeEnv)
trainer.restore(checkpoint_path)
env = CodeEnv(env_config)
n = env.n
dB_len = len(dB_range)
BitErr = np.zeros([dB_len], dtype=int)
CwErr = np.zeros([dB_len], dtype=int)
totCw = np.zeros([dB_len], dtype=int)
totBit = np.zeros([dB_len], dtype=int)
for i in range(dB_len):
print("\n--------\nSimulating EbNo = {} dB".format(dB_range[i]))
env.set_EbNo_dB(dB_range[i])
while(CwErr[i]<minCwErr and totCw[i]+1<=maxCw):
obs = env.reset()
done = (env.syndrome.sum() == 0)
while not done:
action = trainer.compute_action(obs)
obs, _, done, _ = env.step(action)
#env.render()
BitErrThis = np.sum(env.chat)
BitErr[i] = BitErr[i] + BitErrThis
if BitErrThis > 0:
CwErr[i] = CwErr[i] + 1
totCw[i] += 1
totBit[i] += n
print("CwErr:", CwErr[i])
print("BitErr:", BitErr[i])
print("TotCw:", totCw[i])
print("CER:", CwErr[i]/totCw[i])
print("BER:", BitErr[i]/totBit[i])
if SAVE:
resdict = {
"dB_range": dB_range,
"CwErr": CwErr,
"BitErr": BitErr,
"TotCw": totCw,
"TotBit": totBit,
}
print("\n****\nSaving files to:\n.mat -->"+save_path+"\n.txt -->"+save_path_txt)
sio.savemat(save_path, resdict)
with open(save_path_txt, 'w') as file_txt:
file_txt.write(str(resdict))
ray.shutdown()
print("done!")
# }}}
# ============================================================== #