-
Notifications
You must be signed in to change notification settings - Fork 288
[WIP] Make AWQ more general #2400
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,9 @@ | ||
from .api import awq_uintx, insert_awq_observer_ | ||
from .api import AWQConfig, awq_uintx, insert_awq_observer_ | ||
from .core import AWQObservedLinear | ||
|
||
__all__ = [ | ||
"awq_uintx", | ||
"insert_awq_observer_", | ||
"AWQObservedLinear", | ||
"AWQConfig", | ||
] |
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -30,12 +30,15 @@ | |||
ZeroPointDomain, | ||||
) | ||||
from torchao.quantization.transform_module import ( | ||||
_QUANTIZE_CONFIG_HANDLER, | ||||
register_quantize_module_handler, | ||||
) | ||||
from torchao.utils import DummyModule | ||||
|
||||
from .core import ( | ||||
AWQObservedLinear, | ||||
AWQObserver, | ||||
AWQObserver2, | ||||
) | ||||
|
||||
assert len(_DTYPE_TO_BIT_WIDTH) > 0, ( | ||||
|
@@ -50,6 +53,7 @@ def insert_awq_observer_( | |||
quant_dtype: torch.dtype = torch.uint4, | ||||
scale_search_space_size: int = 20, | ||||
group_size: int = 128, | ||||
base_config: Optional[AOBaseConfig] = None, | ||||
): | ||||
""" | ||||
Inserts AWQObserver into Linear layers of a given model. | ||||
|
@@ -80,22 +84,32 @@ def insert_awq_observer_( | |||
|
||||
def replace_with_observer(layer): | ||||
# creates observer and replaces linear layers with AWQObservedLinear layers | ||||
observer = AWQObserver( | ||||
layer.weight, | ||||
layer.bias, | ||||
quantization_granularity, | ||||
mapping_type, | ||||
quant_dtype, | ||||
n_validation_examples, | ||||
validation_sequence_len, | ||||
scale_search_space_size, | ||||
preserve_zero=preserve_zero, | ||||
zero_point_domain=zero_point_domain, | ||||
zero_point_dtype=zero_point_dtype, | ||||
quant_min=quant_min, | ||||
quant_max=quant_max, | ||||
eps=eps, | ||||
) | ||||
if base_config is None: | ||||
observer = AWQObserver( | ||||
layer.weight, | ||||
layer.bias, | ||||
quantization_granularity, | ||||
mapping_type, | ||||
quant_dtype, | ||||
n_validation_examples, | ||||
validation_sequence_len, | ||||
scale_search_space_size, | ||||
preserve_zero=preserve_zero, | ||||
zero_point_domain=zero_point_domain, | ||||
zero_point_dtype=zero_point_dtype, | ||||
quant_min=quant_min, | ||||
quant_max=quant_max, | ||||
eps=eps, | ||||
) | ||||
else: | ||||
observer = AWQObserver2( | ||||
layer.weight, | ||||
layer.bias, | ||||
base_config, | ||||
n_validation_examples, | ||||
validation_sequence_len, | ||||
scale_search_space_size, | ||||
) | ||||
return AWQObservedLinear.from_float(layer, observer) | ||||
|
||||
_replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear) | ||||
|
@@ -194,3 +208,50 @@ def _awq_uintx_transform( | |||
linear.extra_repr = types.MethodType(_linear_extra_repr, module) | ||||
linear.bias = observed_linear.bias | ||||
return linear | ||||
|
||||
|
||||
@dataclass | ||||
class AWQConfig(AOBaseConfig): | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok this is consolidating with quantize_ api's config based design? |
||||
""" | ||||
Configuration for quantizing linear layers when passed into quantize_() | ||||
|
||||
Args: | ||||
quant_dtype: The data type of the quantized weights. Currently only torch.uint4 is intended to be used but can be used with torch.uint1 -> torch.uint8 | ||||
`layout`: layout type for quantized tensor, default is `TensorCoreTiledLayout(inner_k_tiles=8)` | ||||
group_size: Quantization granularity. Use -1 for channel wise quantization | ||||
weight_quant_fn: The quantization function to be used, which takes in the weight and returns the quantized weight. If None, then affine uint4 quantization is used | ||||
set_inductor_config: if True, adjusts `torchinductor` settings to recommended values. | ||||
""" | ||||
|
||||
base_config: AOBaseConfig | ||||
set_inductor_config: bool = True | ||||
|
||||
|
||||
@register_quantize_module_handler(AWQConfig) | ||||
def _awq_transform( | ||||
module: torch.nn.Module, | ||||
config: AWQUIntXConfig, | ||||
) -> torch.nn.Module: | ||||
if config.set_inductor_config: | ||||
torchao.quantization.utils.recommended_inductor_config_setter() | ||||
|
||||
observed_linear = module | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this is for linear only should you not assert that this nn.Linear? Plus how to you make sure this function is called only on nn.Linear? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah that's true, will add an assert, we rely on user to use quantize_ correctly (it's through specifying the filter_fn arg in quantize_ API) ao/torchao/quantization/quant_api.py Line 578 in 4e3d019
|
||||
equalization_scale = observed_linear.act_obs.calculate_qparams() | ||||
|
||||
base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(config.base_config)] | ||||
dummy_mod = DummyModule(observed_linear.weight * equalization_scale) | ||||
quant_mod = base_config_handler(dummy_mod, config.base_config) | ||||
Comment on lines
+242
to
+243
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure whats happening here?. Isnt There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is just trying to quantize the weight with the quantization type specified by config.base_config |
||||
qw = quant_mod.weight | ||||
qw = to_weight_tensor_with_linear_activation_scale_metadata(qw, equalization_scale) | ||||
|
||||
linear = torch.nn.Linear( | ||||
observed_linear.in_features, | ||||
observed_linear.out_features, | ||||
observed_linear.bias != None, | ||||
device=observed_linear.weight.device, | ||||
dtype=observed_linear.weight.dtype, | ||||
) | ||||
linear.weight = torch.nn.Parameter(qw, requires_grad=False) | ||||
linear.extra_repr = types.MethodType(_linear_extra_repr, module) | ||||
linear.bias = observed_linear.bias | ||||
return linear |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you not add kwargs to the AWQObserver and just check
'base_config' in kwargs
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, this is temporary, I think we can deprecate the old one in the end