-
Notifications
You must be signed in to change notification settings - Fork 169
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor int4 weight only quantization with call to
quantize
Summary: This is similar to #294 but applied for int4 weight only quantization Test Plan: unit perf test: python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_int4_wo_quant_perf elapsed time: 0.2166275215148926, ref elapsed time: 0.2191881561279297 elapsed time: 0.2376406478881836, ref elapsed time: 0.22721023559570314 elapsed time: 0.21919679641723633, ref elapsed time: 0.2154969596862793 integration perf test: reference: elapsed_time: 2.5900126953125 milliseconds after refactor: elapsed_time: 2.56680078125 milliseconds diff: no diff TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py Before: After: generated code diff: Reviewers: Subscribers: Tasks: Tags:
- Loading branch information
1 parent
55a4676
commit b250618
Showing
7 changed files
with
353 additions
and
205 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
"""Benchmarks for affine quantized tensor, this includes int8 dynamic quant, int8 weight only quant and int4 weight only quant APIs | ||
""" | ||
import torch | ||
from torchao.quantization.subclass import ( | ||
Int8WeightOnlyQuantizedLinearWeight, | ||
Int4WeightOnlyQuantizedLinearWeight, | ||
) | ||
from torchao.quantization.utils import ( | ||
TORCH_VERSION_AFTER_2_4, | ||
) | ||
from torchao.quantization.quant_api import ( | ||
_replace_with_custom_fn_if_matches_filter, | ||
) | ||
import copy | ||
|
||
class ToyLinearModel(torch.nn.Module): | ||
def __init__(self, m=64, n=32, k=64): | ||
super().__init__() | ||
self.linear1 = torch.nn.Linear(m, n, bias=False).to(torch.float) | ||
self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.float) | ||
|
||
def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"): | ||
return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),) | ||
|
||
def forward(self, x): | ||
x = self.linear1(x) | ||
x = self.linear2(x) | ||
return x | ||
|
||
def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs): | ||
""" | ||
The deprecated implementation for int8 dynamic quant API, used as a reference for | ||
numerics and performance | ||
""" | ||
from torchao.quantization.quant_api import _in_features_greater_than_16 | ||
from torchao.quantization.quant_api import _is_linear | ||
from torchao.quantization.quant_api import _get_subclass_inserter | ||
from torchao.quantization.subclass import Int8DynamicallyQuantizedLinearWeight | ||
|
||
if filter_fn is None: | ||
filter_fn = lambda *args: _is_linear(*args) and _in_features_greater_than_16( | ||
*args | ||
) | ||
|
||
_replace_with_custom_fn_if_matches_filter( | ||
model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs), filter_fn | ||
) | ||
|
||
def _get_ref_change_linear_weights_to_woqtensors(deprecated_tenosr_subclass): | ||
def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs): | ||
""" | ||
The deprecated implementation for weight only quant API, used as a reference for | ||
numerics and performance | ||
""" | ||
from torchao.quantization.quant_api import _is_linear | ||
from torchao.quantization.quant_api import _get_subclass_inserter | ||
|
||
filter_fn = kwargs.pop("filter_fn", _is_linear) | ||
|
||
_replace_with_custom_fn_if_matches_filter( | ||
model, | ||
_get_subclass_inserter(deprecated_tenosr_subclass, enable_parametrization=True, **kwargs), | ||
filter_fn, | ||
) | ||
|
||
return _ref_change_linear_weights_to_woqtensors | ||
|
||
_ref_change_linear_weights_to_int8_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int8WeightOnlyQuantizedLinearWeight) | ||
_ref_change_linear_weights_to_int4_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight) | ||
|
||
|
||
def _bench_quantized_tensor_subclass_perf(api, ref_api, kwargs=None): | ||
if kwargs is None: | ||
kwargs = {} | ||
|
||
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda") | ||
m_ref = copy.deepcopy(m) | ||
# setting batch_size to 20 to be compatible with the kernel | ||
example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda") | ||
|
||
api(m, **kwargs) | ||
|
||
# reference | ||
ref_api(m_ref, **kwargs) | ||
|
||
res = m(*example_inputs) | ||
ref = m_ref(*example_inputs) | ||
|
||
assert torch.equal(res, ref) | ||
|
||
# perf comparison | ||
from torchao.utils import benchmark_model | ||
# warmup | ||
WARMUP = 5 | ||
RUNS = 100 | ||
input_tensor = example_inputs[0] | ||
m = torch.compile(m, mode='max-autotune', fullgraph=True) | ||
|
||
benchmark_model(m, WARMUP, input_tensor) | ||
elapsed_time = benchmark_model(m, RUNS, input_tensor) | ||
|
||
m_ref = torch.compile(m_ref, mode='max-autotune', fullgraph=True) | ||
benchmark_model(m_ref, WARMUP, input_tensor) | ||
ref_elapsed_time = benchmark_model(m_ref, RUNS, input_tensor) | ||
|
||
print(f"elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}") | ||
assert elapsed_time < 1.05 * ref_elapsed_time | ||
|
||
if __name__ == "__main__" and TORCH_VERSION_AFTER_2_4 and torch.cuda.is_available(): | ||
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors | ||
_bench_quantized_tensor_subclass_perf(change_linear_weights_to_int8_dqtensors, _ref_change_linear_weights_to_int8_dqtensors) | ||
|
||
from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors | ||
_bench_quantized_tensor_subclass_perf(change_linear_weights_to_int8_woqtensors, _ref_change_linear_weights_to_int8_woqtensors) | ||
|
||
kwargs = {"groupsize": 32} | ||
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors | ||
_bench_quantized_tensor_subclass_perf(change_linear_weights_to_int4_woqtensors, _ref_change_linear_weights_to_int4_woqtensors, kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.