Skip to content

Commit

Permalink
Fix (calibrate): fix for minifloat act calibration (#966)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Jun 3, 2024
1 parent 0f60606 commit 02f5b6b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
6 changes: 3 additions & 3 deletions src/brevitas/graph/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL
from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjectorBase
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjectorBase
from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector
from brevitas.proxy.runtime_quant import ActQuantProxyFromInjectorBase
from brevitas.proxy.runtime_quant import ClampQuantProxyFromInjector
from brevitas.proxy.runtime_quant import TruncQuantProxyFromInjector
from brevitas.quant_tensor import QuantTensor
Expand Down Expand Up @@ -188,7 +188,7 @@ def disable_act_quantization(self, model, is_training):
# will be discarded through the hook. It is useful for collecting activation stats,
# for example during activation calibration in PTQ
for module in model.modules():
if isinstance(module, ActQuantProxyFromInjector):
if isinstance(module, ActQuantProxyFromInjectorBase):
module.train(is_training)
if self.call_act_quantizer_impl:
hook = module.register_forward_hook(self.disable_act_quant_hook)
Expand Down Expand Up @@ -216,7 +216,7 @@ def enable_act_quantization(self, model, is_training):
if isinstance(module, _ACC_PROXIES):
module.train(is_training)
module.disable_quant = False
elif isinstance(module, ActQuantProxyFromInjector):
elif isinstance(module, ActQuantProxyFromInjectorBase):
module.disable_quant = False
module.train(is_training)
for hook in self.disable_act_quant_hooks:
Expand Down
15 changes: 8 additions & 7 deletions src/brevitas_examples/imagenet_classification/ptq/ptq_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,14 @@ class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat):
'sym': CNNInt8DynamicActPerTensorFloat,
'asym': CNNShiftedUint8DynamicActPerTensorFloat}}}}},
'float': {
'float_scale': {
'stats': {
'per_tensor': {
'sym': Fp8e4m3ActPerTensorFloat}},
'mse': {
'per_tensor': {
'sym': Fp8e4m3ActPerTensorFloatMSE}}}}}
'static': {
'float_scale': {
'stats': {
'per_tensor': {
'sym': Fp8e4m3ActPerTensorFloat}},
'mse': {
'per_tensor': {
'sym': Fp8e4m3ActPerTensorFloatMSE}}}}}}


def quantize_model(
Expand Down

0 comments on commit 02f5b6b

Please sign in to comment.