forked from QueensGambit/CrazyAra
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsearchthread.h
227 lines (188 loc) · 8.13 KB
/
searchthread.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
/*
CrazyAra, a deep learning chess variant engine
Copyright (C) 2018 Johannes Czech, Moritz Willig, Alena Beyer
Copyright (C) 2019-2020 Johannes Czech
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
/*
* @file: searchthread.h
* Created on 23.05.2019
* @author: queensgambit
*
* Handles the functionality of a single search thread in the tree.
*/
#ifndef SEARCHTHREAD_H
#define SEARCHTHREAD_H
#include "node.h"
#include "constants.h"
#include "neuralnetapi.h"
#include "config/searchlimits.h"
#include "util/fixedvector.h"
#include "nn/neuralnetapiuser.h"
enum NodeBackup : uint8_t {
NODE_COLLISION,
NODE_TERMINAL,
NODE_TRANSPOSITION,
NODE_NEW_NODE,
NODE_UNKNOWN,
};
struct NodeDescription
{
NodeBackup type;
// depth which was reached on this rollout
size_t depth;
};
class SearchThread : public NeuralNetAPIUser
{
private:
Node* rootNode;
StateObj* rootState;
unique_ptr<StateObj> newState;
// list of all node objects which have been selected for expansion
unique_ptr<FixedVector<Node*>> newNodes;
unique_ptr<FixedVector<SideToMove>> newNodeSideToMove;
unique_ptr<FixedVector<float>> transpositionValues;
vector<Trajectory> newTrajectories;
vector<Trajectory> transpositionTrajectories;
vector<Trajectory> collisionTrajectories;
Trajectory trajectoryBuffer;
vector<Action> actionsBuffer;
bool isRunning;
MapWithMutex* mapWithMutex;
const SearchSettings* searchSettings;
SearchLimits* searchLimits;
size_t tbHits;
size_t depthSum;
size_t depthMax;
size_t visitsPreSearch;
uint_fast32_t terminalNodeCache; // TODO: better add "const" classifier here is possible
bool reachedTablebases;
public:
/**
* @brief SearchThread
* @param netBatch Network API object which provides the prediction of the neural network
* @param searchSettings Given settings for this search run
* @param MapWithMutex Handle to the hash table
*/
SearchThread(NeuralNetAPI* netBatch, const SearchSettings* searchSettings, MapWithMutex* mapWithMutex);
/**
* @brief create_mini_batch Creates a mini-batch of new unexplored nodes.
* Terminal node are immediatly backpropagated without requesting the NN.
* If the node was found in the hash-table it's value is backpropagated without requesting the NN.
* If a collision occurs (the same node was selected multiple times), it will be added to the collisionNodes vector
*/
void create_mini_batch();
/**
* @brief thread_iteration Runs multiple mcts-rollouts as long as a new batch is filled
*/
void thread_iteration();
/**
* @brief nodes_limits_ok Checks if the searchLimits based on the amount of nodes to search has been reached.
* In the case the number of nodes is set to zero the limit condition is ignored
* @return bool
*/
inline bool nodes_limits_ok();
/**
* @brief is_root_node_unsolved Checks if the root node result is still unsolved (not a forced win, draw or loss)
* @return true for unsolved, else false
*/
inline bool is_root_node_unsolved();
/**
* @brief stop Stops the rollouts of the current thread
*/
void stop();
// Getter, setter functions
void set_search_limits(SearchLimits *s);
Node* get_root_node() const;
SearchLimits *get_search_limits() const;
void set_root_node(Node *value);
bool is_running() const;
void set_is_running(bool value);
void set_reached_tablebases(bool value);
/**
* @brief add_new_node_to_tree Adds a new node to the search by either creating a new node or duplicating an exisiting node in case of transposition usage
* @param newPos Board position of the new node
* @param parentNode Parent node for the now
* @param childIdx Respective index for the new node
* @param nodeBackup Returns NODE_TRANSPOSITION if a tranpsosition node was added and NODE_NEW_NODE otherwise
* @return The newly added node
*/
Node* add_new_node_to_tree(StateObj* newPos, Node* parentNode, ChildIdx childIdx, NodeBackup& nodeBackup);
/**
* @brief reset_tb_hits Sets the number of table hits to 0
*/
void reset_stats();
void set_root_state(StateObj* value);
size_t get_tb_hits() const;
size_t get_avg_depth();
size_t get_max_depth() const;
Node* get_starting_node(Node* currentNode, NodeDescription& description, ChildIdx& childIdx);
private:
/**
* @brief set_nn_results_to_child_nodes Sets the neural network value evaluation and policy prediction vector for every newly expanded nodes
*/
void set_nn_results_to_child_nodes();
/**
* @brief backup_value_outputs Backpropagates all newly received value evaluations from the neural network accross the visited search paths
*/
void backup_value_outputs();
/**
* @brief backup_collisions Reverts the applied virtual loss for all rollouts which ended in a collision event
*/
void backup_collisions();
/**
* @brief get_new_child_to_evaluate Traverses the search tree beginning from the root node and returns the prarent node and child index for the next node to expand.
* @param description Output struct which holds information what type of node it is
* @return Pointer to next child to evaluate (can also be terminal or tranposition node in which case no NN eval is required)
*/
Node* get_new_child_to_evaluate(NodeDescription& description);
void backup_values(FixedVector<Node*>& nodes, vector<Trajectory>& trajectories);
void backup_values(FixedVector<float>* values, vector<Trajectory>& trajectories);
/**
* @brief select_enhanced_move Selects an enhanced move (e.g. checking move) which has not been explored under given conditions.
* @param currentNode Current node during forward simulation
* @return uint_16_t(-1) for no action else custom idx
*/
ChildIdx select_enhanced_move(Node* currentNode) const;
/**
* @brief get_current_transposition_q_value Returns the Q-value which connects to the transposition node
* @param currentNode Current node
* @param childIdx child index
* @param transposVisits Number of visits connecting to the transposition node
* @return Q-Value converted to double
*/
double get_current_transposition_q_value(const Node* currentNode, ChildIdx childIdx, uint_fast32_t transposVisits);
};
void run_search_thread(SearchThread *t);
void fill_nn_results(size_t batchIdx, bool isPolicyMap, const float* valueOutputs, const float* probOutputs, const float* auxiliaryOutputs, Node *node, size_t& tbHits, bool mirrorPolicy, const SearchSettings* searchSettings, bool isRootNodeTB);
void node_post_process_policy(Node *node, float temperature, const SearchSettings* searchSettings);
void node_assign_value(Node *node, const float* valueOutputs, size_t& tbHits, size_t batchIdx, bool isRootNodeTB);
/**
* @brief random_root_playout Uses random move exploration (epsilon greedy) from the given position. The probability for doing a random move decays by depth.
* @param currentNode Current node during trajectory
* @param childIdx Return child index (maybe unchanged)
*/
inline void random_playout(Node* currentNode, ChildIdx& childIdx);
/**
* @brief get_random_depth
* Example: drawing a random number from a uniform distribution in [0, 100]
* DEPTH 0: 0 - 50
* DEPTH 1: 51 - 75
* DEPTH 2: 76 - 77
* DEPTH 3: 77 - 94
* DEPTH 4: 95 - 97
* DEPTH 5: 98 - 99
* DEPTH 6: 100
* @return random depth while the probability of choosing higher depths decreases exponetially
*/
size_t get_random_depth();
#endif // SEARCHTHREAD_H