Skip to content

Commit

Permalink
added passing of main_config and train_config to rl_training.py
Browse files Browse the repository at this point in the history
  • Loading branch information
QueensGambit committed Dec 14, 2020
1 parent bf552e9 commit 500da21
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
12 changes: 9 additions & 3 deletions engine/src/rl/rl_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ def __init__(self, args, nb_games_to_update=1024, nb_arena_games=100, lr_reducti
binary_mappings = {"crazyhouse": "CrazyAra",
"chess": "ClassicAra"}
self.crazyara_binary_name = binary_mappings[self.args.uci_variant]
main_config["planes_train_dir"] = args.crazyara_binary_dir + "export/train/"
main_config["planes_val_dir"] = args.crazyara_binary_dir + "export/val/"

self.proc = None
self.nb_games_to_update = nb_games_to_update
if nb_arena_games % 2 == 1:
Expand Down Expand Up @@ -242,7 +245,9 @@ def _set_uci_options(self, is_arena=False):
# set_uci_param(self.proc, "MaxInitPly", 30)
set_uci_param(self.proc, "Reuse_Tree", "false")
set_uci_param(self.proc, "Precision", self.args.precision)
# set_uci_param(self.proc, "Selfplay_Number_Chunks", 1)
#set_uci_param(self.proc, "Selfplay_Number_Chunks", 1)
set_uci_param(self.proc, "Selfplay_Number_Chunks", 640)
set_uci_param(self.proc, "Selfplay_Chunk_Size", 128)

if is_arena is True:
# set_uci_param(self.proc, "Centi_Temperature", 60) cz
Expand Down Expand Up @@ -490,7 +495,8 @@ def check_for_enough_train_data(self, number_files_to_update):
process = Process(target=update_network, args=(queue, self.nn_update_index, self.k_steps,
self.max_lr, self._get_current_model_arch_file(),
self._get_current_model_weight_file(),
self.crazyara_binary_dir, not self.args.no_onnx_export))
self.crazyara_binary_dir, not self.args.no_onnx_export,
main_config, train_config))
logging.info("start training")
process.start()
self.k_steps = queue.get() + 1
Expand Down Expand Up @@ -551,7 +557,7 @@ def parse_args(cmd_args: list):
parser.add_argument('--nn-update-idx', type=int, default=0,
help="Index of how many NN updates have been done so far."
" This will be used to label the NN weights (default: 0)")
parser.add_argument("--nn-update-files", type=int, default=15, #10,
parser.add_argument("--nn-update-files", type=int, default=10,
help="How many new generated training files are needed to apply an update to the NN")
parser.add_argument("--arena-games", type=int, default=100,
help="How many arena games will be done to judge the quality of the new network")
Expand Down
6 changes: 3 additions & 3 deletions engine/src/rl/rl_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
import mxnet as mx

sys.path.append("../../../")
from DeepCrazyhouse.configs.main_config import main_config
from DeepCrazyhouse.configs.train_config import train_config
from DeepCrazyhouse.src.preprocessing.dataset_loader import load_pgn_dataset
from DeepCrazyhouse.src.training.trainer_agent import acc_sign, cross_entropy, acc_distribution
from DeepCrazyhouse.src.training.trainer_agent_mxnet import TrainerAgentMXNET, add_non_sparse_cross_entropy,\
Expand All @@ -24,7 +22,7 @@
from DeepCrazyhouse.src.domain.neural_net.onnx.convert_to_onnx import convert_mxnet_model_to_onnx


def update_network(queue, nn_update_idx, k_steps_initial, max_lr, symbol_filename, params_filename, cwd, convert_to_onnx):
def update_network(queue, nn_update_idx, k_steps_initial, max_lr, symbol_filename, params_filename, cwd, convert_to_onnx, main_config, train_config):
"""
Creates a new NN checkpoint in the model contender directory after training using the game files stored in the
training directory
Expand All @@ -38,6 +36,8 @@ def update_network(queue, nn_update_idx, k_steps_initial, max_lr, symbol_filenam
Updates the neural network with the newly acquired games from the replay memory
:param cwd: Current working directory (must end with "/")
:param convert_to_onnx: Boolean indicating if the network shall be exported to ONNX to allow TensorRT inference
:param main_config: Dict of the main_config (imported from main_config.py)
:param train_config: Dict of the train_config (imported from train_config.py)
:return: k_steps_final
"""

Expand Down

0 comments on commit 500da21

Please sign in to comment.