From 06f25c6a22d2166be7948449a8e9aa7abaf1e198 Mon Sep 17 00:00:00 2001 From: xsr7qsr Date: Mon, 26 Apr 2021 10:32:37 +0200 Subject: [PATCH 01/19] Add Fairy-Stockfish (fork) to submodules --- .gitmodules | 3 +++ engine/3rdparty/Fairy-Stockfish | 1 + 2 files changed, 4 insertions(+) create mode 160000 engine/3rdparty/Fairy-Stockfish diff --git a/.gitmodules b/.gitmodules index dc2c8068..6358e3e0 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "engine/3rdparty/Stockfish"] path = engine/3rdparty/Stockfish url = https://github.com/QueensGambit/Stockfish.git +[submodule "engine/3rdparty/Fairy-Stockfish"] + path = engine/3rdparty/Fairy-Stockfish + url = https://github.com/xsr7qsr/Fairy-Stockfish.git diff --git a/engine/3rdparty/Fairy-Stockfish b/engine/3rdparty/Fairy-Stockfish new file mode 160000 index 00000000..c4edc95b --- /dev/null +++ b/engine/3rdparty/Fairy-Stockfish @@ -0,0 +1 @@ +Subproject commit c4edc95b096880362f0f02a0a6fc627bb1ddf9b7 From 42e8a989355bad07ac21e5567ecbd54dd0501dcf Mon Sep 17 00:00:00 2001 From: xsr7qsr Date: Mon, 26 Apr 2021 11:07:49 +0200 Subject: [PATCH 02/19] Add xiangqi support for training --- DeepCrazyhouse/src/domain/util.py | 59 +- .../src/domain/variants/constants.py | 2181 ++++++++++++++++- .../variants/plane_policy_representation.py | 2092 +++++++++++++++- .../src/preprocessing/dataset_loader.py | 65 +- .../src/training/trainer_agent_mxnet.py | 61 +- 5 files changed, 4416 insertions(+), 42 deletions(-) diff --git a/DeepCrazyhouse/src/domain/util.py b/DeepCrazyhouse/src/domain/util.py index 724c842d..01ef12c8 100755 --- a/DeepCrazyhouse/src/domain/util.py +++ b/DeepCrazyhouse/src/domain/util.py @@ -15,6 +15,8 @@ MODE, MODE_LICHESS, MODE_CRAZYHOUSE, + MODE_XIANGQI, + LABELS_XIANGQI, CHANNEL_MAPPING_CONST, CHANNEL_MAPPING_POS, MAX_NB_MOVES, @@ -27,6 +29,10 @@ ) +# file lookup for vertically mirrored xiangqi boards +mirrored_files_lookup = {'a': 'i', 'b': 'h', 'c': 'g', 'd': 'f', 'e': 'e', 'f': 'd', 'g': 'c', 'h': 'b', 'i': 'a'} + + def get_row_col(position, mirror=False): """ Maps a value [0,63] to its row and column index @@ -139,6 +145,24 @@ def get_numpy_arrays(pgn_dataset): return start_indices, x, y_value, y_policy, entries[0], entries[1] +def get_x_y_and_indices(dataset): + """ + Loads the content of the given dataset into numpy arrays. + + :param dataset: dataset file handle + :return: numpy-arrays: + starting_idx - defines the index where each game starts + x - the board representation for all games + y_value - the game outcome (-1,0,1) for each board position + y_policy - the movement policy for the next_move played + """ + start_indices = np.array(dataset["start_indices"]) + x = np.array(dataset["x"]) + y_value = np.array(dataset["y_value"]) + y_policy = np.array(dataset["y_policy"]) + return start_indices, x, y_value, y_policy + + def normalize_input_planes(x): """ Normalizes input planes to range [0,1]. Works in place / meaning the input parameter x is manipulated @@ -161,13 +185,21 @@ def normalize_input_planes(x): mat_pos[channel, :, :] /= MAX_NB_PRISONERS # the prison for black begins 5 channels later mat_pos[channel + POCKETS_SIZE_PIECE_TYPE, :, :] /= MAX_NB_PRISONERS + # xiangqi has 7 piece types (king/general is excluded as prisoner) + elif MODE == MODE_XIANGQI: + for p_type in range(6): + channel = CHANNEL_MAPPING_POS["prisoners"] + p_type + mat_pos[channel, :, :] /= MAX_NB_PRISONERS + # the prison for opponent begins 6 channels later + mat_pos[channel + POCKETS_SIZE_PIECE_TYPE, :, :] /= MAX_NB_PRISONERS # Total Move Count # 500 was set as the max number of total moves mat_const[CHANNEL_MAPPING_CONST["total_mv_cnt"], :, :] /= MAX_NB_MOVES # No progress count # after 40 moves of no progress the 40 moves rule for draw applies - mat_const[CHANNEL_MAPPING_CONST["no_progress_cnt"], :, :] /= MAX_NB_NO_PROGRESS + if MODE != MODE_XIANGQI: + mat_const[CHANNEL_MAPPING_CONST["no_progress_cnt"], :, :] /= MAX_NB_NO_PROGRESS return x @@ -177,6 +209,31 @@ def normalize_input_planes(x): MATRIX_NORMALIZER = normalize_input_planes(np.ones((NB_CHANNELS_TOTAL, BOARD_HEIGHT, BOARD_WIDTH))) +def augment(x, y_policy): + """ + Augments a given set of planes and their corresponding policy targets. + The returned planes are vertically mirrored. The returned policy targets + are adjusted, so that they correspond to the new planes. + Works in-place. + :param x: Input planes + :param y_policy: Policy targets + """ + for i in range(x.shape[0]): + for j in range(x.shape[1]): + x[i][j] = np.fliplr(x[i][j]) + + idx_mv = np.where(y_policy[i] == 1)[0][0] + y_policy[i][idx_mv] = 0 + ucci = LABELS_XIANGQI[idx_mv] + + from_square_aug = mirrored_files_lookup[ucci[0]] + ucci[1] + to_square_aug = mirrored_files_lookup[ucci[2]] + ucci[3] + ucci_aug = from_square_aug + to_square_aug + + idx_mv_aug = LABELS_XIANGQI.index(ucci_aug) + y_policy[i][idx_mv_aug] = 1 + + def customize_input_planes(x): """ Reverts normalization back to integer values. Works in place. diff --git a/DeepCrazyhouse/src/domain/variants/constants.py b/DeepCrazyhouse/src/domain/variants/constants.py index 0b810bc8..6320282f 100644 --- a/DeepCrazyhouse/src/domain/variants/constants.py +++ b/DeepCrazyhouse/src/domain/variants/constants.py @@ -26,6 +26,9 @@ MODE_CHESS: -- Input and output constants which only support classical chess and chess960 + +MODE_XIANGQI +-- Input and output constants which only support xiangqi ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ! DO NOT CHANGE THE LABEL LIST OTHERWISE YOU WILL BREAK THE MOVE MAPPING OF THE NETWORK ! @@ -50,7 +53,8 @@ MODE_CRAZYHOUSE = 0 MODE_LICHESS = 1 MODE_CHESS = 2 -MODES = [MODE_CRAZYHOUSE, MODE_LICHESS, MODE_CHESS] +MODE_XIANGQI = 3 +MODES = [MODE_CRAZYHOUSE, MODE_LICHESS, MODE_CHESS, MODE_XIANGQI] # Active mode MODE = main_config["mode"] VERSION = main_config["version"] @@ -58,8 +62,12 @@ if MODE not in MODES: raise ValueError('unsupported "mode" specification in main_config.py') -# The same ordering is used in the python-chess (only that python-chess also includes a "null" piece) -PIECES = ["P", "N", "B", "R", "Q", "K", "p", "n", "b", "r", "q", "k"] +if MODE == MODE_XIANGQI: + # The same ordering is used for red and black player channels in the preprocessed data + PIECES = ["K", "A", "E", "H", "R", "C", "P", "k", "a", "e", "h", "r", "c", "p"] +else: + # The same ordering is used in the python-chess (only that python-chess also includes a "null" piece) + PIECES = ["P", "N", "B", "R", "Q", "K", "p", "n", "b", "r", "q", "k"] # Dictionary for mapping the piece type to an integer value P_MAP = {"P": 0, "N": 1, "B": 2, "R": 3, "Q": 4, "K": 5, "p": 6, "n": 7, "b": 8, "r": 9, "q": 10, "k": 11} @@ -68,11 +76,16 @@ # each index indicates where each section start if MODE == MODE_CRAZYHOUSE or MODE == MODE_LICHESS: CHANNEL_MAPPING_POS = {"pieces": 0, "repetitions": 12, "prisoners": 14, "promo": 24, "ep_square": 26} +elif MODE == MODE_XIANGQI: + CHANNEL_MAPPING_POS = {"pieces": 0, "prisoners": 14} else: # MODE = MODE_CHESS CHANNEL_MAPPING_POS = {"pieces": 0, "repetitions": 12, "ep_square": 14} # constant value inputs -CHANNEL_MAPPING_CONST = {"color": 0, "total_mv_cnt": 1, "castling": 2, "no_progress_cnt": 6, "remaining_checks": 7} +if MODE == MODE_XIANGQI: + CHANNEL_MAPPING_CONST = {"color": 0, "total_mv_cnt": 1} +else: + CHANNEL_MAPPING_CONST = {"color": 0, "total_mv_cnt": 1, "castling": 2, "no_progress_cnt": 6, "remaining_checks": 7} # chess variant specification, same order as in lichess' opening explorer, name is given as ucivariant # (except is960 which is a boolean and can be checked by board.chess960 and is indicated first) @@ -84,8 +97,12 @@ "horde": HordeBoard, "racingkings": RacingKingsBoard} # Define the board size -BOARD_WIDTH = 8 -BOARD_HEIGHT = 8 +if MODE == MODE_XIANGQI: + BOARD_WIDTH = 9 + BOARD_HEIGHT = 10 +else: + BOARD_WIDTH = 8 + BOARD_HEIGHT = 8 # Define constants indicating the number of channels for the input plane presentation # and the number of channels used for the policy map representation @@ -110,6 +127,14 @@ NB_CHANNELS_PER_HISTORY_ITEM = 0 else: NB_CHANNELS_PER_HISTORY_ITEM = 2 +elif MODE == MODE_XIANGQI: + NB_CHANNELS_POS = 26 + NB_CHANNELS_CONST = 2 + NB_CHANNELS_VARIANTS = 0 + NB_POLICY_MAP_CHANNELS = 50 + NB_CHANNELS_HISTORY = 0 + NB_LAST_MOVES = 0 + NB_CHANNELS_PER_HISTORY_ITEM = 0 else: # MODE = MODE_CHESS NB_CHANNELS_POS = 15 NB_CHANNELS_CONST = 7 @@ -120,12 +145,18 @@ NB_CHANNELS_PER_HISTORY_ITEM = 2 # number of labels of the corresponding flattened policy map. Most of these entries are unreachable (always 0) -NB_LABELS_POLICY_MAP = NB_POLICY_MAP_CHANNELS * BOARD_HEIGHT * BOARD_WIDTH +if MODE == MODE_XIANGQI: + NB_LABELS_POLICY_MAP = 2086 +else: + NB_LABELS_POLICY_MAP = NB_POLICY_MAP_CHANNELS * BOARD_HEIGHT * BOARD_WIDTH NB_CHANNELS_HISTORY = NB_LAST_MOVES * NB_CHANNELS_PER_HISTORY_ITEM NB_CHANNELS_TOTAL = NB_CHANNELS_POS + NB_CHANNELS_CONST + NB_CHANNELS_VARIANTS + NB_CHANNELS_HISTORY -# define the number of different pieces one can have in his pocket (the king is excluded) -POCKETS_SIZE_PIECE_TYPE = 5 +# define the number of different pieces one can have in his pocket (the king/general is excluded) +if MODE == MODE_XIANGQI: + POCKETS_SIZE_PIECE_TYPE = 6 +else: + POCKETS_SIZE_PIECE_TYPE = 5 # (this used for normalization the input planes and setting an appropriate integer representation (e.g. int16) # use a constant matrix for normalization to allow broad cast operations @@ -133,6 +164,10 @@ MAX_NB_PRISONERS = 32 # define the maximum number of pieces of each type in a pocket MAX_NB_MOVES = 500 # 500 was set as the max number of total moves MAX_NB_NO_PROGRESS = 40 # originally this was set to 40, but actually it is meant to be 50 move rule +elif MODE == MODE_XIANGQI: + MAX_NB_PRISONERS = 5 + MAX_NB_MOVES = 500 + MAX_NB_NO_PROGRESS = 0 # not used in xiangqi else: # MODE = MODE_LICHESS or MODE = MODE_CHESS: MAX_NB_PRISONERS = 16 # at maximum you can have only 16 pawns (your own and the ones of the opponent) MAX_NB_MOVES = 500 # 500 was set as the max number of total moves @@ -4750,6 +4785,2095 @@ "Q@h8", ] +LABELS_XIANGQI = [ + 'a0b0', + 'a0c0', + 'a0d0', + 'a0e0', + 'a0f0', + 'a0g0', + 'a0h0', + 'a0i0', + 'a0a1', + 'a0a2', + 'a0a3', + 'a0a4', + 'a0a5', + 'a0a6', + 'a0a7', + 'a0a8', + 'a0a9', + 'a0b2', + 'a0c1', + 'b0a0', + 'b0c0', + 'b0d0', + 'b0e0', + 'b0f0', + 'b0g0', + 'b0h0', + 'b0i0', + 'b0b1', + 'b0b2', + 'b0b3', + 'b0b4', + 'b0b5', + 'b0b6', + 'b0b7', + 'b0b8', + 'b0b9', + 'b0a2', + 'b0c2', + 'b0d1', + 'c0a0', + 'c0b0', + 'c0d0', + 'c0e0', + 'c0f0', + 'c0g0', + 'c0h0', + 'c0i0', + 'c0c1', + 'c0c2', + 'c0c3', + 'c0c4', + 'c0c5', + 'c0c6', + 'c0c7', + 'c0c8', + 'c0c9', + 'c0a1', + 'c0b2', + 'c0d2', + 'c0e1', + 'c0a2', + 'c0e2', + 'd0a0', + 'd0b0', + 'd0c0', + 'd0e0', + 'd0f0', + 'd0g0', + 'd0h0', + 'd0i0', + 'd0d1', + 'd0d2', + 'd0d3', + 'd0d4', + 'd0d5', + 'd0d6', + 'd0d7', + 'd0d8', + 'd0d9', + 'd0b1', + 'd0c2', + 'd0e2', + 'd0f1', + 'e0a0', + 'e0b0', + 'e0c0', + 'e0d0', + 'e0f0', + 'e0g0', + 'e0h0', + 'e0i0', + 'e0e1', + 'e0e2', + 'e0e3', + 'e0e4', + 'e0e5', + 'e0e6', + 'e0e7', + 'e0e8', + 'e0e9', + 'e0c1', + 'e0d2', + 'e0f2', + 'e0g1', + 'f0a0', + 'f0b0', + 'f0c0', + 'f0d0', + 'f0e0', + 'f0g0', + 'f0h0', + 'f0i0', + 'f0f1', + 'f0f2', + 'f0f3', + 'f0f4', + 'f0f5', + 'f0f6', + 'f0f7', + 'f0f8', + 'f0f9', + 'f0d1', + 'f0e2', + 'f0g2', + 'f0h1', + 'g0a0', + 'g0b0', + 'g0c0', + 'g0d0', + 'g0e0', + 'g0f0', + 'g0h0', + 'g0i0', + 'g0g1', + 'g0g2', + 'g0g3', + 'g0g4', + 'g0g5', + 'g0g6', + 'g0g7', + 'g0g8', + 'g0g9', + 'g0e1', + 'g0f2', + 'g0h2', + 'g0i1', + 'g0e2', + 'g0i2', + 'h0a0', + 'h0b0', + 'h0c0', + 'h0d0', + 'h0e0', + 'h0f0', + 'h0g0', + 'h0i0', + 'h0h1', + 'h0h2', + 'h0h3', + 'h0h4', + 'h0h5', + 'h0h6', + 'h0h7', + 'h0h8', + 'h0h9', + 'h0f1', + 'h0g2', + 'h0i2', + 'i0a0', + 'i0b0', + 'i0c0', + 'i0d0', + 'i0e0', + 'i0f0', + 'i0g0', + 'i0h0', + 'i0i1', + 'i0i2', + 'i0i3', + 'i0i4', + 'i0i5', + 'i0i6', + 'i0i7', + 'i0i8', + 'i0i9', + 'i0g1', + 'i0h2', + 'a1b1', + 'a1c1', + 'a1d1', + 'a1e1', + 'a1f1', + 'a1g1', + 'a1h1', + 'a1i1', + 'a1a0', + 'a1a2', + 'a1a3', + 'a1a4', + 'a1a5', + 'a1a6', + 'a1a7', + 'a1a8', + 'a1a9', + 'a1b3', + 'a1c2', + 'a1c0', + 'b1a1', + 'b1c1', + 'b1d1', + 'b1e1', + 'b1f1', + 'b1g1', + 'b1h1', + 'b1i1', + 'b1b0', + 'b1b2', + 'b1b3', + 'b1b4', + 'b1b5', + 'b1b6', + 'b1b7', + 'b1b8', + 'b1b9', + 'b1a3', + 'b1c3', + 'b1d2', + 'b1d0', + 'c1a1', + 'c1b1', + 'c1d1', + 'c1e1', + 'c1f1', + 'c1g1', + 'c1h1', + 'c1i1', + 'c1c0', + 'c1c2', + 'c1c3', + 'c1c4', + 'c1c5', + 'c1c6', + 'c1c7', + 'c1c8', + 'c1c9', + 'c1a0', + 'c1a2', + 'c1b3', + 'c1d3', + 'c1e2', + 'c1e0', + 'd1a1', + 'd1b1', + 'd1c1', + 'd1e1', + 'd1f1', + 'd1g1', + 'd1h1', + 'd1i1', + 'd1d0', + 'd1d2', + 'd1d3', + 'd1d4', + 'd1d5', + 'd1d6', + 'd1d7', + 'd1d8', + 'd1d9', + 'd1b0', + 'd1b2', + 'd1c3', + 'd1e3', + 'd1f2', + 'd1f0', + 'e1a1', + 'e1b1', + 'e1c1', + 'e1d1', + 'e1f1', + 'e1g1', + 'e1h1', + 'e1i1', + 'e1e0', + 'e1e2', + 'e1e3', + 'e1e4', + 'e1e5', + 'e1e6', + 'e1e7', + 'e1e8', + 'e1e9', + 'e1c0', + 'e1c2', + 'e1d3', + 'e1f3', + 'e1g2', + 'e1g0', + 'e1d0', + 'e1d2', + 'e1f2', + 'e1f0', + 'f1a1', + 'f1b1', + 'f1c1', + 'f1d1', + 'f1e1', + 'f1g1', + 'f1h1', + 'f1i1', + 'f1f0', + 'f1f2', + 'f1f3', + 'f1f4', + 'f1f5', + 'f1f6', + 'f1f7', + 'f1f8', + 'f1f9', + 'f1d0', + 'f1d2', + 'f1e3', + 'f1g3', + 'f1h2', + 'f1h0', + 'g1a1', + 'g1b1', + 'g1c1', + 'g1d1', + 'g1e1', + 'g1f1', + 'g1h1', + 'g1i1', + 'g1g0', + 'g1g2', + 'g1g3', + 'g1g4', + 'g1g5', + 'g1g6', + 'g1g7', + 'g1g8', + 'g1g9', + 'g1e0', + 'g1e2', + 'g1f3', + 'g1h3', + 'g1i2', + 'g1i0', + 'h1a1', + 'h1b1', + 'h1c1', + 'h1d1', + 'h1e1', + 'h1f1', + 'h1g1', + 'h1i1', + 'h1h0', + 'h1h2', + 'h1h3', + 'h1h4', + 'h1h5', + 'h1h6', + 'h1h7', + 'h1h8', + 'h1h9', + 'h1f0', + 'h1f2', + 'h1g3', + 'h1i3', + 'i1a1', + 'i1b1', + 'i1c1', + 'i1d1', + 'i1e1', + 'i1f1', + 'i1g1', + 'i1h1', + 'i1i0', + 'i1i2', + 'i1i3', + 'i1i4', + 'i1i5', + 'i1i6', + 'i1i7', + 'i1i8', + 'i1i9', + 'i1g0', + 'i1g2', + 'i1h3', + 'a2b2', + 'a2c2', + 'a2d2', + 'a2e2', + 'a2f2', + 'a2g2', + 'a2h2', + 'a2i2', + 'a2a0', + 'a2a1', + 'a2a3', + 'a2a4', + 'a2a5', + 'a2a6', + 'a2a7', + 'a2a8', + 'a2a9', + 'a2b4', + 'a2c3', + 'a2c1', + 'a2b0', + 'a2c4', + 'a2c0', + 'b2a2', + 'b2c2', + 'b2d2', + 'b2e2', + 'b2f2', + 'b2g2', + 'b2h2', + 'b2i2', + 'b2b0', + 'b2b1', + 'b2b3', + 'b2b4', + 'b2b5', + 'b2b6', + 'b2b7', + 'b2b8', + 'b2b9', + 'b2a0', + 'b2a4', + 'b2c4', + 'b2d3', + 'b2d1', + 'b2c0', + 'c2a2', + 'c2b2', + 'c2d2', + 'c2e2', + 'c2f2', + 'c2g2', + 'c2h2', + 'c2i2', + 'c2c0', + 'c2c1', + 'c2c3', + 'c2c4', + 'c2c5', + 'c2c6', + 'c2c7', + 'c2c8', + 'c2c9', + 'c2b0', + 'c2a1', + 'c2a3', + 'c2b4', + 'c2d4', + 'c2e3', + 'c2e1', + 'c2d0', + 'd2a2', + 'd2b2', + 'd2c2', + 'd2e2', + 'd2f2', + 'd2g2', + 'd2h2', + 'd2i2', + 'd2d0', + 'd2d1', + 'd2d3', + 'd2d4', + 'd2d5', + 'd2d6', + 'd2d7', + 'd2d8', + 'd2d9', + 'd2c0', + 'd2b1', + 'd2b3', + 'd2c4', + 'd2e4', + 'd2f3', + 'd2f1', + 'd2e0', + 'e2a2', + 'e2b2', + 'e2c2', + 'e2d2', + 'e2f2', + 'e2g2', + 'e2h2', + 'e2i2', + 'e2e0', + 'e2e1', + 'e2e3', + 'e2e4', + 'e2e5', + 'e2e6', + 'e2e7', + 'e2e8', + 'e2e9', + 'e2d0', + 'e2c1', + 'e2c3', + 'e2d4', + 'e2f4', + 'e2g3', + 'e2g1', + 'e2f0', + 'e2c4', + 'e2g4', + 'e2c0', + 'e2g0', + 'f2a2', + 'f2b2', + 'f2c2', + 'f2d2', + 'f2e2', + 'f2g2', + 'f2h2', + 'f2i2', + 'f2f0', + 'f2f1', + 'f2f3', + 'f2f4', + 'f2f5', + 'f2f6', + 'f2f7', + 'f2f8', + 'f2f9', + 'f2e0', + 'f2d1', + 'f2d3', + 'f2e4', + 'f2g4', + 'f2h3', + 'f2h1', + 'f2g0', + 'g2a2', + 'g2b2', + 'g2c2', + 'g2d2', + 'g2e2', + 'g2f2', + 'g2h2', + 'g2i2', + 'g2g0', + 'g2g1', + 'g2g3', + 'g2g4', + 'g2g5', + 'g2g6', + 'g2g7', + 'g2g8', + 'g2g9', + 'g2f0', + 'g2e1', + 'g2e3', + 'g2f4', + 'g2h4', + 'g2i3', + 'g2i1', + 'g2h0', + 'h2a2', + 'h2b2', + 'h2c2', + 'h2d2', + 'h2e2', + 'h2f2', + 'h2g2', + 'h2i2', + 'h2h0', + 'h2h1', + 'h2h3', + 'h2h4', + 'h2h5', + 'h2h6', + 'h2h7', + 'h2h8', + 'h2h9', + 'h2g0', + 'h2f1', + 'h2f3', + 'h2g4', + 'h2i4', + 'h2i0', + 'i2a2', + 'i2b2', + 'i2c2', + 'i2d2', + 'i2e2', + 'i2f2', + 'i2g2', + 'i2h2', + 'i2i0', + 'i2i1', + 'i2i3', + 'i2i4', + 'i2i5', + 'i2i6', + 'i2i7', + 'i2i8', + 'i2i9', + 'i2h0', + 'i2g1', + 'i2g3', + 'i2h4', + 'i2g4', + 'i2g0', + 'a3b3', + 'a3c3', + 'a3d3', + 'a3e3', + 'a3f3', + 'a3g3', + 'a3h3', + 'a3i3', + 'a3a0', + 'a3a1', + 'a3a2', + 'a3a4', + 'a3a5', + 'a3a6', + 'a3a7', + 'a3a8', + 'a3a9', + 'a3b5', + 'a3c4', + 'a3c2', + 'a3b1', + 'b3a3', + 'b3c3', + 'b3d3', + 'b3e3', + 'b3f3', + 'b3g3', + 'b3h3', + 'b3i3', + 'b3b0', + 'b3b1', + 'b3b2', + 'b3b4', + 'b3b5', + 'b3b6', + 'b3b7', + 'b3b8', + 'b3b9', + 'b3a1', + 'b3a5', + 'b3c5', + 'b3d4', + 'b3d2', + 'b3c1', + 'c3a3', + 'c3b3', + 'c3d3', + 'c3e3', + 'c3f3', + 'c3g3', + 'c3h3', + 'c3i3', + 'c3c0', + 'c3c1', + 'c3c2', + 'c3c4', + 'c3c5', + 'c3c6', + 'c3c7', + 'c3c8', + 'c3c9', + 'c3b1', + 'c3a2', + 'c3a4', + 'c3b5', + 'c3d5', + 'c3e4', + 'c3e2', + 'c3d1', + 'd3a3', + 'd3b3', + 'd3c3', + 'd3e3', + 'd3f3', + 'd3g3', + 'd3h3', + 'd3i3', + 'd3d0', + 'd3d1', + 'd3d2', + 'd3d4', + 'd3d5', + 'd3d6', + 'd3d7', + 'd3d8', + 'd3d9', + 'd3c1', + 'd3b2', + 'd3b4', + 'd3c5', + 'd3e5', + 'd3f4', + 'd3f2', + 'd3e1', + 'e3a3', + 'e3b3', + 'e3c3', + 'e3d3', + 'e3f3', + 'e3g3', + 'e3h3', + 'e3i3', + 'e3e0', + 'e3e1', + 'e3e2', + 'e3e4', + 'e3e5', + 'e3e6', + 'e3e7', + 'e3e8', + 'e3e9', + 'e3d1', + 'e3c2', + 'e3c4', + 'e3d5', + 'e3f5', + 'e3g4', + 'e3g2', + 'e3f1', + 'f3a3', + 'f3b3', + 'f3c3', + 'f3d3', + 'f3e3', + 'f3g3', + 'f3h3', + 'f3i3', + 'f3f0', + 'f3f1', + 'f3f2', + 'f3f4', + 'f3f5', + 'f3f6', + 'f3f7', + 'f3f8', + 'f3f9', + 'f3e1', + 'f3d2', + 'f3d4', + 'f3e5', + 'f3g5', + 'f3h4', + 'f3h2', + 'f3g1', + 'g3a3', + 'g3b3', + 'g3c3', + 'g3d3', + 'g3e3', + 'g3f3', + 'g3h3', + 'g3i3', + 'g3g0', + 'g3g1', + 'g3g2', + 'g3g4', + 'g3g5', + 'g3g6', + 'g3g7', + 'g3g8', + 'g3g9', + 'g3f1', + 'g3e2', + 'g3e4', + 'g3f5', + 'g3h5', + 'g3i4', + 'g3i2', + 'g3h1', + 'h3a3', + 'h3b3', + 'h3c3', + 'h3d3', + 'h3e3', + 'h3f3', + 'h3g3', + 'h3i3', + 'h3h0', + 'h3h1', + 'h3h2', + 'h3h4', + 'h3h5', + 'h3h6', + 'h3h7', + 'h3h8', + 'h3h9', + 'h3g1', + 'h3f2', + 'h3f4', + 'h3g5', + 'h3i5', + 'h3i1', + 'i3a3', + 'i3b3', + 'i3c3', + 'i3d3', + 'i3e3', + 'i3f3', + 'i3g3', + 'i3h3', + 'i3i0', + 'i3i1', + 'i3i2', + 'i3i4', + 'i3i5', + 'i3i6', + 'i3i7', + 'i3i8', + 'i3i9', + 'i3h1', + 'i3g2', + 'i3g4', + 'i3h5', + 'a4b4', + 'a4c4', + 'a4d4', + 'a4e4', + 'a4f4', + 'a4g4', + 'a4h4', + 'a4i4', + 'a4a0', + 'a4a1', + 'a4a2', + 'a4a3', + 'a4a5', + 'a4a6', + 'a4a7', + 'a4a8', + 'a4a9', + 'a4b6', + 'a4c5', + 'a4c3', + 'a4b2', + 'b4a4', + 'b4c4', + 'b4d4', + 'b4e4', + 'b4f4', + 'b4g4', + 'b4h4', + 'b4i4', + 'b4b0', + 'b4b1', + 'b4b2', + 'b4b3', + 'b4b5', + 'b4b6', + 'b4b7', + 'b4b8', + 'b4b9', + 'b4a2', + 'b4a6', + 'b4c6', + 'b4d5', + 'b4d3', + 'b4c2', + 'c4a4', + 'c4b4', + 'c4d4', + 'c4e4', + 'c4f4', + 'c4g4', + 'c4h4', + 'c4i4', + 'c4c0', + 'c4c1', + 'c4c2', + 'c4c3', + 'c4c5', + 'c4c6', + 'c4c7', + 'c4c8', + 'c4c9', + 'c4b2', + 'c4a3', + 'c4a5', + 'c4b6', + 'c4d6', + 'c4e5', + 'c4e3', + 'c4d2', + 'c4a2', + 'c4e2', + 'd4a4', + 'd4b4', + 'd4c4', + 'd4e4', + 'd4f4', + 'd4g4', + 'd4h4', + 'd4i4', + 'd4d0', + 'd4d1', + 'd4d2', + 'd4d3', + 'd4d5', + 'd4d6', + 'd4d7', + 'd4d8', + 'd4d9', + 'd4c2', + 'd4b3', + 'd4b5', + 'd4c6', + 'd4e6', + 'd4f5', + 'd4f3', + 'd4e2', + 'e4a4', + 'e4b4', + 'e4c4', + 'e4d4', + 'e4f4', + 'e4g4', + 'e4h4', + 'e4i4', + 'e4e0', + 'e4e1', + 'e4e2', + 'e4e3', + 'e4e5', + 'e4e6', + 'e4e7', + 'e4e8', + 'e4e9', + 'e4d2', + 'e4c3', + 'e4c5', + 'e4d6', + 'e4f6', + 'e4g5', + 'e4g3', + 'e4f2', + 'f4a4', + 'f4b4', + 'f4c4', + 'f4d4', + 'f4e4', + 'f4g4', + 'f4h4', + 'f4i4', + 'f4f0', + 'f4f1', + 'f4f2', + 'f4f3', + 'f4f5', + 'f4f6', + 'f4f7', + 'f4f8', + 'f4f9', + 'f4e2', + 'f4d3', + 'f4d5', + 'f4e6', + 'f4g6', + 'f4h5', + 'f4h3', + 'f4g2', + 'g4a4', + 'g4b4', + 'g4c4', + 'g4d4', + 'g4e4', + 'g4f4', + 'g4h4', + 'g4i4', + 'g4g0', + 'g4g1', + 'g4g2', + 'g4g3', + 'g4g5', + 'g4g6', + 'g4g7', + 'g4g8', + 'g4g9', + 'g4f2', + 'g4e3', + 'g4e5', + 'g4f6', + 'g4h6', + 'g4i5', + 'g4i3', + 'g4h2', + 'g4e2', + 'g4i2', + 'h4a4', + 'h4b4', + 'h4c4', + 'h4d4', + 'h4e4', + 'h4f4', + 'h4g4', + 'h4i4', + 'h4h0', + 'h4h1', + 'h4h2', + 'h4h3', + 'h4h5', + 'h4h6', + 'h4h7', + 'h4h8', + 'h4h9', + 'h4g2', + 'h4f3', + 'h4f5', + 'h4g6', + 'h4i6', + 'h4i2', + 'i4a4', + 'i4b4', + 'i4c4', + 'i4d4', + 'i4e4', + 'i4f4', + 'i4g4', + 'i4h4', + 'i4i0', + 'i4i1', + 'i4i2', + 'i4i3', + 'i4i5', + 'i4i6', + 'i4i7', + 'i4i8', + 'i4i9', + 'i4h2', + 'i4g3', + 'i4g5', + 'i4h6', + 'a5b5', + 'a5c5', + 'a5d5', + 'a5e5', + 'a5f5', + 'a5g5', + 'a5h5', + 'a5i5', + 'a5a0', + 'a5a1', + 'a5a2', + 'a5a3', + 'a5a4', + 'a5a6', + 'a5a7', + 'a5a8', + 'a5a9', + 'a5b7', + 'a5c6', + 'a5c4', + 'a5b3', + 'b5a5', + 'b5c5', + 'b5d5', + 'b5e5', + 'b5f5', + 'b5g5', + 'b5h5', + 'b5i5', + 'b5b0', + 'b5b1', + 'b5b2', + 'b5b3', + 'b5b4', + 'b5b6', + 'b5b7', + 'b5b8', + 'b5b9', + 'b5a3', + 'b5a7', + 'b5c7', + 'b5d6', + 'b5d4', + 'b5c3', + 'c5a5', + 'c5b5', + 'c5d5', + 'c5e5', + 'c5f5', + 'c5g5', + 'c5h5', + 'c5i5', + 'c5c0', + 'c5c1', + 'c5c2', + 'c5c3', + 'c5c4', + 'c5c6', + 'c5c7', + 'c5c8', + 'c5c9', + 'c5b3', + 'c5a4', + 'c5a6', + 'c5b7', + 'c5d7', + 'c5e6', + 'c5e4', + 'c5d3', + 'c5a7', + 'c5e7', + 'd5a5', + 'd5b5', + 'd5c5', + 'd5e5', + 'd5f5', + 'd5g5', + 'd5h5', + 'd5i5', + 'd5d0', + 'd5d1', + 'd5d2', + 'd5d3', + 'd5d4', + 'd5d6', + 'd5d7', + 'd5d8', + 'd5d9', + 'd5c3', + 'd5b4', + 'd5b6', + 'd5c7', + 'd5e7', + 'd5f6', + 'd5f4', + 'd5e3', + 'e5a5', + 'e5b5', + 'e5c5', + 'e5d5', + 'e5f5', + 'e5g5', + 'e5h5', + 'e5i5', + 'e5e0', + 'e5e1', + 'e5e2', + 'e5e3', + 'e5e4', + 'e5e6', + 'e5e7', + 'e5e8', + 'e5e9', + 'e5d3', + 'e5c4', + 'e5c6', + 'e5d7', + 'e5f7', + 'e5g6', + 'e5g4', + 'e5f3', + 'f5a5', + 'f5b5', + 'f5c5', + 'f5d5', + 'f5e5', + 'f5g5', + 'f5h5', + 'f5i5', + 'f5f0', + 'f5f1', + 'f5f2', + 'f5f3', + 'f5f4', + 'f5f6', + 'f5f7', + 'f5f8', + 'f5f9', + 'f5e3', + 'f5d4', + 'f5d6', + 'f5e7', + 'f5g7', + 'f5h6', + 'f5h4', + 'f5g3', + 'g5a5', + 'g5b5', + 'g5c5', + 'g5d5', + 'g5e5', + 'g5f5', + 'g5h5', + 'g5i5', + 'g5g0', + 'g5g1', + 'g5g2', + 'g5g3', + 'g5g4', + 'g5g6', + 'g5g7', + 'g5g8', + 'g5g9', + 'g5f3', + 'g5e4', + 'g5e6', + 'g5f7', + 'g5h7', + 'g5i6', + 'g5i4', + 'g5h3', + 'g5e7', + 'g5i7', + 'h5a5', + 'h5b5', + 'h5c5', + 'h5d5', + 'h5e5', + 'h5f5', + 'h5g5', + 'h5i5', + 'h5h0', + 'h5h1', + 'h5h2', + 'h5h3', + 'h5h4', + 'h5h6', + 'h5h7', + 'h5h8', + 'h5h9', + 'h5g3', + 'h5f4', + 'h5f6', + 'h5g7', + 'h5i7', + 'h5i3', + 'i5a5', + 'i5b5', + 'i5c5', + 'i5d5', + 'i5e5', + 'i5f5', + 'i5g5', + 'i5h5', + 'i5i0', + 'i5i1', + 'i5i2', + 'i5i3', + 'i5i4', + 'i5i6', + 'i5i7', + 'i5i8', + 'i5i9', + 'i5h3', + 'i5g4', + 'i5g6', + 'i5h7', + 'a6b6', + 'a6c6', + 'a6d6', + 'a6e6', + 'a6f6', + 'a6g6', + 'a6h6', + 'a6i6', + 'a6a0', + 'a6a1', + 'a6a2', + 'a6a3', + 'a6a4', + 'a6a5', + 'a6a7', + 'a6a8', + 'a6a9', + 'a6b8', + 'a6c7', + 'a6c5', + 'a6b4', + 'b6a6', + 'b6c6', + 'b6d6', + 'b6e6', + 'b6f6', + 'b6g6', + 'b6h6', + 'b6i6', + 'b6b0', + 'b6b1', + 'b6b2', + 'b6b3', + 'b6b4', + 'b6b5', + 'b6b7', + 'b6b8', + 'b6b9', + 'b6a4', + 'b6a8', + 'b6c8', + 'b6d7', + 'b6d5', + 'b6c4', + 'c6a6', + 'c6b6', + 'c6d6', + 'c6e6', + 'c6f6', + 'c6g6', + 'c6h6', + 'c6i6', + 'c6c0', + 'c6c1', + 'c6c2', + 'c6c3', + 'c6c4', + 'c6c5', + 'c6c7', + 'c6c8', + 'c6c9', + 'c6b4', + 'c6a5', + 'c6a7', + 'c6b8', + 'c6d8', + 'c6e7', + 'c6e5', + 'c6d4', + 'd6a6', + 'd6b6', + 'd6c6', + 'd6e6', + 'd6f6', + 'd6g6', + 'd6h6', + 'd6i6', + 'd6d0', + 'd6d1', + 'd6d2', + 'd6d3', + 'd6d4', + 'd6d5', + 'd6d7', + 'd6d8', + 'd6d9', + 'd6c4', + 'd6b5', + 'd6b7', + 'd6c8', + 'd6e8', + 'd6f7', + 'd6f5', + 'd6e4', + 'e6a6', + 'e6b6', + 'e6c6', + 'e6d6', + 'e6f6', + 'e6g6', + 'e6h6', + 'e6i6', + 'e6e0', + 'e6e1', + 'e6e2', + 'e6e3', + 'e6e4', + 'e6e5', + 'e6e7', + 'e6e8', + 'e6e9', + 'e6d4', + 'e6c5', + 'e6c7', + 'e6d8', + 'e6f8', + 'e6g7', + 'e6g5', + 'e6f4', + 'f6a6', + 'f6b6', + 'f6c6', + 'f6d6', + 'f6e6', + 'f6g6', + 'f6h6', + 'f6i6', + 'f6f0', + 'f6f1', + 'f6f2', + 'f6f3', + 'f6f4', + 'f6f5', + 'f6f7', + 'f6f8', + 'f6f9', + 'f6e4', + 'f6d5', + 'f6d7', + 'f6e8', + 'f6g8', + 'f6h7', + 'f6h5', + 'f6g4', + 'g6a6', + 'g6b6', + 'g6c6', + 'g6d6', + 'g6e6', + 'g6f6', + 'g6h6', + 'g6i6', + 'g6g0', + 'g6g1', + 'g6g2', + 'g6g3', + 'g6g4', + 'g6g5', + 'g6g7', + 'g6g8', + 'g6g9', + 'g6f4', + 'g6e5', + 'g6e7', + 'g6f8', + 'g6h8', + 'g6i7', + 'g6i5', + 'g6h4', + 'h6a6', + 'h6b6', + 'h6c6', + 'h6d6', + 'h6e6', + 'h6f6', + 'h6g6', + 'h6i6', + 'h6h0', + 'h6h1', + 'h6h2', + 'h6h3', + 'h6h4', + 'h6h5', + 'h6h7', + 'h6h8', + 'h6h9', + 'h6g4', + 'h6f5', + 'h6f7', + 'h6g8', + 'h6i8', + 'h6i4', + 'i6a6', + 'i6b6', + 'i6c6', + 'i6d6', + 'i6e6', + 'i6f6', + 'i6g6', + 'i6h6', + 'i6i0', + 'i6i1', + 'i6i2', + 'i6i3', + 'i6i4', + 'i6i5', + 'i6i7', + 'i6i8', + 'i6i9', + 'i6h4', + 'i6g5', + 'i6g7', + 'i6h8', + 'a7b7', + 'a7c7', + 'a7d7', + 'a7e7', + 'a7f7', + 'a7g7', + 'a7h7', + 'a7i7', + 'a7a0', + 'a7a1', + 'a7a2', + 'a7a3', + 'a7a4', + 'a7a5', + 'a7a6', + 'a7a8', + 'a7a9', + 'a7b9', + 'a7c8', + 'a7c6', + 'a7b5', + 'a7c9', + 'a7c5', + 'b7a7', + 'b7c7', + 'b7d7', + 'b7e7', + 'b7f7', + 'b7g7', + 'b7h7', + 'b7i7', + 'b7b0', + 'b7b1', + 'b7b2', + 'b7b3', + 'b7b4', + 'b7b5', + 'b7b6', + 'b7b8', + 'b7b9', + 'b7a5', + 'b7a9', + 'b7c9', + 'b7d8', + 'b7d6', + 'b7c5', + 'c7a7', + 'c7b7', + 'c7d7', + 'c7e7', + 'c7f7', + 'c7g7', + 'c7h7', + 'c7i7', + 'c7c0', + 'c7c1', + 'c7c2', + 'c7c3', + 'c7c4', + 'c7c5', + 'c7c6', + 'c7c8', + 'c7c9', + 'c7b5', + 'c7a6', + 'c7a8', + 'c7b9', + 'c7d9', + 'c7e8', + 'c7e6', + 'c7d5', + 'd7a7', + 'd7b7', + 'd7c7', + 'd7e7', + 'd7f7', + 'd7g7', + 'd7h7', + 'd7i7', + 'd7d0', + 'd7d1', + 'd7d2', + 'd7d3', + 'd7d4', + 'd7d5', + 'd7d6', + 'd7d8', + 'd7d9', + 'd7c5', + 'd7b6', + 'd7b8', + 'd7c9', + 'd7e9', + 'd7f8', + 'd7f6', + 'd7e5', + 'e7a7', + 'e7b7', + 'e7c7', + 'e7d7', + 'e7f7', + 'e7g7', + 'e7h7', + 'e7i7', + 'e7e0', + 'e7e1', + 'e7e2', + 'e7e3', + 'e7e4', + 'e7e5', + 'e7e6', + 'e7e8', + 'e7e9', + 'e7d5', + 'e7c6', + 'e7c8', + 'e7d9', + 'e7f9', + 'e7g8', + 'e7g6', + 'e7f5', + 'e7c9', + 'e7g9', + 'e7c5', + 'e7g5', + 'f7a7', + 'f7b7', + 'f7c7', + 'f7d7', + 'f7e7', + 'f7g7', + 'f7h7', + 'f7i7', + 'f7f0', + 'f7f1', + 'f7f2', + 'f7f3', + 'f7f4', + 'f7f5', + 'f7f6', + 'f7f8', + 'f7f9', + 'f7e5', + 'f7d6', + 'f7d8', + 'f7e9', + 'f7g9', + 'f7h8', + 'f7h6', + 'f7g5', + 'g7a7', + 'g7b7', + 'g7c7', + 'g7d7', + 'g7e7', + 'g7f7', + 'g7h7', + 'g7i7', + 'g7g0', + 'g7g1', + 'g7g2', + 'g7g3', + 'g7g4', + 'g7g5', + 'g7g6', + 'g7g8', + 'g7g9', + 'g7f5', + 'g7e6', + 'g7e8', + 'g7f9', + 'g7h9', + 'g7i8', + 'g7i6', + 'g7h5', + 'h7a7', + 'h7b7', + 'h7c7', + 'h7d7', + 'h7e7', + 'h7f7', + 'h7g7', + 'h7i7', + 'h7h0', + 'h7h1', + 'h7h2', + 'h7h3', + 'h7h4', + 'h7h5', + 'h7h6', + 'h7h8', + 'h7h9', + 'h7g5', + 'h7f6', + 'h7f8', + 'h7g9', + 'h7i9', + 'h7i5', + 'i7a7', + 'i7b7', + 'i7c7', + 'i7d7', + 'i7e7', + 'i7f7', + 'i7g7', + 'i7h7', + 'i7i0', + 'i7i1', + 'i7i2', + 'i7i3', + 'i7i4', + 'i7i5', + 'i7i6', + 'i7i8', + 'i7i9', + 'i7h5', + 'i7g6', + 'i7g8', + 'i7h9', + 'i7g9', + 'i7g5', + 'a8b8', + 'a8c8', + 'a8d8', + 'a8e8', + 'a8f8', + 'a8g8', + 'a8h8', + 'a8i8', + 'a8a0', + 'a8a1', + 'a8a2', + 'a8a3', + 'a8a4', + 'a8a5', + 'a8a6', + 'a8a7', + 'a8a9', + 'a8c9', + 'a8c7', + 'a8b6', + 'b8a8', + 'b8c8', + 'b8d8', + 'b8e8', + 'b8f8', + 'b8g8', + 'b8h8', + 'b8i8', + 'b8b0', + 'b8b1', + 'b8b2', + 'b8b3', + 'b8b4', + 'b8b5', + 'b8b6', + 'b8b7', + 'b8b9', + 'b8a6', + 'b8d9', + 'b8d7', + 'b8c6', + 'c8a8', + 'c8b8', + 'c8d8', + 'c8e8', + 'c8f8', + 'c8g8', + 'c8h8', + 'c8i8', + 'c8c0', + 'c8c1', + 'c8c2', + 'c8c3', + 'c8c4', + 'c8c5', + 'c8c6', + 'c8c7', + 'c8c9', + 'c8b6', + 'c8a7', + 'c8a9', + 'c8e9', + 'c8e7', + 'c8d6', + 'd8a8', + 'd8b8', + 'd8c8', + 'd8e8', + 'd8f8', + 'd8g8', + 'd8h8', + 'd8i8', + 'd8d0', + 'd8d1', + 'd8d2', + 'd8d3', + 'd8d4', + 'd8d5', + 'd8d6', + 'd8d7', + 'd8d9', + 'd8c6', + 'd8b7', + 'd8b9', + 'd8f9', + 'd8f7', + 'd8e6', + 'e8a8', + 'e8b8', + 'e8c8', + 'e8d8', + 'e8f8', + 'e8g8', + 'e8h8', + 'e8i8', + 'e8e0', + 'e8e1', + 'e8e2', + 'e8e3', + 'e8e4', + 'e8e5', + 'e8e6', + 'e8e7', + 'e8e9', + 'e8d6', + 'e8c7', + 'e8c9', + 'e8g9', + 'e8g7', + 'e8f6', + 'e8d7', + 'e8d9', + 'e8f9', + 'e8f7', + 'f8a8', + 'f8b8', + 'f8c8', + 'f8d8', + 'f8e8', + 'f8g8', + 'f8h8', + 'f8i8', + 'f8f0', + 'f8f1', + 'f8f2', + 'f8f3', + 'f8f4', + 'f8f5', + 'f8f6', + 'f8f7', + 'f8f9', + 'f8e6', + 'f8d7', + 'f8d9', + 'f8h9', + 'f8h7', + 'f8g6', + 'g8a8', + 'g8b8', + 'g8c8', + 'g8d8', + 'g8e8', + 'g8f8', + 'g8h8', + 'g8i8', + 'g8g0', + 'g8g1', + 'g8g2', + 'g8g3', + 'g8g4', + 'g8g5', + 'g8g6', + 'g8g7', + 'g8g9', + 'g8f6', + 'g8e7', + 'g8e9', + 'g8i9', + 'g8i7', + 'g8h6', + 'h8a8', + 'h8b8', + 'h8c8', + 'h8d8', + 'h8e8', + 'h8f8', + 'h8g8', + 'h8i8', + 'h8h0', + 'h8h1', + 'h8h2', + 'h8h3', + 'h8h4', + 'h8h5', + 'h8h6', + 'h8h7', + 'h8h9', + 'h8g6', + 'h8f7', + 'h8f9', + 'h8i6', + 'i8a8', + 'i8b8', + 'i8c8', + 'i8d8', + 'i8e8', + 'i8f8', + 'i8g8', + 'i8h8', + 'i8i0', + 'i8i1', + 'i8i2', + 'i8i3', + 'i8i4', + 'i8i5', + 'i8i6', + 'i8i7', + 'i8i9', + 'i8h6', + 'i8g7', + 'i8g9', + 'a9b9', + 'a9c9', + 'a9d9', + 'a9e9', + 'a9f9', + 'a9g9', + 'a9h9', + 'a9i9', + 'a9a0', + 'a9a1', + 'a9a2', + 'a9a3', + 'a9a4', + 'a9a5', + 'a9a6', + 'a9a7', + 'a9a8', + 'a9c8', + 'a9b7', + 'b9a9', + 'b9c9', + 'b9d9', + 'b9e9', + 'b9f9', + 'b9g9', + 'b9h9', + 'b9i9', + 'b9b0', + 'b9b1', + 'b9b2', + 'b9b3', + 'b9b4', + 'b9b5', + 'b9b6', + 'b9b7', + 'b9b8', + 'b9a7', + 'b9d8', + 'b9c7', + 'c9a9', + 'c9b9', + 'c9d9', + 'c9e9', + 'c9f9', + 'c9g9', + 'c9h9', + 'c9i9', + 'c9c0', + 'c9c1', + 'c9c2', + 'c9c3', + 'c9c4', + 'c9c5', + 'c9c6', + 'c9c7', + 'c9c8', + 'c9b7', + 'c9a8', + 'c9e8', + 'c9d7', + 'c9a7', + 'c9e7', + 'd9a9', + 'd9b9', + 'd9c9', + 'd9e9', + 'd9f9', + 'd9g9', + 'd9h9', + 'd9i9', + 'd9d0', + 'd9d1', + 'd9d2', + 'd9d3', + 'd9d4', + 'd9d5', + 'd9d6', + 'd9d7', + 'd9d8', + 'd9c7', + 'd9b8', + 'd9f8', + 'd9e7', + 'e9a9', + 'e9b9', + 'e9c9', + 'e9d9', + 'e9f9', + 'e9g9', + 'e9h9', + 'e9i9', + 'e9e0', + 'e9e1', + 'e9e2', + 'e9e3', + 'e9e4', + 'e9e5', + 'e9e6', + 'e9e7', + 'e9e8', + 'e9d7', + 'e9c8', + 'e9g8', + 'e9f7', + 'f9a9', + 'f9b9', + 'f9c9', + 'f9d9', + 'f9e9', + 'f9g9', + 'f9h9', + 'f9i9', + 'f9f0', + 'f9f1', + 'f9f2', + 'f9f3', + 'f9f4', + 'f9f5', + 'f9f6', + 'f9f7', + 'f9f8', + 'f9e7', + 'f9d8', + 'f9h8', + 'f9g7', + 'g9a9', + 'g9b9', + 'g9c9', + 'g9d9', + 'g9e9', + 'g9f9', + 'g9h9', + 'g9i9', + 'g9g0', + 'g9g1', + 'g9g2', + 'g9g3', + 'g9g4', + 'g9g5', + 'g9g6', + 'g9g7', + 'g9g8', + 'g9f7', + 'g9e8', + 'g9i8', + 'g9h7', + 'g9e7', + 'g9i7', + 'h9a9', + 'h9b9', + 'h9c9', + 'h9d9', + 'h9e9', + 'h9f9', + 'h9g9', + 'h9i9', + 'h9h0', + 'h9h1', + 'h9h2', + 'h9h3', + 'h9h4', + 'h9h5', + 'h9h6', + 'h9h7', + 'h9h8', + 'h9g7', + 'h9f8', + 'h9i7', + 'i9a9', + 'i9b9', + 'i9c9', + 'i9d9', + 'i9e9', + 'i9f9', + 'i9g9', + 'i9h9', + 'i9i0', + 'i9i1', + 'i9i2', + 'i9i3', + 'i9i4', + 'i9i5', + 'i9i6', + 'i9i7', + 'i9i8', + 'i9h7', + 'i9g8', + 'd0e1', + 'f0e1', + 'd2e1', + 'f2e1', + 'd9e8', + 'f9e8', + 'd7e8', + 'f7e8', +] + # legal moves total: NB_LABELS_LICHESS = 2316 # remove dropping moves for chess variant @@ -4762,6 +6886,9 @@ # Includes king promotion moves were added to support antichess LABELS = LABELS_LICHESS NB_LABELS = NB_LABELS_LICHESS +elif MODE == MODE_XIANGQI: + LABELS = LABELS_XIANGQI + NB_LABELS = len(LABELS_XIANGQI) else: # MODE = MODE_CHESS # same as for crazyhouse but without dropping moves LABELS = LABELS_CZ[:NB_LABELS_CHESS] @@ -4780,29 +6907,33 @@ def mirror_move(move: chess.Move): return chess.Move(from_square, to_square, move.promotion, move.drop) -# flip the labels for BLACK -LABELS_MIRRORED = [None] * NB_LABELS +LABELS_MIRRORED = None +MV_LOOKUP = None +MV_LOOKUP_MIRRORED=None +if MODE != MODE_XIANGQI: + # flip the labels for BLACK + LABELS_MIRRORED = [None] * NB_LABELS -for i, label in enumerate(LABELS): - mv = chess.Move.from_uci(label) - mv_mirrored = mirror_move(mv) - LABELS_MIRRORED[i] = mv_mirrored.uci() + for i, label in enumerate(LABELS): + mv = chess.Move.from_uci(label) + mv_mirrored = mirror_move(mv) + LABELS_MIRRORED[i] = mv_mirrored.uci() -# The movement lookup table is a dictionary/hash-map which maps the string to the corresponding label index -MV_LOOKUP = {} + # The movement lookup table is a dictionary/hash-map which maps the string to the corresponding label index + MV_LOOKUP = {} -# iterate over all moves and assign the integer move index to the string -for i, label in enumerate(LABELS): - MV_LOOKUP[label] = i + # iterate over all moves and assign the integer move index to the string + for i, label in enumerate(LABELS): + MV_LOOKUP[label] = i -# do the same for the black player -MV_LOOKUP_MIRRORED = {} + # do the same for the black player + MV_LOOKUP_MIRRORED = {} -# iterate over all moves and assign the integer move index to the string -for i, label in enumerate(LABELS_MIRRORED): - MV_LOOKUP_MIRRORED[label] = i + # iterate over all moves and assign the integer move index to the string + for i, label in enumerate(LABELS_MIRRORED): + MV_LOOKUP_MIRRORED[label] = i if __name__ == "__main__": print(LABELS[-20:]) diff --git a/DeepCrazyhouse/src/domain/variants/plane_policy_representation.py b/DeepCrazyhouse/src/domain/variants/plane_policy_representation.py index 3c482359..ca0c6e6e 100644 --- a/DeepCrazyhouse/src/domain/variants/plane_policy_representation.py +++ b/DeepCrazyhouse/src/domain/variants/plane_policy_representation.py @@ -14,7 +14,7 @@ import numpy as np import chess from DeepCrazyhouse.src.domain.variants.constants import BOARD_WIDTH, BOARD_HEIGHT, LABELS, P_MAP,\ - NB_POLICY_MAP_CHANNELS, NB_LABELS_CHESS, MODE, MODE_CRAZYHOUSE, MODE_LICHESS + NB_POLICY_MAP_CHANNELS, NB_LABELS_CHESS, MODE, MODE_CRAZYHOUSE, MODE_LICHESS, MODE_XIANGQI from DeepCrazyhouse.src.domain.util import get_row_col @@ -4830,6 +4830,2094 @@ def get_move_planes(move): 5183, ] +FLAT_PLANE_IDX_XIANGQI = [ + 810, + 900, + 990, + 1080, + 1170, + 1260, + 1350, + 1440, + 0, + 90, + 180, + 270, + 360, + 450, + 540, + 630, + 720, + 3780, + 3870, + 2341, + 811, + 901, + 991, + 1081, + 1171, + 1261, + 1351, + 1, + 91, + 181, + 271, + 361, + 451, + 541, + 631, + 721, + 4411, + 3781, + 3871, + 2432, + 2342, + 812, + 902, + 992, + 1082, + 1172, + 1262, + 2, + 92, + 182, + 272, + 362, + 452, + 542, + 632, + 722, + 4322, + 4412, + 3782, + 3872, + 3692, + 3152, + 2523, + 2433, + 2343, + 813, + 903, + 993, + 1083, + 1173, + 3, + 93, + 183, + 273, + 363, + 453, + 543, + 633, + 723, + 4323, + 4413, + 3783, + 3873, + 2614, + 2524, + 2434, + 2344, + 814, + 904, + 994, + 1084, + 4, + 94, + 184, + 274, + 364, + 454, + 544, + 634, + 724, + 4324, + 4414, + 3784, + 3874, + 2705, + 2615, + 2525, + 2435, + 2345, + 815, + 905, + 995, + 5, + 95, + 185, + 275, + 365, + 455, + 545, + 635, + 725, + 4325, + 4415, + 3785, + 3875, + 2796, + 2706, + 2616, + 2526, + 2436, + 2346, + 816, + 906, + 6, + 96, + 186, + 276, + 366, + 456, + 546, + 636, + 726, + 4326, + 4416, + 3786, + 3876, + 3696, + 3156, + 2887, + 2797, + 2707, + 2617, + 2527, + 2437, + 2347, + 817, + 7, + 97, + 187, + 277, + 367, + 457, + 547, + 637, + 727, + 4327, + 4417, + 3787, + 2978, + 2888, + 2798, + 2708, + 2618, + 2528, + 2438, + 2348, + 8, + 98, + 188, + 278, + 368, + 458, + 548, + 638, + 728, + 4328, + 4418, + 819, + 909, + 999, + 1089, + 1179, + 1269, + 1359, + 1449, + 1539, + 9, + 99, + 189, + 279, + 369, + 459, + 549, + 639, + 3789, + 3879, + 3969, + 2350, + 820, + 910, + 1000, + 1090, + 1180, + 1270, + 1360, + 1540, + 10, + 100, + 190, + 280, + 370, + 460, + 550, + 640, + 4420, + 3790, + 3880, + 3970, + 2441, + 2351, + 821, + 911, + 1001, + 1091, + 1181, + 1271, + 1541, + 11, + 101, + 191, + 281, + 371, + 461, + 551, + 641, + 4241, + 4331, + 4421, + 3791, + 3881, + 3971, + 2532, + 2442, + 2352, + 822, + 912, + 1002, + 1092, + 1182, + 1542, + 12, + 102, + 192, + 282, + 372, + 462, + 552, + 642, + 4242, + 4332, + 4422, + 3792, + 3882, + 3972, + 2623, + 2533, + 2443, + 2353, + 823, + 913, + 1003, + 1093, + 1543, + 13, + 103, + 193, + 283, + 373, + 463, + 553, + 643, + 4243, + 4333, + 4423, + 3793, + 3883, + 3973, + 3433, + 3613, + 3073, + 3253, + 2714, + 2624, + 2534, + 2444, + 2354, + 824, + 914, + 1004, + 1544, + 14, + 104, + 194, + 284, + 374, + 464, + 554, + 644, + 4244, + 4334, + 4424, + 3794, + 3884, + 3974, + 2805, + 2715, + 2625, + 2535, + 2445, + 2355, + 825, + 915, + 1545, + 15, + 105, + 195, + 285, + 375, + 465, + 555, + 645, + 4245, + 4335, + 4425, + 3795, + 3885, + 3975, + 2896, + 2806, + 2716, + 2626, + 2536, + 2446, + 2356, + 826, + 1546, + 16, + 106, + 196, + 286, + 376, + 466, + 556, + 646, + 4246, + 4336, + 4426, + 3796, + 2987, + 2897, + 2807, + 2717, + 2627, + 2537, + 2447, + 2357, + 1547, + 17, + 107, + 197, + 287, + 377, + 467, + 557, + 647, + 4247, + 4337, + 4427, + 828, + 918, + 1008, + 1098, + 1188, + 1278, + 1368, + 1458, + 1638, + 1548, + 18, + 108, + 198, + 288, + 378, + 468, + 558, + 3798, + 3888, + 3978, + 4068, + 3168, + 3348, + 2359, + 829, + 919, + 1009, + 1099, + 1189, + 1279, + 1369, + 1639, + 1549, + 19, + 109, + 199, + 289, + 379, + 469, + 559, + 4159, + 4429, + 3799, + 3889, + 3979, + 4069, + 2450, + 2360, + 830, + 920, + 1010, + 1100, + 1190, + 1280, + 1640, + 1550, + 20, + 110, + 200, + 290, + 380, + 470, + 560, + 4160, + 4250, + 4340, + 4430, + 3800, + 3890, + 3980, + 4070, + 2541, + 2451, + 2361, + 831, + 921, + 1011, + 1101, + 1191, + 1641, + 1551, + 21, + 111, + 201, + 291, + 381, + 471, + 561, + 4161, + 4251, + 4341, + 4431, + 3801, + 3891, + 3981, + 4071, + 2632, + 2542, + 2452, + 2362, + 832, + 922, + 1012, + 1102, + 1642, + 1552, + 22, + 112, + 202, + 292, + 382, + 472, + 562, + 4162, + 4252, + 4342, + 4432, + 3802, + 3892, + 3982, + 4072, + 3712, + 3172, + 3532, + 3352, + 2723, + 2633, + 2543, + 2453, + 2363, + 833, + 923, + 1013, + 1643, + 1553, + 23, + 113, + 203, + 293, + 383, + 473, + 563, + 4163, + 4253, + 4343, + 4433, + 3803, + 3893, + 3983, + 4073, + 2814, + 2724, + 2634, + 2544, + 2454, + 2364, + 834, + 924, + 1644, + 1554, + 24, + 114, + 204, + 294, + 384, + 474, + 564, + 4164, + 4254, + 4344, + 4434, + 3804, + 3894, + 3984, + 4074, + 2905, + 2815, + 2725, + 2635, + 2545, + 2455, + 2365, + 835, + 1645, + 1555, + 25, + 115, + 205, + 295, + 385, + 475, + 565, + 4165, + 4255, + 4345, + 4435, + 3805, + 4075, + 2996, + 2906, + 2816, + 2726, + 2636, + 2546, + 2456, + 2366, + 1646, + 1556, + 26, + 116, + 206, + 296, + 386, + 476, + 566, + 4166, + 4256, + 4346, + 4436, + 3716, + 3536, + 837, + 927, + 1017, + 1107, + 1197, + 1287, + 1377, + 1467, + 1737, + 1647, + 1557, + 27, + 117, + 207, + 297, + 387, + 477, + 3807, + 3897, + 3987, + 4077, + 2368, + 838, + 928, + 1018, + 1108, + 1198, + 1288, + 1378, + 1738, + 1648, + 1558, + 28, + 118, + 208, + 298, + 388, + 478, + 4168, + 4438, + 3808, + 3898, + 3988, + 4078, + 2459, + 2369, + 839, + 929, + 1019, + 1109, + 1199, + 1289, + 1739, + 1649, + 1559, + 29, + 119, + 209, + 299, + 389, + 479, + 4169, + 4259, + 4349, + 4439, + 3809, + 3899, + 3989, + 4079, + 2550, + 2460, + 2370, + 840, + 930, + 1020, + 1110, + 1200, + 1740, + 1650, + 1560, + 30, + 120, + 210, + 300, + 390, + 480, + 4170, + 4260, + 4350, + 4440, + 3810, + 3900, + 3990, + 4080, + 2641, + 2551, + 2461, + 2371, + 841, + 931, + 1021, + 1111, + 1741, + 1651, + 1561, + 31, + 121, + 211, + 301, + 391, + 481, + 4171, + 4261, + 4351, + 4441, + 3811, + 3901, + 3991, + 4081, + 2732, + 2642, + 2552, + 2462, + 2372, + 842, + 932, + 1022, + 1742, + 1652, + 1562, + 32, + 122, + 212, + 302, + 392, + 482, + 4172, + 4262, + 4352, + 4442, + 3812, + 3902, + 3992, + 4082, + 2823, + 2733, + 2643, + 2553, + 2463, + 2373, + 843, + 933, + 1743, + 1653, + 1563, + 33, + 123, + 213, + 303, + 393, + 483, + 4173, + 4263, + 4353, + 4443, + 3813, + 3903, + 3993, + 4083, + 2914, + 2824, + 2734, + 2644, + 2554, + 2464, + 2374, + 844, + 1744, + 1654, + 1564, + 34, + 124, + 214, + 304, + 394, + 484, + 4174, + 4264, + 4354, + 4444, + 3814, + 4084, + 3005, + 2915, + 2825, + 2735, + 2645, + 2555, + 2465, + 2375, + 1745, + 1655, + 1565, + 35, + 125, + 215, + 305, + 395, + 485, + 4175, + 4265, + 4355, + 4445, + 846, + 936, + 1026, + 1116, + 1206, + 1296, + 1386, + 1476, + 1836, + 1746, + 1656, + 1566, + 36, + 126, + 216, + 306, + 396, + 3816, + 3906, + 3996, + 4086, + 2377, + 847, + 937, + 1027, + 1117, + 1207, + 1297, + 1387, + 1837, + 1747, + 1657, + 1567, + 37, + 127, + 217, + 307, + 397, + 4177, + 4447, + 3817, + 3907, + 3997, + 4087, + 2468, + 2378, + 848, + 938, + 1028, + 1118, + 1208, + 1298, + 1838, + 1748, + 1658, + 1568, + 38, + 128, + 218, + 308, + 398, + 4178, + 4268, + 4358, + 4448, + 3818, + 3908, + 3998, + 4088, + 3548, + 3368, + 2559, + 2469, + 2379, + 849, + 939, + 1029, + 1119, + 1209, + 1839, + 1749, + 1659, + 1569, + 39, + 129, + 219, + 309, + 399, + 4179, + 4269, + 4359, + 4449, + 3819, + 3909, + 3999, + 4089, + 2650, + 2560, + 2470, + 2380, + 850, + 940, + 1030, + 1120, + 1840, + 1750, + 1660, + 1570, + 40, + 130, + 220, + 310, + 400, + 4180, + 4270, + 4360, + 4450, + 3820, + 3910, + 4000, + 4090, + 2741, + 2651, + 2561, + 2471, + 2381, + 851, + 941, + 1031, + 1841, + 1751, + 1661, + 1571, + 41, + 131, + 221, + 311, + 401, + 4181, + 4271, + 4361, + 4451, + 3821, + 3911, + 4001, + 4091, + 2832, + 2742, + 2652, + 2562, + 2472, + 2382, + 852, + 942, + 1842, + 1752, + 1662, + 1572, + 42, + 132, + 222, + 312, + 402, + 4182, + 4272, + 4362, + 4452, + 3822, + 3912, + 4002, + 4092, + 3552, + 3372, + 2923, + 2833, + 2743, + 2653, + 2563, + 2473, + 2383, + 853, + 1843, + 1753, + 1663, + 1573, + 43, + 133, + 223, + 313, + 403, + 4183, + 4273, + 4363, + 4453, + 3823, + 4093, + 3014, + 2924, + 2834, + 2744, + 2654, + 2564, + 2474, + 2384, + 1844, + 1754, + 1664, + 1574, + 44, + 134, + 224, + 314, + 404, + 4184, + 4274, + 4364, + 4454, + 855, + 945, + 1035, + 1125, + 1215, + 1305, + 1395, + 1485, + 1935, + 1845, + 1755, + 1665, + 1575, + 45, + 135, + 225, + 315, + 3825, + 3915, + 4005, + 4095, + 2386, + 856, + 946, + 1036, + 1126, + 1216, + 1306, + 1396, + 1936, + 1846, + 1756, + 1666, + 1576, + 46, + 136, + 226, + 316, + 4186, + 4456, + 3826, + 3916, + 4006, + 4096, + 2477, + 2387, + 857, + 947, + 1037, + 1127, + 1217, + 1307, + 1937, + 1847, + 1757, + 1667, + 1577, + 47, + 137, + 227, + 317, + 4187, + 4277, + 4367, + 4457, + 3827, + 3917, + 4007, + 4097, + 3737, + 3197, + 2568, + 2478, + 2388, + 858, + 948, + 1038, + 1128, + 1218, + 1938, + 1848, + 1758, + 1668, + 1578, + 48, + 138, + 228, + 318, + 4188, + 4278, + 4368, + 4458, + 3828, + 3918, + 4008, + 4098, + 2659, + 2569, + 2479, + 2389, + 859, + 949, + 1039, + 1129, + 1939, + 1849, + 1759, + 1669, + 1579, + 49, + 139, + 229, + 319, + 4189, + 4279, + 4369, + 4459, + 3829, + 3919, + 4009, + 4099, + 2750, + 2660, + 2570, + 2480, + 2390, + 860, + 950, + 1040, + 1940, + 1850, + 1760, + 1670, + 1580, + 50, + 140, + 230, + 320, + 4190, + 4280, + 4370, + 4460, + 3830, + 3920, + 4010, + 4100, + 2841, + 2751, + 2661, + 2571, + 2481, + 2391, + 861, + 951, + 1941, + 1851, + 1761, + 1671, + 1581, + 51, + 141, + 231, + 321, + 4191, + 4281, + 4371, + 4461, + 3831, + 3921, + 4011, + 4101, + 3741, + 3201, + 2932, + 2842, + 2752, + 2662, + 2572, + 2482, + 2392, + 862, + 1942, + 1852, + 1762, + 1672, + 1582, + 52, + 142, + 232, + 322, + 4192, + 4282, + 4372, + 4462, + 3832, + 4102, + 3023, + 2933, + 2843, + 2753, + 2663, + 2573, + 2483, + 2393, + 1943, + 1853, + 1763, + 1673, + 1583, + 53, + 143, + 233, + 323, + 4193, + 4283, + 4373, + 4463, + 864, + 954, + 1044, + 1134, + 1224, + 1314, + 1404, + 1494, + 2034, + 1944, + 1854, + 1764, + 1674, + 1584, + 54, + 144, + 234, + 3834, + 3924, + 4014, + 4104, + 2395, + 865, + 955, + 1045, + 1135, + 1225, + 1315, + 1405, + 2035, + 1945, + 1855, + 1765, + 1675, + 1585, + 55, + 145, + 235, + 4195, + 4465, + 3835, + 3925, + 4015, + 4105, + 2486, + 2396, + 866, + 956, + 1046, + 1136, + 1226, + 1316, + 2036, + 1946, + 1856, + 1766, + 1676, + 1586, + 56, + 146, + 236, + 4196, + 4286, + 4376, + 4466, + 3836, + 3926, + 4016, + 4106, + 2577, + 2487, + 2397, + 867, + 957, + 1047, + 1137, + 1227, + 2037, + 1947, + 1857, + 1767, + 1677, + 1587, + 57, + 147, + 237, + 4197, + 4287, + 4377, + 4467, + 3837, + 3927, + 4017, + 4107, + 2668, + 2578, + 2488, + 2398, + 868, + 958, + 1048, + 1138, + 2038, + 1948, + 1858, + 1768, + 1678, + 1588, + 58, + 148, + 238, + 4198, + 4288, + 4378, + 4468, + 3838, + 3928, + 4018, + 4108, + 2759, + 2669, + 2579, + 2489, + 2399, + 869, + 959, + 1049, + 2039, + 1949, + 1859, + 1769, + 1679, + 1589, + 59, + 149, + 239, + 4199, + 4289, + 4379, + 4469, + 3839, + 3929, + 4019, + 4109, + 2850, + 2760, + 2670, + 2580, + 2490, + 2400, + 870, + 960, + 2040, + 1950, + 1860, + 1770, + 1680, + 1590, + 60, + 150, + 240, + 4200, + 4290, + 4380, + 4470, + 3840, + 3930, + 4020, + 4110, + 2941, + 2851, + 2761, + 2671, + 2581, + 2491, + 2401, + 871, + 2041, + 1951, + 1861, + 1771, + 1681, + 1591, + 61, + 151, + 241, + 4201, + 4291, + 4381, + 4471, + 3841, + 4111, + 3032, + 2942, + 2852, + 2762, + 2672, + 2582, + 2492, + 2402, + 2042, + 1952, + 1862, + 1772, + 1682, + 1592, + 62, + 152, + 242, + 4202, + 4292, + 4382, + 4472, + 873, + 963, + 1053, + 1143, + 1233, + 1323, + 1413, + 1503, + 2133, + 2043, + 1953, + 1863, + 1773, + 1683, + 1593, + 63, + 153, + 3843, + 3933, + 4023, + 4113, + 3213, + 3393, + 2404, + 874, + 964, + 1054, + 1144, + 1234, + 1324, + 1414, + 2134, + 2044, + 1954, + 1864, + 1774, + 1684, + 1594, + 64, + 154, + 4204, + 4474, + 3844, + 3934, + 4024, + 4114, + 2495, + 2405, + 875, + 965, + 1055, + 1145, + 1235, + 1325, + 2135, + 2045, + 1955, + 1865, + 1775, + 1685, + 1595, + 65, + 155, + 4205, + 4295, + 4385, + 4475, + 3845, + 3935, + 4025, + 4115, + 2586, + 2496, + 2406, + 876, + 966, + 1056, + 1146, + 1236, + 2136, + 2046, + 1956, + 1866, + 1776, + 1686, + 1596, + 66, + 156, + 4206, + 4296, + 4386, + 4476, + 3846, + 3936, + 4026, + 4116, + 2677, + 2587, + 2497, + 2407, + 877, + 967, + 1057, + 1147, + 2137, + 2047, + 1957, + 1867, + 1777, + 1687, + 1597, + 67, + 157, + 4207, + 4297, + 4387, + 4477, + 3847, + 3937, + 4027, + 4117, + 3757, + 3217, + 3577, + 3397, + 2768, + 2678, + 2588, + 2498, + 2408, + 878, + 968, + 1058, + 2138, + 2048, + 1958, + 1868, + 1778, + 1688, + 1598, + 68, + 158, + 4208, + 4298, + 4388, + 4478, + 3848, + 3938, + 4028, + 4118, + 2859, + 2769, + 2679, + 2589, + 2499, + 2409, + 879, + 969, + 2139, + 2049, + 1959, + 1869, + 1779, + 1689, + 1599, + 69, + 159, + 4209, + 4299, + 4389, + 4479, + 3849, + 3939, + 4029, + 4119, + 2950, + 2860, + 2770, + 2680, + 2590, + 2500, + 2410, + 880, + 2140, + 2050, + 1960, + 1870, + 1780, + 1690, + 1600, + 70, + 160, + 4210, + 4300, + 4390, + 4480, + 3850, + 4120, + 3041, + 2951, + 2861, + 2771, + 2681, + 2591, + 2501, + 2411, + 2141, + 2051, + 1961, + 1871, + 1781, + 1691, + 1601, + 71, + 161, + 4211, + 4301, + 4391, + 4481, + 3761, + 3581, + 882, + 972, + 1062, + 1152, + 1242, + 1332, + 1422, + 1512, + 2232, + 2142, + 2052, + 1962, + 1872, + 1782, + 1692, + 1602, + 72, + 3942, + 4032, + 4122, + 2413, + 883, + 973, + 1063, + 1153, + 1243, + 1333, + 1423, + 2233, + 2143, + 2053, + 1963, + 1873, + 1783, + 1693, + 1603, + 73, + 4213, + 3943, + 4033, + 4123, + 2504, + 2414, + 884, + 974, + 1064, + 1154, + 1244, + 1334, + 2234, + 2144, + 2054, + 1964, + 1874, + 1784, + 1694, + 1604, + 74, + 4214, + 4304, + 4394, + 3944, + 4034, + 4124, + 2595, + 2505, + 2415, + 885, + 975, + 1065, + 1155, + 1245, + 2235, + 2145, + 2055, + 1965, + 1875, + 1785, + 1695, + 1605, + 75, + 4215, + 4305, + 4395, + 3945, + 4035, + 4125, + 2686, + 2596, + 2506, + 2416, + 886, + 976, + 1066, + 1156, + 2236, + 2146, + 2056, + 1966, + 1876, + 1786, + 1696, + 1606, + 76, + 4216, + 4306, + 4396, + 3946, + 4036, + 4126, + 3496, + 3676, + 3136, + 3316, + 2777, + 2687, + 2597, + 2507, + 2417, + 887, + 977, + 1067, + 2237, + 2147, + 2057, + 1967, + 1877, + 1787, + 1697, + 1607, + 77, + 4217, + 4307, + 4397, + 3947, + 4037, + 4127, + 2868, + 2778, + 2688, + 2598, + 2508, + 2418, + 888, + 978, + 2238, + 2148, + 2058, + 1968, + 1878, + 1788, + 1698, + 1608, + 78, + 4218, + 4308, + 4398, + 3948, + 4038, + 4128, + 2959, + 2869, + 2779, + 2689, + 2599, + 2509, + 2419, + 889, + 2239, + 2149, + 2059, + 1969, + 1879, + 1789, + 1699, + 1609, + 79, + 4219, + 4309, + 4399, + 4129, + 3050, + 2960, + 2870, + 2780, + 2690, + 2600, + 2510, + 2420, + 2240, + 2150, + 2060, + 1970, + 1880, + 1790, + 1700, + 1610, + 80, + 4220, + 4310, + 4400, + 891, + 981, + 1071, + 1161, + 1251, + 1341, + 1431, + 1521, + 2331, + 2241, + 2151, + 2061, + 1971, + 1881, + 1791, + 1701, + 1611, + 4041, + 4131, + 2422, + 892, + 982, + 1072, + 1162, + 1252, + 1342, + 1432, + 2332, + 2242, + 2152, + 2062, + 1972, + 1882, + 1792, + 1702, + 1612, + 4222, + 4042, + 4132, + 2513, + 2423, + 893, + 983, + 1073, + 1163, + 1253, + 1343, + 2333, + 2243, + 2153, + 2063, + 1973, + 1883, + 1793, + 1703, + 1613, + 4223, + 4313, + 4043, + 4133, + 3593, + 3413, + 2604, + 2514, + 2424, + 894, + 984, + 1074, + 1164, + 1254, + 2334, + 2244, + 2154, + 2064, + 1974, + 1884, + 1794, + 1704, + 1614, + 4224, + 4314, + 4044, + 4134, + 2695, + 2605, + 2515, + 2425, + 895, + 985, + 1075, + 1165, + 2335, + 2245, + 2155, + 2065, + 1975, + 1885, + 1795, + 1705, + 1615, + 4225, + 4315, + 4045, + 4135, + 2786, + 2696, + 2606, + 2516, + 2426, + 896, + 986, + 1076, + 2336, + 2246, + 2156, + 2066, + 1976, + 1886, + 1796, + 1706, + 1616, + 4226, + 4316, + 4046, + 4136, + 2877, + 2787, + 2697, + 2607, + 2517, + 2427, + 897, + 987, + 2337, + 2247, + 2157, + 2067, + 1977, + 1887, + 1797, + 1707, + 1617, + 4227, + 4317, + 4047, + 4137, + 3597, + 3417, + 2968, + 2878, + 2788, + 2698, + 2608, + 2518, + 2428, + 898, + 2338, + 2248, + 2158, + 2068, + 1978, + 1888, + 1798, + 1708, + 1618, + 4228, + 4318, + 4138, + 3059, + 2969, + 2879, + 2789, + 2699, + 2609, + 2519, + 2429, + 2339, + 2249, + 2159, + 2069, + 1979, + 1889, + 1799, + 1709, + 1619, + 4229, + 4319, + 3063, + 3605, + 3261, + 3443, + 3324, + 3506, + 3126, + 3668, +] FLAT_PLANE_IDX = None # in policy version 2, the king promotion moves were added to support antichess, this deprecates older nets @@ -4837,6 +6925,8 @@ def get_move_planes(move): FLAT_PLANE_IDX = FLAT_PLANE_IDX_CRAZYHOUSE elif MODE == MODE_LICHESS: FLAT_PLANE_IDX = FLAT_PLANE_IDX_LICHESS +elif MODE == MODE_XIANGQI: + FLAT_PLANE_IDX = FLAT_PLANE_IDX_XIANGQI else: # MODE = MODE_CHESS # for chess the same as for crazyhouse is used without dropping moves FLAT_PLANE_IDX = FLAT_PLANE_IDX_CRAZYHOUSE[:NB_LABELS_CHESS] diff --git a/DeepCrazyhouse/src/preprocessing/dataset_loader.py b/DeepCrazyhouse/src/preprocessing/dataset_loader.py index 016c72da..10fbdd67 100644 --- a/DeepCrazyhouse/src/preprocessing/dataset_loader.py +++ b/DeepCrazyhouse/src/preprocessing/dataset_loader.py @@ -11,7 +11,7 @@ import numpy as np import zarr from DeepCrazyhouse.configs.main_config import main_config -from DeepCrazyhouse.src.domain.util import get_numpy_arrays, MATRIX_NORMALIZER +from DeepCrazyhouse.src.domain.util import get_numpy_arrays, get_x_y_and_indices, MATRIX_NORMALIZER def _load_dataset_file(dataset_filepath): @@ -105,3 +105,66 @@ def load_pgn_dataset( # apply rescaling using a predefined scaling constant (this makes use of vectorized operations) x *= MATRIX_NORMALIZER return start_indices, x, y_value, y_policy, plys_to_end, pgn_dataset + + +def load_xiangqi_dataset(dataset_type="train", part_id=0, verbose=True, normalize=False): + """ + Loads one part of the preprocessed data set of xiangqi games, originally given as csv. + + :parram dataset_type: either ['train', 'test', 'val'] + :param part_id: Decides which part of the data set will be loaded + :param verbose: True if the log message shall be shown + :param normalize: True if the inputs shall be normalized to 0-1 + :return: numpy-arrays: + start_indices - defines the index where each game starts + x - the board representation for all games + y_value - the game outcome (-1,0,1) for each board position + y_policy - the movement policy for the next_move played + dataset - the dataset file handle (you can use .tree() to show the file structure) + """ + if dataset_type == "train": + zarr_filepaths = glob.glob(main_config["planes_train_dir"] + "**/*.zip") + elif dataset_type == "val": + zarr_filepaths = glob.glob(main_config["planes_val_dir"] + "**/*.zip") + elif dataset_type == "test": + zarr_filepaths = glob.glob(main_config["planes_test_dir"] + "**/*.zip") + else: + raise Exception( + 'Invalid dataset type "%s" given. It must be either "train", "val" or "test"' % dataset_type + ) + + if part_id >= len(zarr_filepaths): + raise Exception("There aren't enough parts available (%d parts) in the given directory for part_id=%d" + % (len(zarr_filepaths), part_id)) + + # load zarr-files + datasets = zarr_filepaths + if verbose: + logging.debug("loading: %s...\n", datasets[part_id]) + + dataset = zarr.group(store=zarr.ZipStore(datasets[part_id], mode="r")) + start_indices, x, y_value, y_policy = get_x_y_and_indices(dataset) + + if verbose: + logging.info("STATISTICS:") + try: + for member in dataset["statistics"]: + if member in ["avg_elo", "avg_elo_red", "avg_elo_black", "num_red_wins", "num_black_wins", "num_draws"]: + print(member, list(dataset["statistics"][member])) + except KeyError: + logging.warning("no statistics found") + + logging.info("PARAMETERS:") + try: + for member in dataset["parameters"]: + print(member, list(dataset["parameters"][member])) + except KeyError: + logging.warning("no parameters found") + + if normalize: + x = x.astype(np.float32) + y_value = y_value.astype(np.float32) + y_policy = y_policy.astype(np.float32) + + x *= MATRIX_NORMALIZER + return start_indices, x, y_value, y_policy, dataset diff --git a/DeepCrazyhouse/src/training/trainer_agent_mxnet.py b/DeepCrazyhouse/src/training/trainer_agent_mxnet.py index 310af279..28e482c5 100644 --- a/DeepCrazyhouse/src/training/trainer_agent_mxnet.py +++ b/DeepCrazyhouse/src/training/trainer_agent_mxnet.py @@ -16,9 +16,10 @@ from tqdm import tqdm_notebook from rtpt import RTPT from DeepCrazyhouse.configs.train_config import TrainConfig, TrainObjects +from DeepCrazyhouse.src.domain.util import augment from DeepCrazyhouse.src.domain.variants.plane_policy_representation import FLAT_PLANE_IDX -from DeepCrazyhouse.src.preprocessing.dataset_loader import load_pgn_dataset -from DeepCrazyhouse.src.domain.variants.constants import NB_LABELS_POLICY_MAP +from DeepCrazyhouse.src.preprocessing.dataset_loader import load_pgn_dataset, load_xiangqi_dataset +from DeepCrazyhouse.src.domain.variants.constants import NB_LABELS_POLICY_MAP, MODE, MODE_XIANGQI from DeepCrazyhouse.src.training.crossentropy import * @@ -163,7 +164,8 @@ def __init__( val_iter, train_config: TrainConfig, train_objects: TrainObjects, - use_rtpt: bool + use_rtpt: bool, + augment = False ): """ Class for training the neural network. @@ -185,18 +187,22 @@ def __init__( self._val_iter = val_iter self.x_train = self.yv_train = self.yp_train = None self._ctx = get_context(train_config.context, train_config.device_id) + self._augment = augment # define a summary writer that logs data and flushes to the file every 5 seconds if self.tc.log_metrics_to_tensorboard: self.sum_writer = SummaryWriter(logdir=self.tc.export_dir+"logs", flush_secs=5, verbose=False) # Define the optimizer if self.tc.optimizer_name == "adam": - self.optimizer = mx.optimizer.Adam(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, lazy_update=True, rescale_grad=(1.0/batch_size)) + self.optimizer = mx.optimizer.Adam(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, lazy_update=True, rescale_grad=(1.0/self.tc.batch_size)) elif self.tc.optimizer_name == "nag": self.optimizer = mx.optimizer.NAG(momentum=self.to.momentum_schedule(0), wd=self.tc.wd, rescale_grad=(1.0/self.tc.batch_size)) else: raise Exception("%s is currently not supported as an optimizer." % self.tc.optimizer_name) self.ordering = list(range(self.tc.nb_parts)) # define a list which describes the order of the processed batches + # if we augment the data set each part is loaded twice + if self._augment: + self.ordering += self.ordering # decides if the policy indices shall be selected directly from spatial feature maps without dense layer self.batch_end_callbacks = [self.batch_callback] @@ -281,25 +287,52 @@ def train(self, cur_it=None): # Probably needs refactoring self.t_s_steps = time() self._model.init_optimizer(optimizer=self.optimizer) + if self._augment: + # stores part ids that were not augmented yet + parts_not_augmented = list(set(self.ordering.copy())) + # stores part ids that were loaded before but not augmented + parts_to_augment = [] + for part_id in tqdm_notebook(self.ordering): - # load one chunk of the dataset from memory - _, self.x_train, self.yv_train, self.yp_train, plys_to_end, _ = load_pgn_dataset(dataset_type="train", - part_id=part_id, - normalize=self.tc.normalize, - verbose=False, - q_value_ratio=self.tc.q_value_ratio) + if MODE == MODE_XIANGQI: + _, self.x_train, self.yv_train, self.yp_train, _ = load_xiangqi_dataset(dataset_type="train", + part_id=part_id, + normalize=self.tc.normalize, + verbose=False) + if self._augment: + # check whether the current part should be augmented + if part_id in parts_to_augment: + augment(self.x_train, self.yp_train) + logging.debug("Using augmented part with id {}".format(part_id)) + elif part_id in parts_not_augmented: + if random.randint(0, 1): + augment(self.x_train, self.yp_train) + parts_not_augmented.remove(part_id) + logging.debug("Using augmented part with id {}".format(part_id)) + else: + parts_to_augment.append(part_id) + logging.debug("Using unaugmented part with id {}".format(part_id)) + else: + # load one chunk of the dataset from memory + _, self.x_train, self.yv_train, self.yp_train, plys_to_end, _ = load_pgn_dataset(dataset_type="train", + part_id=part_id, + normalize=self.tc.normalize, + verbose=False, + q_value_ratio=self.tc.q_value_ratio) # fill_up_batch if there aren't enough games if len(self.yv_train) < self.tc.batch_size: logging.info("filling up batch with too few samples %d" % len(self.yv_train)) self.x_train = fill_up_batch(self.x_train, self.tc.batch_size) self.yv_train = fill_up_batch(self.yv_train, self.tc.batch_size) self.yp_train = fill_up_batch(self.yp_train, self.tc.batch_size) - if plys_to_end is not None: - plys_to_end = fill_up_batch(plys_to_end, self.tc.batch_size) + if MODE != MODE_XIANGQI: + if plys_to_end is not None: + plys_to_end = fill_up_batch(plys_to_end, self.tc.batch_size) - if self.tc.discount != 1: - self.yv_train *= self.tc.discount**plys_to_end + if MODE != MODE_XIANGQI: + if self.tc.discount != 1: + self.yv_train *= self.tc.discount**plys_to_end self.yp_train = prepare_policy(self.yp_train, self.tc.select_policy_from_plane, self.tc.sparse_policy_label, self.tc.is_policy_from_plane_data) From 1b32e23638ef9d9184d96f6ba4964fe990029b49 Mon Sep 17 00:00:00 2001 From: xsr7qsr Date: Mon, 26 Apr 2021 11:07:54 +0200 Subject: [PATCH 03/19] Add xiangqi support for training --- .../src/training/train_xiangqi.ipynb | 607 ++++++++++++++++++ 1 file changed, 607 insertions(+) create mode 100644 DeepCrazyhouse/src/training/train_xiangqi.ipynb diff --git a/DeepCrazyhouse/src/training/train_xiangqi.ipynb b/DeepCrazyhouse/src/training/train_xiangqi.ipynb new file mode 100644 index 00000000..b9349a63 --- /dev/null +++ b/DeepCrazyhouse/src/training/train_xiangqi.ipynb @@ -0,0 +1,607 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "otherwise-causing", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append('../../../')\n", + "\n", + "import glob\n", + "import logging\n", + "import numpy as np\n", + "import mxnet as mx\n", + "from mxnet import gluon\n", + "\n", + "from DeepCrazyhouse.configs.main_config import main_config\n", + "from DeepCrazyhouse.configs.train_config import TrainConfig, TrainObjects\n", + "from DeepCrazyhouse.src.runtime.color_logger import enable_color_logging\n", + "from DeepCrazyhouse.src.domain.variants.constants import NB_POLICY_MAP_CHANNELS, NB_LABELS\n", + "from DeepCrazyhouse.src.domain.variants.plane_policy_representation import FLAT_PLANE_IDX\n", + "\n", + "from DeepCrazyhouse.src.preprocessing.dataset_loader import load_xiangqi_dataset\n", + "\n", + "from DeepCrazyhouse.src.training.lr_schedules.lr_schedules import *\n", + "\n", + "from DeepCrazyhouse.src.domain.neural_net.architectures.rise_mobile_v2 import rise_mobile_v2_symbol\n", + "from DeepCrazyhouse.src.domain.neural_net.architectures.rise_mobile_v3 import rise_mobile_v3_symbol\n", + "\n", + "from DeepCrazyhouse.src.training.trainer_agent import TrainerAgent, evaluate_metrics, acc_sign, reset_metrics\n", + "from DeepCrazyhouse.src.training.trainer_agent_mxnet import TrainerAgentMXNET, get_context\n", + "\n", + "enable_color_logging()\n", + "\n", + "print(\"mxnet version: \", mx.__version__)" + ] + }, + { + "cell_type": "markdown", + "id": "geological-athletics", + "metadata": {}, + "source": [ + "# Main Configuration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "precious-analyst", + "metadata": {}, + "outputs": [], + "source": [ + "for key in main_config.keys():\n", + " print(key, \"= \", main_config[key])" + ] + }, + { + "cell_type": "markdown", + "id": "norman-british", + "metadata": {}, + "source": [ + "# Settings for training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "champion-ancient", + "metadata": {}, + "outputs": [], + "source": [ + "tc = TrainConfig()\n", + "to = TrainObjects()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "passive-remains", + "metadata": {}, + "outputs": [], + "source": [ + "# Setting the context to GPU is strongly recommended\n", + "tc.context = \"gpu\" # Be sure to check the used devices!!!\n", + "tc.device_id = 0\n", + "\n", + "# Used for reproducibility\n", + "tc.seed = 7\n", + "\n", + "tc.export_weights = True\n", + "tc.log_metrics_to_tensorboard = True\n", + "tc.export_grad_histograms = True\n", + "\n", + "# div factor is a constant which can be used to reduce the batch size and learning rate respectively\n", + "# use a value larger 1 if you enconter memory allocation errors\n", + "tc.div_factor = 2\n", + "\n", + "# defines how often a new checkpoint will be saved and the metrics evaluated\n", + "# (batch_steps = 1000 means that every 1000 batches the validation set gets processed)\n", + "tc.batch_steps = 100 * tc.div_factor\n", + "# k_steps_initial defines how many steps have been trained before\n", + "# (k_steps_initial != 0 if you continue training from a checkpoint)\n", + "tc.k_steps_initial = 0\n", + "\n", + "# these are the weights to continue training with\n", + "tc.symbol_file = None # 'model-0.81901-0.713-symbol.json'\n", + "tc.params_file = None #'model-0.81901-0.713-0498.params'\n", + "\n", + "#typically if you half the batch_size, you should double the lr\n", + "tc.batch_size = int(1024 / tc.div_factor)\n", + "\n", + "# optimization parameters\n", + "tc.optimizer_name = \"nag\"\n", + "tc.max_lr = 0.35 / tc.div_factor\n", + "tc.min_lr = 0.00001\n", + "tc.max_momentum = 0.95\n", + "tc.min_momentum = 0.8\n", + "# loads a previous checkpoint if the loss increased significanly\n", + "tc.use_spike_recovery = True\n", + "# stop training as soon as max_spikes has been reached\n", + "tc.max_spikes = 20\n", + "# define spike threshold when the detection will be triggered\n", + "tc.spike_thresh = 1.5\n", + "# weight decay\n", + "tc.wd = 1e-4\n", + "tc.dropout_rate = 0 #0.15\n", + "# weight the value loss a lot lower than the policy loss in order to prevent overfitting\n", + "tc.val_loss_factor = 0.01\n", + "tc.policy_loss_factor = 0.99\n", + "tc.discount = 1.0\n", + "\n", + "tc.normalize = True\n", + "tc.nb_epochs = 7\n", + "# Boolean if potential legal moves will be selected from final policy output\n", + "tc.select_policy_from_plane = True \n", + "\n", + "# define if policy training target is one-hot encoded a distribution (e.g. mcts samples, knowledge distillation)\n", + "tc.sparse_policy_label = True\n", + "# define if the policy data is also defined in \"select_policy_from_plane\" representation\n", + "tc.is_policy_from_plane_data = False\n", + "\n", + "# Decide between mxnet and gluon style for training\n", + "tc.use_mxnet_style = True \n", + "\n", + "# additional custom validation set files which will be logged to tensorboard\n", + "tc.variant_metrics = None #[\"chess960\", \"koth\", \"three_check\"]\n", + "\n", + "tc.name_initials = \"Your Initials\"\n", + "\n", + "# enable data set augmentation\n", + "augment = True" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bridal-meditation", + "metadata": {}, + "outputs": [], + "source": [ + "mode = main_config[\"mode\"]\n", + "ctx = get_context(tc.context, tc.device_id)\n", + "\n", + "# if use_extra_variant_input is true the current active variant is passed two each residual block and\n", + "# concatenated at the end of the final feature representation\n", + "use_extra_variant_input = False\n", + "\n", + "# iteration counter used for the momentum and learning rate schedule\n", + "cur_it = tc.k_steps_initial * tc.batch_steps \n", + "\n", + "# Fix the random seed\n", + "mx.random.seed(tc.seed)" + ] + }, + { + "cell_type": "markdown", + "id": "genuine-network", + "metadata": {}, + "source": [ + "### Crete a ./logs and ./weights directory" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "satisfactory-lobby", + "metadata": {}, + "outputs": [], + "source": [ + "!mkdir ./logs && mkdir ./weights" + ] + }, + { + "cell_type": "markdown", + "id": "logical-hygiene", + "metadata": {}, + "source": [ + "# Load Datasets" + ] + }, + { + "cell_type": "markdown", + "id": "suburban-firmware", + "metadata": {}, + "source": [ + "### Validation set" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "earlier-assessment", + "metadata": {}, + "outputs": [], + "source": [ + "combined = True\n", + "\n", + "if combined:\n", + " start_indices_val_0, x_val_0, y_value_val_0, y_policy_val_0, dataset_0 = load_xiangqi_dataset(dataset_type=\"val\",\n", + " part_id=0,\n", + " verbose=True,\n", + " normalize=tc.normalize)\n", + " start_indices_val_1, x_val_1, y_value_val_1, y_policy_val_1, dataset_1 = load_xiangqi_dataset(dataset_type=\"val\",\n", + " part_id=1,\n", + " verbose=True,\n", + " normalize=tc.normalize)\n", + " # X\n", + " nb_inputs = x_val_0.shape[0] + x_val_1.shape[0]\n", + " nb_planes = x_val_0.shape[1]\n", + " nb_rows = x_val_0.shape[2]\n", + " nb_cols = x_val_0.shape[3]\n", + " x_val = np.zeros((nb_inputs, nb_planes, nb_rows, nb_cols))\n", + " x_val[:x_val_0.shape[0]] = x_val_0\n", + " x_val[x_val_0.shape[0]:] = x_val_1\n", + "\n", + " # value targets\n", + " nb_targets_value = y_value_val_0.shape[0] + y_value_val_1.shape[0]\n", + " y_value_val = np.zeros((nb_targets_value,))\n", + " y_value_val[:y_value_val_0.shape[0]] = y_value_val_0\n", + " y_value_val[y_value_val_0.shape[0]:] = y_value_val_1\n", + "\n", + " # policy targets\n", + " nb_targets_policy = y_policy_val_0.shape[0] + y_policy_val_1.shape[0]\n", + " y_policy_val = np.zeros((nb_targets_policy,y_policy_val_0.shape[1]))\n", + " y_policy_val[:y_policy_val_0.shape[0]] = y_policy_val_0\n", + " y_policy_val[y_policy_val_0.shape[0]:] = y_policy_val_1\n", + "else:\n", + " start_indices_val, x_val, y_value_val, y_policy_val, dataset = load_xiangqi_dataset(dataset_type=\"val\",\n", + " part_id=0,\n", + " verbose=True,\n", + " normalize=tc.normalize)\n", + "if tc.normalize:\n", + " assert x_val.max() <= 1.0, \"Error: Normalization not working.\"\n", + "\n", + "if tc.select_policy_from_plane:\n", + " val_iter = mx.io.NDArrayIter({'data': x_val}, \n", + " {'value_label': y_value_val, \n", + " 'policy_label': np.array(FLAT_PLANE_IDX)[y_policy_val.argmax(axis=1)]},\n", + " tc.batch_size)\n", + "else:\n", + " val_iter = mx.io.NDArrayIter({'data': x_val}, \n", + " {'value_label': y_value_val, \n", + " 'policy_label': y_policy_val.argmax(axis=1)}, \n", + " tc.batch_size)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "missing-correction", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "print(\"x_val.shape: \", x_val.shape)\n", + "print(\"y_value_val.shape: \", y_value_val.shape)\n", + "print(\"y_policy_val.shape: \", y_policy_val.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "fatal-steal", + "metadata": {}, + "source": [ + "### Training properties" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "above-donor", + "metadata": {}, + "outputs": [], + "source": [ + "len(x_val)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "growing-occasion", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "tc.nb_parts = len(glob.glob(main_config[\"planes_train_dir\"] + \"**/*\"))\n", + "print(\"Parts training dataset: \", tc.nb_parts)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "recent-anchor", + "metadata": {}, + "outputs": [], + "source": [ + "# one iteration is defined by passing 1 batch and doing backpropagation\n", + "if augment:\n", + " nb_it_per_epoch = (len(x_val) * tc.nb_parts * 2) // tc.batch_size\n", + "else:\n", + " nb_it_per_epoch = (len(x_val) * tc.nb_parts) // tc.batch_size\n", + "tc.total_it = int(nb_it_per_epoch * tc.nb_epochs)\n", + "print(\"Total iterations: \", tc.total_it)" + ] + }, + { + "cell_type": "markdown", + "id": "exceptional-tutorial", + "metadata": {}, + "source": [ + "# Learning Rate Schedule" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "banned-participant", + "metadata": {}, + "outputs": [], + "source": [ + "to.lr_schedule = OneCycleSchedule(start_lr=tc.max_lr/8, \n", + " max_lr=tc.max_lr, \n", + " cycle_length=tc.total_it*.3, \n", + " cooldown_length=tc.total_it*.6, \n", + " finish_lr=tc.min_lr)\n", + "to.lr_schedule = LinearWarmUp(to.lr_schedule, start_lr=tc.min_lr, length=tc.total_it/30)\n", + "plot_schedule(to.lr_schedule, iterations=tc.total_it)" + ] + }, + { + "cell_type": "markdown", + "id": "secret-thumbnail", + "metadata": {}, + "source": [ + "# Momentum Schedule" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "brief-wyoming", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "to.momentum_schedule = MomentumSchedule(to.lr_schedule, tc.min_lr, tc.max_lr, tc.min_momentum, tc.max_momentum)\n", + "plot_schedule(to.momentum_schedule, iterations=tc.total_it, ylabel='Momentum')" + ] + }, + { + "cell_type": "markdown", + "id": "viral-advantage", + "metadata": {}, + "source": [ + "# Define NN model / Load pretrained model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "gothic-peripheral", + "metadata": {}, + "outputs": [], + "source": [ + "input_shape = x_val[0].shape\n", + "print(\"Input shape: \", input_shape)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "august-customs", + "metadata": {}, + "outputs": [], + "source": [ + "bc_res_blocks = [3] * 5 # 13\n", + "if tc.symbol_file is None:\n", + " # channels_operating_init, channel_expansion\n", + " symbol = rise_mobile_v2_symbol(channels=256, channels_operating_init=512, channel_expansion=0, channels_value_head=8,\n", + " channels_policy_head=NB_POLICY_MAP_CHANNELS, value_fc_size=256, bc_res_blocks=bc_res_blocks, res_blocks=[], act_type='relu',\n", + " n_labels=NB_LABELS, grad_scale_value=tc.val_loss_factor, grad_scale_policy=tc.policy_loss_factor, select_policy_from_plane=tc.select_policy_from_plane,\n", + " use_se=True, dropout_rate=tc.dropout_rate, use_extra_variant_input=use_extra_variant_input)\n", + "else:\n", + " symbol = mx.sym.load(\"weights/\" + symbol_file)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "explicit-reviewer", + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"\n", + "bc_res_blocks = [3] * 13 \n", + "if tc.symbol_file is None:\n", + " # channels_operating_init, channel_expansion\n", + " symbol = rise_mobile_v2_symbol(channels=256, channels_operating_init=128, channel_expansion=64, channels_value_head=8,\n", + " channels_policy_head=NB_POLICY_MAP_CHANNELS, value_fc_size=256, bc_res_blocks=bc_res_blocks, res_blocks=[], act_type='relu',\n", + " n_labels=NB_LABELS, grad_scale_value=tc.val_loss_factor, grad_scale_policy=tc.policy_loss_factor, select_policy_from_plane=tc.select_policy_from_plane,\n", + " use_se=True, dropout_rate=tc.dropout_rate, use_extra_variant_input=use_extra_variant_input)\n", + "else:\n", + " symbol = mx.sym.load(\"weights/\" + symbol_file)\n", + "\"\"\"" + ] + }, + { + "cell_type": "markdown", + "id": "involved-walter", + "metadata": {}, + "source": [ + "# Network summary" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "closing-springer", + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "display(mx.viz.plot_network(\n", + " symbol,\n", + " shape={'data':(1, input_shape[0], input_shape[1], input_shape[2])},\n", + " node_attrs={\"shape\":\"oval\",\"fixedsize\":\"false\"}\n", + " ))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "raised-baker", + "metadata": {}, + "outputs": [], + "source": [ + "mx.viz.print_summary(\n", + " symbol,\n", + " shape={'data':(1, input_shape[0], input_shape[1], input_shape[2])},\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "computational-cleveland", + "metadata": {}, + "source": [ + "# Initialize weights if no pretrained weights are used" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "interracial-processing", + "metadata": {}, + "outputs": [], + "source": [ + "# create a trainable module on compute context\n", + "model = mx.mod.Module(symbol=symbol, context=ctx, label_names=['value_label', 'policy_label'])\n", + "model.bind(for_training=True, data_shapes=[('data', (tc.batch_size, input_shape[0], input_shape[1], input_shape[2]))],\n", + " label_shapes=val_iter.provide_label)\n", + "model.init_params(mx.initializer.Xavier(rnd_type='uniform', factor_type='avg', magnitude=2.24))\n", + "if tc.params_file:\n", + " model.load_params(\"weights/\" + tc.params_file) " + ] + }, + { + "cell_type": "markdown", + "id": "supposed-barrier", + "metadata": {}, + "source": [ + "# Metrics" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "regional-plate", + "metadata": {}, + "outputs": [], + "source": [ + "metrics_mxnet = [\n", + "mx.metric.MSE(name='value_loss', output_names=['value_output'], label_names=['value_label']),\n", + "mx.metric.CrossEntropy(name='policy_loss', output_names=['policy_output'],\n", + " label_names=['policy_label']),\n", + "mx.metric.create(acc_sign, name='value_acc_sign', output_names=['value_output'],\n", + " label_names=['value_label']),\n", + "mx.metric.Accuracy(axis=1, name='policy_acc', output_names=['policy_output'],\n", + " label_names=['policy_label'])\n", + "]\n", + "metrics_gluon = {\n", + "'value_loss': mx.metric.MSE(name='value_loss', output_names=['value_output']),\n", + "'policy_loss': mx.metric.CrossEntropy(name='policy_loss', output_names=['policy_output'],\n", + " label_names=['policy_label']),\n", + "'value_acc_sign': mx.metric.create(acc_sign, name='value_acc_sign', output_names=['value_output'],\n", + " label_names=['value_label']),\n", + "'policy_acc': mx.metric.Accuracy(axis=1, name='policy_acc', output_names=['policy_output'],\n", + " label_names=['policy_label'])\n", + "}\n", + "to.metrics = metrics_mxnet" + ] + }, + { + "cell_type": "markdown", + "id": "gothic-poetry", + "metadata": {}, + "source": [ + "# Training Agent" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fresh-lithuania", + "metadata": {}, + "outputs": [], + "source": [ + "train_agent = TrainerAgentMXNET(model, symbol, val_iter, tc, to, use_rtpt=False, augment=augment)" + ] + }, + { + "cell_type": "markdown", + "id": "peripheral-sense", + "metadata": {}, + "source": [ + "# Performance Pre-Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "clear-burner", + "metadata": {}, + "outputs": [], + "source": [ + "print(model.score(val_iter, to.metrics))" + ] + }, + { + "cell_type": "markdown", + "id": "requested-wrist", + "metadata": {}, + "source": [ + "# Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ethical-piece", + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "(k_steps_final, value_loss_final, policy_loss_final, value_acc_sign_final, val_p_acc_final), (k_steps_best, val_loss_best, val_p_acc_best) = train_agent.train(cur_it)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 25b48b7f8f74c0a05b1c0fbe7ae63b3f059755aa Mon Sep 17 00:00:00 2001 From: xsr7qsr Date: Mon, 26 Apr 2021 12:48:39 +0200 Subject: [PATCH 04/19] Add xiangqi support for MCTS --- engine/CMakeLists.txt | 63 +- .../environments/fairy_state/fairyboard.cpp | 138 ++ .../src/environments/fairy_state/fairyboard.h | 31 + .../fairy_state/fairyinputrepresentation.cpp | 64 + .../fairy_state/fairyinputrepresentation.h | 23 + .../fairy_state/fairyoutputrepresentation.cpp | 171 ++ .../fairy_state/fairyoutputrepresentation.h | 54 + .../fairypolicymaprepresentation.h | 2093 +++++++++++++++++ .../environments/fairy_state/fairystate.cpp | 126 + .../src/environments/fairy_state/fairystate.h | 145 ++ .../environments/fairy_state/fairyutil.cpp | 35 + .../src/environments/fairy_state/fairyutil.h | 45 + engine/src/nn/tensorrtapi.cpp | 4 +- engine/src/state.h | 2 +- engine/src/stateobj.h | 6 + engine/src/uci/crazyara.cpp | 52 +- engine/src/uci/main.cpp | 4 + engine/src/uci/optionsuci.cpp | 15 +- engine/src/uci/variants.h | 9 +- 19 files changed, 3061 insertions(+), 19 deletions(-) create mode 100644 engine/src/environments/fairy_state/fairyboard.cpp create mode 100644 engine/src/environments/fairy_state/fairyboard.h create mode 100644 engine/src/environments/fairy_state/fairyinputrepresentation.cpp create mode 100644 engine/src/environments/fairy_state/fairyinputrepresentation.h create mode 100644 engine/src/environments/fairy_state/fairyoutputrepresentation.cpp create mode 100644 engine/src/environments/fairy_state/fairyoutputrepresentation.h create mode 100644 engine/src/environments/fairy_state/fairypolicymaprepresentation.h create mode 100644 engine/src/environments/fairy_state/fairystate.cpp create mode 100644 engine/src/environments/fairy_state/fairystate.h create mode 100644 engine/src/environments/fairy_state/fairyutil.cpp create mode 100644 engine/src/environments/fairy_state/fairyutil.h diff --git a/engine/CMakeLists.txt b/engine/CMakeLists.txt index 1af4b366..8c983420 100644 --- a/engine/CMakeLists.txt +++ b/engine/CMakeLists.txt @@ -10,18 +10,19 @@ option(BACKEND_TORCH "Build with Torch backend (CPU/GPU) support" OF option(USE_960 "Build with 960 variant support" OFF) option(BUILD_TESTS "Build and run tests" OFF) # enable a single mode for different model input / outputs -option(MODE_CRAZYHOUSE "Build with crazyhouse only support" ON) +option(MODE_CRAZYHOUSE "Build with crazyhouse only support" OFF) option(MODE_CHESS "Build with chess + chess960 only support" OFF) option(MODE_LICHESS "Build with lichess variants support" OFF) option(MODE_OPEN_SPIEL "Build with open_spiel environment support" OFF) +option(MODE_XIANGQI "Build with xiangqi only support" ON) option(SEARCH_UCT "Build with UCT instead of PUCT search" OFF) add_definitions(-DIS_64BIT) -add_definitions(-DCRAZYHOUSE) if (MODE_CRAZYHOUSE) project(CrazyAra CXX) add_definitions(-DMODE_CRAZYHOUSE) + add_definitions(-DCRAZYHOUSE) endif() if (MODE_CHESS) @@ -49,6 +50,15 @@ if (MODE_OPEN_SPIEL) add_definitions(-DACTION_64_BIT) endif() +if (MODE_XIANGQI) + project(XiangqiAra CXX) + add_definitions(-DNO_THREADS) + add_definitions(-DXIANGQI) + add_definitions(-DMODE_XIANGQI) + add_definitions(-DLARGEBOARDS) + add_definitions(-DPRECOMPUTED_MAGICS) +endif () + if (BUILD_TESTS) add_definitions(-DBUILD_TESTS) endif() @@ -112,7 +122,12 @@ file(GLOB_RECURSE CPP_PACKAGE_HEADERS ) find_package (Threads) -include_directories("3rdparty/Stockfish/src") + +if (MODE_XIANGQI) + include_directories("3rdparty/Fairy-Stockfish/src") +else () + include_directories("3rdparty/Stockfish/src") +endif () include_directories("src") file(GLOB sf_related_files @@ -136,20 +151,50 @@ file(GLOB chess_related_files "src/environments/chess_related/*.h" "src/environments/chess_related/*.cpp" ) - -set(source_files - ${source_files} - ${sf_related_files} - ${uci_files} +file(GLOB xiangqi_related_files + "src/environments/fairy_state/*.h" + "src/environments/fairy_state/*.cpp" + ) +file(GLOB fsf_related_files + "3rdparty/Fairy-Stockfish/src/*.h" + "3rdparty/Fairy-Stockfish/src/*.cpp" + "3rdparty/Fairy-Stockfish/src/syzygy/tbprobe.h" + "3rdparty/Fairy-Stockfish/src/syzygy/tbprobe.cpp" + "3rdparty/Fairy-Stockfish/src/nnue/*.h" + "3rdparty/Fairy-Stockfish/src/nnue/*.cpp" + "3rdparty/Fairy-Stockfish/src/nnue/layers/*.h" + "3rdparty/Fairy-Stockfish/src/nnue/layers/*.cpp" + "3rdparty/Fairy-Stockfish/src/nnue/features/*.h" + "3rdparty/Fairy-Stockfish/src/nnue/features/*.cpp" + "3rdparty/Fairy-Stockfish/src/nnue/architectures/*.h" + "3rdparty/Fairy-Stockfish/src/nnue/architectures/*.cpp" ) +list(REMOVE_ITEM fsf_related_files ${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/Fairy-Stockfish/src/ffishjs.cpp) +list(REMOVE_ITEM fsf_related_files ${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/Fairy-Stockfish/src/pyffish.cpp) + +if (NOT MODE_XIANGQI) + set(source_files + ${source_files} + ${sf_related_files} + ${uci_files} + ) +endif () -if (NOT MODE_OPEN_SPIEL) +if (NOT MODE_OPEN_SPIEL AND NOT MODE_XIANGQI) set(source_files ${source_files} ${chess_related_files} ) endif() +if (MODE_XIANGQI) + set(source_files + ${source_files} + ${xiangqi_related_files} + ${fsf_related_files} + ${uci_files}) +endif () + if (MODE_OPEN_SPIEL) set (OPEN_SPIEL_CORE_FILES diff --git a/engine/src/environments/fairy_state/fairyboard.cpp b/engine/src/environments/fairy_state/fairyboard.cpp new file mode 100644 index 00000000..1143d58f --- /dev/null +++ b/engine/src/environments/fairy_state/fairyboard.cpp @@ -0,0 +1,138 @@ +#include "fairyboard.h" +#include "apiutil.h" +#include "fairyutil.h" + + +FairyBoard::FairyBoard() {} + +FairyBoard::FairyBoard(const FairyBoard &b) { + operator=(b); +} + +FairyBoard::~FairyBoard() {} + +FairyBoard& FairyBoard::operator=(const FairyBoard &b) { + std::copy(b.board, b.board+SQUARE_NB, this->board); + std::copy(b.byTypeBB, b.byTypeBB+PIECE_TYPE_NB, this->byTypeBB); + std::copy(b.byColorBB, b.byColorBB+COLOR_NB, this->byColorBB); + std::copy(b.pieceCount, b.pieceCount+PIECE_NB, this->pieceCount); + std::copy(b.castlingRightsMask, b.castlingRightsMask+SQUARE_NB, this->castlingRightsMask); + std::copy(b.castlingRookSquare, b.castlingRookSquare+CASTLING_RIGHT_NB, castlingRookSquare); + std::copy(b.castlingPath, b.castlingPath+CASTLING_RIGHT_NB, this->castlingPath); + this->gamePly = b.gamePly; + sideToMove = b.sideToMove; + psq = b.psq; + thisThread = b.thisThread; + st = b.st; + tsumeMode = b.tsumeMode; + chess960 = b.chess960; + std::copy(&b.pieceCountInHand[0][0], &b.pieceCountInHand[0][0]+COLOR_NB*PIECE_TYPE_NB, &this->pieceCountInHand[0][0]); + promotedPieces = b.promotedPieces; + var = b.var; + return *this; +} + +int FairyBoard::get_pocket_count(Color c, PieceType pt) const { + return Position::count_in_hand(c, pt); +} + +Key FairyBoard::hash_key() const { + return state()->key; +} + +bool FairyBoard::is_terminal() const { + for (const ExtMove move : MoveList(*this)) { + return false; + } + return true; +} + +size_t FairyBoard::number_repetitions() const { + StateInfo *st = state(); + if (st->repetition == 0) { + return 0; + } + else if (st->repetition) { + return 1; + } + else return 0; +} + +Result get_result(const FairyBoard &pos, bool inCheck) { + if (pos.is_terminal()) { + if (!inCheck) { + return DRAWN; + } + if (pos.side_to_move() == BLACK) { + return WHITE_WIN; + } + else { + return BLACK_WIN; + } + } + return NO_RESULT; +} + +std::string wxf_move(Move m, const FairyBoard& pos) { + Notation notation = NOTATION_XIANGQI_WXF; + + std::string wxf = ""; + + Color us = pos.side_to_move(); + Square from = from_sq(m); + Square to = to_sq(m); + + wxf += SAN::piece(pos, m, notation); + SAN::Disambiguation d = SAN::disambiguation_level(pos, m, notation); + wxf += disambiguation(pos, from, notation, d); + + if (rank_of(from) == rank_of(to)) { + wxf += "="; + } + else if (relative_rank(us, to, pos.max_rank()) > relative_rank(us, from, pos.max_rank())) { + wxf += "+"; + } + else { + wxf += "-"; + } + + if (type_of(m) != DROP) { + wxf += file_of(to) == file_of(from) ? std::to_string(std::abs(rank_of(to) - rank_of(from))) : SAN::file(pos, to, notation); + } + else { + wxf += SAN::square(pos, to, notation); + } + return wxf; +} + +std::string uci_move(Move m) { + std::string uciMove; + + Square from = from_sq(m); + Square to = to_sq(m); + + char fromFile = file_to_uci(file_of(from)); + std::string fromRank = rank_to_uci(rank_of(from)); + char toFile = file_to_uci(file_of(to)); + std::string toRank = rank_to_uci(rank_of(to)); + + return std::string(1, fromFile) + fromRank + std::string(1, toFile) + toRank; +} + +char file_to_uci(File file) { + for (auto it = FILE_LOOKUP.begin(); it != FILE_LOOKUP.end(); ++it) { + if (it->second == file) { + return it->first; + } + } + return char(); +} + +std::string rank_to_uci(Rank rank) { + for (auto it = RANK_LOOKUP.begin(); it != RANK_LOOKUP.end(); ++it) { + if (it->second == rank) { + return std::string(1, it->first); + } + } + return "10"; +} diff --git a/engine/src/environments/fairy_state/fairyboard.h b/engine/src/environments/fairy_state/fairyboard.h new file mode 100644 index 00000000..538ea183 --- /dev/null +++ b/engine/src/environments/fairy_state/fairyboard.h @@ -0,0 +1,31 @@ +#ifndef FAIRYBOARD_H +#define FAIRYBOARD_H + +#include +#include +#include "state.h" + +using blaze::StaticVector; +using blaze::DynamicVector; + +class FairyBoard : public Position +{ +public: + FairyBoard(); + FairyBoard(const FairyBoard& b); + ~FairyBoard(); + FairyBoard& operator=(const FairyBoard &b); + + int get_pocket_count(Color c, PieceType pt) const; + Key hash_key() const; + bool is_terminal() const; + size_t number_repetitions() const; +}; + +Result get_result(const FairyBoard &pos, bool inCheck); +std::string wxf_move(Move m, const FairyBoard& pos); +std::string uci_move(Move m); +char file_to_uci(File file); +std::string rank_to_uci(Rank rank); + +#endif //FAIRYBOARD_H diff --git a/engine/src/environments/fairy_state/fairyinputrepresentation.cpp b/engine/src/environments/fairy_state/fairyinputrepresentation.cpp new file mode 100644 index 00000000..4e1377d2 --- /dev/null +++ b/engine/src/environments/fairy_state/fairyinputrepresentation.cpp @@ -0,0 +1,64 @@ +#include "fairyinputrepresentation.h" +#include "fairystate.h" + +using namespace std; + +void set_bits_from_bitmap(Bitboard bitboard, size_t channel, float *inputPlanes, Color color) { + size_t p = 0; + while (bitboard != 0) { + if (bitboard & 0x1) { + if (color == WHITE) { + int col = std::abs(9-std::floor(p/9)); + int row = p % 9; + inputPlanes[channel * StateConstantsFairy::NB_SQUARES() + col * 9 + row] = 1; + } + else { + inputPlanes[channel * StateConstantsFairy::NB_SQUARES() + p] = 1; + } + } + // Largeboards use 12 files per rank, xiangqi boards only use 9 files per rank + (p+1) % 9 == 0 ? bitboard >>= 4 : bitboard >>= 1; + p++; + } +} + +void board_to_planes(const FairyBoard* pos, bool normalize, float *inputPlanes) { + fill(inputPlanes, inputPlanes + StateConstantsFairy::NB_VALUES_TOTAL(), 0.0f); + size_t currentChannel = 0; + Color me = pos->side_to_move(); + Color you = ~me; + + // pieces (ORDER: King, Advisor, Elephant, Horse, Rook, Cannon, Soldier) + for (Color color : {me, you}) { + for (PieceType piece : {KING, FERS, ELEPHANT, HORSE, ROOK, CANNON, SOLDIER}) { + const Bitboard pieces = pos->pieces(color, piece); + set_bits_from_bitmap(pieces, currentChannel, inputPlanes, me); + currentChannel++; + } + } + + // pocket count + for (Color color : {me, you}) { + for (PieceType piece : {FERS, ELEPHANT, HORSE, ROOK, CANNON, SOLDIER}) { + int pocket_cnt = pos->get_pocket_count(color, piece); + if (pocket_cnt > 0) { + std::fill(inputPlanes + currentChannel * StateConstantsFairy::NB_SQUARES(), + inputPlanes + (currentChannel + 1) * StateConstantsFairy::NB_SQUARES(), + normalize ? pocket_cnt / StateConstantsFairy::MAX_NB_PRISONERS() : pocket_cnt); + } + currentChannel++; + } + } + + // color + if (me == WHITE) { + std::fill(inputPlanes + currentChannel * StateConstantsFairy::NB_SQUARES(), + inputPlanes + (currentChannel + 1) * StateConstantsFairy::NB_SQUARES(), 1.0f); + } + currentChannel++; + + // total move count + std::fill(inputPlanes + currentChannel * StateConstantsFairy::NB_SQUARES(), + inputPlanes + (currentChannel + 1) * StateConstantsFairy::NB_SQUARES(), + normalize ? (std::floor(pos->game_ply() / 2 )) / StateConstantsFairy::MAX_FULL_MOVE_COUNTER() : std::floor(pos->game_ply() / 2 )); +} diff --git a/engine/src/environments/fairy_state/fairyinputrepresentation.h b/engine/src/environments/fairy_state/fairyinputrepresentation.h new file mode 100644 index 00000000..eff8a866 --- /dev/null +++ b/engine/src/environments/fairy_state/fairyinputrepresentation.h @@ -0,0 +1,23 @@ +#ifndef FAIRYINPUTREPRESENTATION_H +#define FAIRYINPUTREPRESENTATION_H + +#include "fairyboard.h" + +/** + * @brief board_to_planes Converts the given board representation into the plane representation. + * @param pos Board position + * @param normalize Flag, telling if the representation should be rescaled into the [0,1] range + * @param input_planes Output where the plane representation will be stored. + */ +void board_to_planes(const FairyBoard* pos, bool normalize, float *inputPlanes); + +/** + * @brief set_bits_from_bitmap Sets the individual bits from a given bitboard on the given channel for the inputPlanes + * @param bitboard Bitboard of a single 8x8 plane + * @param channel Channel index on where to set the bits + * @param input_planes Input planes encoded as flat vector + * @param color Color of the side to move + */ +inline void set_bits_from_bitmap(Bitboard bitboard, size_t channel, float *inputPlanes, Color color); + +#endif // FAIRYINPUTREPRESENTATION_H diff --git a/engine/src/environments/fairy_state/fairyoutputrepresentation.cpp b/engine/src/environments/fairy_state/fairyoutputrepresentation.cpp new file mode 100644 index 00000000..9f02e55f --- /dev/null +++ b/engine/src/environments/fairy_state/fairyoutputrepresentation.cpp @@ -0,0 +1,171 @@ +#include +#include "fairyoutputrepresentation.h" +#include "fairypolicymaprepresentation.h" +#include "fairystate.h" +#include "fairyutil.h" + +using namespace std; +using uci_labels::nbRanks; +using uci_labels::nbFiles; + +vector> uci_labels::get_destinations(int rankIdx, int fileIdx) { + vector> destinations; + for (int i = 0; i < nbFiles; ++i) { + tuple tmp{rankIdx, i}; + destinations.emplace_back(tmp); + } + + for (int i = 0; i < nbRanks; ++i) { + tuple tmp{i, fileIdx}; + destinations.emplace_back(tmp); + } + + // horse moves + array horseRankOffsets = {-2, -1, 1, 2, 2, 1, -1, -2}; + array horseFileOffsets = {-1, -2, -2, -1, 1, 2, 2, 1}; + for (int i = 0; i < horseFileOffsets.size(); ++i) { + int rankOffset = horseRankOffsets[i]; + int fileOffset = horseFileOffsets[i]; + tuple tmp{rankIdx + rankOffset, fileIdx + fileOffset}; + destinations.emplace_back(tmp); + } + + // elephant moves + vector elephantRankOffsets; + vector elephantFileOffsets; + if ((rankIdx == 0 && fileIdx == 2) || (rankIdx == 0 && fileIdx == 6) + || (rankIdx == 2 && fileIdx == 0) || (rankIdx == 2 && fileIdx == 4) || (rankIdx == 2 && fileIdx == 8) + || (rankIdx == 7 && fileIdx == 0) || (rankIdx == 7 && fileIdx == 4) || (rankIdx == 7 && fileIdx == 8) + || (rankIdx == 9 && fileIdx == 2) || (rankIdx == 9 && fileIdx == 6)) { + elephantRankOffsets = {2, 2, -2, -2}; + elephantFileOffsets = {-2, 2, -2, 2}; + } else if (rankIdx == 4 && (fileIdx == 2 || fileIdx == 6)) { + elephantRankOffsets = {-2, -2}; + elephantFileOffsets = {-2, 2}; + } else if (rankIdx == 5 && (fileIdx == 2 || fileIdx == 6)) { + elephantRankOffsets = {2, 2}; + elephantFileOffsets = {-2, 2}; + } + for (int i = 0; i < elephantFileOffsets.size(); ++i) { + int rankOffset = elephantRankOffsets[i]; + int fileOffset = elephantFileOffsets[i]; + tuple tmp{rankIdx + rankOffset, fileIdx + fileOffset}; + destinations.emplace_back(tmp); + } + + // advisor diagonal moves from mid palace + if (fileIdx == 4 && (rankIdx == 1 || rankIdx == 8)) { + array advisorRankOffsets = {-1, 1, 1, -1}; + array advisorFileOffsets = {-1, -1, 1, 1}; + for (int i = 0; i < advisorFileOffsets.size(); ++i) { + int rankOffset = advisorRankOffsets[i]; + int fileOffset = advisorFileOffsets[i]; + tuple tmp{rankIdx + rankOffset, fileIdx + fileOffset}; + destinations.emplace_back(tmp); + } + } + return destinations; +} + +vector uci_labels::generate_uci_labels() { + vector labels; + const array ranks = uci_labels::ranks(); + const array files = uci_labels::files(); + for (int rankIdx = 0; rankIdx < nbRanks; ++rankIdx) { + for (int fileIdx = 0; fileIdx < nbFiles; ++fileIdx) { + vector> destinations = uci_labels::get_destinations(rankIdx, fileIdx); + for (tuple destination : destinations) { + int rankIdx2 = get<0>(destination); + int fileIdx2 = get<1>(destination); + if ((fileIdx != fileIdx2 || rankIdx != rankIdx2) + && fileIdx2 >= 0 && fileIdx2 < nbFiles && rankIdx2 >= 0 && rankIdx2 < nbRanks) { + string move = files[fileIdx] + ranks[rankIdx] + files[fileIdx2] + ranks[rankIdx2]; + labels.emplace_back(move); + } + } + } + } + + // advidor moves to mid palace + labels.emplace_back("d1e2"); + labels.emplace_back("f1e2"); + labels.emplace_back("d3e2"); + labels.emplace_back("f3e2"); + labels.emplace_back("d10e9"); + labels.emplace_back("f10e9"); + labels.emplace_back("d8e9"); + labels.emplace_back("f8e9"); + return labels; +} + +string uci_labels::mirror_move(const string &ucciMove) { + // a10b10 + if (ucciMove.size() == 6) { + return string(1, ucciMove[0]) + string(1, '1') + string(1, ucciMove[3]) + string(1, '1'); + } + else if (ucciMove.size() == 5) { + // a10a9 + if (isdigit(ucciMove[2])) { + int rankTo = ucciMove[4] - '0'; + int rankToMirrored = 10 - rankTo + 1; + return string(1, ucciMove[0]) + string(1, '1') + string(1, ucciMove[3]) + to_string(rankToMirrored); + } + // a9a10 + else { + int rankFrom = ucciMove[1] - '0'; + int rankFromMirrored = 10 - rankFrom + 1; + return string(1, ucciMove[0]) + to_string(rankFromMirrored) + string(1, ucciMove[2]) + string(1, '1'); + } + } + // a1b1 + else { + string moveMirrored; + for (size_t i = 0; i < ucciMove.length(); ++i) { + if (isdigit(ucciMove[i])) { + int rank = ucciMove[i] - '0'; + int rankMirrored = 10 - rank + 1; + moveMirrored += to_string(rankMirrored); + } + else { + moveMirrored += ucciMove[i]; + } + } + return moveMirrored; + } +} + +array uci_labels::files() { + return {"a", "b", "c", "d", "e", "f", "g", "h", "i"}; +} + +array uci_labels::ranks() { + return {"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}; +} + +void FairyOutputRepresentation::init_labels() { + LABELS = uci_labels::generate_uci_labels(); + if (LABELS.size() != StateConstantsFairy::NB_LABELS()) { + cerr << "LABELS.size() != StateConstantsFairy::NB_LABELS():" << LABELS.size() << " " + << StateConstantsFairy::NB_LABELS() << endl; + assert(false); + } + LABELS_MIRRORED.resize(LABELS.size()); +} + +void FairyOutputRepresentation::init_policy_constants(bool isPolicyMap) { + for (size_t i = 0; i < StateConstantsFairy::NB_LABELS(); ++i) { + LABELS_MIRRORED[i] = uci_labels::mirror_move(LABELS[i]); + + Square fromSquare = get_origin_square(LABELS[i]); + Square toSquare = get_destination_square(LABELS[i]); + Move move = make_move(fromSquare, toSquare); + isPolicyMap ? MV_LOOKUP[move] = FLAT_PLANE_IDX[i] : MV_LOOKUP[move] = i; + MV_LOOKUP_CLASSIC[move] = i; + + Square fromSquareMirrored = get_origin_square(LABELS_MIRRORED[i]); + Square toSquareMirrored = get_destination_square(LABELS_MIRRORED[i]); + Move moveMirrored = make_move(fromSquareMirrored, toSquareMirrored); + isPolicyMap ? MV_LOOKUP_MIRRORED[moveMirrored] = FLAT_PLANE_IDX[i] : MV_LOOKUP_MIRRORED[moveMirrored] = i; + MV_LOOKUP_MIRRORED_CLASSIC[moveMirrored] = i; + } +} diff --git a/engine/src/environments/fairy_state/fairyoutputrepresentation.h b/engine/src/environments/fairy_state/fairyoutputrepresentation.h new file mode 100644 index 00000000..bada7280 --- /dev/null +++ b/engine/src/environments/fairy_state/fairyoutputrepresentation.h @@ -0,0 +1,54 @@ +#ifndef FAIRYOUTPUTREPRESENTATION_H +#define FAIRYOUTPUTREPRESENTATION_H + +#include +#include +#include +#include "state.h" + +using blaze::HybridVector; +using blaze::DynamicVector; +using action_idx_map = Action[USHRT_MAX]; + +using namespace std; + +namespace uci_labels { + const int nbRanks = 10; + const int nbFiles = 9; + + /** + * @brief get_destinations Returns all possible destinations on a Xiangqi board for a given square index. + */ + vector> get_destinations(int rankIdx, int fileIdx); + + vector generate_uci_labels(); + + string mirror_move(const string &ucciMove); + + // For the ucci labels we begin with index 0 for the ranks + array ranks(); + + array files(); +} + +struct FairyOutputRepresentation { + static vector LABELS; + static vector LABELS_MIRRORED; + static action_idx_map MV_LOOKUP; + static action_idx_map MV_LOOKUP_MIRRORED; + static action_idx_map MV_LOOKUP_CLASSIC; + static action_idx_map MV_LOOKUP_MIRRORED_CLASSIC; + + /** + * @brief init_labels Generates all labels in ucci move notation. + */ + static void init_labels(); + + /** + * @brief init_policy_constants Fills the hash maps for a action to Neural Network index binding. + * @param isPolicyMap describes if a policy map head is used for the Neural Network. + */ + static void init_policy_constants(bool isPolicyMap); +}; + +#endif //FAIRYOUTPUTREPRESENTATION_H diff --git a/engine/src/environments/fairy_state/fairypolicymaprepresentation.h b/engine/src/environments/fairy_state/fairypolicymaprepresentation.h new file mode 100644 index 00000000..a43d4ae9 --- /dev/null +++ b/engine/src/environments/fairy_state/fairypolicymaprepresentation.h @@ -0,0 +1,2093 @@ +#ifndef FAIRYPOLICYMAPREPRESENTATION_H +#define FAIRYPOLICYMAPREPRESENTATION_H + +const unsigned long FLAT_PLANE_IDX[] = { + 810, + 900, + 990, + 1080, + 1170, + 1260, + 1350, + 1440, + 0, + 90, + 180, + 270, + 360, + 450, + 540, + 630, + 720, + 3780, + 3870, + 2341, + 811, + 901, + 991, + 1081, + 1171, + 1261, + 1351, + 1, + 91, + 181, + 271, + 361, + 451, + 541, + 631, + 721, + 4411, + 3781, + 3871, + 2432, + 2342, + 812, + 902, + 992, + 1082, + 1172, + 1262, + 2, + 92, + 182, + 272, + 362, + 452, + 542, + 632, + 722, + 4322, + 4412, + 3782, + 3872, + 3692, + 3152, + 2523, + 2433, + 2343, + 813, + 903, + 993, + 1083, + 1173, + 3, + 93, + 183, + 273, + 363, + 453, + 543, + 633, + 723, + 4323, + 4413, + 3783, + 3873, + 2614, + 2524, + 2434, + 2344, + 814, + 904, + 994, + 1084, + 4, + 94, + 184, + 274, + 364, + 454, + 544, + 634, + 724, + 4324, + 4414, + 3784, + 3874, + 2705, + 2615, + 2525, + 2435, + 2345, + 815, + 905, + 995, + 5, + 95, + 185, + 275, + 365, + 455, + 545, + 635, + 725, + 4325, + 4415, + 3785, + 3875, + 2796, + 2706, + 2616, + 2526, + 2436, + 2346, + 816, + 906, + 6, + 96, + 186, + 276, + 366, + 456, + 546, + 636, + 726, + 4326, + 4416, + 3786, + 3876, + 3696, + 3156, + 2887, + 2797, + 2707, + 2617, + 2527, + 2437, + 2347, + 817, + 7, + 97, + 187, + 277, + 367, + 457, + 547, + 637, + 727, + 4327, + 4417, + 3787, + 2978, + 2888, + 2798, + 2708, + 2618, + 2528, + 2438, + 2348, + 8, + 98, + 188, + 278, + 368, + 458, + 548, + 638, + 728, + 4328, + 4418, + 819, + 909, + 999, + 1089, + 1179, + 1269, + 1359, + 1449, + 1539, + 9, + 99, + 189, + 279, + 369, + 459, + 549, + 639, + 3789, + 3879, + 3969, + 2350, + 820, + 910, + 1000, + 1090, + 1180, + 1270, + 1360, + 1540, + 10, + 100, + 190, + 280, + 370, + 460, + 550, + 640, + 4420, + 3790, + 3880, + 3970, + 2441, + 2351, + 821, + 911, + 1001, + 1091, + 1181, + 1271, + 1541, + 11, + 101, + 191, + 281, + 371, + 461, + 551, + 641, + 4241, + 4331, + 4421, + 3791, + 3881, + 3971, + 2532, + 2442, + 2352, + 822, + 912, + 1002, + 1092, + 1182, + 1542, + 12, + 102, + 192, + 282, + 372, + 462, + 552, + 642, + 4242, + 4332, + 4422, + 3792, + 3882, + 3972, + 2623, + 2533, + 2443, + 2353, + 823, + 913, + 1003, + 1093, + 1543, + 13, + 103, + 193, + 283, + 373, + 463, + 553, + 643, + 4243, + 4333, + 4423, + 3793, + 3883, + 3973, + 3433, + 3613, + 3073, + 3253, + 2714, + 2624, + 2534, + 2444, + 2354, + 824, + 914, + 1004, + 1544, + 14, + 104, + 194, + 284, + 374, + 464, + 554, + 644, + 4244, + 4334, + 4424, + 3794, + 3884, + 3974, + 2805, + 2715, + 2625, + 2535, + 2445, + 2355, + 825, + 915, + 1545, + 15, + 105, + 195, + 285, + 375, + 465, + 555, + 645, + 4245, + 4335, + 4425, + 3795, + 3885, + 3975, + 2896, + 2806, + 2716, + 2626, + 2536, + 2446, + 2356, + 826, + 1546, + 16, + 106, + 196, + 286, + 376, + 466, + 556, + 646, + 4246, + 4336, + 4426, + 3796, + 2987, + 2897, + 2807, + 2717, + 2627, + 2537, + 2447, + 2357, + 1547, + 17, + 107, + 197, + 287, + 377, + 467, + 557, + 647, + 4247, + 4337, + 4427, + 828, + 918, + 1008, + 1098, + 1188, + 1278, + 1368, + 1458, + 1638, + 1548, + 18, + 108, + 198, + 288, + 378, + 468, + 558, + 3798, + 3888, + 3978, + 4068, + 3168, + 3348, + 2359, + 829, + 919, + 1009, + 1099, + 1189, + 1279, + 1369, + 1639, + 1549, + 19, + 109, + 199, + 289, + 379, + 469, + 559, + 4159, + 4429, + 3799, + 3889, + 3979, + 4069, + 2450, + 2360, + 830, + 920, + 1010, + 1100, + 1190, + 1280, + 1640, + 1550, + 20, + 110, + 200, + 290, + 380, + 470, + 560, + 4160, + 4250, + 4340, + 4430, + 3800, + 3890, + 3980, + 4070, + 2541, + 2451, + 2361, + 831, + 921, + 1011, + 1101, + 1191, + 1641, + 1551, + 21, + 111, + 201, + 291, + 381, + 471, + 561, + 4161, + 4251, + 4341, + 4431, + 3801, + 3891, + 3981, + 4071, + 2632, + 2542, + 2452, + 2362, + 832, + 922, + 1012, + 1102, + 1642, + 1552, + 22, + 112, + 202, + 292, + 382, + 472, + 562, + 4162, + 4252, + 4342, + 4432, + 3802, + 3892, + 3982, + 4072, + 3712, + 3172, + 3532, + 3352, + 2723, + 2633, + 2543, + 2453, + 2363, + 833, + 923, + 1013, + 1643, + 1553, + 23, + 113, + 203, + 293, + 383, + 473, + 563, + 4163, + 4253, + 4343, + 4433, + 3803, + 3893, + 3983, + 4073, + 2814, + 2724, + 2634, + 2544, + 2454, + 2364, + 834, + 924, + 1644, + 1554, + 24, + 114, + 204, + 294, + 384, + 474, + 564, + 4164, + 4254, + 4344, + 4434, + 3804, + 3894, + 3984, + 4074, + 2905, + 2815, + 2725, + 2635, + 2545, + 2455, + 2365, + 835, + 1645, + 1555, + 25, + 115, + 205, + 295, + 385, + 475, + 565, + 4165, + 4255, + 4345, + 4435, + 3805, + 4075, + 2996, + 2906, + 2816, + 2726, + 2636, + 2546, + 2456, + 2366, + 1646, + 1556, + 26, + 116, + 206, + 296, + 386, + 476, + 566, + 4166, + 4256, + 4346, + 4436, + 3716, + 3536, + 837, + 927, + 1017, + 1107, + 1197, + 1287, + 1377, + 1467, + 1737, + 1647, + 1557, + 27, + 117, + 207, + 297, + 387, + 477, + 3807, + 3897, + 3987, + 4077, + 2368, + 838, + 928, + 1018, + 1108, + 1198, + 1288, + 1378, + 1738, + 1648, + 1558, + 28, + 118, + 208, + 298, + 388, + 478, + 4168, + 4438, + 3808, + 3898, + 3988, + 4078, + 2459, + 2369, + 839, + 929, + 1019, + 1109, + 1199, + 1289, + 1739, + 1649, + 1559, + 29, + 119, + 209, + 299, + 389, + 479, + 4169, + 4259, + 4349, + 4439, + 3809, + 3899, + 3989, + 4079, + 2550, + 2460, + 2370, + 840, + 930, + 1020, + 1110, + 1200, + 1740, + 1650, + 1560, + 30, + 120, + 210, + 300, + 390, + 480, + 4170, + 4260, + 4350, + 4440, + 3810, + 3900, + 3990, + 4080, + 2641, + 2551, + 2461, + 2371, + 841, + 931, + 1021, + 1111, + 1741, + 1651, + 1561, + 31, + 121, + 211, + 301, + 391, + 481, + 4171, + 4261, + 4351, + 4441, + 3811, + 3901, + 3991, + 4081, + 2732, + 2642, + 2552, + 2462, + 2372, + 842, + 932, + 1022, + 1742, + 1652, + 1562, + 32, + 122, + 212, + 302, + 392, + 482, + 4172, + 4262, + 4352, + 4442, + 3812, + 3902, + 3992, + 4082, + 2823, + 2733, + 2643, + 2553, + 2463, + 2373, + 843, + 933, + 1743, + 1653, + 1563, + 33, + 123, + 213, + 303, + 393, + 483, + 4173, + 4263, + 4353, + 4443, + 3813, + 3903, + 3993, + 4083, + 2914, + 2824, + 2734, + 2644, + 2554, + 2464, + 2374, + 844, + 1744, + 1654, + 1564, + 34, + 124, + 214, + 304, + 394, + 484, + 4174, + 4264, + 4354, + 4444, + 3814, + 4084, + 3005, + 2915, + 2825, + 2735, + 2645, + 2555, + 2465, + 2375, + 1745, + 1655, + 1565, + 35, + 125, + 215, + 305, + 395, + 485, + 4175, + 4265, + 4355, + 4445, + 846, + 936, + 1026, + 1116, + 1206, + 1296, + 1386, + 1476, + 1836, + 1746, + 1656, + 1566, + 36, + 126, + 216, + 306, + 396, + 3816, + 3906, + 3996, + 4086, + 2377, + 847, + 937, + 1027, + 1117, + 1207, + 1297, + 1387, + 1837, + 1747, + 1657, + 1567, + 37, + 127, + 217, + 307, + 397, + 4177, + 4447, + 3817, + 3907, + 3997, + 4087, + 2468, + 2378, + 848, + 938, + 1028, + 1118, + 1208, + 1298, + 1838, + 1748, + 1658, + 1568, + 38, + 128, + 218, + 308, + 398, + 4178, + 4268, + 4358, + 4448, + 3818, + 3908, + 3998, + 4088, + 3548, + 3368, + 2559, + 2469, + 2379, + 849, + 939, + 1029, + 1119, + 1209, + 1839, + 1749, + 1659, + 1569, + 39, + 129, + 219, + 309, + 399, + 4179, + 4269, + 4359, + 4449, + 3819, + 3909, + 3999, + 4089, + 2650, + 2560, + 2470, + 2380, + 850, + 940, + 1030, + 1120, + 1840, + 1750, + 1660, + 1570, + 40, + 130, + 220, + 310, + 400, + 4180, + 4270, + 4360, + 4450, + 3820, + 3910, + 4000, + 4090, + 2741, + 2651, + 2561, + 2471, + 2381, + 851, + 941, + 1031, + 1841, + 1751, + 1661, + 1571, + 41, + 131, + 221, + 311, + 401, + 4181, + 4271, + 4361, + 4451, + 3821, + 3911, + 4001, + 4091, + 2832, + 2742, + 2652, + 2562, + 2472, + 2382, + 852, + 942, + 1842, + 1752, + 1662, + 1572, + 42, + 132, + 222, + 312, + 402, + 4182, + 4272, + 4362, + 4452, + 3822, + 3912, + 4002, + 4092, + 3552, + 3372, + 2923, + 2833, + 2743, + 2653, + 2563, + 2473, + 2383, + 853, + 1843, + 1753, + 1663, + 1573, + 43, + 133, + 223, + 313, + 403, + 4183, + 4273, + 4363, + 4453, + 3823, + 4093, + 3014, + 2924, + 2834, + 2744, + 2654, + 2564, + 2474, + 2384, + 1844, + 1754, + 1664, + 1574, + 44, + 134, + 224, + 314, + 404, + 4184, + 4274, + 4364, + 4454, + 855, + 945, + 1035, + 1125, + 1215, + 1305, + 1395, + 1485, + 1935, + 1845, + 1755, + 1665, + 1575, + 45, + 135, + 225, + 315, + 3825, + 3915, + 4005, + 4095, + 2386, + 856, + 946, + 1036, + 1126, + 1216, + 1306, + 1396, + 1936, + 1846, + 1756, + 1666, + 1576, + 46, + 136, + 226, + 316, + 4186, + 4456, + 3826, + 3916, + 4006, + 4096, + 2477, + 2387, + 857, + 947, + 1037, + 1127, + 1217, + 1307, + 1937, + 1847, + 1757, + 1667, + 1577, + 47, + 137, + 227, + 317, + 4187, + 4277, + 4367, + 4457, + 3827, + 3917, + 4007, + 4097, + 3737, + 3197, + 2568, + 2478, + 2388, + 858, + 948, + 1038, + 1128, + 1218, + 1938, + 1848, + 1758, + 1668, + 1578, + 48, + 138, + 228, + 318, + 4188, + 4278, + 4368, + 4458, + 3828, + 3918, + 4008, + 4098, + 2659, + 2569, + 2479, + 2389, + 859, + 949, + 1039, + 1129, + 1939, + 1849, + 1759, + 1669, + 1579, + 49, + 139, + 229, + 319, + 4189, + 4279, + 4369, + 4459, + 3829, + 3919, + 4009, + 4099, + 2750, + 2660, + 2570, + 2480, + 2390, + 860, + 950, + 1040, + 1940, + 1850, + 1760, + 1670, + 1580, + 50, + 140, + 230, + 320, + 4190, + 4280, + 4370, + 4460, + 3830, + 3920, + 4010, + 4100, + 2841, + 2751, + 2661, + 2571, + 2481, + 2391, + 861, + 951, + 1941, + 1851, + 1761, + 1671, + 1581, + 51, + 141, + 231, + 321, + 4191, + 4281, + 4371, + 4461, + 3831, + 3921, + 4011, + 4101, + 3741, + 3201, + 2932, + 2842, + 2752, + 2662, + 2572, + 2482, + 2392, + 862, + 1942, + 1852, + 1762, + 1672, + 1582, + 52, + 142, + 232, + 322, + 4192, + 4282, + 4372, + 4462, + 3832, + 4102, + 3023, + 2933, + 2843, + 2753, + 2663, + 2573, + 2483, + 2393, + 1943, + 1853, + 1763, + 1673, + 1583, + 53, + 143, + 233, + 323, + 4193, + 4283, + 4373, + 4463, + 864, + 954, + 1044, + 1134, + 1224, + 1314, + 1404, + 1494, + 2034, + 1944, + 1854, + 1764, + 1674, + 1584, + 54, + 144, + 234, + 3834, + 3924, + 4014, + 4104, + 2395, + 865, + 955, + 1045, + 1135, + 1225, + 1315, + 1405, + 2035, + 1945, + 1855, + 1765, + 1675, + 1585, + 55, + 145, + 235, + 4195, + 4465, + 3835, + 3925, + 4015, + 4105, + 2486, + 2396, + 866, + 956, + 1046, + 1136, + 1226, + 1316, + 2036, + 1946, + 1856, + 1766, + 1676, + 1586, + 56, + 146, + 236, + 4196, + 4286, + 4376, + 4466, + 3836, + 3926, + 4016, + 4106, + 2577, + 2487, + 2397, + 867, + 957, + 1047, + 1137, + 1227, + 2037, + 1947, + 1857, + 1767, + 1677, + 1587, + 57, + 147, + 237, + 4197, + 4287, + 4377, + 4467, + 3837, + 3927, + 4017, + 4107, + 2668, + 2578, + 2488, + 2398, + 868, + 958, + 1048, + 1138, + 2038, + 1948, + 1858, + 1768, + 1678, + 1588, + 58, + 148, + 238, + 4198, + 4288, + 4378, + 4468, + 3838, + 3928, + 4018, + 4108, + 2759, + 2669, + 2579, + 2489, + 2399, + 869, + 959, + 1049, + 2039, + 1949, + 1859, + 1769, + 1679, + 1589, + 59, + 149, + 239, + 4199, + 4289, + 4379, + 4469, + 3839, + 3929, + 4019, + 4109, + 2850, + 2760, + 2670, + 2580, + 2490, + 2400, + 870, + 960, + 2040, + 1950, + 1860, + 1770, + 1680, + 1590, + 60, + 150, + 240, + 4200, + 4290, + 4380, + 4470, + 3840, + 3930, + 4020, + 4110, + 2941, + 2851, + 2761, + 2671, + 2581, + 2491, + 2401, + 871, + 2041, + 1951, + 1861, + 1771, + 1681, + 1591, + 61, + 151, + 241, + 4201, + 4291, + 4381, + 4471, + 3841, + 4111, + 3032, + 2942, + 2852, + 2762, + 2672, + 2582, + 2492, + 2402, + 2042, + 1952, + 1862, + 1772, + 1682, + 1592, + 62, + 152, + 242, + 4202, + 4292, + 4382, + 4472, + 873, + 963, + 1053, + 1143, + 1233, + 1323, + 1413, + 1503, + 2133, + 2043, + 1953, + 1863, + 1773, + 1683, + 1593, + 63, + 153, + 3843, + 3933, + 4023, + 4113, + 3213, + 3393, + 2404, + 874, + 964, + 1054, + 1144, + 1234, + 1324, + 1414, + 2134, + 2044, + 1954, + 1864, + 1774, + 1684, + 1594, + 64, + 154, + 4204, + 4474, + 3844, + 3934, + 4024, + 4114, + 2495, + 2405, + 875, + 965, + 1055, + 1145, + 1235, + 1325, + 2135, + 2045, + 1955, + 1865, + 1775, + 1685, + 1595, + 65, + 155, + 4205, + 4295, + 4385, + 4475, + 3845, + 3935, + 4025, + 4115, + 2586, + 2496, + 2406, + 876, + 966, + 1056, + 1146, + 1236, + 2136, + 2046, + 1956, + 1866, + 1776, + 1686, + 1596, + 66, + 156, + 4206, + 4296, + 4386, + 4476, + 3846, + 3936, + 4026, + 4116, + 2677, + 2587, + 2497, + 2407, + 877, + 967, + 1057, + 1147, + 2137, + 2047, + 1957, + 1867, + 1777, + 1687, + 1597, + 67, + 157, + 4207, + 4297, + 4387, + 4477, + 3847, + 3937, + 4027, + 4117, + 3757, + 3217, + 3577, + 3397, + 2768, + 2678, + 2588, + 2498, + 2408, + 878, + 968, + 1058, + 2138, + 2048, + 1958, + 1868, + 1778, + 1688, + 1598, + 68, + 158, + 4208, + 4298, + 4388, + 4478, + 3848, + 3938, + 4028, + 4118, + 2859, + 2769, + 2679, + 2589, + 2499, + 2409, + 879, + 969, + 2139, + 2049, + 1959, + 1869, + 1779, + 1689, + 1599, + 69, + 159, + 4209, + 4299, + 4389, + 4479, + 3849, + 3939, + 4029, + 4119, + 2950, + 2860, + 2770, + 2680, + 2590, + 2500, + 2410, + 880, + 2140, + 2050, + 1960, + 1870, + 1780, + 1690, + 1600, + 70, + 160, + 4210, + 4300, + 4390, + 4480, + 3850, + 4120, + 3041, + 2951, + 2861, + 2771, + 2681, + 2591, + 2501, + 2411, + 2141, + 2051, + 1961, + 1871, + 1781, + 1691, + 1601, + 71, + 161, + 4211, + 4301, + 4391, + 4481, + 3761, + 3581, + 882, + 972, + 1062, + 1152, + 1242, + 1332, + 1422, + 1512, + 2232, + 2142, + 2052, + 1962, + 1872, + 1782, + 1692, + 1602, + 72, + 3942, + 4032, + 4122, + 2413, + 883, + 973, + 1063, + 1153, + 1243, + 1333, + 1423, + 2233, + 2143, + 2053, + 1963, + 1873, + 1783, + 1693, + 1603, + 73, + 4213, + 3943, + 4033, + 4123, + 2504, + 2414, + 884, + 974, + 1064, + 1154, + 1244, + 1334, + 2234, + 2144, + 2054, + 1964, + 1874, + 1784, + 1694, + 1604, + 74, + 4214, + 4304, + 4394, + 3944, + 4034, + 4124, + 2595, + 2505, + 2415, + 885, + 975, + 1065, + 1155, + 1245, + 2235, + 2145, + 2055, + 1965, + 1875, + 1785, + 1695, + 1605, + 75, + 4215, + 4305, + 4395, + 3945, + 4035, + 4125, + 2686, + 2596, + 2506, + 2416, + 886, + 976, + 1066, + 1156, + 2236, + 2146, + 2056, + 1966, + 1876, + 1786, + 1696, + 1606, + 76, + 4216, + 4306, + 4396, + 3946, + 4036, + 4126, + 3496, + 3676, + 3136, + 3316, + 2777, + 2687, + 2597, + 2507, + 2417, + 887, + 977, + 1067, + 2237, + 2147, + 2057, + 1967, + 1877, + 1787, + 1697, + 1607, + 77, + 4217, + 4307, + 4397, + 3947, + 4037, + 4127, + 2868, + 2778, + 2688, + 2598, + 2508, + 2418, + 888, + 978, + 2238, + 2148, + 2058, + 1968, + 1878, + 1788, + 1698, + 1608, + 78, + 4218, + 4308, + 4398, + 3948, + 4038, + 4128, + 2959, + 2869, + 2779, + 2689, + 2599, + 2509, + 2419, + 889, + 2239, + 2149, + 2059, + 1969, + 1879, + 1789, + 1699, + 1609, + 79, + 4219, + 4309, + 4399, + 4129, + 3050, + 2960, + 2870, + 2780, + 2690, + 2600, + 2510, + 2420, + 2240, + 2150, + 2060, + 1970, + 1880, + 1790, + 1700, + 1610, + 80, + 4220, + 4310, + 4400, + 891, + 981, + 1071, + 1161, + 1251, + 1341, + 1431, + 1521, + 2331, + 2241, + 2151, + 2061, + 1971, + 1881, + 1791, + 1701, + 1611, + 4041, + 4131, + 2422, + 892, + 982, + 1072, + 1162, + 1252, + 1342, + 1432, + 2332, + 2242, + 2152, + 2062, + 1972, + 1882, + 1792, + 1702, + 1612, + 4222, + 4042, + 4132, + 2513, + 2423, + 893, + 983, + 1073, + 1163, + 1253, + 1343, + 2333, + 2243, + 2153, + 2063, + 1973, + 1883, + 1793, + 1703, + 1613, + 4223, + 4313, + 4043, + 4133, + 3593, + 3413, + 2604, + 2514, + 2424, + 894, + 984, + 1074, + 1164, + 1254, + 2334, + 2244, + 2154, + 2064, + 1974, + 1884, + 1794, + 1704, + 1614, + 4224, + 4314, + 4044, + 4134, + 2695, + 2605, + 2515, + 2425, + 895, + 985, + 1075, + 1165, + 2335, + 2245, + 2155, + 2065, + 1975, + 1885, + 1795, + 1705, + 1615, + 4225, + 4315, + 4045, + 4135, + 2786, + 2696, + 2606, + 2516, + 2426, + 896, + 986, + 1076, + 2336, + 2246, + 2156, + 2066, + 1976, + 1886, + 1796, + 1706, + 1616, + 4226, + 4316, + 4046, + 4136, + 2877, + 2787, + 2697, + 2607, + 2517, + 2427, + 897, + 987, + 2337, + 2247, + 2157, + 2067, + 1977, + 1887, + 1797, + 1707, + 1617, + 4227, + 4317, + 4047, + 4137, + 3597, + 3417, + 2968, + 2878, + 2788, + 2698, + 2608, + 2518, + 2428, + 898, + 2338, + 2248, + 2158, + 2068, + 1978, + 1888, + 1798, + 1708, + 1618, + 4228, + 4318, + 4138, + 3059, + 2969, + 2879, + 2789, + 2699, + 2609, + 2519, + 2429, + 2339, + 2249, + 2159, + 2069, + 1979, + 1889, + 1799, + 1709, + 1619, + 4229, + 4319, + 3063, + 3605, + 3261, + 3443, + 3324, + 3506, + 3126, + 3668 +}; + +#endif FAIRYPOLICYMAPREPRESENTATION_H diff --git a/engine/src/environments/fairy_state/fairystate.cpp b/engine/src/environments/fairy_state/fairystate.cpp new file mode 100644 index 00000000..1412158f --- /dev/null +++ b/engine/src/environments/fairy_state/fairystate.cpp @@ -0,0 +1,126 @@ +#include "fairystate.h" +#include "fairyinputrepresentation.h" +#include "position.h" +#include "movegen.h" +#include "variant.h" + + +action_idx_map FairyOutputRepresentation::MV_LOOKUP = {}; +action_idx_map FairyOutputRepresentation::MV_LOOKUP_MIRRORED = {}; +action_idx_map FairyOutputRepresentation::MV_LOOKUP_CLASSIC = {}; +action_idx_map FairyOutputRepresentation::MV_LOOKUP_MIRRORED_CLASSIC = {}; +vector FairyOutputRepresentation::LABELS; +vector FairyOutputRepresentation::LABELS_MIRRORED; + +FairyState::FairyState() : + State(), + states(StateListPtr(new std::deque(0))) {} + +FairyState::FairyState(const FairyState &f) : + State(), + board(f.board), + states(StateListPtr(new std::deque(0))) { + states->emplace_back(f.states->back()); +} + +std::vector FairyState::legal_actions() const { + std::vector legalMoves; + for (const ExtMove &move : MoveList(board)) { + legalMoves.push_back(Action(move.move)); + } + return legalMoves; +} + +void FairyState::set(const string &fenStr, bool isChess960, int variant) { + states = StateListPtr(new std::deque(1)); + Thread *thread; + board.set(variants.find("xiangqi")->second, fenStr, isChess960, &states->back(), thread, false); +} + +void FairyState::get_state_planes(bool normalize, float *inputPlanes) const { + board_to_planes(&board, normalize, inputPlanes); +} + +unsigned int FairyState::steps_from_null() const { + return board.game_ply(); +} + +bool FairyState::is_chess960() const { + return false; +} + +string FairyState::fen() const { + return board.fen(); +} + +void FairyState::do_action(Action action) { + states->emplace_back(); + board.do_move(Move(action), states->back()); +} + +void FairyState::undo_action(Action action) { + board.undo_move(Move(action)); +} + +void FairyState::prepare_action() { + // pass +} + +int FairyState::side_to_move() const { + return board.side_to_move(); +} + +Key FairyState::hash_key() const { + return board.hash_key(); +} + +void FairyState::flip() { + board.flip(); +} + +Action FairyState::uci_to_action(string &uciStr) const { + return Action(UCI::to_move(board, uciStr)); +} + +TerminalType FairyState::is_terminal(size_t numberLegalMoves, bool inCheck, float &customTerminalValue) const { + if (numberLegalMoves == 0) { + if (inCheck) { + return TERMINAL_LOSS; + } + return TERMINAL_DRAW; + } + return TERMINAL_NONE; +} + +Result FairyState::check_result(bool inCheck) const { + return get_result(board, inCheck); +} + +bool FairyState::gives_check(Action action) const { + return board.gives_check(Move(action)); +} + +void FairyState::print(ostream &os) const +{ + os << board; +} + +FairyState* FairyState::clone() const { + return new FairyState(*this); +} + +unsigned int FairyState::number_repetitions() const { + return board.number_repetitions(); +} + +string FairyState::action_to_san(Action action, const std::vector &legalActions, bool leadsToWin, bool bookMove) const { + return uci_move(Move(action)); +} + +Tablebase::WDLScore FairyState::check_for_tablebase_wdl(Tablebase::ProbeState &result) { + +} + +void FairyState::set_auxiliary_outputs(const float* auxiliaryOutputs) { + +} diff --git a/engine/src/environments/fairy_state/fairystate.h b/engine/src/environments/fairy_state/fairystate.h new file mode 100644 index 00000000..6fc3534f --- /dev/null +++ b/engine/src/environments/fairy_state/fairystate.h @@ -0,0 +1,145 @@ +#ifndef FAIRYSTATE_H +#define FAIRYSTATE_H + +#include "fairyboard.h" +#include "fairyoutputrepresentation.h" +#include "state.h" +#include "uci.h" +#include "variant.h" + + +class StateConstantsFairy : public StateConstantsInterface +{ +public: + static uint BOARD_WIDTH() { + return NB_SQUARES_HORIZONTAL(); + } + static uint BOARD_HEIGHT() { + return NB_SQUARES_VERTICAL(); + } + static uint NB_SQUARES() { + return BOARD_WIDTH() * BOARD_HEIGHT(); + } + static uint NB_CHANNELS_TOTAL() { + return NB_CHANNELS_POS() + NB_CHANNELS_CONST(); + } + static uint NB_LABELS() { + return 2086; + } + static uint NB_LABELS_POLICY_MAP() { + return 4500; + } + static uint NB_AUXILIARY_OUTPUTS() { + return 0U; + } + static uint NB_PLAYERS() { + return 2; + } + template + static MoveIdx action_to_index(Action action) { + switch (p) { + case normal: + switch (m) { + case notMirrored: + return FairyOutputRepresentation::MV_LOOKUP[action]; + case mirrored: + return FairyOutputRepresentation::MV_LOOKUP_MIRRORED[action]; + default: + return FairyOutputRepresentation::MV_LOOKUP[action]; + } + case classic: + switch (m) { + case notMirrored: + return FairyOutputRepresentation::MV_LOOKUP_CLASSIC[action]; + case mirrored: + return FairyOutputRepresentation::MV_LOOKUP_MIRRORED_CLASSIC[action]; + default: + return FairyOutputRepresentation::MV_LOOKUP_CLASSIC[action]; + } + default: + return FairyOutputRepresentation::MV_LOOKUP[action]; + } + } + // Currently only ucci notation is supported + static string action_to_uci(Action action, bool is960) { + Move m = Move(action); + Square from = from_sq(m); + Square to = to_sq(m); + + if (m == MOVE_NONE) { + return Options["Protocol"] == "usi" ? "resign" : "(none)"; + } + if (m == MOVE_NULL) { + return "0000"; + } + + if (is_pass(m) && Options["Protocol"] == "xboard") { + return "@@@@"; + } + string fromSquare = rank_of(from) < RANK_10 ? string{char('a' + file_of(from)), char('1' + rank_of(from))} + : string{char('a' + file_of(from)), '1', '0'}; + string toSquare = rank_of(to) < RANK_10 ? string{char('a' + file_of(to)), char('1' + rank_of(to))} + : string{char('a' + file_of(to)), '1', '0'}; + return fromSquare + toSquare; + } + static void init(bool isPolicyMap) { + FairyOutputRepresentation::init_labels(); + FairyOutputRepresentation::init_policy_constants(isPolicyMap); + } +#ifdef MODE_XIANGQI + static uint NB_SQUARES_HORIZONTAL() { + return 9; + } + static uint NB_SQUARES_VERTICAL() { + return 10; + } + static uint NB_CHANNELS_POS() { + return 26; + } + static uint NB_CHANNELS_CONST() { + return 2; + } + static float MAX_NB_PRISONERS() { + return 5; + } + static float MAX_FULL_MOVE_COUNTER() { + return 500; + } + #endif +}; + +class FairyState : public State +{ +private: + FairyBoard board; + StateListPtr states; + +public: + FairyState(); + FairyState(const FairyState& f); + + std::vector legal_actions() const override; + void set(const std::string &fenStr, bool isChess960, int variant) override; + void get_state_planes(bool normalize, float *inputPlanes) const override; + unsigned int steps_from_null() const override; + bool is_chess960() const override; + std::string fen() const override; + void do_action(Action action) override; + void undo_action(Action action) override; + void prepare_action() override; + int side_to_move() const override; + Key hash_key() const override; + void flip() override; + Action uci_to_action(std::string& uciStr) const override; + TerminalType is_terminal(size_t numberLegalMoves, bool inCheck, float& customTerminalValue) const override; + Result check_result(bool inCheck) const; + bool gives_check(Action action) const override; + void print(std::ostream& os) const override; + FairyState* clone() const override; + unsigned int number_repetitions() const override; + string action_to_san(Action action, const std::vector &legalActions, bool leadsToWin, bool bookMove) const override; + Tablebase::WDLScore check_for_tablebase_wdl(Tablebase::ProbeState &result) override; + void set_auxiliary_outputs(const float* auxiliaryOutputs) override; +}; + +#endif // FAIRYSTATE_H diff --git a/engine/src/environments/fairy_state/fairyutil.cpp b/engine/src/environments/fairy_state/fairyutil.cpp new file mode 100644 index 00000000..c5ceeba2 --- /dev/null +++ b/engine/src/environments/fairy_state/fairyutil.cpp @@ -0,0 +1,35 @@ +#include "fairyutil.h" +#include + + +Square get_origin_square(const string& uciMove) +{ + File fromFile = FILE_LOOKUP.at(uciMove[0]); + Rank fromRank = isdigit(uciMove[2]) ? RANK_10 : RANK_LOOKUP.at(uciMove[1]); + return make_square(fromFile, fromRank); +} + +Square get_destination_square(const string& uciMove) +{ + File toFile; + Rank toRank; + if (uciMove.size() == 6) { + toFile = FILE_LOOKUP.at(uciMove[3]); + toRank = RANK_10; + } + else if (uciMove.size() == 5) { + if (isdigit(uciMove[2])) { + toFile = FILE_LOOKUP.at(uciMove[3]); + toRank = RANK_LOOKUP.at(uciMove[4]); + } + else { + toFile = FILE_LOOKUP.at(uciMove[2]); + toRank = RANK_10; + } + } + else { + toFile = FILE_LOOKUP.at(uciMove[2]); + toRank = RANK_LOOKUP.at(uciMove[3]); + } + return make_square(toFile, toRank); +} diff --git a/engine/src/environments/fairy_state/fairyutil.h b/engine/src/environments/fairy_state/fairyutil.h new file mode 100644 index 00000000..1872f417 --- /dev/null +++ b/engine/src/environments/fairy_state/fairyutil.h @@ -0,0 +1,45 @@ +#ifndef FAIRYUTIL_H +#define FAIRYUTIL_H + +#include + +using namespace std; + +const unordered_map FILE_LOOKUP = { + {'a', FILE_A}, + {'b', FILE_B}, + {'c', FILE_C}, + {'d', FILE_D}, + {'e', FILE_E}, + {'f', FILE_F}, + {'g', FILE_G}, + {'h', FILE_H}, + {'i', FILE_I}}; + +// Note that we have 10 ranks but use a char to Rank lookup... +const unordered_map RANK_LOOKUP = { + {'1', RANK_1}, + {'2', RANK_2}, + {'3', RANK_3}, + {'4', RANK_4}, + {'5', RANK_5}, + {'6', RANK_6}, + {'7', RANK_7}, + {'8', RANK_8}, + {'9', RANK_9}}; + +/** + * @brief get_origin_square Returns the origin square for a valid ucciMove + * @param uciMove uci-Move in string notation + * @return origin square + */ +Square get_origin_square(const string &uciMove); + +/** + * @brief get_origin_square Returns the destination square for a valid ucciMove + * @param uciMove uci-Move in string notation + * @return destination square + */ +Square get_destination_square(const string &uciMove); + +#endif //FAIRYUTIL_H diff --git a/engine/src/nn/tensorrtapi.cpp b/engine/src/nn/tensorrtapi.cpp index 41399063..e72679b3 100644 --- a/engine/src/nn/tensorrtapi.cpp +++ b/engine/src/nn/tensorrtapi.cpp @@ -34,7 +34,7 @@ #include "EntropyCalibrator.h" #include "stateobj.h" #include "../util/communication.h" -#ifndef MODE_POMMERMAN +#if !defined(MODE_POMMERMAN) && !defined(MODE_XIANGQI) #include "environments/chess_related/chessbatchstream.h" #endif @@ -220,7 +220,7 @@ void TensorrtAPI::set_config_settings(SampleUniquePtr& #elif defined MODE_CRAZYHOUSE calibrationStream.reset(new ChessBatchStream(1, 232)); #endif -#if !defined(MODE_POMMERMAN) && !defined(MODE_OPEN_SPIEL) +#if !defined(MODE_POMMERMAN) && !defined(MODE_OPEN_SPIEL) && !defined(MODE_XIANGQI) calibrator.reset(new Int8EntropyCalibrator2(*(dynamic_cast(calibrationStream.get())), 0, "model", "data")); #endif config->setInt8Calibrator(calibrator.get()); diff --git a/engine/src/state.h b/engine/src/state.h index 82b10806..b989b1e2 100644 --- a/engine/src/state.h +++ b/engine/src/state.h @@ -416,7 +416,7 @@ class State * @param isChess960 If true 960 mode will be active * @param variant Variant which the position corresponds to */ - virtual void init(int variant, bool isChess960) = 0; + //virtual void init(int variant, bool isChess960) = 0; }; #endif // GAMESTATE_H diff --git a/engine/src/stateobj.h b/engine/src/stateobj.h index a9c12e9c..e9b6ea61 100644 --- a/engine/src/stateobj.h +++ b/engine/src/stateobj.h @@ -40,6 +40,9 @@ using blaze::DynamicVector; #include "pommermanstate.h" #elif MODE_OPEN_SPIEL #include "environments/open_spiel/openspielstate.h" +#elif MODE_XIANGQI +#include "environments/fairy_state/fairystate.h" +#include "environments/fairy_state/fairyoutputrepresentation.h" #else #include "environments/chess_related/boardstate.h" #include "environments/chess_related/outputrepresentation.h" @@ -51,6 +54,9 @@ using blaze::DynamicVector; #elif MODE_OPEN_SPIEL using StateObj = OpenSpielState; using StateConstants = StateConstantsOpenSpiel; +#elif MODE_XIANGQI + using StateObj = FairyState; + using StateConstants = StateConstantsFairy; #else using StateObj = BoardState; using StateConstants = StateConstantsBoard; diff --git a/engine/src/uci/crazyara.cpp b/engine/src/uci/crazyara.cpp index c68ccf6c..fecffe0d 100644 --- a/engine/src/uci/crazyara.cpp +++ b/engine/src/uci/crazyara.cpp @@ -41,6 +41,9 @@ #include "optionsuci.h" #include "../tests/benchmarkpositions.h" #include "util/communication.h" +#ifdef MODE_XIANGQI + #include "piece.h" +#endif #ifdef MXNET #include "nn/mxnetapi.h" #elif defined TENSORRT @@ -60,6 +63,8 @@ CrazyAra::CrazyAra(): playSettings(PlaySettings()), #ifdef MODE_CRAZYHOUSE variant(CRAZYHOUSE_VARIANT), +#elif defined(MODE_XIANGQI) + variant(*variants.find("xiangqi")->second), #else variant(CHESS_VARIANT), #endif @@ -89,11 +94,14 @@ void CrazyAra::uci_loop(int argc, char *argv[]) unique_ptr state = make_unique(); string token, cmd; EvalInfo evalInfo; +#ifndef MODE_XIANGQI auto uiThread = make_shared(0); - variant = UCI::variant_from_name(Options["UCI_Variant"]); state->set(StartFENs[variant], is960, variant); - +#endif +#ifdef MODE_XIANGQI + state->set(variant.startFen, is960, 0); +#endif for (int i = 1; i < argc; ++i) cmd += string(argv[i]) + " "; @@ -104,7 +112,9 @@ void CrazyAra::uci_loop(int argc, char *argv[]) }; do { - +#ifdef MODE_XIANGQI + state.get(); +#endif if (it < commands.size()) { cmd = commands[it]; cout << ">>" << cmd << endl; @@ -201,7 +211,13 @@ void CrazyAra::go(const string& fen, string goCommand, EvalInfo& evalInfo) { unique_ptr state = make_unique(); string token, cmd; + +#ifndef MODE_XIANGQI + variant = UCI::variant_from_name(Options["UCI_Variant"]); state->set(StartFENs[variant], is960, variant); +#else + state->set(variant.startFen, is960, 0); +#endif istringstream is("fen " + fen); position(state.get(), is); @@ -231,11 +247,17 @@ void CrazyAra::position(StateObj* state, istringstream& is) Action action; string token, fen; - +#ifndef MODE_XIANGQI + variant = UCI::variant_from_name(Options["UCI_Variant"]); +#endif is >> token; if (token == "startpos") { +#ifndef MODE_XIANGQI fen = StartFENs[variant]; +#else + fen = variant.startFen; +#endif is >> token; // Consume "moves" token if any } else if (token == "fen") { @@ -246,7 +268,11 @@ void CrazyAra::position(StateObj* state, istringstream& is) else return; - state->set(fen, is960, variant); +#ifndef MODE_XIANGQI + state->set(fen, is960, variant); +#else + state->set(fen, is960, 0); +#endif Action lastMove = ACTION_NONE; // Parse move list (if any) @@ -382,11 +408,27 @@ void CrazyAra::init_rl_settings() void CrazyAra::init() { +#ifndef MODE_XIANGQI OptionsUCI::init(Options); Bitboards::init(); Position::init(); Bitbases::init(); Search::init(); +#endif +#ifdef MODE_XIANGQI + pieceMap.init(); + OptionsUCI::init(Options); + UCI::init(Options); + Bitboards::init(); + Position::init(); + Bitbases::init(); + Search::init(); + Tablebases::init(""); + + // This is a workaround for compatibility with Fairy-Stockfish + // Option with key "Threads" is also removed. (See /3rdparty/Fairy-Stockfish/src/ucioption.cpp) + Options.erase("Hash"); +#endif } bool CrazyAra::is_ready() diff --git a/engine/src/uci/main.cpp b/engine/src/uci/main.cpp index 356a3d6b..874c925f 100644 --- a/engine/src/uci/main.cpp +++ b/engine/src/uci/main.cpp @@ -29,9 +29,13 @@ #include "stateobj.h" #include #include "crazyara.h" +#include "variants.h" #ifndef BUILD_TESTS int main(int argc, char* argv[]) { +#ifdef XIANGQI + variants.init(); +#endif CrazyAra crazyara; crazyara.init(); crazyara.welcome(); diff --git a/engine/src/uci/optionsuci.cpp b/engine/src/uci/optionsuci.cpp index 80569080..e16bb056 100644 --- a/engine/src/uci/optionsuci.cpp +++ b/engine/src/uci/optionsuci.cpp @@ -24,6 +24,9 @@ */ #include "optionsuci.h" +#ifdef MODE_XIANGQI + #include "variant.h" +#endif #include #include #include @@ -39,10 +42,11 @@ void on_logger(const Option& o) { } // method is based on 3rdparty/Stockfish/uci.cpp +#ifndef MODE_XIANGQI void on_tb_path(const Option& o) { Tablebases::init(UCI::variant_from_name(Options["UCI_Variant"]), Options["SyzygyPath"]); } - +#endif void OptionsUCI::init(OptionsMap &o) { @@ -141,12 +145,16 @@ void OptionsUCI::init(OptionsMap &o) #else o["Simulations"] << Option(0, 0, 99999999); #endif +#ifndef MODE_XIANGQI o["SyzygyPath"] << Option("", on_tb_path); +#endif o["Threads"] << Option(2, 1, 512); #ifdef MODE_CRAZYHOUSE o["UCI_Variant"] << Option("crazyhouse", {"crazyhouse", "crazyhouse"}); #elif defined MODE_LICHESS o["UCI_Variant"] << Option(availableVariants.front().c_str(), availableVariants); +#elif defined MODE_XIANGQI + o["UCI_Variant"] << Option("xiangqi", {"xiangqi", "xiangqi"}); #else // MODE = MODE_CHESS o["UCI_Variant"] << Option("chess", {"chess", "chess"}); #endif @@ -193,9 +201,14 @@ void OptionsUCI::setoption(istringstream &is, Variant& variant, StateObj& state) cout << "info string Updated option " << name << " to " << value << endl; std::transform(name.begin(), name.end(), name.begin(), ::tolower); if (name == "uci_variant") { +#ifdef XIANGQI + // Workaround. Fairy-Stockfish does not use an enum for variants + cout << "info string variant " << "Xiangqi" << " startpos " << "rnbakabnr/9/1c5c1/p1p1p1p1p/9/9/P1P1P1P1P/1C5C1/9/RNBAKABNR w - - 0 1" << endl; +#else variant = UCI::variant_from_name(value); cout << "info string variant " << (string)Options["UCI_Variant"] << " startpos " << StartFENs[variant] << endl; state.set(StartFENs[variant], Options["UCI_Chess960"], variant); +#endif #ifdef MODE_LICHESS // Set model path for new variant; just in case set model_contender as well Options["Model_Directory"] << ((string) "model" + "/" + string(Options["UCI_Variant"])).c_str(); diff --git a/engine/src/uci/variants.h b/engine/src/uci/variants.h index c4391e14..981e402a 100644 --- a/engine/src/uci/variants.h +++ b/engine/src/uci/variants.h @@ -40,10 +40,14 @@ static vector availableVariants = { "giveaway", // antichess "horde", "kingofthehill", - "racingkings" + "racingkings", + "xiangqi" }; // FEN strings of the initial positions +#ifdef XIANGQI +const int SUBVARIANT_NB = 20; // Thats high quality code +#endif const static string StartFENs[SUBVARIANT_NB] = { "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1", #ifdef ANTI @@ -104,6 +108,9 @@ const static string StartFENs[SUBVARIANT_NB] = { #ifdef TWOKINGSSYMMETRIC "rnbqkknr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKKNR w KQkq - 0 1", #endif + #ifdef XIANGQI + "rnbakabnr/9/1c5c1/p1p1p1p1p/9/9/P1P1P1P1P/1C5C1/9/RNBAKABNR w - - 0 1", + #endif }; #endif // VARIANTS_H From 024bd37175adeab429a78fcefa3c3479552a0713 Mon Sep 17 00:00:00 2001 From: xsr7qsr Date: Mon, 26 Apr 2021 12:48:47 +0200 Subject: [PATCH 05/19] Add xiangqi support for MCTS --- engine/3rdparty/Fairy-Stockfish | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engine/3rdparty/Fairy-Stockfish b/engine/3rdparty/Fairy-Stockfish index c4edc95b..a9d2a160 160000 --- a/engine/3rdparty/Fairy-Stockfish +++ b/engine/3rdparty/Fairy-Stockfish @@ -1 +1 @@ -Subproject commit c4edc95b096880362f0f02a0a6fc627bb1ddf9b7 +Subproject commit a9d2a16035856b1bd0a88e5646b36619eb1264da From 08af294f7e380d6cf49055cf1b62761e6a0fcd4c Mon Sep 17 00:00:00 2001 From: xsr7qsr Date: Mon, 26 Apr 2021 13:54:22 +0200 Subject: [PATCH 06/19] Fix duplicated main.cpp and .nnue file --- engine/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/engine/CMakeLists.txt b/engine/CMakeLists.txt index 8c983420..d61474b0 100644 --- a/engine/CMakeLists.txt +++ b/engine/CMakeLists.txt @@ -171,6 +171,7 @@ file(GLOB fsf_related_files ) list(REMOVE_ITEM fsf_related_files ${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/Fairy-Stockfish/src/ffishjs.cpp) list(REMOVE_ITEM fsf_related_files ${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/Fairy-Stockfish/src/pyffish.cpp) +list(REMOVE_ITEM fsf_related_files ${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/Fairy-Stockfish/src/main.cpp) if (NOT MODE_XIANGQI) set(source_files From acdd5b9e519e997ca2e35b0b2dd66bcbe321f1e2 Mon Sep 17 00:00:00 2001 From: xsr7qsr Date: Mon, 26 Apr 2021 13:54:26 +0200 Subject: [PATCH 07/19] Fix duplicated main.cpp and .nnue file --- engine/3rdparty/Fairy-Stockfish | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engine/3rdparty/Fairy-Stockfish b/engine/3rdparty/Fairy-Stockfish index a9d2a160..ba5cf81c 160000 --- a/engine/3rdparty/Fairy-Stockfish +++ b/engine/3rdparty/Fairy-Stockfish @@ -1 +1 @@ -Subproject commit a9d2a16035856b1bd0a88e5646b36619eb1264da +Subproject commit ba5cf81cf0e522e6ca0974a9fdc53674cea4866c From 53f0c3d29ba1353dad498996a9b112872af49578 Mon Sep 17 00:00:00 2001 From: xsr7qsr Date: Mon, 26 Apr 2021 19:30:46 +0200 Subject: [PATCH 08/19] Delete killed Fairy-Stockfish submodule --- engine/3rdparty/Fairy-Stockfish | 1 - 1 file changed, 1 deletion(-) delete mode 160000 engine/3rdparty/Fairy-Stockfish diff --git a/engine/3rdparty/Fairy-Stockfish b/engine/3rdparty/Fairy-Stockfish deleted file mode 160000 index ba5cf81c..00000000 --- a/engine/3rdparty/Fairy-Stockfish +++ /dev/null @@ -1 +0,0 @@ -Subproject commit ba5cf81cf0e522e6ca0974a9fdc53674cea4866c From aa001047c64c7772ccdf999b499411dab648356f Mon Sep 17 00:00:00 2001 From: xsr7qsr Date: Mon, 26 Apr 2021 20:07:55 +0200 Subject: [PATCH 09/19] Reintroduce Fairy-Stockfish fork --- engine/3rdparty/Fairy-Stockfish | 1 + 1 file changed, 1 insertion(+) create mode 160000 engine/3rdparty/Fairy-Stockfish diff --git a/engine/3rdparty/Fairy-Stockfish b/engine/3rdparty/Fairy-Stockfish new file mode 160000 index 00000000..c4edc95b --- /dev/null +++ b/engine/3rdparty/Fairy-Stockfish @@ -0,0 +1 @@ +Subproject commit c4edc95b096880362f0f02a0a6fc627bb1ddf9b7 From 77282d6cc0b52a99e10e3d73fe41a24e90123da8 Mon Sep 17 00:00:00 2001 From: xsr7qsr Date: Mon, 26 Apr 2021 20:26:58 +0200 Subject: [PATCH 10/19] Reintroduce Fairy-Stockfish fork --- engine/3rdparty/Fairy-Stockfish | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engine/3rdparty/Fairy-Stockfish b/engine/3rdparty/Fairy-Stockfish index c4edc95b..f75abe5d 160000 --- a/engine/3rdparty/Fairy-Stockfish +++ b/engine/3rdparty/Fairy-Stockfish @@ -1 +1 @@ -Subproject commit c4edc95b096880362f0f02a0a6fc627bb1ddf9b7 +Subproject commit f75abe5d3cb1831a79edc050736d790f81cbdf4e From 7e6cf34566bdde3da255895163714e37dd06b12b Mon Sep 17 00:00:00 2001 From: xsr7qsr Date: Mon, 26 Apr 2021 20:35:45 +0200 Subject: [PATCH 11/19] Add tests for xiangqi planes --- engine/tests/tests.cpp | 207 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 207 insertions(+) diff --git a/engine/tests/tests.cpp b/engine/tests/tests.cpp index d25f1d0a..467d54e2 100644 --- a/engine/tests/tests.cpp +++ b/engine/tests/tests.cpp @@ -27,6 +27,7 @@ #ifdef BUILD_TESTS #include +#ifndef MODE_XIANGQI #include #include "catch.hpp" #include "uci.h" @@ -385,5 +386,211 @@ TEST_CASE("State: clone()"){ unique_ptr state2 = unique_ptr(state.clone()); REQUIRE(state2->fen() == state.fen()); } +#else +#include "catch.hpp" +#include "piece.h" +#include "thread.h" +#include "uci.h" +#include "uci/optionsuci.h" +#include "variant.h" +#include "environments/fairy_state/fairyboard.h" +#include "environments/fairy_state/fairystate.h" +#include "environments/fairy_state/fairyutil.h" +#include "environments/fairy_state/fairyinputrepresentation.h" + + +void init() { + pieceMap.init(); + variants.init(); + OptionsUCI::init(Options); + UCI::init(Options); + Bitboards::init(); + Position::init(); + Bitbases::init(); + Search::init(); + Tablebases::init(""); +} + +void get_planes_statistics(const FairyBoard* pos, bool normalize, double& sum, double& maxNum, double& key, size_t& argMax) { + float inputPlanes[StateConstantsFairy::NB_VALUES_TOTAL()]; + board_to_planes(pos, normalize, inputPlanes); + sum = 0; + maxNum = 0; + key = 0; + argMax = 0; + for (unsigned int i = 0; i < StateConstantsFairy::NB_VALUES_TOTAL(); ++i) { + const float val = inputPlanes[i]; + sum += val; + if (val > maxNum) { + maxNum = val; + argMax = i; + } + key += i * val; + } +} + +void apply_moves_to_board(const vector& uciMoves, FairyBoard& pos, StateListPtr& states) { + for (string uciMove : uciMoves) { + Move m = UCI::to_move(pos, uciMove); + states->emplace_back(); + pos.do_move(m, states->back()); + } +} + +void apply_move_to_board(string uciMove, FairyBoard& pos, StateListPtr& states) { + Move m = UCI::to_move(pos, uciMove); + states->emplace_back(); + pos.do_move(m, states->back()); +} +TEST_CASE("Xiangqi_Input_Planes") { + init(); + FairyBoard pos; + StateInfo newState; + StateListPtr states = StateListPtr(new std::deque(1)); + + auto uiThread = make_shared(0); + + const Variant *xiangqiVariant = variants.find("xiangqi")->second; + string startFen = xiangqiVariant->startFen; + pos.set(xiangqiVariant, startFen, false, &states->back(), uiThread.get(), false); + + // starting position test + double sum, maxNum, key; + size_t argMax; + get_planes_statistics(&pos, false, sum, maxNum, key, argMax); + REQUIRE(sum == 122); + REQUIRE(maxNum == 1); + REQUIRE(key == 236909); + REQUIRE(argMax == 85); + REQUIRE(pos.fen() == startFen); + + string uciMove = "c4c5"; + apply_move_to_board(uciMove, pos, states); + get_planes_statistics(&pos, false, sum, maxNum, key, argMax); + REQUIRE(sum == 32); + REQUIRE(maxNum == 1); + REQUIRE(key == 22313); + REQUIRE(argMax == 85); + REQUIRE(pos.fen() == "rnbakabnr/9/1c5c1/p1p1p1p1p/9/2P6/P3P1P1P/1C5C1/9/RNBAKABNR b - - 1 1"); + + uciMove = "g7g6"; + apply_move_to_board(uciMove, pos, states); + get_planes_statistics(&pos, false, sum, maxNum, key, argMax); + REQUIRE(sum == 212); + REQUIRE(maxNum == 1); + REQUIRE(key == 459614); + REQUIRE(argMax == 85); + REQUIRE(pos.fen() == "rnbakabnr/9/1c5c1/p1p1p3p/6p2/2P6/P3P1P1P/1C5C1/9/RNBAKABNR w - - 2 2"); + + uciMove = "h3g3"; + apply_move_to_board(uciMove, pos, states); + get_planes_statistics(&pos, false, sum, maxNum, key, argMax); + REQUIRE(sum == 122); + REQUIRE(maxNum == 1); + REQUIRE(key == 245008); + REQUIRE(argMax == 85); + REQUIRE(pos.fen() == "rnbakabnr/9/1c5c1/p1p1p3p/6p2/2P6/P3P1P1P/1C4C2/9/RNBAKABNR b - - 3 2"); + + uciMove = "c10e8"; + apply_move_to_board(uciMove, pos, states); + get_planes_statistics(&pos, false, sum, maxNum, key, argMax); + REQUIRE(sum == 302); + REQUIRE(maxNum == 2); + REQUIRE(key == 682338); + REQUIRE(argMax == 2430); + REQUIRE(pos.fen() == "rn1akabnr/9/1c2b2c1/p1p1p3p/6p2/2P6/P3P1P1P/1C4C2/9/RNBAKABNR w - - 4 3"); + + uciMove = "h1i3"; + apply_move_to_board(uciMove, pos, states); + get_planes_statistics(&pos, false, sum, maxNum, key, argMax); + REQUIRE(sum == 212); + REQUIRE(maxNum == 2); + REQUIRE(key == 467716); + REQUIRE(argMax == 2430); + REQUIRE(pos.fen() == "rn1akabnr/9/1c2b2c1/p1p1p3p/6p2/2P6/P3P1P1P/1C4C1N/9/RNBAKAB1R b - - 5 3"); + + uciMove = "h10g8"; + apply_move_to_board(uciMove, pos, states); + get_planes_statistics(&pos, false, sum, maxNum, key, argMax); + REQUIRE(sum == 392); + REQUIRE(maxNum == 3); + REQUIRE(key == 905043); + REQUIRE(argMax == 2430); + REQUIRE(pos.fen() == "rn1akab1r/9/1c2b1nc1/p1p1p3p/6p2/2P6/P3P1P1P/1C4C1N/9/RNBAKAB1R w - - 6 4"); + + uciMove = "i1h1"; + apply_move_to_board(uciMove, pos, states); + get_planes_statistics(&pos, false, sum, maxNum, key, argMax); + REQUIRE(sum == 302); + REQUIRE(maxNum == 3); + REQUIRE(key == 690401); + REQUIRE(argMax == 2430); + REQUIRE(pos.fen() == "rn1akab1r/9/1c2b1nc1/p1p1p3p/6p2/2P6/P3P1P1P/1C4C1N/9/RNBAKABR1 b - - 7 4"); + + uciMove = "i10h10"; + apply_move_to_board(uciMove, pos, states); + get_planes_statistics(&pos, false, sum, maxNum, key, argMax); + REQUIRE(sum == 482); + REQUIRE(maxNum == 4); + REQUIRE(key == 1127746); + REQUIRE(argMax == 2430); + REQUIRE(pos.fen() == "rn1akabr1/9/1c2b1nc1/p1p1p3p/6p2/2P6/P3P1P1P/1C4C1N/9/RNBAKABR1 w - - 8 5"); + + uciMove = "b3e3"; + apply_move_to_board(uciMove, pos, states); + get_planes_statistics(&pos, false, sum, maxNum, key, argMax); + REQUIRE(sum == 392); + REQUIRE(maxNum == 4); + REQUIRE(key == 913108); + REQUIRE(argMax == 2430); + REQUIRE(pos.fen() == "rn1akabr1/9/1c2b1nc1/p1p1p3p/6p2/2P6/P3P1P1P/4C1C1N/9/RNBAKABR1 b - - 9 5"); + + uciMove = "h8h4"; + apply_move_to_board(uciMove, pos, states); + get_planes_statistics(&pos, false, sum, maxNum, key, argMax); + REQUIRE(sum == 572); + REQUIRE(maxNum == 5); + REQUIRE(key == 1350490); + REQUIRE(argMax == 2430); + REQUIRE(pos.fen() == "rn1akabr1/9/1c2b1n2/p1p1p3p/6p2/2P6/P3P1PcP/4C1C1N/9/RNBAKABR1 w - - 10 6"); + + uciMove = "b1c3"; + apply_move_to_board(uciMove, pos, states); + get_planes_statistics(&pos, false, sum, maxNum, key, argMax); + REQUIRE(sum == 482); + REQUIRE(maxNum == 5); + REQUIRE(key == 1135796); + REQUIRE(argMax == 2430); + REQUIRE(pos.fen() == "rn1akabr1/9/1c2b1n2/p1p1p3p/6p2/2P6/P3P1PcP/2N1C1C1N/9/R1BAKABR1 b - - 11 6"); + + uciMove = "b10d9"; + apply_move_to_board(uciMove, pos, states); + get_planes_statistics(&pos, false, sum, maxNum, key, argMax); + REQUIRE(sum == 662); + REQUIRE(maxNum == 6); + REQUIRE(key == 1573189); + REQUIRE(argMax == 2430); + REQUIRE(pos.fen() == "r2akabr1/3n5/1c2b1n2/p1p1p3p/6p2/2P6/P3P1PcP/2N1C1C1N/9/R1BAKABR1 w - - 12 7"); + + uciMove = "a1a2"; + apply_move_to_board(uciMove, pos, states); + get_planes_statistics(&pos, false, sum, maxNum, key, argMax); + REQUIRE(sum == 572); + REQUIRE(maxNum == 6); + REQUIRE(key == 1358503); + REQUIRE(argMax == 2430); + REQUIRE(pos.fen() == "r2akabr1/3n5/1c2b1n2/p1p1p3p/6p2/2P6/P3P1PcP/2N1C1C1N/R8/2BAKABR1 b - - 13 7"); + + uciMove = "d10e9"; + apply_move_to_board(uciMove, pos, states); + get_planes_statistics(&pos, false, sum, maxNum, key, argMax); + REQUIRE(sum == 752); + REQUIRE(maxNum == 7); + REQUIRE(key == 1795895); + REQUIRE(argMax == 2430); + REQUIRE(pos.fen() == "r3kabr1/3na4/1c2b1n2/p1p1p3p/6p2/2P6/P3P1PcP/2N1C1C1N/R8/2BAKABR1 w - - 14 8"); + REQUIRE(StateConstantsFairy::NB_VALUES_TOTAL() == 28*90); +} +#endif // MODE_XIANGQI #endif From 26869ed9d3d7c87661c1f1de39c0c0c551ab37b3 Mon Sep 17 00:00:00 2001 From: xsr7qsr Date: Mon, 26 Apr 2021 21:00:26 +0200 Subject: [PATCH 12/19] Add preprocessing support for xiangqi (data is given as CSV) --- .../preprocessing/convert_csv_to_planes.ipynb | 72 +++ .../preprocessing/csv_to_planes_converter.py | 490 ++++++++++++++++ DeepCrazyhouse/src/preprocessing/ucci_util.py | 151 +++++ .../preprocessing/xiangqi_board/__init__.py | 0 .../xiangqi_board/xiangqi_board.py | 554 ++++++++++++++++++ requirements.txt | 1 + 6 files changed, 1268 insertions(+) create mode 100644 DeepCrazyhouse/src/preprocessing/convert_csv_to_planes.ipynb create mode 100644 DeepCrazyhouse/src/preprocessing/csv_to_planes_converter.py create mode 100644 DeepCrazyhouse/src/preprocessing/ucci_util.py create mode 100644 DeepCrazyhouse/src/preprocessing/xiangqi_board/__init__.py create mode 100644 DeepCrazyhouse/src/preprocessing/xiangqi_board/xiangqi_board.py diff --git a/DeepCrazyhouse/src/preprocessing/convert_csv_to_planes.ipynb b/DeepCrazyhouse/src/preprocessing/convert_csv_to_planes.ipynb new file mode 100644 index 00000000..c155e151 --- /dev/null +++ b/DeepCrazyhouse/src/preprocessing/convert_csv_to_planes.ipynb @@ -0,0 +1,72 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "dimensional-sword", + "metadata": {}, + "outputs": [], + "source": [ + "import sys, os\n", + "sys.path.insert(0,'../../../')\n", + "from csv_to_planes_converter import CSV2PlanesConverter" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "mathematical-antigua", + "metadata": {}, + "outputs": [], + "source": [ + "PATH_CSV = \"\"\n", + "PATH_EXPORT = \"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "human-capitol", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "csv_converter = CSV2PlanesConverter(PATH_CSV, min_elo=0, min_number_moves=0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "surface-waste", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "csv_converter.export_batches(PATH_EXPORT)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/DeepCrazyhouse/src/preprocessing/csv_to_planes_converter.py b/DeepCrazyhouse/src/preprocessing/csv_to_planes_converter.py new file mode 100644 index 00000000..2408cf6a --- /dev/null +++ b/DeepCrazyhouse/src/preprocessing/csv_to_planes_converter.py @@ -0,0 +1,490 @@ +from xiangqi_board.xiangqi_board import XiangqiBoard +from DeepCrazyhouse.src.domain.variants.constants import LABELS_XIANGQI +from ucci_util import xiangqi_board_move_to_ucci, mirror_ucci +import time +import re +import zarr +import math +import logging +import numpy as np +import pandas as pd +from numcodecs import Blosc + + +class CSV2PlanesConverter: + def __init__(self, + path_csv, + min_elo=None, + min_number_moves=None, + num_games_per_file=1000, + clevel=5, + compression="lz4"): + log = logging.getLogger() + log.setLevel(logging.DEBUG) + + logging.info("Reading csv into pandas dataframe") + self._df = pd.read_csv(path_csv, delimiter=';') + logging.info("Reading csv finished") + + self._min_elo = min_elo + self._min_number_moves = min_number_moves + + if min_elo is not None or min_number_moves is not None: + self._filter_csv() + + self.xiangqi_board = XiangqiBoard() + + self._num_games_per_file = num_games_per_file + self._clevel = clevel + self._compression = compression + + def _filter_csv(self): + logging.info("Filter csv") + + if self._min_elo is not None: + self._df = self._df[(self._df.red_elo >= self._min_elo) & (self._df.black_elo >= self._min_elo)] + if self._min_number_moves is not None: + self._df = self._df[self._df.num_moves >= self._min_number_moves] + + logging.info("Filter csv finished") + + def get_plain_piece_planes(self): + plain_board = np.zeros((10, 9), dtype=np.int) + + king_plane, advisor_plane, elephant_plane, horse_plane, \ + rook_plane, cannon_plane, pawn_plane = (plain_board.copy() for i in range(7)) + + red_piece_planes = np.array((king_plane, + advisor_plane, + elephant_plane, + horse_plane, + rook_plane, + cannon_plane, + pawn_plane)) + + black_piece_planes = np.array((king_plane, + advisor_plane, + elephant_plane, + horse_plane, + rook_plane, + cannon_plane, + pawn_plane)) + return red_piece_planes, black_piece_planes + + def get_plain_pocket_count_planes(self): + plain_board = np.zeros((10, 9), dtype=np.int) + + advisor_pocket_plane, elephant_pocket_plane, horse_pocket_plane, rook_pocket_plane, \ + cannon_pocket_plane, pawn_pocket_plane = (plain_board.copy() for i in range(6)) + + red_pocket_planes = np.array((advisor_pocket_plane, + elephant_pocket_plane, + horse_pocket_plane, + rook_pocket_plane, + cannon_pocket_plane, + pawn_pocket_plane)) + + black_pocket_planes = np.array((advisor_pocket_plane, + elephant_pocket_plane, + horse_pocket_plane, + rook_pocket_plane, + cannon_pocket_plane, + pawn_pocket_plane)) + return red_pocket_planes, black_pocket_planes + + def get_pocket_count_planes(self): + red_pocket_planes, black_pocket_planes = self.get_plain_pocket_count_planes() + + initial_pieces_on_board = {'k': 1, 'a': 2, 'e': 2, 'h': 2, 'r': 2, 'c': 2, 'p': 5, + 'K': 1, 'A': 2, 'E': 2, 'H': 2, 'R': 2, 'C': 2, 'P': 5} + current_pieces_on_board = self.xiangqi_board.get_num_figures() + for wxf_identifier in current_pieces_on_board.keys(): + if wxf_identifier == 'A': + red_pocket_planes[0].fill(initial_pieces_on_board['A'] - current_pieces_on_board['A']) + elif wxf_identifier == 'a': + black_pocket_planes[0].fill(initial_pieces_on_board['a'] - current_pieces_on_board['a']) + elif wxf_identifier == 'E': + red_pocket_planes[1].fill(initial_pieces_on_board['E'] - current_pieces_on_board['E']) + elif wxf_identifier == 'e': + black_pocket_planes[1].fill(initial_pieces_on_board['e'] - current_pieces_on_board['e']) + elif wxf_identifier == 'H': + red_pocket_planes[2].fill(initial_pieces_on_board['H'] - current_pieces_on_board['H']) + elif wxf_identifier == 'h': + black_pocket_planes[2].fill(initial_pieces_on_board['h'] - current_pieces_on_board['h']) + elif wxf_identifier == 'R': + red_pocket_planes[3].fill(initial_pieces_on_board['R'] - current_pieces_on_board['R']) + elif wxf_identifier == 'r': + black_pocket_planes[3].fill(initial_pieces_on_board['r'] - current_pieces_on_board['r']) + elif wxf_identifier == 'C': + red_pocket_planes[4].fill(initial_pieces_on_board['C'] - current_pieces_on_board['C']) + elif wxf_identifier == 'c': + black_pocket_planes[4].fill(initial_pieces_on_board['c'] - current_pieces_on_board['c']) + elif wxf_identifier == 'P': + red_pocket_planes[5].fill(initial_pieces_on_board['P'] - current_pieces_on_board['P']) + elif wxf_identifier == 'p': + black_pocket_planes[5].fill(initial_pieces_on_board['p'] - current_pieces_on_board['p']) + + return red_pocket_planes, black_pocket_planes + + def board_to_planes(self, red_move, pocket_count=True, flip=True): + red_player_planes, black_player_planes = self.get_plain_piece_planes() + + for row in range(len(self.xiangqi_board.board)): + for col in range(len(self.xiangqi_board.board[0])): + piece = self.xiangqi_board.board[row][col] + if piece != 0: + if piece.wxf_identifier == 'K': + red_player_planes[0][row, col] = 1 + elif piece.wxf_identifier == 'k': + black_player_planes[0][row, col] = 1 + + elif piece.wxf_identifier == 'A': + red_player_planes[1][row, col] = 1 + elif piece.wxf_identifier == 'a': + black_player_planes[1][row, col] = 1 + + elif piece.wxf_identifier == 'E': + red_player_planes[2][row, col] = 1 + elif piece.wxf_identifier == 'e': + black_player_planes[2][row, col] = 1 + + elif piece.wxf_identifier == 'H': + red_player_planes[3][row, col] = 1 + elif piece.wxf_identifier == 'h': + black_player_planes[3][row, col] = 1 + + elif piece.wxf_identifier == 'R': + red_player_planes[4][row, col] = 1 + elif piece.wxf_identifier == 'r': + black_player_planes[4][row, col] = 1 + + elif piece.wxf_identifier == 'C': + red_player_planes[5][row, col] = 1 + elif piece.wxf_identifier == 'c': + black_player_planes[5][row, col] = 1 + + elif piece.wxf_identifier == 'P': + red_player_planes[6][row, col] = 1 + elif piece.wxf_identifier == 'p': + black_player_planes[6][row, col] = 1 + + if flip and not red_move: + for i in range(7): + red_player_planes[i] = np.flip(red_player_planes[i], 0) + black_player_planes[i] = np.flip(black_player_planes[i], 0) + + if red_move: + planes = np.vstack((red_player_planes, black_player_planes)) + else: + planes = np.vstack((black_player_planes, red_player_planes)) + + if pocket_count: + red_pocket_planes, black_pocket_planes = self.get_pocket_count_planes() + if red_move: + pocket_planes = np.vstack((red_pocket_planes, black_pocket_planes)) + else: + pocket_planes = np.vstack((black_pocket_planes, red_pocket_planes)) + planes = np.vstack((planes, pocket_planes)) + + return planes + + def game_to_planes(self, result, movelist, flip=True): + red_move = movelist[0][0].isupper() + + total_move_count = 0 + total_move_count_fen = 0 + + board_planes = self.board_to_planes(red_move, pocket_count=True, flip=flip) + color_plane = np.ones((10, 9), dtype=np.int) if red_move else np.zeros((10, 9), dtype=np.int) + total_move_count_plane = np.full((10, 9), total_move_count_fen, dtype=np.int) + + X_planes = [np.vstack((board_planes, np.array((color_plane, total_move_count_plane))))] + y_value = [] + y_policy = [] + + for i in range(len(movelist)): + if red_move: + if result == "1-0": + y_value.append(1) + elif result == "0-1": + y_value.append(-1) + else: + y_value.append(0) + else: + if result == "1-0": + y_value.append(-1) + elif result == "0-1": + y_value.append(1) + else: + y_value.append(0) + + # Play move + coordinate_change = self.xiangqi_board.parse_single_move(movelist[i], red_move)[0] + + # Build policy + old_pos = coordinate_change[0] + new_pos = coordinate_change[1] + ucci = xiangqi_board_move_to_ucci(old_pos, new_pos) + if flip and not red_move: + ucci = mirror_ucci(ucci) + + # Index of ucci in LABELS_XIANGQI constant + index = np.where(np.asarray(LABELS_XIANGQI) == ucci)[0][0] + + plain_policy_vector = np.zeros((len(LABELS_XIANGQI))) + plain_policy_vector[index] = 1 + y_policy.append(plain_policy_vector) + + total_move_count += 1 + if total_move_count % 2 == 0: + total_move_count_fen += 1 + red_move = not red_move + + board_planes = self.board_to_planes(red_move, pocket_count=True, flip=flip) + color_plane = np.ones((10, 9), dtype=np.int) if red_move else np.zeros((10, 9), dtype=np.int) + total_move_count_plane = np.full((10, 9), total_move_count_fen, dtype=np.int) + + if i != len(movelist) - 1: + X_planes.append(np.vstack((board_planes, np.array((color_plane, total_move_count_plane))))) + + # Reset board + self.xiangqi_board = XiangqiBoard() + + return np.asarray(X_planes), np.asarray(y_value), np.asarray(y_policy) + + def export_batches(self, export_path): + start_time = time.time() + + batch_size = self._num_games_per_file + num_batches = math.ceil(len(self._df) / batch_size) + + regex = re.compile(r'^\d+\.$') + batch_start = 0 + for b in range(num_batches): + logging.info("Creating batch {}/{}".format(str(b + 1), num_batches)) + + batch = {} + for g in range(batch_start, batch_start + batch_size): + if g >= len(self._df): + break + + movelist = self._df.iloc[g].moves + movelist = [move for move in movelist.split(' ') if not regex.match(move)] + result = self._df.iloc[g].result + x, y_value, y_policy = self.game_to_planes(result, movelist) + + # Game playing data + if 'x' not in batch.keys(): + batch['x'] = x + batch['start_indices'] = np.asarray([0]) + else: + batch['start_indices'] = np.concatenate((batch['start_indices'], + np.asarray([len(batch['x'])]))) + batch['x'] = np.concatenate((batch['x'], x)) + + if 'y_value' not in batch.keys(): + batch['y_value'] = y_value + else: + batch['y_value'] = np.concatenate((batch['y_value'], y_value)) + + if 'y_policy' not in batch.keys(): + batch['y_policy'] = y_policy + else: + batch['y_policy'] = np.concatenate((batch['y_policy'], y_policy)) + + # Statistics + if 'elo_red' not in batch.keys(): + batch['elo_red'] = np.asarray([int(self._df.iloc[g].red_elo)]) + else: + current_elo_red = int(self._df.iloc[g].red_elo) + batch['elo_red'] = np.concatenate((batch['elo_red'], np.asarray([current_elo_red]))) + + if 'elo_black' not in batch.keys(): + batch['elo_black'] = np.asarray([int(self._df.iloc[g].black_elo)]) + else: + current_elo_black = int(self._df.iloc[g].black_elo) + batch['elo_black'] = np.concatenate((batch['elo_black'], np.asarray([current_elo_black]))) + + if 'num_moves' not in batch.keys(): + batch['num_moves'] = np.asarray([int(self._df.iloc[g].num_moves)]) + else: + current_num_moves = int(self._df.iloc[g].num_moves) + batch['num_moves'] = np.concatenate((batch['num_moves'], np.asarray([current_num_moves]))) + + # Metadata + if 'player_red' not in batch.keys(): + batch['player_red'] = [self._df.iloc[g].red] + else: + batch['player_red'].append(self._df.iloc[g].red) + + if 'player_black' not in batch.keys(): + batch['player_black'] = [self._df.iloc[g].black] + else: + batch['player_black'].append(self._df.iloc[g].black) + + if 'result' not in batch.keys(): + batch['result'] = [result] + else: + batch['result'].append(result) + + if 'event' not in batch.keys(): + batch['event'] = [self._df.iloc[g].event] + else: + batch['event'].append(self._df.iloc[g].event) + + logging.info("Exporting batch (time: {:.3f}m)".format((time.time() - start_time) / 60)) + self.export_batch(batch, export_path, "batch_" + str(b)) + + batch_start += batch_size + return True + + def export_batch(self, batch, export_path, filename): + start_time = time.time() + store = zarr.ZipStore(export_path + filename + ".zip", mode="w") + zarr_file = zarr.group(store=store, overwrite=True) + compressor = Blosc(cname=self._compression, clevel=self._clevel, shuffle=Blosc.SHUFFLE) + + x = batch['x'] + start_indices = batch['start_indices'] + y_value = batch['y_value'] + y_policy = batch['y_policy'] + + elo_red = batch['elo_red'] + elo_black = batch['elo_black'] + num_moves = batch['num_moves'] + + # Discard missing entries from average elo + indices_red_elo_not_zero = np.where(elo_red > 0)[0] + indices_black_elo_not_zero = np.where(elo_black > 0)[0] + + avg_elo_red = int(elo_red[indices_red_elo_not_zero].sum() / len(indices_red_elo_not_zero)) + avg_elo_black = int(elo_black[indices_black_elo_not_zero].sum() / len(indices_black_elo_not_zero)) + avg_elo = int((avg_elo_red + avg_elo_black) / 2) + + # metadata + player_red = np.asarray(batch['player_red'], dtype=' 1)]))[0] + else: + shared_col = col + + upper_pos = None + for p in pos: + if p[1] == shared_col: + if upper_pos is None: + upper_pos = p + else: + if wxf_identifier.isupper(): + if p[0] < upper_pos[0]: + upper_pos = p + else: + if p[0] > upper_pos[0]: + upper_pos = p + return upper_pos + + def get_position_lower(self, wxf_identifier, col=None): + pos = self.get_positions(wxf_identifier) + + if col is None: + cols = [p[1] for p in pos] + shared_col = list(set([c for c in cols if (cols.count(c) > 1)]))[0] + else: + shared_col = col + + lower_pos = None + for p in pos: + if p[1] == shared_col: + if lower_pos is None: + lower_pos = p + else: + if wxf_identifier.isupper(): + if p[0] > lower_pos[0]: + lower_pos = p + else: + if p[0] < lower_pos[0]: + lower_pos = p + return lower_pos + + def get_position_middle(self, wxf_identifier, col=None): + pos = self.get_positions(wxf_identifier) + + if col is None: + cols = [p[1] for p in pos] + shared_col = list(set([c for c in cols if (cols.count(c) > 1)]))[0] + else: + shared_col = col + + rows = [p[0] for p in pos if p[1] == shared_col] + rows.sort() + mid_row = rows[len(rows) // 2] + + return (mid_row, shared_col) + + def get_position_consider_tandem(self, move): + wxf_identifier = move[0] + if move[1] == '+': + old_row, old_col = self.get_position_upper(wxf_identifier) + elif move[1] == '-': + old_row, old_col = self.get_position_lower(wxf_identifier) + else: + if wxf_identifier.isupper(): + old_col = 9 - int(move[1]) + else: + old_col = int(move[1]) - 1 + old_row = self.get_positions(wxf_identifier, col=old_col)[0][0] + + return (old_row, old_col) + + def get_positions_sorted_by_row(self, wxf_identifier, col=None): + pos = self.get_positions(wxf_identifier) + + if col is None: + cols = [p[1] for p in pos] + shared_col = list(set([c for c in cols if (cols.count(c) > 1)]))[0] + else: + shared_col = col + + pos_shared_col = [p for p in pos if p[1] == shared_col] + return sorted(pos_shared_col, key=lambda tup: tup[0]) + + def move_king(self, move): + wxf_identifier = move[0] + old_row, old_col = self.get_positions(wxf_identifier)[0] + + piece = self.board[old_row][old_col] + self.board[old_row][old_col] = 0 + + direction = move[2] + if wxf_identifier.isupper(): + if direction == '.': + new_col = 9 - int(move[3]) + new_row = old_row + elif direction == '+': + new_col = old_col + new_row = old_row - 1 + elif direction == '-': + new_col = old_col + new_row = old_row + 1 + else: + if direction == '.': + new_col = int(move[3]) - 1 + new_row = old_row + elif direction == '+': + new_col = old_col + new_row = old_row + 1 + elif direction == '-': + new_col = old_col + new_row = old_row - 1 + + self.board[new_row][new_col] = piece + return [(old_row, old_col), (new_row, new_col)] + + def move_advisor(self, move): + wxf_identifier = move[0] + + old_row, old_col = self.get_position_consider_tandem(move) + piece = self.board[old_row][old_col] + self.board[old_row][old_col] = 0 + + direction = move[2] + if wxf_identifier.isupper(): + new_col = 9 - int(move[3]) + if direction == '+': + new_row = old_row - 1 + elif direction == '-': + new_row = old_row + 1 + else: + new_col = int(move[3]) - 1 + if direction == '+': + new_row = old_row + 1 + elif direction == '-': + new_row = old_row - 1 + + self.board[new_row][new_col] = piece + return [(old_row, old_col), (new_row, new_col)] + + def move_elephant(self, move): + wxf_identifier = move[0] + + old_row, old_col = self.get_position_consider_tandem(move) + piece = self.board[old_row][old_col] + self.board[old_row][old_col] = 0 + + direction = move[2] + if wxf_identifier.isupper(): + new_col = 9 - int(move[3]) + if direction == '+': + new_row = old_row - 2 + elif direction == '-': + new_row = old_row + 2 + else: + new_col = int(move[3]) - 1 + if direction == '+': + new_row = old_row + 2 + elif direction == '-': + new_row = old_row - 2 + + self.board[new_row][new_col] = piece + return [(old_row, old_col), (new_row, new_col)] + + def move_horse(self, move): + wxf_identifier = move[0] + + old_row, old_col = self.get_position_consider_tandem(move) + piece = self.board[old_row][old_col] + self.board[old_row][old_col] = 0 + + direction = move[2] + if wxf_identifier.isupper(): + new_col = 9 - int(move[3]) + if direction == '+': + new_row = old_row - 2 if abs(old_col - new_col) == 1 else old_row - 1 + elif direction == '-': + new_row = old_row + 2 if abs(old_col - new_col) == 1 else old_row + 1 + else: + new_col = int(move[3]) - 1 + if direction == '+': + new_row = old_row + 2 if abs(old_col - new_col) == 1 else old_row + 1 + elif direction == '-': + new_row = old_row - 2 if abs(old_col - new_col) == 1 else old_row - 1 + + self.board[new_row][new_col] = piece + return [(old_row, old_col), (new_row, new_col)] + + def move_chariot_or_cannon(self, move): + # Chariots and cannons share possible movements + wxf_identifier = move[0] + + old_row, old_col = self.get_position_consider_tandem(move) + piece = self.board[old_row][old_col] + self.board[old_row][old_col] = 0 + + direction = move[2] + if wxf_identifier.isupper(): + if direction == '.': + new_col = 9 - int(move[3]) + new_row = old_row + elif direction == '+': + new_col = old_col + new_row = old_row - int(move[3]) + elif direction == '-': + new_col = old_col + new_row = old_row + int(move[3]) + else: + if direction == '.': + new_col = int(move[3]) - 1 + new_row = old_row + elif direction == '+': + new_col = old_col + new_row = old_row + int(move[3]) + elif direction == '-': + new_col = old_col + new_row = old_row - int(move[3]) + + self.board[new_row][new_col] = piece + return [(old_row, old_col), (new_row, new_col)] + + def move_pawn(self, move, red_move): + # There are special cases for pawns as we can have + # 1, 2, 3, 4, 5 pawns in a column, as well as + # tandem pawns in two columns + one_to_nine_str = ['1', '2', '3', '4', '5', '6', '7', '8', '9'] + + direction = move[2] + + # The given identifier might not be the figure but + # the column in which the figure currently is positioned + wxf_figure = 'P' if red_move else 'p' + pos = self.get_positions(wxf_figure) + + # Find the current position + cols = [p[1] for p in pos] + shared_cols = list(set([c for c in cols if (cols.count(c) > 1)])) + # Check whether we want to move a pawn from a tandem position + if move[1] in one_to_nine_str: + col_on_board = (9 - int(move[1])) if red_move else (int(move[1]) - 1) + if col_on_board not in shared_cols: + shared_cols = [] + + if len(shared_cols) == 0: + col = (9 - int(move[1])) if red_move else (int(move[1]) - 1) + old_row, old_col = self.get_positions(wxf_figure, col=col)[0] + elif len(shared_cols) == 1: + shared_col = shared_cols[0] + # how many pawns in the same column + rows = [p[0] for p in pos if p[1] == shared_col] + if len(rows) == 1: + old_row, old_col = (rows[0], shared_col) + elif len(rows) == 2: + if move[1] == '+': + old_row, old_col = self.get_position_upper(wxf_figure, col=shared_col) + elif move[1] == '-': + old_row, old_col = self.get_position_lower(wxf_figure, col=shared_col) + elif len(rows) == 3: + if move[1] in one_to_nine_str: + old_row, old_col = self.get_position_middle(wxf_figure, col=shared_col) + elif move[1] == '+': + old_row, old_col = self.get_position_upper(wxf_figure, col=shared_col) + elif move[1] == '-': + old_row, old_col = self.get_position_lower(wxf_figure, col=shared_col) + elif len(rows) == 4: + if move[0] == '+': + old_row, old_col = self.get_position_upper(wxf_figure, col=shared_col) + elif move[0] == '-': + old_row, old_col = self.get_position_lower(wxf_figure, col=shared_col) + elif move[0] == wxf_figure: + possible_pos = self.get_positions_sorted_by_row(wxf_figure, col=shared_col) + if move[1] == '+': + old_row, old_col = possible_pos[-2] + elif move[1] == '-': + old_row, old_col = possible_pos[1] + elif len(rows) == 5: + if move[0] == '+' and move[1] == '+': + old_row, old_col = self.get_position_upper(wxf_figure, col=shared_col) + elif move[0] == '-' and move[1] == '-': + old_row, old_col = self.get_position_lower(wxf_figure, col=shared_col) + elif move[0] == wxf_figure: + if move[1] == '+': + old_row, old_col = self.get_positions_sorted_by_row(wxf_figure, col=shared_col)[-2] + elif move[1] == '-': + old_row, old_col = self.get_positions_sorted_by_row(wxf_figure, col=shared_col)[1] + elif move[1] in one_to_nine_str: + old_row, old_col = self.get_position_middle(wxf_figure, col=shared_col) + else: + # the current column of the figure + wxf_identifier = (9 - int(move[0])) if red_move else (int(move[0]) - 1) + if move[1] in one_to_nine_str: + old_row, old_col = self.get_position_middle(wxf_figure, col=wxf_identifier) + elif move[1] == '+': + old_row, old_col = self.get_position_upper(wxf_figure, col=wxf_identifier) + elif move[1] == '-': + old_row, old_col = self.get_position_lower(wxf_figure, col=wxf_identifier) + + piece = self.board[old_row][old_col] + self.board[old_row][old_col] = 0 + + if wxf_figure.isupper(): + if direction == '.': + new_col = 9 - int(move[3]) + new_row = old_row + elif direction == '+': + new_col = old_col + new_row = old_row - 1 + else: + if direction == '.': + new_col = int(move[3]) - 1 + new_row = old_row + elif direction == '+': + new_col = old_col + new_row = old_row + 1 + + self.board[new_row][new_col] = piece + return [(old_row, old_col), (new_row, new_col)] + + def parse_movelist(self, movelist, display_moves=False): + # In the case that 2 different columns are shared by at least 2 pawns each + one_to_nine_str = ['1', '2', '3', '4', '5', '6', '7', '8', '9'] + + if display_moves: + self.display_board() + + coordinate_changes = [] + + regex = re.compile(r'^\d+\.$') + movelist = [move for move in movelist.split(' ') if not regex.match(move)] + + red_move = True if movelist[0][0].isupper() else False + for move in movelist: + if display_moves: + print("Move: ", move) + + wxf_identifier = move[0] + + if wxf_identifier in ['K', 'k']: + coordinate_change = self.move_king(move) + elif wxf_identifier in ['A', 'a']: + coordinate_change = self.move_advisor(move) + elif wxf_identifier in ['E', 'e']: + coordinate_change = self.move_elephant(move) + elif wxf_identifier in ['H', 'h']: + coordinate_change = self.move_horse(move) + elif wxf_identifier in ['C', 'c', 'R', 'r']: + coordinate_change = self.move_chariot_or_cannon(move) + elif wxf_identifier in ['P', 'p', '+', '-'] or wxf_identifier in one_to_nine_str: + coordinate_change = self.move_pawn(move, red_move) + + coordinate_changes.append(coordinate_change) + + red_move = not red_move + + if display_moves: + self.display_board() + + return coordinate_changes + + def parse_single_move(self, move, red_move, display_move=False): + # In the case that 2 different columns are shared by at least 2 pawns each + one_to_nine_str = ['1', '2', '3', '4', '5', '6', '7', '8', '9'] + + if display_move: + self.display_board() + + coordinate_changes = [] + + if display_move: + print("Move: ", move) + + wxf_identifier = move[0] + + if wxf_identifier in ['K', 'k']: + coordinate_change = self.move_king(move) + elif wxf_identifier in ['A', 'a']: + coordinate_change = self.move_advisor(move) + elif wxf_identifier in ['E', 'e']: + coordinate_change = self.move_elephant(move) + elif wxf_identifier in ['H', 'h']: + coordinate_change = self.move_horse(move) + elif wxf_identifier in ['C', 'c', 'R', 'r']: + coordinate_change = self.move_chariot_or_cannon(move) + elif wxf_identifier in ['P', 'p', '+', '-'] or wxf_identifier in one_to_nine_str: + coordinate_change = self.move_pawn(move, red_move) + + coordinate_changes.append(coordinate_change) + + if display_move: + self.display_board() + + return coordinate_changes + + def init_board(self): + board = [[0] * 9 for _ in range(10)] + + board[0][0] = Rook(Color.BLACK) + board[0][-1] = Rook(Color.BLACK) + board[-1][0] = Rook(Color.RED) + board[-1][-1] = Rook(Color.RED) + + board[0][1] = Horse(Color.BLACK) + board[0][-2] = Horse(Color.BLACK) + board[-1][1] = Horse(Color.RED) + board[-1][-2] = Horse(Color.RED) + + board[0][2] = Elephant(Color.BLACK) + board[0][-3] = Elephant(Color.BLACK) + board[-1][2] = Elephant(Color.RED) + board[-1][-3] = Elephant(Color.RED) + + board[0][3] = Advisor(Color.BLACK) + board[0][-4] = Advisor(Color.BLACK) + board[-1][3] = Advisor(Color.RED) + board[-1][-4] = Advisor(Color.RED) + + board[0][4] = King(Color.BLACK) + board[-1][4] = King(Color.RED) + + board[2][1] = Cannon(Color.BLACK) + board[2][-2] = Cannon(Color.BLACK) + board[-3][1] = Cannon(Color.RED) + board[-3][-2] = Cannon(Color.RED) + + board[3][0] = Pawn(Color.BLACK) + board[3][2] = Pawn(Color.BLACK) + board[3][4] = Pawn(Color.BLACK) + board[3][6] = Pawn(Color.BLACK) + board[3][8] = Pawn(Color.BLACK) + board[-4][0] = Pawn(Color.RED) + board[-4][2] = Pawn(Color.RED) + board[-4][4] = Pawn(Color.RED) + board[-4][6] = Pawn(Color.RED) + board[-4][8] = Pawn(Color.RED) + + return board + + def get_num_figures(self): + figures = {'k': 0, 'a': 0, 'e': 0, 'h': 0, 'r': 0, 'c': 0, 'p': 0, + 'K': 0, 'A': 0, 'E': 0, 'H': 0, 'R': 0, 'C': 0, 'P': 0} + + for row in range(len(self.board)): + for col in range(len(self.board[0])): + if self.board[row][col] != 0: + wxf_identifier = self.board[row][col].wxf_identifier + figures[wxf_identifier] += 1 + + return figures + + def display_board(self): + board = [[0] * 9 for _ in range(10)] + for row in range(len(self.board)): + for col in range(len(self.board[0])): + if self.board[row][col] != 0: + piece = self.board[row][col] + board[row][col] = str(piece.wxf_identifier) + else: + board[row][col] = '0' + print(board) + + def get_bitboard(self): + board = [[0] * 9 for _ in range(10)] + for row in range(len(self.board)): + for col in range(len(self.board[0])): + if self.board[row][col] != 0: + board[row][col] = 1 + else: + board[row][col] = 0 + return board + + +class Color(Enum): + RED = 1 + BLACK = 2 + + +class King: + def __init__(self, color): + self.color = color + self.wxf_identifier = 'K' if color == color.RED else 'k' + + +class Advisor: + def __init__(self, color): + self.color = color + self.wxf_identifier = 'A' if color == color.RED else 'a' + + +class Elephant: + def __init__(self, color): + self.color = color + self.wxf_identifier = 'E' if color == color.RED else 'e' + + +class Rook: + def __init__(self, color): + self.color = color + self.wxf_identifier = 'R' if color == color.RED else 'r' + + +class Cannon: + def __init__(self, color): + self.color = color + self.wxf_identifier = 'C' if color == color.RED else 'c' + + +class Horse: + def __init__(self, color): + self.color = color + self.wxf_identifier = 'H' if color == color.RED else 'h' + + +class Pawn: + def __init__(self, color): + self.color = color + self.wxf_identifier = 'P' if color == color.RED else 'p' diff --git a/requirements.txt b/requirements.txt index c564469a..b2c143d3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ numpy python-chess mxnet onnx==1.3.0 +pandas From a3aa0d9a937b42b4688b5b5bbe78b8c7d0436c41 Mon Sep 17 00:00:00 2001 From: xsr7qsr Date: Tue, 27 Apr 2021 08:14:39 +0200 Subject: [PATCH 13/19] Fix rank and file lookups --- .../environments/fairy_state/fairyboard.cpp | 18 -------- .../src/environments/fairy_state/fairyboard.h | 2 - .../fairypolicymaprepresentation.h | 2 +- .../environments/fairy_state/fairyutil.cpp | 42 +++++++++++++++++++ .../src/environments/fairy_state/fairyutil.h | 36 +++++++--------- 5 files changed, 57 insertions(+), 43 deletions(-) diff --git a/engine/src/environments/fairy_state/fairyboard.cpp b/engine/src/environments/fairy_state/fairyboard.cpp index 1143d58f..7f16c2d0 100644 --- a/engine/src/environments/fairy_state/fairyboard.cpp +++ b/engine/src/environments/fairy_state/fairyboard.cpp @@ -118,21 +118,3 @@ std::string uci_move(Move m) { return std::string(1, fromFile) + fromRank + std::string(1, toFile) + toRank; } - -char file_to_uci(File file) { - for (auto it = FILE_LOOKUP.begin(); it != FILE_LOOKUP.end(); ++it) { - if (it->second == file) { - return it->first; - } - } - return char(); -} - -std::string rank_to_uci(Rank rank) { - for (auto it = RANK_LOOKUP.begin(); it != RANK_LOOKUP.end(); ++it) { - if (it->second == rank) { - return std::string(1, it->first); - } - } - return "10"; -} diff --git a/engine/src/environments/fairy_state/fairyboard.h b/engine/src/environments/fairy_state/fairyboard.h index 538ea183..76f6c3cd 100644 --- a/engine/src/environments/fairy_state/fairyboard.h +++ b/engine/src/environments/fairy_state/fairyboard.h @@ -25,7 +25,5 @@ class FairyBoard : public Position Result get_result(const FairyBoard &pos, bool inCheck); std::string wxf_move(Move m, const FairyBoard& pos); std::string uci_move(Move m); -char file_to_uci(File file); -std::string rank_to_uci(Rank rank); #endif //FAIRYBOARD_H diff --git a/engine/src/environments/fairy_state/fairypolicymaprepresentation.h b/engine/src/environments/fairy_state/fairypolicymaprepresentation.h index a43d4ae9..9e37eefb 100644 --- a/engine/src/environments/fairy_state/fairypolicymaprepresentation.h +++ b/engine/src/environments/fairy_state/fairypolicymaprepresentation.h @@ -2090,4 +2090,4 @@ const unsigned long FLAT_PLANE_IDX[] = { 3668 }; -#endif FAIRYPOLICYMAPREPRESENTATION_H +#endif // FAIRYPOLICYMAPREPRESENTATION_H diff --git a/engine/src/environments/fairy_state/fairyutil.cpp b/engine/src/environments/fairy_state/fairyutil.cpp index c5ceeba2..21b429e4 100644 --- a/engine/src/environments/fairy_state/fairyutil.cpp +++ b/engine/src/environments/fairy_state/fairyutil.cpp @@ -2,6 +2,30 @@ #include +const unordered_map FILE_LOOKUP = { + {'a', FILE_A}, + {'b', FILE_B}, + {'c', FILE_C}, + {'d', FILE_D}, + {'e', FILE_E}, + {'f', FILE_F}, + {'g', FILE_G}, + {'h', FILE_H}, + {'i', FILE_I}}; + +// Note that we have 10 ranks but use a char to Rank lookup... +const unordered_map RANK_LOOKUP = { + {'1', RANK_1}, + {'2', RANK_2}, + {'3', RANK_3}, + {'4', RANK_4}, + {'5', RANK_5}, + {'6', RANK_6}, + {'7', RANK_7}, + {'8', RANK_8}, + {'9', RANK_9}}; + + Square get_origin_square(const string& uciMove) { File fromFile = FILE_LOOKUP.at(uciMove[0]); @@ -33,3 +57,21 @@ Square get_destination_square(const string& uciMove) } return make_square(toFile, toRank); } + +char file_to_uci(File file) { + for (auto it = FILE_LOOKUP.begin(); it != FILE_LOOKUP.end(); ++it) { + if (it->second == file) { + return it->first; + } + } + return char(); +} + +std::string rank_to_uci(Rank rank) { + for (auto it = RANK_LOOKUP.begin(); it != RANK_LOOKUP.end(); ++it) { + if (it->second == rank) { + return std::string(1, it->first); + } + } + return "10"; +} diff --git a/engine/src/environments/fairy_state/fairyutil.h b/engine/src/environments/fairy_state/fairyutil.h index 1872f417..0d523174 100644 --- a/engine/src/environments/fairy_state/fairyutil.h +++ b/engine/src/environments/fairy_state/fairyutil.h @@ -5,28 +5,6 @@ using namespace std; -const unordered_map FILE_LOOKUP = { - {'a', FILE_A}, - {'b', FILE_B}, - {'c', FILE_C}, - {'d', FILE_D}, - {'e', FILE_E}, - {'f', FILE_F}, - {'g', FILE_G}, - {'h', FILE_H}, - {'i', FILE_I}}; - -// Note that we have 10 ranks but use a char to Rank lookup... -const unordered_map RANK_LOOKUP = { - {'1', RANK_1}, - {'2', RANK_2}, - {'3', RANK_3}, - {'4', RANK_4}, - {'5', RANK_5}, - {'6', RANK_6}, - {'7', RANK_7}, - {'8', RANK_8}, - {'9', RANK_9}}; /** * @brief get_origin_square Returns the origin square for a valid ucciMove @@ -42,4 +20,18 @@ Square get_origin_square(const string &uciMove); */ Square get_destination_square(const string &uciMove); +/** + * @brief file_to_uci Returns the uci corresponding to the given file + * @param file FILE to convert + * @return uci corresponding to the file + */ +char file_to_uci(File file); + +/** + * @brief rank_to_uci Returns the uci corresponding to the given rank + * @param rank Rank to convert + * @return uci corresponding to the rank + */ +std::string rank_to_uci(Rank rank); + #endif //FAIRYUTIL_H From dcffc578dbeb169d4d23d59e3756032705b39b4c Mon Sep 17 00:00:00 2001 From: xsr7qsr Date: Tue, 27 Apr 2021 08:29:09 +0200 Subject: [PATCH 14/19] Disable NNUE evaluation --- engine/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/engine/CMakeLists.txt b/engine/CMakeLists.txt index d61474b0..70f9f64b 100644 --- a/engine/CMakeLists.txt +++ b/engine/CMakeLists.txt @@ -57,6 +57,7 @@ if (MODE_XIANGQI) add_definitions(-DMODE_XIANGQI) add_definitions(-DLARGEBOARDS) add_definitions(-DPRECOMPUTED_MAGICS) + add_definitions(-DNNUE_EMBEDDING_OFF) endif () if (BUILD_TESTS) From a8df88aca343dc85cf01661b2d1cd5e7c13dc5f9 Mon Sep 17 00:00:00 2001 From: Johannes Czech Date: Tue, 27 Apr 2021 19:54:38 +0200 Subject: [PATCH 15/19] Update fairystate.h added init() --- engine/src/environments/fairy_state/fairystate.h | 1 + 1 file changed, 1 insertion(+) diff --git a/engine/src/environments/fairy_state/fairystate.h b/engine/src/environments/fairy_state/fairystate.h index 6fc3534f..6ca5f7db 100644 --- a/engine/src/environments/fairy_state/fairystate.h +++ b/engine/src/environments/fairy_state/fairystate.h @@ -140,6 +140,7 @@ class FairyState : public State string action_to_san(Action action, const std::vector &legalActions, bool leadsToWin, bool bookMove) const override; Tablebase::WDLScore check_for_tablebase_wdl(Tablebase::ProbeState &result) override; void set_auxiliary_outputs(const float* auxiliaryOutputs) override; + void init(int variant, bool isChess960); }; #endif // FAIRYSTATE_H From 784c238b4ac42e515f1aa82119f3a755c650dbde Mon Sep 17 00:00:00 2001 From: Johannes Czech Date: Tue, 27 Apr 2021 19:55:18 +0200 Subject: [PATCH 16/19] update fairystate.cpp added init() --- engine/src/environments/fairy_state/fairystate.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/engine/src/environments/fairy_state/fairystate.cpp b/engine/src/environments/fairy_state/fairystate.cpp index 1412158f..9cba9e00 100644 --- a/engine/src/environments/fairy_state/fairystate.cpp +++ b/engine/src/environments/fairy_state/fairystate.cpp @@ -124,3 +124,9 @@ Tablebase::WDLScore FairyState::check_for_tablebase_wdl(Tablebase::ProbeState &r void FairyState::set_auxiliary_outputs(const float* auxiliaryOutputs) { } + +void FairyState::init(int variant, bool isChess960) +{ + states = StateListPtr(new std::deque(1)); + board.set(variants.find("xiangqi")->second, variants.find("xiangqi")->second->startFen, isChess960, &states->back(), nullptr, false); +} From 2c2f090e9a0e25c8f83031fba553251665a0185b Mon Sep 17 00:00:00 2001 From: Johannes Czech Date: Tue, 27 Apr 2021 19:55:56 +0200 Subject: [PATCH 17/19] Update state.h uncomment init() --- engine/src/state.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engine/src/state.h b/engine/src/state.h index b989b1e2..82b10806 100644 --- a/engine/src/state.h +++ b/engine/src/state.h @@ -416,7 +416,7 @@ class State * @param isChess960 If true 960 mode will be active * @param variant Variant which the position corresponds to */ - //virtual void init(int variant, bool isChess960) = 0; + virtual void init(int variant, bool isChess960) = 0; }; #endif // GAMESTATE_H From 0a9914a29e78f8432fe3120b19d3437d43c44731 Mon Sep 17 00:00:00 2001 From: Johannes Czech Date: Tue, 27 Apr 2021 19:57:20 +0200 Subject: [PATCH 18/19] Update optionsuci.cpp use model/xiangqi as Model_Directory --- engine/src/uci/optionsuci.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/engine/src/uci/optionsuci.cpp b/engine/src/uci/optionsuci.cpp index e16bb056..ab3576f2 100644 --- a/engine/src/uci/optionsuci.cpp +++ b/engine/src/uci/optionsuci.cpp @@ -107,6 +107,8 @@ void OptionsUCI::init(OptionsMap &o) o["Model_Directory"] << Option(((string) "model" + "/" + availableVariants.front()).c_str()); #elif defined MODE_CHESS o["Model_Directory"] << Option("model/chess"); +#elif defined MODE_XIANGQI + o["Model_Directory"] << Option("model/xiangqi"); #else o["Model_Directory"] << Option("model"); #endif From c872e250a903b3f36a1ae062cbb2c06511e5821f Mon Sep 17 00:00:00 2001 From: Johannes Czech Date: Tue, 27 Apr 2021 19:58:10 +0200 Subject: [PATCH 19/19] Update CMakeLists.txt use MODE_CRAYZHOUSE as default --- engine/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/engine/CMakeLists.txt b/engine/CMakeLists.txt index 70f9f64b..8ee2df32 100644 --- a/engine/CMakeLists.txt +++ b/engine/CMakeLists.txt @@ -10,11 +10,11 @@ option(BACKEND_TORCH "Build with Torch backend (CPU/GPU) support" OF option(USE_960 "Build with 960 variant support" OFF) option(BUILD_TESTS "Build and run tests" OFF) # enable a single mode for different model input / outputs -option(MODE_CRAZYHOUSE "Build with crazyhouse only support" OFF) +option(MODE_CRAZYHOUSE "Build with crazyhouse only support" ON) option(MODE_CHESS "Build with chess + chess960 only support" OFF) option(MODE_LICHESS "Build with lichess variants support" OFF) option(MODE_OPEN_SPIEL "Build with open_spiel environment support" OFF) -option(MODE_XIANGQI "Build with xiangqi only support" ON) +option(MODE_XIANGQI "Build with xiangqi only support" OFF) option(SEARCH_UCT "Build with UCT instead of PUCT search" OFF) add_definitions(-DIS_64BIT)