diff --git a/engine/CMakeLists.txt b/engine/CMakeLists.txt index d50e5993d..1762a8eb7 100644 --- a/engine/CMakeLists.txt +++ b/engine/CMakeLists.txt @@ -43,7 +43,7 @@ if (MODE_CHESS) # add_definitions(-DVERSION=2) # add_definitions(-DSUB_VERSION=8) add_definitions(-DVERSION=3) -# add_definitions(-DSUPPORT960) + add_definitions(-DSUPPORT960) endif() if (MODE_LICHESS) @@ -57,7 +57,7 @@ if (MODE_LICHESS) add_definitions(-DATOMIC) add_definitions(-DHORDE) add_definitions(-DRACE) -# add_definitions(-DSUPPORT960) + add_definitions(-DSUPPORT960) add_definitions(-DMCTS_TB_SUPPORT) add_definitions(-DVERSION=1) endif() diff --git a/engine/src/environments/chess_related/boardstate.h b/engine/src/environments/chess_related/boardstate.h index df16e4ee0..1f5b28ff9 100644 --- a/engine/src/environments/chess_related/boardstate.h +++ b/engine/src/environments/chess_related/boardstate.h @@ -95,9 +95,9 @@ class StateConstantsBoard : public StateConstantsInterface return OutputRepresentation::MV_LOOKUP[action]; } } - static void init(bool isPolicyMap) { + static void init(bool isPolicyMap, bool is960) { OutputRepresentation::init_labels(); - OutputRepresentation::init_policy_constants(isPolicyMap); + OutputRepresentation::init_policy_constants(isPolicyMap, is960); } // ------------------------------------------------- // | Additional custom methods | diff --git a/engine/src/environments/chess_related/outputrepresentation.cpp b/engine/src/environments/chess_related/outputrepresentation.cpp index 5bd9a8fb1..79302d97e 100644 --- a/engine/src/environments/chess_related/outputrepresentation.cpp +++ b/engine/src/environments/chess_related/outputrepresentation.cpp @@ -36,13 +36,8 @@ void apply_softmax(DynamicVector &policyProbSmall) policyProbSmall = softmax(policyProbSmall); } -void OutputRepresentation::init_policy_constants(bool isPolicyMap) +void OutputRepresentation::init_policy_constants(bool isPolicyMap, bool is960) { -#ifdef SUPPORT960 - const bool is960 = true; -#else - const bool is960 = false; -#endif // fill mirrored label list and look-up table for (size_t mvIdx = 0; mvIdx < StateConstants::NB_LABELS(); mvIdx++) { LABELS_MIRRORED[mvIdx] = mirror_move(LABELS[mvIdx]); diff --git a/engine/src/environments/chess_related/outputrepresentation.h b/engine/src/environments/chess_related/outputrepresentation.h index 766189c76..f8f37072f 100644 --- a/engine/src/environments/chess_related/outputrepresentation.h +++ b/engine/src/environments/chess_related/outputrepresentation.h @@ -87,8 +87,9 @@ struct OutputRepresentation{ /** * @brief init_policy_constants Fills the hash maps for a action to nn index binding. * @param isPolicyMap describes if a policy map head is used for the NN. + * @param is960 defines if 960 variant should be supported */ - static void init_policy_constants(bool isPolicyMap); + static void init_policy_constants(bool isPolicyMap, bool is960); }; diff --git a/engine/src/environments/fairy_state/fairystate.h b/engine/src/environments/fairy_state/fairystate.h index f050d4327..0a69aac20 100644 --- a/engine/src/environments/fairy_state/fairystate.h +++ b/engine/src/environments/fairy_state/fairystate.h @@ -111,7 +111,7 @@ class StateConstantsFairy : public StateConstantsInterface : string{char('a' + file_of(to)), '1', '0'}; return fromSquare + toSquare; } - static void init(bool isPolicyMap) { + static void init(bool isPolicyMap, bool is960) { FairyOutputRepresentation::init_labels(); FairyOutputRepresentation::init_policy_constants(isPolicyMap); } diff --git a/engine/src/environments/open_spiel/openspielstate.h b/engine/src/environments/open_spiel/openspielstate.h index 366531ab8..3f2073a0e 100644 --- a/engine/src/environments/open_spiel/openspielstate.h +++ b/engine/src/environments/open_spiel/openspielstate.h @@ -77,7 +77,7 @@ class StateConstantsOpenSpiel : public StateConstantsInterface available_variants() { diff --git a/engine/src/state.h b/engine/src/state.h index 351e0f434..48091ecca 100644 --- a/engine/src/state.h +++ b/engine/src/state.h @@ -207,9 +207,10 @@ class StateConstantsInterface /** * @brief init Init function which is called after a neural network has been loaded and can be used to initialize static variables. * @param isPolicyMap Boolean indicating if the neural network uses a policy map representation + * @param is960 Boolean indicating if the 960 variant shall be used for initialization */ - static void init(bool isPolicyMap) { - return T::init(isPolicyMap); + static void init(bool isPolicyMap, bool is960) { + return T::init(isPolicyMap, is960); } /** diff --git a/engine/src/uci/crazyara.cpp b/engine/src/uci/crazyara.cpp index 7051d0125..4348f99e5 100644 --- a/engine/src/uci/crazyara.cpp +++ b/engine/src/uci/crazyara.cpp @@ -62,11 +62,7 @@ CrazyAra::CrazyAra(): useRawNetwork(false), // will be initialized in init_search_settings() networkLoaded(false), ongoingSearch(false), -#ifdef SUPPORT960 - is960(true), -#else is960(false), -#endif changedUCIoption(false) { } @@ -577,7 +573,7 @@ bool CrazyAra::is_ready() netBatches.front()->validate_neural_network(); mctsAgent = create_new_mcts_agent(netSingle.get(), netBatches, &searchSettings); rawAgent = make_unique(netSingle.get(), &playSettings, false); - StateConstants::init(mctsAgent->is_policy_map()); + StateConstants::init(mctsAgent->is_policy_map(), is960); timeoutThread.kill(); if (timeoutMS != 0) { tTimeoutThread.join(); @@ -653,12 +649,22 @@ void CrazyAra::set_uci_option(istringstream &is, StateObj& state) const string prevUciVariant = Options["UCI_Variant"]; const int prevFirstDeviceID = Options["First_Device_ID"]; const int prevLastDeviceID = Options["Last_Device_ID"]; +#ifdef SUPPORT960 + const bool prevIs960 = Options["UCI_Chess960"]; +#else + const bool prevIs960 = is960; +#endif OptionsUCI::setoption(is, variant, state); +#ifdef SUPPORT960 + const bool curIs960 = Options["UCI_Chess960"]; +#else + const bool curIs960 = is960; +#endif changedUCIoption = true; if (networkLoaded) { if (string(Options["Model_Directory"]) != prevModelDir || int(Options["Threads"]) != prevThreads || string(Options["UCI_Variant"]) != prevUciVariant || - int(Options["First_Device_ID"]) != prevFirstDeviceID || int(Options["Last_Device_ID"] != prevLastDeviceID)) { + int(Options["First_Device_ID"]) != prevFirstDeviceID || int(Options["Last_Device_ID"] != prevLastDeviceID) || prevIs960 != curIs960) { networkLoaded = false; is_ready(); } diff --git a/engine/src/uci/crazyara.h b/engine/src/uci/crazyara.h index 709922dc4..21d49c867 100644 --- a/engine/src/uci/crazyara.h +++ b/engine/src/uci/crazyara.h @@ -98,7 +98,7 @@ class CrazyAra bool networkLoaded; bool ongoingSearch; bool is960; - bool changedUCIoption = false; + bool changedUCIoption; public: CrazyAra(); diff --git a/engine/tests/tests.cpp b/engine/tests/tests.cpp index 60687427d..031210c04 100644 --- a/engine/tests/tests.cpp +++ b/engine/tests/tests.cpp @@ -567,7 +567,7 @@ TEST_CASE("6-Men WDL"){ #endif TEST_CASE("LABELS length"){ - StateConstants::init(true); + StateConstants::init(true, false); REQUIRE(OutputRepresentation::LABELS.size() == size_t(StateConstants::NB_LABELS())); REQUIRE(OutputRepresentation::LABELS_MIRRORED.size() == size_t(StateConstants::NB_LABELS())); }