diff --git a/src/super_gradients/recipes/checkpoint_params/default_checkpoint_params.yaml b/src/super_gradients/recipes/checkpoint_params/default_checkpoint_params.yaml index 513c565f0b..2dae76d703 100644 --- a/src/super_gradients/recipes/checkpoint_params/default_checkpoint_params.yaml +++ b/src/super_gradients/recipes/checkpoint_params/default_checkpoint_params.yaml @@ -7,3 +7,9 @@ strict_load: # key matching strictness for loading checkpoint's weights _target_: super_gradients.training.sg_trainer.StrictLoad value: no_key_matching pretrained_weights: # a string describing the dataset of the pretrained weights (for example "imagenent"). + +# num_classes of checkpoint_path/ pretrained_weights, when checkpoint_path is not None. +# Used when num_classes != checkpoint_num_class. +# In this case, the module will be initialized with checkpoint_num_class, then weights will be loaded. +# Finally model.replace_head(new_num_classes=num_classes) is called to replace the head with new_num_classes. +checkpoint_num_classes: # number of classes in the checkpoint diff --git a/src/super_gradients/training/kd_trainer/kd_trainer.py b/src/super_gradients/training/kd_trainer/kd_trainer.py index 4f6a6b3064..325ddfe4cc 100644 --- a/src/super_gradients/training/kd_trainer/kd_trainer.py +++ b/src/super_gradients/training/kd_trainer/kd_trainer.py @@ -76,6 +76,7 @@ def train_from_config(cls, cfg: Union[DictConfig, dict]) -> None: pretrained_weights=cfg.student_checkpoint_params.pretrained_weights, checkpoint_path=cfg.student_checkpoint_params.checkpoint_path, load_backbone=cfg.student_checkpoint_params.load_backbone, + checkpoint_num_classes=get_param(cfg.student_checkpoint_params, "checkpoint_num_classes"), ) teacher = models.get( @@ -85,6 +86,7 @@ def train_from_config(cls, cfg: Union[DictConfig, dict]) -> None: pretrained_weights=cfg.teacher_checkpoint_params.pretrained_weights, checkpoint_path=cfg.teacher_checkpoint_params.checkpoint_path, load_backbone=cfg.teacher_checkpoint_params.load_backbone, + checkpoint_num_classes=get_param(cfg.teacher_checkpoint_params, "checkpoint_num_classes"), ) recipe_logged_cfg = {"recipe_config": OmegaConf.to_container(cfg, resolve=True)} diff --git a/src/super_gradients/training/sg_trainer/sg_trainer.py b/src/super_gradients/training/sg_trainer/sg_trainer.py index 2d2254f9c8..c6f5d51901 100755 --- a/src/super_gradients/training/sg_trainer/sg_trainer.py +++ b/src/super_gradients/training/sg_trainer/sg_trainer.py @@ -380,6 +380,7 @@ def evaluate_from_recipe(cls, cfg: DictConfig) -> Tuple[nn.Module, Tuple]: pretrained_weights=cfg.checkpoint_params.pretrained_weights, checkpoint_path=cfg.checkpoint_params.checkpoint_path, load_backbone=cfg.checkpoint_params.load_backbone, + checkpoint_num_classes=get_param(cfg.checkpoint_params, "checkpoint_num_classes"), ) # TEST @@ -2340,6 +2341,7 @@ def quantize_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module, pretrained_weights=cfg.checkpoint_params.pretrained_weights, checkpoint_path=cfg.checkpoint_params.checkpoint_path, load_backbone=False, + checkpoint_num_classes=get_param(cfg.checkpoint_params, "checkpoint_num_classes"), ) recipe_logged_cfg = {"recipe_config": OmegaConf.to_container(cfg, resolve=True)}