diff --git a/test/prototype/test_smoothquant.py b/test/prototype/test_smoothquant.py index 85893f2241..581f75b925 100644 --- a/test/prototype/test_smoothquant.py +++ b/test/prototype/test_smoothquant.py @@ -3,7 +3,6 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -import tempfile import unittest from copy import deepcopy @@ -13,14 +12,11 @@ from torchao.prototype.smoothquant import ( SmoothQuantConfig, SmoothQuantObservedLinear, - insert_smooth_quant_observer_, - load_smooth_quant_recipe, - save_smooth_quant_recipe, ) +from torchao.prototype.smoothquant.core import SmoothQuantStep from torchao.quantization import quantize_ -from torchao.quantization.utils import ( - dequantize_per_channel, - dynamically_quantize_per_channel, +from torchao.quantization.quant_api import ( + Int8DynamicActivationInt8WeightConfig, ) @@ -29,14 +25,22 @@ def __init__(self, m=512, n=256, k=128): super().__init__() self.linear1 = torch.nn.Linear(m, n, bias=False) self.linear2 = torch.nn.Linear(n, k, bias=False) - self.linear3 = torch.nn.Linear(k, 1, bias=False) + self.linear3 = torch.nn.Linear(k, 64, bias=False) def example_inputs( - self, batch_size, sequence_length=10, dtype=torch.bfloat16, device="cuda" + self, + batch_size, + sequence_length=10, + dtype=torch.bfloat16, + device="cuda", ): return [ torch.randn( - 1, sequence_length, self.linear1.in_features, dtype=dtype, device=device + 1, + sequence_length, + self.linear1.in_features, + dtype=dtype, + device=device, ) for j in range(batch_size) ] @@ -48,218 +52,161 @@ def forward(self, x): return x +@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(torch.version.hip is not None, "Skipping tests in ROCm") class TestSmoothQuant(unittest.TestCase): + """SmoothQuant tests using only supported quantization configs.""" + @classmethod def setUpClass(cls): """Set up class-level configuration for tests.""" # This test case will trigger recompilation many times, so set a large cache_size_limit here torch._dynamo.config.cache_size_limit = 128 - @unittest.skip("This test is broken on recent PyTorch, TODO(#1639): fix it") - @common_utils.parametrize("bias", [True, False]) - @common_utils.parametrize("alpha", [None, 0.5, 0.75]) - @common_utils.parametrize("quant_mode", ["static", "dynamic"]) + @common_utils.parametrize("alpha", [0.5, 0.75]) @common_utils.parametrize( - "device", ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) + "base_config", + [ + Int8DynamicActivationInt8WeightConfig(), + # Note: float8_static_activation_float8_weight is broken after recent PyTorch update. + # TODO(#1639): Fix for supporting more API in torchao/quantization/quant_api.py + ], ) - @common_utils.parametrize("input_dtype", [torch.float, torch.bfloat16, torch.half]) - def test_smoothquant_accuracy(self, bias, alpha, quant_mode, device, input_dtype): - """Test the margin error of SmoothQuant across bias, alpha, dtype, etc.""" + @common_utils.parametrize("device", ["cpu", "cuda"]) + @common_utils.parametrize("input_dtype", [torch.bfloat16]) + def test_smoothquant_accuracy(self, alpha, base_config, device, input_dtype): + """Test if SmoothQuant achieves lower loss than basic quantization.""" + in_features = 64 + out_features = 128 + + # Note: This is sanity check. For real run, consider Transformer model to reproduce. + X = torch.randn(16, in_features, dtype=input_dtype, device=device) + W = torch.randn(out_features, in_features, dtype=input_dtype, device=device) + + # Create linear layer + linear = ( + torch.nn.Linear(in_features, out_features, bias=False) + .to(device) + .to(input_dtype) + ) + with torch.no_grad(): + linear.weight.copy_(W) + + # Reference output + out_ref = linear(X) + + # Step 1. Basic quantization + basic_model = deepcopy(linear) + quantize_(basic_model, base_config) + out_basic = basic_model(X) + loss_base = torch.nn.functional.mse_loss(out_basic, out_ref).item() + + # SmoothQuant quantization + model = deepcopy(linear) + config = SmoothQuantConfig( + base_config=base_config, + step=SmoothQuantStep.PREPARE, + alpha=alpha, + ) + quantize_(model, config) - class SimpleLinear(torch.nn.Module): - def __init__(self, bias: bool): - super().__init__() - self.fc = torch.nn.Linear(32, 32, bias) - self.fc.weight.data = torch.randn_like(self.fc.weight.data) + # Perform calibration with test data + model(X) - def forward(self, x): - return self.fc(x) + # Step 2. SmoothQuant + config.step = SmoothQuantStep.CONVERT + quantize_(model, config) - # Create model, reference, and test data - m = SimpleLinear(bias).eval().to(input_dtype).to(device) - m_ref = deepcopy(m) - test_data = torch.randn(2, 32, dtype=input_dtype, device=device) + out_smoothquant = model(X) + loss_smoothquant = torch.nn.functional.mse_loss(out_smoothquant, out_ref).item() - # Step 1: Setup quantized model with observer insertion and calibration - insert_smooth_quant_observer_(m, alpha, quant_mode) + assert loss_smoothquant < loss_base, ( + f"SmoothQuant loss ({loss_smoothquant:.6f}) should not be higher than basic loss ({loss_base:.6f})" + ) - # Perform calibration with test data + @common_utils.parametrize( + "base_config", + [ + Int8DynamicActivationInt8WeightConfig(), + # TODO: Check more quantization APIs + ], + ) + def test_observer_insertion(self, base_config): + """Test that PREPARE step correctly inserts SmoothQuantObservedLinear.""" + + m = ToyLinearModel().eval() + + # Before quantization - should be regular Linear + self.assertIsInstance(m.linear1, torch.nn.Linear) + self.assertNotIsInstance(m.linear1, SmoothQuantObservedLinear) + + # PREPARE step - should insert observers + config = SmoothQuantConfig( + base_config=base_config, + step=SmoothQuantStep.PREPARE, + ) + quantize_(m, config) + + # After PREPARE - should be SmoothQuantObservedLinear + self.assertIsInstance(m.linear1, SmoothQuantObservedLinear) + self.assertTrue(hasattr(m.linear1, "obs")) + + # Test calibration + test_data = torch.randn(2, 512) m(test_data) - # Apply quantization configuration - is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) - quantize_(m, SmoothQuantConfig(), is_observed_linear) + # CONVERT step - should produce regular Linear with quantized weights + config.step = SmoothQuantStep.CONVERT + quantize_(m, config) - # Apply compilation if supported - m = torch.compile(m, fullgraph=True) + # After CONVERT - should be regular Linear again (but quantized) + self.assertIsInstance(m.linear1, torch.nn.Linear) + self.assertNotIsInstance(m.linear1, SmoothQuantObservedLinear) - # Step 2: Inference quantized model + @common_utils.parametrize( + "base_config", + [ + Int8DynamicActivationInt8WeightConfig(), + # TODO: Check more quantization APIs + ], + ) + def test_prepare_for_loading(self, base_config): + """Test PREPARE_FOR_LOADING step for loading pre-quantized checkpoints.""" + + m = ToyLinearModel().eval() + + # Before quantization - should be regular Linear + self.assertIsInstance(m.linear1, torch.nn.Linear) + self.assertNotIsInstance(m.linear1, SmoothQuantObservedLinear) + + # PREPARE_FOR_LOADING step - should create quantized model ready for loading + config = SmoothQuantConfig( + base_config=base_config, + step=SmoothQuantStep.PREPARE_FOR_LOADING, + alpha=0.5, + ) + quantize_(m, config) + + # After PREPARE_FOR_LOADING - should be regular Linear with quantized weights + self.assertIsInstance(m.linear1, torch.nn.Linear) + self.assertNotIsInstance(m.linear1, SmoothQuantObservedLinear) + + # Test that model can run inference + test_data = torch.randn(2, 512) with torch.inference_mode(): - q_out = m(test_data) - - # Step 3: Compute reference - weight = m_ref.fc.weight.data.float() - b = m_ref.fc.bias if bias else None - x_abs_max_per_ic = torch.abs(test_data).max(dim=0).values - w_abs_max_per_ic = torch.abs(weight).max(dim=0).values - - if alpha is not None: - # Apply SmoothQuant - smoothing_factor = torch.pow(x_abs_max_per_ic, alpha) / torch.pow( - w_abs_max_per_ic, 1 - alpha - ) - else: - smoothing_factor = torch.ones_like(x_abs_max_per_ic) - - # Apply smoothing to activations and weights - smoothed_activation = test_data / smoothing_factor - smoothed_weight = weight * smoothing_factor - - # Quantize weights using per-channel quantization - qw, w_scales, w_zps = dynamically_quantize_per_channel( - smoothed_weight, -127, 127, torch.int8 + output = m(test_data) + + # Validate output + self.assertIsNotNone( + output, "PREPARE_FOR_LOADING model output should not be None" ) - fq_wei = dequantize_per_channel(qw, w_scales, w_zps, input_dtype) - - # Handle activation quantization based on mode - if quant_mode == "static": - # activation is quantized per-tensor - act_min, act_max = torch.aminmax(smoothed_activation.float()) - max_val_pos = torch.max(-act_min, act_max) - activation_scale = max_val_pos / 127.0 - - fq_act = ( - torch.quantize_per_tensor( - smoothed_activation.float(), - scale=activation_scale.item(), - zero_point=0, - dtype=torch.qint8, - ) - .dequantize() - .to(input_dtype) - ) - else: - # activation is quantized per-row (batch * sequence_length) - qx, x_scales, x_zps = dynamically_quantize_per_channel( - smoothed_activation.float(), -127, 127, torch.int8 - ) - fq_act = dequantize_per_channel( - qx, - x_scales, - x_zps, - input_dtype, - ) - - # Compute final linear operation - reference_out = torch.nn.functional.linear(fq_act, fq_wei, b) - - # Step 4: Validate numerical accuracy - tolerance = ( - 0.1 - if input_dtype == torch.float - else (0.2 if input_dtype == torch.half else 0.3) + self.assertFalse( + torch.isnan(output).any(), "Model should not produce NaN values" ) - torch.testing.assert_close( - q_out, - reference_out.to(input_dtype), - atol=tolerance, - msg=f"Quantized output differs from reference for " - f"bias={bias}, alpha={alpha}, quant_mode={quant_mode}, " - f"device={device}, dtype={input_dtype}", + self.assertEqual( + output.shape, (2, 64), "Output shape should match expected dimensions" ) - @unittest.skip("This test is broken on recent PyTorch, TODO(#1639): fix it") - @common_utils.parametrize("alpha", [None, 0.5, 0.75]) - @common_utils.parametrize("quant_mode", ["static", "dynamic"]) - @common_utils.parametrize( - "device", ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) - ) - @common_utils.parametrize("input_dtype", [torch.float, torch.bfloat16, torch.half]) - def test_save_load_recipe(self, alpha, quant_mode, device, input_dtype): - """Test save/load recipe functionality.""" - dataset_size = 20 - layer_dims = (512, 256, 128) # Input, hidden, output dimensions - n_calib_examples = 10 - sequence_length = 5 - - # Create two identical models for comparison - m = ToyLinearModel(*layer_dims).eval().to(input_dtype).to(device) - m_save_load = deepcopy(m) - - # Generate calibration dataset - dataset = m.example_inputs( - dataset_size, - sequence_length=sequence_length, - dtype=input_dtype, - device=device, - ) - calibration_data = dataset[:n_calib_examples] - - # Step 1: Setup first quantized model with observer insertion and calibration - insert_smooth_quant_observer_(m, alpha, quant_mode) - - # Perform calibration with calibration data - for data in calibration_data: - m(data) - - # Apply quantization configuration - is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) - quantize_(m, SmoothQuantConfig(), is_observed_linear) - - # Apply compilation if supported - m = torch.compile(m, fullgraph=True) - - # Step 2: Setup save/load model with recipe functionality - insert_smooth_quant_observer_(m_save_load, alpha, quant_mode) - for example in calibration_data: - m_save_load(example.to(device)) - - # Step 3: Test save/load recipe functionality - with tempfile.NamedTemporaryFile() as temp_file: - save_path = temp_file.name - save_smooth_quant_recipe(m_save_load, save_path) - load_smooth_quant_recipe(m_save_load, save_path) - - # Step 4: Complete quantization for save/load model - is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) - quantize_(m_save_load, SmoothQuantConfig(), is_observed_linear) - - m_save_load = torch.compile(m_save_load, fullgraph=True) - - # Step 5: Validate outputs on full dataset - with torch.inference_mode(): - original_outputs = [] - save_load_outputs = [] - - for data in dataset: - # Remove batch dimension for model input - input_tensor = data.squeeze(0) - - original_output = m(input_tensor) - save_load_output = m_save_load(input_tensor) - - original_outputs.append(original_output) - save_load_outputs.append(save_load_output) - - # Concatenate all outputs for comparison - original_result = torch.cat(original_outputs) - save_load_out = torch.cat(save_load_outputs) - - self.assertIsNotNone( - original_result, "Original model output should not be None" - ) - self.assertIsNotNone( - save_load_out, "Save/load model output should not be None" - ) - - torch.testing.assert_close( - original_result, - save_load_out, - msg=f"Save/load recipe should produce identical results for " - f"alpha={alpha}, quant_mode={quant_mode}, device={device}, dtype={input_dtype}", - ) - common_utils.instantiate_parametrized_tests(TestSmoothQuant) diff --git a/torchao/prototype/smoothquant/README.md b/torchao/prototype/smoothquant/README.md index c268a83504..00e819c438 100644 --- a/torchao/prototype/smoothquant/README.md +++ b/torchao/prototype/smoothquant/README.md @@ -1,98 +1,82 @@ -# SmothQuant quantization -This is a native PyTorch implementation of the algorithm described in [this paper](https://arxiv.org/abs/2211.10438). +# SmoothQuant quantization -In this implementation, weights are smoothed (equalized) and quantized to int8 during quantization. Activations are smoothed and quantized to int8 at runtime. Quantization is done either dynamically or statically. If activations are dynamically quantized, qparams (i.e., scales) are found at runtime while qparams are found during quantization for static quantization. For dynamic quantization, activations are quantized per token. And for static quantization, activations are quantized per tensor. Generally, dynamic quantization produces better accuracy while static quantization has better latency. In both cases, weights and activations are symmetrically quantized. +This is a native PyTorch implementation of the algorithm described in [this paper](https://arxiv.org/abs/2211.10438) with TorchAO Quantization APIs. + +$$ +Smoothing factor: s_{j} = \frac{max(|X_{j})^\alpha}{max(|W_{j}|) ^(1-\alpha)}, \ j=1, 2, \dots, C_{i} +$$ + +In this implementation, weights are smoothed (equalized) and quantized to int8 during quantization. Activations are smoothed and quantized to int8 at runtime. Quantization is done either dynamically or statically. For dynamic quantization, activations are quantized per token. And for static quantization, activations are quantized per tensor. ## Quick start + Run the example code with + ```bash -python example.py -m MODLE_ID --device= --quant-mode= +python example.py --model --device # An example -python example.py -m meta-llama/Llama-2-7b-hf --device=cuda --quant-mode=dynamic -``` -To use the `torch.compile` for speedup, add `--compile`. You may want to export `TORCHINDUCTOR_FREEZING=1` for even better performance. -```bash -TORCHINDUCTOR_FREEZING=1 python example.py -m MODLE_ID --device= --quant-mode= --compile +python example.py --model meta-llama/Llama-2-7b-chat-hf ``` -To save a quantized model for reuse, specify `--model-save-path` -```bash -python example.py -m MODLE_ID --device= --quant-mode= --model-save-path ./quantized_model.pt -``` -And load it by `--model-load-path` + +To save a quantized model for reuse, specify `--model_save_path` + ```bash -python example.py -m MODLE_ID --device= --quant-mode= --model-load-path ./quantized_model.pt +python example.py --model --model_save_path ./model_smoothquant.pt ``` - ## Usage of API -The following APIs are provided: -- insert_smooth_quant_observer_ -- SmoothQuantConfig -- save_smooth_quant_recipe (advanced) -- load_smooth_quant_recipe (advanced) -`insert_smooth_quant_observer_` inserts observers into the model to be quantized. For example: -```python -insert_smooth_quant_observer_(model, alpha=0.5, quant_mode="dynamic") -``` -After insertion, run the model for calibration on a certain dataset or (advanced) load a recipe. +`SmoothQuantConfig` configures applying SmoothQuant to each linear layer of the model. Use it with `torchao.quantization.quantize_`. For example: -`SmoothQuantConfig` configures appliying SmoothQuant to each linear layer of the model. Use it by calling `torchao.quantization.quantize_`. For example: ```python -from torchao.prototype.smoothquant import SmoothQuantObservedLinear -is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) -torchao.quantization.quantize_(model, SmoothQuantConfig(), is_observed_linear) -``` -`is_observed_linear` is a filter so that we only quantize observed linear layers. - -(Advanced) `save_smooth_quant_recipe` and `load_smooth_quant_recipe` saves or loads a recipe for a model. +from torchao.prototype.smoothquant import SmoothQuantConfig +from torchao.prototype.smoothquant.core import SmoothQuantStep +from torchao.quantization import quantize_ +from torchao.quantization.quant_api import Int8DynamicActivationInt8WeightConfig -A recipe contains smoothing factors and quantization parameters of weights and activation for all linear layers that are to be quantized. For advanced users, these parameters can be saved and modified somehow to produce better accuray, e.g., different alpha for different layers. Users can even leave some linear layers unquantized by deleting these layers in the recipe. Such modifications can be published as a recipe. By loading the recipe, it can be reused and calibration is no longer needed. +# Step 1: Prepare - insert observers +quant_config = SmoothQuantConfig( + base_config=Int8DynamicActivationInt8WeightConfig(), + step=SmoothQuantStep.PREPARE, + alpha=0.5, +) +quantize_(model, quant_config) -To save a recipe, users should insert observers and run calibration first. For example, -```python -insert_smooth_quant_observer_(model, alpha=0.5, quant_mode="dynamic") -for data in dataset_for_calibration: +# Step 2: Calibration +for data in calibration_dataset: model(data) -save_smooth_quant_recipe(model, "./smooth_quant_recipe.json") -``` -To load a recipe, users should insert observers first. For example, -```python -insert_smooth_quant_observer_(model) -load_smooth_quant_recipe(model, "./smooth_quant_recipe.json") + +# Step 3: Convert +quant_config.step = SmoothQuantStep.CONVERT +quantize_(model, quant_config) ``` -## Benchmark -Running the example with `torch.compile` on a NVIDIA A10G GPU. -### meta-llama/Llama-2-7b-hf -Perplexity -| Quant Method | alpha=0.25 | alpha=0.5 | alpha=0.75 | alpha=None* | -|-|-|-|-|-| -| Dynamic | 8.1872 | 7.4257 | 7.2518 | 7.5509 | -| Static | 43.8051 | 11.2984 | 7.5791 | 19.5050 | +## Benchmarks -Note*: Conventional quantization without SmoothQuant +All experiments use the `meta-llama/Llama-2-7b-chat-hf` model with max sequence length (SeqLen) 512 and calibration limit 128 on a 1xH100 80GB HBM2 instance. For comprehensive benchmarking, we compare three cases: 1. origin, 2. W8A8, 3. SmoothQuant (W8A8). -### meta-llama/Meta-Llama-3-8B -Perplexity -| Quant Method | alpha=0.25 | alpha=0.5 | alpha=0.75 | alpha=None* | -|-|-|-|-|-| -| Dynamic | 21.2475 | 8.8288 | 9.6514 | 8.3574 | -| Static | 301.7118 | 18.0617 | 10.8343 | 278.9819 | +### Benchmark Results -Note*: Conventional quantization without SmoothQuant +Result shows SmoothQuant with W8A8 slightly increase perplexity, reducing latency 33.82%. Since tinygemm kernel only uses bfloat16 inputs, Tokens/sec decreases for float16 input. -### Test method -**Commands** -```bash -# dynamic quant -TORCHINDUCTOR_FREEZING=1 python example.py -m --device=cuda --quant-mode=dynamic --compile -# static quant -TORCHINDUCTOR_FREEZING=1 python example.py -m --device=cuda --quant-mode=static --compile -``` -Use `--alpha` to specify the alpha parameter. Add `--disable-smooth-quant` to run quantization without SmoothQuant. +| Precision dtype | Quantization | Perplexity | Tokens/sec | PPL Change | Speed Change | +|-----------|--------------|------------|------------|------------|--------------| +| bfloat16 | - | 6.93 | 667 | - | - | +| bfloat16* | - | 6.93 | 27 🐌 | - | - | +| bfloat16 | W8A8-dynamic | 7.35 | 1,967 | +6.07% | +33.89% | +| bfloat16 | W8A8-dynamic** | 7.03 | **1,972** | **+1.39%** | **+33.82%** | +| float16 | - | 6.93 | 625 | - | - | +| float16 | W8A8-dynamic | 7.29 | 523 | +5.21% | -19.42% | +| float16 | W8A8-dynamic** | 6.94 | 516 | **+0.21%** | -21.23% | +| bfloat16* | W8A8-dynamic** | 6.92 | 3 🐌 | -0.18% | -768.29% | + +> *Used with `torch.compile`, **Used with **SmoothQuant** + +### Key Findings + +- **Speed Improvement**: Most configurations show 35-40% speed improvement with both W8A8 and SmoothQuant-W8A8 +- **Quality Trade-off**: Slight perplexity increase (~1-1.4%) in most cases +- **Compilation Impact**: Using `--compile` flag significantly degrades performance (768% slower) +- **Best Configuration**: `bfloat16` without `--compile` provides optimal balance -**Environment** -- AWS g5.12xlarge instance -- torch==2.6.0.dev20241017+cu124 -- python==3.12.6 +> Note: Unlike AWQ, this benchmark isn't computed using the script in `vllm/benchmarks` or `lm_eval`. vLLM benchmark will be introduced in foreseeable future. See https://github.com/pytorch/ao/issues/2815 for more information. diff --git a/torchao/prototype/smoothquant/__init__.py b/torchao/prototype/smoothquant/__init__.py index 948a99c080..2ea8b5713a 100644 --- a/torchao/prototype/smoothquant/__init__.py +++ b/torchao/prototype/smoothquant/__init__.py @@ -1,15 +1,13 @@ -from .api import ( - SmoothQuantConfig, - insert_smooth_quant_observer_, - load_smooth_quant_recipe, - save_smooth_quant_recipe, +from .api import SmoothQuantConfig +from .core import ( + SmoothQuantObservedLinear, + SmoothQuantObserver, + SmoothQuantStep, ) -from .core import SmoothQuantObservedLinear __all__ = [ - "insert_smooth_quant_observer_", - "load_smooth_quant_recipe", - "save_smooth_quant_recipe", "SmoothQuantConfig", + "SmoothQuantStep", + "SmoothQuantObserver", "SmoothQuantObservedLinear", ] diff --git a/torchao/prototype/smoothquant/api.py b/torchao/prototype/smoothquant/api.py index 9397b340b3..9f78c49fb8 100644 --- a/torchao/prototype/smoothquant/api.py +++ b/torchao/prototype/smoothquant/api.py @@ -5,227 +5,122 @@ # LICENSE file in the root directory of this source tree. import types from dataclasses import dataclass -from typing import Dict, Optional +from typing import Optional import torch -import torchao from torchao.core.config import AOBaseConfig -from torchao.dtypes import to_affine_quantized_intx, to_affine_quantized_intx_static -from torchao.prototype.smoothquant.core import ( - SmoothQuantObservedLinear, - SmoothQuantObserver, -) -from torchao.quantization import quantize_ -from torchao.quantization.linear_activation_quantized_tensor import ( - to_linear_activation_quantized, -) from torchao.quantization.linear_activation_scale import ( to_weight_tensor_with_linear_activation_scale_metadata, ) from torchao.quantization.quant_api import ( + _QUANTIZE_CONFIG_HANDLER, _linear_extra_repr, - _replace_with_custom_fn_if_matches_filter, ) -from torchao.quantization.quant_primitives import MappingType from torchao.quantization.transform_module import ( register_quantize_module_handler, ) -from torchao.quantization.utils import _get_per_token_block_size -from torchao.quantization.weight_tensor_linear_activation_quantization import ( - to_weight_tensor_with_linear_activation_quantization_metadata, -) - - -def insert_smooth_quant_observer_( - model: torch.nn.Module, alpha: Optional[float] = 0.5, quant_mode: str = "dynamic" -): - """ - Inserts SmoothQuantObserver into Linear layers of a given model. - - Args: - model: The model to be modified (in place). Ensure model is on the desired device for calibration - alpha: The alpha value to determine smoothing factor. Factor = 1 if alpha is None, which means - falling back to conventional quantization. - quant_mode: dynamic or static quantization of activation - """ - _is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear) - - quant_min, quant_max = -127, 127 - eps = torch.finfo(torch.float32).eps - - def replace_with_observer(layer): - # creates observer and replaces linear layers with observed linear layers - observer = SmoothQuantObserver( - layer.weight, - alpha, - quant_mode, - quant_min=quant_min, - quant_max=quant_max, - eps=eps, - ) - return SmoothQuantObservedLinear.from_float(layer, observer) - - _replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear) - - -def save_smooth_quant_recipe( - model: torch.nn.Module, save_path: str -) -> Dict[str, torch.Tensor]: - """ - Save smoothing_factors, act_scales, and wei_scales for each SmoothQuantObservedLinear layer in the model. - """ - result = {} - - def recurse(module: torch.nn.Module, name: str = ""): - for child_name, child in module.named_children(): - full_name = f"{name}.{child_name}" if name else child_name - - # Apply the analysis function to this layer - if isinstance(child, SmoothQuantObservedLinear): - smoothing_factor, act_scales, wei_scales = child.obs.calculate_qparams() - result[full_name + ".smoothing_factor"] = smoothing_factor - result[full_name + ".act_scales"] = act_scales - result[full_name + ".wei_scales"] = wei_scales - - # Recurse into child modules - recurse(child, full_name) - - recurse(model) - - torch.save(result, save_path) - - -def load_smooth_quant_recipe( - model: torch.nn.Module, recipe_path: str, device=None -) -> torch.nn.Module: - recipe = torch.load(recipe_path, weights_only=True) - - def recurse(module: torch.nn.Module, name: str = ""): - if isinstance(module, SmoothQuantObservedLinear): - smoothing_factor = recipe.get(name + ".smoothing_factor", None) - act_scales = recipe.get(name + ".act_scales", None) - wei_scales = recipe.get(name + ".wei_scales", None) - if device is not None: - module.to(device=device) - # act_scales is None for dynamic quantization - if any(x is None for x in (smoothing_factor, wei_scales)): - return module - is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) - wrapper = torch.nn.Sequential(module) - quantize_( - wrapper, - SmoothQuantConfig(smoothing_factor, act_scales, wei_scales), - is_observed_linear, - ) - return wrapper[0] - - mod_new = module - - for child_name, child in module.named_children(): - full_name = f"{name}.{child_name}" if name else child_name - setattr(mod_new, child_name, recurse(child, full_name)) - return mod_new - - recurse(model) - - -class _ActQuantizer: - def __init__(self, target_dtype, quant_min=-127): - self.target_dtype = target_dtype - self.quant_min = quant_min - - def dynamic_quantize(self, input): - return to_affine_quantized_intx( - input, - MappingType.SYMMETRIC, - _get_per_token_block_size(input), - self.target_dtype, - self.quant_min, - ) +from torchao.utils import DummyModule - def static_quantize(self, input, scale, zero_point): - return to_affine_quantized_intx_static( - input, - scale, - zero_point, - list(input.shape), - self.target_dtype, - self.quant_min, - ) +from .core import ( + SmoothQuantObservedLinear, + SmoothQuantObserver, + SmoothQuantStep, +) @dataclass class SmoothQuantConfig(AOBaseConfig): """ - Configuration for quantizing linear layers when passed into quantize_() + Configuration for SmoothQuant quantization when passed into quantize_() Args: - smoothing_factor: The smoothing factor for the layer. Acquired from the layer's observer if None. - act_scales: The activation scales for the layer. Acquired from the layer's observer if None. - wei_scales: The weight scales for the layer. Acquired from the layer's observer if None. - set_inductor_config: if True, adjusts `torchinductor` settings to recommended values. + base_config: Base quantization configuration that SmoothQuant is applied on top of + step (SmoothQuantStep): The step for SmoothQuant process + PREPARE: insert SmoothQuant Observers to linear layers + CONVERT: convert the observed linear modules to quantized modules + PREPARE_FOR_LOADING: convert the floating point model to a dummy smoothquant quantized model, so we can + load the quantized weights through copy_ later + alpha: The alpha value to determine smoothing factor. Factor = 1 if alpha is None, which means + Fall back to conventional quantization if None """ - smoothing_factor: Optional[torch.Tensor] = None - act_scales: Optional[torch.Tensor] = None - wei_scales: Optional[torch.Tensor] = None - set_inductor_config: bool = True + base_config: AOBaseConfig + step: SmoothQuantStep + alpha: Optional[float] = 0.5 + + def __post_init__(self): + self.step = self.step.lower() if isinstance(self.step, str) else self.step.value + all_step_values = [s.value for s in SmoothQuantStep] + if self.step not in all_step_values: + raise ValueError(f"{self.step} is not one of {all_step_values}") @register_quantize_module_handler(SmoothQuantConfig) def _smooth_quant_transform( module: torch.nn.Module, config: SmoothQuantConfig, -): - smoothing_factor = config.smoothing_factor - act_scales = config.act_scales - wei_scales = config.wei_scales - if config.set_inductor_config: - torchao.quantization.utils.recommended_inductor_config_setter() - observed_linear = module - - linear = torch.nn.Linear( - observed_linear.in_features, - observed_linear.out_features, - observed_linear.bias is not None, - device=observed_linear.weight.device, - dtype=observed_linear.weight.dtype, - ) - linear.bias = observed_linear.bias +) -> torch.nn.Module: + step = config.step + base_config = config.base_config - target_dtype = torch.int8 - # act_scales is None for dynamic quantization thus not checked - if any(x is None for x in (smoothing_factor, wei_scales)): - factor, x_scale, w_scales = observed_linear.obs.calculate_qparams() - weight = observed_linear.obs.weight * factor - else: - factor, x_scale, w_scales = smoothing_factor, act_scales, wei_scales - weight = observed_linear.weight * factor - weight = weight.to(observed_linear.weight.dtype) - block_size = (1, weight.size(1)) - wei_zero_points = torch.zeros_like(w_scales, dtype=torch.int64) - qw = to_affine_quantized_intx_static( - weight, - w_scales, - wei_zero_points, - block_size, - target_dtype, - ) + if step == SmoothQuantStep.PREPARE: + observer = SmoothQuantObserver( + weight=module.weight, + alpha=config.alpha, + ) + return SmoothQuantObservedLinear.from_float(module, observer) - if x_scale is None: - # dynamic quant - qw = to_linear_activation_quantized( - qw, _ActQuantizer(target_dtype).dynamic_quantize + if step == SmoothQuantStep.PREPARE_FOR_LOADING: + # loading from pre-quantized checkpoint + observer = SmoothQuantObserver( + weight=module.weight, + alpha=config.alpha, ) + observed_linear = SmoothQuantObservedLinear.from_float(module, observer) + example_input = torch.randn( + (1, module.weight.shape[1]), + device=module.weight.device, + dtype=module.weight.dtype, + ) + observed_linear(example_input) + + elif step == SmoothQuantStep.CONVERT: + if not isinstance(module, SmoothQuantObservedLinear): + print( + f"convert: module is not SmoothQuantObservedLinear, skipping: {type(module)}" + ) + return module + observed_linear = module else: - # static quant - x_zero_point = torch.zeros_like(x_scale, dtype=torch.int64) - qw = to_weight_tensor_with_linear_activation_quantization_metadata( - qw, _ActQuantizer(target_dtype).static_quantize, x_scale, x_zero_point + raise ValueError(f"Unexpected step: {step}") + + # Compute smoothed weight parameters + smoothing_factor = observed_linear.obs.calculate_qparams() + weight = observed_linear.weight * smoothing_factor + + # Create new linear layer + with torch.device("meta"): + linear = torch.nn.Linear( + observed_linear.in_features, + observed_linear.out_features, + observed_linear.bias is not None, + device=observed_linear.weight.device, + dtype=observed_linear.weight.dtype, ) + linear.bias = observed_linear.bias - qw = to_weight_tensor_with_linear_activation_scale_metadata(qw, factor.to(qw.dtype)) + # Quantize weights + base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(base_config)] + dummy_mod = DummyModule(weight) + quant_mod = base_config_handler(dummy_mod, base_config) + qw = quant_mod.weight + + # Add smoothing factor metadata + qw = to_weight_tensor_with_linear_activation_scale_metadata( + qw, smoothing_factor.to(qw.dtype) + ) linear.weight = torch.nn.Parameter(qw, requires_grad=False) - linear.extra_repr = types.MethodType(_linear_extra_repr, module) + linear.extra_repr = types.MethodType(_linear_extra_repr, linear) + return linear diff --git a/torchao/prototype/smoothquant/core.py b/torchao/prototype/smoothquant/core.py index 3e6c6ea5d5..83f1e78275 100644 --- a/torchao/prototype/smoothquant/core.py +++ b/torchao/prototype/smoothquant/core.py @@ -3,15 +3,17 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +from enum import Enum from typing import Optional import torch import torch.nn.functional as F -from torchao.quantization.observer import AffineQuantizedMinMaxObserver, PerAxis -from torchao.quantization.quant_primitives import ( - MappingType, -) + +class SmoothQuantStep(str, Enum): + PREPARE = "prepare" + CONVERT = "convert" + PREPARE_FOR_LOADING = "prepare_for_loading" class SmoothQuantObserver(torch.nn.Module): @@ -19,113 +21,48 @@ def __init__( self, weight: torch.Tensor, alpha: Optional[float] = 0.5, - quant_mode: str = "static", # or dynamic - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, - eps: Optional[float] = None, ): """ - A custom observer for SmoothQuant + A custom observer for smoothing factor, main concept of SmoothQuant. Args: weight: The weight tensor to be observed. alpha: The alpha value to determine smoothing factor, normally between 0 and 1. - Fall back to conventional quantization if alpha is None. - quant_mode: The mode of activation quantization, either static or dynamic - quant_min: The minimum quantized value - quant_max: The maximum quantized value - eps: The minimum scale to avoid dividing by zero. """ super().__init__() assert weight.ndim == 2 self.weight = weight - self.inputs = [] - self.device = self.weight.device self.alpha = alpha - assert quant_mode in ["static", "dynamic"] - self.quant_mode = quant_mode - self.quant_min = quant_min - self.quant_max = quant_max - self.eps = eps - # act.shape = [mb, ic] (reshape if needed), wei.shape = [oc, ic] - # *_ic_obs are used to determine smoothing_factor - # wei_oc_obs is used to find qparams for quantization - self.act_ic_obs = AffineQuantizedMinMaxObserver( - MappingType.SYMMETRIC, - torch.int8, - PerAxis(-1), - eps=eps, - ) - self.wei_ic_obs = AffineQuantizedMinMaxObserver( - MappingType.SYMMETRIC, - torch.int8, - PerAxis(-1), - eps=eps, - ) - self.wei_oc_obs = AffineQuantizedMinMaxObserver( - MappingType.SYMMETRIC, - torch.int8, - PerAxis(0), - quant_min=quant_min, - quant_max=quant_max, - eps=eps, - ) - self.wei_ic_obs(self.weight) + self.inputs = [] + self.device = weight.device @torch.no_grad() def forward(self, input: torch.Tensor): - self.act_ic_obs(input.to("cpu")) + self.inputs.append(input.to("cpu")) return input def calculate_qparams(self): - # 1 Get min/max per IC from observers - wei_min_per_ic = self.wei_ic_obs.min_val - wei_max_per_ic = self.wei_ic_obs.max_val - act_min_per_ic = self.act_ic_obs.min_val - act_max_per_ic = self.act_ic_obs.max_val - x_abs_max_per_ic = ( - torch.max(torch.abs(act_min_per_ic), torch.abs(act_max_per_ic)) + self.eps - ) - w_abs_max_per_ic = ( - torch.max(torch.abs(wei_min_per_ic), torch.abs(wei_max_per_ic)) + self.eps + assert self.inputs and len(self.inputs) > 0, ( + "calibrate observer first by running model on exemplar data" ) - # 2 calculate the smoothing factor + inputs = [inp.to(self.device) for inp in self.inputs] + acc = torch.cat(inputs, dim=0) + # Reshape if needed: [batch, seq, features] -> [batch*seq, features] + if acc.ndim > 2: + acc = acc.view(-1, acc.shape[-1]) + + # Calculate per-channel max values + x_abs_max = torch.max(torch.abs(acc), dim=0)[0] + w_abs_max = torch.max(torch.abs(self.weight), dim=0)[0] + + # Calculate smoothing factor if self.alpha is None: - # fall back to conventional quantization if alpha is None - smoothing_factor = torch.ones_like( - x_abs_max_per_ic, - dtype=x_abs_max_per_ic.dtype, - device=x_abs_max_per_ic.device, - ) - else: - smoothing_factor = torch.pow(x_abs_max_per_ic, self.alpha) / torch.pow( - w_abs_max_per_ic.to(x_abs_max_per_ic.device), 1 - self.alpha - ) - # 3 apply smoothing factor to activations and find scales for static quantization - act_scales = None - if self.quant_mode == "static": - act_min_per_ic_new = act_min_per_ic / smoothing_factor.reshape( - act_min_per_ic.shape - ) - act_max_per_ic_new = act_max_per_ic / smoothing_factor.reshape( - act_max_per_ic.shape - ) - min_val_per_tensor = torch.min(act_min_per_ic_new) - max_val_per_tensor = torch.max(act_max_per_ic_new) - min_val_neg = torch.min( - min_val_per_tensor, torch.zeros_like(min_val_per_tensor) - ) - max_val_pos = torch.max( - max_val_per_tensor, torch.zeros_like(max_val_per_tensor) - ) - max_val_pos = torch.max(-min_val_neg, max_val_pos) - act_scale = max_val_pos / (float(self.quant_max - self.quant_min) / 2) - act_scales = act_scale.to(self.device) - # 4 update weight and find scales - self.wei_oc_obs(self.weight * smoothing_factor.to(self.device)) - wei_scales, _ = self.wei_oc_obs.calculate_qparams() - # 5 return results - return smoothing_factor.to(self.device), act_scales, wei_scales.to(self.device) + return torch.ones_like(x_abs_max) + + eps = torch.finfo(torch.float32).eps + return torch.pow(x_abs_max + eps, self.alpha) / torch.pow( + w_abs_max + eps, 1 - self.alpha + ) class SmoothQuantObservedLinear(torch.nn.Linear): @@ -133,30 +70,31 @@ def __init__( self, in_features: int, out_features: int, - bias: bool, obs: SmoothQuantObserver, + is_bias: bool = False, device=None, dtype=None, ): - super().__init__(in_features, out_features, bias, device, dtype) - assert isinstance(obs, SmoothQuantObserver) + super().__init__( + in_features, out_features, bias=is_bias, device=device, dtype=dtype + ) self.obs = obs def forward(self, input: torch.Tensor): input = self.obs(input) - output = F.linear(input, self.weight, self.bias) - return output + return F.linear(input, self.weight) @classmethod def from_float(cls, float_linear: torch.nn.Linear, obs: SmoothQuantObserver): - observed_linear = cls( - float_linear.in_features, - float_linear.out_features, - float_linear.bias is not None, - obs, - device=float_linear.weight.device, - dtype=float_linear.weight.dtype, - ) + with torch.device("meta"): + observed_linear = cls( + float_linear.in_features, + float_linear.out_features, + obs, + is_bias=float_linear.bias is not None, + device=float_linear.weight.device, + dtype=float_linear.weight.dtype, + ) observed_linear.weight = float_linear.weight observed_linear.bias = float_linear.bias return observed_linear diff --git a/torchao/prototype/smoothquant/example.py b/torchao/prototype/smoothquant/example.py index de1e4ed93e..dbf764e526 100644 --- a/torchao/prototype/smoothquant/example.py +++ b/torchao/prototype/smoothquant/example.py @@ -4,185 +4,263 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. import argparse -import os import time -from typing import Optional import torch from datasets import load_dataset -from tqdm import tqdm -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig +from torchao.prototype.awq.example import get_calib_dataset from torchao.prototype.smoothquant import ( SmoothQuantConfig, - SmoothQuantObservedLinear, - insert_smooth_quant_observer_, ) +from torchao.prototype.smoothquant.core import SmoothQuantStep from torchao.quantization import quantize_ +from torchao.quantization.quant_api import Int8DynamicActivationInt8WeightConfig -def get_calib_dataset(tokenizer=None, n_samples=100, block_size=512): - dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="validation") - samples = [] - n_tokens = n_samples * block_size - n_run = n_tokens - for data in dataset: - line = data["text"] - line = line.strip() - line_encoded = tokenizer.encode(line) - if len(line_encoded) > 512: - continue - sample = torch.tensor([line_encoded]) - if sample.numel() == 0: - continue - samples.append(sample) - n_run -= len(line_encoded) - if n_run <= n_samples: - break - - cat_samples = torch.cat(samples, dim=1) - return [ - cat_samples[:, i * block_size : (i + 1) * block_size] for i in range(n_samples) - ] - - -def wiki2_eval( - model, tokenizer, sequence_length, stride=512, verbose=True, device="cuda" -): - model.eval() - tokenizer.pad_token = tokenizer.eos_token - tokenizer.padding_side = "right" - tokenizer.add_eos_token = False - - print("Loading dataset") - t0 = time.time() +# TODO: Build benchmark within vLLM ecosystem with more quantization APIs +# See https://github.com/pytorch/ao/issues/2815 for more details +def benchmark(model, tokenizer, max_seq_length=512, tasks=["PPL"], device="cuda"): + """Benchmark model with perplexity calculation on WikiText-2""" + # Load WikiText-2 test set dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") - encodings = tokenizer("\n\n".join(dataset["text"]), return_tensors="pt") - print(f"Time to load dataset: {time.time() - t0:.02f} seconds") - - encodings["input_ids"] = encodings["input_ids"].to(device) - - print("Running evaluation") - lls, t = [], [] - for i in tqdm( - range(0, encodings["input_ids"].size(1), stride), disable=not verbose - ): - begin_loc = max(i + stride - sequence_length, 0) - end_loc = min(i + stride, encodings["input_ids"].size(1)) - trg_len = end_loc - i - input_ids = encodings["input_ids"][:, begin_loc:end_loc] - target_ids = input_ids.clone() - target_ids[:, :-trg_len] = -100 # ignore context - - t1 = time.time() - with torch.no_grad(): - log_likelihood = model(input_ids, labels=target_ids).loss * trg_len - if device == "cuda": - torch.cuda.synchronize() - t2 = time.time() - t.append((t2 - t1)) - lls.append(log_likelihood) - - del input_ids, target_ids - - ppl = float(torch.exp(torch.stack(lls).sum() / end_loc)) - pred_time = sum(t) / len(t) - if verbose: - print("perplexity", ppl) - print("time", str(pred_time) + " sec/it") - - return {"perplexity": ppl, "prediction_time": pred_time} - - -def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"): + + # Prepare text data and truncate if necessary + text = "\n\n".join(dataset["text"]) + # Get model's maximum sequence length + model_max_length = getattr(tokenizer, "model_max_length", max_seq_length) + if model_max_length > 1000000: # Default large value, use our max_seq_length + model_max_length = max_seq_length + + encodings = tokenizer( + text, return_tensors="pt", truncation=True, max_length=model_max_length + ) + + # Calculate perplexity model.eval() - model.config.use_cache = False - if tasks is None: - tasks = ["PPL"] - results = {} - if "PPL" in tasks: - results["perplexity"] = wiki2_eval( - model, tokenizer, 512, verbose=True, device=device - ) - return results - - -def wikitext2_ppl( + nlls = [] + + with torch.no_grad(): + seq_len = encodings.input_ids.size(1) + prev_end_loc = 0 + + for begin_loc in range(0, seq_len, max_seq_length): + end_loc = min(begin_loc + max_seq_length, seq_len) + trg_len = end_loc - prev_end_loc + + input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device) + target_ids = input_ids.clone() + target_ids[:, :-trg_len] = -100 + + # Measure inference time + start_time = time.time() + outputs = model(input_ids, labels=target_ids) + inference_time = time.time() - start_time + + neg_log_likelihood = outputs.loss * trg_len + nlls.append(neg_log_likelihood) + + prev_end_loc = end_loc + if end_loc == seq_len: + break + + ppl = torch.exp(torch.stack(nlls).sum() / end_loc) + + return { + "perplexity": ppl.item(), + "tokens_per_sec": input_ids.size(1) / inference_time, + } + + +def quantize_and_eval( model_id: str, - alpha: Optional[float], - quant_mode: str, - calibration_size: int, + alpha: float, + tasks: list[str], + max_seq_length: int, + calibration_limit: int, device: str, - precision: torch.dtype, - sequence_length: int, - compile: bool, - model_load_path: str, model_save_path: str, + model_save_hf_hub_path: str, ): print(f"Loading model on {device}...") torch.manual_seed(34) t0 = time.time() tokenizer = AutoTokenizer.from_pretrained(model_id) - if model_load_path is not None and os.path.exists(model_load_path): - print(f"Loading quantized model from {model_load_path}") - t0 = time.time() - model = torch.load(model_load_path, weights_only=False).to(device) - print(f"Time to load quantized model: {time.time() - t0:.02f} seconds") - else: - model = ( - AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=precision) - .eval() - .to(device) - ) - print(f"Time to load model: {time.time() - t0:.02f} seconds") - print("running calibration") - t0 = time.time() - # insert observers to find average magnitude and calculate scales - insert_smooth_quant_observer_(model, alpha, quant_mode) - calibration_data = get_calib_dataset( - tokenizer=tokenizer, n_samples=calibration_size, block_size=sequence_length - ) - for batch in calibration_data: - model(batch.to(device)) - batch.to("cpu") - print(f"time for calibration: {time.time() - t0:.02f} seconds") - - is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) - print(f"running SmoothQuant with {quant_mode} quantization") - t0 = time.time() - quantize_(model, SmoothQuantConfig(), is_observed_linear) - print(f"time for quantization: {time.time() - t0:.02f} seconds") - if model_save_path is not None: - print(f"Saving quantized model to {model_save_path}") - t0 = time.time() - torch.save(model, model_save_path) - print(f"Time to save quantized model: {time.time() - t0:.02f} seconds") - if compile: - model = torch.compile(model, dynamic=True) - - return benchmark(model, tokenizer, sequence_length, tasks=["PPL"], device=device) + model = ( + AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) + .eval() + .to(device) + ) + print(f"Time to load model: {time.time() - t0:.02f} seconds") + # Step 1: Prepare - insert observers + print("running SmoothQuant prepare and calibrate") + t0 = time.time() + quant_config = SmoothQuantConfig( + base_config=Int8DynamicActivationInt8WeightConfig(), + step=SmoothQuantStep.PREPARE, + alpha=alpha, + ) + quantize_(model, quant_config) -if __name__ == "__main__": + # Step 2: Calibration + calibration_data = get_calib_dataset( + tokenizer=tokenizer, n_samples=calibration_limit, block_size=max_seq_length + ) + for batch in calibration_data: + model(batch.to(device)) + batch.to("cpu") + + print(f"time for prepare and calibration: {time.time() - t0:.02f} seconds") + + # Step 3: Convert to quantized model + print("running SmoothQuant convert") + t0 = time.time() + quant_config.step = SmoothQuantStep.CONVERT + quantize_(model, quant_config) + print(f"time for convert: {time.time() - t0:.02f} seconds") + + # Set up config for loading + quant_config.step = SmoothQuantStep.PREPARE_FOR_LOADING + model.config.quantization_config = TorchAoConfig(quant_config) + + if model_save_path is not None: + print(f"Saving model to {model_save_path}") + torch.save(model, model_save_path) + + if model_save_hf_hub_path is not None: + print("pushing model to hub:", model_save_hf_hub_path) + model.push_to_hub(model_save_hf_hub_path, safe_serialization=False) + tokenizer.push_to_hub(model_save_hf_hub_path) + + print("Benchmarking SmoothQuant model...") + return benchmark(model, tokenizer, max_seq_length, tasks=tasks, device=device) + + +def compare_models( + model_id: str, + alpha: float, + tasks: list[str], + max_seq_length: int, + calibration_limit: int, + device: str, + model_save_path: str, + model_save_hf_hub_path: str, +): + """Compare perplexity and speed for behchmarking SmoothQuant""" + + # Case 1: Base model without quantization + print("Benchmarking base model...") + torch.manual_seed(34) + tokenizer = AutoTokenizer.from_pretrained(model_id) + model = ( + AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) + .eval() + .to(device) + ) + base_results = benchmark( + model, tokenizer, max_seq_length, tasks=tasks, device=device + ) + + # Case 2: W8A8-dynamic without SmoothQuant + print("Benchmarking W8A8-dynamic without SmoothQuant...") + torch.manual_seed(34) + w8a8_model = ( + AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) + .eval() + .to(device) + ) + quantize_(w8a8_model, Int8DynamicActivationInt8WeightConfig()) + w8a8_results = benchmark( + w8a8_model, tokenizer, max_seq_length, tasks=tasks, device=device + ) + + # Case 3: SmoothQuant + W8A8-dynamic + print("Benchmarking SmoothQuant with W8A8-dynamic...") + smoothquant_results = quantize_and_eval( + model_id, + alpha, + tasks, + max_seq_length, + calibration_limit, + device, + model_save_path, + model_save_hf_hub_path, + ) + + # Calculate changes and display results + w8a8_ppl_change = ( + (w8a8_results["perplexity"] - base_results["perplexity"]) + / base_results["perplexity"] + * 100 + ) + w8a8_speed_change = ( + (w8a8_results["tokens_per_sec"] - base_results["tokens_per_sec"]) + / base_results["tokens_per_sec"] + * 100 + ) + + smoothquant_ppl_change = ( + (smoothquant_results["perplexity"] - base_results["perplexity"]) + / base_results["perplexity"] + * 100 + ) + smoothquant_speed_change = ( + (smoothquant_results["tokens_per_sec"] - base_results["tokens_per_sec"]) + / base_results["tokens_per_sec"] + * 100 + ) + + # Print results + print( + f"\nBase: PPL={base_results['perplexity']:.2f}, Speed={base_results['tokens_per_sec']:.2f} tokens/sec" + ) + print( + f"w8a8-Dynamic: PPL={w8a8_results['perplexity']:.2f}, Speed={w8a8_results['tokens_per_sec']:.2f} tokens/sec" + ) + print( + f"SmoothQuant+w8a8: PPL={smoothquant_results['perplexity']:.2f}, Speed={smoothquant_results['tokens_per_sec']:.2f} tokens/sec" + ) + print(f"w8a8 Changes: PPL {w8a8_ppl_change:+.2f}%, Speed {w8a8_speed_change:+.2f}%") + print( + f"SmoothQuant Changes: PPL {smoothquant_ppl_change:+.2f}%, Speed {smoothquant_speed_change:+.2f}%" + ) + + return { + "base_model": base_results, + "w8a8_model": w8a8_results, + "smoothquant_model": smoothquant_results, + "w8a8_ppl_change_percent": w8a8_ppl_change, + "w8a8_speed_improvement_percent": w8a8_speed_change, + "smoothquant_ppl_change_percent": smoothquant_ppl_change, + "smoothquant_speed_improvement_percent": smoothquant_speed_change, + } + + +def create_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( - description="Evaluate a model with the specified parameters." + description="Evaluate a model with SmoothQuant quantization." ) - # Optional arguments with default values parser.add_argument( - "--model-id", "-m", type=str, help="Repository ID of the model." + "--model", type=str, required=True, help="Model ID from Huggingface hub." ) parser.add_argument( "--alpha", type=float, default=0.5, - help="The alpha hyperparameter for SmoothQuant.", + help="The alpha hyperparameter for SmoothQuant. Default is 0.5.", ) parser.add_argument( - "--quant-mode", type=str, help="Quantization mode, either static or dynamic." + "--tasks", + nargs="+", + type=str, + help="Task to benchmark model on.", + default=["PPL"], ) parser.add_argument( - "--calibration-samples", + "--calibration_limit", type=int, default=10, help="Number of samples to use for calibration. Default is 10.", @@ -194,54 +272,38 @@ def wikitext2_ppl( help="Device to run the evaluation on. Default is 'cuda'.", ) parser.add_argument( - "--precision", - type=str, - default="bfloat16", - help="Precision type. Default is 'bfloat16'.", - ) - parser.add_argument( - "--seq_len", + "--max_seq_length", type=int, default=512, - help="Length of examples to calibrate and evaluate model on. Default is 512", + help="Maximum sequence length. Default is 512", ) parser.add_argument( - "--compile", - action="store_true", - help="Flag to indicate if compilation is required.", - ) - parser.add_argument( - "--model-load-path", + "--model_save_path", type=str, default=None, - help="Path to load quantized model. If this is provided, " - "the model will be loaded from this path instead of quantizing the model.", + help="Path to store the quantized model.", ) parser.add_argument( - "--model-save-path", + "--model_save_hf_hub_path", type=str, default=None, - help="Path to store quantized model.", - ) - parser.add_argument( - "--disable-smooth-quant", - action="store_true", - help="Run conventional dynamic or static quantization for testing or debugging.", + help="Huggingface hub path to store the quantized model and tokenizer.", ) + return parser + + +if __name__ == "__main__": + parser = create_parser() args = parser.parse_args() - # Convert precision argument to torch dtype - precision_dtype = getattr(torch, args.precision, torch.bfloat16) - ppl = wikitext2_ppl( - args.model_id, - None if args.disable_smooth_quant else args.alpha, - args.quant_mode, - args.calibration_samples, + result = compare_models( + args.model, + args.alpha, + args.tasks, + args.max_seq_length, + args.calibration_limit, args.device, - args.precision, - args.seq_len, - args.compile, - args.model_load_path, args.model_save_path, + args.model_save_hf_hub_path, )