Skip to content

Commit b465671

Browse files
committed
Update engine to enable openvinotoolkit#3607 and openvinotoolkit#3611
1 parent df29149 commit b465671

File tree

1 file changed

+38
-5
lines changed

1 file changed

+38
-5
lines changed

src/otx/engine/engine.py

+38-5
Original file line numberDiff line numberDiff line change
@@ -367,8 +367,13 @@ def test(
367367
# NOTE, trainer.test takes only lightning based checkpoint.
368368
# So, it can't take the OTX1.x checkpoint.
369369
if checkpoint is not None and not is_ir_ckpt:
370+
kwargs_user_input: dict[str, Any] = {}
371+
if self.task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING:
372+
# to update user's custom infer_reference_info_root through cli for zero-shot learning
373+
kwargs_user_input.update(infer_reference_info_root=self.model.infer_reference_info_root)
374+
370375
model_cls = model.__class__
371-
model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint)
376+
model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, **kwargs_user_input)
372377

373378
if model.label_info != self.datamodule.label_info:
374379
if (
@@ -462,8 +467,13 @@ def predict(
462467
datamodule = self._auto_configurator.update_ov_subset_pipeline(datamodule=datamodule, subset="test")
463468

464469
if checkpoint is not None and not is_ir_ckpt:
470+
kwargs_user_input: dict[str, Any] = {}
471+
if self.task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING:
472+
# to update user's custom infer_reference_info_root through cli for zero-shot learning
473+
kwargs_user_input.update(infer_reference_info_root=self.model.infer_reference_info_root)
474+
465475
model_cls = model.__class__
466-
model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint)
476+
model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, **kwargs_user_input)
467477

468478
if model.label_info != self.datamodule.label_info:
469479
msg = (
@@ -574,8 +584,17 @@ def export(
574584
)
575585

576586
if not is_ir_ckpt:
587+
kwargs_user_input: dict[str, Any] = {}
588+
if self.task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING:
589+
# to update user's custom infer_reference_info_root through cli for zero-shot learning
590+
kwargs_user_input.update(infer_reference_info_root=self.model.infer_reference_info_root)
591+
577592
model_cls = self.model.__class__
578-
self.model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, map_location="cpu")
593+
self.model = model_cls.load_from_checkpoint(
594+
checkpoint_path=checkpoint,
595+
map_location="cpu",
596+
**kwargs_user_input,
597+
)
579598
self.model.eval()
580599

581600
self.model.explain_mode = explain
@@ -738,8 +757,13 @@ def explain(
738757
model = self._auto_configurator.get_ov_model(model_name=str(checkpoint), label_info=datamodule.label_info)
739758

740759
if checkpoint is not None and not is_ir_ckpt:
760+
kwargs_user_input: dict[str, Any] = {}
761+
if self.task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING:
762+
# to update user's custom infer_reference_info_root through cli for zero-shot learning
763+
kwargs_user_input.update(infer_reference_info_root=self.model.infer_reference_info_root)
764+
741765
model_cls = model.__class__
742-
model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint)
766+
model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, **kwargs_user_input)
743767

744768
if model.label_info != self.datamodule.label_info:
745769
msg = (
@@ -841,8 +865,17 @@ def benchmark(
841865
)
842866

843867
if not is_ir_ckpt:
868+
kwargs_user_input: dict[str, Any] = {}
869+
if self.task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING:
870+
# to update user's custom infer_reference_info_root through cli for zero-shot learning
871+
kwargs_user_input.update(infer_reference_info_root=self.model.infer_reference_info_root)
872+
844873
model_cls = self.model.__class__
845-
self.model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, map_location="cpu")
874+
self.model = model_cls.load_from_checkpoint(
875+
checkpoint_path=checkpoint,
876+
map_location="cpu",
877+
**kwargs_user_input,
878+
)
846879
elif isinstance(self.model, OVModel):
847880
msg = "To run benchmark on OV model, checkpoint must be specified."
848881
raise RuntimeError(msg)

0 commit comments

Comments
 (0)