Skip to content

[WIP] Make AWQ more general #2400

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion torchao/prototype/awq/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .api import awq_uintx, insert_awq_observer_
from .api import AWQConfig, awq_uintx, insert_awq_observer_
from .core import AWQObservedLinear

__all__ = [
"awq_uintx",
"insert_awq_observer_",
"AWQObservedLinear",
"AWQConfig",
]
93 changes: 77 additions & 16 deletions torchao/prototype/awq/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,15 @@
ZeroPointDomain,
)
from torchao.quantization.transform_module import (
_QUANTIZE_CONFIG_HANDLER,
register_quantize_module_handler,
)
from torchao.utils import DummyModule

from .core import (
AWQObservedLinear,
AWQObserver,
AWQObserver2,
)

assert len(_DTYPE_TO_BIT_WIDTH) > 0, (
Expand All @@ -50,6 +53,7 @@ def insert_awq_observer_(
quant_dtype: torch.dtype = torch.uint4,
scale_search_space_size: int = 20,
group_size: int = 128,
base_config: Optional[AOBaseConfig] = None,
):
"""
Inserts AWQObserver into Linear layers of a given model.
Expand Down Expand Up @@ -80,22 +84,32 @@ def insert_awq_observer_(

def replace_with_observer(layer):
# creates observer and replaces linear layers with AWQObservedLinear layers
observer = AWQObserver(
layer.weight,
layer.bias,
quantization_granularity,
mapping_type,
quant_dtype,
n_validation_examples,
validation_sequence_len,
scale_search_space_size,
preserve_zero=preserve_zero,
zero_point_domain=zero_point_domain,
zero_point_dtype=zero_point_dtype,
quant_min=quant_min,
quant_max=quant_max,
eps=eps,
)
if base_config is None:
observer = AWQObserver(
layer.weight,
layer.bias,
quantization_granularity,
mapping_type,
quant_dtype,
n_validation_examples,
validation_sequence_len,
scale_search_space_size,
preserve_zero=preserve_zero,
zero_point_domain=zero_point_domain,
zero_point_dtype=zero_point_dtype,
quant_min=quant_min,
quant_max=quant_max,
eps=eps,
)
else:
observer = AWQObserver2(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you not add kwargs to the AWQObserver and just check 'base_config' in kwargs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this is temporary, I think we can deprecate the old one in the end

layer.weight,
layer.bias,
base_config,
n_validation_examples,
validation_sequence_len,
scale_search_space_size,
)
return AWQObservedLinear.from_float(layer, observer)

_replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear)
Expand Down Expand Up @@ -194,3 +208,50 @@ def _awq_uintx_transform(
linear.extra_repr = types.MethodType(_linear_extra_repr, module)
linear.bias = observed_linear.bias
return linear


@dataclass
class AWQConfig(AOBaseConfig):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok this is consolidating with quantize_ api's config based design?

"""
Configuration for quantizing linear layers when passed into quantize_()

Args:
quant_dtype: The data type of the quantized weights. Currently only torch.uint4 is intended to be used but can be used with torch.uint1 -> torch.uint8
`layout`: layout type for quantized tensor, default is `TensorCoreTiledLayout(inner_k_tiles=8)`
group_size: Quantization granularity. Use -1 for channel wise quantization
weight_quant_fn: The quantization function to be used, which takes in the weight and returns the quantized weight. If None, then affine uint4 quantization is used
set_inductor_config: if True, adjusts `torchinductor` settings to recommended values.
"""

base_config: AOBaseConfig
set_inductor_config: bool = True


@register_quantize_module_handler(AWQConfig)
def _awq_transform(
module: torch.nn.Module,
config: AWQUIntXConfig,
) -> torch.nn.Module:
if config.set_inductor_config:
torchao.quantization.utils.recommended_inductor_config_setter()

observed_linear = module
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is for linear only should you not assert that this nn.Linear? Plus how to you make sure this function is called only on nn.Linear?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that's true, will add an assert, we rely on user to use quantize_ correctly (it's through specifying the filter_fn arg in quantize_ API)

filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None,

equalization_scale = observed_linear.act_obs.calculate_qparams()

base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(config.base_config)]
dummy_mod = DummyModule(observed_linear.weight * equalization_scale)
quant_mod = base_config_handler(dummy_mod, config.base_config)
Comment on lines +242 to +243
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure whats happening here?. Isnt module already nn.Module?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is just trying to quantize the weight with the quantization type specified by config.base_config

qw = quant_mod.weight
qw = to_weight_tensor_with_linear_activation_scale_metadata(qw, equalization_scale)

linear = torch.nn.Linear(
observed_linear.in_features,
observed_linear.out_features,
observed_linear.bias != None,
device=observed_linear.weight.device,
dtype=observed_linear.weight.dtype,
)
linear.weight = torch.nn.Parameter(qw, requires_grad=False)
linear.extra_repr = types.MethodType(_linear_extra_repr, module)
linear.bias = observed_linear.bias
return linear
134 changes: 134 additions & 0 deletions torchao/prototype/awq/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import torch
import torch.nn.functional as F

from torchao.core.config import AOBaseConfig
from torchao.dtypes import to_affine_quantized_intx
from torchao.dtypes.uintx.uintx_layout import UintxLayout
from torchao.quantization import Int8DynamicActivationIntxWeightConfig
from torchao.quantization.granularity import Granularity
from torchao.quantization.observer import (
AffineQuantizedObserverBase,
Expand All @@ -18,6 +20,10 @@
MappingType,
ZeroPointDomain,
)
from torchao.quantization.transform_module import (
_QUANTIZE_CONFIG_HANDLER,
)
from torchao.utils import DummyModule


class AWQObserver(AffineQuantizedObserverBase):
Expand Down Expand Up @@ -145,6 +151,134 @@ def calculate_qparams(self):
return best_scales.detach()


class AWQObserver2(AffineQuantizedObserverBase):
def __init__(
self,
weight: torch.Tensor,
bias: torch.Tensor,
config: AOBaseConfig,
n_validation_examples: int,
validation_sequence_len: int,
scale_search_space_size: int = 20,
base_config: Optional[AOBaseConfig] = None,
):
"""
A custom observer for Activation aware Weight Quantization (AWQ)

Args:
weight: The weight tensor to be observed.
bias: The bias tensor to be observed.
quantization_granularity: Granularity which specifies how many weights share the same scale/zero point
input_dtype: The data type of the input tensor.
mapping_type: Always set to asymmetric
target_dtype: The target data type of the quantized tensor
n_validation_examples: Number of examples used to calibrate observer
validation_sequence_len: Number of tokens in each example
scale_search_space_size: The number of scales to search for.
quant_min: The minimum quantized value
quant_max: The maximum quantized value
eps: The minimum scale.
scale_dtype: The data type of the scale tensor.
zero_point_dtype: The data type of the zero point tensor.
preserve_zero: A flag to indicate whether we need zero to be exactly
representable or not.
zero_point_domain: The domain of the zero point.
"""
self.base_config = base_config
quant_min = getattr(config, "quant_min", None)
quant_max = getattr(config, "quant_max", None)

assert isinstance(base_config, Int8DynamicActivationIntxWeightConfig)
# TODO:
quantization_granularity = base_config.weight_granularity
target_dtype = base_config.weight_dtype
mapping_type = base_config.weight_mapping_type

# TODO:
super().__init__(
mapping_type,
target_dtype,
quantization_granularity,
quant_min=quant_min,
quant_max=quant_max,
)
self.quantization_granularity = quantization_granularity
self.weight = weight
self.bias = bias
self.n_validation_examples = n_validation_examples
self.validation_sequence_len = validation_sequence_len
self.calibration_token_count = 0
self.inputs = []
self.outputs = []
self.scale_options = scale_search_space_size
self.device = self.weight.device
self.average = torch.zeros((1, weight.shape[1]), device=self.device)
if self.bias is not None:
self.bias.to(self.device)

@torch.no_grad()
def forward(self, input: torch.Tensor, output: torch.Tensor):
# import pdb
# pdb.set_trace()
# print(input.shape, input.abs().sum(1).shape, self.average.shape)
if len(self.inputs) < self.n_validation_examples:
self.inputs.append(input.to("cpu"))
self.outputs.append(output.to("cpu"))
self.calibration_token_count += input.shape[-2]
self.average += input.abs().sum(-2)

def calculate_qparams(self):
# import pdb
# pdb.set_trace()
assert self.outputs != None, (
"calibrate observer first by running model on exemplar data"
)
self.average /= self.calibration_token_count
for i in range(self.n_validation_examples):
self.inputs[i] = self.inputs[i].to(self.device)
self.outputs[i] = self.outputs[i].to(self.device)

best_loss = float("inf")
best_scales = None
for i in range(self.scale_options):
ratio = i * 1 / self.scale_options
scales = self.average.pow(ratio).to(self.weight.dtype)
scales = scales / (scales.max() * scales.min()).sqrt()
# layout = UintxLayout(self.target_dtype)
# # regardless of weight dtype, we have to store as packed uint8 tensors
# tensor_dtype = torch.uint8
# w = to_affine_quantized_intx(
# self.weight * scales,
# self.mapping_type,
# (1, self.quantization_granularity.group_size),
# tensor_dtype,
# quant_min=self.quant_min,
# quant_max=self.quant_max,
# eps=self.eps,
# scale_dtype=self.scale_dtype,
# zero_point_dtype=self.zero_point_dtype,
# preserve_zero=self.preserve_zero,
# zero_point_domain=self.zero_point_domain,
# _layout=layout,
# )
base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(self.base_config)]
dummy_mod = DummyModule(self.weight * scales)
quant_mod = base_config_handler(dummy_mod, self.base_config)
w = quant_mod.weight

loss = 0
for i in range(self.n_validation_examples):
q_out = F.linear(self.inputs[i] / scales, w, self.bias)
loss += (self.outputs[i] - q_out).pow(2).mean().item()
if loss < best_loss:
best_scales = scales
best_loss = loss
for i in range(self.n_validation_examples):
self.inputs[i].to("cpu")
self.outputs[i].to("cpu")
return best_scales.detach()


class AWQObservedLinear(torch.nn.Linear):
def __init__(
self,
Expand Down
Loading
Loading