diff --git a/README.md b/README.md index e99f56590..bcb3966bf 100644 --- a/README.md +++ b/README.md @@ -19,8 +19,8 @@ All with no intrusive code changes and minimal accuracy degradation. Quantizing your models is a 1 liner that should work on any model with an `nn.Linear` including your favorite HuggingFace model. You can find a more comprehensive usage instructions [here](torchao/quantization/) and a HuggingFace inference example [here](scripts/hf_eval.py) ```python -from torchao.quantization.quant_api import quantize, int4_weight_only -m = quantize(m, int4_weight_only()) +from torchao.quantization.quant_api import quantize_, int4_weight_only +quantize_(m, int4_weight_only()) ``` Benchmarks are run on a machine with a single A100 GPU using the script in `_models/llama` which generates text in a latency-optimized way (batchsize=1) @@ -70,7 +70,7 @@ swap_linear_with_semi_sparse_linear(model, {"seq.0": SemiSparseLinear}) * [MX](torchao/prototype/mx_formats) implementing training and inference support with tensors using the [OCP MX spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) data types, which can be described as groupwise scaled float8/float6/float4/int8, with the scales being constrained to powers of two. This work is prototype as the hardware support is not available yet. * [nf4](torchao/dtypes/nf4tensor.py) which was used to [implement QLoRA](https://github.com/pytorch/torchtune/blob/main/docs/source/tutorials/qlora_finetune.rst) one of the most popular finetuning algorithms without writing custom Triton or CUDA code. Accessible talk [here](https://x.com/HamelHusain/status/1800315287574847701) -* [fp6](torchao/prototype/quant_llm/) for 2x faster inference over fp16 with an easy to use API `quantize(model, fp6_llm_weight_only())` +* [fp6](torchao/prototype/quant_llm/) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fp6_llm_weight_only())` ## Composability diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 4d5a2c511..c21f3a38b 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -23,7 +23,7 @@ int4_weight_only, int8_weight_only, int8_dynamic_activation_int8_weight, - quantize, + quantize_, _replace_with_custom_fn_if_matches_filter, ) # APIs to be deprecated (used for torch 2.2.2 and 2.3) @@ -98,21 +98,21 @@ def _int8wo_api(mod): if TORCH_VERSION_AFTER_2_4: - quantize(mod, int8_weight_only(), set_inductor_config=False) + quantize_(mod, int8_weight_only(), set_inductor_config=False) unwrap_tensor_subclass(mod) else: change_linear_weights_to_int8_woqtensors(mod) def _int8da_int8w_api(mod): if TORCH_VERSION_AFTER_2_4: - quantize(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False) + quantize_(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False) unwrap_tensor_subclass(mod) else: change_linear_weights_to_int8_dqtensors(mod) def _int4wo_api(mod): if TORCH_VERSION_AFTER_2_4: - quantize(mod, int4_weight_only(), set_inductor_config=False) + quantize_(mod, int4_weight_only(), set_inductor_config=False) unwrap_tensor_subclass(mod) else: change_linear_weights_to_int4_woqtensors(mod) @@ -127,8 +127,8 @@ def _int4wo_api(mod): def undo_recommended_configs(): torch._inductor.config.coordinate_descent_tuning = False torch._inductor.config.coordinate_descent_check_all_directions = False - torch._inductor.config.force_fuse_int_mm_with_mul = False - torch._inductor.config.fx_graph_cache = False + torch._inductor.config.force_fuse_int_mm_with_mul = False + torch._inductor.config.fx_graph_cache = False torch._inductor.config.triton.unique_kernel_names = False torch.set_float32_matmul_precision("highest") @@ -844,7 +844,7 @@ def api(mod): kwargs_copy = kwargs.copy() kwargs_copy["group_size"] = groupsize del kwargs_copy["groupsize"] - quantize(mod, int4_weight_only(**kwargs_copy)) + quantize_(mod, int4_weight_only(**kwargs_copy)) unwrap_tensor_subclass(mod) else: change_linear_weights_to_int4_woqtensors(mod, **kwargs) @@ -865,7 +865,7 @@ def test_dynamic_quant(self): m = nn.Sequential(nn.Linear(K, N)) y_ref = m(x) - quantize(m, int8_dynamic_activation_int8_weight()) + quantize_(m, int8_dynamic_activation_int8_weight()) y_test = m(x) sqnr = compute_error(y_ref, y_test) @@ -1259,7 +1259,7 @@ def test_autoquant_manual(self, device, dtype): out3 = mod(example_input) sqnr2 = SQNR(out, out3) self.assertTrue(sqnr2 >= 30) - + @parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE, [ diff --git a/test/prototype/test_quant_llm.py b/test/prototype/test_quant_llm.py index 77eac6f69..fab2d972b 100644 --- a/test/prototype/test_quant_llm.py +++ b/test/prototype/test_quant_llm.py @@ -16,7 +16,7 @@ ) from torchao.prototype.quant_llm.quant_llm import _pack_tc_fpx, _pack_tc_fp6 from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32 -from torchao.quantization.quant_api import quantize +from torchao.quantization.quant_api import quantize_ _DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) @@ -91,7 +91,7 @@ def test_quant_llm_quantize(self, ebits, mbits, bias): linear = torch.nn.Linear(IC, OC, bias=bias, device=device) fpx_linear = copy.deepcopy(linear) - quantize(fpx_linear, quant_llm_fpx_weight_only(ebits, mbits)) + quantize_(fpx_linear, quant_llm_fpx_weight_only(ebits, mbits)) x = torch.randn(N, IC, device=device, dtype=torch.half) expected = fpx_linear(x) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index e8b9d606d..b137cd22d 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -31,7 +31,7 @@ Int8WeightOnlyQuantizedLinearWeight, Int4WeightOnlyQuantizedLinearWeight, ) -from torchao import quantize +from torchao import quantize_ from torchao.quantization.quant_api import ( _replace_with_custom_fn_if_matches_filter, Quantizer, @@ -89,7 +89,7 @@ def convert(self, model: torch.nn.Module) -> torch.nn.Module: class TorchCompileDynamicQuantizer(Quantizer): def quantize(self, model: torch.nn.Module) -> torch.nn.Module: - quantize(model, int8_dynamic_activation_int8_weight()) + quantize_(model, int8_dynamic_activation_int8_weight()) return model class ToyLinearModel(torch.nn.Module): @@ -152,7 +152,7 @@ class TestQuantFlow(TestCase): def test_dynamic_quant_gpu_singleline(self): m = ToyLinearModel().eval() example_inputs = m.example_inputs() - m = quantize(m, int8_dynamic_activation_int8_weight()) + quantize_(m, int8_dynamic_activation_int8_weight()) quantized = m(*example_inputs) # AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64 # While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {}) @@ -195,7 +195,7 @@ def test_int8_wo_quant_save_load(self): ) m = ToyLinearModel().eval().cpu() def api(model): - model = quantize(model, int8_weight_only()) + quantize_(model, int8_weight_only()) unwrap_tensor_subclass(model) api(m) @@ -501,7 +501,7 @@ def test_quantized_tensor_subclass_8da4w(self): m = ToyLinearModel().eval() m_copy = copy.deepcopy(m) example_inputs = m.example_inputs() - m = quantize(m, int8_dynamic_activation_int4_weight(group_size=group_size)) + quantize_(m, int8_dynamic_activation_int4_weight(group_size=group_size)) assert isinstance(m.linear1.weight, LinearActQuantizedTensor) assert isinstance(m.linear2.weight, LinearActQuantizedTensor) @@ -530,7 +530,7 @@ def test_quantized_tensor_subclass_int4(self): example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda") group_size = 32 - m = quantize(m, int4_weight_only(group_size=group_size)) + quantize_(m, int4_weight_only(group_size=group_size)) assert isinstance(m.linear1.weight, AffineQuantizedTensor) assert isinstance(m.linear2.weight, AffineQuantizedTensor) @@ -550,7 +550,7 @@ def test_quantized_tensor_subclass_int8_wo(self): m_copy = copy.deepcopy(m) example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs())) - m = quantize(m, int8_weight_only()) + quantize_(m, int8_weight_only()) assert isinstance(m.linear1.weight, AffineQuantizedTensor) assert isinstance(m.linear2.weight, AffineQuantizedTensor) @@ -573,7 +573,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self): m_copy = copy.deepcopy(m) # setting batch_size to 20 to be compatible with the kernel example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda") - m = quantize(m, int8_dynamic_activation_int8_weight()) + quantize_(m, int8_dynamic_activation_int8_weight()) assert isinstance(m.linear1.weight, LinearActQuantizedTensor) assert isinstance(m.linear2.weight, LinearActQuantizedTensor) @@ -607,7 +607,7 @@ def test_quantized_tensor_subclass_save_load(self): m_copy = copy.deepcopy(m) example_inputs = m.example_inputs(dtype=torch.bfloat16) - m = quantize(m, int8_weight_only()) + quantize_(m, int8_weight_only()) ref = m(*example_inputs) with tempfile.NamedTemporaryFile() as f: torch.save(m.state_dict(), f) diff --git a/torchao/__init__.py b/torchao/__init__.py index 3b5a1b3c0..104dc5f31 100644 --- a/torchao/__init__.py +++ b/torchao/__init__.py @@ -30,14 +30,14 @@ from torchao.quantization import ( autoquant, - quantize, + quantize_, ) from . import dtypes __all__ = [ "dtypes", "autoquant", - "quantize", + "quantize_", ] # test-pytorchbot diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index 73deafffe..35e35ecf0 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -13,7 +13,7 @@ ) from torchao.quantization.quant_api import ( - quantize, int4_weight_only, int8_weight_only, int8_dynamic_activation_int8_weight, unwrap_tensor_subclass + quantize_, int4_weight_only, int8_weight_only, int8_dynamic_activation_int8_weight, unwrap_tensor_subclass ) from torchao._models._eval import TransformerEvalWrapper, InputRecorder @@ -60,13 +60,13 @@ def run_evaluation( if quantization: if "int8wo" in quantization: - quantize(model, int8_weight_only()) + quantize_(model, int8_weight_only()) if "int8dq" in quantization: - quantize(model, int8_dynamic_activation_int8_weight()) + quantize_(model, int8_dynamic_activation_int8_weight()) if "int4wo" in quantization and not "gptq" in quantization: groupsize=int(quantization.split("-")[-1]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" - quantize(model.to(device), int4_weight_only(group_size=groupsize)) + quantize_(model.to(device), int4_weight_only(group_size=groupsize)) if "int4wo" in quantization and "gptq" in quantization: groupsize=int(quantization.split("-")[-2]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" @@ -94,8 +94,8 @@ def run_evaluation( model = torch.compile(model, mode="max-autotune", fullgraph=True) with torch.no_grad(): TransformerEvalWrapper( - model=model.to(device), - tokenizer=tokenizer, + model=model.to(device), + tokenizer=tokenizer, max_seq_length=max_length, input_prep_func=prepare_inputs_for_model, device=device, @@ -122,16 +122,16 @@ def run_evaluation( args = parser.parse_args() run_evaluation( - args.checkpoint_path, - args.tasks, - args.limit, - args.device, - args.precision, - args.quantization, - args.compile, - args.max_length, - args.calibration_tasks, - args.calibration_limit, - args.calibration_seq_length, - args.pad_calibration_inputs, + args.checkpoint_path, + args.tasks, + args.limit, + args.device, + args.precision, + args.quantization, + args.compile, + args.max_length, + args.calibration_tasks, + args.calibration_limit, + args.calibration_seq_length, + args.pad_calibration_inputs, ) diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 8142f80bb..34ff9abb1 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -100,7 +100,7 @@ def generate( T_new = T + max_new_tokens seq = torch.empty(T_new, dtype=prompt.dtype, device=device) seq[:T] = prompt.view(-1) - + # setup model cache max_seq_length = min(T_new, model.config.block_size) if not interactive else 350 with torch.device(device): @@ -158,7 +158,7 @@ def main( """ torchao.quantization.utils.recommended_inductor_config_setter() - + assert checkpoint_path.is_file(), checkpoint_path tokenizer_path = checkpoint_path.parent / "tokenizer.model" assert tokenizer_path.is_file(), str(tokenizer_path) @@ -180,11 +180,11 @@ def main( prompt_length = encoded.size(0) torch.manual_seed(1234) - + if quantization: from torchao.quantization.quant_api import ( - quantize, + quantize_, int8_weight_only, int8_dynamic_activation_int8_weight, int4_weight_only, @@ -193,13 +193,13 @@ def main( ) if "int8wo" in quantization: - quantize(model, int8_weight_only()) + quantize_(model, int8_weight_only()) if "int8dq" in quantization: - quantize(model, int8_dynamic_activation_int8_weight()) + quantize_(model, int8_dynamic_activation_int8_weight()) if "int4wo" in quantization: groupsize=int(quantization.split("-")[-1]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" - quantize(model, int4_weight_only(group_size=groupsize)) + quantize_(model, int4_weight_only(group_size=groupsize)) if "autoquant" == quantization: model = autoquant(model, manual=True) diff --git a/torchao/prototype/quant_llm/README.md b/torchao/prototype/quant_llm/README.md index 631df3081..f0ecd38d5 100644 --- a/torchao/prototype/quant_llm/README.md +++ b/torchao/prototype/quant_llm/README.md @@ -5,15 +5,15 @@ This is a FP16 x FPx mixed matmul kernel optimized for io bound workloads per [F ## Usage ```python -from torchao.quantization.quant_api import quantize +from torchao.quantization.quant_api import quantize_ from torchao.prototype.quant_llm import fp6_llm_weight_only, quant_llm_fpx_weight_only model = ... model.half() # not necessary, but recommeneded to maintain accuracy -quantize(model, fp6_llm_weight_only()) # convert nn.Lineaer.weight to FP6 E3M2 in-place +quantize_(model, fp6_llm_weight_only()) # convert nn.Lineaer.weight to FP6 E3M2 in-place # for generic FPx EyMz where x = 1 + y + z -# quantize(model, quant_llm_fpx_weight_only(2, 2)) # use FP5 E2M2 instead +# quantize_(model, quant_llm_fpx_weight_only(2, 2)) # use FP5 E2M2 instead # fully compatible with torch.compile() model.compile(mode="max-autotune", fullgraph=True) diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 76e7cd9ff..4765d6a5f 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -74,7 +74,7 @@ from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain from torchao.dtypes import to_affine_quantized import copy from torchao.quantization.quant_api import ( - quantize, + quantize_, int4_weight_only, ) @@ -101,7 +101,7 @@ m_bf16 = torch.compile(m_bf16, mode='max-autotune') # apply int4 weight only quant (compatible with tinygemm int4 weight only quant mm kernel in torchao) group_size = 32 # only works for torch 2.4+ -m = quantize(m, int4_weight_only(group_size=group_size)) +quantize_(m, int4_weight_only(group_size=group_size)) # temporary workaround for tensor subclass + torch.compile from torchao.utils import unwrap_tensor_subclass @@ -168,7 +168,7 @@ torch._inductor.config.force_fuse_int_mm_with_mul = True # for torch 2.4+ from torchao.quantization import quantize, int8_dynamic_activation_int8_weight -quantize(model, int8_dynamic_activation_int8_weight()) +quantize_(model, int8_dynamic_activation_int8_weight()) # for torch 2.2.2 and 2.3 from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors @@ -180,7 +180,7 @@ change_linear_weights_to_int8_dqtensors(model) ```python # for torch 2.4+ from torchao.quantization import quantize, int8_weight_only -quantize(model, int8_weight_only()) +quantize_(model, int8_weight_only()) # for torch 2.2.2 and 2.3 from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors @@ -195,7 +195,7 @@ This technique works best when the torch._inductor.config.use_mixed_mm option is ```python # for torch 2.4+ from torchao.quantization import quantize, int4_weight_only -quantize(model, int4_weight_only()) +quantize_(model, int4_weight_only()) # for torch 2.2.2 and 2.3 from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 115062c8f..a1cf1bf03 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -29,7 +29,7 @@ "quantize_affine", "dequantize_affine", "choose_qprams_affine", - "quantize", + "quantize_", "int8_dynamic_activation_int4_weight", "int8_dynamic_activation_int8_weight", "int4_weight_only", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 31ab71f38..3da530b94 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -54,7 +54,7 @@ "Int4WeightOnlyQuantizer", "autoquant", "_get_subclass_inserter", - "quantize", + "quantize_", "int8_dynamic_activation_int4_weight", "int8_dynamic_activation_int8_weight", "int4_weight_only", @@ -259,8 +259,8 @@ def insert_subclass(lin): return insert_subclass -def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None, set_inductor_config: bool=True) -> torch.nn.Module: - """Convert the weight of linear modules in the model with `apply_tensor_subclass` +def quantize_(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None, set_inductor_config: bool=True): + """Convert the weight of linear modules in the model with `apply_tensor_subclass`, model is modified inplace Args: model (torch.nn.Module): input model @@ -273,7 +273,7 @@ def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tens import torch import torch.nn as nn - from torchao import quantize + from torchao import quantize_ # 1. quantize with some predefined `apply_tensor_subclass` method that corresponds to # optimized execution paths or kernels (e.g. int4 tinygemm kernel) @@ -286,7 +286,7 @@ def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tens from torchao.quantization.quant_api import int4_weight_only m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) - m = quantize(m, int4_weight_only(group_size=32)) + quantize_(m, int4_weight_only(group_size=32)) # 2. write your own new apply_tensor_subclass # You can also add your own apply_tensor_subclass by manually calling tensor subclass constructor @@ -305,7 +305,7 @@ def filter_fn(module: nn.Module, fqn: str) -> bool: return isinstance(module, nn.Linear) m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) - m = quantize(m, apply_weight_quant, filter_fn) + quantize_(m, apply_weight_quant, filter_fn) """ if set_inductor_config: @@ -315,7 +315,7 @@ def filter_fn(module: nn.Module, fqn: str) -> bool: _get_linear_subclass_inserter(apply_tensor_subclass), _is_linear if filter_fn is None else filter_fn, ) - return model + def int8_dynamic_activation_int4_weight(group_size=32): """Applies int8 dynamic per token asymmetric activation quantization and int4 per group weight symmetric quantization to linear diff --git a/tutorials/quantize_vit/run_vit_b_quant.py b/tutorials/quantize_vit/run_vit_b_quant.py index 07e0118d2..a082cfe53 100644 --- a/tutorials/quantize_vit/run_vit_b_quant.py +++ b/tutorials/quantize_vit/run_vit_b_quant.py @@ -19,9 +19,9 @@ # for APIs for earlier torch version and other quantization techniques # for torch 2.4+ -from torchao.quantization.quant_api import quantize +from torchao.quantization.quant_api import quantize_ from torchao.quantization.quant_api import int8_dynamic_activation_int8_weight -quantize(model, int8_dynamic_activation_int8_weight()) +quantize_(model, int8_dynamic_activation_int8_weight()) ## Quantization code - end ## compilation configs