From 43bb99d2a20b035222964ec012b48d6714443b89 Mon Sep 17 00:00:00 2001 From: humu789 Date: Tue, 11 Apr 2023 17:11:09 +0800 Subject: [PATCH] fix lint --- mmrazor/engine/runner/quantization_loops.py | 3 ++- mmrazor/models/fake_quants/torch_fake_quants.py | 2 +- mmrazor/models/observers/torch_observers.py | 1 + 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/mmrazor/engine/runner/quantization_loops.py b/mmrazor/engine/runner/quantization_loops.py index e392070c3..d694f3da8 100644 --- a/mmrazor/engine/runner/quantization_loops.py +++ b/mmrazor/engine/runner/quantization_loops.py @@ -18,14 +18,15 @@ from torch.utils.data import DataLoader +from mmrazor.models import register_torch_fake_quants, register_torch_observers from mmrazor.models.fake_quants import (enable_param_learning, enable_static_estimate, enable_val) from mmrazor.registry import LOOPS -from mmrazor.models import register_torch_fake_quants, register_torch_observers TORCH_observers = register_torch_observers() TORCH_fake_quants = register_torch_fake_quants() + @LOOPS.register_module() class QATEpochBasedLoop(EpochBasedTrainLoop): """`EpochBasedLoop` for `QuantizationAwareTraining` diff --git a/mmrazor/models/fake_quants/torch_fake_quants.py b/mmrazor/models/fake_quants/torch_fake_quants.py index e7420a8d3..06e325b32 100644 --- a/mmrazor/models/fake_quants/torch_fake_quants.py +++ b/mmrazor/models/fake_quants/torch_fake_quants.py @@ -10,6 +10,7 @@ from mmrazor.utils import get_package_placeholder torch_fake_quant_src = get_package_placeholder('torch>=1.13') + # TORCH_fake_quants = register_torch_fake_quants() # TORCH_fake_quants including: # FakeQuantize @@ -35,4 +36,3 @@ def register_torch_fake_quants() -> List[str]: MODELS.register_module(module=_fake_quant) torch_fake_quants.append(module_name) return torch_fake_quants - diff --git a/mmrazor/models/observers/torch_observers.py b/mmrazor/models/observers/torch_observers.py index 2c2d49382..4e540667a 100644 --- a/mmrazor/models/observers/torch_observers.py +++ b/mmrazor/models/observers/torch_observers.py @@ -30,6 +30,7 @@ def reset_min_max_vals(self): PerChannelMinMaxObserver.reset_min_max_vals = reset_min_max_vals + # TORCH_observers = register_torch_observers() # TORCH_observers including: # FixedQParamsObserver