Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Fix quantization loop #507

Merged
merged 8 commits into from
Apr 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ jobs:
coverage report -m
# Upload coverage report for python3.8 && pytorch1.12.0 cpu
- name: Upload coverage to Codecov
if: ${{matrix.torch == '1.12.0' && matrix.python-version == '3.8'}}
uses: codecov/codecov-action@v2
if: ${{matrix.torch == '1.13.0' && matrix.python-version == '3.8'}}
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
flags: unittests
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@
_delete_=True,
type='mmrazor.LSQEpochBasedLoop',
max_epochs=100,
val_interval=1)
val_interval=1,
freeze_bn_begin=1)
val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop')
test_cfg = val_cfg

# Make sure the buffer such as min_val/max_val in saved checkpoint is the same
# among different rank.
default_hooks = dict(sync=dict(type='SyncBuffersHook'))
63 changes: 63 additions & 0 deletions configs/quantization/qat/lsq_openvino_resnet18_8xb32_10e_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py']

resnet = _base_.model
float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501

global_qconfig = dict(
w_observer=dict(type='mmrazor.LSQPerChannelObserver'),
a_observer=dict(type='mmrazor.LSQObserver'),
w_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'),
a_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'),
w_qscheme=dict(
qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True),
a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True),
)

model = dict(
_delete_=True,
_scope_='mmrazor',
type='MMArchitectureQuant',
data_preprocessor=dict(
type='mmcls.ClsDataPreprocessor',
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True),
architecture=resnet,
float_checkpoint=float_checkpoint,
quantizer=dict(
type='mmrazor.OpenVINOQuantizer',
global_qconfig=global_qconfig,
tracer=dict(
type='mmrazor.CustomTracer',
skipped_methods=[
'mmcls.models.heads.ClsHead._get_loss',
'mmcls.models.heads.ClsHead._get_predictions'
])))

optim_wrapper = dict(
optimizer=dict(type='SGD', lr=0.0001, momentum=0.9, weight_decay=0.0001))

# learning policy
param_scheduler = dict(
_delete_=True, type='ConstantLR', factor=1.0, by_epoch=True)

model_wrapper_cfg = dict(
type='mmrazor.MMArchitectureQuantDDP',
broadcast_buffers=False,
find_unused_parameters=True)

# train, val, test setting
train_cfg = dict(
_delete_=True,
type='mmrazor.LSQEpochBasedLoop',
max_epochs=10,
val_interval=1,
freeze_bn_begin=1)
val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop')

# Make sure the buffer such as min_val/max_val in saved checkpoint is the same
# among different rank.
default_hooks = dict(sync=dict(type='SyncBuffersHook'))
62 changes: 62 additions & 0 deletions configs/quantization/qat/qat_openvino_resnet18_10e_8xb32_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py']

resnet = _base_.model
float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501

global_qconfig = dict(
w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'),
a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'),
w_fake_quant=dict(type='mmrazor.FakeQuantize'),
a_fake_quant=dict(type='mmrazor.FakeQuantize'),
w_qscheme=dict(
qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True),
a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True),
)

model = dict(
_delete_=True,
_scope_='mmrazor',
type='MMArchitectureQuant',
data_preprocessor=dict(
type='mmcls.ClsDataPreprocessor',
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True),
architecture=resnet,
float_checkpoint=float_checkpoint,
quantizer=dict(
type='mmrazor.OpenVINOQuantizer',
global_qconfig=global_qconfig,
tracer=dict(
type='mmrazor.CustomTracer',
skipped_methods=[
'mmcls.models.heads.ClsHead._get_loss',
'mmcls.models.heads.ClsHead._get_predictions'
])))

optim_wrapper = dict(
optimizer=dict(type='SGD', lr=0.0001, momentum=0.9, weight_decay=0.0001))

# learning policy
param_scheduler = dict(
_delete_=True, type='ConstantLR', factor=1.0, by_epoch=True)

model_wrapper_cfg = dict(
type='mmrazor.MMArchitectureQuantDDP',
broadcast_buffers=False,
find_unused_parameters=False)

# train, val, test setting
train_cfg = dict(
_delete_=True,
type='mmrazor.QATEpochBasedLoop',
max_epochs=10,
val_interval=1)
val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop')

# Make sure the buffer such as min_val/max_val in saved checkpoint is the same
# among different rank.
default_hooks = dict(sync=dict(type='SyncBuffersHook'))
61 changes: 41 additions & 20 deletions mmrazor/engine/runner/quantization_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
from torch.nn.intrinsic.qat import freeze_bn_stats
except ImportError:
from mmrazor.utils import get_placeholder

disable_observer = get_placeholder('torch>=1.13')
enable_fake_quant = get_placeholder('torch>=1.13')
enable_observer = get_placeholder('torch>=1.13')
freeze_bn_stats = get_placeholder('torch>=1.13')

from mmengine.dist import all_reduce_params, is_distributed
from torch.utils.data import DataLoader

from mmrazor.models import register_torch_fake_quants, register_torch_observers
Expand Down Expand Up @@ -69,7 +71,18 @@ def prepare_for_run_epoch(self):
"""Toggle the state of the observers and fake quantizers before qat
training."""
self.runner.model.apply(enable_fake_quant)
self.runner.model.apply(enable_observer)

# The initialized _epoch equals to 0 so _epoch + 1
# equal to the current epoch
if (self.disable_observer_begin > 0
and self._epoch + 1 >= self.disable_observer_begin):
self.runner.model.apply(disable_observer)
else:
self.runner.model.apply(enable_observer)

if (self.freeze_bn_begin > 0
and self._epoch + 1 >= self.freeze_bn_begin):
self.runner.model.apply(freeze_bn_stats)

def prepare_for_val(self):
"""Toggle the state of the observers and fake quantizers before
Expand All @@ -89,8 +102,6 @@ def run(self):
if (self.runner.val_loop is not None
and self._epoch >= self.val_begin
and self._epoch % self.val_interval == 0):
# observer disabled during evaluation
self.prepare_for_val()
self.runner.val_loop.run()

self.runner.call_hook('after_train')
Expand All @@ -100,18 +111,13 @@ def run_epoch(self) -> None:
self.runner.call_hook('before_train_epoch')
self.runner.model.train()

# The initialized _epoch equals to 0 so _epoch + 1
# equal to the current epoch
if self._epoch + 1 >= self.disable_observer_begin:
self.runner.model.apply(disable_observer)

if self._epoch + 1 >= self.freeze_bn_begin:
self.runner.model.apply(freeze_bn_stats)

for idx, data_batch in enumerate(self.dataloader):
self.run_iter(idx, data_batch)

self.runner.model.sync_qparams(src_mode='loss')
# Make sure the registered buffer such as `observer_enabled` is
# correct in the saved checkpoint.
self.prepare_for_val()
self.runner.call_hook('after_train_epoch')
self._epoch += 1

Expand Down Expand Up @@ -156,11 +162,16 @@ def __init__(
dynamic_intervals=dynamic_intervals)

self.is_first_batch = True
self.distributed = is_distributed()

def prepare_for_run_epoch(self):
"""Toggle the state of the observers and fake quantizers before qat
training."""
pass
if (self.freeze_bn_begin > 0
and self._epoch + 1 >= self.freeze_bn_begin):
self.runner.model.apply(freeze_bn_stats)

self.runner.model.apply(enable_param_learning)

def prepare_for_val(self):
"""Toggle the state of the observers and fake quantizers before
Expand All @@ -172,20 +183,30 @@ def run_epoch(self) -> None:
self.runner.call_hook('before_train_epoch')
self.runner.model.train()

# TODO freeze bn
if self._epoch + 1 >= self.freeze_bn_begin:
self.runner.model.apply(freeze_bn_stats)

for idx, data_batch in enumerate(self.dataloader):
if self.is_first_batch:
# lsq init
self.is_first_batch = False
# lsq observer init
self.runner.model.apply(enable_static_estimate)
else:
self.runner.model.apply(enable_param_learning)

self.run_iter(idx, data_batch)

if self.is_first_batch:
# In the first batch, scale in LearnableFakeQuantize is
# calculated through lsq observer. As the values of `scale` of
# different observers in different rank are usually different,
# we have to sync the `scale` here.
if self.distributed:
all_reduce_params(
self.runner.model.parameters(), op='mean')

# Change back to param learning mode
self.is_first_batch = False
self.runner.model.apply(enable_param_learning)

self.runner.model.sync_qparams(src_mode='loss')
# Make sure the registered buffer such as `observer_enabled` is
# correct in the saved checkpoint.
self.prepare_for_val()
self.runner.call_hook('after_train_epoch')
self._epoch += 1

Expand Down
87 changes: 0 additions & 87 deletions mmrazor/models/losses/adaround_loss.py

This file was deleted.

2 changes: 1 addition & 1 deletion requirements/tests.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
codecov
coverage
flake8
interrogate
isort==4.3.21
Expand Down