forked from filipmlynarski/splendor-ai
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_model.py
46 lines (39 loc) · 1.46 KB
/
train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import sys
sys.path.insert(0, 'splendor_ai')
from splendor_ai.alpha_zero import AlphaZero
from splendor_ai.mini_splendor_state_encoder import MiniSplendorStateEncoder
import interactive_splendor
from environment import minisplendor
model = AlphaZero(MiniSplendorStateEncoder(),0.3,15)
model.net.train_heuristic(model.state_encoder,minisplendor.MiniSplendor())
num_iters = 50
num_comparision_games = 20
needed_wins = 13
it = 0
while True:
model.is_learning = True
print('self play')
for i in range(num_iters):
print('game',i)
interactive_splendor.play_game(False,True,(('m1',model), ('m2',model)), minisplendor.MiniSplendor())
model2 = model.produce_new_version()
model.is_learning, model2.is_learning = True, True
r = [0,0]
print('fight!')
for i in range(num_comparision_games):
players = (('m1',model), ('m2',model2))
swapped = False
if i > num_comparision_games//2:
players = (('m2',model2), ('m1',model))
swapped = True
result = interactive_splendor.play_game(False,False,players,minisplendor.MiniSplendor())
print('game result = ',result if not swapped else 1-result)
if result not in [0,1]:
continue
r[result if not swapped else 1-result] += 1
print('loses = ',r[0],',wins = ',r[1])
if r[1] >= needed_wins:
print('better model')
model = model2
model.net.model.save('saved/my_model'+str(it))
it += 1