Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug/sg 861 decouple qat from train from config #1001

Merged
merged 17 commits into from
May 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,8 @@ jobs:
python3.8 -m pip install -r requirements.txt
python3.8 -m pip install .
python3.8 -m pip install torch torchvision torchaudio
python3.8 -m pip install pytorch-quantization==2.1.2 --extra-index-url https://pypi.ngc.nvidia.com
python3.8 src/super_gradients/examples/train_from_recipe_example/train_from_recipe.py --config-name=coco2017_pose_dekr_w32_no_dc experiment_name=shortened_coco2017_pose_dekr_w32_ap_test batch_size=4 val_batch_size=8 epochs=1 training_hyperparams.lr_warmup_steps=0 training_hyperparams.average_best_models=False training_hyperparams.max_train_batches=1000 training_hyperparams.max_valid_batches=100 multi_gpu=DDP num_gpus=4
python3.8 src/super_gradients/examples/train_from_recipe_example/train_from_recipe.py --config-name=cifar10_resnet experiment_name=shortened_cifar10_resnet_accuracy_test epochs=100 training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4
python3.8 src/super_gradients/examples/convert_recipe_example/convert_recipe_example.py --config-name=cifar10_conversion_params experiment_name=shortened_cifar10_resnet_accuracy_test
Expand Down
3 changes: 3 additions & 0 deletions src/super_gradients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from super_gradients.common.registry.registry import ARCHITECTURES
from super_gradients.sanity_check import env_sanity_check
from super_gradients.training.utils.distributed_training_utils import setup_device
from super_gradients.training.pre_launch_callbacks import AutoTrainBatchSizeSelectionCallback, QATRecipeModificationCallback

__all__ = [
"ARCHITECTURES",
Expand All @@ -18,6 +19,8 @@
"is_distributed",
"env_sanity_check",
"setup_device",
"QATRecipeModificationCallback",
"AutoTrainBatchSizeSelectionCallback",
]

__version__ = "3.1.1"
Expand Down
5 changes: 2 additions & 3 deletions src/super_gradients/qat_from_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@
import hydra
from omegaconf import DictConfig

from super_gradients import init_trainer
from super_gradients.training.qat_trainer.qat_trainer import QATTrainer
from super_gradients import init_trainer, Trainer


@hydra.main(config_path="recipes", version_base="1.2")
def _main(cfg: DictConfig) -> None:
QATTrainer.train_from_config(cfg)
Trainer.quantize_from_config(cfg)


def main():
Expand Down
2 changes: 2 additions & 0 deletions src/super_gradients/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from super_gradients.training.kd_trainer import KDTrainer
from super_gradients.training.qat_trainer import QATTrainer
from super_gradients.common import MultiGPUMode, StrictLoad, EvaluationType
from super_gradients.training.pre_launch_callbacks import modify_params_for_qat

__all__ = [
"distributed_training_utils",
Expand All @@ -16,4 +17,5 @@
"MultiGPUMode",
"StrictLoad",
"EvaluationType",
"modify_params_for_qat",
]
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
PreLaunchCallback,
AutoTrainBatchSizeSelectionCallback,
QATRecipeModificationCallback,
modify_params_for_qat,
)
from super_gradients.common.registry.registry import ALL_PRE_LAUNCH_CALLBACKS

__all__ = ["PreLaunchCallback", "AutoTrainBatchSizeSelectionCallback", "QATRecipeModificationCallback", "ALL_PRE_LAUNCH_CALLBACKS"]
__all__ = ["PreLaunchCallback", "AutoTrainBatchSizeSelectionCallback", "QATRecipeModificationCallback", "ALL_PRE_LAUNCH_CALLBACKS", "modify_params_for_qat"]

Large diffs are not rendered by default.

192 changes: 4 additions & 188 deletions src/super_gradients/training/qat_trainer/qat_trainer.py
Original file line number Diff line number Diff line change
@@ -1,201 +1,17 @@
import os
from typing import Union, Tuple

import copy
import hydra
import torch.cuda
from deprecated import deprecated
from omegaconf import DictConfig
from omegaconf import OmegaConf
from torch import nn

from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.environment.device_utils import device_config
from super_gradients.training import utils as core_utils, models, dataloaders
from super_gradients.training.sg_trainer import Trainer
from super_gradients.training.utils import get_param
from super_gradients.training.utils.distributed_training_utils import setup_device
from super_gradients.modules.repvgg_block import fuse_repvgg_blocks_residual_branches

logger = get_logger(__name__)
try:
from super_gradients.training.utils.quantization.calibrator import QuantizationCalibrator
from super_gradients.training.utils.quantization.export import export_quantized_module_to_onnx
from super_gradients.training.utils.quantization.selective_quantization_utils import SelectiveQuantizer

_imported_pytorch_quantization_failure = None

except (ImportError, NameError, ModuleNotFoundError) as import_err:
logger.debug("Failed to import pytorch_quantization:")
logger.debug(import_err)
_imported_pytorch_quantization_failure = import_err


class QATTrainer(Trainer):
@classmethod
def train_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module, Tuple]:
"""
Perform quantization aware training (QAT) according to a recipe configuration.
This method will instantiate all the objects specified in the recipe, build and quantize the model,
and calibrate the quantized model. The resulting quantized model and the output of the trainer.train()
method will be returned.
The quantized model will be exported to ONNX along with other checkpoints.
:param cfg: The parsed DictConfig object from yaml recipe files or a dictionary.
:return: A tuple containing the quantized model and the output of trainer.train() method.
:rtype: Tuple[nn.Module, Tuple]
:raises ValueError: If the recipe does not have the required key `quantization_params` or
`checkpoint_params.checkpoint_path` in it.
:raises NotImplementedError: If the recipe requests multiple GPUs or num_gpus is not equal to 1.
:raises ImportError: If pytorch-quantization import was unsuccessful
"""
if _imported_pytorch_quantization_failure is not None:
raise _imported_pytorch_quantization_failure

# INSTANTIATE ALL OBJECTS IN CFG
cfg = hydra.utils.instantiate(cfg)

# TRIGGER CFG MODIFYING CALLBACKS
cfg = cls._trigger_cfg_modifying_callbacks(cfg)

if "quantization_params" not in cfg:
raise ValueError("Your recipe does not have quantization_params. Add them to use QAT.")

if "checkpoint_path" not in cfg.checkpoint_params:
raise ValueError("Starting checkpoint is a must for QAT finetuning.")

num_gpus = core_utils.get_param(cfg, "num_gpus")
multi_gpu = core_utils.get_param(cfg, "multi_gpu")
device = core_utils.get_param(cfg, "device")
if num_gpus != 1:
raise NotImplementedError(
f"Recipe requests multi_gpu={cfg.multi_gpu} and num_gpus={cfg.num_gpus}. QAT is proven to work correctly only with multi_gpu=OFF and num_gpus=1"
)

setup_device(device=device, multi_gpu=multi_gpu, num_gpus=num_gpus)

# INSTANTIATE DATA LOADERS
train_dataloader = dataloaders.get(
name=get_param(cfg, "train_dataloader"),
dataset_params=copy.deepcopy(cfg.dataset_params.train_dataset_params),
dataloader_params=copy.deepcopy(cfg.dataset_params.train_dataloader_params),
)

val_dataloader = dataloaders.get(
name=get_param(cfg, "val_dataloader"),
dataset_params=copy.deepcopy(cfg.dataset_params.val_dataset_params),
dataloader_params=copy.deepcopy(cfg.dataset_params.val_dataloader_params),
)

if "calib_dataloader" in cfg:
calib_dataloader_name = get_param(cfg, "calib_dataloader")
calib_dataloader_params = copy.deepcopy(cfg.dataset_params.calib_dataloader_params)
calib_dataset_params = copy.deepcopy(cfg.dataset_params.calib_dataset_params)
else:
calib_dataloader_name = get_param(cfg, "train_dataloader")
calib_dataloader_params = copy.deepcopy(cfg.dataset_params.train_dataloader_params)
calib_dataset_params = copy.deepcopy(cfg.dataset_params.train_dataset_params)

# if we use whole dataloader for calibration, don't shuffle it
# HistogramCalibrator collection routine is sensitive to order of batches and produces slightly different results
# if we use several batches, we don't want it to be from one class if it's sequential in dataloader
# model is in eval mode, so BNs will not be affected
calib_dataloader_params.shuffle = cfg.quantization_params.calib_params.num_calib_batches is not None
# we don't need training transforms during calibration, distribution of activations will be skewed
calib_dataset_params.transforms = cfg.dataset_params.val_dataset_params.transforms

calib_dataloader = dataloaders.get(
name=calib_dataloader_name,
dataset_params=calib_dataset_params,
dataloader_params=calib_dataloader_params,
)

# BUILD MODEL
model = models.get(
model_name=cfg.arch_params.get("model_name", None) or cfg.architecture,
num_classes=cfg.get("num_classes", None) or cfg.arch_params.num_classes,
arch_params=cfg.arch_params,
strict_load=cfg.checkpoint_params.strict_load,
pretrained_weights=cfg.checkpoint_params.pretrained_weights,
checkpoint_path=cfg.checkpoint_params.checkpoint_path,
load_backbone=False,
)
model.to(device_config.device)

# QUANTIZE MODEL
model.eval()
fuse_repvgg_blocks_residual_branches(model)

q_util = SelectiveQuantizer(
default_quant_modules_calibrator_weights=cfg.quantization_params.selective_quantizer_params.calibrator_w,
default_quant_modules_calibrator_inputs=cfg.quantization_params.selective_quantizer_params.calibrator_i,
default_per_channel_quant_weights=cfg.quantization_params.selective_quantizer_params.per_channel,
default_learn_amax=cfg.quantization_params.selective_quantizer_params.learn_amax,
verbose=cfg.quantization_params.calib_params.verbose,
)
q_util.register_skip_quantization(layer_names=cfg.quantization_params.selective_quantizer_params.skip_modules)
q_util.quantize_module(model)

# CALIBRATE MODEL
logger.info("Calibrating model...")
calibrator = QuantizationCalibrator(
verbose=cfg.quantization_params.calib_params.verbose,
torch_hist=True,
)
calibrator.calibrate_model(
model,
method=cfg.quantization_params.calib_params.histogram_calib_method,
calib_data_loader=calib_dataloader,
num_calib_batches=cfg.quantization_params.calib_params.num_calib_batches or len(train_dataloader),
percentile=get_param(cfg.quantization_params.calib_params, "percentile", 99.99),
)
calibrator.reset_calibrators(model) # release memory taken by calibrators

# VALIDATE PTQ MODEL AND PRINT SUMMARY
logger.info("Validating PTQ model...")
trainer = Trainer(experiment_name=cfg.experiment_name, ckpt_root_dir=get_param(cfg, "ckpt_root_dir", default_val=None))
valid_metrics_dict = trainer.test(model=model, test_loader=val_dataloader, test_metrics_list=cfg.training_hyperparams.valid_metrics_list)
results = ["PTQ Model Validation Results"]
results += [f" - {metric:10}: {value}" for metric, value in valid_metrics_dict.items()]
logger.info("\n".join(results))

# TRAIN
if cfg.quantization_params.ptq_only:
logger.info("cfg.quantization_params.ptq_only=True. Performing PTQ only!")
suffix = "ptq"
res = None
else:
model.train()
recipe_logged_cfg = {"recipe_config": OmegaConf.to_container(cfg, resolve=True)}
trainer = Trainer(experiment_name=cfg.experiment_name, ckpt_root_dir=get_param(cfg, "ckpt_root_dir", default_val=None))
torch.cuda.empty_cache()

res = trainer.train(
model=model,
train_loader=train_dataloader,
valid_loader=val_dataloader,
training_params=cfg.training_hyperparams,
additional_configs_to_log=recipe_logged_cfg,
)
suffix = "qat"

# EXPORT QUANTIZED MODEL TO ONNX
input_shape = next(iter(val_dataloader))[0].shape
os.makedirs(trainer.checkpoints_dir_path, exist_ok=True)

qdq_onnx_path = os.path.join(trainer.checkpoints_dir_path, f"{cfg.experiment_name}_{'x'.join((str(x) for x in input_shape))}_{suffix}.onnx")
# TODO: modify SG's convert_to_onnx for quantized models and use it instead
export_quantized_module_to_onnx(
model=model.cpu(),
onnx_filename=qdq_onnx_path,
input_shape=input_shape,
input_size=input_shape,
train=False,
)

logger.info(f"Exported {suffix.upper()} ONNX to {qdq_onnx_path}")

return model, res
@deprecated(version="3.2.0", reason="QATTrainer is deprecated and will be removed in future release, use Trainer " "class instead.")
def quantize_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module, Tuple]:
return Trainer.quantize_from_config(cfg)
Loading