Skip to content

Commit

Permalink
fixed creation of additional dataloaders in train_cnn.ipynb
Browse files Browse the repository at this point in the history
  • Loading branch information
HelpstoneX committed Sep 23, 2024
1 parent ab8a7a7 commit f50550b
Showing 1 changed file with 2 additions and 13 deletions.
15 changes: 2 additions & 13 deletions DeepCrazyhouse/src/training/train_cnn.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -399,18 +399,7 @@
" for phase in [str(phase) for phase in to.phase_weights.keys()] + [\"None\"]:\n",
" pgn_dataset_arrays_dict = load_pgn_dataset(dataset_type='test', part_id=0,\n",
" verbose=True, normalize=tc.normalize, phase=phase)\n",
" s_idcs_val_tmp = pgn_dataset_arrays_dict[\"start_indices\"]\n",
" x_val_tmp = pgn_dataset_arrays_dict[\"x\"]\n",
" yv_val_tmp = pgn_dataset_arrays_dict[\"y_value\"]\n",
" yp_val_tmp = pgn_dataset_arrays_dict[\"y_policy\"]\n",
" plys_to_end_tmp = pgn_dataset_arrays_dict[\"plys_to_end\"]\n",
" pgn_datasets_val_tmp = pgn_dataset_arrays_dict[\"pgn_dataset\"]\n",
" phase_vector_tmp = pgn_dataset_arrays_dict[\"phase_vector\"]\n",
"\n",
" if tc.discount != 1:\n",
" yv_val_tmp *= tc.discount**plys_to_end_tmp\n",
"\n",
" data_loader = get_data_loader(x_val_tmp, yv_val_tmp, yp_val_tmp, plys_to_end_tmp, phase_vector_tmp, tc, shuffle=False)\n",
" data_loader = get_data_loader(pgn_dataset_arrays_dict, tc, shuffle=False)\n",
" additional_data_loaders[f\"Phase{phase}Test\"] = data_loader"
]
},
Expand Down Expand Up @@ -1726,4 +1715,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}

0 comments on commit f50550b

Please sign in to comment.