-
Notifications
You must be signed in to change notification settings - Fork 530
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feature/sg 132 models convert (#598)
* black * formatting mid * black lint * docs and final cleanup lint * unit tests lint and black * broken import fix * num classes arg in unit test * Standardize typo * check out path * docs edit * lower repvgg unit tests tolerance * unit test prints * prints lint * toerance fix repvgg unit test
- Loading branch information
Showing
17 changed files
with
287 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
30 changes: 30 additions & 0 deletions
30
src/super_gradients/examples/convert_recipe_example/convert_recipe_example.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
""" | ||
Example code for running SuperGradient's recipes. | ||
General use: python convert_recipe_example.py --config-name=DESIRED_RECIPE'S_CONVERSION_PARAMS experiment_name=DESIRED_RECIPE'S_EXPERIMENT_NAME. | ||
For more optoins see : super_gradients/recipes/conversion_params/default_conversion_params.yaml. | ||
Note: conversion_params yaml file should reside under super_gradients/recipes/conversion_params | ||
""" | ||
|
||
from omegaconf import DictConfig | ||
import hydra | ||
import pkg_resources | ||
from super_gradients import init_trainer | ||
from super_gradients.training import models | ||
|
||
|
||
@hydra.main(config_path=pkg_resources.resource_filename("super_gradients.recipes.conversion_params", ""), version_base="1.2") | ||
def main(cfg: DictConfig) -> None: | ||
# INSTANTIATE ALL OBJECTS IN CFG | ||
models.convert_from_config(cfg) | ||
|
||
|
||
def run(): | ||
init_trainer() | ||
main() | ||
|
||
|
||
if __name__ == "__main__": | ||
run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
36 changes: 36 additions & 0 deletions
36
src/super_gradients/recipes/conversion_params/cifar10_conversion_params.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# Example conversion parameters, to be used with super_gradients/examples/convert_recipe_example/convert_recipe_example.py | ||
# Suppose you trained cifar10_resnet using train_from_recipe beforehand, Then: | ||
# python convert_recipe_example.py --config-name=cifar10_conversion_params experiment_name=YOUR_EXPERIMENT_NAME. | ||
# Alternatively (or if ckpts are located anywhere else from the default checkpoints dir), you can give the full checkpoint path: | ||
# python convert_recipe_example.py --config-name=cifar10_conversion_params checkpoint_path=YOUR_CHECKPOINT_PATH | ||
defaults: | ||
- default_conversion_params | ||
- _self_ | ||
|
||
experiment_name: resnet18_cifar # The experiment name used to train the model (optional- ignored when checkpoint_path is given) | ||
|
||
# CONVERSION RELATED PARAMS | ||
out_path: # str, Destination path for the .onnx file. When None- out_path will be the resolved checkpoint path replacing .ckpt suffix with .onnx. | ||
input_shape: # input shape, not including batch_size. Always channels first (i.e (3, 224, 224)). | ||
- 3 | ||
- 32 | ||
- 32 | ||
pre_process: # Preprocessing pipeline, will be resolved by TransformsFactory(), and will be baked into the converted model (optional). | ||
Compose: | ||
transforms: | ||
- Standardize | ||
- Normalize: | ||
mean: | ||
- 0.4914 | ||
- 0.4822 | ||
- 0.4465 | ||
std: | ||
- 0.2023 | ||
- 0.1994 | ||
- 0.2010 | ||
|
||
|
||
post_process: # Postprocessing pipeline, will be resolved by TransformsFactory(), and will be baked into the converted model (optional). | ||
prep_model_for_conversion_kwargs: # For SgModules, args to be passed to model.prep_model_for_conversion prior to torch.onnx.export call. | ||
torch_onnx_export_kwargs: # kwargs (EXCLUDING: FIRST 3 KWARGS- MODEL, F, ARGS). to be unpacked in torch.onnx.export call | ||
opset_version: 16 |
19 changes: 19 additions & 0 deletions
19
src/super_gradients/recipes/conversion_params/default_conversion_params.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
experiment_name: # The experiment name used to train the model (optional- ignored when checkpoint_path is given) | ||
ckpt_root_dir: # The checkpoint root directory, s.t ckpt_root_dir/experiment_name/ckpt_name resides. | ||
# Can be ignored if the checkpoints directory is the default (i.e path to checkpoints module from contents root), or when checkpoint_path is given | ||
ckpt_name: ckpt_best.pth # Name of the checkpoint to export ("ckpt_latest.pth", "average_model.pth" or "ckpt_best.pth" for instance). | ||
checkpoint_path: | ||
strict_load: no_key_matching # One of [On, Off, no_key_matching] (case insensitive) See super_gradients/common/data_types/enum/strict_load.py | ||
# NOTES ON: ckpt_root_dir, checkpoint_path, and ckpt_name: | ||
# - ckpt_root_dir, experiment_name and ckpt_name are only used when checkpoint_path is None. | ||
# - when checkpoint_path is None, the model will be vuilt according to the output yaml config inside ckpt_root_dir/experiment_name/ckpt_name. Also note that in | ||
# this case its also legal not to pass ckpt_root_dir, which will be resolved to the default SG ckpt dir. | ||
|
||
|
||
# CONVERSION RELATED PARAMS | ||
out_path: # str, Destination path for the .onnx file. When None- will be set to the checkpoint_path.replace(".ckpt",".onnx"). | ||
input_shape: # input shape, not including batch_size. Always channels first (i.e (3, 224, 224)). | ||
pre_process: # Preprocessing pipeline, will be resolved by TransformsFactory(), and will be baked into the converted model (optional). | ||
post_process: # Postprocessing pipeline, will be resolved by TransformsFactory(), and will be baked into the converted model (optional). | ||
prep_model_for_conversion_kwargs: # For SgModules, args to be passed to model.prep_model_for_conversion prior to torch.onnx.export call. | ||
torch_onnx_export_kwargs: # kwargs (EXCLUDING: FIRST 3 KWARGS- MODEL, F, ARGS). to be unpacked in torch.onnx.export call |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
from pathlib import Path | ||
|
||
import hydra | ||
import torch | ||
from omegaconf import DictConfig | ||
import numpy as np | ||
from torch.nn import Identity | ||
|
||
from super_gradients.common.abstractions.abstract_logger import get_logger | ||
from super_gradients.common.decorators.factory_decorator import resolve_param | ||
from super_gradients.common.factories.transforms_factory import TransformsFactory | ||
from super_gradients.training import models | ||
from super_gradients.training.utils.checkpoint_utils import get_checkpoints_dir_path | ||
from super_gradients.training.utils.hydra_utils import load_experiment_cfg | ||
from super_gradients.training.utils.sg_trainer_utils import parse_args | ||
import os | ||
import pathlib | ||
|
||
logger = get_logger(__name__) | ||
|
||
|
||
class ConvertableCompletePipelineModel(torch.nn.Module): | ||
""" | ||
Exportable nn.Module that wraps the model, preprocessing and postprocessing. | ||
Args: | ||
model: torch.nn.Module, the main model. takes input from pre_process' output, and feeds pre_process. | ||
pre_process: torch.nn.Module, preprocessing module, its output will be model's input. When none (default), set to Identity(). | ||
pre_process: torch.nn.Module, postprocessing module, its output is the final output. When none (default), set to Identity(). | ||
**prep_model_for_conversion_kwargs: for SgModules- args to be passed to model.prep_model_for_conversion | ||
prior to torch.onnx.export call. | ||
""" | ||
|
||
def __init__(self, model: torch.nn.Module, pre_process: torch.nn.Module = None, post_process: torch.nn.Module = None, **prep_model_for_conversion_kwargs): | ||
super(ConvertableCompletePipelineModel, self).__init__() | ||
model.eval() | ||
pre_process = pre_process or Identity() | ||
post_process = post_process or Identity() | ||
if hasattr(model, "prep_model_for_conversion"): | ||
model.prep_model_for_conversion(**prep_model_for_conversion_kwargs) | ||
self.model = model | ||
self.pre_process = pre_process | ||
self.post_process = post_process | ||
|
||
def forward(self, x): | ||
return self.post_process(self.model(self.pre_process(x))) | ||
|
||
|
||
@resolve_param("pre_process", TransformsFactory()) | ||
@resolve_param("post_process", TransformsFactory()) | ||
def convert_to_onnx( | ||
model: torch.nn.Module, | ||
out_path: str, | ||
input_shape: tuple, | ||
pre_process: torch.nn.Module = None, | ||
post_process: torch.nn.Module = None, | ||
prep_model_for_conversion_kwargs=None, | ||
torch_onnx_export_kwargs=None, | ||
): | ||
""" | ||
Exports model to ONNX. | ||
:param model: torch.nn.Module, model to export to ONNX. | ||
:param out_path: str, destination path for the .onnx file. | ||
:param input_shape: tuple, input shape, excluding batch_size (i.e (3, 224, 224)). | ||
:param pre_process: torch.nn.Module, preprocessing pipeline, will be resolved by TransformsFactory() | ||
:param post_process: torch.nn.Module, postprocessing pipeline, will be resolved by TransformsFactory() | ||
:param prep_model_for_conversion_kwargs: dict, for SgModules- args to be passed to model.prep_model_for_conversion | ||
prior to torch.onnx.export call. | ||
:param torch_onnx_export_kwargs: kwargs (EXCLUDING: FIRST 3 KWARGS- MODEL, F, ARGS). to be unpacked in torch.onnx.export call | ||
:return: out_path | ||
""" | ||
if not os.path.isdir(pathlib.Path(out_path).parent.resolve()): | ||
raise FileNotFoundError(f"Could not find destination directory {out_path} for the ONNX file.") | ||
torch_onnx_export_kwargs = torch_onnx_export_kwargs or dict() | ||
prep_model_for_conversion_kwargs = prep_model_for_conversion_kwargs or dict() | ||
onnx_input = torch.Tensor(np.zeros([1, *input_shape])) | ||
if not out_path.endswith(".onnx"): | ||
out_path = out_path + ".onnx" | ||
complete_model = ConvertableCompletePipelineModel(model, pre_process, post_process, **prep_model_for_conversion_kwargs) | ||
|
||
torch.onnx.export(model=complete_model, args=onnx_input, f=out_path, **torch_onnx_export_kwargs) | ||
return out_path | ||
|
||
|
||
def prepare_conversion_cfgs(cfg: DictConfig): | ||
""" | ||
Builds the cfg (i.e conversion_params) and experiment_cfg (i.e recipe config according to cfg.experiment_name) | ||
to be used by convert_recipe_example | ||
:param cfg: DictConfig, converion_params config | ||
:return: cfg, experiment_cfg | ||
""" | ||
cfg = hydra.utils.instantiate(cfg) | ||
# CREATE THE EXPERIMENT CFG | ||
experiment_cfg = load_experiment_cfg(cfg.experiment_name, cfg.ckpt_root_dir) | ||
hydra.utils.instantiate(experiment_cfg) | ||
if cfg.checkpoint_path is None: | ||
logger.info( | ||
"checkpoint_params.checkpoint_path was not provided, so the model will be converted using weights from " | ||
"checkpoints_dir/training_hyperparams.ckpt_name " | ||
) | ||
checkpoints_dir = Path(get_checkpoints_dir_path(experiment_name=cfg.experiment_name, ckpt_root_dir=cfg.ckpt_root_dir)) | ||
cfg.checkpoint_path = str(checkpoints_dir / cfg.ckpt_name) | ||
cfg.out_path = cfg.out_path or cfg.checkpoint_path.replace(".ckpt", ".onnx") | ||
logger.info(f"Exporting checkpoint: {cfg.checkpoint_path} to ONNX.") | ||
return cfg, experiment_cfg | ||
|
||
|
||
def convert_from_config(cfg: DictConfig) -> str: | ||
""" | ||
Exports model according to cfg. | ||
See: | ||
super_gradients/recipes/conversion_params/default_conversion_params.yaml for the full cfg content documentation, | ||
and super_gradients/examples/convert_recipe_example/convert_recipe_example.py for usage. | ||
:param cfg: | ||
:return: out_path, the path of the saved .onnx file. | ||
""" | ||
cfg, experiment_cfg = prepare_conversion_cfgs(cfg) | ||
model = models.get( | ||
model_name=experiment_cfg.architecture, | ||
num_classes=experiment_cfg.arch_params.num_classes, | ||
arch_params=experiment_cfg.arch_params, | ||
strict_load=cfg.strict_load, | ||
checkpoint_path=cfg.checkpoint_path, | ||
) | ||
cfg = parse_args(cfg, models.convert_to_onnx) | ||
out_path = models.convert_to_onnx(model=model, **cfg) | ||
return out_path |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import tempfile | ||
import unittest | ||
|
||
from super_gradients.training import models | ||
from torchvision.transforms import Compose, Normalize, Resize | ||
from super_gradients.training.transforms import Standardize | ||
import os | ||
|
||
|
||
class TestModelsONNXExport(unittest.TestCase): | ||
def test_models_onnx_export(self): | ||
pretrained_model = models.get("resnet18", num_classes=1000, pretrained_weights="imagenet") | ||
preprocess = Compose([Resize(224), Standardize(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) | ||
with tempfile.TemporaryDirectory() as tmpdirname: | ||
out_path = os.path.join(tmpdirname, "resnet18.onnx") | ||
models.convert_to_onnx(model=pretrained_model, out_path=out_path, input_shape=(3, 256, 256), pre_process=preprocess) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Oops, something went wrong.