Skip to content

Commit

Permalink
[Features]Quantize pipeline (open-mmlab#350)
Browse files Browse the repository at this point in the history
* init demo

* add customer_tracer

* add quantizer

* add fake_quant, loop, config

* remove CPatcher in custome_tracer

* demo_try

* init version

* modified base.py

* pre-rebase

* wip of adaround series

* adaround experiment

* trasfer to s2

* update api

* point at sub_reconstruction

* pre-checkout

* export onnx

* add customtracer

* fix lint

* move custom tracer

* fix import

* update

* updated

* retina loss & predict & tesnor DONE

* for RFC

* Customed FX initialize

* add UT init

* TDO: UTs

* Successfully RUN

* update loop

* update loop docstrings

* update quantizer docstrings

* update qscheme docstrings

* update qobserver docstrings

* update tracer docstrings

* update UTs init

* update UTs init

* fix bugs

* fix lsq

* refactor quantize pipeline

* fix quant

* WIP: debug qat

* fix lsq bugs

* fix qat, docstring in progress

* TDO: UTs

* fix bugs

* fix lsq

* refactor quantize pipeline

* fix quant

* WIP: debug qat

* fix lsq bugs

* fix qat, docstring in progress

* fixed DefaultQconfigs name

* fix bugs

* add comments and fix typos

* delete useless codes

* fix bugs and add comments

* rename prepare_module_dict

* update lsq config

Co-authored-by: humu789 <humu@pjlab.org.cn>
Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com>
Co-authored-by: FreakieHuang <frank0huang@foxmail.com>
Co-authored-by: pppppM <gjf_mail@126.com>
  • Loading branch information
5 people committed Jan 9, 2023
1 parent ffb8247 commit 6b1e482
Show file tree
Hide file tree
Showing 23 changed files with 739 additions and 154 deletions.
8 changes: 2 additions & 6 deletions configs/quantization/ptq/adaround.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py']
_base_ = ['mmcls::resnet/resnet18_8xb16_cifar10.py']

test_cfg = dict(
_delete_=True,
type='mmrazor.PTQLoop',
dataloader=_base_.test_dataloader,
evaluator=_base_.test_evaluator,
calibrate_dataloader=_base_.train_dataloader,
batch_num=32,

# reconstruction_cfg=dict(
# pattern='layer',
# loss=dict(
Expand Down
1 change: 0 additions & 1 deletion configs/quantization/qat/demo.py

This file was deleted.

70 changes: 70 additions & 0 deletions configs/quantization/qat/lsq_resnet18_8xb16_cifar10.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
_base_ = ['mmcls::resnet/resnet18_8xb16_cifar10.py']

resnet = _base_.model
pretrained_ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_b16x8_cifar10_20210528-bd6371c8.pth' # noqa: E501

model = dict(
_delete_=True,
_scope_='mmrazor',
type='GeneralQuant',
data_preprocessor=dict(
type='mmcls.ClsDataPreprocessor',
num_classes=10,
# RGB format normalization parameters
mean=[125.307, 122.961, 113.8575],
std=[51.5865, 50.847, 51.255],
# loaded images are already RGB format
to_rgb=False),
architecture=resnet,
pretrained_ckpt=pretrained_ckpt,
quantizer=dict(
type='CustomQuantizer',
skipped_methods=[
'mmcls.models.heads.ClsHead._get_loss',
'mmcls.models.heads.ClsHead._get_predictions'
],
qconfig=dict(
qtype='affine',
w_observer=dict(type='mmrazor.LSQObserver'),
a_observer=dict(type='mmrazor.LSQObserver'),
w_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'),
a_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'),
w_qscheme=dict(
bit=8,
is_symmetry=False,
is_per_channel=False,
is_pot_scale=False,
),
a_qscheme=dict(
bit=8,
is_symmetry=False,
is_per_channel=False,
is_pot_scale=False),
)))

optim_wrapper = dict(
optimizer=dict(type='SGD', lr=0.004, momentum=0.9, weight_decay=0.0001))

# learning policy
param_scheduler = dict(
_delete_=True,
type='CosineAnnealingLR',
T_max=100,
by_epoch=True,
begin=0,
end=100)

model_wrapper_cfg = dict(
type='mmrazor.GeneralQuantDDP',
broadcast_buffers=False,
find_unused_parameters=True)

# train, val, test setting
train_cfg = dict(
_delete_=True,
type='mmrazor.QATEpochBasedLoop',
by_epoch=True,
max_epochs=100,
val_interval=1)
val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop')
test_cfg = val_cfg
75 changes: 75 additions & 0 deletions configs/quantization/qat/lsq_resnet18_8xb32_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py']

train_cfg = dict(
_delete_=True,
type='mmrazor.QATEpochBasedLoop',
max_epochs=_base_.train_cfg.max_epochs)

resnet = _base_.model
ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501
resnet.init_cfg = dict(type='Pretrained', checkpoint=ckpt)

model = dict(
_delete_=True,
_scope_='mmrazor',
type='GeneralQuant',
# data_preprocessor = dict(
# num_classes=1000,
# # RGB format normalization parameters
# mean=[123.675, 116.28, 103.53],
# std=[58.395, 57.12, 57.375],
# # convert image from BGR to RGB
# to_rgb=True,
# ),
architecture=resnet,
quantizer=dict(
type='CustomQuantizer',
skipped_methods=[
'mmcls.models.heads.ClsHead._get_loss',
'mmcls.models.heads.ClsHead._get_predictions'
],
qconfig=dict(
qtype='affine',
w_observer=dict(type='mmrazor.MinMaxObserver'),
a_observer=dict(type='mmrazor.EMAMinMaxObserver'),
w_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'),
a_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'),
w_qscheme=dict(
bit=8,
is_symmetry=False,
is_per_channel=False,
is_pot_scale=False,
),
a_qscheme=dict(
bit=8,
is_symmetry=False,
is_per_channel=False,
is_pot_scale=False),
)))

optim_wrapper = dict(
optimizer=dict(type='SGD', lr=0.004, momentum=0.9, weight_decay=0.0001))

# learning policy
param_scheduler = dict(
_delete_=True,
type='CosineAnnealingLR',
T_max=100,
by_epoch=True,
begin=0,
end=100)

default_hooks = dict(
checkpoint=dict(
type='CheckpointHook',
interval=5,
max_keep_ckpts=3,
out_dir='/mnt/petrelfs/caoweihan.p/training_ckpt/quant'))

model_wrapper_cfg = dict(
type='mmrazor.GeneralQuantDDP',
broadcast_buffers=False,
find_unused_parameters=False)

val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop')
test_cfg = val_cfg
37 changes: 0 additions & 37 deletions configs/quantization/qat/lsq_resnet50_8xb16_cifar10.py

This file was deleted.

Loading

0 comments on commit 6b1e482

Please sign in to comment.