-
Notifications
You must be signed in to change notification settings - Fork 517
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feature/sg 128 kd recipe resnet50 (#213)
* add kd train with resnet50 * add kd train with resnet50 * wip * update acc and s3 path * wip * split train_from_recipe * change import * load new resnet50 weights * wip * changes
- Loading branch information
1 parent
bd56ade
commit adf5dca
Showing
14 changed files
with
192 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,9 @@ | ||
from super_gradients.training import ARCHITECTURES, losses, utils, datasets_utils, DataAugmentation, \ | ||
TestDatasetInterface, SegmentationTestDatasetInterface, DetectionTestDatasetInterface, ClassificationTestDatasetInterface, SgModel | ||
TestDatasetInterface, SegmentationTestDatasetInterface, DetectionTestDatasetInterface, ClassificationTestDatasetInterface, SgModel, KDModel | ||
from super_gradients.common import init_trainer, is_distributed | ||
from super_gradients.examples.train_from_recipe_example import train_from_recipe | ||
from super_gradients.examples.train_from_kd_recipe_example import train_from_kd_recipe | ||
|
||
__all__ = ['ARCHITECTURES', 'losses', 'utils', 'datasets_utils', 'DataAugmentation', | ||
'TestDatasetInterface', 'SgModel', 'SegmentationTestDatasetInterface', 'DetectionTestDatasetInterface', | ||
'ClassificationTestDatasetInterface', 'init_trainer', 'is_distributed', 'train_from_recipe'] | ||
'TestDatasetInterface', 'SgModel', 'KDModel', 'SegmentationTestDatasetInterface', 'DetectionTestDatasetInterface', | ||
'ClassificationTestDatasetInterface', 'init_trainer', 'is_distributed', 'train_from_recipe', 'train_from_kd_recipe'] |
Empty file.
22 changes: 22 additions & 0 deletions
22
src/super_gradients/examples/train_from_kd_recipe_example/train_from_kd_recipe.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
""" | ||
Example code for running SuperGradient's recipes. | ||
General use: python train_from_kd_recipe.py --config-name="DESIRED_RECIPE". | ||
For recipe's specific instructions and details refer to the recipe's configuration file in the recipes directory. | ||
""" | ||
|
||
import super_gradients | ||
from omegaconf import DictConfig | ||
import hydra | ||
import pkg_resources | ||
from super_gradients.training.kd_trainer import KDTrainer | ||
|
||
|
||
@hydra.main(config_path=pkg_resources.resource_filename("super_gradients.recipes", "")) | ||
def main(cfg: DictConfig) -> None: | ||
KDTrainer.train(cfg) | ||
|
||
|
||
if __name__ == "__main__": | ||
super_gradients.init_trainer() | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# ResNet50 Imagenet classification training: | ||
# This example trains with batch_size = 192 * 8 GPUs, total 1536. | ||
# Training time on 8 x GeForce RTX A5000 is 9min / epoch. | ||
# Reach => 81.91 Top1 accuracy. | ||
# | ||
# Log and tensorboard at s3://deci-pretrained-models/KD_ResNet50_Beit_Base_ImageNet/average_model.pth | ||
|
||
# Instructions: | ||
# running from the command line, set the PYTHONPATH environment variable: (Replace "YOUR_LOCAL_PATH" with the path to the downloaded repo): | ||
# export PYTHONPATH="YOUR_LOCAL_PATH"/super_gradients/:"YOUR_LOCAL_PATH"/super_gradients/src/ | ||
# Then: | ||
# python train_from_recipe_example/train_from_kd_recipe.py --config-name=imagenet_resnet50_kd | ||
|
||
defaults: | ||
- training_hyperparams: imagenet_resnet50_kd_train_params | ||
- dataset_params: imagenet_dataset_params | ||
- arch_params: default_arch_params | ||
- checkpoint_params: default_checkpoint_params | ||
|
||
training_hyperparams: | ||
loss: kd_loss | ||
criterion_params: | ||
distillation_loss_coeff: 0.8 | ||
task_loss_fn: | ||
_target_: super_gradients.training.losses.label_smoothing_cross_entropy_loss.LabelSmoothingCrossEntropyLoss | ||
|
||
arch_params: | ||
teacher_input_adapter: | ||
_target_: super_gradients.training.utils.kd_model_utils.NormalizationAdapter | ||
mean_original: [0.485, 0.456, 0.406] | ||
std_original: [0.229, 0.224, 0.225] | ||
mean_required: [0.5, 0.5, 0.5] | ||
std_required: [0.5, 0.5, 0.5] | ||
|
||
student_arch_params: | ||
num_classes: 1000 | ||
|
||
teacher_arch_params: | ||
num_classes: 1000 | ||
image_size: [224, 224] | ||
patch_size: [16, 16] | ||
|
||
dataset_params: | ||
batch_size: 192 | ||
val_batch_size: 256 | ||
random_erase_prob: 0 | ||
random_erase_value: random | ||
train_interpolation: random | ||
rand_augment_config_string: rand-m7-mstd0.5 | ||
cutmix: True | ||
cutmix_params: | ||
mixup_alpha: 0.2 | ||
cutmix_alpha: 1.0 | ||
label_smoothing: 0.1 | ||
aug_repeat_count: 3 | ||
|
||
dataset_interface: | ||
imagenet: | ||
dataset_params: ${dataset_params} | ||
|
||
data_loader_num_workers: 8 | ||
|
||
model_checkpoints_location: local | ||
load_checkpoint: False | ||
checkpoint_params: | ||
load_checkpoint: ${load_checkpoint} | ||
teacher_pretrained_weights: imagenet | ||
|
||
run_teacher_on_eval: True | ||
|
||
experiment_name: resnet50_imagenet_KD_Model | ||
|
||
ckpt_root_dir: | ||
|
||
multi_gpu: | ||
_target_: super_gradients.training.sg_model.MultiGPUMode | ||
value: DDP | ||
|
||
sg_model: | ||
_target_: super_gradients.KDModel | ||
experiment_name: ${experiment_name} | ||
model_checkpoints_location: ${model_checkpoints_location} | ||
ckpt_root_dir: ${ckpt_root_dir} | ||
multi_gpu: ${multi_gpu} | ||
|
||
architecture: kd_module | ||
student_architecture: resnet50 | ||
teacher_architecture: beit_base_patch16_224 |
24 changes: 24 additions & 0 deletions
24
src/super_gradients/recipes/training_hyperparams/imagenet_resnet50_kd_train_params.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
defaults: | ||
- default_train_params | ||
|
||
max_epochs: 610 | ||
initial_lr: 5e-3 | ||
lr_mode: cosine | ||
lr_warmup_epochs: 5 | ||
lr_cooldown_epochs: 10 | ||
ema: True | ||
mixed_precision: True | ||
zero_weight_decay_on_bias_and_bn: True | ||
optimizer: Lamb | ||
optimizer_params: | ||
weight_decay: 0.02 | ||
loss: cross_entropy | ||
train_metrics_list: # metrics for evaluation | ||
- _target_: super_gradients.training.metrics.Accuracy | ||
- _target_: super_gradients.training.metrics.Top5 | ||
valid_metrics_list: # metrics for evaluation | ||
- _target_: super_gradients.training.metrics.Accuracy | ||
- _target_: super_gradients.training.metrics.Top5 | ||
loss_logging_items_names: ["Loss", "Task Loss", "Distillation Loss"] | ||
|
||
_convert_: all |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# PACKAGE IMPORTS FOR EXTERNAL USAGE | ||
|
||
from super_gradients.training.kd_model.kd_model import KDModel | ||
|
||
__all__ = ['KDModel'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from super_gradients.training.trainer import Trainer | ||
|
||
|
||
class KDTrainer(Trainer): | ||
""" | ||
Class for running SuperGradient's recipes for KD Models. | ||
See train_from_kd_recipe example in the examples directory to demonstrate it's usage. | ||
""" | ||
|
||
@classmethod | ||
def build_model(cls, cfg): | ||
cfg.sg_model.build_model(student_architecture=cfg.student_architecture, | ||
teacher_architecture=cfg.teacher_architecture, | ||
arch_params=cfg.arch_params, student_arch_params=cfg.student_arch_params, | ||
teacher_arch_params=cfg.teacher_arch_params, | ||
checkpoint_params=cfg.checkpoint_params, run_teacher_on_eval=cfg.run_teacher_on_eval) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import torch | ||
|
||
|
||
class NormalizationAdapter(torch.nn.Module): | ||
def __init__(self, mean_original, std_original, mean_required, std_required): | ||
super(NormalizationAdapter, self).__init__() | ||
mean_original = torch.tensor(mean_original).unsqueeze(-1).unsqueeze(-1) | ||
std_original = torch.tensor(std_original).unsqueeze(-1).unsqueeze(-1) | ||
mean_required = torch.tensor(mean_required).unsqueeze(-1).unsqueeze(-1) | ||
std_required = torch.tensor(std_required).unsqueeze(-1).unsqueeze(-1) | ||
|
||
self.additive = torch.nn.Parameter((mean_original - mean_required) / std_original) | ||
self.multiplier = torch.nn.Parameter(std_original / std_required) | ||
|
||
def forward(self, x): | ||
x = (x + self.additive) * self.multiplier | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters