Skip to content

Commit

Permalink
Support monitoring and configuring draw rate
Browse files Browse the repository at this point in the history
  • Loading branch information
ianfab committed Sep 15, 2024
1 parent 8eecaf0 commit ed0eafd
Showing 1 changed file with 20 additions and 16 deletions.
36 changes: 20 additions & 16 deletions src/tools/training_data_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ namespace Stockfish::Tools

std::string seed;

bool write_out_draw_game_in_training_data_generation = true;
float write_out_draw_game_in_training_data_generation = 1;
bool detect_draw_by_consecutive_low_score = true;
bool detect_draw_by_insufficient_mating_material = true;
bool filter_captures = false;
Expand Down Expand Up @@ -179,6 +179,7 @@ namespace Stockfish::Tools
void generate_worker(
Thread& th,
std::atomic<uint64_t>& counter,
std::atomic<uint64_t>& draw_counter,
uint64_t limit);

bool was_seen_before(const Position& pos);
Expand All @@ -201,12 +202,13 @@ namespace Stockfish::Tools
PSVector& sfens,
int8_t lastTurnIsWin,
std::atomic<uint64_t>& counter,
std::atomic<uint64_t>& draw_counter,
uint64_t limit,
Color result_color);

void report(uint64_t done, uint64_t new_done);
void report(uint64_t done, uint64_t draws, uint64_t new_done);

void maybe_report(uint64_t done);
void maybe_report(uint64_t done, uint64_t draws);
};

void TrainingDataGenerator::set_gensfen_search_limits()
Expand Down Expand Up @@ -235,16 +237,17 @@ namespace Stockfish::Tools
set_gensfen_search_limits();

std::atomic<uint64_t> counter{0};
Threads.execute_with_workers([&counter, limit, this](Thread& th) {
generate_worker(th, counter, limit);
std::atomic<uint64_t> draw_counter{0};
Threads.execute_with_workers([&counter, &draw_counter, limit, this](Thread& th) {
generate_worker(th, counter, draw_counter, limit);
});
Threads.wait_for_workers_finished();

sfen_writer.flush();

if (limit % REPORT_STATS_EVERY != 0)
{
report(limit, limit % REPORT_STATS_EVERY);
report(limit, draw_counter, limit % REPORT_STATS_EVERY);
}

std::cout << std::endl;
Expand All @@ -253,6 +256,7 @@ namespace Stockfish::Tools
void TrainingDataGenerator::generate_worker(
Thread& th,
std::atomic<uint64_t>& counter,
std::atomic<uint64_t>& draw_counter,
uint64_t limit)
{
// For the time being, it will be treated as a draw
Expand Down Expand Up @@ -304,7 +308,7 @@ namespace Stockfish::Tools
vector<int> move_hist_scores;

auto flush_psv = [&](int8_t result) {
quit = commit_psv(th, packed_sfens, result, counter, limit, pos.side_to_move());
quit = commit_psv(th, packed_sfens, result, counter, draw_counter, limit, pos.side_to_move());
};

for (int ply = 0; ; ++ply)
Expand Down Expand Up @@ -661,14 +665,12 @@ namespace Stockfish::Tools
PSVector& sfens,
int8_t result,
std::atomic<uint64_t>& counter,
std::atomic<uint64_t>& draw_counter,
uint64_t limit,
Color result_color)
{
if (!params.write_out_draw_game_in_training_data_generation && result == 0)
{
// We didn't write anything so why quit.
if (float(draw_counter + 1) / (counter + 1) > params.write_out_draw_game_in_training_data_generation && result == 0)
return false;
}

auto side_to_move_from_sfen = [](auto& sfen){
return (Color)(sfen.sfen.data[0] & 1);
Expand All @@ -692,7 +694,8 @@ namespace Stockfish::Tools
return true;

// because `iter` was done, now we do one more
maybe_report(iter + 1);
draw_counter += result == 0;
maybe_report(iter + 1, draw_counter);

// Write out one sfen.
sfen_writer.write(th.id(), sfen);
Expand All @@ -701,7 +704,7 @@ namespace Stockfish::Tools
return false;
}

void TrainingDataGenerator::report(uint64_t done, uint64_t new_done)
void TrainingDataGenerator::report(uint64_t done, uint64_t draws, uint64_t new_done)
{
const auto now_time = now();
const TimePoint elapsed = now_time - last_stats_report_time + 1;
Expand All @@ -710,14 +713,15 @@ namespace Stockfish::Tools
<< endl
<< done << " sfens, "
<< new_done * 1000 / elapsed << " sfens/second, "
<< "draw rate " << draws * 100 / done << "%, "
<< "at " << now_string() << sync_endl;

last_stats_report_time = now_time;

out = sync_region_cout.new_region();
}

void TrainingDataGenerator::maybe_report(uint64_t done)
void TrainingDataGenerator::maybe_report(uint64_t done, uint64_t draws)
{
if (done % REPORT_DOT_EVERY == 0)
{
Expand All @@ -735,7 +739,7 @@ namespace Stockfish::Tools

if (done % REPORT_STATS_EVERY == 0)
{
report(done, REPORT_STATS_EVERY);
report(done, draws, REPORT_STATS_EVERY);
}
}
}
Expand Down Expand Up @@ -892,7 +896,7 @@ namespace Stockfish::Tools
<< " - output_file_name = " << params.output_file_name << endl
<< " - save_every = " << params.save_every << endl
<< " - random_file_name = " << random_file_name << endl
<< " - write_drawn_games = " << params.write_out_draw_game_in_training_data_generation << endl
<< " - keep_draws = " << params.write_out_draw_game_in_training_data_generation << endl
<< " - draw by low score = " << params.detect_draw_by_consecutive_low_score << endl
<< " - draw by insuff. mat. = " << params.detect_draw_by_insufficient_mating_material << endl
<< " - filter_captures = " << params.filter_captures << endl
Expand Down

0 comments on commit ed0eafd

Please sign in to comment.