Skip to content

Commit

Permalink
RL pytorch fix (#222)
Browse files Browse the repository at this point in the history
* Fix runtime errors in rl_loop.py

* Add save_cur_phase(const StateObj *pos)

 and all needed changes for it

* Update Selfplay() constructor

and make settings const

* add const to get_num_phases()

and remove const from SearchLimits

* add save_cur_phase() in header

and fix constructor in header of Selfplay()

* add const to rlSettings

* Add missing ) in traindataexporter.cpp

* Add phaseVector export

* Add save_cur_phase(pos); to traindataexporter.cpp

* Update rl_training.py: get_validation_data()

* Update trainer_agent_pytorch.py

 check if delete path is a file
  • Loading branch information
QueensGambit authored Jan 10, 2025
1 parent c045744 commit 6197410
Show file tree
Hide file tree
Showing 10 changed files with 65 additions and 17 deletions.
3 changes: 2 additions & 1 deletion DeepCrazyhouse/src/training/trainer_agent_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,8 @@ def delete_previous_weights(self):
# delete previous weights to save space
files = glob.glob(self.tc.export_dir + 'weights/*')
for f in files:
os.remove(f)
if os.path.isfile(f):
os.remove(f)

def _get_train_loader(self, part_id):
# load one chunk of the dataset from memory
Expand Down
5 changes: 5 additions & 0 deletions engine/src/nn/neuralnetapiuser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,8 @@ void NeuralNetAPIUser::run_inference(uint_fast16_t iterations)
}
}

unsigned int NeuralNetAPIUser::get_num_phases() const
{
return numPhases;
}

6 changes: 6 additions & 0 deletions engine/src/nn/neuralnetapiuser.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ class NeuralNetAPIUser
* @param iterations Number of iterations to run
*/
void run_inference(uint_fast16_t iterations);

/**
* @brief get_num_phases Returns the number of phases
* @return numPhases
*/
unsigned int get_num_phases() const;
};

#endif // NEURALNETAPIUSER_H
6 changes: 3 additions & 3 deletions engine/src/rl/rl_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self, args, rl_config, nb_arena_games=100, lr_reduction=0.0001, k_s
self.rl_config = rl_config

self.file_io = FileIO(orig_binary_name=self.rl_config.binary_name, binary_dir=self.rl_config.binary_dir,
uci_variant=self.rl_config.uci_variant, framework=self.tc.framework)
uci_variant=self.rl_config.uci_variant)
self.binary_io = None

if nb_arena_games % 2 == 1:
Expand Down Expand Up @@ -85,7 +85,7 @@ def initialize(self, is_arena=False):
is_arena: Signals that UCI option should be set for arena comparison
:return:
"""
self.model_name = self.file_io.get_current_model_weight_file()
self.model_name = self.file_io.get_current_model_tar_file()
self.binary_io = BinaryIO(binary_path=self.file_io.binary_dir+self.current_binary_name)
self.binary_io.set_uci_options(self.rl_config.uci_variant, self.args.context, self.args.device_id,
self.rl_config.precision, self.file_io.model_dir,
Expand All @@ -105,7 +105,7 @@ def check_for_new_model(self):
self.nn_update_index = extract_nn_update_idx_from_binary_name(self.current_binary_name)

# If a new model is available, the binary name has also changed
model_name = self.file_io.get_current_model_weight_file()
model_name = self.file_io.get_current_model_tar_file()
if model_name != "" and model_name != self.model_name:
logging.info("Loading new model: %s" % model_name)

Expand Down
2 changes: 1 addition & 1 deletion engine/src/rl/rl_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def update_network(queue, nn_update_idx: int, tar_filename: Path, convert_to_onn
raise Exception('No .zip files for training available. Check the path in main_config["planes_train_dir"]:'
' %s' % main_config["planes_train_dir"])

val_data, x_val, _ = get_validation_data(train_config)
val_data, x_val = get_validation_data(train_config)

input_shape = x_val[0].shape
# calculate how many iterations per epoch exist
Expand Down
8 changes: 5 additions & 3 deletions engine/src/rl/selfplay.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ string load_random_fen(string filepath)
}


SelfPlay::SelfPlay(RawNetAgent* rawAgent, MCTSAgent* mctsAgent, SearchLimits* searchLimits, PlaySettings* playSettings,
RLSettings* rlSettings, OptionsMap& options):
rawAgent(rawAgent), mctsAgent(mctsAgent), searchLimits(searchLimits), playSettings(playSettings),
SelfPlay::SelfPlay(RawNetAgent* rawAgent, MCTSAgent* mctsAgent, const SearchSettings* searchSettings, SearchLimits* searchLimits, const PlaySettings* playSettings,
const RLSettings* rlSettings, OptionsMap& options):
rawAgent(rawAgent), mctsAgent(mctsAgent), searchSettings(searchSettings), searchLimits(searchLimits), playSettings(playSettings),
rlSettings(rlSettings), gameIdx(0), gamesPerMin(0), samplesPerMin(0), options(options)
{
is960 = options["UCI_Chess960"];
Expand Down Expand Up @@ -113,6 +113,8 @@ SelfPlay::SelfPlay(RawNetAgent* rawAgent, MCTSAgent* mctsAgent, SearchLimits* se
gamePGN.round = "?";
gamePGN.is960 = is960;
this->exporter = new TrainDataExporter(string("data_") + mctsAgent->get_device_name() + string(".zarr"),
mctsAgent->get_num_phases(),
searchSettings->gamePhaseDefinition,
rlSettings->numberChunks, rlSettings->chunkSize);
filenamePGNSelfplay = string("games_") + mctsAgent->get_device_name() + string(".pgn");
filenamePGNArena = string("arena_games_")+ mctsAgent->get_device_name() + string(".pgn");
Expand Down
10 changes: 6 additions & 4 deletions engine/src/rl/selfplay.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,10 @@ class SelfPlay
private:
RawNetAgent* rawAgent;
MCTSAgent* mctsAgent;
const SearchSettings* searchSettings;
SearchLimits* searchLimits;
PlaySettings* playSettings;
RLSettings* rlSettings;
const PlaySettings* playSettings;
const RLSettings* rlSettings;
OptionsMap& options;
GamePGN gamePGN;
TrainDataExporter* exporter;
Expand All @@ -90,13 +91,14 @@ class SelfPlay
* @brief SelfPlay
* @param rawAgent Raw network agent which uses the raw network policy for e.g. game initiliation
* @param mctsAgent MCTSAgent which is used during selfplay for game generation
* @param searchSettings Search settings configuration struct
* @param searchLimits Search limit configuration struct
* @param playSettings Playing setting configuration struct
* @param RLSettings Additional settings for reinforcement learning usage
* @param options Object holding all UCI options
*/
SelfPlay(RawNetAgent* rawAgent, MCTSAgent* mctsAgent, SearchLimits* searchLimits, PlaySettings* playSettings,
RLSettings* rlSettings, OptionsMap& options);
SelfPlay(RawNetAgent* rawAgent, MCTSAgent* mctsAgent, const SearchSettings* searchSettings, SearchLimits* searchLimits, const PlaySettings* playSettings,
const RLSettings* rlSettings, OptionsMap& options);
~SelfPlay();

/**
Expand Down
22 changes: 21 additions & 1 deletion engine/src/rl/traindataexporter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ void TrainDataExporter::save_sample(const StateObj* pos, const EvalInfo& eval)
save_best_move_q(eval);
save_side_to_move(Color(pos->side_to_move()));
save_cur_sample_index();
save_cur_phase(pos);
++curSampleIdx;
// value will be set later in export_game_result()
firstMove = false;
Expand Down Expand Up @@ -87,6 +88,20 @@ void TrainDataExporter::save_cur_sample_index()
}
}

void TrainDataExporter::save_cur_phase(const StateObj* pos)
{
// curGamePhase, starting from 0
xt::xarray<int16_t> phaseArray({ 1 }, pos->get_phase(numPhases, gamePhaseDefinition));

if (firstMove) {
gamePhaseVector = phaseArray;
}
else {
// concatenate the sample to array for the current game
gamePhaseVector = xt::concatenate(xtuple(gamePhaseVector, phaseArray));
}
}

void TrainDataExporter::export_game_samples(Result result) {
if (startIdx >= numberSamples) {
info_string("Extended number of maximum samples");
Expand All @@ -106,13 +121,16 @@ void TrainDataExporter::export_game_samples(Result result) {
z5::types::ShapeType offsetPolicy = { startIdx, 0 };
z5::multiarray::writeSubarray<float>(dPolicy, gamePolicy, offsetPolicy.begin());
z5::multiarray::writeSubarray<int16_t>(dPlysToEnd, gamePlysToEnd, offset.begin());
z5::multiarray::writeSubarray<int16_t>(dPhaseVector, gamePhaseVector, offset.begin());

startIdx += curSampleIdx;
gameIdx++;
save_start_idx();
}

TrainDataExporter::TrainDataExporter(const string& fileName, size_t numberChunks, size_t chunkSize):
TrainDataExporter::TrainDataExporter(const string& fileName, unsigned int numPhases, GamePhaseDefinition gamePhaseDefinition, size_t numberChunks, size_t chunkSize):
numPhases(numPhases),
gamePhaseDefinition(gamePhaseDefinition),
numberChunks(numberChunks),
chunkSize(chunkSize),
numberSamples(numberChunks * chunkSize),
Expand Down Expand Up @@ -214,6 +232,7 @@ void TrainDataExporter::open_dataset_from_file(const z5::filesystem::handle::Fil
dPolicy = z5::openDataset(file, "y_policy");
dbestMoveQ = z5::openDataset(file, "y_best_move_q");
dPlysToEnd = z5::openDataset(file, "plys_to_end");
dPhaseVector = z5::openDataset(file, "phase_vector");
}

void TrainDataExporter::create_new_dataset_file(const z5::filesystem::handle::File &file)
Expand All @@ -231,6 +250,7 @@ void TrainDataExporter::create_new_dataset_file(const z5::filesystem::handle::Fi
dPolicy = z5::createDataset(file, "y_policy", "float32", { numberSamples, StateConstants::NB_LABELS() }, { chunkSize, StateConstants::NB_LABELS() });
dbestMoveQ = z5::createDataset(file, "y_best_move_q", "float32", { numberSamples }, { chunkSize });
dPlysToEnd = z5::createDataset(file, "plys_to_end", "int16", { numberSamples }, { chunkSize });
dPhaseVector = z5::createDataset(file, "phase_vector", "int16", { numberSamples }, { chunkSize });

save_start_idx();
}
Expand Down
14 changes: 13 additions & 1 deletion engine/src/rl/traindataexporter.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
class TrainDataExporter
{
private:
unsigned int numPhases;
GamePhaseDefinition gamePhaseDefinition;
size_t numberChunks;
size_t chunkSize;
size_t numberSamples;
Expand All @@ -57,12 +59,14 @@ class TrainDataExporter
std::unique_ptr<z5::Dataset> dPolicy;
std::unique_ptr<z5::Dataset> dbestMoveQ;
std::unique_ptr<z5::Dataset> dPlysToEnd;
std::unique_ptr<z5::Dataset> dPhaseVector;

xt::xarray<int16_t> gameX;
xt::xarray<int16_t> gameValue;
xt::xarray<float> gamePolicy;
xt::xarray<float> gameBestMoveQ;
xt::xarray<int16_t> gamePlysToEnd;
xt::xarray<int16_t> gamePhaseVector;
bool firstMove;

// current number of games - 1
Expand Down Expand Up @@ -106,6 +110,12 @@ class TrainDataExporter
*/
void save_cur_sample_index();

/**
* @brief save_cur_phase Saves the current phase id for the current position.
* @param pos Current position
*/
void save_cur_phase(const StateObj* pos);

/**
* @brief save_start_idx Saves the current starting index where the next game starts to the game array
*/
Expand Down Expand Up @@ -140,11 +150,13 @@ class TrainDataExporter
/**
* @brief TrainDataExporter
* @param fileNameExport File name of the uncompressed data to be exported in (e.g. "data.zarr")
* @param numPhases Number of game phases to support for exporting
* @param gamePhaseDefinition Game phase definition to use
* @param numberChunks Defines how many chunks a single file should contain.
* The product of the number of chunks and its chunk size yields the total number of samples of a file.
* @param chunkSize Defines the chunk size of a single chunk
*/
TrainDataExporter(const string& fileNameExport, size_t numberChunks=200, size_t chunkSize=128);
TrainDataExporter(const string& fileNameExport, unsigned int numPhases, GamePhaseDefinition gamePhaseDefinition, size_t numberChunks=200, size_t chunkSize=128);

/**
* @brief export_pos Saves a given board position, policy and Q-value to the specific game arrays
Expand Down
6 changes: 3 additions & 3 deletions engine/src/uci/crazyara.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ void CrazyAra::activeuci()
void CrazyAra::selfplay(istringstream &is)
{
prepare_search_config_structs();
SelfPlay selfPlay(rawAgent.get(), mctsAgent.get(), &searchLimits, &playSettings, &rlSettings, Options);
SelfPlay selfPlay(rawAgent.get(), mctsAgent.get(), &searchSettings, &searchLimits, &playSettings, &rlSettings, Options);
size_t numberOfGames;
is >> numberOfGames;
selfPlay.go(numberOfGames, variant);
Expand All @@ -366,7 +366,7 @@ void CrazyAra::selfplay(istringstream &is)
void CrazyAra::arena(istringstream &is)
{
prepare_search_config_structs();
SelfPlay selfPlay(rawAgent.get(), mctsAgent.get(), &searchLimits, &playSettings, &rlSettings, Options);
SelfPlay selfPlay(rawAgent.get(), mctsAgent.get(), &searchSettings, &searchLimits, &playSettings, &rlSettings, Options);
fill_nn_vectors(Options["Model_Directory_Contender"], netSingleContenderVector, netBatchesContenderVector);
mctsAgentContender = create_new_mcts_agent(netSingleContenderVector, netBatchesContenderVector, &searchSettings);
size_t numberOfGames;
Expand Down Expand Up @@ -420,7 +420,7 @@ void CrazyAra::multimodel_arena(istringstream &is, const string &modelDirectory1
mcts2 = create_new_mcts_agent(netSingleContenderVector, netBatchesContenderVector, &searchSettings, static_cast<MCTSAgentType>(type));
}

SelfPlay selfPlay(rawAgent.get(), mcts1.get(), &searchLimits, &playSettings, &rlSettings, Options);
SelfPlay selfPlay(rawAgent.get(), mcts1.get(), &searchSettings, &searchLimits, &playSettings, &rlSettings, Options);
size_t numberOfGames;
is >> numberOfGames;
TournamentResult tournamentResult = selfPlay.go_arena(mcts2.get(), numberOfGames, variant);
Expand Down

0 comments on commit 6197410

Please sign in to comment.