-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathUCTNode.h
123 lines (113 loc) · 3.66 KB
/
UCTNode.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#ifndef UCTNODE_H_INCLUDED
#define UCTNODE_H_INCLUDED
#include "config.h"
#include <tuple>
#include <atomic>
#include "SMP.h"
#include "GameState.h"
#include "Playout.h"
#include "Network.h"
class UCTNode {
public:
typedef std::tuple<float, int, UCTNode*> sortnode_t;
UCTNode(int vertex, float score,
int expand_threshold, int netscore_threshold,
int movenum);
~UCTNode();
bool first_visit() const;
bool has_children() const;
float get_winrate(int tomove) const;
float get_raverate() const;
double get_blackwins() const;
void create_children(std::atomic<int> & nodecount,
FastState & state, bool at_root, bool use_nets);
void netscore_children(std::atomic<int> & nodecount,
FastState & state, bool at_root);
void scoring_cb(std::atomic<int> * nodecount,
FastState & state,
Network::Netresult & raw_netlist,
bool all_symmetries);
void run_value_net(FastState & state);
void kill_superkos(KoState & state);
void delete_child(UCTNode * child);
void invalidate();
bool valid() const;
bool should_expand() const;
bool should_netscore() const;
int get_move() const;
int get_visits() const;
int get_ravevisits() const;
bool has_netscore() const;
float get_score() const;
void set_score(float score);
float get_eval(int tomove) const;
float get_mixed_score(int tomove);
static float score_mix_function(int movenum, float eval, float winrate);
double get_blackevals() const;
int get_evalcount() const;
bool has_eval_propagated() const;
void set_eval_propagated();
void set_move(int move);
void set_visits(int visits);
void set_blackwins(double wins);
void set_expand_cnt(int runs, int netscore_runs);
void set_blackevals(double blacevals);
void set_evalcount(int evalcount);
void set_expand_cnt(int runs);
void set_eval(float eval);
void accumulate_eval(float eval);
void update(Playout & gameresult, int color, bool update_eval);
void updateRAVE(Playout & playout, int color);
UCTNode* uct_select_child(int color, bool use_nets);
UCTNode* get_first_child() const;
UCTNode* get_pass_child() const;
UCTNode* get_nopass_child() const;
UCTNode* get_sibling() const;
void sort_root_children(int color);
void sort_children();
SMP::Mutex & get_mutex();
private:
UCTNode();
void link_child(UCTNode * newchild);
void link_nodelist(std::atomic<int> & nodecount,
FastBoard & state,
Network::Netresult & nodes,
bool use_nets);
void rescore_nodelist(std::atomic<int> & nodecount,
FastBoard & state,
Network::Netresult & nodes,
bool all_symmetries);
float smp_noise();
// Tree data
std::atomic<bool> m_has_children;
UCTNode* m_firstchild;
UCTNode* m_nextsibling;
// Move
int m_move;
// UCT
std::atomic<double> m_blackwins;
std::atomic<int> m_visits;
// RAVE
std::atomic<double> m_ravestmwins;
std::atomic<int> m_ravevisits;
// move order
float m_score;
// board eval
bool m_eval_propagated;
std::atomic<double> m_blackevals;
std::atomic<int> m_evalcount;
int m_movenum;
bool m_is_evaluating; // mutex
// alive (superko)
std::atomic<bool> m_valid;
// extend node
int m_expand_cnt;
bool m_is_expanding;
// dcnn node
bool m_has_netscore;
int m_netscore_thresh;
int m_symmetries_done;
bool m_is_netscoring;
SMP::Mutex m_nodemutex;
};
#endif