Skip to content

Commit 30b36b8

Browse files
authored
Add pt2e dynamic quantization (#1795)
Signed-off-by: yiliu30 <yi4.liu@intel.com>
1 parent d4b0f0f commit 30b36b8

File tree

10 files changed

+167
-41
lines changed

10 files changed

+167
-41
lines changed

neural_compressor/torch/algorithms/pt2e_quant/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@
1313
# limitations under the License.
1414

1515

16-
from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8StaticQuantizer
16+
from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8PT2EQuantizer

neural_compressor/torch/algorithms/pt2e_quant/core.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@
1818

1919
from typing import Any
2020

21-
import torch
2221
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
23-
from torch._export import capture_pre_autograd_graph
2422
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
2523
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
2624
from torch.fx.graph_module import GraphModule
@@ -30,15 +28,21 @@
3028
from neural_compressor.torch.utils import create_xiq_quantizer_from_pt2e_config
3129

3230

33-
class W8A8StaticQuantizer(Quantizer):
31+
class W8A8PT2EQuantizer(Quantizer):
32+
is_dynamic = False
33+
34+
def __init__(self, quant_config=None):
35+
super().__init__(quant_config)
3436

3537
@staticmethod
3638
def update_quantizer_based_on_quant_config(quant_config=None) -> X86InductorQuantizer:
3739
if not quant_config:
3840
quantizer = X86InductorQuantizer()
39-
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())
41+
quantizer.set_global(
42+
xiq.get_default_x86_inductor_quantization_config(is_dynamic=W8A8PT2EQuantizer.is_dynamic)
43+
)
4044
else:
41-
quantizer = create_xiq_quantizer_from_pt2e_config(quant_config)
45+
quantizer = create_xiq_quantizer_from_pt2e_config(quant_config, is_dynamic=W8A8PT2EQuantizer.is_dynamic)
4246
return quantizer
4347

4448
def prepare(self, model: GraphModule, example_inputs=None, inplace=True, *args, **kwargs) -> GraphModule:

neural_compressor/torch/quantization/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
get_default_fp8_config,
3636
get_default_fp8_config_set,
3737
get_woq_tuning_config,
38+
DynamicQuantConfig,
39+
get_default_dynamic_config,
3840
)
3941

4042
from neural_compressor.torch.quantization.autotune import (

neural_compressor/torch/quantization/algorithm_entry.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from copy import deepcopy
1616
from types import MethodType
17-
from typing import Any, Callable, Dict, Tuple
17+
from typing import Callable, Dict, Tuple
1818

1919
import torch
2020

@@ -42,7 +42,7 @@
4242
TEQConfig,
4343
)
4444
from neural_compressor.torch.utils import get_quantizer, is_ipex_imported, logger, postprocess_model, register_algo
45-
from neural_compressor.torch.utils.constants import PT2E_STATIC_QUANT
45+
from neural_compressor.torch.utils.constants import PT2E_DYNAMIC_QUANT, PT2E_STATIC_QUANT
4646

4747

4848
###################### RTN Algo Entry ##################################
@@ -186,19 +186,39 @@ def static_quant_entry(
186186
return model
187187

188188

189+
###################### PT2E Dynamic Quant Algo Entry ##################################
190+
@register_algo(name=PT2E_DYNAMIC_QUANT)
191+
@torch.no_grad()
192+
def pt2e_dynamic_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode, *args, **kwargs) -> torch.nn.Module:
193+
logger.info("Quantize model with the PT2E static quant algorithm.")
194+
from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8PT2EQuantizer
195+
196+
run_fn = kwargs.get("run_fn", None)
197+
example_inputs = kwargs.get("example_inputs", None)
198+
inplace = kwargs.get("inplace", True)
199+
W8A8PT2EQuantizer.is_dynamic = True
200+
for _, quant_config in configs_mapping.items():
201+
if quant_config.name == PT2E_DYNAMIC_QUANT:
202+
w8a8_quantizer = W8A8PT2EQuantizer(quant_config=quant_config)
203+
model = w8a8_quantizer.execute(
204+
model, mode=mode, run_fn=run_fn, example_inputs=example_inputs, inplace=inplace
205+
)
206+
return model
207+
208+
189209
###################### PT2E Static Quant Algo Entry ##################################
190210
@register_algo(name=PT2E_STATIC_QUANT)
191211
@torch.no_grad()
192212
def pt2e_static_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode, *args, **kwargs) -> torch.nn.Module:
193213
logger.info("Quantize model with the PT2E static quant algorithm.")
194-
from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8StaticQuantizer
214+
from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8PT2EQuantizer
195215

196216
run_fn = kwargs.get("run_fn", None)
197217
example_inputs = kwargs.get("example_inputs", None)
198218
inplace = kwargs.get("inplace", True)
199219
for _, quant_config in configs_mapping.items():
200220
if quant_config.name == STATIC_QUANT:
201-
w8a8_quantizer = W8A8StaticQuantizer(quant_config=quant_config)
221+
w8a8_quantizer = W8A8PT2EQuantizer(quant_config=quant_config)
202222
model = w8a8_quantizer.execute(
203223
model, mode=mode, run_fn=run_fn, example_inputs=example_inputs, inplace=inplace
204224
)

neural_compressor/torch/quantization/config.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
# pylint:disable=import-error
1818

1919
from collections import OrderedDict
20-
from typing import Any, Callable, Dict, List, NamedTuple, Optional
20+
from typing import Callable, Dict, List, NamedTuple, Optional
2121
from typing import OrderedDict as OrderedDictType
2222
from typing import Tuple, Union
2323

@@ -50,6 +50,7 @@
5050
PRIORITY_HQQ,
5151
PRIORITY_RTN,
5252
PRIORITY_TEQ,
53+
PT2E_DYNAMIC_QUANT,
5354
)
5455

5556
__all__ = [
@@ -778,6 +779,80 @@ def get_default_AutoRound_config() -> AutoRoundConfig:
778779
return AutoRoundConfig()
779780

780781

782+
######################## Dynamic Quant Config ###############################
783+
@register_config(framework_name=FRAMEWORK_NAME, algo_name=PT2E_DYNAMIC_QUANT)
784+
class DynamicQuantConfig(BaseConfig):
785+
"""Config class for dynamic quantization."""
786+
787+
name = PT2E_DYNAMIC_QUANT
788+
params_list = [
789+
"w_dtype",
790+
"w_sym",
791+
"w_granularity",
792+
"w_algo",
793+
"act_dtype",
794+
"act_sym",
795+
"act_granularity",
796+
"act_algo",
797+
]
798+
supported_configs: List[OperatorConfig] = []
799+
800+
def __init__(
801+
self,
802+
w_dtype: str = "int8",
803+
w_sym: bool = True,
804+
w_granularity: str = "per_tensor",
805+
w_algo: str = "minmax",
806+
act_dtype: str = "uint8",
807+
act_sym: bool = False,
808+
act_granularity: str = "per_tensor",
809+
act_algo: str = "kl",
810+
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
811+
):
812+
"""Init Dynamic Quant Configs."""
813+
super().__init__(white_list=white_list)
814+
self.w_dtype = w_dtype
815+
self.w_sym = w_sym
816+
self.w_granularity = w_granularity
817+
self.w_algo = w_algo
818+
self.act_dtype = act_dtype
819+
self.act_sym = act_sym
820+
self.act_granularity = act_granularity
821+
self.act_algo = act_algo
822+
self._post_init()
823+
824+
@classmethod
825+
def register_supported_configs(cls) -> List[OperatorConfig]:
826+
supported_configs = []
827+
linear_static_config = cls()
828+
operators = [torch.nn.Linear]
829+
supported_configs.append(OperatorConfig(config=linear_static_config, operators=operators))
830+
cls.supported_configs = supported_configs
831+
832+
@staticmethod
833+
def get_model_info(model: torch.nn.Module, example_inputs=None):
834+
return None
835+
836+
def to_config_mapping(
837+
self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
838+
) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]:
839+
config_mapping = OrderedDict({self.name: self})
840+
return config_mapping
841+
842+
@classmethod
843+
def get_config_set_for_tuning(cls) -> Union[None, "DynamicQuantConfig", List["DynamicQuantConfig"]]:
844+
return cls(act_sym=[True, False], act_algo=["kl", "minmax"])
845+
846+
847+
def get_default_dynamic_config() -> DynamicQuantConfig:
848+
"""Generate the default dynamic quant config.
849+
850+
Returns:
851+
the default dynamic quant config.
852+
"""
853+
return DynamicQuantConfig()
854+
855+
781856
######################## Static Quant Config ###############################
782857
@register_config(framework_name=FRAMEWORK_NAME, algo_name=STATIC_QUANT)
783858
class StaticQuantConfig(BaseConfig):

neural_compressor/torch/utils/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,4 @@
5252

5353

5454
PT2E_STATIC_QUANT = "pt2e_static_quant"
55+
PT2E_DYNAMIC_QUANT = "pt2e_dynamic_quant"

neural_compressor/torch/utils/environ.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,20 +31,20 @@ def is_hpex_available():
3131
return _hpex_available
3232

3333

34-
try:
35-
import intel_extension_for_pytorch as ipex
36-
37-
_ipex_available = True
38-
except:
39-
_ipex_available = False
40-
41-
4234
def is_ipex_available():
35+
try:
36+
import intel_extension_for_pytorch as ipex
37+
38+
_ipex_available = True
39+
except:
40+
_ipex_available = False
4341
return _ipex_available
4442

4543

4644
def get_ipex_version():
47-
if _ipex_available:
45+
if is_ipex_available():
46+
import intel_extension_for_pytorch as ipex
47+
4848
try:
4949
ipex_version = ipex.__version__.split("+")[0]
5050
except ValueError as e: # pragma: no cover

neural_compressor/torch/utils/utility.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import torch
1919
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
20-
from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver
20+
from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver, PlaceholderObserver
2121
from torch.ao.quantization.quantizer import QuantizationSpec
2222
from torch.ao.quantization.quantizer.x86_inductor_quantizer import QuantizationConfig, X86InductorQuantizer
2323
from typing_extensions import TypeAlias
@@ -172,30 +172,41 @@ def postprocess_model(model, mode, quantizer):
172172
del model.quantizer
173173

174174

175-
def create_quant_spec_from_config(dtype, sym, granularity, algo) -> QuantizationSpec:
175+
def create_quant_spec_from_config(dtype, sym, granularity, algo, is_dynamic=False) -> QuantizationSpec:
176176
dtype_mapping: Dict[str, torch.dtype] = {"int8": torch.int8, "uint8": torch.uint8}
177+
select_dtype = dtype_mapping[dtype]
178+
min_max_mapping = {torch.int8: (-128, 127), torch.uint8: (0, 255)}
177179
qscheme_mapping = {
178180
"per_channel": {True: torch.per_channel_symmetric, False: torch.per_tensor_affine},
179181
"per_tensor": {True: torch.per_tensor_symmetric, False: torch.per_tensor_affine},
180182
}
181183
observer_mapping = {
184+
"placeholder": PlaceholderObserver,
182185
"minmax": MinMaxObserver,
183186
"kl": HistogramObserver,
184187
}
188+
# Force to use placeholder observer for dynamic quantization
189+
if is_dynamic:
190+
algo = "placeholder"
185191
# algo
186192
observer_or_fake_quant_ctr = observer_mapping[algo]
187193
# qscheme
188194
qscheme = qscheme_mapping[granularity][sym]
189195
quantization_spec = QuantizationSpec(
190-
dtype=dtype_mapping[dtype], observer_or_fake_quant_ctr=observer_or_fake_quant_ctr, qscheme=qscheme
196+
dtype=select_dtype,
197+
quant_min=min_max_mapping[select_dtype][0],
198+
quant_max=min_max_mapping[select_dtype][1],
199+
observer_or_fake_quant_ctr=observer_or_fake_quant_ctr,
200+
qscheme=qscheme,
201+
is_dynamic=is_dynamic,
191202
)
192203
return quantization_spec
193204

194205

195-
def _map_inc_config_to_torch_quant_config(inc_config) -> QuantizationConfig:
196-
default_quant_config = xiq.get_default_x86_inductor_quantization_config()
206+
def _map_inc_config_to_torch_quant_config(inc_config, is_dynamic=False) -> QuantizationConfig:
207+
default_quant_config = xiq.get_default_x86_inductor_quantization_config(is_dynamic=is_dynamic)
197208
input_act_quant_spec = create_quant_spec_from_config(
198-
inc_config.act_dtype, inc_config.act_sym, inc_config.act_granularity, inc_config.act_algo
209+
inc_config.act_dtype, inc_config.act_sym, inc_config.act_granularity, inc_config.act_algo, is_dynamic=is_dynamic
199210
)
200211
weight_quant_spec = create_quant_spec_from_config(
201212
inc_config.w_dtype, inc_config.w_sym, inc_config.w_granularity, inc_config.w_algo
@@ -210,14 +221,14 @@ def _map_inc_config_to_torch_quant_config(inc_config) -> QuantizationConfig:
210221
return quant_config
211222

212223

213-
def create_xiq_quantizer_from_pt2e_config(config) -> X86InductorQuantizer:
224+
def create_xiq_quantizer_from_pt2e_config(config, is_dynamic=False) -> X86InductorQuantizer:
214225
quantizer = xiq.X86InductorQuantizer()
215226
# set global
216-
global_config = _map_inc_config_to_torch_quant_config(config)
227+
global_config = _map_inc_config_to_torch_quant_config(config, is_dynamic)
217228
quantizer.set_global(global_config)
218229
# set local
219230
for module_or_func_name, local_config in config.local_config.items():
220-
local_quant_config = _map_inc_config_to_torch_quant_config(local_config)
231+
local_quant_config = _map_inc_config_to_torch_quant_config(local_config, is_dynamic)
221232
if isinstance(module_or_func_name, torch.nn.Module):
222233
quantizer.set_module_type_qconfig(module_or_func_name, local_quant_config)
223234
else:

test/3x/torch/algorithms/pt2e_quant/test_pt2e_w8a8.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
import torch
66

77
from neural_compressor.common.utils import logger
8-
from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8StaticQuantizer
8+
from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8PT2EQuantizer
99
from neural_compressor.torch.export import export_model_for_pt2e_quant
1010
from neural_compressor.torch.utils import TORCH_VERSION_2_2_2, get_torch_version
1111

1212

13-
class TestW8A8StaticQuantizer:
13+
class TestW8A8PT2EQuantizer:
1414

1515
@staticmethod
1616
def get_toy_model():
@@ -52,7 +52,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
5252
@pytest.mark.skipif(get_torch_version() <= TORCH_VERSION_2_2_2, reason="Requires torch>=2.3.0")
5353
def test_quantizer_on_simple_model(self):
5454
model, example_inputs = self.build_simple_torch_model_and_example_inputs()
55-
w8a8_static_quantizer = W8A8StaticQuantizer()
55+
w8a8_static_quantizer = W8A8PT2EQuantizer()
5656
# prepare
5757
prepare_model = w8a8_static_quantizer.prepare(model, example_inputs=example_inputs)
5858
# calibrate
@@ -81,7 +81,7 @@ def test_quantizer_on_llm(self):
8181
model = export_model_for_pt2e_quant(model, example_inputs=example_inputs)
8282

8383
quant_config = None
84-
w8a8_static_quantizer = W8A8StaticQuantizer()
84+
w8a8_static_quantizer = W8A8PT2EQuantizer()
8585
# prepare
8686
prepare_model = w8a8_static_quantizer.prepare(model)
8787
# calibrate

0 commit comments

Comments
 (0)