Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the sharpness limit in WDLRescale configurable, and fix the Elo --> Contempt calculation #1941

Merged
merged 14 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion src/mcts/params.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,21 @@ SearchParams::WDLRescaleParams AccurateWDLRescaleParams(
return SearchParams::WDLRescaleParams(ratio, diff);
}

// Converts regular Elo into ideal UHO game pair Elo based on the same Elo
// dependent draw rate model used below. Necessary because regular Elo doesn't
// behave well at higher level, while the ideal UHO game pair Elo calculated
// from the decisive game pair ratio underestimates Elo differences by a
// factor of 2 at lower levels.

float ConvertRegularToGamePairElo(float elo_regular) {
const float transition_sharpness = 250.0f;
const float transition_midpoint = 2737.0f;
return elo_regular +
0.5f * transition_sharpness *
std::log(1.0f + std::exp((transition_midpoint - elo_regular) /
transition_sharpness));
}

// Calculate ratio and diff for WDL conversion from the contempt settings.
// Less accurate Elo model, but automatically chooses draw rate and accuracy
// based on the absolute Elo of both sides. Doesn't require clamping, but still
Expand All @@ -129,6 +144,10 @@ SearchParams::WDLRescaleParams SimplifiedWDLRescaleParams(
(1.0f - draw_rate_reference));
float elo_opp =
elo_active - std::clamp(contempt, -contempt_max, contempt_max);
// Convert regular Elo input into internally used game pair Elo.
elo_active = ConvertRegularToGamePairElo(elo_active);
elo_opp = ConvertRegularToGamePairElo(elo_opp);
// Estimate draw rate from given Elo.
float scale_active =
1.0f / (1.0f / scale_zero + std::exp(elo_active / elo_slope - offset));
float scale_opp =
Expand All @@ -144,7 +163,8 @@ SearchParams::WDLRescaleParams SimplifiedWDLRescaleParams(
float mu_opp =
-std::log(10) / 200 * scale_zero * elo_slope *
std::log(1.0f + std::exp(-elo_opp / elo_slope + offset) / scale_zero);
float diff = (mu_active - mu_opp) * contempt_attenuation;
float diff = 1.0f / (scale_reference * scale_reference) *
(mu_active - mu_opp) * contempt_attenuation;
return SearchParams::WDLRescaleParams(ratio, diff);
}
} // namespace
Expand Down Expand Up @@ -375,6 +395,12 @@ const OptionId SearchParams::kWDLContemptAttenuationId{
"wdl-contempt-attenuation", "WDLContemptAttenuation",
"Scales how Elo advantage is applied for contempt. Use 1.0 for realistic "
"analysis, and 0.5-0.6 for optimal match performance."};
const OptionId SearchParams::kWDLMaxSId{
"wdl-max-s", "WDLMaxS",
"Limits the WDL derived sharpness s to a reasonable value to avoid "
"erratic behavior at high contempt values. Default recommended for "
"regular chess, increase value for more volatile positions like DFRC "
"or piece odds."};
const OptionId SearchParams::kWDLEvalObjectivityId{
"wdl-eval-objectivity", "WDLEvalObjectivity",
"When calculating the centipawn eval output, decides how objective/"
Expand Down Expand Up @@ -531,6 +557,7 @@ void SearchParams::Populate(OptionsParser* options) {
options->Add<FloatOption>(kContemptMaxValueId, 0, 10000.0f) = 420.0f;
options->Add<FloatOption>(kWDLCalibrationEloId, 0, 10000.0f) = 0.0f;
options->Add<FloatOption>(kWDLContemptAttenuationId, -10.0f, 10.0f) = 1.0f;
options->Add<FloatOption>(kWDLMaxSId, 0.0f, 10.0f) = 1.4f;
options->Add<FloatOption>(kWDLEvalObjectivityId, 0.0f, 1.0f) = 1.0f;
options->Add<FloatOption>(kWDLDrawRateTargetId, 0.001f, 0.999f) = 0.5f;
options->Add<FloatOption>(kWDLDrawRateReferenceId, 0.001f, 0.999f) = 0.5f;
Expand Down Expand Up @@ -568,6 +595,7 @@ void SearchParams::Populate(OptionsParser* options) {
options->HideOption(kTemperatureVisitOffsetId);
options->HideOption(kContemptMaxValueId);
options->HideOption(kWDLContemptAttenuationId);
options->HideOption(kWDLMaxSId);
options->HideOption(kWDLDrawRateTargetId);
options->HideOption(kWDLBookExitBiasId);
}
Expand Down Expand Up @@ -633,6 +661,7 @@ SearchParams::SearchParams(const OptionsDict& options)
options.Get<float>(kWDLCalibrationEloId),
options.Get<float>(kContemptMaxValueId),
options.Get<float>(kWDLContemptAttenuationId))),
kWDLMaxS(options.Get<float>(kWDLMaxSId)),
kWDLEvalObjectivity(options.Get<float>(kWDLEvalObjectivityId)),
kMaxOutOfOrderEvalsFactor(
options.Get<float>(kMaxOutOfOrderEvalsFactorId)),
Expand Down
3 changes: 3 additions & 0 deletions src/mcts/params.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ class SearchParams {
}
float GetWDLRescaleRatio() const { return kWDLRescaleParams.ratio; }
float GetWDLRescaleDiff() const { return kWDLRescaleParams.diff; }
float GetWDLMaxS() const { return kWDLMaxS; }
float GetWDLEvalObjectivity() const { return kWDLEvalObjectivity; }
float GetMaxOutOfOrderEvalsFactor() const {
return kMaxOutOfOrderEvalsFactor;
Expand Down Expand Up @@ -213,6 +214,7 @@ class SearchParams {
static const OptionId kContemptMaxValueId;
static const OptionId kWDLCalibrationEloId;
static const OptionId kWDLContemptAttenuationId;
static const OptionId kWDLMaxSId;
static const OptionId kWDLEvalObjectivityId;
static const OptionId kWDLDrawRateTargetId;
static const OptionId kWDLDrawRateReferenceId;
Expand Down Expand Up @@ -275,6 +277,7 @@ class SearchParams {
const float kDrawScore;
const float kContempt;
const WDLRescaleParams kWDLRescaleParams;
const float kWDLMaxS;
const float kWDLEvalObjectivity;
const float kMaxOutOfOrderEvalsFactor;
const float kNpsLimit;
Expand Down
26 changes: 11 additions & 15 deletions src/mcts/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ void ApplyDirichletNoise(Node* node, float eps, double alpha) {
namespace {
// WDL conversion formula based on random walk model.
inline double WDLRescale(float& v, float& d, float wdl_rescale_ratio,
float wdl_rescale_diff, float sign, bool invert) {
float wdl_rescale_diff, float sign, bool invert,
float max_reasonable_s) {
if (invert) {
wdl_rescale_diff = -wdl_rescale_diff;
wdl_rescale_ratio = 1.0f / wdl_rescale_ratio;
Expand All @@ -238,8 +239,7 @@ inline double WDLRescale(float& v, float& d, float wdl_rescale_ratio,
auto b = FastLog(1 / w - 1);
auto s = 2 / (a + b);
// Safeguard against unrealistically broad WDL distributions coming from
// the NN. Could be made into a parameter, but probably unnecessary.
const float max_reasonable_s = 1.4f;
// the NN. Originally hardcoded, made into a parameter for piece odds.
if (!invert) s = std::min(max_reasonable_s, s);
auto mu = (a - b) / (a + b);
auto s_new = s * wdl_rescale_ratio;
Expand Down Expand Up @@ -313,7 +313,7 @@ void Search::SendUciInfo() REQUIRES(nodes_mutex_) REQUIRES(counters_mutex_) {
contempt_mode_ == ContemptMode::NONE
? 0
: params_.GetWDLRescaleDiff() * params_.GetWDLEvalObjectivity(),
sign, true);
sign, true, params_.GetWDLMaxS());
}
const auto q = edge.GetQ(default_q, draw_score);
if (edge.IsTerminal() && wl != 0.0f) {
Expand Down Expand Up @@ -523,13 +523,10 @@ std::vector<std::string> Search::GetVerboseStats(Node* node) const {
up = -up;
std::swap(lo, up);
}
*oss << (lo == up
? "(T) "
: lo == GameResult::DRAW && up == GameResult::WHITE_WON
? "(W) "
: lo == GameResult::BLACK_WON && up == GameResult::DRAW
? "(L) "
: "");
*oss << (lo == up ? "(T) "
: lo == GameResult::DRAW && up == GameResult::WHITE_WON ? "(W) "
: lo == GameResult::BLACK_WON && up == GameResult::DRAW ? "(L) "
: "");
}
};

Expand Down Expand Up @@ -1315,9 +1312,8 @@ void SearchWorker::GatherMinibatch() {
// massive nps drop.
if (thread_count > 1 && minibatch_size > 0 &&
computation_->GetCacheMisses() > params_.GetIdlingMinimumWork() &&
thread_count -
search_->backend_waiting_counter_.load(
std::memory_order_relaxed) >
thread_count - search_->backend_waiting_counter_.load(
std::memory_order_relaxed) >
params_.GetThreadIdlingThreshold()) {
return;
}
Expand Down Expand Up @@ -2221,7 +2217,7 @@ void SearchWorker::FetchSingleNodeResult(NodeToProcess* node_to_process,
search_->contempt_mode_ == ContemptMode::NONE
? 0
: params_.GetWDLRescaleDiff(),
sign, false);
sign, false, params_.GetWDLMaxS());
}
node_to_process->v = v;
node_to_process->d = d;
Expand Down