Skip to content

Commit

Permalink
Support Autotune FP16 Mix-precision on torch 3.0 new API (#1793)
Browse files Browse the repository at this point in the history
Signed-off-by: zehao-intel <zehao.huang@intel.com>
  • Loading branch information
zehao-intel authored May 17, 2024
1 parent bacc164 commit 2e1cdc5
Show file tree
Hide file tree
Showing 8 changed files with 316 additions and 1 deletion.
1 change: 1 addition & 0 deletions neural_compressor/common/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
TEQ = "teq" # pragma: no cover
AUTOROUND = "autoround"
FP8_QUANT = "fp8_quant"
MIX_PRECISION = "mix_precision"

# options
import datetime
Expand Down
19 changes: 19 additions & 0 deletions neural_compressor/torch/algorithms/mix_precision/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from neural_compressor.torch.algorithms.mix_precision.half_precision_convert import HalfPrecisionConverter
from neural_compressor.torch.algorithms.mix_precision.module_wrappers import HalfPrecisionModuleWrapper
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#
# -*- coding: utf-8 -*-
#
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Half-precision Convert for Torch Modules."""

from typing import Dict, Tuple

import torch

from neural_compressor.common import logger
from neural_compressor.torch.algorithms.mix_precision.module_wrappers import HalfPrecisionModuleWrapper
from neural_compressor.torch.utils import get_device


class HalfPrecisionConverter:
"""Converter Class for FP16 and BF16."""

dtype_mapping = {
"fp16": torch.float16,
"bf16": torch.bfloat16,
}

def __init__(self, configs_mapping: Dict[Tuple[str], object], *args, **kwargs):
"""Initialize the Half-precision Converter with config.
Args:
configs_mapping (Dict): config class for mix-precision.
"""
self.configs_mapping = configs_mapping
self.device = get_device()

def convert(self, model: torch.nn.Module):
"""Convert to FP16 or BF16 model.
Args:
model (torch.nn.Module): the input model.
Returns:
mix_precision_model (torch.nn.Module): model with mix-precision.
"""
if len(self.configs_mapping) > 0:
logger.info("Convert operators to half-precision")

if next(model.parameters()).is_cuda:
self.device = "cuda"
elif next(model.parameters()).is_cpu:
self.device = "cpu"

mix_precision_model = self._wrap_half_precision_model(model)
mix_precision_model.to(self.device)

return mix_precision_model

def _wrap_half_precision_model(self, model: torch.nn.Module, prefix=""):
"""Wrap and replace half-precision target modules.
Args:
model (torch.nn.Module): the input module.
prefix (str): the name prefix for named children.
Returns:
model (torch.nn.Module): the model whose target modules have been wrapped.
"""
for name, child in model.named_children():
op_name = prefix + "." + name if prefix != "" else name
for op_info, config in self.configs_mapping.items():
if op_name == op_info[0] and config.dtype in ("fp16", "bf16"):
child = HalfPrecisionModuleWrapper(
module=child, device=self.device, dtype=self.dtype_mapping[config.dtype]
)
else:
self._wrap_half_precision_model(child, op_name)
setattr(model, name, child)

return model
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#
# -*- coding: utf-8 -*-
#
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Half-precision Wrapper for Torch Modules."""

import torch


class HalfPrecisionModuleWrapper(torch.nn.Module):
"""FP16 or BF16 Module Wrapper Class."""

def __init__(self, module, device="cpu", dtype=torch.float16):
"""Init a HalfPrecisionModuleWrapper object."""
super(HalfPrecisionModuleWrapper, self).__init__()
self.add_module("module", module)
self.device = device
self.dtype = dtype
self.weight = self.module.weight if hasattr(self.module, "weight") else None
self.bias = self.module.bias if hasattr(self.module, "bias") else None

def forward(self, X):
"""Convert dtype."""
with torch.autocast(device_type=self.device, dtype=self.dtype):
X = self.module(X)
return X.float()
3 changes: 3 additions & 0 deletions neural_compressor/torch/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
FP8Config,
get_default_fp8_config,
get_default_fp8_config_set,
MixPrecisionConfig,
get_default_mix_precision_config,
get_default_mix_precision_config_set,
get_woq_tuning_config,
DynamicQuantConfig,
get_default_dynamic_config,
Expand Down
16 changes: 16 additions & 0 deletions neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
FP8_QUANT,
GPTQ,
HQQ,
MIX_PRECISION,
RTN,
SMOOTH_QUANT,
STATIC_QUANT,
Expand All @@ -36,6 +37,7 @@
FP8Config,
GPTQConfig,
HQQConfig,
MixPrecisionConfig,
RTNConfig,
SmoothQuantConfig,
StaticQuantConfig,
Expand Down Expand Up @@ -528,3 +530,17 @@ def fp8_quant_entry(
model.qconfig = configs_mapping
model.save = MethodType(save, model)
return model


###################### Mixed Precision Algo Entry ##################################
@register_algo(MIX_PRECISION)
def mix_precision_entry(
model: torch.nn.Module, configs_mapping: Dict[Tuple[str], MixPrecisionConfig], *args, **kwargs
) -> torch.nn.Module:
# only support fp16 and bf16 now, more types might be added later
from neural_compressor.torch.algorithms.mix_precision import HalfPrecisionConverter

half_precision_converter = HalfPrecisionConverter(configs_mapping, *args, **kwargs)
mix_precision_model = half_precision_converter.convert(model)

return mix_precision_model
76 changes: 76 additions & 0 deletions neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
FP8_QUANT,
GPTQ,
HQQ,
MIX_PRECISION,
OP_NAME_OR_MODULE_TYPE,
RTN,
SMOOTH_QUANT,
Expand Down Expand Up @@ -1196,6 +1197,81 @@ def get_default_fp8_config_set() -> FP8Config:
return FP8Config.get_config_set_for_tuning()


######################## MixPrecision Config ###############################
@register_config(framework_name=FRAMEWORK_NAME, algo_name=MIX_PRECISION)
class MixPrecisionConfig(BaseConfig):
"""Config class for mix-precision."""

name = MIX_PRECISION
supported_configs: List[OperatorConfig] = []
params_list = [
"dtype",
]
supported_half_precision_ops = (
torch.nn.Linear,
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.Conv3d,
)

def __init__(
self,
dtype: Union[str, List[str]] = "fp16",
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
):
"""Init MixPrecision config.
Args:
"""
super().__init__(white_list=white_list)
self.dtype = dtype
self._post_init()

@classmethod
def register_supported_configs(cls) -> List[OperatorConfig]:
supported_configs = []
mix_precision_config = MixPrecisionConfig(
dtype=["fp16", "bf16", "fp32"],
)
operators = cls.supported_half_precision_ops
supported_configs.append(OperatorConfig(config=mix_precision_config, operators=operators))
cls.supported_configs = supported_configs

@staticmethod
def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
white_list = tuple(MixPrecisionConfig.supported_half_precision_ops)
filter_result = []
for op_name, module in model.named_modules():
if isinstance(module, white_list):
pair = (op_name, type(module).__name__)
filter_result.append(pair)
logger.debug(f"Get model info: {filter_result}")
return filter_result

@classmethod
def get_config_set_for_tuning(cls) -> Union[None, "MixPrecisionConfig", List["MixPrecisionConfig"]]:
# TODO fwk owner needs to update it.
return MixPrecisionConfig(dtype=["fp16", "bf16", "fp32"])


def get_default_mix_precision_config() -> MixPrecisionConfig:
"""Generate the default mix-precision config.
Returns:
the default mix-precision config.
"""
return MixPrecisionConfig()


def get_default_mix_precision_config_set() -> MixPrecisionConfig:
"""Generate the default mix-precision config set.
Returns:
the default mix-precision config.
"""
return MixPrecisionConfig.get_config_set_for_tuning()


##################### Algo Configs End ###################################


Expand Down
76 changes: 75 additions & 1 deletion test/3x/torch/test_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@
import transformers

from neural_compressor.common import logger
from neural_compressor.torch.quantization import RTNConfig, TuningConfig, autotune, get_all_config_set
from neural_compressor.torch.quantization import (
MixPrecisionConfig,
RTNConfig,
TuningConfig,
autotune,
get_all_config_set,
)
from neural_compressor.torch.utils import constants

FAKE_DOUBLE_QUANT_CONFIGS = {
Expand Down Expand Up @@ -332,6 +338,74 @@ def eval_acc_fn(model):
)
self.assertIsNone(best_model)

@reset_tuning_target
def test_autotune_mix_precision_default(self):
from neural_compressor.torch.algorithms.mix_precision import HalfPrecisionModuleWrapper

baseline = [1]
acc_res_lst = baseline + [0.9, 0.99, 1]

def eval_acc_fn(model):
res = acc_res_lst.pop(0)
return res

custom_tune_config = TuningConfig(config_set=[MixPrecisionConfig(dtype=["fp16", "bf16", "fp32"])], max_trials=3)
best_model = autotune(model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fn=eval_acc_fn)

self.assertIsNotNone(best_model)
self.assertTrue(isinstance(best_model.fc1, HalfPrecisionModuleWrapper))
self.assertTrue(isinstance(best_model.fc2, HalfPrecisionModuleWrapper))
self.assertTrue(isinstance(best_model.fc3, HalfPrecisionModuleWrapper))

@reset_tuning_target
def test_autotune_mix_precision_set_op_name(self):
from neural_compressor.common.base_config import ComposableConfig, config_registry
from neural_compressor.torch.algorithms.mix_precision import HalfPrecisionModuleWrapper

baseline = [1]
acc_res_lst = baseline + [0.9, 1.1]

def eval_acc_fn(model):
res = acc_res_lst.pop(0)
return res

config1 = {
"mix_precision": {
"global": {
"dtype": "bf16",
},
"local": {
"fc2": {
"dtype": "fp32",
}
},
}
}
config2 = {
"mix_precision": {
"global": {
"dtype": "fp16",
},
"local": {
"fc1": {
"dtype": "fp32",
}
},
}
}

registered_configs = config_registry.get_cls_configs()
config1 = ComposableConfig.from_dict(config1, config_registry=registered_configs["torch"])
config2 = ComposableConfig.from_dict(config2, config_registry=registered_configs["torch"])

custom_tune_config = TuningConfig(config_set=[config1, config2], max_trials=2)
best_model = autotune(model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fn=eval_acc_fn)

self.assertIsNotNone(best_model)
self.assertTrue(isinstance(best_model.fc1, torch.nn.Linear))
self.assertTrue(isinstance(best_model.fc2, HalfPrecisionModuleWrapper))
self.assertTrue(isinstance(best_model.fc3, HalfPrecisionModuleWrapper))


if __name__ == "__main__":
unittest.main()

0 comments on commit 2e1cdc5

Please sign in to comment.