From 1f07175a2663689111996960947af67ec90c360d Mon Sep 17 00:00:00 2001 From: shayaharon Date: Sun, 14 May 2023 17:38:35 +0300 Subject: [PATCH 01/10] adde unit tests --- .circleci/config.yml | 2 + src/super_gradients/__init__.py | 3 + src/super_gradients/qat_from_recipe.py | 2 +- .../recipes/roboflow_yolo_nas_s_qat.yaml | 1 - .../training/qat_trainer/qat_trainer.py | 232 ++++++++++++++---- tests/deci_core_recipe_test_suite_runner.py | 2 + .../coded_qat_launch_test.py | 39 +++ 7 files changed, 226 insertions(+), 55 deletions(-) create mode 100644 tests/recipe_training_tests/coded_qat_launch_test.py diff --git a/.circleci/config.yml b/.circleci/config.yml index f3d050922c..1aeaf84f08 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -454,6 +454,8 @@ jobs: python3.8 -m pip install -r requirements.txt python3.8 -m pip install . python3.8 -m pip install torch torchvision torchaudio + python3.8 -m pip install pytorch-quantization==2.1.2 --extra-index-url https://pypi.ngc.nvidia.com + python3.8 src/super_gradients/examples/train_from_recipe_example/train_from_recipe.py --config-name=coco2017_pose_dekr_w32_no_dc experiment_name=shortened_coco2017_pose_dekr_w32_ap_test batch_size=4 val_batch_size=8 epochs=1 training_hyperparams.lr_warmup_steps=0 training_hyperparams.average_best_models=False training_hyperparams.max_train_batches=1000 training_hyperparams.max_valid_batches=100 multi_gpu=DDP num_gpus=4 python3.8 src/super_gradients/examples/train_from_recipe_example/train_from_recipe.py --config-name=cifar10_resnet experiment_name=shortened_cifar10_resnet_accuracy_test epochs=100 training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4 python3.8 src/super_gradients/examples/convert_recipe_example/convert_recipe_example.py --config-name=cifar10_conversion_params experiment_name=shortened_cifar10_resnet_accuracy_test diff --git a/src/super_gradients/__init__.py b/src/super_gradients/__init__.py index be8887e4dc..342d68e454 100755 --- a/src/super_gradients/__init__.py +++ b/src/super_gradients/__init__.py @@ -3,6 +3,7 @@ from super_gradients.common.registry.registry import ARCHITECTURES from super_gradients.sanity_check import env_sanity_check from super_gradients.training.utils.distributed_training_utils import setup_device +from super_gradients.training.pre_launch_callbacks import AutoTrainBatchSizeSelectionCallback, QATRecipeModificationCallback __all__ = [ "ARCHITECTURES", @@ -18,6 +19,8 @@ "is_distributed", "env_sanity_check", "setup_device", + "QATRecipeModificationCallback", + "AutoTrainBatchSizeSelectionCallback", ] __version__ = "3.1.1" diff --git a/src/super_gradients/qat_from_recipe.py b/src/super_gradients/qat_from_recipe.py index b6e39a5e03..e6e6a30497 100644 --- a/src/super_gradients/qat_from_recipe.py +++ b/src/super_gradients/qat_from_recipe.py @@ -14,7 +14,7 @@ @hydra.main(config_path="recipes", version_base="1.2") def _main(cfg: DictConfig) -> None: - QATTrainer.train_from_config(cfg) + QATTrainer.quantize_from_config(cfg) def main(): diff --git a/src/super_gradients/recipes/roboflow_yolo_nas_s_qat.yaml b/src/super_gradients/recipes/roboflow_yolo_nas_s_qat.yaml index 10ec0598f9..ac92b7032d 100644 --- a/src/super_gradients/recipes/roboflow_yolo_nas_s_qat.yaml +++ b/src/super_gradients/recipes/roboflow_yolo_nas_s_qat.yaml @@ -4,7 +4,6 @@ defaults: - _self_ checkpoint_params: - checkpoint_path: ??? strict_load: no_key_matching pre_launch_callbacks_list: diff --git a/src/super_gradients/training/qat_trainer/qat_trainer.py b/src/super_gradients/training/qat_trainer/qat_trainer.py index d4c11ec5ef..630a012362 100644 --- a/src/super_gradients/training/qat_trainer/qat_trainer.py +++ b/src/super_gradients/training/qat_trainer/qat_trainer.py @@ -1,5 +1,6 @@ import os -from typing import Union, Tuple +from typing import Union, Tuple, Dict, Mapping, List +from torchmetrics import Metric import copy import hydra @@ -7,8 +8,10 @@ from omegaconf import DictConfig from omegaconf import OmegaConf from torch import nn +from torch.utils.data import DataLoader from super_gradients.common.abstractions.abstract_logger import get_logger +from super_gradients.common.environment.cfg_utils import load_recipe from super_gradients.common.environment.device_utils import device_config from super_gradients.training import utils as core_utils, models, dataloaders from super_gradients.training.sg_trainer import Trainer @@ -32,7 +35,7 @@ class QATTrainer(Trainer): @classmethod - def train_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module, Tuple]: + def quantize_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module, Tuple]: """ Perform quantization aware training (QAT) according to a recipe configuration. @@ -42,6 +45,10 @@ def train_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module, Tup The quantized model will be exported to ONNX along with other checkpoints. + The call to self.quantize (see docs in the next method) is done with the created + train_loader and valid_loader. If no calibration data loader is passed through cfg.calib_loader, + a train data laoder with the validation transforms is used for calibration. + :param cfg: The parsed DictConfig object from yaml recipe files or a dictionary. :return: A tuple containing the quantized model and the output of trainer.train() method. :rtype: Tuple[nn.Module, Tuple] @@ -61,11 +68,13 @@ def train_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module, Tup # TRIGGER CFG MODIFYING CALLBACKS cfg = cls._trigger_cfg_modifying_callbacks(cfg) - if "quantization_params" not in cfg: - raise ValueError("Your recipe does not have quantization_params. Add them to use QAT.") + quantization_params = get_param(cfg, "quantization_params") + + if quantization_params is None: + raise logger.warning("Your recipe does not include quantization_params. Using default quantization params.") - if "checkpoint_path" not in cfg.checkpoint_params: - raise ValueError("Starting checkpoint is a must for QAT finetuning.") + if get_param(cfg.checkpoint_params, "checkpoint_path") is None and get_param(cfg.checkpoint_params, "pretrained_weights") is None: + raise ValueError("Starting checkpoint / pretrained weights are a must for QAT finetuning.") num_gpus = core_utils.get_param(cfg, "num_gpus") multi_gpu = core_utils.get_param(cfg, "multi_gpu") @@ -123,70 +132,117 @@ def train_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module, Tup checkpoint_path=cfg.checkpoint_params.checkpoint_path, load_backbone=False, ) - model.to(device_config.device) - - # QUANTIZE MODEL - model.eval() - fuse_repvgg_blocks_residual_branches(model) - q_util = SelectiveQuantizer( - default_quant_modules_calibrator_weights=cfg.quantization_params.selective_quantizer_params.calibrator_w, - default_quant_modules_calibrator_inputs=cfg.quantization_params.selective_quantizer_params.calibrator_i, - default_per_channel_quant_weights=cfg.quantization_params.selective_quantizer_params.per_channel, - default_learn_amax=cfg.quantization_params.selective_quantizer_params.learn_amax, - verbose=cfg.quantization_params.calib_params.verbose, + recipe_logged_cfg = {"recipe_config": OmegaConf.to_container(cfg, resolve=True)} + trainer = QATTrainer(experiment_name=cfg.experiment_name, ckpt_root_dir=get_param(cfg, "ckpt_root_dir")) + + res = trainer.quantize( + model=model, + quantization_params=quantization_params, + calib_dataloader=calib_dataloader, + val_dataloader=val_dataloader, + train_dataloader=train_dataloader, + training_params=cfg.training_hyperparams, + additional_qat_configs_to_log=recipe_logged_cfg, ) - q_util.register_skip_quantization(layer_names=cfg.quantization_params.selective_quantizer_params.skip_modules) - q_util.quantize_module(model) - # CALIBRATE MODEL - logger.info("Calibrating model...") - calibrator = QuantizationCalibrator( - verbose=cfg.quantization_params.calib_params.verbose, - torch_hist=True, - ) - calibrator.calibrate_model( - model, - method=cfg.quantization_params.calib_params.histogram_calib_method, - calib_data_loader=calib_dataloader, - num_calib_batches=cfg.quantization_params.calib_params.num_calib_batches or len(train_dataloader), - percentile=get_param(cfg.quantization_params.calib_params, "percentile", 99.99), - ) - calibrator.reset_calibrators(model) # release memory taken by calibrators + return model, res - # VALIDATE PTQ MODEL AND PRINT SUMMARY - logger.info("Validating PTQ model...") - trainer = Trainer(experiment_name=cfg.experiment_name, ckpt_root_dir=get_param(cfg, "ckpt_root_dir", default_val=None)) - valid_metrics_dict = trainer.test(model=model, test_loader=val_dataloader, test_metrics_list=cfg.training_hyperparams.valid_metrics_list) - results = ["PTQ Model Validation Results"] - results += [f" - {metric:10}: {value}" for metric, value in valid_metrics_dict.items()] - logger.info("\n".join(results)) + def quantize( + self, + calib_dataloader: DataLoader, + model: torch.nn.Module = None, + val_dataloader: DataLoader = None, + train_dataloader: DataLoader = None, + quantization_params: Mapping = None, + training_params: Mapping = None, + additional_qat_configs_to_log: Dict = None, + valid_metrics_list: List[Metric] = None, + ): + """ + Performs post-training quantization (PTQ), and optionally quantization-aware training (QAT). + Exports the ONNX model to the checkpoints directory. + + :param calib_dataloader: DataLoader, data loader for calibration. + + :param model: torch.nn.Module, Model to perform QAT/PTQ on. When None, will try to use self.net which is set + in previous self.train(..) call (default=None). + + + :param val_dataloader: DataLoader, data loader for validation. Used both for validating the calibrated model after PTQ and during QAT. + When None, will try to use self.valid_loader if it was set in previous self.train(..) call (default=None). + :param train_dataloader: DataLoader, data loader for QA training, can be ignored when quantization_params["ptq_only"]=True (default=None). + + :param quantization_params: Mapping, with the following entries:defaults- + + ptq_only: False # whether to launch QAT, or leave PTQ only + selective_quantizer_params: + calibrator_w: "max" # calibrator type for weights, acceptable types are ["max", "histogram"] + calibrator_i: "histogram" # calibrator type for inputs acceptable types are ["max", "histogram"] + per_channel: True # per-channel quantization of weights, activations stay per-tensor by default + learn_amax: False # enable learnable amax in all TensorQuantizers using straight-through estimator + skip_modules: # optional list of module names (strings) to skip from quantization + + calib_params: + histogram_calib_method: "percentile" # calibration method for all "histogram" calibrators, acceptable types are ["percentile", "entropy", mse"], + "max" calibrators always use "max" + percentile: 99.99 # percentile for all histogram calibrators with method "percentile", other calibrators are not affected + + num_calib_batches: # number of batches to use for calibration, if None, 512 / batch_size will be used + verbose: False # if calibrator should be verbose + + + :param training_params: Mapping, training hyper parameters for QAT, same as in super.train(...). When None, will try to use self.training_params + which is set in previous self.train(..) call (default=None). + + :param additional_qat_configs_to_log: Dict, Optional dictionary containing configs that will be added to the QA training's + sg_logger. Format should be {"Config_title_1": {...}, "Config_title_2":{..}}. + + :param valid_metrics_list: (list(torchmetrics.Metric)) metrics list for evaluation of the calibrated model. + When None, the validation metrics from training_params are used). (default=None). + + :return: Validation results of the QAT model in case quantization_params['ptq_only']=False and of the PTQ + model otherwise. + """ + + if quantization_params is None: + quantization_params = load_recipe("quantization_params/default_quantization_params").quantization_params + logger.info(f"Using default quantization params: {quantization_params}") + training_params = training_params or self.training_params.to_dict() + valid_metrics_list = valid_metrics_list or get_param(training_params, "valid_metrics_list") + train_dataloader = train_dataloader or self.train_loader + val_dataloader = val_dataloader or self.valid_loader + model = model or self.net + + res = self.calibrate_model( + calib_dataloader=calib_dataloader, + model=model, + quantization_params=quantization_params, + val_dataloader=val_dataloader, + valid_metrics_list=valid_metrics_list, + ) # TRAIN - if cfg.quantization_params.ptq_only: - logger.info("cfg.quantization_params.ptq_only=True. Performing PTQ only!") + if get_param(quantization_params, "ptq_only", False): + logger.info("quantization_params.ptq_only=True. Performing PTQ only!") suffix = "ptq" - res = None else: model.train() - recipe_logged_cfg = {"recipe_config": OmegaConf.to_container(cfg, resolve=True)} - trainer = Trainer(experiment_name=cfg.experiment_name, ckpt_root_dir=get_param(cfg, "ckpt_root_dir", default_val=None)) torch.cuda.empty_cache() - res = trainer.train( + res = self.train( model=model, train_loader=train_dataloader, valid_loader=val_dataloader, - training_params=cfg.training_hyperparams, - additional_configs_to_log=recipe_logged_cfg, + training_params=training_params, + additional_configs_to_log=additional_qat_configs_to_log, ) suffix = "qat" - # EXPORT QUANTIZED MODEL TO ONNX input_shape = next(iter(val_dataloader))[0].shape - os.makedirs(trainer.checkpoints_dir_path, exist_ok=True) + os.makedirs(self.checkpoints_dir_path, exist_ok=True) + qdq_onnx_path = os.path.join(self.checkpoints_dir_path, f"{self.experiment_name}_{'x'.join((str(x) for x in input_shape))}_{suffix}.onnx") - qdq_onnx_path = os.path.join(trainer.checkpoints_dir_path, f"{cfg.experiment_name}_{'x'.join((str(x) for x in input_shape))}_{suffix}.onnx") # TODO: modify SG's convert_to_onnx for quantized models and use it instead export_quantized_module_to_onnx( model=model.cpu(), @@ -195,7 +251,77 @@ def train_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module, Tup input_size=input_shape, train=False, ) - logger.info(f"Exported {suffix.upper()} ONNX to {qdq_onnx_path}") + return res - return model, res + def calibrate_model(self, calib_dataloader, model, quantization_params, val_dataloader, valid_metrics_list): + """ + Performs calibration. + + :param calib_dataloader: DataLoader, data loader for calibration. + + :param model: torch.nn.Module, Model to perform calibration on. When None, will try to use self.net which is + set in previous self.train(..) call (default=None). + + :param val_dataloader: DataLoader, data loader for validation. Used both for validating the calibrated model. + When None, will try to use self.valid_loader if it was set in previous self.train(..) call (default=None). + + :param quantization_params: Mapping, with the following entries:defaults- + selective_quantizer_params: + calibrator_w: "max" # calibrator type for weights, acceptable types are ["max", "histogram"] + calibrator_i: "histogram" # calibrator type for inputs acceptable types are ["max", "histogram"] + per_channel: True # per-channel quantization of weights, activations stay per-tensor by default + learn_amax: False # enable learnable amax in all TensorQuantizers using straight-through estimator + skip_modules: # optional list of module names (strings) to skip from quantization + + calib_params: histogram_calib_method: "percentile" # calibration method for all "histogram" calibrators, + acceptable types are ["percentile", "entropy", mse"], "max" calibrators always use "max" percentile: + 99.99 # percentile for all histogram calibrators with method "percentile", + other calibrators are not affected num_calib_batches: # number of batches to use for + calibration, if None, 512 / batch_size will be used verbose: False # if calibrator + should be verbose + + + + :param valid_metrics_list: (list(torchmetrics.Metric)) metrics list for evaluation of the calibrated model. + + :return: Validation results of the calibrated model. + """ + selective_quantizer_params = get_param(quantization_params, "selective_quantizer_params") + calib_params = get_param(quantization_params, "calib_params") + model = model or self.net + model.to(device_config.device) + # QUANTIZE MODEL + model.eval() + fuse_repvgg_blocks_residual_branches(model) + q_util = SelectiveQuantizer( + default_quant_modules_calibrator_weights=get_param(selective_quantizer_params, "calibrator_w"), + default_quant_modules_calibrator_inputs=get_param(selective_quantizer_params, "calibrator_i"), + default_per_channel_quant_weights=get_param(selective_quantizer_params, "per_channel"), + default_learn_amax=get_param(selective_quantizer_params, "learn_amax"), + verbose=get_param(calib_params, "verbose"), + ) + q_util.register_skip_quantization(layer_names=get_param(selective_quantizer_params, "skip_modules")) + q_util.quantize_module(model) + # CALIBRATE MODEL + logger.info("Calibrating model...") + calibrator = QuantizationCalibrator( + verbose=get_param(calib_params, "verbose"), + torch_hist=True, + ) + calibrator.calibrate_model( + model, + method=get_param(calib_params, "histogram_calib_method"), + calib_data_loader=calib_dataloader, + num_calib_batches=get_param(calib_params, "num_calib_batches") or len(calib_dataloader), + percentile=get_param(calib_params, "percentile", 99.99), + ) + calibrator.reset_calibrators(model) # release memory taken by calibrators + # VALIDATE PTQ MODEL AND PRINT SUMMARY + logger.info("Validating PTQ model...") + valid_metrics_dict = self.test(model=model, test_loader=val_dataloader, test_metrics_list=valid_metrics_list) + results = ["PTQ Model Validation Results"] + results += [f" - {metric:10}: {value}" for metric, value in valid_metrics_dict.items()] + logger.info("\n".join(results)) + + return valid_metrics_dict diff --git a/tests/deci_core_recipe_test_suite_runner.py b/tests/deci_core_recipe_test_suite_runner.py index 02696c1498..9ffbfd3ac4 100644 --- a/tests/deci_core_recipe_test_suite_runner.py +++ b/tests/deci_core_recipe_test_suite_runner.py @@ -2,6 +2,7 @@ import unittest from tests.recipe_training_tests.automatic_batch_selection_single_gpu_test import TestAutoBatchSelectionSingleGPU +from tests.recipe_training_tests.coded_qat_launch_test import CodedQATLuanchTest from tests.recipe_training_tests.shortened_recipes_accuracy_test import ShortenedRecipesAccuracyTests @@ -17,6 +18,7 @@ def _add_modules_to_unit_tests_suite(self): _add_modules_to_unit_tests_suite - Adds unit tests to the Unit Tests Test Suite :return: """ + self.recipe_tests_suite.addTest(self.test_loader.loadTestsFromModule(CodedQATLuanchTest)) self.recipe_tests_suite.addTest(self.test_loader.loadTestsFromModule(ShortenedRecipesAccuracyTests)) self.recipe_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestAutoBatchSelectionSingleGPU)) diff --git a/tests/recipe_training_tests/coded_qat_launch_test.py b/tests/recipe_training_tests/coded_qat_launch_test.py new file mode 100644 index 0000000000..948e7826eb --- /dev/null +++ b/tests/recipe_training_tests/coded_qat_launch_test.py @@ -0,0 +1,39 @@ +import unittest + +from super_gradients import QATTrainer +from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader +from super_gradients.training.metrics import Accuracy, Top5 +from super_gradients.training.models import ResNet18 + + +class CodedQATLuanchTest(unittest.TestCase): + def test_qat_launch(self): + trainer = QATTrainer("test_launch_qat_with_minimal_changes") + net = ResNet18(num_classes=5, arch_params={}) + train_params = { + "max_epochs": 2, + "lr_updates": [1], + "lr_decay_factor": 0.1, + "lr_mode": "step", + "lr_warmup_epochs": 0, + "initial_lr": 0.1, + "loss": "cross_entropy", + "optimizer": "SGD", + "criterion_params": {}, + "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9}, + "train_metrics_list": [Accuracy(), Top5()], + "valid_metrics_list": [Accuracy(), Top5()], + "metric_to_watch": "Accuracy", + "greater_metric_to_watch_is_better": True, + } + trainer.train( + model=net, + training_params=train_params, + train_loader=classification_test_dataloader(batch_size=10), + valid_loader=classification_test_dataloader(batch_size=10), + ) + trainer.quantize(calib_dataloader=classification_test_dataloader(batch_size=10)) + + +if __name__ == "__main__": + unittest.main() From 73b4ba8768ac328612cb254ed2d63338858b7424 Mon Sep 17 00:00:00 2001 From: shayaharon Date: Sun, 14 May 2023 17:40:32 +0300 Subject: [PATCH 02/10] changed local --- src/super_gradients/recipes/roboflow_yolo_nas_s_qat.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/src/super_gradients/recipes/roboflow_yolo_nas_s_qat.yaml b/src/super_gradients/recipes/roboflow_yolo_nas_s_qat.yaml index ac92b7032d..10ec0598f9 100644 --- a/src/super_gradients/recipes/roboflow_yolo_nas_s_qat.yaml +++ b/src/super_gradients/recipes/roboflow_yolo_nas_s_qat.yaml @@ -4,6 +4,7 @@ defaults: - _self_ checkpoint_params: + checkpoint_path: ??? strict_load: no_key_matching pre_launch_callbacks_list: From 33699f3a42f0e706a1bc03307bd7d96a8488759e Mon Sep 17 00:00:00 2001 From: shayaharon Date: Mon, 15 May 2023 10:48:11 +0300 Subject: [PATCH 03/10] switch to ema model before quantization if exists --- src/super_gradients/training/qat_trainer/qat_trainer.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/super_gradients/training/qat_trainer/qat_trainer.py b/src/super_gradients/training/qat_trainer/qat_trainer.py index 630a012362..1d2a8f4a71 100644 --- a/src/super_gradients/training/qat_trainer/qat_trainer.py +++ b/src/super_gradients/training/qat_trainer/qat_trainer.py @@ -165,8 +165,9 @@ def quantize( :param calib_dataloader: DataLoader, data loader for calibration. - :param model: torch.nn.Module, Model to perform QAT/PTQ on. When None, will try to use self.net which is set - in previous self.train(..) call (default=None). + :param model: torch.nn.Module, Model to perform QAT/PTQ on. When None, will try to use the network from + previous self.train call(that is, if there was one - will try to use self.ema_model.ema if EMA was used, + otherwise self.net)................................ :param val_dataloader: DataLoader, data loader for validation. Used both for validating the calibrated model after PTQ and during QAT. @@ -213,7 +214,7 @@ def quantize( valid_metrics_list = valid_metrics_list or get_param(training_params, "valid_metrics_list") train_dataloader = train_dataloader or self.train_loader val_dataloader = val_dataloader or self.valid_loader - model = model or self.net + model = model or get_param(self.ema_model, "ema") or self.net res = self.calibrate_model( calib_dataloader=calib_dataloader, @@ -289,7 +290,7 @@ def calibrate_model(self, calib_dataloader, model, quantization_params, val_data """ selective_quantizer_params = get_param(quantization_params, "selective_quantizer_params") calib_params = get_param(quantization_params, "calib_params") - model = model or self.net + model = model or get_param(self.ema_model, "ema") or self.net model.to(device_config.device) # QUANTIZE MODEL model.eval() From b7cddea9a3511aa8bb840f6a500b6c736a722f01 Mon Sep 17 00:00:00 2001 From: shayaharon Date: Tue, 16 May 2023 13:51:02 +0300 Subject: [PATCH 04/10] midifying method complete --- .../pre_launch_callbacks.py | 180 +++++++--- .../training/qat_trainer/qat_trainer.py | 319 +---------------- .../training/sg_trainer/sg_trainer.py | 324 +++++++++++++++++- .../coded_qat_launch_test.py | 2 +- 4 files changed, 449 insertions(+), 376 deletions(-) diff --git a/src/super_gradients/training/pre_launch_callbacks/pre_launch_callbacks.py b/src/super_gradients/training/pre_launch_callbacks/pre_launch_callbacks.py index 279eca858b..4f3e8263cd 100644 --- a/src/super_gradients/training/pre_launch_callbacks/pre_launch_callbacks.py +++ b/src/super_gradients/training/pre_launch_callbacks/pre_launch_callbacks.py @@ -5,6 +5,7 @@ from omegaconf import DictConfig import torch +from super_gradients.common.environment.cfg_utils import load_recipe from super_gradients.common.registry.registry import register_pre_launch_callback from super_gradients import is_distributed from super_gradients.common.abstractions.abstract_logger import get_logger @@ -13,6 +14,8 @@ import cv2 import numpy as np +from super_gradients.training.utils import get_param + logger = get_logger(__name__) @@ -70,7 +73,7 @@ class AutoTrainBatchSizeSelectionCallback(PreLaunchCallback): :param max_batch_size: int, optional, upper limit of the batch sizes to try. When None, the search will continue until the maximal batch size that does not raise CUDA OUT OF MEMORY is found (deafult=None). - :param scale_lr: bool, whether to linearly scale cfg.training_hyperparams.initial_lr, i.e multiply by + :param scale_lr: bool, whether to linearly scale cfg.training_hyperparamsinitial_lr, i.e multiply by FOUND_BATCH_SIZE/cfg.dataset_params.train_datalaoder_params.batch_size (default=True) :param mode: str, one of ["fastest","largest"], whether to select the largest batch size that fits memory or the one that the resulted in overall fastest execution. @@ -103,14 +106,14 @@ def __call__(self, cfg: DictConfig) -> DictConfig: load_backbone=cfg.checkpoint_params.load_backbone, ) tmp_cfg = deepcopy(cfg) - tmp_cfg.training_hyperparams.batch_accumulate = 1 - tmp_cfg.training_hyperparams.max_train_batches = self.num_forward_passes - tmp_cfg.training_hyperparams.run_validation_freq = 2 - tmp_cfg.training_hyperparams.silent_mode = True - tmp_cfg.training_hyperparams.save_model = False - tmp_cfg.training_hyperparams.max_epochs = 1 - tmp_cfg.training_hyperparams.average_best_models = False - tmp_cfg.training_hyperparams.kill_ddp_pgroup_on_end = False + tmp_cfg.training_hyperparamsbatch_accumulate = 1 + tmp_cfg.training_hyperparamsmax_train_batches = self.num_forward_passes + tmp_cfg.training_hyperparamsrun_validation_freq = 2 + tmp_cfg.training_hyperparamssilent_mode = True + tmp_cfg.training_hyperparamssave_model = False + tmp_cfg.training_hyperparamsmax_epochs = 1 + tmp_cfg.training_hyperparamsaverage_best_models = False + tmp_cfg.training_hyperparamskill_ddp_pgroup_on_end = False tmp_cfg.pre_launch_callbacks_list = [] fastest_batch_time = np.inf @@ -166,7 +169,7 @@ def _inject_selected_batch_size_to_config(self, cfg, model, msg, selected_batch_ def _adapt_lr_if_needed(self, cfg: DictConfig, found_batch_size: int) -> DictConfig: if self.scale_lr: scale_factor = found_batch_size / cfg.dataset_params.train_dataloader_params.batch_size - cfg.training_hyperparams.initial_lr = cfg.training_hyperparams.initial_lr * scale_factor + cfg.training_hyperparamsinitial_lr = cfg.training_hyperparamsinitial_lr * scale_factor return cfg @classmethod @@ -180,6 +183,113 @@ def _clear_model_gpu_mem(cls, model): barrier() +def modify_params_for_qat( + training_hyperparams, + train_dataset_params, + val_dataset_params, + train_dataloader_params, + val_dataloader_params, + quantization_params=None, + batch_size_divisor: int = 2, + max_epochs_divisor: int = 10, + lr_decay_factor: float = 0.01, + warmup_epochs_divisor: int = 10, + cosine_final_lr_ratio: float = 0.01, + disable_phase_callbacks: bool = True, + disable_augmentations: bool = False, +): + """ + + This method modifies the recipe for QAT to implement rules of thumb based on the regular non-qat recipe. + It does so by manipulating the training_hyperparams, train_dataloader_params, val_dataloader_params, train_dataset_params, val_dataset_params. + Usage: + train_dataloader_params = {'batch_size':32 + + :param val_dataset_params: Dict, validation dataset_params to be passed to dataloaders.get(...) when instantiating the train dataloader. + :param train_dataset_params: Dict, train dataset_params to be passed to dataloaders.get(...) when instantiating the validation dataloader. + :param val_dataloader_params: Dict, validation dataloader_params to be passed to dataloaders.get(...) when instantiating the validation dataloader. + :param train_dataloader_params: Dict, train dataloader_params to be passed to dataloaders.get(...) when instantiating the train dataloader. + :param training_hyperparams: Dict, train parameters passed to Trainer.qat(...) + :param quantization_params: Dict, quantization parameters as passed to Trainer.qat(...). When None, will use the + default parameters in super_gradients/recipes/quantization_params/default_quantization_params.yaml + :param int batch_size_divisor: Divisor used to calculate the batch size. Default value is 2. + :param int max_epochs_divisor: Divisor used to calculate the maximum number of epochs. Default value is 10. + :param float lr_decay_factor: Factor used to decay the learning rate, weight decay and warmup. Default value is 0.01. + :param int warmup_epochs_divisor: Divisor used to calculate the number of warm-up epochs. Default value is 10. + :param float cosine_final_lr_ratio: Ratio used to determine the final learning rate in a cosine annealing schedule. Default value is 0.01. + :param bool disable_phase_callbacks: Flag to control to disable phase callbacks, which can interfere with QAT. Default value is True. + :param bool disable_augmentations: Flag to control to disable phase augmentations, which can interfere with QAT. Default value is False. + :return: modified (copy) quantization_params, training_hyperparams, train_dataloader_params, val_dataloader_params, train_dataset_params, val_dataset_params + """ + if quantization_params is None: + quantization_params = load_recipe("quantization_params/default_quantization_params").quantization_params + + quantization_params = deepcopy(quantization_params) + training_hyperparams = deepcopy(training_hyperparams) + train_dataloader_params = deepcopy(train_dataloader_params) + val_dataloader_params = deepcopy(val_dataloader_params) + train_dataset_params = deepcopy(train_dataset_params) + val_dataset_params = deepcopy(val_dataset_params) + + if "max_epochs" not in training_hyperparams.keys(): + raise ValueError("max_epochs is a required field in training_hyperparams for QAT modification.") + + if "initial_lr" not in training_hyperparams.keys(): + raise ValueError("initial_lr is a required field in training_hyperparams for QAT modification.") + + if "optimizer_params" not in training_hyperparams.keys(): + raise ValueError("optimizer_params is a required field in training_hyperparams for QAT modification.") + + if "weight_decay" not in training_hyperparams["optimizer_params"].keys(): + raise ValueError("weight_decay is a required field in training_hyperparams['optimizer_params'] for QAT modification.") + + # Q/DQ Layers take a lot of space for activations in training mode + if get_param(quantization_params, "selective_quantizer_params") and get_param(quantization_params["selective_quantizer_params"], "learn_amax"): + train_dataloader_params["batch_size"] //= batch_size_divisor + val_dataloader_params["batch_size"] //= batch_size_divisor + + logger.warning(f"New dataset_params.train_dataloader_params.batch_size: {train_dataloader_params['batch_size']}") + logger.warning(f"New dataset_params.val_dataloader_params.batch_size: {val_dataloader_params['batch_size']}") + training_hyperparams["max_epochs"] //= max_epochs_divisor + logger.warning(f"New number of epochs: {training_hyperparams['max_epochs']}") + training_hyperparams["initial_lr"] *= lr_decay_factor + if get_param(training_hyperparams, "warmup_initial_lr") is not None: + training_hyperparams["warmup_initial_lr"] *= lr_decay_factor + else: + training_hyperparams["warmup_initial_lr"] = training_hyperparams["initial_lr"] * 0.01 + training_hyperparams["optimizer_params"]["weight_decay"] *= lr_decay_factor + logger.warning(f"New learning rate: {training_hyperparams['initial_lr']}") + logger.warning(f"New weight decay: {training_hyperparams['optimizer_params']['weight_decay']}") + # as recommended by pytorch-quantization docs + if get_param(training_hyperparams, "lr_mode") != "cosine": + training_hyperparams["lr_mode"] = "cosine" + training_hyperparams["cosine_final_lr_ratio"] = cosine_final_lr_ratio + logger.warning( + f"lr_mode will be set to cosine for QAT run instead of {get_param(training_hyperparams, 'lr_mode')} with " + f"cosine_final_lr_ratio={cosine_final_lr_ratio}" + ) + + training_hyperparams["lr_warmup_epochs"] = (training_hyperparams["max_epochs"] // warmup_epochs_divisor) or 1 + logger.warning(f"New lr_warmup_epochs: {training_hyperparams['lr_warmup_epochs']}") + + # do mess with Q/DQ + if get_param(training_hyperparams, "ema"): + logger.warning("EMA will be disabled for QAT run.") + training_hyperparams["ema"] = False + if get_param(training_hyperparams, "sync_bn"): + logger.warning("SyncBatchNorm will be disabled for QAT run.") + training_hyperparams["sync_bn"] = False + if disable_phase_callbacks and get_param(training_hyperparams, "phase_callbacks") is not None and len(training_hyperparams["phase_callbacks"]) > 0: + logger.warning(f"Recipe contains {len(training_hyperparams['phase_callbacks'])} phase callbacks. All of them will be disabled.") + training_hyperparams["phase_callbacks"] = [] + # no augmentations + if disable_augmentations and "transforms" in val_dataset_params: + logger.warning("Augmentations will be disabled for QAT run. Using validation transforms instead.") + train_dataset_params["transforms"] = val_dataset_params["transforms"] + + return quantization_params, training_hyperparams, train_dataloader_params, val_dataloader_params, train_dataset_params, val_dataset_params + + @register_pre_launch_callback() class QATRecipeModificationCallback(PreLaunchCallback): """ @@ -209,7 +319,7 @@ class QATRecipeModificationCallback(PreLaunchCallback): disable_phase_callbacks: True disable_augmentations: False - USE THIS CALLBACK ONLY WITH QATTrainer! + USE THIS CALLBACK ONLY WITH Trainer.quantize_from_config """ def __init__( @@ -234,55 +344,15 @@ def __call__(self, cfg: Union[dict, DictConfig]) -> Union[dict, DictConfig]: logger.info("Modifying recipe to suit QAT rules of thumb. Remove QATRecipeModificationCallback to disable.") cfg = copy.deepcopy(cfg) + quantization_params = cfg.quantization_params + dataset_params = cfg.dataset_params + training_hyperparams = cfg.training_hyperparams - # Q/DQ Layers take a lot of space for activations in training mode - if cfg.quantization_params.selective_quantizer_params.learn_amax: - cfg.dataset_params.train_dataloader_params.batch_size //= self.batch_size_divisor - cfg.dataset_params.val_dataloader_params.batch_size //= self.batch_size_divisor - - logger.warning(f"New dataset_params.train_dataloader_params.batch_size: {cfg.dataset_params.train_dataloader_params.batch_size}") - logger.warning(f"New dataset_params.val_dataloader_params.batch_size: {cfg.dataset_params.val_dataloader_params.batch_size}") - - cfg.training_hyperparams.max_epochs //= self.max_epochs_divisor - logger.warning(f"New number of epochs: {cfg.training_hyperparams.max_epochs}") - - cfg.training_hyperparams.initial_lr *= self.lr_decay_factor - if cfg.training_hyperparams.warmup_initial_lr is not None: - cfg.training_hyperparams.warmup_initial_lr *= self.lr_decay_factor - else: - cfg.training_hyperparams.warmup_initial_lr = cfg.training_hyperparams.initial_lr * 0.01 - - cfg.training_hyperparams.optimizer_params.weight_decay *= self.lr_decay_factor - - logger.warning(f"New learning rate: {cfg.training_hyperparams.initial_lr}") - logger.warning(f"New weight decay: {cfg.training_hyperparams.optimizer_params.weight_decay}") - - # as recommended by pytorch-quantization docs - cfg.training_hyperparams.lr_mode = "cosine" - cfg.training_hyperparams.lr_warmup_epochs = (cfg.training_hyperparams.max_epochs // self.warmup_epochs_divisor) or 1 - cfg.training_hyperparams.cosine_final_lr_ratio = self.cosine_final_lr_ratio - - # do mess with Q/DQ - if cfg.training_hyperparams.ema: - logger.warning("EMA will be disabled for QAT run.") - cfg.training_hyperparams.ema = False - - if cfg.training_hyperparams.sync_bn: - logger.warning("SyncBatchNorm will be disabled for QAT run.") - cfg.training_hyperparams.sync_bn = False - - if self.disable_phase_callbacks and len(cfg.training_hyperparams.phase_callbacks) > 0: - logger.warning(f"Recipe contains {len(cfg.training_hyperparams.phase_callbacks)} phase callbacks. All of them will be disabled.") - cfg.training_hyperparams.phase_callbacks = [] + cfg.dataset_params, cfg.quantization_params, cfg.training_hyperparams = modify_params_for_qat(dataset_params, quantization_params, training_hyperparams) if cfg.multi_gpu != "OFF" or cfg.num_gpus != 1: logger.warning(f"Recipe requests multi_gpu={cfg.multi_gpu} and num_gpus={cfg.num_gpus}. Changing to multi_gpu=OFF and num_gpus=1") cfg.multi_gpu = "OFF" cfg.num_gpus = 1 - # no augmentations - if self.disable_augmentations and "transforms" in cfg.dataset_params.val_dataset_params: - logger.warning("Augmentations will be disabled for QAT run.") - cfg.dataset_params.train_dataset_params.transforms = cfg.dataset_params.val_dataset_params.transforms - return cfg diff --git a/src/super_gradients/training/qat_trainer/qat_trainer.py b/src/super_gradients/training/qat_trainer/qat_trainer.py index 1d2a8f4a71..f9b7858237 100644 --- a/src/super_gradients/training/qat_trainer/qat_trainer.py +++ b/src/super_gradients/training/qat_trainer/qat_trainer.py @@ -1,328 +1,17 @@ -import os -from typing import Union, Tuple, Dict, Mapping, List -from torchmetrics import Metric +from typing import Union, Tuple -import copy -import hydra -import torch.cuda +from deprecated import deprecated from omegaconf import DictConfig -from omegaconf import OmegaConf from torch import nn -from torch.utils.data import DataLoader from super_gradients.common.abstractions.abstract_logger import get_logger -from super_gradients.common.environment.cfg_utils import load_recipe -from super_gradients.common.environment.device_utils import device_config -from super_gradients.training import utils as core_utils, models, dataloaders from super_gradients.training.sg_trainer import Trainer -from super_gradients.training.utils import get_param -from super_gradients.training.utils.distributed_training_utils import setup_device -from super_gradients.modules.repvgg_block import fuse_repvgg_blocks_residual_branches logger = get_logger(__name__) -try: - from super_gradients.training.utils.quantization.calibrator import QuantizationCalibrator - from super_gradients.training.utils.quantization.export import export_quantized_module_to_onnx - from super_gradients.training.utils.quantization.selective_quantization_utils import SelectiveQuantizer - - _imported_pytorch_quantization_failure = None - -except (ImportError, NameError, ModuleNotFoundError) as import_err: - logger.debug("Failed to import pytorch_quantization:") - logger.debug(import_err) - _imported_pytorch_quantization_failure = import_err class QATTrainer(Trainer): @classmethod + @deprecated(version="3.2.0", reason="QATTrainer is deprecated and will be removed in future release, use Trainer " "class instead.") def quantize_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module, Tuple]: - """ - Perform quantization aware training (QAT) according to a recipe configuration. - - This method will instantiate all the objects specified in the recipe, build and quantize the model, - and calibrate the quantized model. The resulting quantized model and the output of the trainer.train() - method will be returned. - - The quantized model will be exported to ONNX along with other checkpoints. - - The call to self.quantize (see docs in the next method) is done with the created - train_loader and valid_loader. If no calibration data loader is passed through cfg.calib_loader, - a train data laoder with the validation transforms is used for calibration. - - :param cfg: The parsed DictConfig object from yaml recipe files or a dictionary. - :return: A tuple containing the quantized model and the output of trainer.train() method. - :rtype: Tuple[nn.Module, Tuple] - - :raises ValueError: If the recipe does not have the required key `quantization_params` or - `checkpoint_params.checkpoint_path` in it. - :raises NotImplementedError: If the recipe requests multiple GPUs or num_gpus is not equal to 1. - :raises ImportError: If pytorch-quantization import was unsuccessful - - """ - if _imported_pytorch_quantization_failure is not None: - raise _imported_pytorch_quantization_failure - - # INSTANTIATE ALL OBJECTS IN CFG - cfg = hydra.utils.instantiate(cfg) - - # TRIGGER CFG MODIFYING CALLBACKS - cfg = cls._trigger_cfg_modifying_callbacks(cfg) - - quantization_params = get_param(cfg, "quantization_params") - - if quantization_params is None: - raise logger.warning("Your recipe does not include quantization_params. Using default quantization params.") - - if get_param(cfg.checkpoint_params, "checkpoint_path") is None and get_param(cfg.checkpoint_params, "pretrained_weights") is None: - raise ValueError("Starting checkpoint / pretrained weights are a must for QAT finetuning.") - - num_gpus = core_utils.get_param(cfg, "num_gpus") - multi_gpu = core_utils.get_param(cfg, "multi_gpu") - device = core_utils.get_param(cfg, "device") - if num_gpus != 1: - raise NotImplementedError( - f"Recipe requests multi_gpu={cfg.multi_gpu} and num_gpus={cfg.num_gpus}. QAT is proven to work correctly only with multi_gpu=OFF and num_gpus=1" - ) - - setup_device(device=device, multi_gpu=multi_gpu, num_gpus=num_gpus) - - # INSTANTIATE DATA LOADERS - train_dataloader = dataloaders.get( - name=get_param(cfg, "train_dataloader"), - dataset_params=copy.deepcopy(cfg.dataset_params.train_dataset_params), - dataloader_params=copy.deepcopy(cfg.dataset_params.train_dataloader_params), - ) - - val_dataloader = dataloaders.get( - name=get_param(cfg, "val_dataloader"), - dataset_params=copy.deepcopy(cfg.dataset_params.val_dataset_params), - dataloader_params=copy.deepcopy(cfg.dataset_params.val_dataloader_params), - ) - - if "calib_dataloader" in cfg: - calib_dataloader_name = get_param(cfg, "calib_dataloader") - calib_dataloader_params = copy.deepcopy(cfg.dataset_params.calib_dataloader_params) - calib_dataset_params = copy.deepcopy(cfg.dataset_params.calib_dataset_params) - else: - calib_dataloader_name = get_param(cfg, "train_dataloader") - calib_dataloader_params = copy.deepcopy(cfg.dataset_params.train_dataloader_params) - calib_dataset_params = copy.deepcopy(cfg.dataset_params.train_dataset_params) - - # if we use whole dataloader for calibration, don't shuffle it - # HistogramCalibrator collection routine is sensitive to order of batches and produces slightly different results - # if we use several batches, we don't want it to be from one class if it's sequential in dataloader - # model is in eval mode, so BNs will not be affected - calib_dataloader_params.shuffle = cfg.quantization_params.calib_params.num_calib_batches is not None - # we don't need training transforms during calibration, distribution of activations will be skewed - calib_dataset_params.transforms = cfg.dataset_params.val_dataset_params.transforms - - calib_dataloader = dataloaders.get( - name=calib_dataloader_name, - dataset_params=calib_dataset_params, - dataloader_params=calib_dataloader_params, - ) - - # BUILD MODEL - model = models.get( - model_name=cfg.arch_params.get("model_name", None) or cfg.architecture, - num_classes=cfg.get("num_classes", None) or cfg.arch_params.num_classes, - arch_params=cfg.arch_params, - strict_load=cfg.checkpoint_params.strict_load, - pretrained_weights=cfg.checkpoint_params.pretrained_weights, - checkpoint_path=cfg.checkpoint_params.checkpoint_path, - load_backbone=False, - ) - - recipe_logged_cfg = {"recipe_config": OmegaConf.to_container(cfg, resolve=True)} - trainer = QATTrainer(experiment_name=cfg.experiment_name, ckpt_root_dir=get_param(cfg, "ckpt_root_dir")) - - res = trainer.quantize( - model=model, - quantization_params=quantization_params, - calib_dataloader=calib_dataloader, - val_dataloader=val_dataloader, - train_dataloader=train_dataloader, - training_params=cfg.training_hyperparams, - additional_qat_configs_to_log=recipe_logged_cfg, - ) - - return model, res - - def quantize( - self, - calib_dataloader: DataLoader, - model: torch.nn.Module = None, - val_dataloader: DataLoader = None, - train_dataloader: DataLoader = None, - quantization_params: Mapping = None, - training_params: Mapping = None, - additional_qat_configs_to_log: Dict = None, - valid_metrics_list: List[Metric] = None, - ): - """ - Performs post-training quantization (PTQ), and optionally quantization-aware training (QAT). - Exports the ONNX model to the checkpoints directory. - - :param calib_dataloader: DataLoader, data loader for calibration. - - :param model: torch.nn.Module, Model to perform QAT/PTQ on. When None, will try to use the network from - previous self.train call(that is, if there was one - will try to use self.ema_model.ema if EMA was used, - otherwise self.net)................................ - - - :param val_dataloader: DataLoader, data loader for validation. Used both for validating the calibrated model after PTQ and during QAT. - When None, will try to use self.valid_loader if it was set in previous self.train(..) call (default=None). - - :param train_dataloader: DataLoader, data loader for QA training, can be ignored when quantization_params["ptq_only"]=True (default=None). - - :param quantization_params: Mapping, with the following entries:defaults- - - ptq_only: False # whether to launch QAT, or leave PTQ only - selective_quantizer_params: - calibrator_w: "max" # calibrator type for weights, acceptable types are ["max", "histogram"] - calibrator_i: "histogram" # calibrator type for inputs acceptable types are ["max", "histogram"] - per_channel: True # per-channel quantization of weights, activations stay per-tensor by default - learn_amax: False # enable learnable amax in all TensorQuantizers using straight-through estimator - skip_modules: # optional list of module names (strings) to skip from quantization - - calib_params: - histogram_calib_method: "percentile" # calibration method for all "histogram" calibrators, acceptable types are ["percentile", "entropy", mse"], - "max" calibrators always use "max" - percentile: 99.99 # percentile for all histogram calibrators with method "percentile", other calibrators are not affected - - num_calib_batches: # number of batches to use for calibration, if None, 512 / batch_size will be used - verbose: False # if calibrator should be verbose - - - :param training_params: Mapping, training hyper parameters for QAT, same as in super.train(...). When None, will try to use self.training_params - which is set in previous self.train(..) call (default=None). - - :param additional_qat_configs_to_log: Dict, Optional dictionary containing configs that will be added to the QA training's - sg_logger. Format should be {"Config_title_1": {...}, "Config_title_2":{..}}. - - :param valid_metrics_list: (list(torchmetrics.Metric)) metrics list for evaluation of the calibrated model. - When None, the validation metrics from training_params are used). (default=None). - - :return: Validation results of the QAT model in case quantization_params['ptq_only']=False and of the PTQ - model otherwise. - """ - - if quantization_params is None: - quantization_params = load_recipe("quantization_params/default_quantization_params").quantization_params - logger.info(f"Using default quantization params: {quantization_params}") - training_params = training_params or self.training_params.to_dict() - valid_metrics_list = valid_metrics_list or get_param(training_params, "valid_metrics_list") - train_dataloader = train_dataloader or self.train_loader - val_dataloader = val_dataloader or self.valid_loader - model = model or get_param(self.ema_model, "ema") or self.net - - res = self.calibrate_model( - calib_dataloader=calib_dataloader, - model=model, - quantization_params=quantization_params, - val_dataloader=val_dataloader, - valid_metrics_list=valid_metrics_list, - ) - # TRAIN - if get_param(quantization_params, "ptq_only", False): - logger.info("quantization_params.ptq_only=True. Performing PTQ only!") - suffix = "ptq" - else: - model.train() - torch.cuda.empty_cache() - - res = self.train( - model=model, - train_loader=train_dataloader, - valid_loader=val_dataloader, - training_params=training_params, - additional_configs_to_log=additional_qat_configs_to_log, - ) - suffix = "qat" - # EXPORT QUANTIZED MODEL TO ONNX - input_shape = next(iter(val_dataloader))[0].shape - os.makedirs(self.checkpoints_dir_path, exist_ok=True) - qdq_onnx_path = os.path.join(self.checkpoints_dir_path, f"{self.experiment_name}_{'x'.join((str(x) for x in input_shape))}_{suffix}.onnx") - - # TODO: modify SG's convert_to_onnx for quantized models and use it instead - export_quantized_module_to_onnx( - model=model.cpu(), - onnx_filename=qdq_onnx_path, - input_shape=input_shape, - input_size=input_shape, - train=False, - ) - logger.info(f"Exported {suffix.upper()} ONNX to {qdq_onnx_path}") - return res - - def calibrate_model(self, calib_dataloader, model, quantization_params, val_dataloader, valid_metrics_list): - """ - Performs calibration. - - :param calib_dataloader: DataLoader, data loader for calibration. - - :param model: torch.nn.Module, Model to perform calibration on. When None, will try to use self.net which is - set in previous self.train(..) call (default=None). - - :param val_dataloader: DataLoader, data loader for validation. Used both for validating the calibrated model. - When None, will try to use self.valid_loader if it was set in previous self.train(..) call (default=None). - - :param quantization_params: Mapping, with the following entries:defaults- - selective_quantizer_params: - calibrator_w: "max" # calibrator type for weights, acceptable types are ["max", "histogram"] - calibrator_i: "histogram" # calibrator type for inputs acceptable types are ["max", "histogram"] - per_channel: True # per-channel quantization of weights, activations stay per-tensor by default - learn_amax: False # enable learnable amax in all TensorQuantizers using straight-through estimator - skip_modules: # optional list of module names (strings) to skip from quantization - - calib_params: histogram_calib_method: "percentile" # calibration method for all "histogram" calibrators, - acceptable types are ["percentile", "entropy", mse"], "max" calibrators always use "max" percentile: - 99.99 # percentile for all histogram calibrators with method "percentile", - other calibrators are not affected num_calib_batches: # number of batches to use for - calibration, if None, 512 / batch_size will be used verbose: False # if calibrator - should be verbose - - - - :param valid_metrics_list: (list(torchmetrics.Metric)) metrics list for evaluation of the calibrated model. - - :return: Validation results of the calibrated model. - """ - selective_quantizer_params = get_param(quantization_params, "selective_quantizer_params") - calib_params = get_param(quantization_params, "calib_params") - model = model or get_param(self.ema_model, "ema") or self.net - model.to(device_config.device) - # QUANTIZE MODEL - model.eval() - fuse_repvgg_blocks_residual_branches(model) - q_util = SelectiveQuantizer( - default_quant_modules_calibrator_weights=get_param(selective_quantizer_params, "calibrator_w"), - default_quant_modules_calibrator_inputs=get_param(selective_quantizer_params, "calibrator_i"), - default_per_channel_quant_weights=get_param(selective_quantizer_params, "per_channel"), - default_learn_amax=get_param(selective_quantizer_params, "learn_amax"), - verbose=get_param(calib_params, "verbose"), - ) - q_util.register_skip_quantization(layer_names=get_param(selective_quantizer_params, "skip_modules")) - q_util.quantize_module(model) - # CALIBRATE MODEL - logger.info("Calibrating model...") - calibrator = QuantizationCalibrator( - verbose=get_param(calib_params, "verbose"), - torch_hist=True, - ) - calibrator.calibrate_model( - model, - method=get_param(calib_params, "histogram_calib_method"), - calib_data_loader=calib_dataloader, - num_calib_batches=get_param(calib_params, "num_calib_batches") or len(calib_dataloader), - percentile=get_param(calib_params, "percentile", 99.99), - ) - calibrator.reset_calibrators(model) # release memory taken by calibrators - # VALIDATE PTQ MODEL AND PRINT SUMMARY - logger.info("Validating PTQ model...") - valid_metrics_dict = self.test(model=model, test_loader=val_dataloader, test_metrics_list=valid_metrics_list) - results = ["PTQ Model Validation Results"] - results += [f" - {metric:10}: {value}" for metric, value in valid_metrics_dict.items()] - logger.info("\n".join(results)) - - return valid_metrics_dict + return Trainer.quantize_from_config(cfg) diff --git a/src/super_gradients/training/sg_trainer/sg_trainer.py b/src/super_gradients/training/sg_trainer/sg_trainer.py index 117beb2bdc..9bc4acf6e5 100755 --- a/src/super_gradients/training/sg_trainer/sg_trainer.py +++ b/src/super_gradients/training/sg_trainer/sg_trainer.py @@ -1,24 +1,27 @@ +import copy import inspect import os from copy import deepcopy from pathlib import Path -from typing import Union, Tuple, Mapping, Dict, Any +from typing import Union, Tuple, Mapping, Dict, Any, List import hydra import numpy as np import torch -from omegaconf import DictConfig -from omegaconf import OmegaConf +import torch.cuda +import torch.nn +from omegaconf import DictConfig, OmegaConf from piptools.scripts.sync import _get_installed_distributions from torch import nn from torch.cuda.amp import GradScaler, autocast from torch.utils.data import DataLoader, SequentialSampler from torch.utils.data.distributed import DistributedSampler -from torchmetrics import MetricCollection +from torchmetrics import MetricCollection, Metric from tqdm import tqdm from super_gradients.common.environment.checkpoints_dir_utils import get_checkpoints_dir_path, get_ckpt_local_path from super_gradients.module_interfaces import HasPreprocessingParams, HasPredict +from super_gradients.modules.repvgg_block import fuse_repvgg_blocks_residual_branches from super_gradients.training.utils.sg_trainer_utils import get_callable_param_names from super_gradients.common.abstractions.abstract_logger import get_logger @@ -83,10 +86,21 @@ from super_gradients.common.registry.registry import LR_SCHEDULERS_CLS_DICT, LR_WARMUP_CLS_DICT from super_gradients.common.environment.device_utils import device_config from super_gradients.training.utils import HpmStruct -from super_gradients.common.environment.cfg_utils import load_experiment_cfg, add_params_to_cfg +from super_gradients.common.environment.cfg_utils import load_experiment_cfg, add_params_to_cfg, load_recipe from super_gradients.common.factories.pre_launch_callbacks_factory import PreLaunchCallbacksFactory from super_gradients.training.params import TrainingParams +try: + from super_gradients.training.utils.quantization.calibrator import QuantizationCalibrator + from super_gradients.training.utils.quantization.export import export_quantized_module_to_onnx + from super_gradients.training.utils.quantization.selective_quantization_utils import SelectiveQuantizer + + _imported_pytorch_quantization_failure = None + +except (ImportError, NameError, ModuleNotFoundError) as import_err: + logger.debug("Failed to import pytorch_quantization:") + logger.debug(import_err) + _imported_pytorch_quantization_failure = import_err logger = get_logger(__name__) @@ -2006,3 +2020,303 @@ def _init_loss_logging_names(self, loss_logging_items): self.metric_to_watch = criterion_name + "/" + self.metric_to_watch else: self.loss_logging_items_names = [criterion_name] + + @classmethod + def quantize_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module, Tuple]: + """ + Perform quantization aware training (QAT) according to a recipe configuration. + + This method will instantiate all the objects specified in the recipe, build and quantize the model, + and calibrate the quantized model. The resulting quantized model and the output of the trainer.train() + method will be returned. + + The quantized model will be exported to ONNX along with other checkpoints. + + The call to self.quantize (see docs in the next method) is done with the created + train_loader and valid_loader. If no calibration data loader is passed through cfg.calib_loader, + a train data laoder with the validation transforms is used for calibration. + + :param cfg: The parsed DictConfig object from yaml recipe files or a dictionary. + :return: A tuple containing the quantized model and the output of trainer.train() method. + :rtype: Tuple[nn.Module, Tuple] + + :raises ValueError: If the recipe does not have the required key `quantization_params` or + `checkpoint_params.checkpoint_path` in it. + :raises NotImplementedError: If the recipe requests multiple GPUs or num_gpus is not equal to 1. + :raises ImportError: If pytorch-quantization import was unsuccessful + + """ + if _imported_pytorch_quantization_failure is not None: + raise _imported_pytorch_quantization_failure + + # INSTANTIATE ALL OBJECTS IN CFG + cfg = hydra.utils.instantiate(cfg) + + # TRIGGER CFG MODIFYING CALLBACKS + cfg = cls._trigger_cfg_modifying_callbacks(cfg) + + quantization_params = get_param(cfg, "quantization_params") + + if quantization_params is None: + raise logger.warning("Your recipe does not include quantization_params. Using default quantization params.") + + if get_param(cfg.checkpoint_params, "checkpoint_path") is None and get_param(cfg.checkpoint_params, "pretrained_weights") is None: + raise ValueError("Starting checkpoint / pretrained weights are a must for QAT finetuning.") + + num_gpus = core_utils.get_param(cfg, "num_gpus") + multi_gpu = core_utils.get_param(cfg, "multi_gpu") + device = core_utils.get_param(cfg, "device") + if num_gpus != 1: + raise NotImplementedError( + f"Recipe requests multi_gpu={cfg.multi_gpu} and num_gpus={cfg.num_gpus}. QAT is proven to work correctly only with multi_gpu=OFF and num_gpus=1" + ) + + setup_device(device=device, multi_gpu=multi_gpu, num_gpus=num_gpus) + + # INSTANTIATE DATA LOADERS + train_dataloader = dataloaders.get( + name=get_param(cfg, "train_dataloader"), + dataset_params=copy.deepcopy(cfg.dataset_params.train_dataset_params), + dataloader_params=copy.deepcopy(cfg.dataset_params.train_dataloader_params), + ) + + val_dataloader = dataloaders.get( + name=get_param(cfg, "val_dataloader"), + dataset_params=copy.deepcopy(cfg.dataset_params.val_dataset_params), + dataloader_params=copy.deepcopy(cfg.dataset_params.val_dataloader_params), + ) + + if "calib_dataloader" in cfg: + calib_dataloader_name = get_param(cfg, "calib_dataloader") + calib_dataloader_params = copy.deepcopy(cfg.dataset_params.calib_dataloader_params) + calib_dataset_params = copy.deepcopy(cfg.dataset_params.calib_dataset_params) + else: + calib_dataloader_name = get_param(cfg, "train_dataloader") + calib_dataloader_params = copy.deepcopy(cfg.dataset_params.train_dataloader_params) + calib_dataset_params = copy.deepcopy(cfg.dataset_params.train_dataset_params) + + # if we use whole dataloader for calibration, don't shuffle it + # HistogramCalibrator collection routine is sensitive to order of batches and produces slightly different results + # if we use several batches, we don't want it to be from one class if it's sequential in dataloader + # model is in eval mode, so BNs will not be affected + calib_dataloader_params.shuffle = cfg.quantization_params.calib_params.num_calib_batches is not None + # we don't need training transforms during calibration, distribution of activations will be skewed + calib_dataset_params.transforms = cfg.dataset_params.val_dataset_params.transforms + + calib_dataloader = dataloaders.get( + name=calib_dataloader_name, + dataset_params=calib_dataset_params, + dataloader_params=calib_dataloader_params, + ) + + # BUILD MODEL + model = models.get( + model_name=cfg.arch_params.get("model_name", None) or cfg.architecture, + num_classes=cfg.get("num_classes", None) or cfg.arch_params.num_classes, + arch_params=cfg.arch_params, + strict_load=cfg.checkpoint_params.strict_load, + pretrained_weights=cfg.checkpoint_params.pretrained_weights, + checkpoint_path=cfg.checkpoint_params.checkpoint_path, + load_backbone=False, + ) + + recipe_logged_cfg = {"recipe_config": OmegaConf.to_container(cfg, resolve=True)} + trainer = Trainer(experiment_name=cfg.experiment_name, ckpt_root_dir=get_param(cfg, "ckpt_root_dir")) + + res = trainer.qat( + model=model, + quantization_params=quantization_params, + calib_dataloader=calib_dataloader, + val_dataloader=val_dataloader, + train_dataloader=train_dataloader, + training_params=cfg.training_hyperparams, + additional_qat_configs_to_log=recipe_logged_cfg, + ) + + return model, res + + def qat( + self, + calib_dataloader: DataLoader, + model: torch.nn.Module = None, + val_dataloader: DataLoader = None, + train_dataloader: DataLoader = None, + quantization_params: Mapping = None, + training_params: Mapping = None, + additional_qat_configs_to_log: Dict = None, + valid_metrics_list: List[Metric] = None, + ): + """ + Performs post-training quantization (PTQ), and then quantization-aware training (QAT). + Exports the ONNX models (ckpt_best.pth of QAT and the calibrated model) to the checkpoints directory. + + :param calib_dataloader: DataLoader, data loader for calibration. + + :param model: torch.nn.Module, Model to perform QAT/PTQ on. When None, will try to use the network from + previous self.train call(that is, if there was one - will try to use self.ema_model.ema if EMA was used, + otherwise self.net) + + + :param val_dataloader: DataLoader, data loader for validation. Used both for validating the calibrated model after PTQ and during QAT. + When None, will try to use self.valid_loader if it was set in previous self.train(..) call (default=None). + + :param train_dataloader: DataLoader, data loader for QA training, can be ignored when quantization_params["ptq_only"]=True (default=None). + + :param quantization_params: Mapping, with the following entries:defaults- + ptq_only: False # whether to launch QAT, or leave PTQ only + selective_quantizer_params: + calibrator_w: "max" # calibrator type for weights, acceptable types are ["max", "histogram"] + calibrator_i: "histogram" # calibrator type for inputs acceptable types are ["max", "histogram"] + per_channel: True # per-channel quantization of weights, activations stay per-tensor by default + learn_amax: False # enable learnable amax in all TensorQuantizers using straight-through estimator + skip_modules: # optional list of module names (strings) to skip from quantization + + calib_params: + histogram_calib_method: "percentile" # calibration method for all "histogram" calibrators, + # acceptable types are ["percentile", "entropy", mse"], + # "max" calibrators always use "max" + + percentile: 99.99 # percentile for all histogram calibrators with method "percentile", + # other calibrators are not affected + + num_calib_batches: # number of batches to use for calibration, if None, 512 / batch_size will be used + verbose: False # if calibrator should be verbose + + + :param training_params: Mapping, training hyper parameters for QAT, same as in super.train(...). When None, will try to use self.training_params + which is set in previous self.train(..) call (default=None). + + :param additional_qat_configs_to_log: Dict, Optional dictionary containing configs that will be added to the QA training's + sg_logger. Format should be {"Config_title_1": {...}, "Config_title_2":{..}}. + + :param valid_metrics_list: (list(torchmetrics.Metric)) metrics list for evaluation of the calibrated model. + When None, the validation metrics from training_params are used). (default=None). + + :return: Validation results of the QAT model in case quantization_params['ptq_only']=False and of the PTQ + model otherwise. + """ + + if quantization_params is None: + quantization_params = load_recipe("quantization_params/default_quantization_params").quantization_params + logger.info(f"Using default quantization params: {quantization_params}") + training_params = training_params or self.training_params.to_dict() + valid_metrics_list = valid_metrics_list or get_param(training_params, "valid_metrics_list") + train_dataloader = train_dataloader or self.train_loader + val_dataloader = val_dataloader or self.valid_loader + model = model or get_param(self.ema_model, "ema") or self.net + + res = self.ptq( + calib_dataloader=calib_dataloader, + model=model, + quantization_params=quantization_params, + val_dataloader=val_dataloader, + valid_metrics_list=valid_metrics_list, + ) + # TRAIN + if get_param(quantization_params, "ptq_only", False): + logger.info("quantization_params.ptq_only=True. Performing PTQ only!") + suffix = "ptq" + else: + model.train() + torch.cuda.empty_cache() + + res = self.train( + model=model, + train_loader=train_dataloader, + valid_loader=val_dataloader, + training_params=training_params, + additional_configs_to_log=additional_qat_configs_to_log, + ) + suffix = "qat" + # EXPORT QUANTIZED MODEL TO ONNX + input_shape = next(iter(val_dataloader))[0].shape + os.makedirs(self.checkpoints_dir_path, exist_ok=True) + qdq_onnx_path = os.path.join(self.checkpoints_dir_path, f"{self.experiment_name}_{'x'.join((str(x) for x in input_shape))}_{suffix}.onnx") + + # TODO: modify SG's convert_to_onnx for quantized models and use it instead + export_quantized_module_to_onnx( + model=model.cpu(), + onnx_filename=qdq_onnx_path, + input_shape=input_shape, + input_size=input_shape, + train=False, + ) + logger.info(f"Exported {suffix.upper()} ONNX to {qdq_onnx_path}") + return res + + def ptq(self, calib_dataloader, model, quantization_params, val_dataloader, valid_metrics_list): + """ + Performs calibration. + + + :param calib_dataloader: DataLoader, data loader for calibration. + + :param model: torch.nn.Module, Model to perform calibration on. When None, will try to use self.net which is + set in previous self.train(..) call (default=None). + + :param val_dataloader: DataLoader, data loader for validation. Used both for validating the calibrated model. + When None, will try to use self.valid_loader if it was set in previous self.train(..) call (default=None). + + :param quantization_params: Mapping, with the following entries:defaults- + selective_quantizer_params: + calibrator_w: "max" # calibrator type for weights, acceptable types are ["max", "histogram"] + calibrator_i: "histogram" # calibrator type for inputs acceptable types are ["max", "histogram"] + per_channel: True # per-channel quantization of weights, activations stay per-tensor by default + learn_amax: False # enable learnable amax in all TensorQuantizers using straight-through estimator + skip_modules: # optional list of module names (strings) to skip from quantization + + calib_params: + histogram_calib_method: "percentile" # calibration method for all "histogram" calibrators, + # acceptable types are ["percentile", "entropy", mse"], + # "max" calibrators always use "max" + + percentile: 99.99 # percentile for all histogram calibrators with method "percentile", + # other calibrators are not affected + + num_calib_batches: # number of batches to use for calibration, if None, 512 / batch_size will be used + verbose: False # if calibrator should be verbose + + + + :param valid_metrics_list: (list(torchmetrics.Metric)) metrics list for evaluation of the calibrated model. + + :return: Validation results of the calibrated model. + """ + selective_quantizer_params = get_param(quantization_params, "selective_quantizer_params") + calib_params = get_param(quantization_params, "calib_params") + model = model or get_param(self.ema_model, "ema") or self.net + model.to(device_config.device) + # QUANTIZE MODEL + model.eval() + fuse_repvgg_blocks_residual_branches(model) + q_util = SelectiveQuantizer( + default_quant_modules_calibrator_weights=get_param(selective_quantizer_params, "calibrator_w"), + default_quant_modules_calibrator_inputs=get_param(selective_quantizer_params, "calibrator_i"), + default_per_channel_quant_weights=get_param(selective_quantizer_params, "per_channel"), + default_learn_amax=get_param(selective_quantizer_params, "learn_amax"), + verbose=get_param(calib_params, "verbose"), + ) + q_util.register_skip_quantization(layer_names=get_param(selective_quantizer_params, "skip_modules")) + q_util.quantize_module(model) + # CALIBRATE MODEL + logger.info("Calibrating model...") + calibrator = QuantizationCalibrator( + verbose=get_param(calib_params, "verbose"), + torch_hist=True, + ) + calibrator.calibrate_model( + model, + method=get_param(calib_params, "histogram_calib_method"), + calib_data_loader=calib_dataloader, + num_calib_batches=get_param(calib_params, "num_calib_batches") or len(calib_dataloader), + percentile=get_param(calib_params, "percentile", 99.99), + ) + calibrator.reset_calibrators(model) # release memory taken by calibrators + # VALIDATE PTQ MODEL AND PRINT SUMMARY + logger.info("Validating PTQ model...") + valid_metrics_dict = self.test(model=model, test_loader=val_dataloader, test_metrics_list=valid_metrics_list) + results = ["PTQ Model Validation Results"] + results += [f" - {metric:10}: {value}" for metric, value in valid_metrics_dict.items()] + logger.info("\n".join(results)) + + return valid_metrics_dict diff --git a/tests/recipe_training_tests/coded_qat_launch_test.py b/tests/recipe_training_tests/coded_qat_launch_test.py index 948e7826eb..ae7ca3925c 100644 --- a/tests/recipe_training_tests/coded_qat_launch_test.py +++ b/tests/recipe_training_tests/coded_qat_launch_test.py @@ -32,7 +32,7 @@ def test_qat_launch(self): train_loader=classification_test_dataloader(batch_size=10), valid_loader=classification_test_dataloader(batch_size=10), ) - trainer.quantize(calib_dataloader=classification_test_dataloader(batch_size=10)) + trainer.qat(calib_dataloader=classification_test_dataloader(batch_size=10)) if __name__ == "__main__": From 57e58801d73eade3215508858b36b00668c210ef Mon Sep 17 00:00:00 2001 From: shayaharon Date: Tue, 16 May 2023 13:57:30 +0300 Subject: [PATCH 05/10] midifying method cal in pre launch callback --- .../pre_launch_callbacks.py | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/src/super_gradients/training/pre_launch_callbacks/pre_launch_callbacks.py b/src/super_gradients/training/pre_launch_callbacks/pre_launch_callbacks.py index 4f3e8263cd..c26f16cf93 100644 --- a/src/super_gradients/training/pre_launch_callbacks/pre_launch_callbacks.py +++ b/src/super_gradients/training/pre_launch_callbacks/pre_launch_callbacks.py @@ -219,7 +219,7 @@ def modify_params_for_qat( :param float cosine_final_lr_ratio: Ratio used to determine the final learning rate in a cosine annealing schedule. Default value is 0.01. :param bool disable_phase_callbacks: Flag to control to disable phase callbacks, which can interfere with QAT. Default value is True. :param bool disable_augmentations: Flag to control to disable phase augmentations, which can interfere with QAT. Default value is False. - :return: modified (copy) quantization_params, training_hyperparams, train_dataloader_params, val_dataloader_params, train_dataset_params, val_dataset_params + :return: modified (copy) training_hyperparams, train_dataset_params, val_dataset_params, train_dataloader_params, val_dataloader_params """ if quantization_params is None: quantization_params = load_recipe("quantization_params/default_quantization_params").quantization_params @@ -287,7 +287,7 @@ def modify_params_for_qat( logger.warning("Augmentations will be disabled for QAT run. Using validation transforms instead.") train_dataset_params["transforms"] = val_dataset_params["transforms"] - return quantization_params, training_hyperparams, train_dataloader_params, val_dataloader_params, train_dataset_params, val_dataset_params + return training_hyperparams, train_dataset_params, val_dataset_params, train_dataloader_params, val_dataloader_params @register_pre_launch_callback() @@ -344,11 +344,20 @@ def __call__(self, cfg: Union[dict, DictConfig]) -> Union[dict, DictConfig]: logger.info("Modifying recipe to suit QAT rules of thumb. Remove QATRecipeModificationCallback to disable.") cfg = copy.deepcopy(cfg) - quantization_params = cfg.quantization_params - dataset_params = cfg.dataset_params - training_hyperparams = cfg.training_hyperparams - cfg.dataset_params, cfg.quantization_params, cfg.training_hyperparams = modify_params_for_qat(dataset_params, quantization_params, training_hyperparams) + ( + cfg.training_hyperparams, + cfg.dataset_params.train_dataset_params, + cfg.dataset_params.val_dataset_params, + cfg.dataset_params.train_dataloader_params, + cfg.dataset_params.val_dataloader_params, + ) = modify_params_for_qat( + training_hyperparams=cfg.training_hyperparams, + train_dataset_params=cfg.dataset_params.train_dataset_params, + val_dataset_params=cfg.dataset_params.val_dataset_params, + val_dataloader_params=cfg.dataset_params.train_dataloader_params, + quantization_params=cfg.quantization_params, + ) if cfg.multi_gpu != "OFF" or cfg.num_gpus != 1: logger.warning(f"Recipe requests multi_gpu={cfg.multi_gpu} and num_gpus={cfg.num_gpus}. Changing to multi_gpu=OFF and num_gpus=1") From 7c9506e03613fea107554f8671c5344a2127d72b Mon Sep 17 00:00:00 2001 From: shayaharon Date: Tue, 16 May 2023 15:23:03 +0300 Subject: [PATCH 06/10] removed option to get the defaults from previous training --- src/super_gradients/training/__init__.py | 2 + .../training/pre_launch_callbacks/__init__.py | 3 +- .../pre_launch_callbacks.py | 8 +++ .../training/sg_trainer/sg_trainer.py | 64 +++++++++++-------- .../training/utils/quantization/export.py | 11 +++- .../coded_qat_launch_test.py | 51 +++++++++++++-- 6 files changed, 103 insertions(+), 36 deletions(-) diff --git a/src/super_gradients/training/__init__.py b/src/super_gradients/training/__init__.py index 7db7ea638d..22db6eb6d6 100755 --- a/src/super_gradients/training/__init__.py +++ b/src/super_gradients/training/__init__.py @@ -5,6 +5,7 @@ from super_gradients.training.kd_trainer import KDTrainer from super_gradients.training.qat_trainer import QATTrainer from super_gradients.common import MultiGPUMode, StrictLoad, EvaluationType +from super_gradients.training.pre_launch_callbacks import modify_params_for_qat __all__ = [ "distributed_training_utils", @@ -16,4 +17,5 @@ "MultiGPUMode", "StrictLoad", "EvaluationType", + "modify_params_for_qat", ] diff --git a/src/super_gradients/training/pre_launch_callbacks/__init__.py b/src/super_gradients/training/pre_launch_callbacks/__init__.py index f8563522e7..3a2b932dc0 100644 --- a/src/super_gradients/training/pre_launch_callbacks/__init__.py +++ b/src/super_gradients/training/pre_launch_callbacks/__init__.py @@ -2,7 +2,8 @@ PreLaunchCallback, AutoTrainBatchSizeSelectionCallback, QATRecipeModificationCallback, + modify_params_for_qat, ) from super_gradients.common.registry.registry import ALL_PRE_LAUNCH_CALLBACKS -__all__ = ["PreLaunchCallback", "AutoTrainBatchSizeSelectionCallback", "QATRecipeModificationCallback", "ALL_PRE_LAUNCH_CALLBACKS"] +__all__ = ["PreLaunchCallback", "AutoTrainBatchSizeSelectionCallback", "QATRecipeModificationCallback", "ALL_PRE_LAUNCH_CALLBACKS", "modify_params_for_qat"] diff --git a/src/super_gradients/training/pre_launch_callbacks/pre_launch_callbacks.py b/src/super_gradients/training/pre_launch_callbacks/pre_launch_callbacks.py index c26f16cf93..4f58372451 100644 --- a/src/super_gradients/training/pre_launch_callbacks/pre_launch_callbacks.py +++ b/src/super_gradients/training/pre_launch_callbacks/pre_launch_callbacks.py @@ -354,9 +354,17 @@ def __call__(self, cfg: Union[dict, DictConfig]) -> Union[dict, DictConfig]: ) = modify_params_for_qat( training_hyperparams=cfg.training_hyperparams, train_dataset_params=cfg.dataset_params.train_dataset_params, + train_dataloader_params=cfg.dataset_params.train_dataloader_params, val_dataset_params=cfg.dataset_params.val_dataset_params, val_dataloader_params=cfg.dataset_params.train_dataloader_params, quantization_params=cfg.quantization_params, + batch_size_divisor=self.batch_size_divisor, + disable_phase_callbacks=self.disable_phase_callbacks, + cosine_final_lr_ratio=self.cosine_final_lr_ratio, + warmup_epochs_divisor=self.warmup_epochs_divisor, + lr_decay_factor=self.lr_decay_factor, + max_epochs_divisor=self.max_epochs_divisor, + disable_augmentations=self.disable_augmentations, ) if cfg.multi_gpu != "OFF" or cfg.num_gpus != 1: diff --git a/src/super_gradients/training/sg_trainer/sg_trainer.py b/src/super_gradients/training/sg_trainer/sg_trainer.py index 9bc4acf6e5..8d0f37c42f 100755 --- a/src/super_gradients/training/sg_trainer/sg_trainer.py +++ b/src/super_gradients/training/sg_trainer/sg_trainer.py @@ -2138,11 +2138,11 @@ def quantize_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module, def qat( self, calib_dataloader: DataLoader, - model: torch.nn.Module = None, - val_dataloader: DataLoader = None, - train_dataloader: DataLoader = None, - quantization_params: Mapping = None, + model: torch.nn.Module, + val_dataloader: DataLoader, + train_dataloader: DataLoader, training_params: Mapping = None, + quantization_params: Mapping = None, additional_qat_configs_to_log: Dict = None, valid_metrics_list: List[Metric] = None, ): @@ -2163,7 +2163,6 @@ def qat( :param train_dataloader: DataLoader, data loader for QA training, can be ignored when quantization_params["ptq_only"]=True (default=None). :param quantization_params: Mapping, with the following entries:defaults- - ptq_only: False # whether to launch QAT, or leave PTQ only selective_quantizer_params: calibrator_w: "max" # calibrator type for weights, acceptable types are ["max", "histogram"] calibrator_i: "histogram" # calibrator type for inputs acceptable types are ["max", "histogram"] @@ -2199,39 +2198,33 @@ def qat( if quantization_params is None: quantization_params = load_recipe("quantization_params/default_quantization_params").quantization_params logger.info(f"Using default quantization params: {quantization_params}") - training_params = training_params or self.training_params.to_dict() valid_metrics_list = valid_metrics_list or get_param(training_params, "valid_metrics_list") - train_dataloader = train_dataloader or self.train_loader - val_dataloader = val_dataloader or self.valid_loader model = model or get_param(self.ema_model, "ema") or self.net - res = self.ptq( + _ = self.ptq( calib_dataloader=calib_dataloader, model=model, quantization_params=quantization_params, val_dataloader=val_dataloader, valid_metrics_list=valid_metrics_list, + deepcopy_model_for_export=True, ) # TRAIN - if get_param(quantization_params, "ptq_only", False): - logger.info("quantization_params.ptq_only=True. Performing PTQ only!") - suffix = "ptq" - else: - model.train() - torch.cuda.empty_cache() - - res = self.train( - model=model, - train_loader=train_dataloader, - valid_loader=val_dataloader, - training_params=training_params, - additional_configs_to_log=additional_qat_configs_to_log, - ) - suffix = "qat" + model.train() + torch.cuda.empty_cache() + + res = self.train( + model=model, + train_loader=train_dataloader, + valid_loader=val_dataloader, + training_params=training_params, + additional_configs_to_log=additional_qat_configs_to_log, + ) + # EXPORT QUANTIZED MODEL TO ONNX input_shape = next(iter(val_dataloader))[0].shape os.makedirs(self.checkpoints_dir_path, exist_ok=True) - qdq_onnx_path = os.path.join(self.checkpoints_dir_path, f"{self.experiment_name}_{'x'.join((str(x) for x in input_shape))}_{suffix}.onnx") + qdq_onnx_path = os.path.join(self.checkpoints_dir_path, f"{self.experiment_name}_{'x'.join((str(x) for x in input_shape))}_qat.onnx") # TODO: modify SG's convert_to_onnx for quantized models and use it instead export_quantized_module_to_onnx( @@ -2241,10 +2234,10 @@ def qat( input_size=input_shape, train=False, ) - logger.info(f"Exported {suffix.upper()} ONNX to {qdq_onnx_path}") + logger.info(f"Exported QAT ONNX to {qdq_onnx_path}") return res - def ptq(self, calib_dataloader, model, quantization_params, val_dataloader, valid_metrics_list): + def ptq(self, calib_dataloader, model, quantization_params, val_dataloader, valid_metrics_list, deepcopy_model_for_export=False): """ Performs calibration. @@ -2280,6 +2273,9 @@ def ptq(self, calib_dataloader, model, quantization_params, val_dataloader, vali :param valid_metrics_list: (list(torchmetrics.Metric)) metrics list for evaluation of the calibrated model. + :param deepcopy_model_for_export: bool, Whether to export deepcopy(model). Necessary in case further training is + performed and prep_model_for_conversion makes the network un-trainable (i.e RepVGG blocks). + :return: Validation results of the calibrated model. """ selective_quantizer_params = get_param(quantization_params, "selective_quantizer_params") @@ -2319,4 +2315,18 @@ def ptq(self, calib_dataloader, model, quantization_params, val_dataloader, vali results += [f" - {metric:10}: {value}" for metric, value in valid_metrics_dict.items()] logger.info("\n".join(results)) + input_shape = next(iter(val_dataloader))[0].shape + os.makedirs(self.checkpoints_dir_path, exist_ok=True) + qdq_onnx_path = os.path.join(self.checkpoints_dir_path, f"{self.experiment_name}_{'x'.join((str(x) for x in input_shape))}_ptq.onnx") + + # TODO: modify SG's convert_to_onnx for quantized models and use it instead + export_quantized_module_to_onnx( + model=model.cpu(), + onnx_filename=qdq_onnx_path, + input_shape=input_shape, + input_size=input_shape, + train=False, + deepcopy_model=deepcopy_model_for_export, + ) + return valid_metrics_dict diff --git a/src/super_gradients/training/utils/quantization/export.py b/src/super_gradients/training/utils/quantization/export.py index 48e189cb23..fc6341d5c1 100644 --- a/src/super_gradients/training/utils/quantization/export.py +++ b/src/super_gradients/training/utils/quantization/export.py @@ -1,3 +1,5 @@ +from copy import deepcopy + import torch from torch.onnx import TrainingMode @@ -14,10 +16,14 @@ _imported_pytorch_quantization_failure = import_err -def export_quantized_module_to_onnx(model: torch.nn.Module, onnx_filename: str, input_shape: tuple, train: bool = False, to_cpu: bool = True, **kwargs): +def export_quantized_module_to_onnx( + model: torch.nn.Module, onnx_filename: str, input_shape: tuple, train: bool = False, to_cpu: bool = True, deepcopy_model=False, **kwargs +): """ Method for exporting onnx after QAT. + :param deepcopy_model: Whether to export deepcopy(model). Necessary in case further training is performed and + prep_model_for_conversion makes the network un-trainable (i.e RepVGG blocks). :param to_cpu: transfer model to CPU before converting to ONNX, dirty workaround when model's tensors are on different devices :param train: export model in training mode :param model: torch.nn.Module, model to export @@ -27,6 +33,9 @@ def export_quantized_module_to_onnx(model: torch.nn.Module, onnx_filename: str, if _imported_pytorch_quantization_failure is not None: raise _imported_pytorch_quantization_failure + if deepcopy_model: + model = deepcopy(model) + use_fb_fake_quant_state = quant_nn.TensorQuantizer.use_fb_fake_quant quant_nn.TensorQuantizer.use_fb_fake_quant = True diff --git a/tests/recipe_training_tests/coded_qat_launch_test.py b/tests/recipe_training_tests/coded_qat_launch_test.py index ae7ca3925c..a815197c20 100644 --- a/tests/recipe_training_tests/coded_qat_launch_test.py +++ b/tests/recipe_training_tests/coded_qat_launch_test.py @@ -1,7 +1,10 @@ import unittest +from torchvision.transforms import Normalize, ToTensor, RandomHorizontalFlip, RandomCrop + from super_gradients import QATTrainer -from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader +from super_gradients.training import modify_params_for_qat +from super_gradients.training.dataloaders.dataloaders import cifar10_train, cifar10_val from super_gradients.training.metrics import Accuracy, Top5 from super_gradients.training.models import ResNet18 @@ -9,10 +12,10 @@ class CodedQATLuanchTest(unittest.TestCase): def test_qat_launch(self): trainer = QATTrainer("test_launch_qat_with_minimal_changes") - net = ResNet18(num_classes=5, arch_params={}) + net = ResNet18(num_classes=10, arch_params={}) train_params = { - "max_epochs": 2, - "lr_updates": [1], + "max_epochs": 10, + "lr_updates": [], "lr_decay_factor": 0.1, "lr_mode": "step", "lr_warmup_epochs": 0, @@ -25,14 +28,48 @@ def test_qat_launch(self): "valid_metrics_list": [Accuracy(), Top5()], "metric_to_watch": "Accuracy", "greater_metric_to_watch_is_better": True, + "ema": True, + } + + train_dataset_params = { + "transforms": [ + RandomCrop(size=32, padding=4), + RandomHorizontalFlip, + ToTensor, + Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]), + ] } + + train_dataloader_params = {"batch_size": 256} + + val_dataset_params = {"transforms": [ToTensor, Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])]} + + val_dataloader_params = {"batch_size": 256} + + train_loader = cifar10_train(dataset_params=train_dataset_params, dataloader_params=train_dataloader_params) + valid_loader = cifar10_val(dataset_params=val_dataset_params, dataloader_params=val_dataloader_params) + trainer.train( model=net, training_params=train_params, - train_loader=classification_test_dataloader(batch_size=10), - valid_loader=classification_test_dataloader(batch_size=10), + train_loader=train_loader, + valid_loader=valid_loader, + ) + + train_params, train_dataset_params, val_dataset_params, train_dataloader_params, val_dataloader_params = modify_params_for_qat( + train_params, train_dataset_params, val_dataset_params, train_dataloader_params, val_dataloader_params + ) + + train_loader = cifar10_train(dataset_params=train_dataset_params, dataloader_params=train_dataloader_params) + valid_loader = cifar10_val(dataset_params=val_dataset_params, dataloader_params=val_dataloader_params) + + trainer.qat( + model=net, + training_params=train_params, + train_loader=train_loader, + valid_loader=valid_loader, + calib_dataloader=train_loader, ) - trainer.qat(calib_dataloader=classification_test_dataloader(batch_size=10)) if __name__ == "__main__": From 15eed290a0f3101ae0577a6003151ccf48b4cdeb Mon Sep 17 00:00:00 2001 From: shayaharon Date: Tue, 16 May 2023 16:02:07 +0300 Subject: [PATCH 07/10] added unit tests passing --- .../training/sg_trainer/sg_trainer.py | 83 +++++++++++++------ .../coded_qat_launch_test.py | 63 +++++++++++++- 2 files changed, 115 insertions(+), 31 deletions(-) diff --git a/src/super_gradients/training/sg_trainer/sg_trainer.py b/src/super_gradients/training/sg_trainer/sg_trainer.py index 8d0f37c42f..a3f86e59a8 100755 --- a/src/super_gradients/training/sg_trainer/sg_trainer.py +++ b/src/super_gradients/training/sg_trainer/sg_trainer.py @@ -10,6 +10,7 @@ import torch import torch.cuda import torch.nn +import torchmetrics from omegaconf import DictConfig, OmegaConf from piptools.scripts.sync import _get_installed_distributions from torch import nn @@ -2123,24 +2124,33 @@ def quantize_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module, recipe_logged_cfg = {"recipe_config": OmegaConf.to_container(cfg, resolve=True)} trainer = Trainer(experiment_name=cfg.experiment_name, ckpt_root_dir=get_param(cfg, "ckpt_root_dir")) - res = trainer.qat( - model=model, - quantization_params=quantization_params, - calib_dataloader=calib_dataloader, - val_dataloader=val_dataloader, - train_dataloader=train_dataloader, - training_params=cfg.training_hyperparams, - additional_qat_configs_to_log=recipe_logged_cfg, - ) + if quantization_params.ptq_only: + res = trainer.ptq( + calib_loader=calib_dataloader, + model=model, + quantization_params=quantization_params, + valid_loader=val_dataloader, + valid_metrics_list=cfg.training_hyperparams.valid_metrics_list, + ) + else: + res = trainer.qat( + model=model, + quantization_params=quantization_params, + calib_loader=calib_dataloader, + valid_loader=val_dataloader, + train_loader=train_dataloader, + training_params=cfg.training_hyperparams, + additional_qat_configs_to_log=recipe_logged_cfg, + ) return model, res def qat( self, - calib_dataloader: DataLoader, + calib_loader: DataLoader, model: torch.nn.Module, - val_dataloader: DataLoader, - train_dataloader: DataLoader, + valid_loader: DataLoader, + train_loader: DataLoader, training_params: Mapping = None, quantization_params: Mapping = None, additional_qat_configs_to_log: Dict = None, @@ -2150,17 +2160,17 @@ def qat( Performs post-training quantization (PTQ), and then quantization-aware training (QAT). Exports the ONNX models (ckpt_best.pth of QAT and the calibrated model) to the checkpoints directory. - :param calib_dataloader: DataLoader, data loader for calibration. + :param calib_loader: DataLoader, data loader for calibration. :param model: torch.nn.Module, Model to perform QAT/PTQ on. When None, will try to use the network from previous self.train call(that is, if there was one - will try to use self.ema_model.ema if EMA was used, otherwise self.net) - :param val_dataloader: DataLoader, data loader for validation. Used both for validating the calibrated model after PTQ and during QAT. + :param valid_loader: DataLoader, data loader for validation. Used both for validating the calibrated model after PTQ and during QAT. When None, will try to use self.valid_loader if it was set in previous self.train(..) call (default=None). - :param train_dataloader: DataLoader, data loader for QA training, can be ignored when quantization_params["ptq_only"]=True (default=None). + :param train_loader: DataLoader, data loader for QA training, can be ignored when quantization_params["ptq_only"]=True (default=None). :param quantization_params: Mapping, with the following entries:defaults- selective_quantizer_params: @@ -2182,6 +2192,9 @@ def qat( verbose: False # if calibrator should be verbose + When None, the above default config is used (default=None) + + :param training_params: Mapping, training hyper parameters for QAT, same as in super.train(...). When None, will try to use self.training_params which is set in previous self.train(..) call (default=None). @@ -2202,10 +2215,10 @@ def qat( model = model or get_param(self.ema_model, "ema") or self.net _ = self.ptq( - calib_dataloader=calib_dataloader, + calib_loader=calib_loader, model=model, quantization_params=quantization_params, - val_dataloader=val_dataloader, + valid_loader=valid_loader, valid_metrics_list=valid_metrics_list, deepcopy_model_for_export=True, ) @@ -2215,14 +2228,14 @@ def qat( res = self.train( model=model, - train_loader=train_dataloader, - valid_loader=val_dataloader, + train_loader=train_loader, + valid_loader=valid_loader, training_params=training_params, additional_configs_to_log=additional_qat_configs_to_log, ) # EXPORT QUANTIZED MODEL TO ONNX - input_shape = next(iter(val_dataloader))[0].shape + input_shape = next(iter(valid_loader))[0].shape os.makedirs(self.checkpoints_dir_path, exist_ok=True) qdq_onnx_path = os.path.join(self.checkpoints_dir_path, f"{self.experiment_name}_{'x'.join((str(x) for x in input_shape))}_qat.onnx") @@ -2237,17 +2250,25 @@ def qat( logger.info(f"Exported QAT ONNX to {qdq_onnx_path}") return res - def ptq(self, calib_dataloader, model, quantization_params, val_dataloader, valid_metrics_list, deepcopy_model_for_export=False): + def ptq( + self, + calib_loader: DataLoader, + model: nn.Module, + valid_loader: DataLoader, + valid_metrics_list: List[torchmetrics.Metric], + quantization_params: Dict = None, + deepcopy_model_for_export: bool = False, + ): """ Performs calibration. - :param calib_dataloader: DataLoader, data loader for calibration. + :param calib_loader: DataLoader, data loader for calibration. :param model: torch.nn.Module, Model to perform calibration on. When None, will try to use self.net which is set in previous self.train(..) call (default=None). - :param val_dataloader: DataLoader, data loader for validation. Used both for validating the calibrated model. + :param valid_loader: DataLoader, data loader for validation. Used both for validating the calibrated model. When None, will try to use self.valid_loader if it was set in previous self.train(..) call (default=None). :param quantization_params: Mapping, with the following entries:defaults- @@ -2270,6 +2291,9 @@ def ptq(self, calib_dataloader, model, quantization_params, val_dataloader, vali verbose: False # if calibrator should be verbose + When None, the above default config is used (default=None) + + :param valid_metrics_list: (list(torchmetrics.Metric)) metrics list for evaluation of the calibrated model. @@ -2278,6 +2302,11 @@ def ptq(self, calib_dataloader, model, quantization_params, val_dataloader, vali :return: Validation results of the calibrated model. """ + + if quantization_params is None: + quantization_params = load_recipe("quantization_params/default_quantization_params").quantization_params + logger.info(f"Using default quantization params: {quantization_params}") + selective_quantizer_params = get_param(quantization_params, "selective_quantizer_params") calib_params = get_param(quantization_params, "calib_params") model = model or get_param(self.ema_model, "ema") or self.net @@ -2303,19 +2332,19 @@ def ptq(self, calib_dataloader, model, quantization_params, val_dataloader, vali calibrator.calibrate_model( model, method=get_param(calib_params, "histogram_calib_method"), - calib_data_loader=calib_dataloader, - num_calib_batches=get_param(calib_params, "num_calib_batches") or len(calib_dataloader), + calib_data_loader=calib_loader, + num_calib_batches=get_param(calib_params, "num_calib_batches") or len(calib_loader), percentile=get_param(calib_params, "percentile", 99.99), ) calibrator.reset_calibrators(model) # release memory taken by calibrators # VALIDATE PTQ MODEL AND PRINT SUMMARY logger.info("Validating PTQ model...") - valid_metrics_dict = self.test(model=model, test_loader=val_dataloader, test_metrics_list=valid_metrics_list) + valid_metrics_dict = self.test(model=model, test_loader=valid_loader, test_metrics_list=valid_metrics_list) results = ["PTQ Model Validation Results"] results += [f" - {metric:10}: {value}" for metric, value in valid_metrics_dict.items()] logger.info("\n".join(results)) - input_shape = next(iter(val_dataloader))[0].shape + input_shape = next(iter(valid_loader))[0].shape os.makedirs(self.checkpoints_dir_path, exist_ok=True) qdq_onnx_path = os.path.join(self.checkpoints_dir_path, f"{self.experiment_name}_{'x'.join((str(x) for x in input_shape))}_ptq.onnx") diff --git a/tests/recipe_training_tests/coded_qat_launch_test.py b/tests/recipe_training_tests/coded_qat_launch_test.py index a815197c20..460305d9e3 100644 --- a/tests/recipe_training_tests/coded_qat_launch_test.py +++ b/tests/recipe_training_tests/coded_qat_launch_test.py @@ -34,15 +34,15 @@ def test_qat_launch(self): train_dataset_params = { "transforms": [ RandomCrop(size=32, padding=4), - RandomHorizontalFlip, - ToTensor, + RandomHorizontalFlip(), + ToTensor(), Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]), ] } train_dataloader_params = {"batch_size": 256} - val_dataset_params = {"transforms": [ToTensor, Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])]} + val_dataset_params = {"transforms": [ToTensor(), Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])]} val_dataloader_params = {"batch_size": 256} @@ -68,9 +68,64 @@ def test_qat_launch(self): training_params=train_params, train_loader=train_loader, valid_loader=valid_loader, - calib_dataloader=train_loader, + calib_loader=train_loader, ) + def test_ptq_launch(self): + trainer = QATTrainer("test_launch_qat_with_minimal_changes") + net = ResNet18(num_classes=10, arch_params={}) + train_params = { + "max_epochs": 10, + "lr_updates": [], + "lr_decay_factor": 0.1, + "lr_mode": "step", + "lr_warmup_epochs": 0, + "initial_lr": 0.1, + "loss": "cross_entropy", + "optimizer": "SGD", + "criterion_params": {}, + "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9}, + "train_metrics_list": [Accuracy(), Top5()], + "valid_metrics_list": [Accuracy(), Top5()], + "metric_to_watch": "Accuracy", + "greater_metric_to_watch_is_better": True, + "ema": True, + } + + train_dataset_params = { + "transforms": [ + RandomCrop(size=32, padding=4), + RandomHorizontalFlip(), + ToTensor(), + Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]), + ] + } + + train_dataloader_params = {"batch_size": 256} + + val_dataset_params = {"transforms": [ToTensor(), Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])]} + + val_dataloader_params = {"batch_size": 256} + + train_loader = cifar10_train(dataset_params=train_dataset_params, dataloader_params=train_dataloader_params) + valid_loader = cifar10_val(dataset_params=val_dataset_params, dataloader_params=val_dataloader_params) + + trainer.train( + model=net, + training_params=train_params, + train_loader=train_loader, + valid_loader=valid_loader, + ) + + train_params, train_dataset_params, val_dataset_params, train_dataloader_params, val_dataloader_params = modify_params_for_qat( + train_params, train_dataset_params, val_dataset_params, train_dataloader_params, val_dataloader_params + ) + + train_loader = cifar10_train(dataset_params=train_dataset_params, dataloader_params=train_dataloader_params) + valid_loader = cifar10_val(dataset_params=val_dataset_params, dataloader_params=val_dataloader_params) + + trainer.ptq(model=net, valid_loader=valid_loader, calib_loader=train_loader, valid_metrics_list=train_params["valid_metrics_list"]) + if __name__ == "__main__": unittest.main() From 5e422f29d90eb4419eb3e37034820b3909c36ee2 Mon Sep 17 00:00:00 2001 From: shayaharon Date: Tue, 16 May 2023 16:26:45 +0300 Subject: [PATCH 08/10] updated docs and test names --- .../pre_launch_callbacks.py | 40 ++++++++++++++++++- .../coded_qat_launch_test.py | 6 +-- 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/src/super_gradients/training/pre_launch_callbacks/pre_launch_callbacks.py b/src/super_gradients/training/pre_launch_callbacks/pre_launch_callbacks.py index 4f58372451..fd0c145726 100644 --- a/src/super_gradients/training/pre_launch_callbacks/pre_launch_callbacks.py +++ b/src/super_gradients/training/pre_launch_callbacks/pre_launch_callbacks.py @@ -203,7 +203,45 @@ def modify_params_for_qat( This method modifies the recipe for QAT to implement rules of thumb based on the regular non-qat recipe. It does so by manipulating the training_hyperparams, train_dataloader_params, val_dataloader_params, train_dataset_params, val_dataset_params. Usage: - train_dataloader_params = {'batch_size':32 + trainer = Trainer("test_launch_qat_with_minimal_changes") + net = ResNet18(num_classes=10, arch_params={}) + train_params = {...} + + train_dataset_params = { + "transforms": [... + ] + } + + train_dataloader_params = {"batch_size": 256} + + val_dataset_params = {"transforms": [ToTensor(), Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])]} + + val_dataloader_params = {"batch_size": 256} + + train_loader = cifar10_train(dataset_params=train_dataset_params, dataloader_params=train_dataloader_params) + valid_loader = cifar10_val(dataset_params=val_dataset_params, dataloader_params=val_dataloader_params) + + trainer.train( + model=net, + training_params=train_params, + train_loader=train_loader, + valid_loader=valid_loader, + ) + + train_params, train_dataset_params, val_dataset_params, train_dataloader_params, val_dataloader_params = modify_params_for_qat( + train_params, train_dataset_params, val_dataset_params, train_dataloader_params, val_dataloader_params + ) + + train_loader = cifar10_train(dataset_params=train_dataset_params, dataloader_params=train_dataloader_params) + valid_loader = cifar10_val(dataset_params=val_dataset_params, dataloader_params=val_dataloader_params) + + trainer.qat( + model=net, + training_params=train_params, + train_loader=train_loader, + valid_loader=valid_loader, + calib_loader=train_loader, + ) :param val_dataset_params: Dict, validation dataset_params to be passed to dataloaders.get(...) when instantiating the train dataloader. :param train_dataset_params: Dict, train dataset_params to be passed to dataloaders.get(...) when instantiating the validation dataloader. diff --git a/tests/recipe_training_tests/coded_qat_launch_test.py b/tests/recipe_training_tests/coded_qat_launch_test.py index 460305d9e3..e5bb8531c1 100644 --- a/tests/recipe_training_tests/coded_qat_launch_test.py +++ b/tests/recipe_training_tests/coded_qat_launch_test.py @@ -2,7 +2,7 @@ from torchvision.transforms import Normalize, ToTensor, RandomHorizontalFlip, RandomCrop -from super_gradients import QATTrainer +from super_gradients import Trainer from super_gradients.training import modify_params_for_qat from super_gradients.training.dataloaders.dataloaders import cifar10_train, cifar10_val from super_gradients.training.metrics import Accuracy, Top5 @@ -11,7 +11,7 @@ class CodedQATLuanchTest(unittest.TestCase): def test_qat_launch(self): - trainer = QATTrainer("test_launch_qat_with_minimal_changes") + trainer = Trainer("test_launch_qat_with_minimal_changes") net = ResNet18(num_classes=10, arch_params={}) train_params = { "max_epochs": 10, @@ -72,7 +72,7 @@ def test_qat_launch(self): ) def test_ptq_launch(self): - trainer = QATTrainer("test_launch_qat_with_minimal_changes") + trainer = Trainer("test_launch_ptq_with_minimal_changes") net = ResNet18(num_classes=10, arch_params={}) train_params = { "max_epochs": 10, From 2dceb7541cfa1fc3791d52e306fc7111b0c258ed Mon Sep 17 00:00:00 2001 From: shayaharon Date: Tue, 16 May 2023 16:37:59 +0300 Subject: [PATCH 09/10] moved logger init --- src/super_gradients/training/sg_trainer/sg_trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/super_gradients/training/sg_trainer/sg_trainer.py b/src/super_gradients/training/sg_trainer/sg_trainer.py index a3f86e59a8..c424e95ffd 100755 --- a/src/super_gradients/training/sg_trainer/sg_trainer.py +++ b/src/super_gradients/training/sg_trainer/sg_trainer.py @@ -91,6 +91,9 @@ from super_gradients.common.factories.pre_launch_callbacks_factory import PreLaunchCallbacksFactory from super_gradients.training.params import TrainingParams +logger = get_logger(__name__) + + try: from super_gradients.training.utils.quantization.calibrator import QuantizationCalibrator from super_gradients.training.utils.quantization.export import export_quantized_module_to_onnx @@ -102,7 +105,6 @@ logger.debug("Failed to import pytorch_quantization:") logger.debug(import_err) _imported_pytorch_quantization_failure = import_err -logger = get_logger(__name__) class Trainer: From 954d9f4977066cf603efa25e4ac4fec5869c6c5f Mon Sep 17 00:00:00 2001 From: shayaharon Date: Wed, 17 May 2023 11:02:03 +0300 Subject: [PATCH 10/10] comments resolved --- src/super_gradients/qat_from_recipe.py | 5 ++--- src/super_gradients/training/sg_trainer/sg_trainer.py | 6 +----- src/super_gradients/training/utils/quantization/export.py | 4 ++-- 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/src/super_gradients/qat_from_recipe.py b/src/super_gradients/qat_from_recipe.py index e6e6a30497..a81898313d 100644 --- a/src/super_gradients/qat_from_recipe.py +++ b/src/super_gradients/qat_from_recipe.py @@ -8,13 +8,12 @@ import hydra from omegaconf import DictConfig -from super_gradients import init_trainer -from super_gradients.training.qat_trainer.qat_trainer import QATTrainer +from super_gradients import init_trainer, Trainer @hydra.main(config_path="recipes", version_base="1.2") def _main(cfg: DictConfig) -> None: - QATTrainer.quantize_from_config(cfg) + Trainer.quantize_from_config(cfg) def main(): diff --git a/src/super_gradients/training/sg_trainer/sg_trainer.py b/src/super_gradients/training/sg_trainer/sg_trainer.py index 1cfe58c871..069668b186 100755 --- a/src/super_gradients/training/sg_trainer/sg_trainer.py +++ b/src/super_gradients/training/sg_trainer/sg_trainer.py @@ -2053,7 +2053,6 @@ def quantize_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module, :param cfg: The parsed DictConfig object from yaml recipe files or a dictionary. :return: A tuple containing the quantized model and the output of trainer.train() method. - :rtype: Tuple[nn.Module, Tuple] :raises ValueError: If the recipe does not have the required key `quantization_params` or `checkpoint_params.checkpoint_path` in it. @@ -2226,7 +2225,6 @@ def qat( quantization_params = load_recipe("quantization_params/default_quantization_params").quantization_params logger.info(f"Using default quantization params: {quantization_params}") valid_metrics_list = valid_metrics_list or get_param(training_params, "valid_metrics_list") - model = model or get_param(self.ema_model, "ema") or self.net _ = self.ptq( calib_loader=calib_loader, @@ -2274,8 +2272,7 @@ def ptq( deepcopy_model_for_export: bool = False, ): """ - Performs calibration. - + Performs post-training quantization (calibration of the model).. :param calib_loader: DataLoader, data loader for calibration. @@ -2323,7 +2320,6 @@ def ptq( selective_quantizer_params = get_param(quantization_params, "selective_quantizer_params") calib_params = get_param(quantization_params, "calib_params") - model = model or get_param(self.ema_model, "ema") or self.net model.to(device_config.device) # QUANTIZE MODEL model.eval() diff --git a/src/super_gradients/training/utils/quantization/export.py b/src/super_gradients/training/utils/quantization/export.py index fc6341d5c1..fb238e79bd 100644 --- a/src/super_gradients/training/utils/quantization/export.py +++ b/src/super_gradients/training/utils/quantization/export.py @@ -22,13 +22,13 @@ def export_quantized_module_to_onnx( """ Method for exporting onnx after QAT. - :param deepcopy_model: Whether to export deepcopy(model). Necessary in case further training is performed and - prep_model_for_conversion makes the network un-trainable (i.e RepVGG blocks). :param to_cpu: transfer model to CPU before converting to ONNX, dirty workaround when model's tensors are on different devices :param train: export model in training mode :param model: torch.nn.Module, model to export :param onnx_filename: str, target path for the onnx file, :param input_shape: tuple, input shape (usually BCHW) + :param deepcopy_model: Whether to export deepcopy(model). Necessary in case further training is performed and + prep_model_for_conversion makes the network un-trainable (i.e RepVGG blocks). """ if _imported_pytorch_quantization_failure is not None: raise _imported_pytorch_quantization_failure