Skip to content

Commit

Permalink
Fix applying model's hparams when loading model from checkpoint (#4057)
Browse files Browse the repository at this point in the history
  • Loading branch information
sungchul2 authored Oct 23, 2024
1 parent 7b07e6b commit d85f5da
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 10 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ All notable changes to this project will be documented in this file.
(<https://github.com/openvinotoolkit/training_extensions/pull/4056>)
- Upgrade MAPI in 2.2
(<https://github.com/openvinotoolkit/training_extensions/pull/4052>)
- Fix applying model's hparams when loading model from checkpoint
(<https://github.com/openvinotoolkit/training_extensions/pull/4057>)

## \[v2.1.0\]

Expand Down
40 changes: 35 additions & 5 deletions src/otx/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,8 +367,14 @@ def test(
# NOTE, trainer.test takes only lightning based checkpoint.
# So, it can't take the OTX1.x checkpoint.
if checkpoint is not None and not is_ir_ckpt:
kwargs_user_input: dict[str, Any] = {}
if self.task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING:
# to update user's custom infer_reference_info_root through cli for zero-shot learning
# TODO (sungchul): revisit for better solution
kwargs_user_input.update(infer_reference_info_root=self.model.infer_reference_info_root)

model_cls = model.__class__
model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, **model.hparams)
model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, **kwargs_user_input)

if model.label_info != self.datamodule.label_info:
if (
Expand Down Expand Up @@ -462,8 +468,14 @@ def predict(
datamodule = self._auto_configurator.update_ov_subset_pipeline(datamodule=datamodule, subset="test")

if checkpoint is not None and not is_ir_ckpt:
kwargs_user_input: dict[str, Any] = {}
if self.task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING:
# to update user's custom infer_reference_info_root through cli for zero-shot learning
# TODO (sungchul): revisit for better solution
kwargs_user_input.update(infer_reference_info_root=self.model.infer_reference_info_root)

model_cls = model.__class__
model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, **model.hparams)
model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, **kwargs_user_input)

if model.label_info != self.datamodule.label_info:
msg = (
Expand Down Expand Up @@ -574,11 +586,17 @@ def export(
)

if not is_ir_ckpt:
kwargs_user_input: dict[str, Any] = {}
if self.task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING:
# to update user's custom infer_reference_info_root through cli for zero-shot learning
# TODO (sungchul): revisit for better solution
kwargs_user_input.update(infer_reference_info_root=self.model.infer_reference_info_root)

model_cls = self.model.__class__
self.model = model_cls.load_from_checkpoint(
checkpoint_path=checkpoint,
map_location="cpu",
**self.model.hparams,
**kwargs_user_input,
)
self.model.eval()

Expand Down Expand Up @@ -742,8 +760,14 @@ def explain(
model = self._auto_configurator.get_ov_model(model_name=str(checkpoint), label_info=datamodule.label_info)

if checkpoint is not None and not is_ir_ckpt:
kwargs_user_input: dict[str, Any] = {}
if self.task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING:
# to update user's custom infer_reference_info_root through cli for zero-shot learning
# TODO (sungchul): revisit for better solution
kwargs_user_input.update(infer_reference_info_root=self.model.infer_reference_info_root)

model_cls = model.__class__
model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, **model.hparams)
model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, **kwargs_user_input)

if model.label_info != self.datamodule.label_info:
msg = (
Expand Down Expand Up @@ -845,11 +869,17 @@ def benchmark(
)

if not is_ir_ckpt:
kwargs_user_input: dict[str, Any] = {}
if self.task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING:
# to update user's custom infer_reference_info_root through cli for zero-shot learning
# TODO (sungchul): revisit for better solution
kwargs_user_input.update(infer_reference_info_root=self.model.infer_reference_info_root)

model_cls = self.model.__class__
self.model = model_cls.load_from_checkpoint(
checkpoint_path=checkpoint,
map_location="cpu",
**self.model.hparams,
**kwargs_user_input,
)
elif isinstance(self.model, OVModel):
msg = "To run benchmark on OV model, checkpoint must be specified."
Expand Down
6 changes: 1 addition & 5 deletions tests/unit/engine/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,11 +223,7 @@ def test_exporting(self, fxt_engine, mocker) -> None:
checkpoint = "path/to/checkpoint.ckpt"
fxt_engine.checkpoint = checkpoint
fxt_engine.export()
mock_load_from_checkpoint.assert_called_once_with(
checkpoint_path=checkpoint,
map_location="cpu",
**fxt_engine.model.hparams,
)
mock_load_from_checkpoint.assert_called_once_with(checkpoint_path=checkpoint, map_location="cpu")
mock_export.assert_called_once_with(
output_dir=Path(fxt_engine.work_dir),
base_name="exported_model",
Expand Down

0 comments on commit d85f5da

Please sign in to comment.