Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify no_merge_train_val to retrain. #332

Merged
merged 2 commits into from
Sep 15, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 32 additions & 22 deletions search_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,16 +141,16 @@ def init_search_algorithm(search_alg, metric=None, mode=None):
logging.info(f"{search_alg} search is found, run BasicVariantGenerator().")


def prepare_retrain_config(best_config, best_log_dir, merge_train_val):
def prepare_retrain_config(best_config, best_log_dir, retrain):
"""Prepare the configuration for re-training.

Args:
best_config (AttributeDict): The best hyper-parameter configuration.
best_log_dir (str): The directory of the best trial of the experiment.
merge_train_val (bool): Whether to merge the training and validation data.
retrain (bool): Whether to retrain the model with merged training and validation data.
"""
if merge_train_val:
best_config.merge_train_val = True
if retrain:
best_config.retrain = True

log_path = os.path.join(best_log_dir, "logs.json")
if os.path.isfile(log_path):
Expand All @@ -165,15 +165,15 @@ def prepare_retrain_config(best_config, best_log_dir, merge_train_val):
optimal_idx = log_metric.argmax() if best_config.mode == "max" else log_metric.argmin()
best_config.epochs = optimal_idx.item() + 1 # plus 1 for epochs
else:
best_config.merge_train_val = False
best_config.retrain = False
Eleven1Liu marked this conversation as resolved.
Show resolved Hide resolved


def load_static_data(config, merge_train_val=False):
def load_static_data(config, retrain=False):
"""Preload static data once for multiple trials.

Args:
config (AttributeDict): Config of the experiment.
merge_train_val (bool, optional): Whether to merge the training and validation data.
retrain (bool): Whether to retrain the model with merged training and validation data.
Defaults to False.

Returns:
Expand All @@ -184,7 +184,7 @@ def load_static_data(config, merge_train_val=False):
test_data=config.test_file,
val_data=config.val_file,
val_size=config.val_size,
merge_train_val=merge_train_val,
merge_train_val=retrain,
tokenize_text="lm_weight" not in config.network_config,
remove_no_label_data=config.remove_no_label_data,
)
Expand All @@ -205,31 +205,31 @@ def load_static_data(config, merge_train_val=False):
}


def retrain_best_model(exp_name, best_config, best_log_dir, merge_train_val):
def retrain_best_model(exp_name, best_config, best_log_dir, retrain):
"""Re-train the model with the best hyper-parameters.
A new model is trained on the combined training and validation data if `merge_train_val` is True.
A new model is trained on the combined training and validation data if `retrain` is True.
If a test set is provided, it will be evaluated by the obtained model.

Args:
exp_name (str): The directory to save trials generated by ray tune.
best_config (AttributeDict): The best hyper-parameter configuration.
best_log_dir (str): The directory of the best trial of the experiment.
merge_train_val (bool): Whether to merge the training and validation data.
retrain (bool): Whether to retrain the model with merged training and validation data.
"""
best_config.silent = False
checkpoint_dir = os.path.join(best_config.result_dir, exp_name, "trial_best_params")
os.makedirs(checkpoint_dir, exist_ok=True)
with open(os.path.join(checkpoint_dir, "params.yml"), "w") as fp:
yaml.dump(dict(best_config), fp)
best_config.run_name = "_".join(exp_name.split("_")[:-1]) + "_best"
best_config.checkpoint_dir = checkpoint_dir
best_config.log_path = os.path.join(best_config.checkpoint_dir, "logs.json")
prepare_retrain_config(best_config, best_log_dir, merge_train_val)
prepare_retrain_config(best_config, best_log_dir, retrain)
set_seed(seed=best_config.seed)
with open(os.path.join(checkpoint_dir, "params.yml"), "w") as fp:
yaml.dump(dict(best_config), fp)

data = load_static_data(best_config, merge_train_val=best_config.merge_train_val)
data = load_static_data(best_config, retrain=best_config.retrain)

if merge_train_val:
if retrain:
logging.info(f"Re-training with best config: \n{best_config}")
trainer = TorchTrainer(config=best_config, **data)
trainer.train()
Expand All @@ -247,7 +247,7 @@ def retrain_best_model(exp_name, best_config, best_log_dir, merge_train_val):

if "test" in data["datasets"]:
test_results = trainer.test()
if merge_train_val:
if retrain:
logging.info(f"Test results after re-training: {test_results}")
else:
logging.info(f"Test results of best config: {test_results}")
Expand All @@ -260,8 +260,18 @@ def main():
"--config",
help="Path to configuration file (default: %(default)s). Please specify a config with all arguments in LibMultiLabel/main.py::get_config.",
)
parser.add_argument("--cpu_count", type=int, default=4, help="Number of CPU per trial (default: %(default)s)")
parser.add_argument("--gpu_count", type=int, default=1, help="Number of GPU per trial (default: %(default)s)")
parser.add_argument(
"--cpu_count",
type=int,
default=4,
help="Number of CPU per trial (default: %(default)s)",
)
parser.add_argument(
"--gpu_count",
type=int,
default=1,
help="Number of GPU per trial (default: %(default)s)",
)
parser.add_argument(
"--num_samples",
type=int,
Expand All @@ -275,9 +285,9 @@ def main():
help="Search algorithms (default: %(default)s)",
)
parser.add_argument(
"--no_merge_train_val",
"--no_retrain",
action="store_true",
help="Do not add the validation set in re-training the final model after hyper-parameter search.",
help="Do not retrain the model with validation set after hyperparameter search.",
)
args, _ = parser.parse_known_args()

Expand Down Expand Up @@ -343,7 +353,7 @@ def main():
# Save best model after parameter search.
best_config = analysis.get_best_config(f"val_{config.val_metric}", config.mode, scope="all")
best_log_dir = analysis.get_best_logdir(f"val_{config.val_metric}", config.mode, scope="all")
retrain_best_model(exp_name, best_config, best_log_dir, merge_train_val=not config.no_merge_train_val)
retrain_best_model(exp_name, best_config, best_log_dir, retrain=not config.no_retrain)


if __name__ == "__main__":
Expand Down