Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (ptq): learned round support in evaluate/benchmark #639

Merged
merged 10 commits into from
Jul 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/brevitas/core/function_wrapper/learned_round.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
Different implementations for LearnedRound.
"""

from typing import Optional

import torch

import brevitas
from brevitas import config
from brevitas.core.utils import SliceTensor
from brevitas.function.ops_ste import floor_ste


Expand Down Expand Up @@ -56,14 +59,21 @@ class LearnedRoundSte(brevitas.jit.ScriptModule):
"""

def __init__(
self, learned_round_impl: torch.nn.Module, learned_round_init: torch.Tensor) -> None:
self,
learned_round_impl: torch.nn.Module,
learned_round_init: torch.Tensor,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None) -> None:
super(LearnedRoundSte, self).__init__()
self.learned_round_impl = learned_round_impl
learned_round_init = learned_round_init.to(device=device, dtype=dtype)
self.tensor_slicer = SliceTensor()
self.value = torch.nn.Parameter(learned_round_init)

@brevitas.jit.script_method
def forward(self, x: torch.Tensor) -> torch.Tensor:
p = self.p_forward()
p = self.tensor_slicer(p)
return floor_ste(x) + p.to(x.dtype)

def p_forward(self):
Expand Down
14 changes: 13 additions & 1 deletion src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing_extensions import Protocol
from typing_extensions import runtime_checkable

from brevitas import config
from brevitas.function import max_int
from brevitas.quant_tensor import QuantTensor

Expand Down Expand Up @@ -49,12 +50,23 @@ class ParameterQuantProxyFromInjector(QuantProxyFromInjector):
def tracked_parameter_list(self):
pass

def init_tensor_quant(self):
def init_tensor_quant(self, preserve_state_dict=False):
param_list = self.tracked_parameter_list

# params might not be there yet, e.g. bias before merging
if param_list:
if preserve_state_dict:
reinit_on_state_dict = config.REINIT_ON_STATE_DICT_LOAD
ignore_missing_key = config.IGNORE_MISSING_KEYS
config.REINIT_ON_STATE_DICT_LOAD = False
config.IGNORE_MISSING_KEYS = True
state_dict = self.state_dict()
self.quant_injector = self.quant_injector.let(tracked_parameter_list=param_list)
super(ParameterQuantProxyFromInjector, self).init_tensor_quant()
if preserve_state_dict:
self.load_state_dict(state_dict)
config.IGNORE_MISSING_KEYS = ignore_missing_key
config.REINIT_ON_STATE_DICT_LOAD = reinit_on_state_dict

def max_uint_value(self, bit_width):
return max_int(False, self.is_narrow_range, bit_width)
Expand Down
20 changes: 20 additions & 0 deletions src/brevitas_examples/imagenet_classification/ptq/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Furthermore, Brevitas additional PTQ techniques can be enabled:
- Graph equalization[<sup>1 </sup>].
- If Graph equalization is enabled, the _merge\_bias_ technique can be enabled.[<sup>2 </sup>] [<sup>3 </sup>].
- GPTQ [<sup>4 </sup>].
- Learned Round [<sup>5 </sup>].


Internally, when defining a quantized model programmatically, Brevitas leverages `torch.fx` and its `symbolic_trace` functionality, meaning that an input model is required to pass symbolic tracing for it to work.
Expand Down Expand Up @@ -70,15 +71,24 @@ usage: ptq_evaluate.py [-h] --calibration-dir CALIBRATION_DIR --validation-dir
[--scale-factor-type {float32,po2}]
[--act-bit-width ACT_BIT_WIDTH]
[--weight-bit-width WEIGHT_BIT_WIDTH]
[--layerwise-first-last-bit-width LAYERWISE_FIRST_LAST_BIT_WIDTH]
[--bias-bit-width {int32,int16}]
[--act-quant-type {symmetric,asymmetric}]
[--act-equalization {fx,layerwise,None}]
[--act-quant-calibration-type {percentile,mse}]
[--graph-eq-iterations GRAPH_EQ_ITERATIONS]
[--learned-round-iters LEARNED_ROUND_ITERS]
[--learned-round-lr LEARNED_ROUND_LR]
[--act-quant-percentile ACT_QUANT_PERCENTILE]
[--export-onnx-qcdq] [--export-torch-qcdq]
[--scaling-per-output-channel | --no-scaling-per-output-channel]
[--bias-corr | --no-bias-corr]
[--graph-eq-merge-bias | --no-graph-eq-merge-bias]
[--weight-narrow-range | --no-weight-narrow-range]
[--gptq | --no-gptq]
[--gptq-act-order | --no-gptq-act-order]
[--learned-round | --no-learned-round]
[--calibrate-bn | --no-calibrate-bn]

PyTorch ImageNet PTQ Validation

Expand Down Expand Up @@ -138,6 +148,11 @@ optional arguments:
Activation quantization type (default: symmetric)
--graph-eq-iterations GRAPH_EQ_ITERATIONS
Numbers of iterations for graph equalization (default: 20)
--learned-round-iters LEARNED_ROUND_ITERS
Numbers of iterations for learned round for each layer
(default: 1000)
--learned-round-lr LEARNED_ROUND_LR
Learning rate for learned round (default: 1e-3)
--act-quant-percentile ACT_QUANT_PERCENTILE
Percentile to use for stats of activation quantization (default: 99.999)
--export-onnx-qcdq If true, export the model in onnx qcdq format
Expand All @@ -160,6 +175,10 @@ optional arguments:
--no-gptq Disable GPTQ (default: enabled)
--gptq-act-order Enable GPTQ Act order heuristic (default: disabled)
--no-gptq-act-order Disable GPTQ Act order heuristic (default: disabled)
--learned-round Enable Learned round (default: disabled)
--no-learned-round Disable Learned round (default: disabled)
--calibrate-bn Enable Calibrate BN (default: disabled)
--no-calibrate-bn Disable Calibrate BN (default: disabled)
```

The script requires to specify the calibration folder (`--calibration-dir`), from which the calibration samples will be taken (configurable with the `--calibration-samples` argument), and a validation folder (`--validation-dir`).
Expand Down Expand Up @@ -188,3 +207,4 @@ and a `RESULTS_IMGCLSMOB.csv` with the results on manually quantized models star
[<sup>2 </sup>]: https://github.com/Xilinx/Vitis-AI/blob/50da04ddae396d10a1545823aca30b3abb24a276/src/vai_quantizer/vai_q_pytorch/nndct_shared/optimization/commander.py#L450
[<sup>3 </sup>]: https://github.com/openppl-public/ppq/blob/master/ppq/quantization/algorithm/equalization.py
[<sup>4 </sup>]: https://arxiv.org/abs/2210.17323
[<sup>5 </sup>]: https://arxiv.org/abs/2004.10568
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_act_equalization
from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_bias_correction
from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_gptq
from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_learned_round_learning
from brevitas_examples.imagenet_classification.ptq.ptq_common import calibrate
from brevitas_examples.imagenet_classification.ptq.ptq_common import quantize_model
from brevitas_examples.imagenet_classification.ptq.utils import get_gpu_index
Expand Down Expand Up @@ -54,7 +55,8 @@
'bias_corr': [True], # Bias Correction
'graph_eq_iterations': [0, 20], # Graph Equalization
'graph_eq_merge_bias': [False, True], # Merge bias for Graph Equalization
'act_eq': ['fx', 'layerwise', None], # Perform Activation Equalization (Smoothquant)
'act_equalization': ['fx', 'layerwise', None], # Perform Activation Equalization (Smoothquant)
'learned_round': [False, True], # Enable/Disable Learned Round
'gptq': [False, True], # Enable/Disable GPTQ
'gptq_act_order': [False, True], # Use act_order euristics for GPTQ
'act_quant_percentile': [99.9, 99.99, 99.999], # Activation Quantization Percentile
Expand All @@ -71,7 +73,8 @@
'bias_corr': [True], # Bias Correction
'graph_eq_iterations': [20], # Graph Equalization
'graph_eq_merge_bias': [True], # Merge bias for Graph Equalization
'act_eq': ['fx'], # Perform Activation Equalization (Smoothquant)
'act_equalization': ['fx'], # Perform Activation Equalization (Smoothquant)
'learned_round': [False], # Enable/Disable Learned Round
'gptq': [True], # Enable/Disable GPTQ
'gptq_act_order': [False], # Use act_order euristics for GPTQ
'act_quant_percentile': [99.999], # Activation Quantization Percentile
Expand Down Expand Up @@ -134,8 +137,6 @@ def ptq_torchvision_models(df, args):
return
combination = combinations[args.idx]

# k = 0
# for combination in combinations:
config_namespace = SimpleNamespace()
for key, value in zip(OPTIONS.keys(), combination):
setattr(config_namespace, key, value)
Expand Down Expand Up @@ -203,7 +204,8 @@ def ptq_torchvision_models(df, args):

if config_namespace.act_equalization is not None:
print("Applying activation equalization:")
apply_act_equalization(model, calib_loader, layerwise=args.act_equalization == 'layerwise')
apply_act_equalization(
model, calib_loader, layerwise=config_namespace.act_equalization == 'layerwise')

# Define the quantized model
quant_model = quantize_model(
Expand Down Expand Up @@ -231,6 +233,10 @@ def ptq_torchvision_models(df, args):
print("Performing gptq")
apply_gptq(calib_loader, quant_model, config_namespace.gptq_act_order)

if config_namespace.learned_round:
print("Applying Learned Round:")
apply_learned_round_learning(quant_model, calib_loader)
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved

if config_namespace.bias_corr:
print("Applying bias correction")
apply_bias_correction(calib_loader, quant_model)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

# Part of this code has been re-adapted from https://github.com/yhhhli/BRECQ
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
# under the following LICENSE:

# MIT License

# Copyright (c) 2021 Yuhang Li

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import numpy as np
import torch
import torch.nn.functional as F

from brevitas import config
from brevitas.core.function_wrapper.learned_round import LearnedRoundSte
from brevitas.graph.calibrate import DisableEnableQuantization
from brevitas.inject.enum import FloatToIntImplType
from brevitas.inject.enum import LearnedRoundImplType
from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL

config.IGNORE_MISSING_KEYS = True


class StopFwdException(Exception):
"""Used to throw and catch an exception to stop traversing the graph."""
pass


class DataSaverHook:

def __init__(self, store_output: False):
self.store_output = store_output
self.input_store = None
self.output_store = None

def __call__(self, module, input_batch, output_batch):
if self.store_output:
self.output_store = output_batch
self.input_store = input_batch
raise StopFwdException


class LinearTempDecay:

def __init__(self, t_max: int, rel_start_decay: float = 0.2, start_b: int = 10, end_b: int = 2):
self.t_max = t_max
self.start_decay = rel_start_decay * t_max
self.start_b = start_b
self.end_b = end_b

def __call__(self, t):
if t < self.start_decay:
return self.start_b
else:
rel_t = (t - self.start_decay) / (self.t_max - self.start_decay)
return self.end_b + (self.start_b - self.end_b) * max(0.0, (1 - rel_t))


class Loss:

def __init__(
self,
module,
learned_round_module,
weight=0.01,
max_count=1000,
b_range=(20, 2),
warmup=0.2,
decay_start=0.0):
self.weight = weight
self.module = module
self.loss_start = max_count * warmup
self.temp_decay = LinearTempDecay(
max_count,
start_b=b_range[0],
end_b=b_range[1],
rel_start_decay=warmup + (1.0 - warmup) * decay_start)
self.iter = 0
self.learned_round_module = learned_round_module

def __call__(self, pred, tgt):
self.iter += 1

rec_loss = F.mse_loss(pred, tgt, reduction='none').sum(1).mean()

if self.iter < self.loss_start:
b = self.temp_decay(self.iter)
round_loss = 0
else: # 1 - |(h-0.5)*2|**b
b = self.temp_decay(self.iter)
round_vals = self.learned_round_module.p_forward()
round_loss = self.weight * (1 - ((round_vals - 0.5).abs() * 2).pow(b)).sum()

total_loss = rec_loss + round_loss
return total_loss, rec_loss, round_loss, b


def find_learned_round_module(module):
for submodule in module.modules():
if isinstance(submodule, LearnedRoundSte):
return submodule
return False


def insert_learned_round_quantizer(layer, learned_round_zeta=1.1, learned_round_gamma=-0.1):
if isinstance(layer, QuantWBIOL):
if not find_learned_round_module(layer):
floor_weight = torch.floor(layer.weight.data / layer.quant_weight().scale)
volcacius marked this conversation as resolved.
Show resolved Hide resolved
delta = (layer.weight.data / layer.quant_weight().scale) - floor_weight
value = -torch.log((learned_round_zeta - learned_round_gamma) /
(delta - learned_round_gamma) - 1)
layer.weight_quant.quant_injector = layer.weight_quant.quant_injector.let(
float_to_int_impl_type=FloatToIntImplType.LEARNED_ROUND,
learned_round_impl_type=LearnedRoundImplType.HARD_SIGMOID,
learned_round_gamma=learned_round_gamma,
learned_round_zeta=learned_round_zeta,
learned_round_init=value)
layer.weight_quant.init_tensor_quant(preserve_state_dict=True)


def split_layers(model, layers):
for module in model.children():
if isinstance(module, QuantWBIOL):
layers.append(module)
else:
split_layers(module, layers)


def learned_round_iterator(layers, iters=1000):
for layer in layers:
insert_learned_round_quantizer(layer)

for p in layer.parameters():
p.requires_grad = False

learned_round_module = find_learned_round_module(layer)
learned_round_module.value.requires_grad = True
layer_loss = Loss(module=layer, learned_round_module=learned_round_module, max_count=iters)
yield layer, layer_loss, learned_round_module
layer.eval()


def save_inp_out_data(
model,
module,
dataloader: torch.utils.data.DataLoader,
store_inp=False,
store_out=False,
keep_gpu: bool = True,
disable_quant=False):
if disable_quant:
disable_quant_class = DisableEnableQuantization()
disable_quant_class.disable_act_quantization(model, False)
disable_quant_class.disable_param_quantization(model, False)
device = next(model.parameters()).device
data_saver = DataSaverHook(store_output=store_out)
handle = module.register_forward_hook(data_saver)
cached = [[], []]
with torch.no_grad():
for img, t in dataloader:
try:
_ = model(img.to(device))
except StopFwdException:
pass
if store_inp:
if keep_gpu:
cached[0].append(data_saver.input_store[0].detach())
else:
cached[0].append(data_saver.input_store[0].detach().cpu())
if store_out:
if keep_gpu:
cached[1].append(data_saver.output_store.detach())
else:
cached[1].append(data_saver.output_store.detach().cpu())
if store_inp:
cached[0] = torch.cat([x for x in cached[0]])
if store_out:
cached[1] = torch.cat([x for x in cached[1]])
handle.remove()
if disable_quant:
disable_quant_class.enable_act_quantization(model, False)
disable_quant_class.enable_param_quantization(model, False)
return cached
Loading
Loading