Skip to content

Commit

Permalink
Fix 960 initialization problem (#207)
Browse files Browse the repository at this point in the history
* Fix is960 initialization problem

* Update tests.cpp
  • Loading branch information
QueensGambit authored Aug 8, 2023
1 parent df8122b commit 66e26b4
Show file tree
Hide file tree
Showing 11 changed files with 27 additions and 24 deletions.
4 changes: 2 additions & 2 deletions engine/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ if (MODE_CHESS)
# add_definitions(-DVERSION=2)
# add_definitions(-DSUB_VERSION=8)
add_definitions(-DVERSION=3)
# add_definitions(-DSUPPORT960)
add_definitions(-DSUPPORT960)
endif()

if (MODE_LICHESS)
Expand All @@ -57,7 +57,7 @@ if (MODE_LICHESS)
add_definitions(-DATOMIC)
add_definitions(-DHORDE)
add_definitions(-DRACE)
# add_definitions(-DSUPPORT960)
add_definitions(-DSUPPORT960)
add_definitions(-DMCTS_TB_SUPPORT)
add_definitions(-DVERSION=1)
endif()
Expand Down
4 changes: 2 additions & 2 deletions engine/src/environments/chess_related/boardstate.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ class StateConstantsBoard : public StateConstantsInterface<StateConstantsBoard>
return OutputRepresentation::MV_LOOKUP[action];
}
}
static void init(bool isPolicyMap) {
static void init(bool isPolicyMap, bool is960) {
OutputRepresentation::init_labels();
OutputRepresentation::init_policy_constants(isPolicyMap);
OutputRepresentation::init_policy_constants(isPolicyMap, is960);
}
// -------------------------------------------------
// | Additional custom methods |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,8 @@ void apply_softmax(DynamicVector<float> &policyProbSmall)
policyProbSmall = softmax(policyProbSmall);
}

void OutputRepresentation::init_policy_constants(bool isPolicyMap)
void OutputRepresentation::init_policy_constants(bool isPolicyMap, bool is960)
{
#ifdef SUPPORT960
const bool is960 = true;
#else
const bool is960 = false;
#endif
// fill mirrored label list and look-up table
for (size_t mvIdx = 0; mvIdx < StateConstants::NB_LABELS(); mvIdx++) {
LABELS_MIRRORED[mvIdx] = mirror_move(LABELS[mvIdx]);
Expand Down
3 changes: 2 additions & 1 deletion engine/src/environments/chess_related/outputrepresentation.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@ struct OutputRepresentation{
/**
* @brief init_policy_constants Fills the hash maps for a action to nn index binding.
* @param isPolicyMap describes if a policy map head is used for the NN.
* @param is960 defines if 960 variant should be supported
*/
static void init_policy_constants(bool isPolicyMap);
static void init_policy_constants(bool isPolicyMap, bool is960);


};
Expand Down
2 changes: 1 addition & 1 deletion engine/src/environments/fairy_state/fairystate.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class StateConstantsFairy : public StateConstantsInterface<StateConstantsFairy>
: string{char('a' + file_of(to)), '1', '0'};
return fromSquare + toSquare;
}
static void init(bool isPolicyMap) {
static void init(bool isPolicyMap, bool is960) {
FairyOutputRepresentation::init_labels();
FairyOutputRepresentation::init_policy_constants(isPolicyMap);
}
Expand Down
2 changes: 1 addition & 1 deletion engine/src/environments/open_spiel/openspielstate.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class StateConstantsOpenSpiel : public StateConstantsInterface<StateConstantsOpe
static MoveIdx action_to_index(Action action) {
return action; // TODO
}
static void init(bool isPolicyMap) {
static void init(bool isPolicyMap, bool is960) {
return; // pass
}

Expand Down
2 changes: 1 addition & 1 deletion engine/src/environments/stratego_related/strategostate.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class StateConstantsStratego : public StateConstantsInterface<StateConstantsStra
static MoveIdx action_to_index(Action action) {
return action; // TODO
}
static void init(bool isPolicyMap) {
static void init(bool isPolicyMap, bool is960) {
return; // pass
}
static std::vector<std::string> available_variants() {
Expand Down
5 changes: 3 additions & 2 deletions engine/src/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,10 @@ class StateConstantsInterface
/**
* @brief init Init function which is called after a neural network has been loaded and can be used to initialize static variables.
* @param isPolicyMap Boolean indicating if the neural network uses a policy map representation
* @param is960 Boolean indicating if the 960 variant shall be used for initialization
*/
static void init(bool isPolicyMap) {
return T::init(isPolicyMap);
static void init(bool isPolicyMap, bool is960) {
return T::init(isPolicyMap, is960);
}

/**
Expand Down
18 changes: 12 additions & 6 deletions engine/src/uci/crazyara.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,7 @@ CrazyAra::CrazyAra():
useRawNetwork(false), // will be initialized in init_search_settings()
networkLoaded(false),
ongoingSearch(false),
#ifdef SUPPORT960
is960(true),
#else
is960(false),
#endif
changedUCIoption(false)
{
}
Expand Down Expand Up @@ -577,7 +573,7 @@ bool CrazyAra::is_ready()
netBatches.front()->validate_neural_network();
mctsAgent = create_new_mcts_agent(netSingle.get(), netBatches, &searchSettings);
rawAgent = make_unique<RawNetAgent>(netSingle.get(), &playSettings, false);
StateConstants::init(mctsAgent->is_policy_map());
StateConstants::init(mctsAgent->is_policy_map(), is960);
timeoutThread.kill();
if (timeoutMS != 0) {
tTimeoutThread.join();
Expand Down Expand Up @@ -653,12 +649,22 @@ void CrazyAra::set_uci_option(istringstream &is, StateObj& state)
const string prevUciVariant = Options["UCI_Variant"];
const int prevFirstDeviceID = Options["First_Device_ID"];
const int prevLastDeviceID = Options["Last_Device_ID"];
#ifdef SUPPORT960
const bool prevIs960 = Options["UCI_Chess960"];
#else
const bool prevIs960 = is960;
#endif

OptionsUCI::setoption(is, variant, state);
#ifdef SUPPORT960
const bool curIs960 = Options["UCI_Chess960"];
#else
const bool curIs960 = is960;
#endif
changedUCIoption = true;
if (networkLoaded) {
if (string(Options["Model_Directory"]) != prevModelDir || int(Options["Threads"]) != prevThreads || string(Options["UCI_Variant"]) != prevUciVariant ||
int(Options["First_Device_ID"]) != prevFirstDeviceID || int(Options["Last_Device_ID"] != prevLastDeviceID)) {
int(Options["First_Device_ID"]) != prevFirstDeviceID || int(Options["Last_Device_ID"] != prevLastDeviceID) || prevIs960 != curIs960) {
networkLoaded = false;
is_ready<false>();
}
Expand Down
2 changes: 1 addition & 1 deletion engine/src/uci/crazyara.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class CrazyAra
bool networkLoaded;
bool ongoingSearch;
bool is960;
bool changedUCIoption = false;
bool changedUCIoption;

public:
CrazyAra();
Expand Down
2 changes: 1 addition & 1 deletion engine/tests/tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ TEST_CASE("6-Men WDL"){
#endif

TEST_CASE("LABELS length"){
StateConstants::init(true);
StateConstants::init(true, false);
REQUIRE(OutputRepresentation::LABELS.size() == size_t(StateConstants::NB_LABELS()));
REQUIRE(OutputRepresentation::LABELS_MIRRORED.size() == size_t(StateConstants::NB_LABELS()));
}
Expand Down

0 comments on commit 66e26b4

Please sign in to comment.