-
Notifications
You must be signed in to change notification settings - Fork 297
Add HQQ support #605
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add HQQ support #605
Changes from all commits
f99a90c
ab9ea3d
153fed1
e022654
f7a9e50
082dc58
77d498a
9a83eda
c65e796
b5abf39
c5834fe
3303d95
41b2fb0
9382ec1
d15accf
93ca471
2146a5a
1e5eec8
890a7be
9700d52
0b511f5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import unittest | ||
import torch | ||
from torchao.dtypes.affine_quantized_tensor import ( | ||
to_affine_quantized, | ||
ZeroPointDomain, | ||
PlainAQTLayout, | ||
PlainLayoutType, | ||
TensorCoreTiledAQTLayout, | ||
TensorCoreTiledLayoutType, | ||
MappingType, | ||
) | ||
|
||
from torchao.utils import ( | ||
TORCH_VERSION_AT_LEAST_2_4, | ||
TORCH_VERSION_AT_LEAST_2_5, | ||
) | ||
|
||
cuda_available = torch.cuda.is_available() | ||
|
||
#Parameters | ||
device = 'cuda:0' | ||
compute_dtype = torch.bfloat16 | ||
group_size = 64 | ||
mapping_type = MappingType.ASYMMETRIC | ||
block_size = (1, group_size) #axis=1 | ||
preserve_zero = False | ||
zero_point_domain = ZeroPointDomain.FLOAT | ||
zero_point_dtype = compute_dtype | ||
inner_k_tiles = 8 | ||
in_features = 4096 | ||
out_features = 11800 | ||
torch_seed = 100 | ||
|
||
|
||
def _init_data(in_features, out_features, compute_dtype, device, torch_seed): | ||
torch.random.manual_seed(torch_seed) | ||
linear_layer = torch.nn.Linear(in_features, out_features, bias=False, device=device) | ||
x = torch.randn((1, linear_layer.in_features), dtype=torch.float, device=device)/20. | ||
y_ref = linear_layer(x) | ||
W = linear_layer.weight.data.clone().to(device=device, dtype=compute_dtype) | ||
return W, x, y_ref | ||
|
||
def _eval_hqq(nbits, layout_type): | ||
W, x, y_ref = _init_data(in_features, out_features, compute_dtype, device, torch_seed) | ||
|
||
#Plain layout | ||
target_dtype = torch.uint8 | ||
#Tensorcore layout | ||
if isinstance(layout_type, TensorCoreTiledLayoutType): | ||
target_dtype = torch.uint8 if TORCH_VERSION_AT_LEAST_2_5 else torch.int32 | ||
|
||
q_tensor_hqq = to_affine_quantized( | ||
input_float=W, | ||
mapping_type=mapping_type, | ||
block_size=block_size, | ||
target_dtype=target_dtype, | ||
quant_min=0, | ||
quant_max=2**nbits - 1, | ||
zero_point_domain=zero_point_domain, | ||
preserve_zero=preserve_zero, | ||
layout_type=layout_type, | ||
use_hqq=True, | ||
) | ||
|
||
quant_linear_layer = torch.nn.Linear(W.shape[1], W.shape[0], bias=False, device=W.device) | ||
del quant_linear_layer.weight | ||
quant_linear_layer.weight = q_tensor_hqq | ||
dequantize_error = (W - q_tensor_hqq.dequantize()).abs().mean().item() | ||
dot_product_error = (y_ref - quant_linear_layer(x.to(compute_dtype))).abs().mean().item() | ||
|
||
return dequantize_error, dot_product_error | ||
|
||
|
||
class TestHQQBase(unittest.TestCase): | ||
@unittest.skipIf(not cuda_available, "Need CUDA available") | ||
def test_hqq(self, nbits=None, layout_type=None, ref_dequantize_error=None, ref_dot_product_error=None): | ||
if(nbits is None): return | ||
dequantize_error, dot_product_error = _eval_hqq(nbits=nbits, layout_type=layout_type) | ||
self.assertTrue(dequantize_error < ref_dequantize_error) | ||
self.assertTrue(dot_product_error < ref_dot_product_error) | ||
|
||
class TestHQQ8Bit(TestHQQBase): | ||
def test_hqq_plain_8bit(self): | ||
self.test_hqq(nbits=8, layout_type=PlainLayoutType(), ref_dequantize_error=5e-5, ref_dot_product_error=0.00013) | ||
|
||
class TestHQQ7Bit(TestHQQBase): | ||
def test_hqq_plain_7bit(self): | ||
self.test_hqq(nbits=7, layout_type=PlainLayoutType(), ref_dequantize_error=6e-05, ref_dot_product_error=0.000193) | ||
|
||
class TestHQQ6Bit(TestHQQBase): | ||
def test_hqq_plain_6bit(self): | ||
self.test_hqq(nbits=6, layout_type=PlainLayoutType(), ref_dequantize_error=0.0001131, ref_dot_product_error=0.000353) | ||
|
||
class TestHQQ5Bit(TestHQQBase): | ||
def test_hqq_plain_5bit(self): | ||
self.test_hqq(nbits=5, layout_type=PlainLayoutType(), ref_dequantize_error=0.00023, ref_dot_product_error=0.000704) | ||
|
||
class TestHQQ4bit(TestHQQBase): | ||
def test_hqq_plain_4bit(self): | ||
self.test_hqq(nbits=4, layout_type=PlainLayoutType(), ref_dequantize_error=0.000487, ref_dot_product_error=0.001472) | ||
|
||
def test_hqq_tensorcore_4bit(self): | ||
self.test_hqq(nbits=4, layout_type=TensorCoreTiledLayoutType(inner_k_tiles=inner_k_tiles), ref_dequantize_error=0.000487, ref_dot_product_error=0.00147) | ||
|
||
class TestHQQ3Bit(TestHQQBase): | ||
def test_hqq_plain_3bit(self): | ||
self.test_hqq(nbits=3, layout_type=PlainLayoutType(), ref_dequantize_error=0.00101, ref_dot_product_error=0.003047) | ||
|
||
class TestHQQ2Bit(TestHQQBase): | ||
def test_hqq_plain_2bit(self): | ||
self.test_hqq(nbits=2, layout_type=PlainLayoutType(), ref_dequantize_error=0.002366, ref_dot_product_error=0.007255) | ||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import torch | ||
from torchao.prototype.hqq.core import HQQQuantizer | ||
from torchao.dtypes.affine_quantized_tensor import ( | ||
to_affine_quantized, | ||
ZeroPointDomain, | ||
PlainAQTLayout, | ||
PlainLayoutType, | ||
TensorCoreTiledAQTLayout, | ||
TensorCoreTiledLayoutType, | ||
MappingType, | ||
) | ||
|
||
#Parameters | ||
device, compute_dtype = "cuda:0", torch.bfloat16 | ||
group_size, axis = 64, 1 | ||
in_features, out_features = 4096, 11800 | ||
|
||
torch.random.manual_seed(100) | ||
linear_layer = torch.nn.Linear(in_features, out_features, bias=False, device=device) | ||
x = torch.randn((1, linear_layer.in_features), dtype=torch.float, device=device)/20. | ||
y_ref = linear_layer(x) | ||
W = linear_layer.weight.data.clone().to(device=device, dtype=compute_dtype) | ||
del linear_layer.weight | ||
|
||
################################################################################################ | ||
#AffineQuantizedTensor example | ||
################################################################################################ | ||
print('-------------------------------------------------------------------') | ||
print('AffineQuantizedTensor example') | ||
print('-------------------------------------------------------------------') | ||
mapping_type = MappingType.ASYMMETRIC | ||
block_size = (1, group_size) | ||
target_dtype = torch.uint8 #until sub-byte dtypes are supported | ||
preserve_zero = False | ||
zero_point_domain = ZeroPointDomain.FLOAT | ||
zero_point_dtype = compute_dtype | ||
layout_type = PlainLayoutType() | ||
|
||
for nbits in list(range(2, 9))[::-1]: | ||
print('------------------------------------------------------------------------------') | ||
q_tensor_default = to_affine_quantized( | ||
input_float=W, | ||
mapping_type=mapping_type, | ||
block_size=block_size, | ||
target_dtype=target_dtype, | ||
quant_min=0, | ||
quant_max=2**nbits - 1, | ||
zero_point_domain= zero_point_domain, | ||
preserve_zero=preserve_zero, | ||
layout_type=layout_type, | ||
) | ||
|
||
linear_layer.weight = q_tensor_default | ||
print("nbits", nbits, "| Default dequantization error", (W - q_tensor_default.dequantize()).abs().mean().item()) | ||
print("nbits", nbits, '| Default Dot product error', (y_ref - linear_layer(x.to(compute_dtype))).abs().mean().item()) | ||
# nbits 4 | Default dequantization error 0.001953125 | ||
# nbits 4 | Default Dot product error 0.005926903802901506 | ||
|
||
|
||
q_tensor_hqq = to_affine_quantized( | ||
input_float=W, | ||
mapping_type=mapping_type, | ||
block_size=block_size, | ||
target_dtype=target_dtype, | ||
quant_min=0, | ||
quant_max=2**nbits - 1, | ||
zero_point_domain=zero_point_domain, | ||
preserve_zero=preserve_zero, | ||
layout_type=layout_type, | ||
use_hqq=True, | ||
) | ||
|
||
linear_layer.weight = q_tensor_hqq | ||
print("nbits", nbits, "| HQQ dequantization error", (W - q_tensor_hqq.dequantize()).abs().mean().item()) | ||
print("nbits", nbits, '| HQQ Dot product error', (y_ref - linear_layer(x.to(compute_dtype))).abs().mean().item()) | ||
# nbits 4 | HQQ dequantization error 0.0004863739013671875 | ||
# nbits 4 | HQQ Dot product error 0.0014713306445628405 | ||
|
||
################################################################################################ | ||
#quant_api example | ||
################################################################################################ | ||
print('-------------------------------------------------------------------') | ||
print('Quant API example') | ||
print('-------------------------------------------------------------------') | ||
|
||
from torchao.quantization.quant_api import int4_weight_only | ||
nbits = 4 | ||
target_dtype = torch.int32 | ||
inner_k_tiles = 8 | ||
layout_type = TensorCoreTiledLayoutType(inner_k_tiles=inner_k_tiles) | ||
|
||
int4_weight_only_patch_fct = int4_weight_only(group_size=group_size, inner_k_tiles=inner_k_tiles) | ||
linear_layer_default = torch.nn.Linear(in_features, out_features, bias=False, device=device) | ||
linear_layer_default.weight.data = W.clone() | ||
linear_layer_default = int4_weight_only_patch_fct(linear_layer_default) | ||
print("nbits", nbits, "| Default dequantization error", (W - linear_layer_default(torch.eye(W.shape[1], dtype=W.dtype, device=W.device)).T).abs().mean().item()) | ||
print("nbits", nbits, '| Default Dot product error', (y_ref - linear_layer_default(x.to(compute_dtype))).abs().mean().item()) | ||
# nbits 4 | Default dequantization error 0.000492095947265625 | ||
# nbits 4 | Default Dot product error 0.0015244047390297055 | ||
|
||
|
||
q_tensor_hqq = to_affine_quantized( | ||
input_float=W, | ||
mapping_type=mapping_type, | ||
block_size=block_size, | ||
target_dtype=target_dtype, | ||
quant_min=0, | ||
quant_max=2**nbits - 1, | ||
zero_point_domain=zero_point_domain, | ||
preserve_zero=preserve_zero, | ||
layout_type=layout_type, | ||
use_hqq=True, | ||
) | ||
linear_layer.weight = q_tensor_hqq | ||
print("nbits", nbits, "| HQQ dequantization error", (W - linear_layer(torch.eye(W.shape[1], dtype=W.dtype, device=W.device)).T).abs().mean().item()) | ||
print("nbits", nbits, '| HQQ Dot product error', (y_ref - linear_layer(x.to(compute_dtype))).abs().mean().item()) | ||
# nbits 4 | HQQ dequantization error 0.0004863739013671875 | ||
# nbits 4 | HQQ Dot product error 0.0014699687017127872 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -389,7 +389,7 @@ def int4_weight_only(group_size=128, layout_type=TensorCoreTiledLayoutType(inner | |
size is more fine grained, choices are [256, 128, 64, 32] | ||
`layout_type`: layout type for quantized tensor, default is `TensorCoreTiledLayoutType(inner_k_tiles=8)` | ||
""" | ||
def apply_int4_weight_only_quant(weight): | ||
def apply_int4_weight_only_quant(weight, use_hqq=False): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just found that this flag is not used, so we don't really expose hqq to users right now, are you planning to create a new function for hqq? cc @mobicham There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My understand was that @HDCharles suggested putting it there and later turning it on by default. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. commented in #786 (comment) |
||
if weight.shape[-1] % group_size != 0: | ||
return weight | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.