diff --git a/.circleci/config.yml b/.circleci/config.yml index f89c2dd1e5..590329e459 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -205,7 +205,8 @@ jobs: python3.8 -m pip install -r requirements.txt python3.8 -m pip install git+https://github.com/Deci-AI/super-gradients.git@${CIRCLE_BRANCH} python3.8 -m pip install torch==1.12.0+cu116 torchvision==0.13.0+cu116 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu116 - python3.8 src/super_gradients/examples/train_from_recipe_example/train_from_recipe.py --config-name=cifar10_resnet experiment_name=shortened_cifar10_resnet_accuracy_test training_hyperparams.max_epochs=100 training_hyperparams.average_best_models=False +multi_gpu=DDP +num_gpus=4 + python3.8 src/super_gradients/examples/train_from_recipe_example/train_from_recipe.py --config-name=cifar10_resnet experiment_name=shortened_cifar10_resnet_accuracy_test training_hyperparams.max_epochs=100 training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4 + python3.8 src/super_gradients/examples/convert_recipe_example/convert_recipe_example.py --config-name=cifar10_conversion_params experiment_name=shortened_cifar10_resnet_accuracy_test python3.8 src/super_gradients/examples/train_from_recipe_example/train_from_recipe.py --config-name=coco2017_yolox experiment_name=shortened_coco2017_yolox_n_map_test architecture=yolox_n training_hyperparams.loss=yolox_fast_loss training_hyperparams.max_epochs=10 training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4 python3.8 src/super_gradients/examples/train_from_recipe_example/train_from_recipe.py --config-name=cityscapes_regseg48 experiment_name=shortened_cityscapes_regseg48_iou_test training_hyperparams.max_epochs=10 training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4 coverage run --source=super_gradients -m unittest tests/deci_core_recipe_test_suite_runner.py diff --git a/src/super_gradients/common/object_names.py b/src/super_gradients/common/object_names.py index 8c63b619b5..7b5343d303 100644 --- a/src/super_gradients/common/object_names.py +++ b/src/super_gradients/common/object_names.py @@ -56,6 +56,7 @@ class Transforms: RandAugmentTransform = "RandAugmentTransform" Lighting = "Lighting" RandomErase = "RandomErase" + Standardize = "Standardize" # From torch Compose = "Compose" diff --git a/src/super_gradients/examples/convert_recipe_example/__init__.py b/src/super_gradients/examples/convert_recipe_example/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/super_gradients/examples/convert_recipe_example/convert_recipe_example.py b/src/super_gradients/examples/convert_recipe_example/convert_recipe_example.py new file mode 100644 index 0000000000..521ef95689 --- /dev/null +++ b/src/super_gradients/examples/convert_recipe_example/convert_recipe_example.py @@ -0,0 +1,30 @@ +""" +Example code for running SuperGradient's recipes. + +General use: python convert_recipe_example.py --config-name=DESIRED_RECIPE'S_CONVERSION_PARAMS experiment_name=DESIRED_RECIPE'S_EXPERIMENT_NAME. + +For more optoins see : super_gradients/recipes/conversion_params/default_conversion_params.yaml. + +Note: conversion_params yaml file should reside under super_gradients/recipes/conversion_params +""" + +from omegaconf import DictConfig +import hydra +import pkg_resources +from super_gradients import init_trainer +from super_gradients.training import models + + +@hydra.main(config_path=pkg_resources.resource_filename("super_gradients.recipes.conversion_params", ""), version_base="1.2") +def main(cfg: DictConfig) -> None: + # INSTANTIATE ALL OBJECTS IN CFG + models.convert_from_config(cfg) + + +def run(): + init_trainer() + main() + + +if __name__ == "__main__": + run() diff --git a/src/super_gradients/recipes/cifar10_resnet.yaml b/src/super_gradients/recipes/cifar10_resnet.yaml index 7954046c8d..9fb1e1f90f 100644 --- a/src/super_gradients/recipes/cifar10_resnet.yaml +++ b/src/super_gradients/recipes/cifar10_resnet.yaml @@ -30,7 +30,8 @@ architecture: resnet18_cifar experiment_name: resnet18_cifar - +multi_gpu: Off +num_gpus: 1 # THE FOLLOWING PARAMS ARE DIRECTLY USED BY HYDRA hydra: run: diff --git a/src/super_gradients/recipes/conversion_params/__init__.py b/src/super_gradients/recipes/conversion_params/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/super_gradients/recipes/conversion_params/cifar10_conversion_params.yaml b/src/super_gradients/recipes/conversion_params/cifar10_conversion_params.yaml new file mode 100644 index 0000000000..fc4390f3d4 --- /dev/null +++ b/src/super_gradients/recipes/conversion_params/cifar10_conversion_params.yaml @@ -0,0 +1,36 @@ +# Example conversion parameters, to be used with super_gradients/examples/convert_recipe_example/convert_recipe_example.py +# Suppose you trained cifar10_resnet using train_from_recipe beforehand, Then: +# python convert_recipe_example.py --config-name=cifar10_conversion_params experiment_name=YOUR_EXPERIMENT_NAME. +# Alternatively (or if ckpts are located anywhere else from the default checkpoints dir), you can give the full checkpoint path: +# python convert_recipe_example.py --config-name=cifar10_conversion_params checkpoint_path=YOUR_CHECKPOINT_PATH +defaults: + - default_conversion_params + - _self_ + +experiment_name: resnet18_cifar # The experiment name used to train the model (optional- ignored when checkpoint_path is given) + +# CONVERSION RELATED PARAMS +out_path: # str, Destination path for the .onnx file. When None- out_path will be the resolved checkpoint path replacing .ckpt suffix with .onnx. +input_shape: # input shape, not including batch_size. Always channels first (i.e (3, 224, 224)). + - 3 + - 32 + - 32 +pre_process: # Preprocessing pipeline, will be resolved by TransformsFactory(), and will be baked into the converted model (optional). + Compose: + transforms: + - Standardize + - Normalize: + mean: + - 0.4914 + - 0.4822 + - 0.4465 + std: + - 0.2023 + - 0.1994 + - 0.2010 + + +post_process: # Postprocessing pipeline, will be resolved by TransformsFactory(), and will be baked into the converted model (optional). +prep_model_for_conversion_kwargs: # For SgModules, args to be passed to model.prep_model_for_conversion prior to torch.onnx.export call. +torch_onnx_export_kwargs: # kwargs (EXCLUDING: FIRST 3 KWARGS- MODEL, F, ARGS). to be unpacked in torch.onnx.export call + opset_version: 16 diff --git a/src/super_gradients/recipes/conversion_params/default_conversion_params.yaml b/src/super_gradients/recipes/conversion_params/default_conversion_params.yaml new file mode 100644 index 0000000000..407d7651f0 --- /dev/null +++ b/src/super_gradients/recipes/conversion_params/default_conversion_params.yaml @@ -0,0 +1,19 @@ +experiment_name: # The experiment name used to train the model (optional- ignored when checkpoint_path is given) +ckpt_root_dir: # The checkpoint root directory, s.t ckpt_root_dir/experiment_name/ckpt_name resides. + # Can be ignored if the checkpoints directory is the default (i.e path to checkpoints module from contents root), or when checkpoint_path is given +ckpt_name: ckpt_best.pth # Name of the checkpoint to export ("ckpt_latest.pth", "average_model.pth" or "ckpt_best.pth" for instance). +checkpoint_path: +strict_load: no_key_matching # One of [On, Off, no_key_matching] (case insensitive) See super_gradients/common/data_types/enum/strict_load.py +# NOTES ON: ckpt_root_dir, checkpoint_path, and ckpt_name: +# - ckpt_root_dir, experiment_name and ckpt_name are only used when checkpoint_path is None. +# - when checkpoint_path is None, the model will be vuilt according to the output yaml config inside ckpt_root_dir/experiment_name/ckpt_name. Also note that in +# this case its also legal not to pass ckpt_root_dir, which will be resolved to the default SG ckpt dir. + + +# CONVERSION RELATED PARAMS +out_path: # str, Destination path for the .onnx file. When None- will be set to the checkpoint_path.replace(".ckpt",".onnx"). +input_shape: # input shape, not including batch_size. Always channels first (i.e (3, 224, 224)). +pre_process: # Preprocessing pipeline, will be resolved by TransformsFactory(), and will be baked into the converted model (optional). +post_process: # Postprocessing pipeline, will be resolved by TransformsFactory(), and will be baked into the converted model (optional). +prep_model_for_conversion_kwargs: # For SgModules, args to be passed to model.prep_model_for_conversion prior to torch.onnx.export call. +torch_onnx_export_kwargs: # kwargs (EXCLUDING: FIRST 3 KWARGS- MODEL, F, ARGS). to be unpacked in torch.onnx.export call diff --git a/src/super_gradients/training/models/__init__.py b/src/super_gradients/training/models/__init__.py index b0190f46dc..9db502330e 100755 --- a/src/super_gradients/training/models/__init__.py +++ b/src/super_gradients/training/models/__init__.py @@ -21,3 +21,4 @@ from super_gradients.training.models.user_models import * from super_gradients.training.models.model_factory import get from super_gradients.training.models.arch_params_factory import get_arch_params +from super_gradients.training.models.conversion import convert_to_onnx, convert_from_config diff --git a/src/super_gradients/training/models/conversion.py b/src/super_gradients/training/models/conversion.py new file mode 100644 index 0000000000..ad9753001d --- /dev/null +++ b/src/super_gradients/training/models/conversion.py @@ -0,0 +1,130 @@ +from pathlib import Path + +import hydra +import torch +from omegaconf import DictConfig +import numpy as np +from torch.nn import Identity + +from super_gradients.common.abstractions.abstract_logger import get_logger +from super_gradients.common.decorators.factory_decorator import resolve_param +from super_gradients.common.factories.transforms_factory import TransformsFactory +from super_gradients.training import models +from super_gradients.training.utils.checkpoint_utils import get_checkpoints_dir_path +from super_gradients.training.utils.hydra_utils import load_experiment_cfg +from super_gradients.training.utils.sg_trainer_utils import parse_args +import os +import pathlib + +logger = get_logger(__name__) + + +class ConvertableCompletePipelineModel(torch.nn.Module): + """ + Exportable nn.Module that wraps the model, preprocessing and postprocessing. + Args: + model: torch.nn.Module, the main model. takes input from pre_process' output, and feeds pre_process. + pre_process: torch.nn.Module, preprocessing module, its output will be model's input. When none (default), set to Identity(). + pre_process: torch.nn.Module, postprocessing module, its output is the final output. When none (default), set to Identity(). + **prep_model_for_conversion_kwargs: for SgModules- args to be passed to model.prep_model_for_conversion + prior to torch.onnx.export call. + """ + + def __init__(self, model: torch.nn.Module, pre_process: torch.nn.Module = None, post_process: torch.nn.Module = None, **prep_model_for_conversion_kwargs): + super(ConvertableCompletePipelineModel, self).__init__() + model.eval() + pre_process = pre_process or Identity() + post_process = post_process or Identity() + if hasattr(model, "prep_model_for_conversion"): + model.prep_model_for_conversion(**prep_model_for_conversion_kwargs) + self.model = model + self.pre_process = pre_process + self.post_process = post_process + + def forward(self, x): + return self.post_process(self.model(self.pre_process(x))) + + +@resolve_param("pre_process", TransformsFactory()) +@resolve_param("post_process", TransformsFactory()) +def convert_to_onnx( + model: torch.nn.Module, + out_path: str, + input_shape: tuple, + pre_process: torch.nn.Module = None, + post_process: torch.nn.Module = None, + prep_model_for_conversion_kwargs=None, + torch_onnx_export_kwargs=None, +): + """ + Exports model to ONNX. + + :param model: torch.nn.Module, model to export to ONNX. + :param out_path: str, destination path for the .onnx file. + :param input_shape: tuple, input shape, excluding batch_size (i.e (3, 224, 224)). + :param pre_process: torch.nn.Module, preprocessing pipeline, will be resolved by TransformsFactory() + :param post_process: torch.nn.Module, postprocessing pipeline, will be resolved by TransformsFactory() + :param prep_model_for_conversion_kwargs: dict, for SgModules- args to be passed to model.prep_model_for_conversion + prior to torch.onnx.export call. + :param torch_onnx_export_kwargs: kwargs (EXCLUDING: FIRST 3 KWARGS- MODEL, F, ARGS). to be unpacked in torch.onnx.export call + + :return: out_path + """ + if not os.path.isdir(pathlib.Path(out_path).parent.resolve()): + raise FileNotFoundError(f"Could not find destination directory {out_path} for the ONNX file.") + torch_onnx_export_kwargs = torch_onnx_export_kwargs or dict() + prep_model_for_conversion_kwargs = prep_model_for_conversion_kwargs or dict() + onnx_input = torch.Tensor(np.zeros([1, *input_shape])) + if not out_path.endswith(".onnx"): + out_path = out_path + ".onnx" + complete_model = ConvertableCompletePipelineModel(model, pre_process, post_process, **prep_model_for_conversion_kwargs) + + torch.onnx.export(model=complete_model, args=onnx_input, f=out_path, **torch_onnx_export_kwargs) + return out_path + + +def prepare_conversion_cfgs(cfg: DictConfig): + """ + Builds the cfg (i.e conversion_params) and experiment_cfg (i.e recipe config according to cfg.experiment_name) + to be used by convert_recipe_example + + :param cfg: DictConfig, converion_params config + :return: cfg, experiment_cfg + """ + cfg = hydra.utils.instantiate(cfg) + # CREATE THE EXPERIMENT CFG + experiment_cfg = load_experiment_cfg(cfg.experiment_name, cfg.ckpt_root_dir) + hydra.utils.instantiate(experiment_cfg) + if cfg.checkpoint_path is None: + logger.info( + "checkpoint_params.checkpoint_path was not provided, so the model will be converted using weights from " + "checkpoints_dir/training_hyperparams.ckpt_name " + ) + checkpoints_dir = Path(get_checkpoints_dir_path(experiment_name=cfg.experiment_name, ckpt_root_dir=cfg.ckpt_root_dir)) + cfg.checkpoint_path = str(checkpoints_dir / cfg.ckpt_name) + cfg.out_path = cfg.out_path or cfg.checkpoint_path.replace(".ckpt", ".onnx") + logger.info(f"Exporting checkpoint: {cfg.checkpoint_path} to ONNX.") + return cfg, experiment_cfg + + +def convert_from_config(cfg: DictConfig) -> str: + """ + Exports model according to cfg. + + See: + super_gradients/recipes/conversion_params/default_conversion_params.yaml for the full cfg content documentation, + and super_gradients/examples/convert_recipe_example/convert_recipe_example.py for usage. + :param cfg: + :return: out_path, the path of the saved .onnx file. + """ + cfg, experiment_cfg = prepare_conversion_cfgs(cfg) + model = models.get( + model_name=experiment_cfg.architecture, + num_classes=experiment_cfg.arch_params.num_classes, + arch_params=experiment_cfg.arch_params, + strict_load=cfg.strict_load, + checkpoint_path=cfg.checkpoint_path, + ) + cfg = parse_args(cfg, models.convert_to_onnx) + out_path = models.convert_to_onnx(model=model, **cfg) + return out_path diff --git a/src/super_gradients/training/transforms/__init__.py b/src/super_gradients/training/transforms/__init__.py index cd9b518139..900eac853a 100644 --- a/src/super_gradients/training/transforms/__init__.py +++ b/src/super_gradients/training/transforms/__init__.py @@ -6,6 +6,7 @@ DetectionHSV, DetectionPaddedRescale, DetectionTargetsFormatTransform, + Standardize, ) from super_gradients.training.transforms.all_transforms import ( TRANSFORMS, @@ -26,6 +27,7 @@ "DetectionPaddedRescale", "DetectionTargetsFormatTransform", "imported_albumentations_failure", + "Standardize", ] cv2.setNumThreads(0) diff --git a/src/super_gradients/training/transforms/all_transforms.py b/src/super_gradients/training/transforms/all_transforms.py index f6bda6c853..38867af0d8 100644 --- a/src/super_gradients/training/transforms/all_transforms.py +++ b/src/super_gradients/training/transforms/all_transforms.py @@ -25,6 +25,7 @@ DetectionTargetsFormat, DetectionPaddedRescale, DetectionTargetsFormatTransform, + Standardize, ) from torchvision.transforms import ( Compose, @@ -123,6 +124,7 @@ Transforms.RandomAdjustSharpness: RandomAdjustSharpness, Transforms.RandomAutocontrast: RandomAutocontrast, Transforms.RandomEqualize: RandomEqualize, + Transforms.Standardize: Standardize, } logger = get_logger(__name__) diff --git a/src/super_gradients/training/transforms/transforms.py b/src/super_gradients/training/transforms/transforms.py index 4ac7e0334d..c51ed38d57 100644 --- a/src/super_gradients/training/transforms/transforms.py +++ b/src/super_gradients/training/transforms/transforms.py @@ -3,6 +3,7 @@ import random from typing import Optional, Union, Tuple, List, Sequence, Dict +import torch.nn from PIL import Image, ImageFilter, ImageOps from torchvision import transforms as transforms import numpy as np @@ -1104,3 +1105,20 @@ def rescale_and_pad_to_size(img, input_size, swap=(2, 0, 1), pad_val=114): padded_img = padded_img.transpose(swap) padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) return padded_img, r + + +class Standardize(torch.nn.Module): + """ + Standardize image pixel values. + :return img/max_val + + attributes: + max_val: float, value to as described above (default=255) + """ + + def __init__(self, max_val=255.0): + super(Standardize, self).__init__() + self.max_val = max_val + + def forward(self, img): + return img / self.max_val diff --git a/tests/deci_core_unit_test_suite_runner.py b/tests/deci_core_unit_test_suite_runner.py index 82d93c4cf2..486c0ec7c3 100644 --- a/tests/deci_core_unit_test_suite_runner.py +++ b/tests/deci_core_unit_test_suite_runner.py @@ -24,6 +24,7 @@ from tests.end_to_end_tests import TestTrainer from tests.unit_tests.detection_utils_test import TestDetectionUtils from tests.unit_tests.detection_dataset_test import DetectionDatasetTest +from tests.unit_tests.export_onnx_test import TestModelsONNXExport from tests.unit_tests.local_ckpt_head_replacement_test import LocalCkptHeadReplacementTest from tests.unit_tests.phase_delegates_test import ContextMethodsTest from tests.unit_tests.quantization_utility_tests import QuantizationUtilityTest @@ -115,6 +116,7 @@ def _add_modules_to_unit_tests_suite(self): self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestRepVGGBlock)) self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(LocalCkptHeadReplacementTest)) self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(DetectionDatasetTest)) + self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestModelsONNXExport)) def _add_modules_to_end_to_end_tests_suite(self): """ diff --git a/tests/recipe_training_tests/shortened_recipes_accuracy_test.py b/tests/recipe_training_tests/shortened_recipes_accuracy_test.py index 5d9ea4aee2..9ab3c18a44 100644 --- a/tests/recipe_training_tests/shortened_recipes_accuracy_test.py +++ b/tests/recipe_training_tests/shortened_recipes_accuracy_test.py @@ -1,7 +1,7 @@ import unittest import shutil -from coverage.annotate import os +import os from super_gradients.common.environment import environment_config import torch @@ -14,6 +14,10 @@ def setUp(cls): def test_shortened_cifar10_resnet_accuracy(self): self.assertTrue(self._reached_goal_metric(experiment_name="shortened_cifar10_resnet_accuracy_test", metric_value=0.9167, delta=0.05)) + def test_convert_shortened_cifar10_resnet(self): + ckpt_dir = os.path.join(environment_config.PKG_CHECKPOINTS_DIR, "shortened_cifar10_resnet_accuracy_test") + self.assertTrue(os.path.exists(os.path.join(ckpt_dir, "ckpt_best.onnx"))) + def test_shortened_coco2017_yolox_n_map(self): self.assertTrue(self._reached_goal_metric(experiment_name="shortened_coco2017_yolox_n_map_test", metric_value=0.044, delta=0.02)) diff --git a/tests/unit_tests/export_onnx_test.py b/tests/unit_tests/export_onnx_test.py new file mode 100644 index 0000000000..a51bd0e093 --- /dev/null +++ b/tests/unit_tests/export_onnx_test.py @@ -0,0 +1,20 @@ +import tempfile +import unittest + +from super_gradients.training import models +from torchvision.transforms import Compose, Normalize, Resize +from super_gradients.training.transforms import Standardize +import os + + +class TestModelsONNXExport(unittest.TestCase): + def test_models_onnx_export(self): + pretrained_model = models.get("resnet18", num_classes=1000, pretrained_weights="imagenet") + preprocess = Compose([Resize(224), Standardize(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) + with tempfile.TemporaryDirectory() as tmpdirname: + out_path = os.path.join(tmpdirname, "resnet18.onnx") + models.convert_to_onnx(model=pretrained_model, out_path=out_path, input_shape=(3, 256, 256), pre_process=preprocess) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/repvgg_unit_test.py b/tests/unit_tests/repvgg_unit_test.py index c9d5040066..1ab1fc19f5 100644 --- a/tests/unit_tests/repvgg_unit_test.py +++ b/tests/unit_tests/repvgg_unit_test.py @@ -10,11 +10,11 @@ class BackboneBasedModel(torch.nn.Module): """ Auxiliary model which will use repvgg as backbone """ + def __init__(self, backbone, backbone_output_channel, num_classes=1000): super(BackboneBasedModel, self).__init__() self.backbone = backbone - self.conv = torch.nn.Conv2d(in_channels=backbone_output_channel, out_channels=backbone_output_channel, kernel_size=1, - stride=1, padding=0) + self.conv = torch.nn.Conv2d(in_channels=backbone_output_channel, out_channels=backbone_output_channel, kernel_size=1, stride=1, padding=0) self.bn = torch.nn.BatchNorm2d(backbone_output_channel) # Adding a bn layer that should NOT be fused self.avgpool = torch.nn.AdaptiveAvgPool2d(output_size=1) self.linear = torch.nn.Linear(backbone_output_channel, num_classes) @@ -28,17 +28,14 @@ def forward(self, x): return self.linear(x) def prep_model_for_conversion(self): - if hasattr(self.backbone, 'prep_model_for_conversion'): + if hasattr(self.backbone, "prep_model_for_conversion"): self.backbone.prep_model_for_conversion() class TestRepVgg(unittest.TestCase): - def setUp(self): # contains all arch_params needed for initialization of all architectures - self.all_arch_params = HpmStruct(**{'num_classes': 10, - 'width_mult': 1, - 'build_residual_branches': True}) + self.all_arch_params = HpmStruct(**{"num_classes": 10, "width_mult": 1, "build_residual_branches": True}) self.backbone_arch_params = copy.deepcopy(self.all_arch_params) self.backbone_arch_params.override(backbone_mode=True) @@ -51,31 +48,31 @@ def test_deployment_architecture(self): in_channels = 3 for arch_name in ARCHITECTURES: # skip custom constructors to keep all_arch_params as general as a possible - if 'repvgg' not in arch_name or 'custom' in arch_name: + if "repvgg" not in arch_name or "custom" in arch_name: continue model = ARCHITECTURES[arch_name](arch_params=self.all_arch_params) - self.assertTrue(hasattr(model.stem, 'branch_3x3')) # check single layer for training mode + self.assertTrue(hasattr(model.stem, "branch_3x3")) # check single layer for training mode self.assertTrue(model.build_residual_branches) training_mode_sd = model.state_dict() for module in training_mode_sd: - self.assertFalse('reparam' in module) # deployment block included in training mode + self.assertFalse("reparam" in module) # deployment block included in training mode test_input = torch.ones((1, in_channels, image_size, image_size)) model.eval() training_mode_output = model(test_input) model.prep_model_for_conversion() - self.assertTrue(hasattr(model.stem, 'rbr_reparam')) # check single layer for training mode + self.assertTrue(hasattr(model.stem, "rbr_reparam")) # check single layer for training mode self.assertFalse(model.build_residual_branches) deployment_mode_sd = model.state_dict() for module in deployment_mode_sd: - self.assertFalse('running_mean' in module) # BN were not fused - self.assertFalse('branch' in module) # branches were not joined + self.assertFalse("running_mean" in module) # BN were not fused + self.assertFalse("branch" in module) # branches were not joined deployment_mode_output = model(test_input) # difference is of very low magnitude - self.assertFalse(False in torch.isclose(training_mode_output, deployment_mode_output, atol=1e-5)) + self.assertFalse(False in torch.isclose(training_mode_output, deployment_mode_output, atol=1e-4)) def test_backbone_mode(self): """ @@ -85,8 +82,7 @@ def test_backbone_mode(self): in_channels = 3 test_input = torch.rand((1, in_channels, image_size, image_size)) backbone_model = RepVggA1(self.backbone_arch_params) - model = BackboneBasedModel(backbone_model, backbone_output_channel=1280, - num_classes=self.backbone_arch_params.num_classes) + model = BackboneBasedModel(backbone_model, backbone_output_channel=1280, num_classes=self.backbone_arch_params.num_classes) backbone_model.eval() model.eval() @@ -98,17 +94,17 @@ def test_backbone_mode(self): training_mode_sd = model.state_dict() for module in training_mode_sd: - self.assertFalse('reparam' in module) # deployment block included in training mode + self.assertFalse("reparam" in module) # deployment block included in training mode model.prep_model_for_conversion() deployment_mode_sd_list = list(model.state_dict().keys()) - self.assertTrue('bn.running_mean' in deployment_mode_sd_list) # Verify non backbone batch norm wasn't fused + self.assertTrue("bn.running_mean" in deployment_mode_sd_list) # Verify non backbone batch norm wasn't fused for module in deployment_mode_sd_list: - self.assertFalse('running_mean' in module and module.startswith('backbone')) # BN were not fused - self.assertFalse('branch' in module and module.startswith('backbone')) # branches were not joined + self.assertFalse("running_mean" in module and module.startswith("backbone")) # BN were not fused + self.assertFalse("branch" in module and module.startswith("backbone")) # branches were not joined model_deployment_mode_output = model(test_input) self.assertFalse(False in torch.isclose(model_deployment_mode_output, model_training_mode_output, atol=1e-5)) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main()