Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update imports after QAT was moved out of prototype #1883

Merged
merged 1 commit into from
Oct 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 46 additions & 28 deletions torchtune/training/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

from typing import Callable, Optional
from warnings import warn

from torchtune.utils._import_guard import _USE_NEW_TENSOR_CORE_TILED_LAYOUT_API

Expand All @@ -18,22 +19,29 @@
int8_dynamic_activation_int4_weight,
quantize_,
)
from torchao.quantization.prototype.qat import (
disable_4w_fake_quant,
disable_8da4w_fake_quant,
enable_4w_fake_quant,
enable_8da4w_fake_quant,
Int4WeightOnlyQATQuantizer,
Int8DynActInt4WeightQATQuantizer,
)
from torchao.quantization.prototype.qat._module_swap_api import (
disable_4w_fake_quant_module_swap,
disable_8da4w_fake_quant_module_swap,
enable_4w_fake_quant_module_swap,
enable_8da4w_fake_quant_module_swap,
Int4WeightOnlyQATQuantizerModuleSwap,
Int8DynActInt4WeightQATQuantizerModuleSwap,
)

try:
# torchao 0.7+
from torchao.quantization.qat import (
Int4WeightOnlyQATQuantizer,
Int8DynActInt4WeightQATQuantizer,
)
from torchao.quantization.qat.linear import (
disable_4w_fake_quant,
disable_8da4w_fake_quant,
enable_4w_fake_quant,
enable_8da4w_fake_quant,
)
except ImportError:
# torchao 0.6 and before
from torchao.quantization.prototype.qat import (
disable_4w_fake_quant,
disable_8da4w_fake_quant,
enable_4w_fake_quant,
enable_8da4w_fake_quant,
Int4WeightOnlyQATQuantizer,
Int8DynActInt4WeightQATQuantizer,
)


__all__ = [
Expand All @@ -52,9 +60,9 @@
_quantizer_mode_to_enable_fake_quant = {}


# ========================================================
# int8 dynamic activations + int4 weight tensor subclass |
# ========================================================
# ========================================
# int8 dynamic activations + int4 weight |
# ========================================


class Int8DynActInt4WeightQuantizer:
Expand Down Expand Up @@ -106,15 +114,15 @@ def quantize(self, model):
_quantizer_mode_to_enable_fake_quant["4w-qat"] = enable_4w_fake_quant


# =============
# module swap |
# =============
# ====================== #
# Backward compatibility #
# ====================== #

# Note: QAT tensor subclass implementation in torchao only works
# with FSDP2 today. For other distribution strategies like DDP and
# FSDP1, users will need to fall back to the old module swap flow.

# int4 weight-only
Int4WeightOnlyQATQuantizerModuleSwap = Int4WeightOnlyQATQuantizer
disable_4w_fake_quant_module_swap = disable_4w_fake_quant
enable_4w_fake_quant_module_swap = enable_4w_fake_quant
_quantizer_to_mode[Int4WeightOnlyQATQuantizerModuleSwap] = "4w-qat-module-swap"
_quantizer_mode_to_disable_fake_quant[
"4w-qat-module-swap"
Expand All @@ -124,6 +132,9 @@ def quantize(self, model):
] = enable_4w_fake_quant_module_swap

# int8 dynamic activations + int4 weight
Int8DynActInt4WeightQATQuantizerModuleSwap = Int8DynActInt4WeightQATQuantizer
disable_8da4w_fake_quant_module_swap = disable_8da4w_fake_quant
enable_8da4w_fake_quant_module_swap = enable_8da4w_fake_quant
_quantizer_to_mode[Int8DynActInt4WeightQATQuantizerModuleSwap] = "8da4w-qat-module-swap"
_quantizer_mode_to_disable_fake_quant[
"8da4w-qat-module-swap"
Expand All @@ -141,16 +152,23 @@ def get_quantizer_mode(quantizer: Optional[Callable]) -> Optional[str]:

Currently supported:

- :class:`~torchao.quantization.quant_api.Int8DynActInt4WeightQuantizer`: "8da4w" (requires ``torch>=2.3.0``)
- :class:`~torchao.quantization.prototype.qat.Int8DynActInt4WeightQATQuantizer`: "8da4w-qat" (requires ``torch>=2.4.0``)
- :class:`~torchtune.training.quantization.Int8DynActInt4WeightQuantizer`: "8da4w"
- :class:`~torchtune.training.quantization.Int4WeightOnlyQuantizer`: "4w"
- :class:`~torchao.quantization.qat.Int8DynActInt4WeightQATQuantizer`: "8da4w-qat"
- :class:`~torchao.quantization.qat.Int4WeightOnlyQATQuantizer`: "4w-qat"

Args:
quantizer (Optional[Callable]): A callable object that implements the `quantize` method.

Returns:
Optional[str]: The quantization mode.
"""
return _quantizer_to_mode.get(type(quantizer), None)
mode = _quantizer_to_mode.get(type(quantizer), None)
if mode is not None and "module-swap" in mode:
warn(
"*QuantizerModuleSwap is deprecated. Please use the version without 'ModuleSwap' instead"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should be able to just give the exact classes in the warn message using e.g. f"{quantizer.__class__.__name__}" (or something like that)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm I don't think that's possible given we just assign the classes today, e.g.:

>>> class A: pass
... 
>>> B = A
>>> A().__class__.__name__
'A'
>>> B().__class__.__name__
'A'

)
return mode


def _get_disable_fake_quant(quantizer_mode: str) -> Callable:
Expand Down
Loading