-
Notifications
You must be signed in to change notification settings - Fork 260
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support Autotune FP16 Mix-precision on torch 3.0 new API (#1793)
Signed-off-by: zehao-intel <zehao.huang@intel.com>
- Loading branch information
1 parent
bacc164
commit 2e1cdc5
Showing
8 changed files
with
316 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
19 changes: 19 additions & 0 deletions
19
neural_compressor/torch/algorithms/mix_precision/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (c) 2024 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from neural_compressor.torch.algorithms.mix_precision.half_precision_convert import HalfPrecisionConverter | ||
from neural_compressor.torch.algorithms.mix_precision.module_wrappers import HalfPrecisionModuleWrapper |
88 changes: 88 additions & 0 deletions
88
neural_compressor/torch/algorithms/mix_precision/half_precision_convert.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (c) 2024 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Half-precision Convert for Torch Modules.""" | ||
|
||
from typing import Dict, Tuple | ||
|
||
import torch | ||
|
||
from neural_compressor.common import logger | ||
from neural_compressor.torch.algorithms.mix_precision.module_wrappers import HalfPrecisionModuleWrapper | ||
from neural_compressor.torch.utils import get_device | ||
|
||
|
||
class HalfPrecisionConverter: | ||
"""Converter Class for FP16 and BF16.""" | ||
|
||
dtype_mapping = { | ||
"fp16": torch.float16, | ||
"bf16": torch.bfloat16, | ||
} | ||
|
||
def __init__(self, configs_mapping: Dict[Tuple[str], object], *args, **kwargs): | ||
"""Initialize the Half-precision Converter with config. | ||
Args: | ||
configs_mapping (Dict): config class for mix-precision. | ||
""" | ||
self.configs_mapping = configs_mapping | ||
self.device = get_device() | ||
|
||
def convert(self, model: torch.nn.Module): | ||
"""Convert to FP16 or BF16 model. | ||
Args: | ||
model (torch.nn.Module): the input model. | ||
Returns: | ||
mix_precision_model (torch.nn.Module): model with mix-precision. | ||
""" | ||
if len(self.configs_mapping) > 0: | ||
logger.info("Convert operators to half-precision") | ||
|
||
if next(model.parameters()).is_cuda: | ||
self.device = "cuda" | ||
elif next(model.parameters()).is_cpu: | ||
self.device = "cpu" | ||
|
||
mix_precision_model = self._wrap_half_precision_model(model) | ||
mix_precision_model.to(self.device) | ||
|
||
return mix_precision_model | ||
|
||
def _wrap_half_precision_model(self, model: torch.nn.Module, prefix=""): | ||
"""Wrap and replace half-precision target modules. | ||
Args: | ||
model (torch.nn.Module): the input module. | ||
prefix (str): the name prefix for named children. | ||
Returns: | ||
model (torch.nn.Module): the model whose target modules have been wrapped. | ||
""" | ||
for name, child in model.named_children(): | ||
op_name = prefix + "." + name if prefix != "" else name | ||
for op_info, config in self.configs_mapping.items(): | ||
if op_name == op_info[0] and config.dtype in ("fp16", "bf16"): | ||
child = HalfPrecisionModuleWrapper( | ||
module=child, device=self.device, dtype=self.dtype_mapping[config.dtype] | ||
) | ||
else: | ||
self._wrap_half_precision_model(child, op_name) | ||
setattr(model, name, child) | ||
|
||
return model |
38 changes: 38 additions & 0 deletions
38
neural_compressor/torch/algorithms/mix_precision/module_wrappers.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (c) 2024 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Half-precision Wrapper for Torch Modules.""" | ||
|
||
import torch | ||
|
||
|
||
class HalfPrecisionModuleWrapper(torch.nn.Module): | ||
"""FP16 or BF16 Module Wrapper Class.""" | ||
|
||
def __init__(self, module, device="cpu", dtype=torch.float16): | ||
"""Init a HalfPrecisionModuleWrapper object.""" | ||
super(HalfPrecisionModuleWrapper, self).__init__() | ||
self.add_module("module", module) | ||
self.device = device | ||
self.dtype = dtype | ||
self.weight = self.module.weight if hasattr(self.module, "weight") else None | ||
self.bias = self.module.bias if hasattr(self.module, "bias") else None | ||
|
||
def forward(self, X): | ||
"""Convert dtype.""" | ||
with torch.autocast(device_type=self.device, dtype=self.dtype): | ||
X = self.module(X) | ||
return X.float() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters