-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathalphazero.py
717 lines (558 loc) · 26.5 KB
/
alphazero.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
# -*- coding: utf-8 -*-
"""AlphaZero
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1Bl4LkvNENrOjkg1eDlS_-g_4E25R8x9O
"""
import torch
import numpy as np
# np.seterr(all='raise')
import time
from torch.nn import Linear, BatchNorm1d
from torch.nn.functional import relu, softmax, leaky_relu
from torch import tanh
import copy
from tictactoe import tictactoe
import matplotlib.pyplot as plt
# set up NN
# game state should have player turn as last value
try:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
except Exception():
device = torch.device("cpu") # force cpu
print(f"Using {device}")
class AlphaZeroConv(torch.nn.Module):
# Before passing through the NN, convert the +1, -1, 0 representation to a one hot encoding scheme.
# We will have 2 3x3 maps with +1 to indicate filled spaces for each player
def __init__(self, input_dim, hidden_layer_dim, output_dim): # should probably use conv layers and residual layers
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.input_layer = Linear(input_dim, hidden_layer_dim)
self.conv2d1 = torch.nn.Conv2d(3, 128, 2, stride=1)
self.conv2d2 = torch.nn.Conv2d(128, 128, 2, stride=1)
self.fc_layers = torch.nn.ModuleList(
[Linear(hidden_layer_dim, hidden_layer_dim), Linear(hidden_layer_dim, hidden_layer_dim)])
self.batch_norms = torch.nn.ModuleList([BatchNorm1d(hidden_layer_dim) for x in self.fc_layers])
self.policy_layer = Linear(hidden_layer_dim, output_dim)
self.value_layer = Linear(hidden_layer_dim, 1)
self.fc1 = Linear(128, 128)
def forward(self, x):
if len(x.shape) == 1:
x = x.unsqueeze(0)
turn = x[:, -1]
board = x[:, :-1]
zeros = torch.zeros_like(board)
map1 = torch.where(board == 1, board, zeros)
map2 = torch.where(board == -1, -board, zeros)
map1 = map1.view(-1, 3, 3)
map2 = map2.view(-1, 3, 3)
turn = torch.ones_like(map1) * turn.unsqueeze(-1).unsqueeze(-1).expand_as(map1)
h = torch.stack((map1, map2, turn), dim=1)
h = leaky_relu(self.conv2d1(h))
h = leaky_relu(self.conv2d2(h))
h = h.squeeze(-1).squeeze(-1) # reduce to fully connected residual layers
h = leaky_relu(self.fc1(h))
# h = relu(self.input_layer(x))
for i, l in enumerate(self.fc_layers): # all the hidden layers
h = l(h) + h
h = self.batch_norms[i](h)
h = leaky_relu(h)
p = softmax(self.policy_layer(h), dim=-1) # probs for each action
_v = self.value_layer(h) # predict a value for this state
v = tanh(_v)
return torch.cat((p, v), dim=1).squeeze()
class AlphaZeroConvLarge(torch.nn.Module):
# Before passing through the NN, convert the +1, -1, 0 representation to a one hot encoding scheme.
# We will have 2 3x3 maps with +1 to indicate filled spaces for each player
def __init__(self, input_dim, hidden_layer_dim, output_dim): # should probably use conv layers and residual layers
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.input_layer = Linear(input_dim, hidden_layer_dim)
self.conv2d1 = torch.nn.Conv2d(3, 128, 2, stride=1)
self.conv2d2 = torch.nn.Conv2d(128, 128, 2, stride=1)
self.fc_layers = torch.nn.ModuleList(
[Linear(hidden_layer_dim, hidden_layer_dim), Linear(hidden_layer_dim, hidden_layer_dim),
Linear(hidden_layer_dim, hidden_layer_dim), Linear(hidden_layer_dim, hidden_layer_dim)])
self.batch_norms = torch.nn.ModuleList([BatchNorm1d(hidden_layer_dim) for x in self.fc_layers])
self.pre_policy_layers = torch.nn.ModuleList(
[Linear(hidden_layer_dim, hidden_layer_dim), Linear(hidden_layer_dim, hidden_layer_dim)])
self.policy_layer = Linear(hidden_layer_dim, output_dim)
self.pre_value_layers = torch.nn.ModuleList(
[Linear(hidden_layer_dim, hidden_layer_dim), Linear(hidden_layer_dim, hidden_layer_dim)])
self.value_layer = Linear(hidden_layer_dim, 1)
self.fc1 = Linear(128, 128)
def forward(self, x):
if len(x.shape) == 1:
x = x.unsqueeze(0)
turn = x[:, -1]
board = x[:, :-1]
zeros = torch.zeros_like(board)
map1 = torch.where(board == 1, board, zeros)
map2 = torch.where(board == -1, -board, zeros)
map1 = map1.view(-1, 3, 3)
map2 = map2.view(-1, 3, 3)
turn = torch.ones_like(map1) * turn.unsqueeze(-1).unsqueeze(-1).expand_as(map1)
h = torch.stack((map1, map2, turn), dim=1)
h = leaky_relu(self.conv2d1(h))
h = leaky_relu(self.conv2d2(h))
h = h.squeeze(-1).squeeze(-1) # reduce to fully connected residual layers
h = leaky_relu(self.fc1(h))
# h = relu(self.input_layer(x))
for i, l in enumerate(self.fc_layers): # all the hidden layers
h = l(h) + h
h = self.batch_norms[i](h)
h = leaky_relu(h)
_p = h
_v = h
for l in self.pre_policy_layers:
_p = l(_p) + _p
_p = leaky_relu(_p)
p = softmax(self.policy_layer(_p), dim=-1) # probs for each action
for l in self.pre_value_layers:
_v = l(_v) + _v
_v = leaky_relu(_v)
_v = self.value_layer(_v) # predict a value for this state
v = tanh(_v)
return torch.cat((p, v), dim=1).squeeze()
class AlphaZero(torch.nn.Module):
def __init__(self, input_dim, hidden_layer_dim, output_dim): # should probably use conv layers and residual layers
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.input_layer = Linear(input_dim, hidden_layer_dim)
self.fc_layers = torch.nn.ModuleList(
[Linear(hidden_layer_dim, hidden_layer_dim),
Linear(hidden_layer_dim, hidden_layer_dim)])
self.policy_layer = Linear(hidden_layer_dim, output_dim)
self.value_layer = Linear(hidden_layer_dim, 1)
def forward(self, x):
h = relu(self.input_layer(x))
for l in self.fc_layers: # all the hidden layers
h = relu(l(h))
p = softmax(self.policy_layer(h), dim=-1) # probs for each action
_v = self.value_layer(h) # predict a value for this state
v = tanh(_v)
if len(p.size()) == 2: # we were doing a batch input of vectors x
return torch.cat((p, v), dim=1) # cat along dim 1
else: # a single vector x was input, cat along dim 0
return torch.cat((p, v), dim=0)
class AlphaZeroResidual(torch.nn.Module):
def __init__(self, input_dim, hidden_layer_dim, output_dim):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.input_layer = Linear(input_dim, hidden_layer_dim)
self.fc_layers = torch.nn.ModuleList(
[Linear(hidden_layer_dim, hidden_layer_dim), Linear(hidden_layer_dim, hidden_layer_dim),
Linear(hidden_layer_dim, hidden_layer_dim)])
# , Linear(hidden_layer_dim, hidden_layer_dim), Linear(hidden_layer_dim, hidden_layer_dim)])
self.batch_norms = torch.nn.ModuleList([BatchNorm1d(hidden_layer_dim) for x in self.fc_layers])
self.policy_layer = Linear(hidden_layer_dim, output_dim)
self.value_layer = Linear(hidden_layer_dim, 1)
def forward(self, x):
if len(x.shape) == 1:
x = x.unsqueeze(0)
h = relu(self.input_layer(x))
for i, l in enumerate(self.fc_layers): # all the hidden layers
h = l(h) + h # residual block
h = self.batch_norms[i](h)
h = relu(h)
p = softmax(self.policy_layer(h), dim=-1) # probs for each action
_v = self.value_layer(h) # predict a value for this state
v = tanh(_v)
return torch.cat((p, v), dim=1).squeeze()
class Node():
def __init__(self, state, model, parent_node):
self.c = 1 # hyperparameter
self.state = np.copy(state)
self.z = None
self.parent = parent_node
self.player = 1 if parent_node is None else -parent_node.player # +1 if first player, -1 if 2nd player
x = torch.tensor(self.state, device=device, dtype=torch.float32)
predict = model(x)
self.initial_probs = predict[:-1].detach().cpu().numpy() # don't need grad here, this is making training data
self.value = predict[-1].detach().cpu().numpy() # nor here
action_size = predict.size()[0] - 1
self.Q = .5 * np.ones(action_size) #.zeros(action_size)
self.actions_taken = np.zeros(action_size)
self.last_action = -1 # used when we pass back through the tree to update Q values
def search(current_node, node_dict, agent, game): # get the next action for the game
# an UCB value for each action
u = current_node.Q + current_node.c * current_node.initial_probs * np.sqrt(np.sum(current_node.actions_taken)) / (
1 + current_node.actions_taken) + 1e-6 # a small inital prob for each action for numerical stability/simplcity
mask = game.getLegalActionMask()
u_masked = mask * u
u_masked[u_masked == 0] = np.nan # make the zeros nans so they are ignored by the argmax
action = np.nanargmax(u_masked)
if np.random.random() < .2: # select a random legal action
probs = np.ones_like(u_masked)
probs = probs * mask
probs = probs / np.sum(probs)
action = np.random.choice(list(range(game.action_space_size)), p=probs)
current_node.last_action = action
current_node.actions_taken[action] += 1
next_state, reward, done = game.step(action)
if done: # reached terminal state
# update q values for the nodes along this path
node = current_node
propagate_value_up_tree(node, reward)
return
else: # grow tree and search from the next node
key = next_state.tobytes()
if key in node_dict:
next_node = node_dict[key]
search(next_node, node_dict, agent, game)
return
else: # this is a new state not currently in the tree, we create the new node and backpropagate its value up the tree
new_node = make_new_node_in_tree(next_state, node_dict, agent, current_node)
# propagate_value_up_tree(new_node, new_node.value) # done in make_new_node_in_tree
# we only search the next node if it was already in the tree
return
def propagate_value_up_tree(leaf_node, value):
node = leaf_node
# update Q with moving average
new_q = node.Q[node.last_action] * (node.actions_taken[node.last_action] - 1) + value * node.player
new_q /= node.actions_taken[node.last_action]
node.Q[node.last_action] = new_q
if leaf_node.parent is None: # root node
return
else:
node = leaf_node.parent
propagate_value_up_tree(node, value)
def make_new_node_in_tree(obs, node_dict, agent, last_node):
new_node = Node(obs, agent, last_node)
v = new_node.value
# pass the value of new node up tree and update the Q values
propagate_value_up_tree(last_node, v)
# use a dictionary to store the graph
key = obs.tobytes()
node_dict[key] = new_node
return new_node
# simulate many games, build up tree of nodes
def simulate_game(game, agent, search_steps):
s_list = []
pi_list = []
z_list = []
obs, reward, done = game.reset()
root = Node(obs, agent, None)
node_dict = {}
node_dict[root.state.tobytes()] = root
current_node = root
while not done:
current_state = obs
key = obs.tobytes()
if key in node_dict:
current_node = node_dict[key]
else:
current_node = make_new_node_in_tree(current_state, node_dict, agent, current_node)
current_node.parent = None # we don't need to back propagate the values through the tree past this node
# add some noise to the root of the tree to encourage exploration and variety
# mix_ratio = .1
# diri_alpha = .2
# noise = np.random.dirichlet(diri_alpha * np.ones_like(current_node.initial_probs))
# current_node.initial_probs = (1 - mix_ratio) * current_node.initial_probs + mix_ratio * noise
for n in range(search_steps): # run search_steps trajectories through game
copy_game = game.copy()
search(current_node, node_dict, agent,
copy_game) # game will get modified, copy game before calling search or be able to reset state
# normalize the actions taken during the searching to get probabilities learned during MCTS
legal_action_mask = game.getLegalActionMask()
pi = current_node.actions_taken * legal_action_mask
pi = pi / np.sum(pi) # softmax?
# save the true distribution or a greedy version of it
# if np.random.random() < .05:
# _pi = pi
# else:
# greedy_pi = np.zeros_like(pi)
# greedy_pi[np.argmax(pi)] = 1
# _pi = greedy_pi
_pi = pi
s_list.append(current_state)
pi_list.append(_pi)
z_list.append(0) # to be filled with game outcome later
# select next action with some variation to make examples robust
# action = np.random.choice(list(range(game.action_space_size)), p=pi)
if np.random.random() < .2: # select a random legal action
probs = np.ones_like(pi)
probs = probs * legal_action_mask
probs = probs / np.sum(probs)
action = np.random.choice(list(range(game.action_space_size)), p=probs)
else: # select action proportional to determined probs from MCTS
action = np.random.choice(list(range(game.action_space_size)), p=pi)
# reward should be +1 if 1st player won, -1 if 2nd player won, 0 if tie
obs, reward, done = game.step(action)
if done: # save data about winning state
s_list.append(obs)
pi_list.append(np.zeros_like(_pi))
z_list.append(reward)
for i in range(len(z_list)): # fill in the outcome for every tuple along this path
z_list[i] = reward
return s_list, pi_list, z_list
def generate_training_data(game, num_games, search_steps, agent):
# intialize tensors after simulating first game
s_list, pi_list, z_list = simulate_game(game, agent, search_steps)
s = torch.tensor(s_list, device=device).to(dtype=torch.float)
pi = torch.tensor(pi_list, device=device).to(dtype=torch.float)
z = torch.tensor(z_list, device=device).to(dtype=torch.float)
for _ in range(num_games - 1): # make example list into tensors for training
s_list, pi_list, z_list = simulate_game(game, agent, search_steps)
_s = torch.tensor(s_list, device=device).to(dtype=torch.float)
_pi = torch.tensor(pi_list, device=device).to(dtype=torch.float)
_z = torch.tensor(z_list, device=device).to(dtype=torch.float)
s = torch.cat((s, _s))
pi = torch.cat((pi, _pi))
z = torch.cat((z, _z))
print(f"Generated {s.size(0)} new states")
return s, pi, z
# set up policy improvement of NN using the simulated games
def improve_model(model, training_data, epochs, optimizer, batch_size = 64, verbose=False):
mean_loss = None
s = training_data[0]
pi = training_data[1]
z = training_data[2]
for i in range(epochs):
rands = torch.randperm(s.size(0))
s = s[rands]
pi = pi[rands]
z = z[rands]
for b in range(0, s.size(0), batch_size):
end = b + batch_size if (b+batch_size < s.size(0)) else (s.size(0))
if end-b<2: # skip single points at end if needed
continue
y = model(s[b : end,:])
p = y[:, :-1] + 1e-6 # numerical stability so we don't get log(0)
v = y[:, -1]
# loss = (v - z) * (v - z) - torch.sum(pi * torch.log(p), dim=-1)
value_loss = torch.mean((v - z[b:end]) * (v - z[b:end]))
policy_loss = -torch.mean(torch.sum(pi[b:end] * torch.log(p), dim=-1))
mean_loss = value_loss + policy_loss
if torch.isnan(mean_loss):
print("getting nans")
optimizer.zero_grad()
mean_loss.backward()
optimizer.step()
print(f"Epoch {i+1}, Loss: {mean_loss}")
return mean_loss.detach().cpu().numpy(), value_loss.detach().cpu().numpy(), policy_loss.detach().cpu().numpy()
# set up playing games between old agent and updated agent, keeping the winner
def play_game(p1, p2, game, greedy=False):
obs, reward, done = game.reset()
model_playing = p1
while not done:
x = torch.tensor(obs, device=device, dtype=torch.float32)
y = model_playing(x)
v = y[-1]
p = y[:-1]
p = p.detach().cpu().numpy() + 1e-6 # numerical stability term
mask = game.getLegalActionMask()
p = p * mask
p = p / np.sum(p) # normalize probs
if greedy: # this option is bad since the same game gets played everytime
action = np.argmax(p)
else:
action = np.random.choice(list(range(game.action_space_size)), p=p)
obs, reward, done = game.step(action)
if done:
return reward # +1 if p1 wins, -1 if p2 wins, 0 if draw
# alternate turns
if model_playing is p1:
model_playing = p2
else:
model_playing = p2
def compare_agents(old_agent, new_agent, game, num_games, greedy=False): # num_games should be even
new_wins = 0
old_wins = 0
ties = 0
for i in range(num_games // 2):
r = play_game(old_agent, new_agent, game, greedy=greedy)
if r == 1:
old_wins += 1
elif r == -1:
new_wins += 1
elif r == 0:
ties += 1
else:
raise RuntimeError("Got bad return from game")
for i in range(num_games // 2):
r = play_game(new_agent, old_agent, game, greedy=greedy)
if r == 1:
new_wins += 1
elif r == -1:
old_wins += 1
elif r == 0:
ties += 1
else:
raise RuntimeError("Got bad return from game")
return new_wins, old_wins, ties
def get_agent_action(agent, game, state, greedy=False, verbose=False):
mask = game.getLegalActionMask()
x = torch.tensor(state, device=device, dtype=torch.float)
y = agent(x)
pi = y[:-1]
v = y[-1].detach().cpu().numpy()
_pi = pi.detach().cpu().numpy() + 1e-6 # for numerical stability
_pi = _pi * mask
_pi = _pi / np.sum(_pi)
if greedy:
action = np.argmax(_pi)
else:
action = np.random.choice(list(range(9)), p=_pi)
if verbose:
print(f" Value: {v}, action: {action}")
print(f"{state[0]} {state[1]} {state[2]} || {_pi[0]:.3f} {_pi[1]:.3f} {_pi[2]:.3f}")
print(f"{state[3]} {state[4]} {state[5]} || {_pi[3]:.3f} {_pi[4]:.3f} {_pi[5]:.3f}")
print(f"{state[6]} {state[7]} {state[8]} || {_pi[6]:.3f} {_pi[7]:.3f} {_pi[8]:.3f}")
return action
def play_against_heuristics(game, agent, agent_plays=1, render=False, verbose=False):
obs, r, done = game.reset()
while not done:
if agent_plays == 1:
agent_action = get_agent_action(agent, game, obs, greedy=True, verbose=False)
obs, reward, done = game.step(agent_action)
if render:
game.render()
# heuristics takes turn
if not done:
action = game.get_computer_move()
obs, reward, done = game.step(action)
if render:
game.render()
elif agent_plays == 2:
action = game.get_computer_move()
obs, reward, done = game.step(action)
if render:
game.render()
if not done:
agent_action = get_agent_action(agent, game, obs, greedy=True, verbose=False)
obs, reward, done = game.step(agent_action)
if render:
game.render()
else:
raise Exception("Invalid value for agent plays parameter. Choose 1 or 2")
if verbose:
if reward == 0:
print("Tie game")
elif reward == 1 and agent_plays == 1:
print("Agent won")
else:
print("Heuristics won")
return reward
# training loop
def main():
game = tictactoe()
input_size = game.obs_space_size
output_size = game.action_space_size
hidden_layer_size = 128
agent = AlphaZeroResidual(input_size, hidden_layer_size, output_size)
#agent = AlphaZero(input_size, hidden_layer_size, output_size)
# agent = AlphaZero(input_size, hidden_layer_size, output_size)
# agent = AlphaZeroConvLarge(input_size, hidden_layer_size, output_size)
agent.to(device)
print(agent)
agent.eval() # sets batch norm properly
print(f"Parameters: {sum([p.nelement() for p in agent.parameters()])}")
iterations = 25 # how many times we want to make training data, update a model
num_games = 30 # play certain number of games to generate examples each iteration (batches)
search_steps = 30 # for each step of each game, consider this many possible outcomes
optimization_steps = 10 # once we have generated training data, how many epochs do we do on this data
num_faceoff_games = 2 # when comparing updated model and old model, how many games they play to determine winner
lr = .01
optimizer = torch.optim.Adam(agent.parameters(), lr=lr, weight_decay=0.01)
lr_decay = .1
lr_decay_period = 200
replay_buffer_size = 800
best_loss = .1
training_data = None
new_agent = None
do_compare_agents = False
play_vs_heuristics = True
with open("saved_models/saved_data.csv", "w+") as f:
f.write(f"total loss, value loss, policy loss, Agent wins, Heuristic wins, Ties\n")
for itr in range(iterations):
print(f"Starting iteration {itr + 1} / {iterations}")
# if itr != 0 and itr % lr_decay_period == 0:
# lr *= lr_decay
# print(f"Decayed lr to {lr}")
print(f"Generating training data...")
# sliding window
if training_data is None:
training_data = generate_training_data(game, num_games, search_steps, agent)
else: # we generate more data if our improved agent fails to beat the old one
print("Generating additional data...")
more_data = generate_training_data(game, num_games, search_steps, agent)
s = torch.cat((training_data[0], more_data[0]))
pi = torch.cat((training_data[1], more_data[1]))
z = torch.cat((training_data[2], more_data[2]))
training_data = [s, pi, z]
# Clip the training data to hold last states
# Acts as a replay buffer with a sliding window
_s = training_data[0][-replay_buffer_size:, :]
_pi = training_data[1][-replay_buffer_size:, :]
_z = training_data[2][-replay_buffer_size:]
training_data = [_s, _pi, _z]
#training_data = generate_training_data(game, num_games, search_steps, agent)
print(f"Generated training data: {training_data[0].size(0)} states")
# if new_agent is None: # keep the progress on the new model if we failed to beat the old one
# new_agent = copy.deepcopy(agent)
print(f"Improving model...")
agent.train()
loss, value_loss, policy_loss = improve_model(agent, training_data, optimization_steps, optimizer, verbose=False)
agent.eval()
print(f"Finished improving model, total loss: {loss}, value loss: {value_loss}, policy loss: {policy_loss}")
if loss < best_loss: # model is pretty good. lets stop and check it out
torch.save(agent.state_dict(), f"saved_models/tictactoe_agent_{loss}.pt")
best_loss = loss
# if do_compare_agents:
# # 1. We can compare agents and keep the best one, or 2. just keep improving the same network repeatedly
# # AlphaGo Zero 1. (the specialized version of the algortihm for Go, a predecessor to the general purpose AlhpaZero)
# # AlphaZero uses 2.
# print(f"Comparing agents...")
# new_wins, old_wins, ties = compare_agents(agent, new_agent, game, num_faceoff_games)
# print(f"New wins: {new_wins}, old wins: {old_wins}, ties: {ties}")
# if new_wins > old_wins:
# agent = new_agent
# new_agent = None
# else:
# agent = new_agent
if play_vs_heuristics:
agent_wins_first = 0
agent_wins_second = 0
heursitic_wins_first = 0
heursitic_wins_second = 0
ties = 0
for i in range(num_faceoff_games // 2):
game.reset()
r = play_against_heuristics(game, agent, agent_plays=1)
if r == 0:
ties += 1
elif r == 1:
agent_wins_first += 1
elif r == -1:
heursitic_wins_second += 1
for i in range(num_faceoff_games // 2):
game.reset()
r = play_against_heuristics(game, agent, agent_plays=2)
if r == 0:
ties += 1
elif r == 1:
heursitic_wins_first += 1
elif r == -1:
agent_wins_second += 1
print(f"Agent wins: ({agent_wins_first},{agent_wins_second}) , Heuristic wins: ({heursitic_wins_first},{heursitic_wins_second}) , "
f"Ties: {ties}")
f.write(f"{loss}, {value_loss}, {policy_loss}, {agent_wins_first+agent_wins_second}, {heursitic_wins_first+heursitic_wins_second}, {ties}\n")
# training_data = None # reset training data every time, no replay buffer
print("Saving agent")
torch.save(agent.state_dict(), "saved_models/tictactoe_agent.pt")
print(f"Best loss seen: {best_loss}")
return agent
if __name__ == "__main__":
start = time.time()
main()
end = time.time()
elapsed = end - start
h = elapsed // 3600
m = (elapsed - h * 3600) // 60
s = elapsed % 60
print(f"Time: {h} h {m} m {s} s")