Skip to content

Commit

Permalink
Update mctsagent.cpp (#216)
Browse files Browse the repository at this point in the history
Only use phase selection for nets.size() > 1
  • Loading branch information
QueensGambit committed Oct 6, 2024
1 parent b4216ea commit f984efa
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions engine/src/agents/mctsagent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,12 @@ shared_ptr<Node> MCTSAgent::get_root_node_from_tree(StateObj *state)
void MCTSAgent::set_root_node_predictions()
{
state->get_state_planes(true, inputPlanes, nets.front()->get_version());
GamePhase currentPhase = state->get_phase(numPhases, searchSettings->gamePhaseDefinition);
nets[phaseToNetsIndex.at(currentPhase)]->predict(inputPlanes, valueOutputs, probOutputs, auxiliaryOutputs);
size_t netIdx = 0;
if (nets.size() > 1) {
GamePhase currentPhase = state->get_phase(numPhases, searchSettings->gamePhaseDefinition);
netIdx = phaseToNetsIndex.at(currentPhase);
}
nets[netIdx]->predict(inputPlanes, valueOutputs, probOutputs, auxiliaryOutputs);
size_t tbHits = 0;
fill_nn_results(0, nets[phaseToNetsIndex.at(currentPhase)]->is_policy_map(), valueOutputs, probOutputs, auxiliaryOutputs, rootNode.get(), tbHits,
rootState->mirror_policy(state->side_to_move()), searchSettings, rootNode->is_tablebase());
Expand Down

0 comments on commit f984efa

Please sign in to comment.