From d9e7beca5a9df9cbe8543f5750f02a47201403e8 Mon Sep 17 00:00:00 2001 From: Zhiqiang Tang Date: Tue, 30 May 2023 10:32:41 -0700 Subject: [PATCH] [AutoMM] Support loading specific checkpoints (#3244) --- .../src/autogluon/multimodal/matcher.py | 50 +++-------- .../src/autogluon/multimodal/predictor.py | 54 +++--------- .../autogluon/multimodal/utils/__init__.py | 2 +- .../src/autogluon/multimodal/utils/load.py | 88 ++++++++++++++++++- .../unittests/predictor/test_predictor.py | 18 ++++ 5 files changed, 130 insertions(+), 82 deletions(-) diff --git a/multimodal/src/autogluon/multimodal/matcher.py b/multimodal/src/autogluon/multimodal/matcher.py index 1c4f3cb1e07..962b08a4060 100644 --- a/multimodal/src/autogluon/multimodal/matcher.py +++ b/multimodal/src/autogluon/multimodal/matcher.py @@ -78,8 +78,10 @@ filter_hyperparameters, get_available_devices, get_config, + get_dir_ckpt_paths, get_fit_complete_message, get_fit_start_message, + get_load_ckpt_paths, get_local_pretrained_config_paths, get_minmax_mode, get_stopping_threshold, @@ -1941,10 +1943,11 @@ def load( ------- The loaded matcher object. """ - path = os.path.abspath(os.path.expanduser(path)) - assert os.path.isdir(path), f"'{path}' must be an existing directory." + dir_path, ckpt_path = get_dir_ckpt_paths(path=path) + + assert os.path.isdir(dir_path), f"'{dir_path}' must be an existing directory." matcher = cls(query="", response="") - matcher = cls._load_metadata(matcher=matcher, path=path, resume=resume, verbosity=verbosity) + matcher = cls._load_metadata(matcher=matcher, path=dir_path, resume=resume, verbosity=verbosity) query_model, response_model = create_siamese_model( query_config=matcher._query_config, @@ -1952,42 +1955,11 @@ def load( pretrained=False, ) - resume_ckpt_path = os.path.join(path, LAST_CHECKPOINT) - final_ckpt_path = os.path.join(path, MODEL_CHECKPOINT) - if resume: # resume training which crashed before - if not os.path.isfile(resume_ckpt_path): - if os.path.isfile(final_ckpt_path): - raise ValueError( - f"Resuming checkpoint '{resume_ckpt_path}' doesn't exist, but " - f"final checkpoint '{final_ckpt_path}' exists, which means training " - f"is already completed." - ) - else: - raise ValueError( - f"Resuming checkpoint '{resume_ckpt_path}' and " - f"final checkpoint '{final_ckpt_path}' both don't exist. " - f"Consider starting training from scratch." - ) - load_path = resume_ckpt_path - logger.info(f"Resume training from checkpoint: '{resume_ckpt_path}'") - ckpt_path = resume_ckpt_path - else: # load a model checkpoint for prediction, evaluation, or continuing training on new data - if not os.path.isfile(final_ckpt_path): - if os.path.isfile(resume_ckpt_path): - raise ValueError( - f"Final checkpoint '{final_ckpt_path}' doesn't exist, but " - f"resuming checkpoint '{resume_ckpt_path}' exists, which means training " - f"is not done yet. Consider resume training from '{resume_ckpt_path}'." - ) - else: - raise ValueError( - f"Resuming checkpoint '{resume_ckpt_path}' and " - f"final checkpoint '{final_ckpt_path}' both don't exist. " - f"Consider starting training from scratch." - ) - load_path = final_ckpt_path - logger.info(f"Load pretrained checkpoint: {os.path.join(path, MODEL_CHECKPOINT)}") - ckpt_path = None # must set None since we do not resume training + load_path, ckpt_path = get_load_ckpt_paths( + ckpt_path=ckpt_path, + dir_path=dir_path, + resume=resume, + ) query_model, response_model = cls._load_state_dict( query_model=query_model, diff --git a/multimodal/src/autogluon/multimodal/predictor.py b/multimodal/src/autogluon/multimodal/predictor.py index f57907c4b53..460e61d1e5e 100644 --- a/multimodal/src/autogluon/multimodal/predictor.py +++ b/multimodal/src/autogluon/multimodal/predictor.py @@ -132,8 +132,10 @@ get_available_devices, get_config, get_detection_classes, + get_dir_ckpt_paths, get_fit_complete_message, get_fit_start_message, + get_load_ckpt_paths, get_local_pretrained_config_paths, get_minmax_mode, get_mixup, @@ -152,7 +154,6 @@ modify_duplicate_model_names, object_detection_data_to_df, predict, - process_batch, save_ovd_result_df, save_pretrained_model_configs, save_result_df, @@ -2623,6 +2624,7 @@ def load( can be completely or partially trained by .fit(). If a previous training has completed, it will load the checkpoint `model.ckpt`. Otherwise if a previous training accidentally collapses in the middle, it can load the `last.ckpt` checkpoint by setting `resume=True`. + It also supports loading one specific checkpoint given its path. Parameters ---------- @@ -2639,11 +2641,12 @@ def load( ------- The loaded predictor object. """ - path = os.path.abspath(os.path.expanduser(path)) - assert os.path.isdir(path), f"'{path}' must be an existing directory." + dir_path, ckpt_path = get_dir_ckpt_paths(path=path) + + assert os.path.isdir(dir_path), f"'{dir_path}' must be an existing directory." predictor = cls(label="dummy_label") - with open(os.path.join(path, "assets.json"), "r") as fp: + with open(os.path.join(dir_path, "assets.json"), "r") as fp: assets = json.load(fp) if "class_name" in assets and assets["class_name"] == "MultiModalMatcher": predictor._matcher = MultiModalMatcher.load( @@ -2653,7 +2656,7 @@ def load( ) return predictor - predictor = cls._load_metadata(predictor=predictor, path=path, resume=resume, verbosity=verbosity) + predictor = cls._load_metadata(predictor=predictor, path=dir_path, resume=resume, verbosity=verbosity) efficient_finetune = OmegaConf.select(predictor._config, "optimization.efficient_finetune") @@ -2674,42 +2677,11 @@ def load( model=model, ) - resume_ckpt_path = os.path.join(path, LAST_CHECKPOINT) - final_ckpt_path = os.path.join(path, MODEL_CHECKPOINT) - if resume: # resume training which crashed before - if not os.path.isfile(resume_ckpt_path): - if os.path.isfile(final_ckpt_path): - raise ValueError( - f"Resuming checkpoint '{resume_ckpt_path}' doesn't exist, but " - f"final checkpoint '{final_ckpt_path}' exists, which means training " - f"is already completed." - ) - else: - raise ValueError( - f"Resuming checkpoint '{resume_ckpt_path}' and " - f"final checkpoint '{final_ckpt_path}' both don't exist. " - f"Consider starting training from scratch." - ) - load_path = resume_ckpt_path - logger.info(f"Resume training from checkpoint: '{resume_ckpt_path}'") - ckpt_path = resume_ckpt_path - else: # load a model checkpoint for prediction, evaluation, or continuing training on new data - if not os.path.isfile(final_ckpt_path): - if os.path.isfile(resume_ckpt_path): - raise ValueError( - f"Final checkpoint '{final_ckpt_path}' doesn't exist, but " - f"resuming checkpoint '{resume_ckpt_path}' exists, which means training " - f"is not done yet. Consider resume training from '{resume_ckpt_path}'." - ) - else: - raise ValueError( - f"Resuming checkpoint '{resume_ckpt_path}' and " - f"final checkpoint '{final_ckpt_path}' both don't exist. " - f"Consider starting training from scratch." - ) - load_path = final_ckpt_path - logger.info(f"Load pretrained checkpoint: {os.path.join(path, MODEL_CHECKPOINT)}") - ckpt_path = None # must set None since we do not resume training + load_path, ckpt_path = get_load_ckpt_paths( + ckpt_path=ckpt_path, + dir_path=dir_path, + resume=resume, + ) model = cls._load_state_dict( model=model, diff --git a/multimodal/src/autogluon/multimodal/utils/__init__.py b/multimodal/src/autogluon/multimodal/utils/__init__.py index 7cec1a72b94..e47d768d90a 100644 --- a/multimodal/src/autogluon/multimodal/utils/__init__.py +++ b/multimodal/src/autogluon/multimodal/utils/__init__.py @@ -44,7 +44,7 @@ from .export import ExportMixin from .hpo import hyperparameter_tune from .inference import extract_from_output, infer_batch, predict, process_batch, use_realtime -from .load import CustomUnpickler, load_text_tokenizers +from .load import CustomUnpickler, get_dir_ckpt_paths, get_load_ckpt_paths, load_text_tokenizers from .log import LogFilter, apply_log_filter, get_fit_complete_message, get_fit_start_message, make_exp_dir from .map import MeanAveragePrecision from .matcher import compute_semantic_similarity, convert_data_for_ranking, create_siamese_model, semantic_search diff --git a/multimodal/src/autogluon/multimodal/utils/load.py b/multimodal/src/autogluon/multimodal/utils/load.py index c87c3c3653e..1d3e5ec3915 100644 --- a/multimodal/src/autogluon/multimodal/utils/load.py +++ b/multimodal/src/autogluon/multimodal/utils/load.py @@ -3,7 +3,7 @@ import pickle from typing import Dict, List, Optional, Tuple, Union -from ..constants import AUTOMM +from ..constants import LAST_CHECKPOINT, MODEL_CHECKPOINT from ..data import DocumentProcessor, NerProcessor, TextProcessor logger = logging.getLogger(__name__) @@ -51,3 +51,89 @@ def find_class(self, module, name): renamed_module = module.replace("autogluon.text.automm", "autogluon.multimodal") return super(CustomUnpickler, self).find_class(renamed_module, name) + + +def get_dir_ckpt_paths(path: str): + """ + Get the dir path and ckpt path from a path. + + Parameters + ---------- + path + A path which can be either a dir or ckpt path. + + Returns + ------- + The dir and ckpt paths. + """ + path = os.path.abspath(os.path.expanduser(path)) + if os.path.isfile(path): + dir_path = os.path.dirname(path) + ckpt_path = path + else: + dir_path = path + ckpt_path = None + + return dir_path, ckpt_path + + +def get_load_ckpt_paths(ckpt_path: str, dir_path: str, resume: bool): + """ + Get the load_path and ckpt_path. They can be the same or different. + #TODO: merging load_path and ckpt_path. + + Parameters + ---------- + ckpt_path + The path of one checkpoint, which can be None. + dir_path + The dir path from where to load model. + resume + Whether to resume training. + + Returns + ------- + load_path and ckpt_path + """ + if ckpt_path: + load_path = ckpt_path + logger.info(f"Loading checkpoint: '{ckpt_path}'") + else: + resume_ckpt_path = os.path.join(dir_path, LAST_CHECKPOINT) + final_ckpt_path = os.path.join(dir_path, MODEL_CHECKPOINT) + if resume: # resume training which crashed before + if not os.path.isfile(resume_ckpt_path): + if os.path.isfile(final_ckpt_path): + raise ValueError( + f"Resuming checkpoint '{resume_ckpt_path}' doesn't exist, but " + f"final checkpoint '{final_ckpt_path}' exists, which means training " + f"is already completed." + ) + else: + raise ValueError( + f"Resuming checkpoint '{resume_ckpt_path}' and " + f"final checkpoint '{final_ckpt_path}' both don't exist. " + f"Consider starting training from scratch." + ) + load_path = resume_ckpt_path + logger.info(f"Resume training from checkpoint: '{resume_ckpt_path}'") + ckpt_path = resume_ckpt_path + else: # load a model checkpoint for prediction, evaluation, or continuing training on new data + if not os.path.isfile(final_ckpt_path): + if os.path.isfile(resume_ckpt_path): + raise ValueError( + f"Final checkpoint '{final_ckpt_path}' doesn't exist, but " + f"resuming checkpoint '{resume_ckpt_path}' exists, which means training " + f"is not done yet. Consider resume training from '{resume_ckpt_path}'." + ) + else: + raise ValueError( + f"Resuming checkpoint '{resume_ckpt_path}' and " + f"final checkpoint '{final_ckpt_path}' both don't exist. " + f"Consider starting training from scratch." + ) + load_path = final_ckpt_path + logger.info(f"Load pretrained checkpoint: {os.path.join(dir_path, MODEL_CHECKPOINT)}") + ckpt_path = None # must set None since we do not resume training + + return load_path, ckpt_path diff --git a/multimodal/tests/unittests/predictor/test_predictor.py b/multimodal/tests/unittests/predictor/test_predictor.py index cd9b14f9306..d03a272ea8a 100644 --- a/multimodal/tests/unittests/predictor/test_predictor.py +++ b/multimodal/tests/unittests/predictor/test_predictor.py @@ -708,3 +708,21 @@ def test_fit_with_data_path(): predictor = MultiModalPredictor(label="label") predictor.fit(train_data=train_csv_file, time_limit=0) predictor.fit(train_data=train_csv_file, tuning_data=train_csv_file, time_limit=0) + + +def test_load_ckpt(): + download_dir = "./" + train_data, test_data = shopee_dataset(download_dir=download_dir) + predictor = MultiModalPredictor(label="label") + predictor.fit(train_data=train_data, time_limit=20) + src_file = os.path.join(predictor.path, "model.ckpt") + dest_file = os.path.join(predictor.path, "epoch=8-step=18.ckpt") + shutil.copy(src_file, dest_file) + loaded_predictor = MultiModalPredictor.load(path=dest_file) + + predictions = predictor.predict(test_data, as_pandas=False) + predictions2 = loaded_predictor.predict(test_data, as_pandas=False) + npt.assert_equal(predictions, predictions2) + predictions_prob = predictor.predict_proba(test_data, as_pandas=False) + predictions2_prob = loaded_predictor.predict_proba(test_data, as_pandas=False) + npt.assert_equal(predictions_prob, predictions2_prob)