|
16 | 16 | # limitations under the License. |
17 | 17 | # pylint:disable=import-error |
18 | 18 |
|
19 | | -import torch |
20 | 19 | from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union |
21 | 20 |
|
22 | | -from neural_compressor.common.base_config import BaseConfig, config_registry, register_config |
23 | | -from neural_compressor.common.utility import ( |
24 | | - OP_NAME_OR_MODULE_TYPE, |
25 | | - DEFAULT_WHITE_LIST, |
26 | | - FP8_QUANT, |
27 | | - GPTQ, |
28 | | - RTN, |
29 | | -) |
| 21 | +import torch |
30 | 22 |
|
| 23 | +from neural_compressor.common.base_config import BaseConfig, config_registry, register_config |
| 24 | +from neural_compressor.common.utility import DEFAULT_WHITE_LIST, FP8_QUANT, GPTQ, OP_NAME_OR_MODULE_TYPE, RTN |
31 | 25 | from neural_compressor.torch.utils import is_hpex_available, logger |
32 | 26 | from neural_compressor.torch.utils.constants import PRIORITY_GPTQ, PRIORITY_RTN |
33 | 27 |
|
@@ -75,7 +69,7 @@ def __init__( |
75 | 69 | use_layer_wise: bool = False, |
76 | 70 | export_compressed_model: bool = False, |
77 | 71 | double_quant_dtype: str = "fp32", |
78 | | - double_quant_bits: int = 8, # not available when double_quant_dtype is not 'int' |
| 72 | + double_quant_bits: int = 8, # not available when double_quant_dtype is not 'int' |
79 | 73 | double_quant_sym: bool = True, |
80 | 74 | double_quant_group_size: int = 256, |
81 | 75 | white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, |
@@ -255,9 +249,7 @@ def register_supported_configs(cls) -> List[OperatorConfig]: |
255 | 249 | # TODO(Yi) |
256 | 250 | linear_gptq_config = GPTQConfig() |
257 | 251 | operators = [torch.nn.Linear, torch.nn.functional.linear] |
258 | | - supported_configs.append( |
259 | | - OperatorConfig(config=linear_gptq_config, operators=operators) |
260 | | - ) |
| 252 | + supported_configs.append(OperatorConfig(config=linear_gptq_config, operators=operators)) |
261 | 253 | cls.supported_configs = supported_configs |
262 | 254 |
|
263 | 255 | @staticmethod |
@@ -362,8 +354,10 @@ def get_default_fp8_qconfig() -> FP8QConfig: |
362 | 354 | """ |
363 | 355 | return FP8QConfig() |
364 | 356 |
|
| 357 | + |
365 | 358 | ##################### Algo Configs End ################################### |
366 | 359 |
|
| 360 | + |
367 | 361 | def get_all_registered_configs() -> Dict[str, BaseConfig]: |
368 | 362 | registered_configs = config_registry.get_all_configs() |
369 | 363 | return registered_configs.get(FRAMEWORK_NAME, {}) |
0 commit comments