Skip to content

Commit

Permalink
MoE MCTS fixes (#218)
Browse files Browse the repository at this point in the history
* Fix rawNetAgent and compile bugs for RL mode
Change ordering of netBatchesVector (threads go first now)
Use raw pointers in NeuralNetAPIUser now to resolve ownership problem
Remove create_new_net_batches
Add fill_nn_vectors
Add fill_single_nn_vector

* Make Nodes_Limit available in RL mode

* remove unused code
  • Loading branch information
QueensGambit authored Sep 6, 2024
1 parent 04509dd commit 6b43b32
Show file tree
Hide file tree
Showing 14 changed files with 112 additions and 95 deletions.
2 changes: 1 addition & 1 deletion engine/src/agents/agent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ void Agent::set_must_wait(bool value)
mustWait = value;
}

Agent::Agent(vector<unique_ptr<NeuralNetAPI>>& nets, PlaySettings* playSettings, bool verbose):
Agent::Agent(const vector<unique_ptr<NeuralNetAPI>>& nets, const PlaySettings* playSettings, bool verbose):
NeuralNetAPIUser(nets),
playSettings(playSettings), mustWait(true), verbose(verbose), isRunning(false)
{
Expand Down
4 changes: 2 additions & 2 deletions engine/src/agents/agent.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class Agent : public NeuralNetAPIUser

protected:
SearchLimits* searchLimits;
PlaySettings* playSettings;
const PlaySettings* playSettings;
StateObj* state;
EvalInfo* evalInfo;
// Protect the isRunning attribute and makes sure that the stop() command can only be called after the search has actually been started.
Expand All @@ -72,7 +72,7 @@ class Agent : public NeuralNetAPIUser
bool isRunning;

public:
Agent(vector<unique_ptr<NeuralNetAPI>>& nets, PlaySettings* playSettings, bool verbose);
Agent(const vector<unique_ptr<NeuralNetAPI>>& nets, const PlaySettings* playSettings, bool verbose);

/**
* @brief perform_action Selects an action based on the evaluation result
Expand Down
10 changes: 3 additions & 7 deletions engine/src/agents/mctsagent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
#include "../node.h"
#include "../util/communication.h"

MCTSAgent::MCTSAgent(vector<unique_ptr<NeuralNetAPI>>& netSingleVector, vector<vector<unique_ptr<NeuralNetAPI>>>& netBatchesVector,
MCTSAgent::MCTSAgent(const vector<unique_ptr<NeuralNetAPI>>& netSingleVector, const vector<vector<unique_ptr<NeuralNetAPI>>>& netBatchesVector,
SearchSettings* searchSettings, PlaySettings* playSettings):
Agent(netSingleVector, playSettings, true),
searchSettings(searchSettings),
Expand All @@ -51,12 +51,8 @@ MCTSAgent::MCTSAgent(vector<unique_ptr<NeuralNetAPI>>& netSingleVector, vector<v
{
mapWithMutex.hashTable.reserve(1e6);

for (auto i = 0; i < searchSettings->threads; ++i) {
vector<unique_ptr<NeuralNetAPI>> netBatchVector; // stores the ith element of all netBatches in netBatchesVector
for (auto& netBatches : netBatchesVector) {
netBatchVector.push_back(std::move(netBatches[i]));
}
searchThreads.emplace_back(new SearchThread(netBatchVector, searchSettings, &mapWithMutex));
for (size_t idx = 0; idx < searchSettings->threads; ++idx) {
searchThreads.emplace_back(new SearchThread(netBatchesVector[idx], searchSettings, &mapWithMutex));
}
timeManager = make_unique<TimeManager>(searchSettings->randomMoveFactor);
generator = default_random_engine(r());
Expand Down
4 changes: 2 additions & 2 deletions engine/src/agents/mctsagent.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ class MCTSAgent : public Agent
unique_ptr<ThreadManager> threadManager;
bool reachedTablebases;
public:
MCTSAgent(vector<unique_ptr<NeuralNetAPI>>& netSingleVector,
vector<vector<unique_ptr<NeuralNetAPI>>>& netBatchesVector,
MCTSAgent(const vector<unique_ptr<NeuralNetAPI>>& netSingleVector,
const vector<vector<unique_ptr<NeuralNetAPI>>>& netBatchesVector,
SearchSettings* searchSettings,
PlaySettings* playSettings);
~MCTSAgent();
Expand Down
2 changes: 1 addition & 1 deletion engine/src/agents/rawnetagent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

using blaze::HybridVector;

RawNetAgent::RawNetAgent(vector<unique_ptr<NeuralNetAPI>>& nets, PlaySettings* playSettings, bool verbose, SearchSettings* searchSettings):
RawNetAgent::RawNetAgent(const vector<unique_ptr<NeuralNetAPI>>& nets, const PlaySettings* playSettings, bool verbose, const SearchSettings* searchSettings):
Agent(nets, playSettings, verbose),
searchSettings(searchSettings)
{
Expand Down
4 changes: 2 additions & 2 deletions engine/src/agents/rawnetagent.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ using namespace crazyara;
class RawNetAgent: public Agent
{
public:
SearchSettings* searchSettings;
const SearchSettings* searchSettings;

RawNetAgent(vector<unique_ptr<NeuralNetAPI>>& nets, PlaySettings* playSettings, bool verbose, SearchSettings* searchSettings);
RawNetAgent(const vector<unique_ptr<NeuralNetAPI>>& nets, const PlaySettings* playSettings, bool verbose, const SearchSettings* searchSettings);
RawNetAgent(const RawNetAgent&) = delete;
RawNetAgent& operator=(RawNetAgent const&) = delete;

Expand Down
9 changes: 6 additions & 3 deletions engine/src/nn/neuralnetapi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,12 @@ Version read_version_from_string(const string &modelFileName)
GamePhase read_game_phase_from_string(const string& modelDir)
{
// use last char of modelDir and convert to int by subtracting '0'
// TODO throw errors if necessary (if last letter is not a digit)

int gamePhase = (modelDir[modelDir.length() - 2]) - '0';
// assume phase 0 if last character is not a digit
char phaseChar = modelDir[modelDir.length() - 2];
if (!std::isdigit(phaseChar)) {
return GamePhase(0);
}
int gamePhase = phaseChar - '0';
return GamePhase(gamePhase);
}

Expand Down
7 changes: 4 additions & 3 deletions engine/src/nn/neuralnetapiuser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@
#include "common.h"
#endif

NeuralNetAPIUser::NeuralNetAPIUser(vector<unique_ptr<NeuralNetAPI>>& netsNew) :
NeuralNetAPIUser::NeuralNetAPIUser(const vector<unique_ptr<NeuralNetAPI>>& netsNew) :
auxiliaryOutputs(nullptr)
{
nets = std::move(netsNew);
for (size_t idx = 0; idx < netsNew.size(); idx++) {
nets.push_back(netsNew[idx].get());
}
numPhases = nets.size();

for (unsigned int i = 0; i < numPhases; i++)
{
GamePhase phaseOfNetI = nets[i]->get_game_phase();
Expand Down
4 changes: 2 additions & 2 deletions engine/src/nn/neuralnetapiuser.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
class NeuralNetAPIUser
{
protected:
vector<unique_ptr<NeuralNetAPI>> nets; // vector of net objects
vector<NeuralNetAPI*> nets; // vector of net objects
unsigned int numPhases;
std::map<GamePhase, int> phaseToNetsIndex; // maps a GamePhase to the index of the net that should be used

Expand All @@ -50,7 +50,7 @@ class NeuralNetAPIUser
float* auxiliaryOutputs;

public:
NeuralNetAPIUser(vector<unique_ptr<NeuralNetAPI>>& netsNew);
NeuralNetAPIUser(const vector<unique_ptr<NeuralNetAPI>>& netsNew);
~NeuralNetAPIUser();
NeuralNetAPIUser(NeuralNetAPIUser&) = delete;

Expand Down
7 changes: 3 additions & 4 deletions engine/src/searchthread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ size_t SearchThread::get_max_depth() const
return depthMax;
}

SearchThread::SearchThread(vector<unique_ptr<NeuralNetAPI>>& netBatchVector, const SearchSettings* searchSettings, MapWithMutex* mapWithMutex):
SearchThread::SearchThread(const vector<unique_ptr<NeuralNetAPI>>& netBatchVector, const SearchSettings* searchSettings, MapWithMutex* mapWithMutex):
NeuralNetAPIUser(netBatchVector),
rootNode(nullptr), rootState(nullptr), newState(nullptr), // will be be set via setter methods
newNodes(make_unique<FixedVector<Node*>>(searchSettings->batchSize)),
Expand All @@ -68,7 +68,7 @@ void SearchThread::set_root_node(Node *value)
visitsPreSearch = rootNode->get_visits();
}

void SearchThread::set_search_limits(SearchLimits *s)
void SearchThread::set_search_limits(const SearchLimits *s)
{
searchLimits = s;
}
Expand Down Expand Up @@ -116,7 +116,7 @@ Node *SearchThread::get_root_node() const
return rootNode;
}

SearchLimits *SearchThread::get_search_limits() const
const SearchLimits *SearchThread::get_search_limits() const
{
return searchLimits;
}
Expand Down Expand Up @@ -398,7 +398,6 @@ void SearchThread::thread_iteration()
GamePhase majorityPhase = pr->first;

phaseCountMap.clear();

// query the network that corresponds to the majority phase
nets[phaseToNetsIndex.at(majorityPhase)]->predict(inputPlanes, valueOutputs, probOutputs, auxiliaryOutputs);
set_nn_results_to_child_nodes();
Expand Down
8 changes: 4 additions & 4 deletions engine/src/searchthread.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class SearchThread : public NeuralNetAPIUser

MapWithMutex* mapWithMutex;
const SearchSettings* searchSettings;
SearchLimits* searchLimits;
const SearchLimits* searchLimits;
size_t tbHits;
size_t depthSum;
size_t depthMax;
Expand All @@ -89,7 +89,7 @@ class SearchThread : public NeuralNetAPIUser
* @param searchSettings Given settings for this search run
* @param MapWithMutex Handle to the hash table
*/
SearchThread(vector<unique_ptr<NeuralNetAPI>>& netBatchVector, const SearchSettings* searchSettings, MapWithMutex* mapWithMutex);
SearchThread(const vector<unique_ptr<NeuralNetAPI>>& netBatchVector, const SearchSettings* searchSettings, MapWithMutex* mapWithMutex);

/**
* @brief create_mini_batch Creates a mini-batch of new unexplored nodes.
Expand Down Expand Up @@ -123,9 +123,9 @@ class SearchThread : public NeuralNetAPIUser
void stop();

// Getter, setter functions
void set_search_limits(SearchLimits *s);
void set_search_limits(const SearchLimits *s);
Node* get_root_node() const;
SearchLimits *get_search_limits() const;
const SearchLimits *get_search_limits() const;
void set_root_node(Node *value);
bool is_running() const;
void set_is_running(bool value);
Expand Down
115 changes: 61 additions & 54 deletions engine/src/uci/crazyara.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ CrazyAra::CrazyAra():
rawAgent(nullptr),
mctsAgent(nullptr),
#ifdef USE_RL
netSingleContender(nullptr),
mctsAgentContender(nullptr),
#endif
searchSettings(SearchSettings()),
Expand Down Expand Up @@ -368,9 +367,8 @@ void CrazyAra::arena(istringstream &is)
{
prepare_search_config_structs();
SelfPlay selfPlay(rawAgent.get(), mctsAgent.get(), &searchLimits, &playSettings, &rlSettings, Options);
netSingleContender = create_new_net_single(Options["Model_Directory_Contender"]);
netBatchesContender = create_new_net_batches(Options["Model_Directory_Contender"]);
mctsAgentContender = create_new_mcts_agent(netSingleContender.get(), netBatchesContender, &searchSettings);
fill_nn_vectors(Options["Model_Directory_Contender"], netSingleContenderVector, netBatchesContenderVector);
mctsAgentContender = create_new_mcts_agent(netSingleContenderVector, netBatchesContenderVector, &searchSettings);
size_t numberOfGames;
is >> numberOfGames;
TournamentResult tournamentResult = selfPlay.go_arena(mctsAgentContender.get(), numberOfGames, variant);
Expand Down Expand Up @@ -401,12 +399,11 @@ void CrazyAra::multimodel_arena(istringstream &is, const string &modelDirectory1
is >> folder;
modelDir1 = "m" + std::to_string(folder) + "/";
}
auto mcts1 = create_new_mcts_agent(netSingle.get(), netBatches, &searchSettings, static_cast<MCTSAgentType>(type));
auto mcts1 = create_new_mcts_agent(netSingleVector, netBatchesVector, &searchSettings, static_cast<MCTSAgentType>(type));
if (modelDir1 != "")
{
netSingle = create_new_net_single(modelDir1);
netBatches = create_new_net_batches(modelDir1);
mcts1 = create_new_mcts_agent(netSingle.get(), netBatches, &searchSettings, static_cast<MCTSAgentType>(type));
fill_nn_vectors(modelDir1, netSingleVector, netBatchesVector);
mcts1 = create_new_mcts_agent(netSingleVector, netBatchesVector, &searchSettings, static_cast<MCTSAgentType>(type));
}

is >> type;
Expand All @@ -416,12 +413,11 @@ void CrazyAra::multimodel_arena(istringstream &is, const string &modelDirectory1
is >> folder;
modelDir2 = "m" + std::to_string(folder) + "/";
}
auto mcts2 = create_new_mcts_agent(netSingle.get(), netBatches, &searchSettings, static_cast<MCTSAgentType>(type));
auto mcts2 = create_new_mcts_agent(netSingleVector, netBatchesVector, &searchSettings, static_cast<MCTSAgentType>(type));
if (modelDir2 != "")
{
netSingleContender = create_new_net_single(modelDir2);
netBatchesContender = create_new_net_batches(modelDir2);
mcts2 = create_new_mcts_agent(netSingleContender.get(), netBatchesContender, &searchSettings, static_cast<MCTSAgentType>(type));
fill_nn_vectors(modelDir2, netSingleContenderVector, netBatchesContenderVector);
mcts2 = create_new_mcts_agent(netSingleContenderVector, netBatchesContenderVector, &searchSettings, static_cast<MCTSAgentType>(type));
}

SelfPlay selfPlay(rawAgent.get(), mcts1.get(), &searchLimits, &playSettings, &rlSettings, Options);
Expand Down Expand Up @@ -549,13 +545,55 @@ void CrazyAra::init()
#endif
}

void CrazyAra::fill_single_nn_vector(const string& modelDirectory, vector<unique_ptr<NeuralNetAPI>>& netSingleVector, vector<vector<unique_ptr<NeuralNetAPI>>>& netBatchesVector)
{
unique_ptr<NeuralNetAPI> netSingleTmp = create_new_net(modelDirectory, int(Options["First_Device_ID"]), 1);
netSingleTmp->validate_neural_network();
netSingleVector.push_back(std::move(netSingleTmp));

size_t idx = 0;
for (int deviceId = int(Options["First_Device_ID"]); deviceId <= int(Options["Last_Device_ID"]); ++deviceId) {
for (size_t i = 0; i < size_t(Options["Threads"]); ++i) {
unique_ptr<NeuralNetAPI> netBatchesTmp = create_new_net(modelDirectory, deviceId, searchSettings.batchSize);
netBatchesTmp->validate_neural_network();
netBatchesVector[idx].push_back(std::move(netBatchesTmp));
++idx;
}
}
}

void CrazyAra::fill_nn_vectors(const string& modelDirectory, vector<unique_ptr<NeuralNetAPI>>& netSingleVector, vector<vector<unique_ptr<NeuralNetAPI>>>& netBatchesVector)
{
netSingleVector.clear();
netBatchesVector.clear();
// threads is the first dimension, the phase are the 2nd dimension
netBatchesVector.resize(Options["Threads"] * get_num_gpus(Options));

// early return if no phases are used
for (const auto& entry : fs::directory_iterator(modelDirectory)) {
if (!fs::is_directory(entry.path())) {
fill_single_nn_vector(modelDirectory, netSingleVector, netBatchesVector);
return;
}
else {
break;
}
}

// analyse directory to get num phases
for (const auto& entry : fs::directory_iterator(modelDirectory)) {
std::cout << entry.path().generic_string() << std::endl;

fill_single_nn_vector(entry.path().generic_string(), netSingleVector, netBatchesVector);
}
}


template<bool verbose>
bool CrazyAra::is_ready()
{
bool hasReplied = false;
if (!networkLoaded) {
netSingleVector.clear();
netBatchesVector.clear();
const size_t timeoutMS = Options["Timeout_MS"];
TimeOutReadyThread timeoutThread(timeoutMS);
thread tTimeoutThread;
Expand All @@ -568,22 +606,10 @@ bool CrazyAra::is_ready()
init_rl_settings();
#endif

// analyse directory to get num phases
for (const auto& entry : fs::directory_iterator(string(Options["Model_Directory"]))) {
std::cout << entry.path().generic_string() << std::endl;

unique_ptr<NeuralNetAPI> netSingleTmp = create_new_net_single(entry.path().generic_string());
netSingleTmp->validate_neural_network();
vector<unique_ptr<NeuralNetAPI>> netBatchesTmp = create_new_net_batches(entry.path().generic_string());
netBatchesTmp.front()->validate_neural_network();

netSingleVector.push_back(std::move(netSingleTmp));
netBatchesVector.push_back(std::move(netBatchesTmp));
}
fill_nn_vectors(Options["Model_Directory"], netSingleVector, netBatchesVector);

mctsAgent = create_new_mcts_agent(netSingleVector, netBatchesVector, &searchSettings);
//rawAgent = make_unique<RawNetAgent>(netSingleVector, &playSettings, false, &searchSettings);
// TODO: rawAgent currently doesn't work (netSingleVector somehow doesn't contain any nets)
rawAgent = make_unique<RawNetAgent>(netSingleVector, &playSettings, false, &searchSettings);
StateConstants::init(mctsAgent->is_policy_map(), Options["UCI_Chess960"]);

timeoutThread.kill();
Expand Down Expand Up @@ -617,40 +643,21 @@ string CrazyAra::engine_info()
return ss.str();
}

unique_ptr<NeuralNetAPI> CrazyAra::create_new_net_single(const string& modelDirectory)
unique_ptr<NeuralNetAPI> CrazyAra::create_new_net(const string& modelDirectory, int deviceId, unsigned int batchSize)
{
#ifdef MXNET
return make_unique<MXNetAPI>(Options["Context"], int(Options["First_Device_ID"]), 1, modelDirectory, Options["Precision"], false);
#elif defined TENSORRT
return make_unique<TensorrtAPI>(int(Options["First_Device_ID"]), 1, modelDirectory, Options["Precision"]);
#elif defined OPENVINO
return make_unique<OpenVinoAPI>(int(Options["First_Device_ID"]), 1, modelDirectory, Options["Threads_NN_Inference"]);
#endif
return nullptr;
}

vector<unique_ptr<NeuralNetAPI>> CrazyAra::create_new_net_batches(const string& modelDirectory)
{
vector<unique_ptr<NeuralNetAPI>> netBatches;
#ifdef MXNET
#ifdef TENSORRT
const bool useTensorRT = bool(Options["Use_TensorRT"]);
#else
const bool useTensorRT = false;
#endif
return make_unique<MXNetAPI>(Options["Context"], deviceId, batchSize, modelDirectory, Options["Precision"], useTensorRT);
#elif defined TENSORRT
return make_unique<TensorrtAPI>(deviceId, batchSize, modelDirectory, Options["Precision"]);
#elif defined OPENVINO
return make_unique<OpenVinoAPI>(deviceId, batchSize, modelDirectory, Options["Threads_NN_Inference"]);
#endif
for (int deviceId = int(Options["First_Device_ID"]); deviceId <= int(Options["Last_Device_ID"]); ++deviceId) {
for (size_t i = 0; i < size_t(Options["Threads"]); ++i) {
#ifdef MXNET
netBatches.push_back(make_unique<MXNetAPI>(Options["Context"], deviceId, searchSettings.batchSize, modelDirectory, Options["Precision"], useTensorRT));
#elif defined TENSORRT
netBatches.push_back(make_unique<TensorrtAPI>(deviceId, searchSettings.batchSize, modelDirectory, Options["Precision"]));
#elif defined OPENVINO
netBatches.push_back(make_unique<OpenVinoAPI>(deviceId, searchSettings.batchSize, modelDirectory, Options["Threads_NN_Inference"]));
#endif
}
}
return netBatches;
return nullptr;
}

void CrazyAra::set_uci_option(istringstream &is, StateObj& state)
Expand Down
Loading

0 comments on commit 6b43b32

Please sign in to comment.