Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Virtual_Visit, Virtual_Mix, Virtual_Offset #205

Merged
merged 34 commits into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
5ac613f
Add VIRTUAL_VISIT as new virtual simulation style
QueensGambit Jul 27, 2023
a730b11
Sort folder names in train archive by time stamp (close #189)
QueensGambit Jul 31, 2023
bdb422a
Add virtualVisitIncrement
QueensGambit Aug 2, 2023
14b5dcb
Add UCI-Option "Virtual_Visit_Increment"
QueensGambit Aug 2, 2023
283f319
Add VIRTUAL_OFFSET UCI-Option
QueensGambit Aug 2, 2023
6908a18
Add missing ;
QueensGambit Aug 2, 2023
de76cab
Simplify VIRTUAL_VISIT
QueensGambit Aug 2, 2023
e46506a
Update virtual-Offset implementation
QueensGambit Aug 3, 2023
bb0fec1
Add missing implementation for VIRTUAL_OFFSET in
QueensGambit Aug 3, 2023
faf0312
Change Centi_Virtual_Loss to Milli_Virtual_loss
QueensGambit Aug 3, 2023
52b544a
Change Milli_Virtual_Loss to Micro_Virtual_Loss
QueensGambit Aug 3, 2023
c92b7a4
Update max value for Micro_Virtual_Loss
QueensGambit Aug 3, 2023
1d17d91
Add VIRTUAL_MIX
QueensGambit Aug 3, 2023
875820f
Use realVisitSum as condition for Virtual_Mix
QueensGambit Aug 3, 2023
38dc0fe
Simplify code and use realVisits of child node for VIRTUAL_MIX threshold
QueensGambit Aug 3, 2023
2d0141e
Fix compile bugs
QueensGambit Aug 3, 2023
7d910b7
Use d->childNumber visits again
QueensGambit Aug 3, 2023
522f347
Deactive virtualWeight for now
QueensGambit Aug 3, 2023
5a55691
Deactive virtualWeight
QueensGambit Aug 3, 2023
4e92cb5
Use Q_INIT for comparision
QueensGambit Aug 3, 2023
511d583
revert last change
QueensGambit Aug 3, 2023
e8fd85a
Add virtualOffsetStrenght(0.001)
QueensGambit Aug 3, 2023
46dd935
revert
QueensGambit Aug 3, 2023
cfa391b
Update VIRTUAL_OFFSET
QueensGambit Aug 3, 2023
c38d9f4
Switch between VIRTUAL_LOSS and VIRTUAL_VISIT
QueensGambit Aug 3, 2023
0f94920
Add virtualLossIncrement
QueensGambit Aug 4, 2023
8eba23f
remove virtualLossIncrement again due to underperformance
QueensGambit Aug 4, 2023
3ed0c8d
Merge remote-tracking branch 'origin/master' into virtual_visit
QueensGambit Aug 4, 2023
e1ef98b
Update UCI-default values
QueensGambit Aug 4, 2023
514dad5
Fix compile error
QueensGambit Aug 4, 2023
d51ff2f
Disable 960 Support for now due to problems
QueensGambit Aug 7, 2023
0aa57a4
Fix init of second argument in first_and_second_max()
QueensGambit Aug 7, 2023
db35e1d
Remove init of additional nodes in NodeData
QueensGambit Aug 7, 2023
3f4a5a2
Remove Virtual_Weight for now
QueensGambit Aug 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions engine/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ if (MODE_CHESS)
# add_definitions(-DVERSION=2)
# add_definitions(-DSUB_VERSION=8)
add_definitions(-DVERSION=3)
add_definitions(-DSUPPORT960)
# add_definitions(-DSUPPORT960)
endif()

if (MODE_LICHESS)
Expand All @@ -57,7 +57,7 @@ if (MODE_LICHESS)
add_definitions(-DATOMIC)
add_definitions(-DHORDE)
add_definitions(-DRACE)
add_definitions(-DSUPPORT960)
# add_definitions(-DSUPPORT960)
add_definitions(-DMCTS_TB_SUPPORT)
add_definitions(-DVERSION=1)
endif()
Expand Down
6 changes: 4 additions & 2 deletions engine/src/agents/config/searchsettings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ SearchSettings::SearchSettings():
nodePolicyTemperature(1.0f),
qValueWeight(1.0f),
qVetoDelta(0.4f),
virtualLoss(1.0f),
verbose(true),
epsilonChecksCounter(100),
useMCGS(true),
Expand All @@ -49,7 +48,10 @@ SearchSettings::SearchSettings():
epsilonGreedyCounter(20),
reuseTree(true),
mctsSolver(false),
searchPlayerMode(MODE_TWO_PLAYER)
searchPlayerMode(MODE_TWO_PLAYER),
virtualStyle(VIRTUAL_VISIT),
virtualMixThreshold(1000),
virtualOffsetStrenght(0.001)
{

}
14 changes: 13 additions & 1 deletion engine/src/agents/config/searchsettings.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ enum SearchPlayerMode {
MODE_TWO_PLAYER
};

enum VirtualStyle {
VIRTUAL_LOSS,
VIRTUAL_VISIT,
VIRTUAL_OFFSET,
VIRTUAL_MIX
};

struct SearchSettings
{
uint16_t multiPV;
Expand All @@ -48,7 +55,6 @@ struct SearchSettings
float qValueWeight;
// describes how much better the highest Q-Value has to be to replace the candidate move with the highest visit count
float qVetoDelta;
uint_fast32_t virtualLoss;
bool verbose;
uint_fast8_t epsilonChecksCounter;
// bool enhanceCaptures; currently not support
Expand All @@ -75,6 +81,12 @@ struct SearchSettings
bool mctsSolver;
// Defines the nubmer of players within the MCTS search. Available are MODE_SINGLE_PLAYER and MODE_TWO_PLAYER
SearchPlayerMode searchPlayerMode;
// Define the virtual style to avoid conflict between different threads in within the same mini-batch
VirtualStyle virtualStyle;
// Defines the number of visits to switch from virtual-visit to virtual-loss
uint_fast32_t virtualMixThreshold;
// Defines the strength of the virtual offset
double virtualOffsetStrenght;
SearchSettings();

};
Expand Down
70 changes: 55 additions & 15 deletions engine/src/node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ bool Node::is_sorted() const
return sorted;
}

double Node::get_q_sum(ChildIdx childIdx, float virtualLoss) const
double Node::get_q_sum_virtual_loss(ChildIdx childIdx) const
{
return get_child_number_visits(childIdx) * double(get_q_value(childIdx)) + get_virtual_loss_counter(childIdx) * virtualLoss;
return get_child_number_visits(childIdx) * double(get_q_value(childIdx)) + get_virtual_loss_counter(childIdx);
}

bool Node::is_transposition() const
Expand Down Expand Up @@ -504,18 +504,28 @@ bool Node::has_nn_results() const
return hasNNResults;
}

void Node::apply_virtual_loss_to_child(ChildIdx childIdx, uint_fast32_t virtualLoss)
void Node::apply_virtual_loss_to_child(ChildIdx childIdx, const SearchSettings* searchSettings)
{
// update the stats of the parent node
// make it look like if one has lost X games from this node forward where X is the virtual loss value
// temporarily reduce the attraction of this node by applying a virtual loss /
// the effect of virtual loss will be undone if the playout is over
d->qValues[childIdx] = (double(d->qValues[childIdx]) * d->childNumberVisits[childIdx] - virtualLoss) / double(d->childNumberVisits[childIdx] + virtualLoss);
switch (get_virtual_style(searchSettings, d->childNumberVisits[childIdx])) {
case VIRTUAL_LOSS:
d->qValues[childIdx] = (double(d->qValues[childIdx]) * d->childNumberVisits[childIdx] - 1) / double(d->childNumberVisits[childIdx] + 1);
break;
case VIRTUAL_OFFSET:
d->qValues[childIdx] -= searchSettings->virtualOffsetStrenght;
case VIRTUAL_VISIT: ; // ignore
case VIRTUAL_MIX: ; // unreachable
}

// virtual increase the number of visits
d->childNumberVisits[childIdx] += virtualLoss;
d->visitSum += virtualLoss;
++d->childNumberVisits[childIdx];
++d->visitSum;

// increment virtual loss counter
update_virtual_loss_counter<true>(childIdx, virtualLoss);
update_virtual_loss_counter<true>(childIdx);
}

float Node::get_q_value(ChildIdx childIdx) const
Expand Down Expand Up @@ -642,20 +652,29 @@ uint32_t Node::get_real_visits(ChildIdx childIdx) const
return d->childNumberVisits[childIdx] - d->virtualLossCounter[childIdx];
}

void backup_collision(float virtualLoss, const Trajectory& trajectory) {
void backup_collision(const SearchSettings* searchSettings, const Trajectory& trajectory) {
for (auto it = trajectory.rbegin(); it != trajectory.rend(); ++it) {
it->node->revert_virtual_loss(it->childIdx, virtualLoss);
it->node->revert_virtual_loss(it->childIdx, searchSettings);
}
}

void Node::revert_virtual_loss(ChildIdx childIdx, float virtualLoss)
void Node::revert_virtual_loss(ChildIdx childIdx, const SearchSettings* searchSettings)
{
lock();
d->qValues[childIdx] = (double(d->qValues[childIdx]) * d->childNumberVisits[childIdx] + virtualLoss) / (d->childNumberVisits[childIdx] - virtualLoss);
d->childNumberVisits[childIdx] -= virtualLoss;
d->visitSum -= virtualLoss;
switch (get_virtual_style(searchSettings, d->childNumberVisits[childIdx])) {
case VIRTUAL_LOSS:
d->qValues[childIdx] = (double(d->qValues[childIdx]) * d->childNumberVisits[childIdx] + 1) / (d->childNumberVisits[childIdx] - 1);
break;
case VIRTUAL_OFFSET:
d->qValues[childIdx] += searchSettings->virtualOffsetStrenght;
case VIRTUAL_MIX: ; // ignore
case VIRTUAL_VISIT: ; // ignore
}
--d->childNumberVisits[childIdx];
--d->visitSum;

// decrement virtual loss counter
update_virtual_loss_counter<false>(childIdx, virtualLoss);
update_virtual_loss_counter<false>(childIdx);
unlock();
}

Expand Down Expand Up @@ -990,6 +1009,27 @@ void Node::disable_action(size_t childIdxForParent)
d->qValues[childIdxForParent] = -INT_MAX;
}

double Node::get_transposition_q_value(const SearchSettings *searchSettings, ChildIdx childIdx, uint_fast32_t transposVisits)
{
double transposQValue;
switch(get_virtual_style(searchSettings, d->childNumberVisits[childIdx])) {
case VIRTUAL_LOSS:
transposQValue = get_q_sum_virtual_loss(childIdx) / transposVisits;
break;
case VIRTUAL_VISIT:
transposQValue = get_q_value(childIdx);
break;
case VIRTUAL_OFFSET:
transposQValue = double(get_q_value(childIdx)) + get_virtual_loss_counter(childIdx) * searchSettings->virtualOffsetStrenght;
case VIRTUAL_MIX: ;
// unreachable
}
if (searchSettings->searchPlayerMode == MODE_TWO_PLAYER) {
return -transposQValue;
}
return transposQValue;
}

void Node::enhance_moves(const SearchSettings* searchSettings)
{
// if (!searchSettings->enhanceChecks && !searchSettings->enhanceCaptures) {
Expand Down Expand Up @@ -1295,7 +1335,7 @@ bool is_terminal_value(float value)
return (value == WIN_VALUE || value == DRAW_VALUE || value == LOSS_VALUE);
}

float get_transposition_q_value(uint_fast32_t transposVisits, double transposQValue, double targetQValue)
float get_transposition_backup_value(uint_fast32_t transposVisits, double transposQValue, double targetQValue)
{
return std::clamp(transposVisits * (targetQValue - transposQValue) + targetQValue, double(LOSS_VALUE), double(WIN_VALUE));
}
Expand Down
79 changes: 55 additions & 24 deletions engine/src/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,16 @@ struct NodeAndBudget {
node(node), budget(budget), curState(state) {}
};

inline VirtualStyle get_virtual_style(const SearchSettings* searchSettings, uint_fast32_t visits) {
if (searchSettings->virtualStyle == VIRTUAL_MIX) {
if (visits > searchSettings->virtualMixThreshold) {
return VIRTUAL_LOSS;
}
return VIRTUAL_VISIT;
}
return searchSettings->virtualStyle;
}

class Node
{
private:
Expand Down Expand Up @@ -190,28 +200,42 @@ class Node
void revert_virtual_loss_and_update(ChildIdx childIdx, float value, const SearchSettings* searchSettings, bool solveForTerminal)
{
lock();
// decrement virtual loss counter
update_virtual_loss_counter<false>(childIdx, searchSettings->virtualLoss);

valueSum += value;
++realVisitsSum;

if (d->childNumberVisits[childIdx] == searchSettings->virtualLoss) {
if (d->childNumberVisits[childIdx] == 1) {
// set new Q-value based on return
// (the initialization of the Q-value was by Q_INIT which we don't want to recover.)
d->qValues[childIdx] = value;
}
else {
// revert virtual loss and update the Q-value
assert(d->childNumberVisits[childIdx] != 0);
d->qValues[childIdx] = (double(d->qValues[childIdx]) * d->childNumberVisits[childIdx] + searchSettings->virtualLoss + value) / d->childNumberVisits[childIdx];
uint_fast32_t childRealVisit;
double newQVal;
switch(get_virtual_style(searchSettings, d->childNumberVisits[childIdx])) {
case VIRTUAL_LOSS:
d->qValues[childIdx] = (double(d->qValues[childIdx]) * d->childNumberVisits[childIdx] + 1 + value) / d->childNumberVisits[childIdx];
break;
case VIRTUAL_VISIT:
childRealVisit = get_real_visits(childIdx);
d->qValues[childIdx] = (double(d->qValues[childIdx]) * childRealVisit + value) / (childRealVisit + 1);
break;
case VIRTUAL_OFFSET:
newQVal = double(d->qValues[childIdx]) + d->virtualLossCounter[childIdx] * searchSettings->virtualOffsetStrenght;
newQVal = (newQVal * childRealVisit + value) / (childRealVisit + 1.0);
d->qValues[childIdx] = newQVal - ((d->virtualLossCounter[childIdx]-1) * searchSettings->virtualOffsetStrenght);
case VIRTUAL_MIX: ;
// unreachable
}

assert(!isnan(d->qValues[childIdx]));
}

if (searchSettings->virtualLoss != 1) {
d->childNumberVisits[childIdx] -= size_t(searchSettings->virtualLoss) - 1;
d->visitSum -= size_t(searchSettings->virtualLoss) - 1;
}
// decrement virtual loss counter
update_virtual_loss_counter<false>(childIdx);

if (freeBackup) {
++d->freeVisits;
}
Expand All @@ -225,7 +249,7 @@ class Node
* @brief revert_virtual_loss Reverts the virtual loss for a target node
* @param childIdx Index to the child node to update
*/
void revert_virtual_loss(ChildIdx childIdx, float virtualLoss);
void revert_virtual_loss(ChildIdx childIdx, const SearchSettings* searchSettings);

bool is_playout_node() const;

Expand Down Expand Up @@ -259,7 +283,7 @@ class Node
double get_value_sum() const;
uint32_t get_real_visits() const;

void apply_virtual_loss_to_child(ChildIdx childIdx, uint_fast32_t virtualLoss);
void apply_virtual_loss_to_child(ChildIdx childIdx, const SearchSettings* searchSettings);

void increment_no_visit_idx();
void fully_expand_node();
Expand Down Expand Up @@ -470,17 +494,17 @@ class Node
*/
void decrement_number_parents();

double get_q_sum(ChildIdx childIdx, float virtualLoss) const;
double get_q_sum_virtual_loss(ChildIdx childIdx) const;

template<bool increment>
void update_virtual_loss_counter(ChildIdx childIdx, float virtualLoss)
void update_virtual_loss_counter(ChildIdx childIdx)
{
if (increment) {
d->virtualLossCounter[childIdx] += virtualLoss;
++d->virtualLossCounter[childIdx];
}
else {
assert(d->virtualLossCounter[childIdx] != 0);
d->virtualLossCounter[childIdx] -= virtualLoss;
--d->virtualLossCounter[childIdx];
}
}

Expand Down Expand Up @@ -515,6 +539,17 @@ class Node

uint32_t get_number_of_nodes() const;


/**
* @brief get_transposition_q_value Returns the Q-value (without virtualLoss) which connects to the transposition node.
* The q-Value is also multiplied by -1 if searchSettings->searchPlayerMode == MODE_TWO_PLAYER.
* @param currentNode Current node
* @param childIdx child index
* @param transposVisits Number of visits connecting to the transposition node
* @return Q-Value converted to double
*/
double get_transposition_q_value(const SearchSettings* searchSettings, ChildIdx childIdx, uint_fast32_t transposVisits);

private:
/**
* @brief reserve_full_memory Reserves memory for all available child nodes
Expand Down Expand Up @@ -765,12 +800,12 @@ bool is_terminal_value(float value);
/**
* @brief backup_collision Iteratively removes the virtual loss of the collision event that occurred
* @param rootNode Root node of the tree
* @param virtualLoss Virtual loss value
* @param searchSettings Search settings struct
* @param trajectory Trajectory on how to get to the given collision
*/
void backup_collision(float virtualLoss, const Trajectory& trajectory);
void backup_collision(const SearchSettings* searchSettings, const Trajectory& trajectory);

float get_transposition_q_value(uint_fast32_t transposVisits, double transposQValue, double masterQValue);
float get_transposition_backup_value(uint_fast32_t transposVisits, double transposQValue, double masterQValue);

/**
* @brief backup_value Iteratively backpropagates a value prediction across all of the parents for this node.
Expand All @@ -788,15 +823,12 @@ void backup_value(float value, const SearchSettings* searchSettings, const Traje
if (targetQValue != 0) {
const uint_fast32_t transposVisits = it->node->get_real_visits(it->childIdx);
if (transposVisits != 0) {
const double transposQValue = -it->node->get_q_sum(it->childIdx, searchSettings->virtualLoss) / transposVisits;
value = get_transposition_q_value(transposVisits, transposQValue, targetQValue);
const double transposQValue = it->node->get_transposition_q_value(searchSettings, it->childIdx, transposVisits);
value = get_transposition_backup_value(transposVisits, transposQValue, targetQValue);
}
}
switch (searchSettings->searchPlayerMode) {
case MODE_TWO_PLAYER:
if (searchSettings->searchPlayerMode == MODE_TWO_PLAYER) {
value = -value;
break;
case MODE_SINGLE_PLAYER: ;
}
freeBackup ? it->node->revert_virtual_loss_and_update<true>(it->childIdx, value, searchSettings, solveForTerminal) :
it->node->revert_virtual_loss_and_update<false>(it->childIdx, value, searchSettings, solveForTerminal);
Expand All @@ -818,5 +850,4 @@ void backup_value(float value, const SearchSettings* searchSettings, const Traje
*/
bool is_transposition_verified(const Node* node, const StateObj* state);


#endif // NODE_H
4 changes: 0 additions & 4 deletions engine/src/nodedata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,6 @@ void NodeData::reserve_initial_space()
virtualLossCounter.reserve(initSize);
nodeTypes.reserve(initSize);
add_empty_node();
if (initSize > 1) {
add_empty_node();
++noVisitIdx;
}
}

NodeData::NodeData():
Expand Down
Loading
Loading