diff --git a/src/super_gradients/training/sg_trainer/sg_trainer.py b/src/super_gradients/training/sg_trainer/sg_trainer.py index 069668b186..cd433eefd6 100755 --- a/src/super_gradients/training/sg_trainer/sg_trainer.py +++ b/src/super_gradients/training/sg_trainer/sg_trainer.py @@ -324,12 +324,14 @@ def evaluate_from_recipe(cls, cfg: DictConfig) -> Tuple[nn.Module, Tuple]: name=cfg.val_dataloader, dataset_params=cfg.dataset_params.val_dataset_params, dataloader_params=cfg.dataset_params.val_dataloader_params ) - if cfg.checkpoint_params.pretrained_weights is None and cfg.checkpoint_params.checkpoint_path is None: + if cfg.checkpoint_params.checkpoint_path is None: logger.info( "checkpoint_params.checkpoint_path was not provided, " "so the recipe will be evaluated using checkpoints_dir/training_hyperparams.ckpt_name" ) - checkpoints_dir = Path(get_checkpoints_dir_path(experiment_name=cfg.experiment_name, ckpt_root_dir=cfg.ckpt_root_dir)) - cfg.checkpoint_params.checkpoint_path = str(checkpoints_dir / cfg.training_hyperparams.ckpt_name) + checkpoints_dir = get_checkpoints_dir_path(experiment_name=cfg.experiment_name, ckpt_root_dir=cfg.ckpt_root_dir) + checkpoint_path = os.path.join(checkpoints_dir, cfg.training_hyperparams.ckpt_name) + if os.path.exists(checkpoint_path): + cfg.checkpoint_params.checkpoint_path = checkpoint_path logger.info(f"Evaluating checkpoint: {cfg.checkpoint_params.checkpoint_path}")