Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix 960 initialization problem #207

Merged
merged 2 commits into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions engine/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ if (MODE_CHESS)
# add_definitions(-DVERSION=2)
# add_definitions(-DSUB_VERSION=8)
add_definitions(-DVERSION=3)
# add_definitions(-DSUPPORT960)
add_definitions(-DSUPPORT960)
endif()

if (MODE_LICHESS)
Expand All @@ -57,7 +57,7 @@ if (MODE_LICHESS)
add_definitions(-DATOMIC)
add_definitions(-DHORDE)
add_definitions(-DRACE)
# add_definitions(-DSUPPORT960)
add_definitions(-DSUPPORT960)
add_definitions(-DMCTS_TB_SUPPORT)
add_definitions(-DVERSION=1)
endif()
Expand Down
4 changes: 2 additions & 2 deletions engine/src/environments/chess_related/boardstate.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ class StateConstantsBoard : public StateConstantsInterface<StateConstantsBoard>
return OutputRepresentation::MV_LOOKUP[action];
}
}
static void init(bool isPolicyMap) {
static void init(bool isPolicyMap, bool is960) {
OutputRepresentation::init_labels();
OutputRepresentation::init_policy_constants(isPolicyMap);
OutputRepresentation::init_policy_constants(isPolicyMap, is960);
}
// -------------------------------------------------
// | Additional custom methods |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,8 @@ void apply_softmax(DynamicVector<float> &policyProbSmall)
policyProbSmall = softmax(policyProbSmall);
}

void OutputRepresentation::init_policy_constants(bool isPolicyMap)
void OutputRepresentation::init_policy_constants(bool isPolicyMap, bool is960)
{
#ifdef SUPPORT960
const bool is960 = true;
#else
const bool is960 = false;
#endif
// fill mirrored label list and look-up table
for (size_t mvIdx = 0; mvIdx < StateConstants::NB_LABELS(); mvIdx++) {
LABELS_MIRRORED[mvIdx] = mirror_move(LABELS[mvIdx]);
Expand Down
3 changes: 2 additions & 1 deletion engine/src/environments/chess_related/outputrepresentation.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@ struct OutputRepresentation{
/**
* @brief init_policy_constants Fills the hash maps for a action to nn index binding.
* @param isPolicyMap describes if a policy map head is used for the NN.
* @param is960 defines if 960 variant should be supported
*/
static void init_policy_constants(bool isPolicyMap);
static void init_policy_constants(bool isPolicyMap, bool is960);


};
Expand Down
2 changes: 1 addition & 1 deletion engine/src/environments/fairy_state/fairystate.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class StateConstantsFairy : public StateConstantsInterface<StateConstantsFairy>
: string{char('a' + file_of(to)), '1', '0'};
return fromSquare + toSquare;
}
static void init(bool isPolicyMap) {
static void init(bool isPolicyMap, bool is960) {
FairyOutputRepresentation::init_labels();
FairyOutputRepresentation::init_policy_constants(isPolicyMap);
}
Expand Down
2 changes: 1 addition & 1 deletion engine/src/environments/open_spiel/openspielstate.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class StateConstantsOpenSpiel : public StateConstantsInterface<StateConstantsOpe
static MoveIdx action_to_index(Action action) {
return action; // TODO
}
static void init(bool isPolicyMap) {
static void init(bool isPolicyMap, bool is960) {
return; // pass
}

Expand Down
2 changes: 1 addition & 1 deletion engine/src/environments/stratego_related/strategostate.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class StateConstantsStratego : public StateConstantsInterface<StateConstantsStra
static MoveIdx action_to_index(Action action) {
return action; // TODO
}
static void init(bool isPolicyMap) {
static void init(bool isPolicyMap, bool is960) {
return; // pass
}
static std::vector<std::string> available_variants() {
Expand Down
5 changes: 3 additions & 2 deletions engine/src/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,10 @@ class StateConstantsInterface
/**
* @brief init Init function which is called after a neural network has been loaded and can be used to initialize static variables.
* @param isPolicyMap Boolean indicating if the neural network uses a policy map representation
* @param is960 Boolean indicating if the 960 variant shall be used for initialization
*/
static void init(bool isPolicyMap) {
return T::init(isPolicyMap);
static void init(bool isPolicyMap, bool is960) {
return T::init(isPolicyMap, is960);
}

/**
Expand Down
18 changes: 12 additions & 6 deletions engine/src/uci/crazyara.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,7 @@ CrazyAra::CrazyAra():
useRawNetwork(false), // will be initialized in init_search_settings()
networkLoaded(false),
ongoingSearch(false),
#ifdef SUPPORT960
is960(true),
#else
is960(false),
#endif
changedUCIoption(false)
{
}
Expand Down Expand Up @@ -577,7 +573,7 @@ bool CrazyAra::is_ready()
netBatches.front()->validate_neural_network();
mctsAgent = create_new_mcts_agent(netSingle.get(), netBatches, &searchSettings);
rawAgent = make_unique<RawNetAgent>(netSingle.get(), &playSettings, false);
StateConstants::init(mctsAgent->is_policy_map());
StateConstants::init(mctsAgent->is_policy_map(), is960);
timeoutThread.kill();
if (timeoutMS != 0) {
tTimeoutThread.join();
Expand Down Expand Up @@ -653,12 +649,22 @@ void CrazyAra::set_uci_option(istringstream &is, StateObj& state)
const string prevUciVariant = Options["UCI_Variant"];
const int prevFirstDeviceID = Options["First_Device_ID"];
const int prevLastDeviceID = Options["Last_Device_ID"];
#ifdef SUPPORT960
const bool prevIs960 = Options["UCI_Chess960"];
#else
const bool prevIs960 = is960;
#endif

OptionsUCI::setoption(is, variant, state);
#ifdef SUPPORT960
const bool curIs960 = Options["UCI_Chess960"];
#else
const bool curIs960 = is960;
#endif
changedUCIoption = true;
if (networkLoaded) {
if (string(Options["Model_Directory"]) != prevModelDir || int(Options["Threads"]) != prevThreads || string(Options["UCI_Variant"]) != prevUciVariant ||
int(Options["First_Device_ID"]) != prevFirstDeviceID || int(Options["Last_Device_ID"] != prevLastDeviceID)) {
int(Options["First_Device_ID"]) != prevFirstDeviceID || int(Options["Last_Device_ID"] != prevLastDeviceID) || prevIs960 != curIs960) {
networkLoaded = false;
is_ready<false>();
}
Expand Down
2 changes: 1 addition & 1 deletion engine/src/uci/crazyara.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class CrazyAra
bool networkLoaded;
bool ongoingSearch;
bool is960;
bool changedUCIoption = false;
bool changedUCIoption;

public:
CrazyAra();
Expand Down
2 changes: 1 addition & 1 deletion engine/tests/tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ TEST_CASE("6-Men WDL"){
#endif

TEST_CASE("LABELS length"){
StateConstants::init(true);
StateConstants::init(true, false);
REQUIRE(OutputRepresentation::LABELS.size() == size_t(StateConstants::NB_LABELS()));
REQUIRE(OutputRepresentation::LABELS_MIRRORED.size() == size_t(StateConstants::NB_LABELS()));
}
Expand Down