Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions neural_compressor/common/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
COMPOSABLE_CONFIG = "composable_config"
RTN = "rtn"
STATIC_QUANT = "static_quant"
SMOOTH_QUANT = "smooth_quant"
GPTQ = "gptq"
FP8_QUANT = "fp8_quant"

Expand Down
4 changes: 4 additions & 0 deletions neural_compressor/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
get_default_rtn_config,
GPTQConfig,
get_default_gptq_config,
StaticQuantConfig,
get_default_static_config,
SmoothQuantConfig,
get_default_sq_config,
)

from neural_compressor.common.base_tuning import TuningConfig
Expand Down
4 changes: 4 additions & 0 deletions neural_compressor/torch/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,8 @@
get_default_rtn_config,
GPTQConfig,
get_default_gptq_config,
StaticQuantConfig,
get_default_static_config,
SmoothQuantConfig,
get_default_sq_config,
)
209 changes: 208 additions & 1 deletion neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,15 @@
import torch

from neural_compressor.common.base_config import BaseConfig, config_registry, register_config
from neural_compressor.common.utils import DEFAULT_WHITE_LIST, FP8_QUANT, GPTQ, OP_NAME_OR_MODULE_TYPE, RTN
from neural_compressor.common.utils import (
DEFAULT_WHITE_LIST,
FP8_QUANT,
GPTQ,
OP_NAME_OR_MODULE_TYPE,
RTN,
SMOOTH_QUANT,
STATIC_QUANT,
)
from neural_compressor.torch.utils.constants import PRIORITY_GPTQ, PRIORITY_RTN
from neural_compressor.torch.utils.utility import is_hpex_avaliable, logger

Expand Down Expand Up @@ -311,6 +319,205 @@ def get_default_gptq_config() -> GPTQConfig:
return GPTQConfig()


######################## Static Quant Config ###############################
@register_config(framework_name=FRAMEWORK_NAME, algo_name=STATIC_QUANT)
class StaticQuantConfig(BaseConfig):
"""Config class for static quantization."""

name = STATIC_QUANT
supported_configs: List[OperatorConfig] = []
params_list = [
"w_dtype",
"w_sym",
"w_granularity",
"w_algo",
"act_dtype" "act_sym",
"act_granularity",
"act_algo",
]

def __init__(
self,
w_dtype: str = "int8",
w_sym: bool = True,
w_granularity: str = "per_channel",
w_algo: str = "minmax",
act_dtype: str = "uint8",
act_sym: bool = False,
act_granularity: str = "per_tensor",
act_algo: str = "minmax",
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
):
"""Init Static Quant Configs."""
super().__init__(white_list=white_list)
self.w_dtype = w_dtype
self.w_sym = w_sym
self.w_granularity = w_granularity
self.w_algo = w_algo
self.act_dtype = act_dtype
self.act_sym = act_sym
self.act_granularity = act_granularity
self.act_algo = act_algo
self._post_init()

def to_dict(self):
return super().to_dict(params_list=self.params_list, operator2str=operator2str)

@classmethod
def from_dict(cls, config_dict):
return super(StaticQuantConfig, cls).from_dict(config_dict=config_dict, str2operator=str2operator)

@classmethod
def register_supported_configs(cls) -> List[OperatorConfig]:
supported_configs = []
# TODO(Yi)
linear_static_config = StaticQuantConfig()
operators = [torch.nn.Linear, torch.nn.functional.linear]
supported_configs.append(
OperatorConfig(config=linear_static_config, operators=operators, backend=Backend.DEFAULT)
)
cls.supported_configs = supported_configs

@staticmethod
def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
white_list = (torch.nn.Linear,)
filter_result = []
for op_name, module in model.named_modules():
if isinstance(module, white_list):
pair = (op_name, type(module).__name__)
filter_result.append(pair)
logger.debug(f"Get model info: {filter_result}")
return filter_result


# TODO(Yi) run `register_supported_configs` for all registered config.
StaticQuantConfig.register_supported_configs()


def get_default_static_config() -> StaticQuantConfig:
"""Generate the default static quant config.

Returns:
the default static quant config.
"""
return StaticQuantConfig()


######################## Smooth Quant Config ###############################
@register_config(framework_name=FRAMEWORK_NAME, algo_name=SMOOTH_QUANT)
class SmoothQuantConfig(BaseConfig):
"""Config class for smooth quantization."""

name = SMOOTH_QUANT
supported_configs: List[OperatorConfig] = []
params_list = [
"w_dtype",
"w_sym",
"w_granularity",
"w_algo",
"act_dtype" "act_sym",
"act_granularity",
"act_algo",
"alpha",
"folding",
"scale_sharing",
"auto_alpha_args",
]

def __init__(
self,
w_dtype: str = "int8",
w_sym: bool = True,
w_granularity: str = "per_channel",
w_algo: str = "minmax",
act_dtype: str = "uint8",
act_sym: bool = False,
act_granularity: str = "per_tensor",
act_algo: str = "minmax",
alpha: float = 0.5,
folding: bool = False,
# below for autotune
scale_sharing: bool = False,
init_alpha: float = 0.5,
alpha_min: float = 0.0,
alpha_max: float = 1.0,
alpha_step: float = 0.1,
shared_criterion: str = "max",
enable_blockwise_loss: bool = False,
auto_alpha_args: dict = None,
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
):
"""Init SmoothQuant Configs."""
super().__init__(white_list=white_list)
self.w_dtype = w_dtype
self.w_sym = w_sym
self.w_granularity = w_granularity
self.w_algo = w_algo
self.act_dtype = act_dtype
self.act_sym = act_sym
self.act_granularity = act_granularity
self.act_algo = act_algo
self.alpha = alpha
self.folding = folding
# below for autotune
self.scale_sharing = scale_sharing
self.init_alpha = init_alpha
self.alpha_min = alpha_min
self.alpha_max = alpha_max
self.alpha_step = alpha_step
self.shared_criterion = shared_criterion
self.enable_blockwise_loss = enable_blockwise_loss
self.auto_alpha_args = {
"init_alpha": self.init_alpha,
"alpha_min": self.alpha_min,
"alpha_max": self.alpha_max,
"alpha_step": self.alpha_step,
"shared_criterion": self.shared_criterion,
"enable_blockwise_loss": self.enable_blockwise_loss,
}
self._post_init()

def to_dict(self):
return super().to_dict(params_list=self.params_list, operator2str=operator2str)

@classmethod
def from_dict(cls, config_dict):
return super(SmoothQuantConfig, cls).from_dict(config_dict=config_dict, str2operator=str2operator)

@classmethod
def register_supported_configs(cls) -> List[OperatorConfig]:
supported_configs = []
# TODO(Yi)
linear_sq_config = SmoothQuantConfig()
operators = [torch.nn.Linear, torch.nn.functional.linear]
supported_configs.append(OperatorConfig(config=linear_sq_config, operators=operators, backend=Backend.DEFAULT))
cls.supported_configs = supported_configs

@staticmethod
def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
white_list = (torch.nn.Linear,)
filter_result = []
for op_name, module in model.named_modules():
if isinstance(module, white_list):
pair = (op_name, type(module).__name__)
filter_result.append(pair)
logger.debug(f"Get model info: {filter_result}")
return filter_result


# TODO(Yi) run `register_supported_configs` for all registered config.
SmoothQuantConfig.register_supported_configs()


def get_default_sq_config() -> SmoothQuantConfig:
"""Generate the default smoothquant config.

Returns:
the default smoothquant config.
"""
return SmoothQuantConfig()


######################## FP8 Config ###############################
if is_hpex_avaliable():

Expand Down