Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sg 132 models convert #598

Merged
merged 18 commits into from
Jan 8, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ jobs:
python3.8 -m pip install -r requirements.txt
python3.8 -m pip install git+https://github.com/Deci-AI/super-gradients.git@${CIRCLE_BRANCH}
python3.8 -m pip install torch==1.12.0+cu116 torchvision==0.13.0+cu116 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu116
python3.8 src/super_gradients/examples/train_from_recipe_example/train_from_recipe.py --config-name=cifar10_resnet experiment_name=shortened_cifar10_resnet_accuracy_test training_hyperparams.max_epochs=100 training_hyperparams.average_best_models=False +multi_gpu=DDP +num_gpus=4
python3.8 src/super_gradients/examples/train_from_recipe_example/train_from_recipe.py --config-name=cifar10_resnet experiment_name=shortened_cifar10_resnet_accuracy_test training_hyperparams.max_epochs=100 training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4
python3.8 src/super_gradients/examples/convert_recipe_example/convert_recipe_example.py --config-name=cifar10_conversion_params experiment_name=shortened_cifar10_resnet_accuracy_test
python3.8 src/super_gradients/examples/train_from_recipe_example/train_from_recipe.py --config-name=coco2017_yolox experiment_name=shortened_coco2017_yolox_n_map_test architecture=yolox_n training_hyperparams.loss=yolox_fast_loss training_hyperparams.max_epochs=10 training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4
python3.8 src/super_gradients/examples/train_from_recipe_example/train_from_recipe.py --config-name=cityscapes_regseg48 experiment_name=shortened_cityscapes_regseg48_iou_test training_hyperparams.max_epochs=10 training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4
coverage run --source=super_gradients -m unittest tests/deci_core_recipe_test_suite_runner.py
Expand Down
1 change: 1 addition & 0 deletions src/super_gradients/common/object_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class Transforms:
RandAugmentTransform = "RandAugmentTransform"
Lighting = "Lighting"
RandomErase = "RandomErase"
Standardize = "Standardize"

# From torch
Compose = "Compose"
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""
Example code for running SuperGradient's recipes.

General use: python convert_recipe_example.py --config-name=DESIRED_RECIPE'S_CONVERSION_PARAMS experiment_name=DESIRED_RECIPE'S_EXPERIMENT_NAME.

For more optoins see : super_gradients/recipes/conversion_params/default_conversion_params.yaml.

Note: conversion_params yaml file should reside under super_gradients/recipes/conversion_params
"""

from omegaconf import DictConfig
import hydra
import pkg_resources
from super_gradients import init_trainer
from super_gradients.training import models


@hydra.main(config_path=pkg_resources.resource_filename("super_gradients.recipes.conversion_params", ""), version_base="1.2")
def main(cfg: DictConfig) -> None:
# INSTANTIATE ALL OBJECTS IN CFG
models.convert_from_config(cfg)


def run():
init_trainer()
main()


if __name__ == "__main__":
run()
3 changes: 2 additions & 1 deletion src/super_gradients/recipes/cifar10_resnet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ architecture: resnet18_cifar

experiment_name: resnet18_cifar


multi_gpu: Off
num_gpus: 1
# THE FOLLOWING PARAMS ARE DIRECTLY USED BY HYDRA
hydra:
run:
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Example conversion parameters, to be used with super_gradients/examples/convert_recipe_example/convert_recipe_example.py
# Suppose you trained cifar10_resnet using train_from_recipe beforehand, Then:
# python convert_recipe_example.py --config-name=cifar10_conversion_params experiment_name=YOUR_EXPERIMENT_NAME.
# Alternatively (or if ckpts are located anywhere else from the default checkpoints dir), you can give the full checkpoint path:
# python convert_recipe_example.py --config-name=cifar10_conversion_params checkpoint_path=YOUR_CHECKPOINT_PATH
defaults:
- default_conversion_params
- _self_

experiment_name: resnet18_cifar # The experiment name used to train the model (optional- ignored when checkpoint_path is given)

# CONVERSION RELATED PARAMS
out_path: # str, Destination path for the .onnx file. When None- out_path will be the resolved checkpoint path replacing .ckpt suffix with .onnx.
input_shape: # input shape, not including batch_size. Always channels first (i.e (3, 224, 224)).
- 3
- 32
- 32
pre_process: # Preprocessing pipeline, will be resolved by TransformsFactory(), and will be baked into the converted model (optional).
Compose:
transforms:
- Standardize
- Normalize:
mean:
- 0.4914
- 0.4822
- 0.4465
std:
- 0.2023
- 0.1994
- 0.2010


post_process: # Postprocessing pipeline, will be resolved by TransformsFactory(), and will be baked into the converted model (optional).
prep_model_for_conversion_kwargs: # For SgModules, args to be passed to model.prep_model_for_conversion prior to torch.onnx.export call.
torch_onnx_export_kwargs: # kwargs (EXCLUDING: FIRST 3 KWARGS- MODEL, F, ARGS). to be unpacked in torch.onnx.export call
opset_version: 16
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
experiment_name: # The experiment name used to train the model (optional- ignored when checkpoint_path is given)
ckpt_root_dir: # The checkpoint root directory, s.t ckpt_root_dir/experiment_name/ckpt_name resides.
# Can be ignored if the checkpoints directory is the default (i.e path to checkpoints module from contents root), or when checkpoint_path is given
ckpt_name: ckpt_best.pth # Name of the checkpoint to export ("ckpt_latest.pth", "average_model.pth" or "ckpt_best.pth" for instance).
checkpoint_path:
ofrimasad marked this conversation as resolved.
Show resolved Hide resolved
strict_load: no_key_matching # One of [On, Off, no_key_matching] (case insensitive) See super_gradients/common/data_types/enum/strict_load.py
# NOTES ON: ckpt_root_dir, checkpoint_path, and ckpt_name:
# - ckpt_root_dir, experiment_name and ckpt_name are only used when checkpoint_path is None.
# - when checkpoint_path is None, the model will be vuilt according to the output yaml config inside ckpt_root_dir/experiment_name/ckpt_name. Also note that in
# this case its also legal not to pass ckpt_root_dir, which will be resolved to the default SG ckpt dir.


# CONVERSION RELATED PARAMS
out_path: # str, Destination path for the .onnx file. When None- will be set to the checkpoint_path.replace(".ckpt",".onnx").
input_shape: # input shape, not including batch_size. Always channels first (i.e (3, 224, 224)).
pre_process: # Preprocessing pipeline, will be resolved by TransformsFactory(), and will be baked into the converted model (optional).
post_process: # Postprocessing pipeline, will be resolved by TransformsFactory(), and will be baked into the converted model (optional).
prep_model_for_conversion_kwargs: # For SgModules, args to be passed to model.prep_model_for_conversion prior to torch.onnx.export call.
torch_onnx_export_kwargs: # kwargs (EXCLUDING: FIRST 3 KWARGS- MODEL, F, ARGS). to be unpacked in torch.onnx.export call
1 change: 1 addition & 0 deletions src/super_gradients/training/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@
from super_gradients.training.models.user_models import *
from super_gradients.training.models.model_factory import get
from super_gradients.training.models.arch_params_factory import get_arch_params
from super_gradients.training.models.conversion import convert_to_onnx, convert_from_config
130 changes: 130 additions & 0 deletions src/super_gradients/training/models/conversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from pathlib import Path
ofrimasad marked this conversation as resolved.
Show resolved Hide resolved

import hydra
import torch
from omegaconf import DictConfig
import numpy as np
from torch.nn import Identity

from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.decorators.factory_decorator import resolve_param
from super_gradients.common.factories.transforms_factory import TransformsFactory
from super_gradients.training import models
from super_gradients.training.utils.checkpoint_utils import get_checkpoints_dir_path
from super_gradients.training.utils.hydra_utils import load_experiment_cfg
from super_gradients.training.utils.sg_trainer_utils import parse_args
import os
import pathlib

logger = get_logger(__name__)


class ConvertableCompletePipelineModel(torch.nn.Module):
"""
Exportable nn.Module that wraps the model, preprocessing and postprocessing.
Args:
model: torch.nn.Module, the main model. takes input from pre_process' output, and feeds pre_process.
pre_process: torch.nn.Module, preprocessing module, its output will be model's input. When none (default), set to Identity().
pre_process: torch.nn.Module, postprocessing module, its output is the final output. When none (default), set to Identity().
**prep_model_for_conversion_kwargs: for SgModules- args to be passed to model.prep_model_for_conversion
prior to torch.onnx.export call.
"""

def __init__(self, model: torch.nn.Module, pre_process: torch.nn.Module = None, post_process: torch.nn.Module = None, **prep_model_for_conversion_kwargs):
super(ConvertableCompletePipelineModel, self).__init__()
model.eval()
pre_process = pre_process or Identity()
post_process = post_process or Identity()
if hasattr(model, "prep_model_for_conversion"):
model.prep_model_for_conversion(**prep_model_for_conversion_kwargs)
self.model = model
self.pre_process = pre_process
self.post_process = post_process

def forward(self, x):
return self.post_process(self.model(self.pre_process(x)))


@resolve_param("pre_process", TransformsFactory())
@resolve_param("post_process", TransformsFactory())
def convert_to_onnx(
model: torch.nn.Module,
out_path: str,
input_shape: tuple,
pre_process: torch.nn.Module = None,
post_process: torch.nn.Module = None,
prep_model_for_conversion_kwargs=None,
torch_onnx_export_kwargs=None,
):
"""
Exports model to ONNX.

:param model: torch.nn.Module, model to export to ONNX.
:param out_path: str, destination path for the .onnx file.
:param input_shape: tuple, input shape, excluding batch_size (i.e (3, 224, 224)).
:param pre_process: torch.nn.Module, preprocessing pipeline, will be resolved by TransformsFactory()
:param post_process: torch.nn.Module, postprocessing pipeline, will be resolved by TransformsFactory()
:param prep_model_for_conversion_kwargs: dict, for SgModules- args to be passed to model.prep_model_for_conversion
prior to torch.onnx.export call.
:param torch_onnx_export_kwargs: kwargs (EXCLUDING: FIRST 3 KWARGS- MODEL, F, ARGS). to be unpacked in torch.onnx.export call

:return: out_path
"""
if not os.path.isdir(pathlib.Path(out_path).parent.resolve()):
raise FileNotFoundError(f"Could not find destination directory {out_path} for the ONNX file.")
torch_onnx_export_kwargs = torch_onnx_export_kwargs or dict()
prep_model_for_conversion_kwargs = prep_model_for_conversion_kwargs or dict()
onnx_input = torch.Tensor(np.zeros([1, *input_shape]))
ofrimasad marked this conversation as resolved.
Show resolved Hide resolved
if not out_path.endswith(".onnx"):
out_path = out_path + ".onnx"
complete_model = ConvertableCompletePipelineModel(model, pre_process, post_process, **prep_model_for_conversion_kwargs)

torch.onnx.export(model=complete_model, args=onnx_input, f=out_path, **torch_onnx_export_kwargs)
return out_path


def prepare_conversion_cfgs(cfg: DictConfig):
"""
Builds the cfg (i.e conversion_params) and experiment_cfg (i.e recipe config according to cfg.experiment_name)
to be used by convert_recipe_example

:param cfg: DictConfig, converion_params config
:return: cfg, experiment_cfg
"""
cfg = hydra.utils.instantiate(cfg)
# CREATE THE EXPERIMENT CFG
experiment_cfg = load_experiment_cfg(cfg.experiment_name, cfg.ckpt_root_dir)
hydra.utils.instantiate(experiment_cfg)
if cfg.checkpoint_path is None:
logger.info(
"checkpoint_params.checkpoint_path was not provided, so the model will be converted using weights from "
"checkpoints_dir/training_hyperparams.ckpt_name "
)
checkpoints_dir = Path(get_checkpoints_dir_path(experiment_name=cfg.experiment_name, ckpt_root_dir=cfg.ckpt_root_dir))
cfg.checkpoint_path = str(checkpoints_dir / cfg.ckpt_name)
cfg.out_path = cfg.out_path or cfg.checkpoint_path.replace(".ckpt", ".onnx")
logger.info(f"Exporting checkpoint: {cfg.checkpoint_path} to ONNX.")
return cfg, experiment_cfg


def convert_from_config(cfg: DictConfig) -> str:
"""
Exports model according to cfg.

See:
super_gradients/recipes/conversion_params/default_conversion_params.yaml for the full cfg content documentation,
and super_gradients/examples/convert_recipe_example/convert_recipe_example.py for usage.
:param cfg:
:return: out_path, the path of the saved .onnx file.
"""
cfg, experiment_cfg = prepare_conversion_cfgs(cfg)
model = models.get(
model_name=experiment_cfg.architecture,
num_classes=experiment_cfg.arch_params.num_classes,
arch_params=experiment_cfg.arch_params,
strict_load=cfg.strict_load,
checkpoint_path=cfg.checkpoint_path,
)
cfg = parse_args(cfg, models.convert_to_onnx)
out_path = models.convert_to_onnx(model=model, **cfg)
return out_path
2 changes: 2 additions & 0 deletions src/super_gradients/training/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
DetectionHSV,
DetectionPaddedRescale,
DetectionTargetsFormatTransform,
Standardize,
)
from super_gradients.training.transforms.all_transforms import (
TRANSFORMS,
Expand All @@ -26,6 +27,7 @@
"DetectionPaddedRescale",
"DetectionTargetsFormatTransform",
"imported_albumentations_failure",
"Standardize",
]

cv2.setNumThreads(0)
2 changes: 2 additions & 0 deletions src/super_gradients/training/transforms/all_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
DetectionTargetsFormat,
DetectionPaddedRescale,
DetectionTargetsFormatTransform,
Standardize,
)
from torchvision.transforms import (
Compose,
Expand Down Expand Up @@ -123,6 +124,7 @@
Transforms.RandomAdjustSharpness: RandomAdjustSharpness,
Transforms.RandomAutocontrast: RandomAutocontrast,
Transforms.RandomEqualize: RandomEqualize,
Transforms.Standardize: Standardize,
}
logger = get_logger(__name__)

Expand Down
18 changes: 18 additions & 0 deletions src/super_gradients/training/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import random
from typing import Optional, Union, Tuple, List, Sequence, Dict

import torch.nn
from PIL import Image, ImageFilter, ImageOps
from torchvision import transforms as transforms
import numpy as np
Expand Down Expand Up @@ -1104,3 +1105,20 @@ def rescale_and_pad_to_size(img, input_size, swap=(2, 0, 1), pad_val=114):
padded_img = padded_img.transpose(swap)
padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
return padded_img, r


class Standardize(torch.nn.Module):
"""
Standardize image pixel values.
:return img/max_val

attributes:
max_val: float, value to as described above (default=255)
"""

def __init__(self, max_val=255.0):
super(Standardize, self).__init__()
self.max_val = max_val

def forward(self, img):
return img / self.max_val
2 changes: 2 additions & 0 deletions tests/deci_core_unit_test_suite_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tests.end_to_end_tests import TestTrainer
from tests.unit_tests.detection_utils_test import TestDetectionUtils
from tests.unit_tests.detection_dataset_test import DetectionDatasetTest
from tests.unit_tests.export_onnx_test import TestModelsONNXExport
from tests.unit_tests.local_ckpt_head_replacement_test import LocalCkptHeadReplacementTest
from tests.unit_tests.phase_delegates_test import ContextMethodsTest
from tests.unit_tests.quantization_utility_tests import QuantizationUtilityTest
Expand Down Expand Up @@ -115,6 +116,7 @@ def _add_modules_to_unit_tests_suite(self):
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestRepVGGBlock))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(LocalCkptHeadReplacementTest))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(DetectionDatasetTest))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestModelsONNXExport))

def _add_modules_to_end_to_end_tests_suite(self):
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
import shutil

from coverage.annotate import os
import os
from super_gradients.common.environment import environment_config
import torch

Expand All @@ -14,6 +14,10 @@ def setUp(cls):
def test_shortened_cifar10_resnet_accuracy(self):
self.assertTrue(self._reached_goal_metric(experiment_name="shortened_cifar10_resnet_accuracy_test", metric_value=0.9167, delta=0.05))

def test_convert_shortened_cifar10_resnet(self):
ckpt_dir = os.path.join(environment_config.PKG_CHECKPOINTS_DIR, "shortened_cifar10_resnet_accuracy_test")
self.assertTrue(os.path.exists(os.path.join(ckpt_dir, "ckpt_best.onnx")))

def test_shortened_coco2017_yolox_n_map(self):
self.assertTrue(self._reached_goal_metric(experiment_name="shortened_coco2017_yolox_n_map_test", metric_value=0.044, delta=0.02))

Expand Down
20 changes: 20 additions & 0 deletions tests/unit_tests/export_onnx_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import tempfile
import unittest

from super_gradients.training import models
from torchvision.transforms import Compose, Normalize, Resize
from super_gradients.training.transforms import Standardize
import os


class TestModelsONNXExport(unittest.TestCase):
def test_models_onnx_export(self):
pretrained_model = models.get("resnet18", num_classes=1000, pretrained_weights="imagenet")
preprocess = Compose([Resize(224), Standardize(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
with tempfile.TemporaryDirectory() as tmpdirname:
out_path = os.path.join(tmpdirname, "resnet18.onnx")
models.convert_to_onnx(model=pretrained_model, out_path=out_path, input_shape=(3, 256, 256), pre_process=preprocess)


if __name__ == "__main__":
unittest.main()
Loading