Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug/sg 861 decouple qat from train from config #1001

Merged
merged 17 commits into from
May 23, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,8 @@ jobs:
python3.8 -m pip install -r requirements.txt
python3.8 -m pip install .
python3.8 -m pip install torch torchvision torchaudio
python3.8 -m pip install pytorch-quantization==2.1.2 --extra-index-url https://pypi.ngc.nvidia.com

python3.8 src/super_gradients/examples/train_from_recipe_example/train_from_recipe.py --config-name=coco2017_pose_dekr_w32_no_dc experiment_name=shortened_coco2017_pose_dekr_w32_ap_test batch_size=4 val_batch_size=8 epochs=1 training_hyperparams.lr_warmup_steps=0 training_hyperparams.average_best_models=False training_hyperparams.max_train_batches=1000 training_hyperparams.max_valid_batches=100 multi_gpu=DDP num_gpus=4
python3.8 src/super_gradients/examples/train_from_recipe_example/train_from_recipe.py --config-name=cifar10_resnet experiment_name=shortened_cifar10_resnet_accuracy_test epochs=100 training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4
python3.8 src/super_gradients/examples/convert_recipe_example/convert_recipe_example.py --config-name=cifar10_conversion_params experiment_name=shortened_cifar10_resnet_accuracy_test
Expand Down
3 changes: 3 additions & 0 deletions src/super_gradients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from super_gradients.common.registry.registry import ARCHITECTURES
from super_gradients.sanity_check import env_sanity_check
from super_gradients.training.utils.distributed_training_utils import setup_device
from super_gradients.training.pre_launch_callbacks import AutoTrainBatchSizeSelectionCallback, QATRecipeModificationCallback

__all__ = [
"ARCHITECTURES",
Expand All @@ -18,6 +19,8 @@
"is_distributed",
"env_sanity_check",
"setup_device",
"QATRecipeModificationCallback",
"AutoTrainBatchSizeSelectionCallback",
]

__version__ = "3.1.1"
Expand Down
2 changes: 1 addition & 1 deletion src/super_gradients/qat_from_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

@hydra.main(config_path="recipes", version_base="1.2")
def _main(cfg: DictConfig) -> None:
QATTrainer.train_from_config(cfg)
QATTrainer.quantize_from_config(cfg)


def main():
Expand Down
233 changes: 180 additions & 53 deletions src/super_gradients/training/qat_trainer/qat_trainer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import os
from typing import Union, Tuple
from typing import Union, Tuple, Dict, Mapping, List
from torchmetrics import Metric

import copy
import hydra
import torch.cuda
from omegaconf import DictConfig
from omegaconf import OmegaConf
from torch import nn
from torch.utils.data import DataLoader

from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.environment.cfg_utils import load_recipe
from super_gradients.common.environment.device_utils import device_config
from super_gradients.training import utils as core_utils, models, dataloaders
from super_gradients.training.sg_trainer import Trainer
Expand All @@ -32,7 +35,7 @@

class QATTrainer(Trainer):
@classmethod
def train_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module, Tuple]:
def quantize_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module, Tuple]:
"""
Perform quantization aware training (QAT) according to a recipe configuration.

Expand All @@ -42,6 +45,10 @@ def train_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module, Tup

The quantized model will be exported to ONNX along with other checkpoints.

The call to self.quantize (see docs in the next method) is done with the created
train_loader and valid_loader. If no calibration data loader is passed through cfg.calib_loader,
a train data laoder with the validation transforms is used for calibration.

:param cfg: The parsed DictConfig object from yaml recipe files or a dictionary.
:return: A tuple containing the quantized model and the output of trainer.train() method.
:rtype: Tuple[nn.Module, Tuple]
Expand All @@ -61,11 +68,13 @@ def train_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module, Tup
# TRIGGER CFG MODIFYING CALLBACKS
cfg = cls._trigger_cfg_modifying_callbacks(cfg)

if "quantization_params" not in cfg:
raise ValueError("Your recipe does not have quantization_params. Add them to use QAT.")
quantization_params = get_param(cfg, "quantization_params")

if quantization_params is None:
raise logger.warning("Your recipe does not include quantization_params. Using default quantization params.")

if "checkpoint_path" not in cfg.checkpoint_params:
raise ValueError("Starting checkpoint is a must for QAT finetuning.")
if get_param(cfg.checkpoint_params, "checkpoint_path") is None and get_param(cfg.checkpoint_params, "pretrained_weights") is None:
raise ValueError("Starting checkpoint / pretrained weights are a must for QAT finetuning.")

num_gpus = core_utils.get_param(cfg, "num_gpus")
multi_gpu = core_utils.get_param(cfg, "multi_gpu")
Expand Down Expand Up @@ -123,70 +132,118 @@ def train_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module, Tup
checkpoint_path=cfg.checkpoint_params.checkpoint_path,
load_backbone=False,
)
model.to(device_config.device)

# QUANTIZE MODEL
model.eval()
fuse_repvgg_blocks_residual_branches(model)

q_util = SelectiveQuantizer(
default_quant_modules_calibrator_weights=cfg.quantization_params.selective_quantizer_params.calibrator_w,
default_quant_modules_calibrator_inputs=cfg.quantization_params.selective_quantizer_params.calibrator_i,
default_per_channel_quant_weights=cfg.quantization_params.selective_quantizer_params.per_channel,
default_learn_amax=cfg.quantization_params.selective_quantizer_params.learn_amax,
verbose=cfg.quantization_params.calib_params.verbose,
recipe_logged_cfg = {"recipe_config": OmegaConf.to_container(cfg, resolve=True)}
trainer = QATTrainer(experiment_name=cfg.experiment_name, ckpt_root_dir=get_param(cfg, "ckpt_root_dir"))

res = trainer.quantize(
model=model,
quantization_params=quantization_params,
calib_dataloader=calib_dataloader,
val_dataloader=val_dataloader,
train_dataloader=train_dataloader,
training_params=cfg.training_hyperparams,
additional_qat_configs_to_log=recipe_logged_cfg,
)
q_util.register_skip_quantization(layer_names=cfg.quantization_params.selective_quantizer_params.skip_modules)
q_util.quantize_module(model)

# CALIBRATE MODEL
logger.info("Calibrating model...")
calibrator = QuantizationCalibrator(
verbose=cfg.quantization_params.calib_params.verbose,
torch_hist=True,
)
calibrator.calibrate_model(
model,
method=cfg.quantization_params.calib_params.histogram_calib_method,
calib_data_loader=calib_dataloader,
num_calib_batches=cfg.quantization_params.calib_params.num_calib_batches or len(train_dataloader),
percentile=get_param(cfg.quantization_params.calib_params, "percentile", 99.99),
)
calibrator.reset_calibrators(model) # release memory taken by calibrators
return model, res

# VALIDATE PTQ MODEL AND PRINT SUMMARY
logger.info("Validating PTQ model...")
trainer = Trainer(experiment_name=cfg.experiment_name, ckpt_root_dir=get_param(cfg, "ckpt_root_dir", default_val=None))
valid_metrics_dict = trainer.test(model=model, test_loader=val_dataloader, test_metrics_list=cfg.training_hyperparams.valid_metrics_list)
results = ["PTQ Model Validation Results"]
results += [f" - {metric:10}: {value}" for metric, value in valid_metrics_dict.items()]
logger.info("\n".join(results))
def quantize(
self,
calib_dataloader: DataLoader,
model: torch.nn.Module = None,
val_dataloader: DataLoader = None,
train_dataloader: DataLoader = None,
quantization_params: Mapping = None,
training_params: Mapping = None,
additional_qat_configs_to_log: Dict = None,
valid_metrics_list: List[Metric] = None,
):
"""
Performs post-training quantization (PTQ), and optionally quantization-aware training (QAT).
Exports the ONNX model to the checkpoints directory.

:param calib_dataloader: DataLoader, data loader for calibration.

:param model: torch.nn.Module, Model to perform QAT/PTQ on. When None, will try to use the network from
previous self.train call(that is, if there was one - will try to use self.ema_model.ema if EMA was used,
otherwise self.net)................................


:param val_dataloader: DataLoader, data loader for validation. Used both for validating the calibrated model after PTQ and during QAT.
When None, will try to use self.valid_loader if it was set in previous self.train(..) call (default=None).

:param train_dataloader: DataLoader, data loader for QA training, can be ignored when quantization_params["ptq_only"]=True (default=None).

:param quantization_params: Mapping, with the following entries:defaults-

ptq_only: False # whether to launch QAT, or leave PTQ only
selective_quantizer_params:
calibrator_w: "max" # calibrator type for weights, acceptable types are ["max", "histogram"]
calibrator_i: "histogram" # calibrator type for inputs acceptable types are ["max", "histogram"]
per_channel: True # per-channel quantization of weights, activations stay per-tensor by default
learn_amax: False # enable learnable amax in all TensorQuantizers using straight-through estimator
skip_modules: # optional list of module names (strings) to skip from quantization

calib_params:
histogram_calib_method: "percentile" # calibration method for all "histogram" calibrators, acceptable types are ["percentile", "entropy", mse"],
"max" calibrators always use "max"
percentile: 99.99 # percentile for all histogram calibrators with method "percentile", other calibrators are not affected

num_calib_batches: # number of batches to use for calibration, if None, 512 / batch_size will be used
verbose: False # if calibrator should be verbose


:param training_params: Mapping, training hyper parameters for QAT, same as in super.train(...). When None, will try to use self.training_params
which is set in previous self.train(..) call (default=None).

:param additional_qat_configs_to_log: Dict, Optional dictionary containing configs that will be added to the QA training's
sg_logger. Format should be {"Config_title_1": {...}, "Config_title_2":{..}}.

:param valid_metrics_list: (list(torchmetrics.Metric)) metrics list for evaluation of the calibrated model.
When None, the validation metrics from training_params are used). (default=None).

:return: Validation results of the QAT model in case quantization_params['ptq_only']=False and of the PTQ
model otherwise.
"""

if quantization_params is None:
quantization_params = load_recipe("quantization_params/default_quantization_params").quantization_params
logger.info(f"Using default quantization params: {quantization_params}")
training_params = training_params or self.training_params.to_dict()
valid_metrics_list = valid_metrics_list or get_param(training_params, "valid_metrics_list")
train_dataloader = train_dataloader or self.train_loader
val_dataloader = val_dataloader or self.valid_loader
model = model or get_param(self.ema_model, "ema") or self.net

res = self.calibrate_model(
calib_dataloader=calib_dataloader,
model=model,
quantization_params=quantization_params,
val_dataloader=val_dataloader,
valid_metrics_list=valid_metrics_list,
)
# TRAIN
if cfg.quantization_params.ptq_only:
logger.info("cfg.quantization_params.ptq_only=True. Performing PTQ only!")
if get_param(quantization_params, "ptq_only", False):
logger.info("quantization_params.ptq_only=True. Performing PTQ only!")
suffix = "ptq"
res = None
else:
model.train()
recipe_logged_cfg = {"recipe_config": OmegaConf.to_container(cfg, resolve=True)}
trainer = Trainer(experiment_name=cfg.experiment_name, ckpt_root_dir=get_param(cfg, "ckpt_root_dir", default_val=None))
torch.cuda.empty_cache()

res = trainer.train(
res = self.train(
model=model,
train_loader=train_dataloader,
valid_loader=val_dataloader,
training_params=cfg.training_hyperparams,
additional_configs_to_log=recipe_logged_cfg,
training_params=training_params,
additional_configs_to_log=additional_qat_configs_to_log,
)
suffix = "qat"

# EXPORT QUANTIZED MODEL TO ONNX
input_shape = next(iter(val_dataloader))[0].shape
os.makedirs(trainer.checkpoints_dir_path, exist_ok=True)
os.makedirs(self.checkpoints_dir_path, exist_ok=True)
qdq_onnx_path = os.path.join(self.checkpoints_dir_path, f"{self.experiment_name}_{'x'.join((str(x) for x in input_shape))}_{suffix}.onnx")

qdq_onnx_path = os.path.join(trainer.checkpoints_dir_path, f"{cfg.experiment_name}_{'x'.join((str(x) for x in input_shape))}_{suffix}.onnx")
# TODO: modify SG's convert_to_onnx for quantized models and use it instead
export_quantized_module_to_onnx(
model=model.cpu(),
Expand All @@ -195,7 +252,77 @@ def train_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module, Tup
input_size=input_shape,
train=False,
)

logger.info(f"Exported {suffix.upper()} ONNX to {qdq_onnx_path}")
return res

return model, res
def calibrate_model(self, calib_dataloader, model, quantization_params, val_dataloader, valid_metrics_list):
"""
Performs calibration.

:param calib_dataloader: DataLoader, data loader for calibration.

:param model: torch.nn.Module, Model to perform calibration on. When None, will try to use self.net which is
set in previous self.train(..) call (default=None).

:param val_dataloader: DataLoader, data loader for validation. Used both for validating the calibrated model.
When None, will try to use self.valid_loader if it was set in previous self.train(..) call (default=None).

:param quantization_params: Mapping, with the following entries:defaults-
selective_quantizer_params:
calibrator_w: "max" # calibrator type for weights, acceptable types are ["max", "histogram"]
calibrator_i: "histogram" # calibrator type for inputs acceptable types are ["max", "histogram"]
per_channel: True # per-channel quantization of weights, activations stay per-tensor by default
learn_amax: False # enable learnable amax in all TensorQuantizers using straight-through estimator
skip_modules: # optional list of module names (strings) to skip from quantization

calib_params: histogram_calib_method: "percentile" # calibration method for all "histogram" calibrators,
acceptable types are ["percentile", "entropy", mse"], "max" calibrators always use "max" percentile:
99.99 # percentile for all histogram calibrators with method "percentile",
other calibrators are not affected num_calib_batches: # number of batches to use for
calibration, if None, 512 / batch_size will be used verbose: False # if calibrator
should be verbose



:param valid_metrics_list: (list(torchmetrics.Metric)) metrics list for evaluation of the calibrated model.

:return: Validation results of the calibrated model.
"""
selective_quantizer_params = get_param(quantization_params, "selective_quantizer_params")
calib_params = get_param(quantization_params, "calib_params")
model = model or get_param(self.ema_model, "ema") or self.net
model.to(device_config.device)
# QUANTIZE MODEL
model.eval()
fuse_repvgg_blocks_residual_branches(model)
q_util = SelectiveQuantizer(
default_quant_modules_calibrator_weights=get_param(selective_quantizer_params, "calibrator_w"),
default_quant_modules_calibrator_inputs=get_param(selective_quantizer_params, "calibrator_i"),
default_per_channel_quant_weights=get_param(selective_quantizer_params, "per_channel"),
default_learn_amax=get_param(selective_quantizer_params, "learn_amax"),
verbose=get_param(calib_params, "verbose"),
)
q_util.register_skip_quantization(layer_names=get_param(selective_quantizer_params, "skip_modules"))
q_util.quantize_module(model)
# CALIBRATE MODEL
logger.info("Calibrating model...")
calibrator = QuantizationCalibrator(
verbose=get_param(calib_params, "verbose"),
torch_hist=True,
)
calibrator.calibrate_model(
model,
method=get_param(calib_params, "histogram_calib_method"),
calib_data_loader=calib_dataloader,
num_calib_batches=get_param(calib_params, "num_calib_batches") or len(calib_dataloader),
percentile=get_param(calib_params, "percentile", 99.99),
)
calibrator.reset_calibrators(model) # release memory taken by calibrators
# VALIDATE PTQ MODEL AND PRINT SUMMARY
logger.info("Validating PTQ model...")
valid_metrics_dict = self.test(model=model, test_loader=val_dataloader, test_metrics_list=valid_metrics_list)
results = ["PTQ Model Validation Results"]
results += [f" - {metric:10}: {value}" for metric, value in valid_metrics_dict.items()]
logger.info("\n".join(results))

return valid_metrics_dict
2 changes: 2 additions & 0 deletions tests/deci_core_recipe_test_suite_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import unittest

from tests.recipe_training_tests.automatic_batch_selection_single_gpu_test import TestAutoBatchSelectionSingleGPU
from tests.recipe_training_tests.coded_qat_launch_test import CodedQATLuanchTest
from tests.recipe_training_tests.shortened_recipes_accuracy_test import ShortenedRecipesAccuracyTests


Expand All @@ -17,6 +18,7 @@ def _add_modules_to_unit_tests_suite(self):
_add_modules_to_unit_tests_suite - Adds unit tests to the Unit Tests Test Suite
:return:
"""
self.recipe_tests_suite.addTest(self.test_loader.loadTestsFromModule(CodedQATLuanchTest))
self.recipe_tests_suite.addTest(self.test_loader.loadTestsFromModule(ShortenedRecipesAccuracyTests))
self.recipe_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestAutoBatchSelectionSingleGPU))

Expand Down
Loading