Skip to content

Commit

Permalink
fixed test for int4wo and add __init__.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanxian97 committed Jul 29, 2024
1 parent c08ab33 commit 75b55c2
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 10 deletions.
16 changes: 6 additions & 10 deletions test/quantization/test_mixed_precision.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
import torch
import torch.nn as nn

import os
import sys
# append the path to the naive_intNwo.py file
sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "torchao/quantization/prototype/mixed_precision/scripts"))
from naive_intNwo import intN_weight_only
from torchao.quantization.prototype.mixed_precision.scripts.naive_intNwo import intN_weight_only

from torchao.quantization import quantize_, int8_weight_only, int4_weight_only

Expand All @@ -18,11 +14,11 @@
)

def test_weight_only_quant(quantization_bit=2, symmetric=False):
for x_shape in [[32, 64], [80, 80, 80, 64], [16, 64, 64]]:
x = torch.randn(*x_shape)
m = nn.Sequential(nn.Linear(64, 80))
for x_shape in [[64, 32], [80, 80, 80, 32], [16, 64, 32]]:
x = torch.randn(*x_shape, dtype=torch.bfloat16)
m = nn.Sequential(nn.Linear(32, 80)).bfloat16()
y_ref = m(x)
quantize_(m, intN_weight_only(n=quantization_bit, group_size=16, symmetric=symmetric))
quantize_(m, intN_weight_only(n=quantization_bit, group_size=32, symmetric=symmetric))
y_wo = m(x)
sqnr = compute_error(y_ref, y_wo)
#SQNR_dB can be approximated by 6.02n, where n is the bit width of the quantization
Expand All @@ -31,7 +27,7 @@ def test_weight_only_quant(quantization_bit=2, symmetric=False):


# test if the asymmetric and symmetric quantization API works with different bit widths
for i in [2,3,5,6,8]:
for i in [2, 3, 4, 5, 6, 8]:
#test for asymmetric quantization
try:
test_weight_only_quant(i, False)
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .naive_intNwo import intN_weight_only

0 comments on commit 75b55c2

Please sign in to comment.