diff --git a/DeepCrazyhouse/src/domain/agent/neural_net_api.py b/DeepCrazyhouse/src/domain/agent/neural_net_api.py index c581fad8..8f7ab75b 100644 --- a/DeepCrazyhouse/src/domain/agent/neural_net_api.py +++ b/DeepCrazyhouse/src/domain/agent/neural_net_api.py @@ -43,7 +43,12 @@ def __init__(self, ctx="cpu", batch_size=1, select_policy_form_planes: bool = Tr self.symbol_path = glob.glob(model_architecture_dir + "*")[0] if model_weights_dir == "default": - self.params_path = glob.glob(main_config["model_weights_dir"] + "*")[0] + self.params_path = None + + paths = glob.glob(main_config["model_weights_dir"] + "*") + for path in paths: + if ".params" in path: + self.params_path = path else: self.params_path = glob.glob(model_weights_dir + "*")[0] # make sure the needed files have been found