Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
HIT-cwh committed Nov 22, 2022
1 parent 15e9d18 commit a6bb97d
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 43 deletions.
Original file line number Diff line number Diff line change
@@ -1,27 +1,22 @@
_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py']

train_cfg = dict(
_delete_=True,
type='mmrazor.QATEpochBasedLoop',
max_epochs=_base_.train_cfg.max_epochs)
_base_ = ['mmcls::resnet/resnet18_8xb16_cifar10.py']

resnet = _base_.model
ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501
resnet.init_cfg = dict(type='Pretrained', checkpoint=ckpt)
pretrained_ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_b16x8_cifar10_20210528-bd6371c8.pth' # noqa: E501

model = dict(
_delete_=True,
_scope_='mmrazor',
type='GeneralQuant',
# data_preprocessor = dict(
# num_classes=1000,
# # RGB format normalization parameters
# mean=[123.675, 116.28, 103.53],
# std=[58.395, 57.12, 57.375],
# # convert image from BGR to RGB
# to_rgb=True,
# ),
data_preprocessor=dict(
type='mmcls.ClsDataPreprocessor',
num_classes=10,
# RGB format normalization parameters
mean=[125.307, 122.961, 113.8575],
std=[51.5865, 50.847, 51.255],
# loaded images are already RGB format
to_rgb=False),
architecture=resnet,
pretrained_ckpt=pretrained_ckpt,
quantizer=dict(
type='TensorRTQuantizer',
skipped_methods=[
Expand All @@ -30,8 +25,8 @@
],
qconfig=dict(
qtype='affine',
w_observer=dict(type='mmrazor.MinMaxObserver'),
a_observer=dict(type='mmrazor.EMAMinMaxObserver'),
w_observer=dict(type='mmrazor.LSQObserver'),
a_observer=dict(type='mmrazor.LSQObserver'),
w_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'),
a_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'),
w_qscheme=dict(
Expand Down Expand Up @@ -59,17 +54,17 @@
begin=0,
end=100)

default_hooks = dict(
checkpoint=dict(
type='CheckpointHook',
interval=5,
max_keep_ckpts=3,
out_dir='/mnt/petrelfs/caoweihan.p/training_ckpt/quant'))

model_wrapper_cfg = dict(
type='mmrazor.GeneralQuantDDP',
broadcast_buffers=False,
find_unused_parameters=False)
find_unused_parameters=True)

# train, val, test setting
train_cfg = dict(
_delete_=True,
type='mmrazor.QATEpochBasedLoop',
by_epoch=True,
max_epochs=100,
val_interval=1)
val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop')
test_cfg = val_cfg
5 changes: 2 additions & 3 deletions mmrazor/engine/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dump_subnet_hook import DumpSubnetHook
from .estimate_resources_hook import EstimateResourcesHook
from .visualization_hook import RazorVisualizationHook

# from .quant_hook import QuantitiveHook

__all__ = ['DumpSubnetHook', 'EstimateResourcesHook']
__all__ = ['DumpSubnetHook', 'EstimateResourcesHook', 'RazorVisualizationHook']
4 changes: 2 additions & 2 deletions mmrazor/models/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .distill import (DAFLDataFreeDistillation, DataFreeDistillation,
FpnTeacherDistill, OverhaulFeatureDistillation,
SelfDistill, SingleTeacherDistill)
from .nas import SPOS, AutoSlim, AutoSlimDDP, Darts, DartsDDP, Dsnas, DsnasDDP
from .nas import DSNAS, DSNASDDP, SPOS, AutoSlim, AutoSlimDDP, Darts, DartsDDP
from .pruning import SlimmableNetwork, SlimmableNetworkDDP
from .pruning.ite_prune_algorithm import ItePruneAlgorithm
from .quantization import GeneralQuant
Expand All @@ -13,5 +13,5 @@
'SlimmableNetwork', 'SlimmableNetworkDDP', 'AutoSlim', 'AutoSlimDDP',
'Darts', 'DartsDDP', 'SelfDistill', 'DataFreeDistillation',
'DAFLDataFreeDistillation', 'OverhaulFeatureDistillation',
'ItePruneAlgorithm', 'Dsnas', 'DsnasDDP', 'GeneralQuant'
'ItePruneAlgorithm', 'DSNAS', 'DSNASDDP', 'GeneralQuant'
]
2 changes: 0 additions & 2 deletions mmrazor/models/observers/minmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ def forward(self, x_orig):
max_val = torch.max(self.max_val, max_val_cur)
self.min_val = min_val
self.max_val = max_val
# self.min_val.copy_(min_val)
# self.max_val.copy_(max_val)

return x

Expand Down
2 changes: 1 addition & 1 deletion mmrazor/models/quantizers/trt_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class TensorRTQuantizer(CustomQuantizer):
"""Quantizer for TensorRT backend."""

def __init__(self,
qconfig=DefalutQconfigs['default'],
qconfig=DefalutQconfigs['tensorrt'],
is_qat=True,
skipped_methods=None,
prepare_custom_config_dict=None,
Expand Down
32 changes: 25 additions & 7 deletions tests/test_models/test_task_modules/test_custom_tracer.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,35 @@
# Copyright (c) OpenMMLab. All rights reserved.
class testCustomTracer:
from unittest import TestCase

def test_init():
pass
from mmrazor.models.task_modules import CustomTracer, UntracedMethodRegistry
from mmrazor.testing import ConvBNReLU


class testCustomTracer(TestCase):

def test_init(self):
tracer = CustomTracer()
assert tracer.skipped_methods.__len__() == 0

def test_trace():
def test_trace(self):
tracer = CustomTracer()
model = ConvBNReLU(3, 3, norm_cfg=dict(type='BN'))
graph = tracer.trace(model) # noqa: F841

def test_auto_skip_call_module(self):
pass

def test_auto_skip_call_module():
def test_auto_skip_call_method(self):
pass

def test_auto_skip_call_method():
def test_configurable_skipped_methods(self):
pass

def test_configurable_skipped_methods():

class testUntracedMethodRgistry(TestCase):

def test_init(self):
self.assertEqual(len(UntracedMethodRegistry.method_dict), 0)

def test_add_method(self):
pass
4 changes: 2 additions & 2 deletions tools/tracer_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from mmengine.config import Config
from mmengine.registry import MODELS

from mmrazor.models.task_modules.tracer import custom_symbolic_tracer
from mmrazor.models.task_modules.tracer import custom_symbolic_trace

cfg_path = 'configs/quantization/ptq/demo.py'
_ADAROUND_SUPPORT_TYPE = (torch.nn.Conv2d, torch.nn.Linear)
Expand Down Expand Up @@ -73,7 +73,7 @@ def main():
# load config
cfg = Config.fromfile(cfg_path)
model = MODELS.build(cfg.model)
symbolic_traced = custom_symbolic_tracer(
symbolic_traced = custom_symbolic_trace(
model, concrete_args={'mode': 'tensor'})
# block_slices = extract_blocks(symbolic_traced)
block_slices = extract_layers(
Expand Down

0 comments on commit a6bb97d

Please sign in to comment.