-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathbot_caffe.py
69 lines (51 loc) · 1.99 KB
/
bot_caffe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import tempfile
import subprocess
import os
import logging
import array
import numpy as np
import time
import caffe
from players import DistributionBot, DistWrappingMaxPlayer
import cubes
from state import gomill_gamestate2state, State
class DetlefDistBot(DistributionBot):
"""
CNN as kindly provided by Detlef Schmicker. See
http://computer-go.org/pipermail/computer-go/2015-December/008324.html
The net should (as of January 2016) be available here:
http://physik.de/CNNlast.tar.gz
"""
def __init__(self, caffe_net):
super(DetlefDistBot, self).__init__()
self.caffe_net = caffe_net
def gen_probdist_raw(self, game_state, player):
cube = cubes.get_cube_detlef(gomill_gamestate2state(game_state), player)
cube = cube.reshape( (1,) + cube.shape)
logging.debug("%s sending data of shape=%s"%(self, cube.shape))
resp = self.caffe_net.forward_all(**{'data':cube})['ip']
logging.debug("%s read response of shape=%s"%(self, resp.shape))
# FIXME update, 128 output channels is detlef's mistake :-)
resp = resp.reshape((128, game_state.board.side, game_state.board.side))
tot = resp.sum()
ret = resp[0]
logging.debug("%s trimming channelstook off %.3f %%"%(self,
100 * (tot - ret.sum())/tot))
return ret / ret.sum()
if __name__ == "__main__":
def test_bot():
import gomill
import rank
logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s',
level=logging.DEBUG)
# substitute with your own
caffe_net = caffe.Net('golast19.prototxt', 'golast.trained', 0)
player = DistWrappingMaxPlayer(DetlefDistBot(caffe_net))
class GameState:
pass
s = GameState()
s.board = gomill.boards.Board(19)
s.ko_point = None
s.move_history = []
print player.genmove(s, 'b').move
test_bot()