Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 40 additions & 34 deletions test/prototype/test_smoothquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,15 @@
)
from torchao.prototype.smoothquant.core import SmoothQuantStep
from torchao.quantization import quantize_
from torchao.quantization.linear_activation_scale import (
WeightTensorWithLinearActivationScaleMetadata,
)
from torchao.quantization.quant_api import (
Int8DynamicActivationInt8WeightConfig,
)
from torchao.quantization.utils import (
compute_error as SQNR,
)


class ToyLinearModel(torch.nn.Module):
Expand All @@ -34,16 +40,19 @@ def example_inputs(
dtype=torch.bfloat16,
device="cuda",
):
return [
torch.randn(
1,
sequence_length,
self.linear1.in_features,
dtype=dtype,
device=device,
)
for j in range(batch_size)
]
# For SmoothQuant tests, we intentionally insert some outliers to input features
x = torch.randn(
batch_size,
sequence_length,
self.linear1.in_features,
dtype=dtype,
device=device,
)
n_outliers = max(1, int(x.size(-1) * 0.1))
# Randomly select outlier features
outlier_indices = torch.randperm(x.size(-1))[:n_outliers]
x[:, :, outlier_indices] *= 10.0
return (x,)

def forward(self, x):
x = self.linear1(x)
Expand All @@ -52,7 +61,9 @@ def forward(self, x):
return x


@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
device_list = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]


@unittest.skipIf(torch.version.hip is not None, "Skipping tests in ROCm")
class TestSmoothQuant(unittest.TestCase):
"""SmoothQuant tests using only supported quantization configs."""
Expand All @@ -72,37 +83,25 @@ def setUpClass(cls):
# TODO(#1639): Fix for supporting more API in torchao/quantization/quant_api.py
],
)
@common_utils.parametrize("device", ["cpu", "cuda"])
@common_utils.parametrize("device", device_list)
@common_utils.parametrize("input_dtype", [torch.bfloat16])
def test_smoothquant_accuracy(self, alpha, base_config, device, input_dtype):
"""Test if SmoothQuant achieves lower loss than basic quantization."""
in_features = 64
out_features = 128

# Note: This is sanity check. For real run, consider Transformer model to reproduce.
X = torch.randn(16, in_features, dtype=input_dtype, device=device)
W = torch.randn(out_features, in_features, dtype=input_dtype, device=device)

# Create linear layer
linear = (
torch.nn.Linear(in_features, out_features, bias=False)
.to(device)
.to(input_dtype)
)
with torch.no_grad():
linear.weight.copy_(W)
m = ToyLinearModel().eval().to(device).to(input_dtype)
x = m.example_inputs(batch_size=16, dtype=input_dtype, device=device)

# Reference output
out_ref = linear(X)
out_ref = m(*x)

# Step 1. Basic quantization
basic_model = deepcopy(linear)
basic_model = deepcopy(m)
quantize_(basic_model, base_config)
out_basic = basic_model(X)
out_basic = basic_model(*x)
loss_base = torch.nn.functional.mse_loss(out_basic, out_ref).item()

# SmoothQuant quantization
model = deepcopy(linear)
# Step 2. SmoothQuant
model = deepcopy(m)
config = SmoothQuantConfig(
base_config=base_config,
step=SmoothQuantStep.PREPARE,
Expand All @@ -111,18 +110,25 @@ def test_smoothquant_accuracy(self, alpha, base_config, device, input_dtype):
quantize_(model, config)

# Perform calibration with test data
model(X)
model(*x)

# Step 2. SmoothQuant
config.step = SmoothQuantStep.CONVERT
quantize_(model, config)
assert isinstance(
model.linear1.weight, WeightTensorWithLinearActivationScaleMetadata
)
assert isinstance(
model.linear2.weight, WeightTensorWithLinearActivationScaleMetadata
)

out_smoothquant = model(X)
out_smoothquant = model(*x)
loss_smoothquant = torch.nn.functional.mse_loss(out_smoothquant, out_ref).item()

assert loss_smoothquant < loss_base, (
f"SmoothQuant loss ({loss_smoothquant:.6f}) should not be higher than basic loss ({loss_base:.6f})"
)
# Make sure the result is reasonable
self.assertGreater(SQNR(out_ref, out_smoothquant), 20.0)

@common_utils.parametrize(
"base_config",
Expand Down
Loading