-
Notifications
You must be signed in to change notification settings - Fork 1
/
validate_reinforce.py
72 lines (53 loc) · 2.42 KB
/
validate_reinforce.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import argparse
import datetime
import random
from collections import namedtuple
from itertools import count
from matplotlib.pylab import plt
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
from torch.utils.tensorboard import SummaryWriter
from kalah.agents.maxscoreagent import MaxScoreAgent
from kalah.agents.maxscorerepeatagent import MaxScoreRepeatAgent
from kalah.agents.minimaxagent import MinimaxAgent
from kalah.agents.randomagent import RandomAgent
from kalah.agents.reinforceagent import ReinforceAgent
from kalah.kalahboard import KalahBoard
from kalah.kalahenv import KalahEnv
from kalah.models.reinforce import ReinforceModel
from kalah.utils import compare_agents
parser = argparse.ArgumentParser(description='Validate a reinforce model for Kalah')
parser.add_argument('--bins', type=int, default=6, metavar='B',
help='bins of the Kalah board (default: 6)')
parser.add_argument('--seeds', type=int, default=4, metavar='S',
help='seeds of the Kalah board (default: 4)')
parser.add_argument('--seed', type=int, default=543, metavar='N',
help='random seed (default: 543)')
parser.add_argument('--validation-games', type=int, default=100, metavar='EG',
help='how many games to play (default: 100)')
parser.add_argument('--model-path', type=str, default='./model.pt', metavar='MP',
help='path to trained model (default: ./model.pt)')
args = parser.parse_args()
torch.manual_seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)
def final_compare(model):
print("Comparisons with other agent:")
n_games = args.validation_games
agent_class1 = ReinforceAgent
agent_classes = [RandomAgent, MaxScoreAgent, MaxScoreRepeatAgent, MinimaxAgent]
for agent_class2 in agent_classes:
agent1 = agent_class1(model, seed=args.seed)
agent2 = agent_class2(seed=args.seed)
[wins_agent_one, _, draws, _] = compare_agents(args.bins, args.seeds, n_games, agent1, agent2)
win_percentage = 100*wins_agent_one / n_games
print(agent_class1.__name__, "won", win_percentage,"% ( n =", wins_agent_one, ") of all N =", n_games ,"games against", agent_class2.__name__, "Number of draws:", draws)
def main():
model = torch.load(args.model_path)
model.eval()
final_compare(model)
if __name__ == '__main__':
main()