diff --git a/engine/src/agents/mctsagent.cpp b/engine/src/agents/mctsagent.cpp index ad651b63..b503c705 100644 --- a/engine/src/agents/mctsagent.cpp +++ b/engine/src/agents/mctsagent.cpp @@ -166,8 +166,12 @@ shared_ptr MCTSAgent::get_root_node_from_tree(StateObj *state) void MCTSAgent::set_root_node_predictions() { state->get_state_planes(true, inputPlanes, nets.front()->get_version()); - GamePhase currentPhase = state->get_phase(numPhases, searchSettings->gamePhaseDefinition); - nets[phaseToNetsIndex.at(currentPhase)]->predict(inputPlanes, valueOutputs, probOutputs, auxiliaryOutputs); + size_t netIdx = 0; + if (nets.size() > 1) { + GamePhase currentPhase = state->get_phase(numPhases, searchSettings->gamePhaseDefinition); + netIdx = phaseToNetsIndex.at(currentPhase); + } + nets[netIdx]->predict(inputPlanes, valueOutputs, probOutputs, auxiliaryOutputs); size_t tbHits = 0; fill_nn_results(0, nets[phaseToNetsIndex.at(currentPhase)]->is_policy_map(), valueOutputs, probOutputs, auxiliaryOutputs, rootNode.get(), tbHits, rootState->mirror_policy(state->side_to_move()), searchSettings, rootNode->is_tablebase());