Skip to content

Commit

Permalink
Virtual_Visit, Virtual_Mix, Virtual_Offset (#205)
Browse files Browse the repository at this point in the history
* Add VIRTUAL_VISIT as new virtual simulation style

* Sort folder names in train archive by time stamp (close #189)

* Add virtualVisitIncrement

also only increase virtualLoss Counter by 1

* Add UCI-Option "Virtual_Visit_Increment"

* Add VIRTUAL_OFFSET UCI-Option

* Add missing ;

* Simplify VIRTUAL_VISIT

* Update virtual-Offset implementation

+ use also virtual-visit of 1
+ don't use additional vector operation

* Add missing implementation for VIRTUAL_OFFSET in
get_transposition_q_value()

* Change Centi_Virtual_Loss to Milli_Virtual_loss

* Change Milli_Virtual_Loss to Micro_Virtual_Loss

* Update max value for Micro_Virtual_Loss

* Add VIRTUAL_MIX
Rename VIRTUAL_LOSS into VIRTUAL_WEIGHT
Replace virtualVisitIncrement by virtualMixThreshold

* Use realVisitSum as condition for Virtual_Mix

* Simplify code and use realVisits of child node for VIRTUAL_MIX threshold

* Fix compile bugs

* Use d->childNumber visits again

* Deactive virtualWeight for now

* Deactive virtualWeight

* Use Q_INIT for comparision

* revert last change

* Add virtualOffsetStrenght(0.001)
Switch between VIRTUAL_OFFSET and VIRTUAL_VISIT when using VIRTUAL_MIX

* revert

* Update VIRTUAL_OFFSET

* Switch between VIRTUAL_LOSS and VIRTUAL_VISIT

* Add virtualLossIncrement

* remove virtualLossIncrement again due to underperformance

* Update UCI-default values

* Fix compile error

* Disable 960 Support for now due to problems

* Fix init of second argument in first_and_second_max()

* Remove init of additional nodes in NodeData

* Remove Virtual_Weight for now
  • Loading branch information
QueensGambit authored Aug 8, 2023
1 parent 282c671 commit 812d62d
Show file tree
Hide file tree
Showing 11 changed files with 165 additions and 58 deletions.
4 changes: 2 additions & 2 deletions engine/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ if (MODE_CHESS)
# add_definitions(-DVERSION=2)
# add_definitions(-DSUB_VERSION=8)
add_definitions(-DVERSION=3)
add_definitions(-DSUPPORT960)
# add_definitions(-DSUPPORT960)
endif()

if (MODE_LICHESS)
Expand All @@ -57,7 +57,7 @@ if (MODE_LICHESS)
add_definitions(-DATOMIC)
add_definitions(-DHORDE)
add_definitions(-DRACE)
add_definitions(-DSUPPORT960)
# add_definitions(-DSUPPORT960)
add_definitions(-DMCTS_TB_SUPPORT)
add_definitions(-DVERSION=1)
endif()
Expand Down
6 changes: 4 additions & 2 deletions engine/src/agents/config/searchsettings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ SearchSettings::SearchSettings():
nodePolicyTemperature(1.0f),
qValueWeight(1.0f),
qVetoDelta(0.4f),
virtualLoss(1.0f),
verbose(true),
epsilonChecksCounter(100),
useMCGS(true),
Expand All @@ -49,7 +48,10 @@ SearchSettings::SearchSettings():
epsilonGreedyCounter(20),
reuseTree(true),
mctsSolver(false),
searchPlayerMode(MODE_TWO_PLAYER)
searchPlayerMode(MODE_TWO_PLAYER),
virtualStyle(VIRTUAL_VISIT),
virtualMixThreshold(1000),
virtualOffsetStrenght(0.001)
{

}
14 changes: 13 additions & 1 deletion engine/src/agents/config/searchsettings.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ enum SearchPlayerMode {
MODE_TWO_PLAYER
};

enum VirtualStyle {
VIRTUAL_LOSS,
VIRTUAL_VISIT,
VIRTUAL_OFFSET,
VIRTUAL_MIX
};

struct SearchSettings
{
uint16_t multiPV;
Expand All @@ -48,7 +55,6 @@ struct SearchSettings
float qValueWeight;
// describes how much better the highest Q-Value has to be to replace the candidate move with the highest visit count
float qVetoDelta;
uint_fast32_t virtualLoss;
bool verbose;
uint_fast8_t epsilonChecksCounter;
// bool enhanceCaptures; currently not support
Expand All @@ -75,6 +81,12 @@ struct SearchSettings
bool mctsSolver;
// Defines the nubmer of players within the MCTS search. Available are MODE_SINGLE_PLAYER and MODE_TWO_PLAYER
SearchPlayerMode searchPlayerMode;
// Define the virtual style to avoid conflict between different threads in within the same mini-batch
VirtualStyle virtualStyle;
// Defines the number of visits to switch from virtual-visit to virtual-loss
uint_fast32_t virtualMixThreshold;
// Defines the strength of the virtual offset
double virtualOffsetStrenght;
SearchSettings();

};
Expand Down
70 changes: 55 additions & 15 deletions engine/src/node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ bool Node::is_sorted() const
return sorted;
}

double Node::get_q_sum(ChildIdx childIdx, float virtualLoss) const
double Node::get_q_sum_virtual_loss(ChildIdx childIdx) const
{
return get_child_number_visits(childIdx) * double(get_q_value(childIdx)) + get_virtual_loss_counter(childIdx) * virtualLoss;
return get_child_number_visits(childIdx) * double(get_q_value(childIdx)) + get_virtual_loss_counter(childIdx);
}

bool Node::is_transposition() const
Expand Down Expand Up @@ -504,18 +504,28 @@ bool Node::has_nn_results() const
return hasNNResults;
}

void Node::apply_virtual_loss_to_child(ChildIdx childIdx, uint_fast32_t virtualLoss)
void Node::apply_virtual_loss_to_child(ChildIdx childIdx, const SearchSettings* searchSettings)
{
// update the stats of the parent node
// make it look like if one has lost X games from this node forward where X is the virtual loss value
// temporarily reduce the attraction of this node by applying a virtual loss /
// the effect of virtual loss will be undone if the playout is over
d->qValues[childIdx] = (double(d->qValues[childIdx]) * d->childNumberVisits[childIdx] - virtualLoss) / double(d->childNumberVisits[childIdx] + virtualLoss);
switch (get_virtual_style(searchSettings, d->childNumberVisits[childIdx])) {
case VIRTUAL_LOSS:
d->qValues[childIdx] = (double(d->qValues[childIdx]) * d->childNumberVisits[childIdx] - 1) / double(d->childNumberVisits[childIdx] + 1);
break;
case VIRTUAL_OFFSET:
d->qValues[childIdx] -= searchSettings->virtualOffsetStrenght;
case VIRTUAL_VISIT: ; // ignore
case VIRTUAL_MIX: ; // unreachable
}

// virtual increase the number of visits
d->childNumberVisits[childIdx] += virtualLoss;
d->visitSum += virtualLoss;
++d->childNumberVisits[childIdx];
++d->visitSum;

// increment virtual loss counter
update_virtual_loss_counter<true>(childIdx, virtualLoss);
update_virtual_loss_counter<true>(childIdx);
}

float Node::get_q_value(ChildIdx childIdx) const
Expand Down Expand Up @@ -642,20 +652,29 @@ uint32_t Node::get_real_visits(ChildIdx childIdx) const
return d->childNumberVisits[childIdx] - d->virtualLossCounter[childIdx];
}

void backup_collision(float virtualLoss, const Trajectory& trajectory) {
void backup_collision(const SearchSettings* searchSettings, const Trajectory& trajectory) {
for (auto it = trajectory.rbegin(); it != trajectory.rend(); ++it) {
it->node->revert_virtual_loss(it->childIdx, virtualLoss);
it->node->revert_virtual_loss(it->childIdx, searchSettings);
}
}

void Node::revert_virtual_loss(ChildIdx childIdx, float virtualLoss)
void Node::revert_virtual_loss(ChildIdx childIdx, const SearchSettings* searchSettings)
{
lock();
d->qValues[childIdx] = (double(d->qValues[childIdx]) * d->childNumberVisits[childIdx] + virtualLoss) / (d->childNumberVisits[childIdx] - virtualLoss);
d->childNumberVisits[childIdx] -= virtualLoss;
d->visitSum -= virtualLoss;
switch (get_virtual_style(searchSettings, d->childNumberVisits[childIdx])) {
case VIRTUAL_LOSS:
d->qValues[childIdx] = (double(d->qValues[childIdx]) * d->childNumberVisits[childIdx] + 1) / (d->childNumberVisits[childIdx] - 1);
break;
case VIRTUAL_OFFSET:
d->qValues[childIdx] += searchSettings->virtualOffsetStrenght;
case VIRTUAL_MIX: ; // ignore
case VIRTUAL_VISIT: ; // ignore
}
--d->childNumberVisits[childIdx];
--d->visitSum;

// decrement virtual loss counter
update_virtual_loss_counter<false>(childIdx, virtualLoss);
update_virtual_loss_counter<false>(childIdx);
unlock();
}

Expand Down Expand Up @@ -990,6 +1009,27 @@ void Node::disable_action(size_t childIdxForParent)
d->qValues[childIdxForParent] = -INT_MAX;
}

double Node::get_transposition_q_value(const SearchSettings *searchSettings, ChildIdx childIdx, uint_fast32_t transposVisits)
{
double transposQValue;
switch(get_virtual_style(searchSettings, d->childNumberVisits[childIdx])) {
case VIRTUAL_LOSS:
transposQValue = get_q_sum_virtual_loss(childIdx) / transposVisits;
break;
case VIRTUAL_VISIT:
transposQValue = get_q_value(childIdx);
break;
case VIRTUAL_OFFSET:
transposQValue = double(get_q_value(childIdx)) + get_virtual_loss_counter(childIdx) * searchSettings->virtualOffsetStrenght;
case VIRTUAL_MIX: ;
// unreachable
}
if (searchSettings->searchPlayerMode == MODE_TWO_PLAYER) {
return -transposQValue;
}
return transposQValue;
}

void Node::enhance_moves(const SearchSettings* searchSettings)
{
// if (!searchSettings->enhanceChecks && !searchSettings->enhanceCaptures) {
Expand Down Expand Up @@ -1295,7 +1335,7 @@ bool is_terminal_value(float value)
return (value == WIN_VALUE || value == DRAW_VALUE || value == LOSS_VALUE);
}

float get_transposition_q_value(uint_fast32_t transposVisits, double transposQValue, double targetQValue)
float get_transposition_backup_value(uint_fast32_t transposVisits, double transposQValue, double targetQValue)
{
return std::clamp(transposVisits * (targetQValue - transposQValue) + targetQValue, double(LOSS_VALUE), double(WIN_VALUE));
}
Expand Down
79 changes: 55 additions & 24 deletions engine/src/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,16 @@ struct NodeAndBudget {
node(node), budget(budget), curState(state) {}
};

inline VirtualStyle get_virtual_style(const SearchSettings* searchSettings, uint_fast32_t visits) {
if (searchSettings->virtualStyle == VIRTUAL_MIX) {
if (visits > searchSettings->virtualMixThreshold) {
return VIRTUAL_LOSS;
}
return VIRTUAL_VISIT;
}
return searchSettings->virtualStyle;
}

class Node
{
private:
Expand Down Expand Up @@ -190,28 +200,42 @@ class Node
void revert_virtual_loss_and_update(ChildIdx childIdx, float value, const SearchSettings* searchSettings, bool solveForTerminal)
{
lock();
// decrement virtual loss counter
update_virtual_loss_counter<false>(childIdx, searchSettings->virtualLoss);

valueSum += value;
++realVisitsSum;

if (d->childNumberVisits[childIdx] == searchSettings->virtualLoss) {
if (d->childNumberVisits[childIdx] == 1) {
// set new Q-value based on return
// (the initialization of the Q-value was by Q_INIT which we don't want to recover.)
d->qValues[childIdx] = value;
}
else {
// revert virtual loss and update the Q-value
assert(d->childNumberVisits[childIdx] != 0);
d->qValues[childIdx] = (double(d->qValues[childIdx]) * d->childNumberVisits[childIdx] + searchSettings->virtualLoss + value) / d->childNumberVisits[childIdx];
uint_fast32_t childRealVisit;
double newQVal;
switch(get_virtual_style(searchSettings, d->childNumberVisits[childIdx])) {
case VIRTUAL_LOSS:
d->qValues[childIdx] = (double(d->qValues[childIdx]) * d->childNumberVisits[childIdx] + 1 + value) / d->childNumberVisits[childIdx];
break;
case VIRTUAL_VISIT:
childRealVisit = get_real_visits(childIdx);
d->qValues[childIdx] = (double(d->qValues[childIdx]) * childRealVisit + value) / (childRealVisit + 1);
break;
case VIRTUAL_OFFSET:
newQVal = double(d->qValues[childIdx]) + d->virtualLossCounter[childIdx] * searchSettings->virtualOffsetStrenght;
newQVal = (newQVal * childRealVisit + value) / (childRealVisit + 1.0);
d->qValues[childIdx] = newQVal - ((d->virtualLossCounter[childIdx]-1) * searchSettings->virtualOffsetStrenght);
case VIRTUAL_MIX: ;
// unreachable
}

assert(!isnan(d->qValues[childIdx]));
}

if (searchSettings->virtualLoss != 1) {
d->childNumberVisits[childIdx] -= size_t(searchSettings->virtualLoss) - 1;
d->visitSum -= size_t(searchSettings->virtualLoss) - 1;
}
// decrement virtual loss counter
update_virtual_loss_counter<false>(childIdx);

if (freeBackup) {
++d->freeVisits;
}
Expand All @@ -225,7 +249,7 @@ class Node
* @brief revert_virtual_loss Reverts the virtual loss for a target node
* @param childIdx Index to the child node to update
*/
void revert_virtual_loss(ChildIdx childIdx, float virtualLoss);
void revert_virtual_loss(ChildIdx childIdx, const SearchSettings* searchSettings);

bool is_playout_node() const;

Expand Down Expand Up @@ -259,7 +283,7 @@ class Node
double get_value_sum() const;
uint32_t get_real_visits() const;

void apply_virtual_loss_to_child(ChildIdx childIdx, uint_fast32_t virtualLoss);
void apply_virtual_loss_to_child(ChildIdx childIdx, const SearchSettings* searchSettings);

void increment_no_visit_idx();
void fully_expand_node();
Expand Down Expand Up @@ -470,17 +494,17 @@ class Node
*/
void decrement_number_parents();

double get_q_sum(ChildIdx childIdx, float virtualLoss) const;
double get_q_sum_virtual_loss(ChildIdx childIdx) const;

template<bool increment>
void update_virtual_loss_counter(ChildIdx childIdx, float virtualLoss)
void update_virtual_loss_counter(ChildIdx childIdx)
{
if (increment) {
d->virtualLossCounter[childIdx] += virtualLoss;
++d->virtualLossCounter[childIdx];
}
else {
assert(d->virtualLossCounter[childIdx] != 0);
d->virtualLossCounter[childIdx] -= virtualLoss;
--d->virtualLossCounter[childIdx];
}
}

Expand Down Expand Up @@ -515,6 +539,17 @@ class Node

uint32_t get_number_of_nodes() const;


/**
* @brief get_transposition_q_value Returns the Q-value (without virtualLoss) which connects to the transposition node.
* The q-Value is also multiplied by -1 if searchSettings->searchPlayerMode == MODE_TWO_PLAYER.
* @param currentNode Current node
* @param childIdx child index
* @param transposVisits Number of visits connecting to the transposition node
* @return Q-Value converted to double
*/
double get_transposition_q_value(const SearchSettings* searchSettings, ChildIdx childIdx, uint_fast32_t transposVisits);

private:
/**
* @brief reserve_full_memory Reserves memory for all available child nodes
Expand Down Expand Up @@ -765,12 +800,12 @@ bool is_terminal_value(float value);
/**
* @brief backup_collision Iteratively removes the virtual loss of the collision event that occurred
* @param rootNode Root node of the tree
* @param virtualLoss Virtual loss value
* @param searchSettings Search settings struct
* @param trajectory Trajectory on how to get to the given collision
*/
void backup_collision(float virtualLoss, const Trajectory& trajectory);
void backup_collision(const SearchSettings* searchSettings, const Trajectory& trajectory);

float get_transposition_q_value(uint_fast32_t transposVisits, double transposQValue, double masterQValue);
float get_transposition_backup_value(uint_fast32_t transposVisits, double transposQValue, double masterQValue);

/**
* @brief backup_value Iteratively backpropagates a value prediction across all of the parents for this node.
Expand All @@ -788,15 +823,12 @@ void backup_value(float value, const SearchSettings* searchSettings, const Traje
if (targetQValue != 0) {
const uint_fast32_t transposVisits = it->node->get_real_visits(it->childIdx);
if (transposVisits != 0) {
const double transposQValue = -it->node->get_q_sum(it->childIdx, searchSettings->virtualLoss) / transposVisits;
value = get_transposition_q_value(transposVisits, transposQValue, targetQValue);
const double transposQValue = it->node->get_transposition_q_value(searchSettings, it->childIdx, transposVisits);
value = get_transposition_backup_value(transposVisits, transposQValue, targetQValue);
}
}
switch (searchSettings->searchPlayerMode) {
case MODE_TWO_PLAYER:
if (searchSettings->searchPlayerMode == MODE_TWO_PLAYER) {
value = -value;
break;
case MODE_SINGLE_PLAYER: ;
}
freeBackup ? it->node->revert_virtual_loss_and_update<true>(it->childIdx, value, searchSettings, solveForTerminal) :
it->node->revert_virtual_loss_and_update<false>(it->childIdx, value, searchSettings, solveForTerminal);
Expand All @@ -818,5 +850,4 @@ void backup_value(float value, const SearchSettings* searchSettings, const Traje
*/
bool is_transposition_verified(const Node* node, const StateObj* state);


#endif // NODE_H
4 changes: 0 additions & 4 deletions engine/src/nodedata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,6 @@ void NodeData::reserve_initial_space()
virtualLossCounter.reserve(initSize);
nodeTypes.reserve(initSize);
add_empty_node();
if (initSize > 1) {
add_empty_node();
++noVisitIdx;
}
}

NodeData::NodeData():
Expand Down
Loading

0 comments on commit 812d62d

Please sign in to comment.