Skip to content

Commit

Permalink
[AutoMM] Support loading specific checkpoints (open-mmlab#3244)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqiangdon authored May 30, 2023
1 parent 45ad3ec commit d9e7bec
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 82 deletions.
50 changes: 11 additions & 39 deletions multimodal/src/autogluon/multimodal/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,10 @@
filter_hyperparameters,
get_available_devices,
get_config,
get_dir_ckpt_paths,
get_fit_complete_message,
get_fit_start_message,
get_load_ckpt_paths,
get_local_pretrained_config_paths,
get_minmax_mode,
get_stopping_threshold,
Expand Down Expand Up @@ -1941,53 +1943,23 @@ def load(
-------
The loaded matcher object.
"""
path = os.path.abspath(os.path.expanduser(path))
assert os.path.isdir(path), f"'{path}' must be an existing directory."
dir_path, ckpt_path = get_dir_ckpt_paths(path=path)

assert os.path.isdir(dir_path), f"'{dir_path}' must be an existing directory."
matcher = cls(query="", response="")
matcher = cls._load_metadata(matcher=matcher, path=path, resume=resume, verbosity=verbosity)
matcher = cls._load_metadata(matcher=matcher, path=dir_path, resume=resume, verbosity=verbosity)

query_model, response_model = create_siamese_model(
query_config=matcher._query_config,
response_config=matcher._response_config,
pretrained=False,
)

resume_ckpt_path = os.path.join(path, LAST_CHECKPOINT)
final_ckpt_path = os.path.join(path, MODEL_CHECKPOINT)
if resume: # resume training which crashed before
if not os.path.isfile(resume_ckpt_path):
if os.path.isfile(final_ckpt_path):
raise ValueError(
f"Resuming checkpoint '{resume_ckpt_path}' doesn't exist, but "
f"final checkpoint '{final_ckpt_path}' exists, which means training "
f"is already completed."
)
else:
raise ValueError(
f"Resuming checkpoint '{resume_ckpt_path}' and "
f"final checkpoint '{final_ckpt_path}' both don't exist. "
f"Consider starting training from scratch."
)
load_path = resume_ckpt_path
logger.info(f"Resume training from checkpoint: '{resume_ckpt_path}'")
ckpt_path = resume_ckpt_path
else: # load a model checkpoint for prediction, evaluation, or continuing training on new data
if not os.path.isfile(final_ckpt_path):
if os.path.isfile(resume_ckpt_path):
raise ValueError(
f"Final checkpoint '{final_ckpt_path}' doesn't exist, but "
f"resuming checkpoint '{resume_ckpt_path}' exists, which means training "
f"is not done yet. Consider resume training from '{resume_ckpt_path}'."
)
else:
raise ValueError(
f"Resuming checkpoint '{resume_ckpt_path}' and "
f"final checkpoint '{final_ckpt_path}' both don't exist. "
f"Consider starting training from scratch."
)
load_path = final_ckpt_path
logger.info(f"Load pretrained checkpoint: {os.path.join(path, MODEL_CHECKPOINT)}")
ckpt_path = None # must set None since we do not resume training
load_path, ckpt_path = get_load_ckpt_paths(
ckpt_path=ckpt_path,
dir_path=dir_path,
resume=resume,
)

query_model, response_model = cls._load_state_dict(
query_model=query_model,
Expand Down
54 changes: 13 additions & 41 deletions multimodal/src/autogluon/multimodal/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,10 @@
get_available_devices,
get_config,
get_detection_classes,
get_dir_ckpt_paths,
get_fit_complete_message,
get_fit_start_message,
get_load_ckpt_paths,
get_local_pretrained_config_paths,
get_minmax_mode,
get_mixup,
Expand All @@ -152,7 +154,6 @@
modify_duplicate_model_names,
object_detection_data_to_df,
predict,
process_batch,
save_ovd_result_df,
save_pretrained_model_configs,
save_result_df,
Expand Down Expand Up @@ -2623,6 +2624,7 @@ def load(
can be completely or partially trained by .fit(). If a previous training has completed,
it will load the checkpoint `model.ckpt`. Otherwise if a previous training accidentally
collapses in the middle, it can load the `last.ckpt` checkpoint by setting `resume=True`.
It also supports loading one specific checkpoint given its path.
Parameters
----------
Expand All @@ -2639,11 +2641,12 @@ def load(
-------
The loaded predictor object.
"""
path = os.path.abspath(os.path.expanduser(path))
assert os.path.isdir(path), f"'{path}' must be an existing directory."
dir_path, ckpt_path = get_dir_ckpt_paths(path=path)

assert os.path.isdir(dir_path), f"'{dir_path}' must be an existing directory."
predictor = cls(label="dummy_label")

with open(os.path.join(path, "assets.json"), "r") as fp:
with open(os.path.join(dir_path, "assets.json"), "r") as fp:
assets = json.load(fp)
if "class_name" in assets and assets["class_name"] == "MultiModalMatcher":
predictor._matcher = MultiModalMatcher.load(
Expand All @@ -2653,7 +2656,7 @@ def load(
)
return predictor

predictor = cls._load_metadata(predictor=predictor, path=path, resume=resume, verbosity=verbosity)
predictor = cls._load_metadata(predictor=predictor, path=dir_path, resume=resume, verbosity=verbosity)

efficient_finetune = OmegaConf.select(predictor._config, "optimization.efficient_finetune")

Expand All @@ -2674,42 +2677,11 @@ def load(
model=model,
)

resume_ckpt_path = os.path.join(path, LAST_CHECKPOINT)
final_ckpt_path = os.path.join(path, MODEL_CHECKPOINT)
if resume: # resume training which crashed before
if not os.path.isfile(resume_ckpt_path):
if os.path.isfile(final_ckpt_path):
raise ValueError(
f"Resuming checkpoint '{resume_ckpt_path}' doesn't exist, but "
f"final checkpoint '{final_ckpt_path}' exists, which means training "
f"is already completed."
)
else:
raise ValueError(
f"Resuming checkpoint '{resume_ckpt_path}' and "
f"final checkpoint '{final_ckpt_path}' both don't exist. "
f"Consider starting training from scratch."
)
load_path = resume_ckpt_path
logger.info(f"Resume training from checkpoint: '{resume_ckpt_path}'")
ckpt_path = resume_ckpt_path
else: # load a model checkpoint for prediction, evaluation, or continuing training on new data
if not os.path.isfile(final_ckpt_path):
if os.path.isfile(resume_ckpt_path):
raise ValueError(
f"Final checkpoint '{final_ckpt_path}' doesn't exist, but "
f"resuming checkpoint '{resume_ckpt_path}' exists, which means training "
f"is not done yet. Consider resume training from '{resume_ckpt_path}'."
)
else:
raise ValueError(
f"Resuming checkpoint '{resume_ckpt_path}' and "
f"final checkpoint '{final_ckpt_path}' both don't exist. "
f"Consider starting training from scratch."
)
load_path = final_ckpt_path
logger.info(f"Load pretrained checkpoint: {os.path.join(path, MODEL_CHECKPOINT)}")
ckpt_path = None # must set None since we do not resume training
load_path, ckpt_path = get_load_ckpt_paths(
ckpt_path=ckpt_path,
dir_path=dir_path,
resume=resume,
)

model = cls._load_state_dict(
model=model,
Expand Down
2 changes: 1 addition & 1 deletion multimodal/src/autogluon/multimodal/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from .export import ExportMixin
from .hpo import hyperparameter_tune
from .inference import extract_from_output, infer_batch, predict, process_batch, use_realtime
from .load import CustomUnpickler, load_text_tokenizers
from .load import CustomUnpickler, get_dir_ckpt_paths, get_load_ckpt_paths, load_text_tokenizers
from .log import LogFilter, apply_log_filter, get_fit_complete_message, get_fit_start_message, make_exp_dir
from .map import MeanAveragePrecision
from .matcher import compute_semantic_similarity, convert_data_for_ranking, create_siamese_model, semantic_search
Expand Down
88 changes: 87 additions & 1 deletion multimodal/src/autogluon/multimodal/utils/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pickle
from typing import Dict, List, Optional, Tuple, Union

from ..constants import AUTOMM
from ..constants import LAST_CHECKPOINT, MODEL_CHECKPOINT
from ..data import DocumentProcessor, NerProcessor, TextProcessor

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -51,3 +51,89 @@ def find_class(self, module, name):
renamed_module = module.replace("autogluon.text.automm", "autogluon.multimodal")

return super(CustomUnpickler, self).find_class(renamed_module, name)


def get_dir_ckpt_paths(path: str):
"""
Get the dir path and ckpt path from a path.
Parameters
----------
path
A path which can be either a dir or ckpt path.
Returns
-------
The dir and ckpt paths.
"""
path = os.path.abspath(os.path.expanduser(path))
if os.path.isfile(path):
dir_path = os.path.dirname(path)
ckpt_path = path
else:
dir_path = path
ckpt_path = None

return dir_path, ckpt_path


def get_load_ckpt_paths(ckpt_path: str, dir_path: str, resume: bool):
"""
Get the load_path and ckpt_path. They can be the same or different.
#TODO: merging load_path and ckpt_path.
Parameters
----------
ckpt_path
The path of one checkpoint, which can be None.
dir_path
The dir path from where to load model.
resume
Whether to resume training.
Returns
-------
load_path and ckpt_path
"""
if ckpt_path:
load_path = ckpt_path
logger.info(f"Loading checkpoint: '{ckpt_path}'")
else:
resume_ckpt_path = os.path.join(dir_path, LAST_CHECKPOINT)
final_ckpt_path = os.path.join(dir_path, MODEL_CHECKPOINT)
if resume: # resume training which crashed before
if not os.path.isfile(resume_ckpt_path):
if os.path.isfile(final_ckpt_path):
raise ValueError(
f"Resuming checkpoint '{resume_ckpt_path}' doesn't exist, but "
f"final checkpoint '{final_ckpt_path}' exists, which means training "
f"is already completed."
)
else:
raise ValueError(
f"Resuming checkpoint '{resume_ckpt_path}' and "
f"final checkpoint '{final_ckpt_path}' both don't exist. "
f"Consider starting training from scratch."
)
load_path = resume_ckpt_path
logger.info(f"Resume training from checkpoint: '{resume_ckpt_path}'")
ckpt_path = resume_ckpt_path
else: # load a model checkpoint for prediction, evaluation, or continuing training on new data
if not os.path.isfile(final_ckpt_path):
if os.path.isfile(resume_ckpt_path):
raise ValueError(
f"Final checkpoint '{final_ckpt_path}' doesn't exist, but "
f"resuming checkpoint '{resume_ckpt_path}' exists, which means training "
f"is not done yet. Consider resume training from '{resume_ckpt_path}'."
)
else:
raise ValueError(
f"Resuming checkpoint '{resume_ckpt_path}' and "
f"final checkpoint '{final_ckpt_path}' both don't exist. "
f"Consider starting training from scratch."
)
load_path = final_ckpt_path
logger.info(f"Load pretrained checkpoint: {os.path.join(dir_path, MODEL_CHECKPOINT)}")
ckpt_path = None # must set None since we do not resume training

return load_path, ckpt_path
18 changes: 18 additions & 0 deletions multimodal/tests/unittests/predictor/test_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,3 +708,21 @@ def test_fit_with_data_path():
predictor = MultiModalPredictor(label="label")
predictor.fit(train_data=train_csv_file, time_limit=0)
predictor.fit(train_data=train_csv_file, tuning_data=train_csv_file, time_limit=0)


def test_load_ckpt():
download_dir = "./"
train_data, test_data = shopee_dataset(download_dir=download_dir)
predictor = MultiModalPredictor(label="label")
predictor.fit(train_data=train_data, time_limit=20)
src_file = os.path.join(predictor.path, "model.ckpt")
dest_file = os.path.join(predictor.path, "epoch=8-step=18.ckpt")
shutil.copy(src_file, dest_file)
loaded_predictor = MultiModalPredictor.load(path=dest_file)

predictions = predictor.predict(test_data, as_pandas=False)
predictions2 = loaded_predictor.predict(test_data, as_pandas=False)
npt.assert_equal(predictions, predictions2)
predictions_prob = predictor.predict_proba(test_data, as_pandas=False)
predictions2_prob = loaded_predictor.predict_proba(test_data, as_pandas=False)
npt.assert_equal(predictions_prob, predictions2_prob)

0 comments on commit d9e7bec

Please sign in to comment.