@@ -367,8 +367,13 @@ def test(
367
367
# NOTE, trainer.test takes only lightning based checkpoint.
368
368
# So, it can't take the OTX1.x checkpoint.
369
369
if checkpoint is not None and not is_ir_ckpt :
370
+ kwargs_user_input : dict [str , Any ] = {}
371
+ if self .task == OTXTaskType .ZERO_SHOT_VISUAL_PROMPTING :
372
+ # to update user's custom infer_reference_info_root through cli for zero-shot learning
373
+ kwargs_user_input .update (infer_reference_info_root = self .model .infer_reference_info_root )
374
+
370
375
model_cls = model .__class__
371
- model = model_cls .load_from_checkpoint (checkpoint_path = checkpoint )
376
+ model = model_cls .load_from_checkpoint (checkpoint_path = checkpoint , ** kwargs_user_input )
372
377
373
378
if model .label_info != self .datamodule .label_info :
374
379
if (
@@ -462,8 +467,13 @@ def predict(
462
467
datamodule = self ._auto_configurator .update_ov_subset_pipeline (datamodule = datamodule , subset = "test" )
463
468
464
469
if checkpoint is not None and not is_ir_ckpt :
470
+ kwargs_user_input : dict [str , Any ] = {}
471
+ if self .task == OTXTaskType .ZERO_SHOT_VISUAL_PROMPTING :
472
+ # to update user's custom infer_reference_info_root through cli for zero-shot learning
473
+ kwargs_user_input .update (infer_reference_info_root = self .model .infer_reference_info_root )
474
+
465
475
model_cls = model .__class__
466
- model = model_cls .load_from_checkpoint (checkpoint_path = checkpoint )
476
+ model = model_cls .load_from_checkpoint (checkpoint_path = checkpoint , ** kwargs_user_input )
467
477
468
478
if model .label_info != self .datamodule .label_info :
469
479
msg = (
@@ -574,8 +584,17 @@ def export(
574
584
)
575
585
576
586
if not is_ir_ckpt :
587
+ kwargs_user_input : dict [str , Any ] = {}
588
+ if self .task == OTXTaskType .ZERO_SHOT_VISUAL_PROMPTING :
589
+ # to update user's custom infer_reference_info_root through cli for zero-shot learning
590
+ kwargs_user_input .update (infer_reference_info_root = self .model .infer_reference_info_root )
591
+
577
592
model_cls = self .model .__class__
578
- self .model = model_cls .load_from_checkpoint (checkpoint_path = checkpoint , map_location = "cpu" )
593
+ self .model = model_cls .load_from_checkpoint (
594
+ checkpoint_path = checkpoint ,
595
+ map_location = "cpu" ,
596
+ ** kwargs_user_input ,
597
+ )
579
598
self .model .eval ()
580
599
581
600
self .model .explain_mode = explain
@@ -738,8 +757,13 @@ def explain(
738
757
model = self ._auto_configurator .get_ov_model (model_name = str (checkpoint ), label_info = datamodule .label_info )
739
758
740
759
if checkpoint is not None and not is_ir_ckpt :
760
+ kwargs_user_input : dict [str , Any ] = {}
761
+ if self .task == OTXTaskType .ZERO_SHOT_VISUAL_PROMPTING :
762
+ # to update user's custom infer_reference_info_root through cli for zero-shot learning
763
+ kwargs_user_input .update (infer_reference_info_root = self .model .infer_reference_info_root )
764
+
741
765
model_cls = model .__class__
742
- model = model_cls .load_from_checkpoint (checkpoint_path = checkpoint )
766
+ model = model_cls .load_from_checkpoint (checkpoint_path = checkpoint , ** kwargs_user_input )
743
767
744
768
if model .label_info != self .datamodule .label_info :
745
769
msg = (
@@ -841,8 +865,17 @@ def benchmark(
841
865
)
842
866
843
867
if not is_ir_ckpt :
868
+ kwargs_user_input : dict [str , Any ] = {}
869
+ if self .task == OTXTaskType .ZERO_SHOT_VISUAL_PROMPTING :
870
+ # to update user's custom infer_reference_info_root through cli for zero-shot learning
871
+ kwargs_user_input .update (infer_reference_info_root = self .model .infer_reference_info_root )
872
+
844
873
model_cls = self .model .__class__
845
- self .model = model_cls .load_from_checkpoint (checkpoint_path = checkpoint , map_location = "cpu" )
874
+ self .model = model_cls .load_from_checkpoint (
875
+ checkpoint_path = checkpoint ,
876
+ map_location = "cpu" ,
877
+ ** kwargs_user_input ,
878
+ )
846
879
elif isinstance (self .model , OVModel ):
847
880
msg = "To run benchmark on OV model, checkpoint must be specified."
848
881
raise RuntimeError (msg )
0 commit comments