Skip to content

Commit

Permalink
Add select_nn_index() for phase selection (#216)
Browse files Browse the repository at this point in the history
- return 0 if no phases is enabled

Set Game_Phase_Definition default to "lichess"
  • Loading branch information
QueensGambit committed Oct 6, 2024
1 parent 36582c6 commit b4216ea
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 15 deletions.
36 changes: 22 additions & 14 deletions engine/src/searchthread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -379,27 +379,35 @@ void SearchThread::create_mini_batch()
}
}

size_t SearchThread::select_nn_index()
{
if (nets.size() == 1) {
return 0;
}
// determine majority class in current batch
using pair_type = decltype(phaseCountMap)::value_type;
auto pr = std::max_element
(
std::begin(phaseCountMap), std::end(phaseCountMap),
[](const pair_type& p1, const pair_type& p2) {
return p1.second < p2.second;
}
);

GamePhase majorityPhase = pr->first;

phaseCountMap.clear();
return phaseToNetsIndex.at(majorityPhase);
}

void SearchThread::thread_iteration()
{
create_mini_batch();
#ifndef SEARCH_UCT
if (newNodes->size() != 0) {

// determine majority class in current batch
using pair_type = decltype(phaseCountMap)::value_type;
auto pr = std::max_element
(
std::begin(phaseCountMap), std::end(phaseCountMap),
[](const pair_type& p1, const pair_type& p2) {
return p1.second < p2.second;
}
);

GamePhase majorityPhase = pr->first;

phaseCountMap.clear();
// query the network that corresponds to the majority phase
nets[phaseToNetsIndex.at(majorityPhase)]->predict(inputPlanes, valueOutputs, probOutputs, auxiliaryOutputs);
nets[select_nn_index()]->predict(inputPlanes, valueOutputs, probOutputs, auxiliaryOutputs);
set_nn_results_to_child_nodes();
}
#endif
Expand Down
7 changes: 7 additions & 0 deletions engine/src/searchthread.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,13 @@ class SearchThread : public NeuralNetAPIUser
* @return Q-Value converted to double
*/
double get_current_transposition_q_value(const Node* currentNode, ChildIdx childIdx, uint_fast32_t transposVisits);

/**
* @brief select_nn_index Returns the index according to the majority phase in the current batch.
* If no phases is enabled, 0 will be returned.
* @return Majority phase index or 0
*/
size_t select_nn_index();
};

void run_search_thread(SearchThread *t);
Expand Down
2 changes: 1 addition & 1 deletion engine/src/uci/optionsuci.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ void OptionsUCI::init(OptionsMap &o)
o["Use_Raw_Network"] << Option(false);
o["Virtual_Style"] << Option("virtual_mix", { "virtual_loss", "virtual_visit", "virtual_offset", "virtual_mix" });
o["Virtual_Mix_Threshold"] << Option(1000, 1, 99999999);
o["Game_Phase_Definition"] << Option("movecount", { "lichess", "movecount"});
o["Game_Phase_Definition"] << Option("lichess", { "lichess", "movecount"});
// additional UCI-Options for RL only
#ifdef USE_RL
o["Centi_Node_Random_Factor"] << Option(10, 0, 100);
Expand Down

0 comments on commit b4216ea

Please sign in to comment.