From d2168f2301977cbff5b33c14a4ea8584b95e8e77 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 5 Sep 2025 16:26:43 -0700 Subject: [PATCH] Bump int4 weight only config version to 2 Summary: Current Int4WeightOnlyConfig has version 1 and 2, and default is 1, this PR changes the default to 2 and made modification to callsites. For the Int4WeightOnlyConfig that's using the old configuration, we added explicit `version=1`, we can migrate the callsite to use the version 2 separately For READMEs we migrate the usage to version 2 directly Deprecation: TODO Test Plan: Regression tests: python test/dtypes/test_affine_quantized.py python test/quantization/test_quant_api.py python test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py python test/quantization/quantize_/workflows/int4/test_int4_opaque_tensor.py python test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py python test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py python test/quantization/quantize_/workflows/int4/test_int4_tensor.py python test/quantization/quantize_/workflows/int4/test_int4_tile_packed_to_4d_tensor.py Reviewers: Subscribers: Tasks: Tags: --- .../quantize_and_upload.py | 9 +- .../test_affine_quantized_tensor_parallel.py | 1 + .../test_load_and_run_checkpoint.py | 111 +++++++++++++++--- test/prototype/test_awq.py | 6 +- .../int4/test_int4_marlin_sparse_tensor.py | 1 - .../workflows/int4/test_int4_opaque_tensor.py | 1 - .../int4/test_int4_plain_int32_tensor.py | 1 - .../int4/test_int4_preshuffled_tensor.py | 1 - .../workflows/int4/test_int4_tensor.py | 1 - .../test_int4_tile_packed_to_4d_tensor.py | 2 - torchao/dtypes/uintx/int4_cpu_layout.py | 4 + torchao/dtypes/uintx/int4_xpu_layout.py | 4 + torchao/dtypes/uintx/marlin_sparse_layout.py | 4 + .../dtypes/uintx/tensor_core_tiled_layout.py | 4 + torchao/prototype/awq/example.py | 2 +- torchao/quantization/README.md | 12 +- torchao/quantization/quant_api.py | 9 +- torchao/sparsity/README.md | 3 +- 18 files changed, 131 insertions(+), 45 deletions(-) diff --git a/.github/scripts/torchao_model_releases/quantize_and_upload.py b/.github/scripts/torchao_model_releases/quantize_and_upload.py index e118cf8002..24f19fe6ec 100644 --- a/.github/scripts/torchao_model_releases/quantize_and_upload.py +++ b/.github/scripts/torchao_model_releases/quantize_and_upload.py @@ -206,7 +206,7 @@ def _untie_weights_and_save_locally(model_id): _int4_quant_code = """ from torchao.quantization import Int4WeightOnlyConfig -quant_config = Int4WeightOnlyConfig(group_size=128, packing_format="tile_packed_to_4d", int4_choose_qparams_algorithm="hqq", version=2) +quant_config = Int4WeightOnlyConfig(group_size=128, int4_packing_format="tile_packed_to_4d", int4_choose_qparams_algorithm="hqq") quantization_config = TorchAoConfig(quant_type=quant_config) quantized_model = AutoModelForCausalLM.from_pretrained(model_to_quantize, device_map="auto", torch_dtype=torch.bfloat16, quantization_config=quantization_config) tokenizer = AutoTokenizer.from_pretrained(model_id) @@ -256,7 +256,7 @@ def _untie_weights_and_save_locally(model_id): ) tokenizer = AutoTokenizer.from_pretrained(model_id) -base_config = Int4WeightOnlyConfig(group_size=128, version=2) +base_config = Int4WeightOnlyConfig(group_size=128) quant_config = AWQConfig(base_config, step="prepare") quantize_( model, @@ -633,9 +633,8 @@ def quantize_and_upload( "FP8": Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), "INT4": Int4WeightOnlyConfig( group_size=128, - packing_format="tile_packed_to_4d", + int4_packing_format="tile_packed_to_4d", int4_choose_qparams_algorithm="hqq", - version=2, ), "INT8-INT4": ModuleFqnToConfig( { @@ -669,7 +668,7 @@ def quantize_and_upload( ) tokenizer = AutoTokenizer.from_pretrained(model_id) - base_config = Int4WeightOnlyConfig(group_size=128, version=2) + base_config = Int4WeightOnlyConfig(group_size=128) quant_config = AWQConfig(base_config, step="prepare") quantize_( model, diff --git a/test/dtypes/test_affine_quantized_tensor_parallel.py b/test/dtypes/test_affine_quantized_tensor_parallel.py index fd5f43a470..49471d3ad1 100644 --- a/test/dtypes/test_affine_quantized_tensor_parallel.py +++ b/test/dtypes/test_affine_quantized_tensor_parallel.py @@ -145,6 +145,7 @@ def test_tp(self, dtype): class TestInt4woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel): QUANT_METHOD_FN = staticmethod(int4_weight_only) + QUANT_METHOD_KWARGS = {"version": 1} COMMON_DTYPES = [torch.bfloat16] @common_utils.parametrize("dtype", COMMON_DTYPES) diff --git a/test/integration/test_load_and_run_checkpoint.py b/test/integration/test_load_and_run_checkpoint.py index d18feaef9a..58c43d9008 100644 --- a/test/integration/test_load_and_run_checkpoint.py +++ b/test/integration/test_load_and_run_checkpoint.py @@ -24,9 +24,22 @@ # please check model card for how to generate these models -_DEPRECATED_SINGLE_LINEAR_MODEL_NAMES = [ +# high precision model, used for testing config deprecation warning +_HIGH_PRECISION_MODEL = "facebook/opt-125m" + +_DEPRECATED_SINGLE_LINEAR_MODEL_INFO = [ # model card: https://huggingface.co/torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v1-0.13.dev - "torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v1-0.13.dev" + ( + "torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v1-0.13.dev", + 1, + "Float8DynamicActivationFloat8WeightConfig", + ), + # model card: https://huggingface.co/torchao-testing/single-linear-Int4WeightOnlyConfig-v1-0.14.dev + ( + "torchao-testing/single-linear-Int4WeightOnlyConfig-v1-0.14.dev", + 1, + "Int4WeightOnlyConfig", + ), ] _DEPRECATED_MODEL_INFO = [ @@ -36,15 +49,33 @@ 1, "Float8DynamicActivationFloat8WeightConfig", ), + # model card: https://huggingface.co/torchao-testing/opt-125m-Int4WeightOnlyConfig-v1-0.14.dev + ( + "torchao-testing/opt-125m-Int4WeightOnlyConfig-v1-0.14.dev", + 1, + "Int4WeightOnlyConfig", + ), ] -_SINGLE_LINEAR_MODEL_NAMES = [ +_SINGLE_LINEAR_MODEL_INFO = [ # model card: https://huggingface.co/torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v2-0.13.dev - "torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v2-0.13.dev", + ( + "torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v2-0.13.dev", + 2, + "Float8DynamicActivationFloat8WeightConfig", + ), # model card: https://huggingface.co/torchao-testing/single-linear-Int4WeightOnlyConfig-v2-0.13.dev - "torchao-testing/single-linear-Int4WeightOnlyConfig-v2-0.13.dev", + ( + "torchao-testing/single-linear-Int4WeightOnlyConfig-v2-0.13.dev", + 2, + "Int4WeightOnlyConfig", + ), # model card: https://huggingface.co/torchao-testing/single-linear-Int4WeightOnlyConfig-preshuffled-v2-0.13.dev - "torchao-testing/single-linear-Int4WeightOnlyConfig-preshuffled-v2-0.13.dev", + ( + "torchao-testing/single-linear-Int4WeightOnlyConfig-preshuffled-v2-0.13.dev", + 2, + "Int4WeightOnlyConfig", + ), ] @@ -55,7 +86,9 @@ "Skipping the test in fbcode for now, not sure how to download from transformers", ) class TestLoadAndRunCheckpoint(TestCase): - def _test_single_linear_helper(self, model_name): + def _test_single_linear_helper( + self, model_name, version, config_name, is_deprecated + ): from huggingface_hub import hf_hub_download downloaded_model = hf_hub_download(model_name, filename="model.pt") @@ -69,8 +102,20 @@ def _test_single_linear_helper(self, model_name): model = torch.nn.Sequential( torch.nn.Linear(32, 256, dtype=torch.bfloat16, device="cuda") ) - with open(downloaded_model, "rb") as f: + + with ( + open(downloaded_model, "rb") as f, + warnings.catch_warnings(record=True) as caught_warnings, + ): model.load_state_dict(torch.load(f), assign=True) + if is_deprecated: + assert any( + f"Models quantized with version {version} of {config_name} is deprecated" + in str(w.message) + for w in caught_warnings + ), ( + f"Didn't get expected warning message for deprecation for model: {model_name}" + ) downloaded_example_inputs = hf_hub_download( model_name, filename="model_inputs.pt" @@ -84,17 +129,23 @@ def _test_single_linear_helper(self, model_name): output = model(*example_inputs) self.assertTrue(torch.equal(output, ref_output)) - @common_utils.parametrize("model_name", _DEPRECATED_SINGLE_LINEAR_MODEL_NAMES) - def test_deprecated_single_linear(self, model_name): - self._test_single_linear_helper(model_name) + @common_utils.parametrize("model_info", _DEPRECATED_SINGLE_LINEAR_MODEL_INFO) + def test_deprecated_single_linear(self, model_info): + model_name, version, config_name = model_info + self._test_single_linear_helper( + model_name, version, config_name, is_deprecated=True + ) - @common_utils.parametrize("model_name", _SINGLE_LINEAR_MODEL_NAMES) - def test_single_linear(self, model_name): + @common_utils.parametrize("model_info", _SINGLE_LINEAR_MODEL_INFO) + def test_single_linear(self, model_info): """Test that we can load and run the quantized linear checkpoint with saved sample input and match the saved output, to make sure there is no BC breaking changes when we make changes to tensor subclass implementations """ - self._test_single_linear_helper(model_name) + model_name, version, config_name = model_info + self._test_single_linear_helper( + model_name, version, config_name, is_deprecated=False + ) @common_utils.parametrize("model_info", _DEPRECATED_MODEL_INFO) def test_deprecated_hf_models(self, model_info): @@ -109,17 +160,23 @@ def test_deprecated_hf_models(self, model_info): torch_dtype="bfloat16", device_map="cuda:0", ) + # version mismatch check in config.py assert any( "Stored version is not the same as current default version of the config" in str(w.message) for w in caught_warnings - ), "Didn't get expected warning message for version mismatch" + ), ( + f"Didn't get expected warning message for version mismatch for config {config_name}, model {model_name}" + ) + # checkpoint deprecation assert any( - f"Models quantized with version 1 of {config_name} is deprecated" + f"Models quantized with version {version} of {config_name} is deprecated" in str(w.message) for w in caught_warnings - ), "Didn't get expected warning message for deprecation" + ), ( + f"Didn't get expected warning message for deprecation for model {model_name}" + ) assert isinstance(quantized_model.config.quantization_config, TorchAoConfig) assert ( quantized_model.config.quantization_config.quant_type.version == version @@ -139,7 +196,8 @@ def test_deprecated_hf_models(self, model_info): return_tensors="pt", ).to("cuda") generated_ids = quantized_model.generate( - **inputs, max_new_tokens=128, temperature=0 + **inputs, + max_new_tokens=128, ) downloaded_output = hf_hub_download(model_name, filename="model_output.pt") @@ -153,6 +211,23 @@ def test_deprecated_hf_models(self, model_info): generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False ) + # make sure we throw warning for config deprecation + with warnings.catch_warnings(record=True) as caught_warnings: + _ = AutoModelForCausalLM.from_pretrained( + _HIGH_PRECISION_MODEL, + torch_dtype="bfloat16", + device_map="cuda:0", + quantization_config=quantized_model.config.quantization_config, + ) + # config version deprecation in quant_api.py + assert any( + f"Config Deprecation: version {version} of {config_name} is deprecated and will no longer be supported in a future release" + in str(w.message) + for w in caught_warnings + ), ( + f"Didn't get expected warning message for version deprecation for config {config_name}, model {model_name}" + ) + common_utils.instantiate_parametrized_tests(TestLoadAndRunCheckpoint) diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index e6bd573029..0f18be5d01 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -73,7 +73,7 @@ def test_awq_functionality(self): m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device) # baseline quantization - base_config = Int4WeightOnlyConfig(group_size=group_size, version=2) + base_config = Int4WeightOnlyConfig(group_size=group_size) m_baseline = copy.deepcopy(m) quantize_(m_baseline, base_config) @@ -123,7 +123,7 @@ def test_awq_loading(self): calibration_data = dataset[:n_calibration_examples] # calibrate - base_config = Int4WeightOnlyConfig(group_size=group_size, version=2) + base_config = Int4WeightOnlyConfig(group_size=group_size) quant_config = AWQConfig(base_config, step=AWQStep.PREPARE) quantize_(m, quant_config) @@ -177,7 +177,7 @@ def test_awq_loading_vllm(self): calibration_data = dataset[:n_calibration_examples] # calibrate - base_config = Int4WeightOnlyConfig(group_size=group_size, version=2) + base_config = Int4WeightOnlyConfig(group_size=group_size) quant_config = AWQConfig(base_config, step=AWQStep.PREPARE) quantize_(m, quant_config) diff --git a/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py index d6961dfa23..56994b2639 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py @@ -27,7 +27,6 @@ BF16_ACT_CONFIG = Int4WeightOnlyConfig( group_size=128, int4_packing_format="marlin_sparse", - version=2, ) diff --git a/test/quantization/quantize_/workflows/int4/test_int4_opaque_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_opaque_tensor.py index 0b3e84fb77..3f6a8846d0 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_opaque_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_opaque_tensor.py @@ -29,7 +29,6 @@ def get_config(group_size): return Int4WeightOnlyConfig( group_size=group_size, int4_packing_format="opaque", - version=2, ) diff --git a/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py index 728ebd880a..82a10916fa 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py @@ -29,7 +29,6 @@ def get_config(group_size): return Int4WeightOnlyConfig( group_size=group_size, int4_packing_format="plain_int32", - version=2, ) diff --git a/test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py index 4760f75257..df25b650b2 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py @@ -30,7 +30,6 @@ BF16_ACT_CONFIG = Int4WeightOnlyConfig( group_size=128, int4_packing_format="preshuffled", - version=2, ) # only 128 group_size is supported diff --git a/test/quantization/quantize_/workflows/int4/test_int4_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_tensor.py index a971db609e..f438d9c3db 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_tensor.py @@ -35,7 +35,6 @@ def setUp(self): self.config = Int4WeightOnlyConfig( group_size=128, int4_packing_format="plain", - version=2, ) self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] diff --git a/test/quantization/quantize_/workflows/int4/test_int4_tile_packed_to_4d_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_tile_packed_to_4d_tensor.py index 64519e327a..9fe9fddfb8 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_tile_packed_to_4d_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_tile_packed_to_4d_tensor.py @@ -25,14 +25,12 @@ INT4_CONFIG = Int4WeightOnlyConfig( group_size=128, int4_packing_format="tile_packed_to_4d", - version=2, ) INT4_HQQ_CONFIG = Int4WeightOnlyConfig( group_size=128, int4_packing_format="tile_packed_to_4d", int4_choose_qparams_algorithm="hqq", - version=2, ) diff --git a/torchao/dtypes/uintx/int4_cpu_layout.py b/torchao/dtypes/uintx/int4_cpu_layout.py index cd09eec452..1ae9dca3b6 100644 --- a/torchao/dtypes/uintx/int4_cpu_layout.py +++ b/torchao/dtypes/uintx/int4_cpu_layout.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import warnings from dataclasses import dataclass from typing import Optional, Tuple @@ -78,6 +79,9 @@ def __init__( transposed: bool, _layout: Layout, ): + warnings.warn( + "Models quantized with version 1 of Int4WeightOnlyConfig is deprecated and will no longer be supported in a future release, please upgrade torchao and quantize again, or download a newer torchao checkpoint, see https://github.com/pytorch/ao/issues/2948 for more details" + ) self.packed_weight = packed_weight self.scale_and_zero = scale_and_zero self.transposed = False diff --git a/torchao/dtypes/uintx/int4_xpu_layout.py b/torchao/dtypes/uintx/int4_xpu_layout.py index a01fad31c2..ff6dc68813 100644 --- a/torchao/dtypes/uintx/int4_xpu_layout.py +++ b/torchao/dtypes/uintx/int4_xpu_layout.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import warnings from dataclasses import dataclass from typing import Optional, Tuple @@ -207,6 +208,9 @@ def __init__( scale: torch.Tensor = None, zero: torch.Tensor = None, ): + warnings.warn( + "Models quantized with version 1 of Int4WeightOnlyConfig is deprecated and will no longer be supported in a future release, please upgrade torchao and quantize again, or download a newer torchao checkpoint, see https://github.com/pytorch/ao/issues/2948 for more details" + ) self.packed_weight = packed_weight self.scale_and_zero = scale_and_zero self.transposed = False diff --git a/torchao/dtypes/uintx/marlin_sparse_layout.py b/torchao/dtypes/uintx/marlin_sparse_layout.py index af1f8040f6..cba2428d94 100644 --- a/torchao/dtypes/uintx/marlin_sparse_layout.py +++ b/torchao/dtypes/uintx/marlin_sparse_layout.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import warnings from dataclasses import dataclass import torch @@ -158,6 +159,9 @@ def __init__( group_size: int, num_bits: int, ): + warnings.warn( + "Models quantized with version 1 of Int4WeightOnlyConfig is deprecated and will no longer be supported in a future release, please upgrade torchao and quantize again, or download a newer torchao checkpoint, see https://github.com/pytorch/ao/issues/2948 for more details" + ) self.int_data = int_data self.scale_and_zero = None self.scale = scale diff --git a/torchao/dtypes/uintx/tensor_core_tiled_layout.py b/torchao/dtypes/uintx/tensor_core_tiled_layout.py index 992294b766..1961cc33c5 100644 --- a/torchao/dtypes/uintx/tensor_core_tiled_layout.py +++ b/torchao/dtypes/uintx/tensor_core_tiled_layout.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. import logging +import warnings from dataclasses import dataclass from typing import Optional, Tuple @@ -237,6 +238,9 @@ def __init__( transposed: bool, _layout: Layout, ): + warnings.warn( + "Models quantized with version 1 of Int4WeightOnlyConfig is deprecated and will no longer be supported in a future release, please upgrade torchao and quantize again, or download a newer torchao checkpoint, see https://github.com/pytorch/ao/issues/2948 for more details" + ) self.packed_weight = packed_weight self.scale_and_zero = scale_and_zero self.transposed = False diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index 222e184075..cc7f530b6f 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -226,7 +226,7 @@ def quantize_and_eval( # TODO: this is temporary, we'll be using Int4WeightOnlyConfig soon from torchao.quantization import Int4WeightOnlyConfig - base_config = Int4WeightOnlyConfig(group_size=group_size, version=2) + base_config = Int4WeightOnlyConfig(group_size=group_size) print(f"running {quant} prepare and calibrate") t0 = time.time() quant_config = AWQConfig(base_config, step="prepare") diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 5c9ec82de7..e1e4c20a31 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -129,9 +129,9 @@ from torchao.quantization import quantize_, Int4WeightOnlyConfig group_size = 32 # you can enable [hqq](https://github.com/mobiusml/hqq/tree/master) quantization which is expected to improves accuracy through -# use_hqq flag for `Int4WeightOnlyConfig` quantization +# by setting int4_choose_qparams_algorithm to "hqq" for `Int4WeightOnlyConfig` quantization use_hqq = False -quantize_(model, Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq, version=1)) +quantize_(model, Int4WeightOnlyConfig(group_size=group_size, int4_packing_format="tile_packed_to_4d", int4_choose_qparams_algorithm="hqq")) ``` Note: The quantization error incurred by applying int4 quantization to your model can be fairly significant, so using external techniques like GPTQ may be necessary to obtain a usable model. @@ -150,7 +150,7 @@ from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfi quantize_(model, Int8DynamicActivationInt8WeightConfig()) ``` -### A16W8 Float8 WeightOnly Quantization +#### A16W8 Float8 WeightOnly Quantization ```python # for torch 2.5+ @@ -285,9 +285,9 @@ m_bf16 = torch.compile(m_bf16, mode='max-autotune') # apply int4 weight only quant (compatible with tinygemm int4 weight only quant mm kernel in torchao) group_size = 32 # only works for torch 2.4+ -quantize_(m, Int4WeightOnlyConfig(group_size=group_size, version=1)) -## If different zero_point_domain needed -# quantize_(m, Int4WeightOnlyConfig(group_size=group_size, zero_point_domain=ZeroPointDomain.FLOAT, version=1)) +quantize_(m, Int4WeightOnlyConfig(group_size=group_size, int4_packing_format="tile_packed_to_4d")) +# can also specify different packing format +# quantize_(m, Int4WeightOnlyConfig(group_size=group_size, int4_packing_format="plain")) # compile the model to improve performance m = torch.compile(m, mode='max-autotune') diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 67216a6d94..c6122e6cb3 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1092,7 +1092,7 @@ class Int4WeightOnlyConfig(AOBaseConfig): int4_choose_qparams_algorithm: Int4ChooseQParamsAlgorithm = ( Int4ChooseQParamsAlgorithm.TINYGEMM ) - version: int = 1 + version: int = 2 def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.Int4WeightOnlyConfig") @@ -1175,6 +1175,9 @@ def _int4_weight_only_quantize_tensor(weight, config): assert config.version == 1 + warnings.warn( + "Config Deprecation: version 1 of Int4WeightOnlyConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2948 for more details" + ) mapping_type = MappingType.ASYMMETRIC target_dtype = torch.int32 quant_min = 0 @@ -1583,7 +1586,7 @@ def __post_init__(self): def _float8_weight_only_quant_tensor(weight, config): if config.version == 1: warnings.warn( - "version 1 of Float8WeightOnlyConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2649 for more details" + "Config Deprecation: version 1 of Float8WeightOnlyConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2649 for more details" ) from torchao.dtypes import to_affine_quantized_floatx @@ -1763,7 +1766,7 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): if config.version == 1: warnings.warn( - "version 1 of Float8DynamicActivationFloat8WeightConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2649 for more details" + "Config Deprecation: version 1 of Float8DynamicActivationFloat8WeightConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2649 for more details" ) block_size = get_block_size(weight.shape[-2:], weight_granularity) diff --git a/torchao/sparsity/README.md b/torchao/sparsity/README.md index e32a2f706d..2c62c2738a 100644 --- a/torchao/sparsity/README.md +++ b/torchao/sparsity/README.md @@ -53,11 +53,10 @@ Sparse-Marlin 2:4 is an optimized GPU kernel that extends the Mixed Auto-Regress ```py from torchao.quantization.quant_api import quantize_, Int4WeightOnlyConfig -from torchao.dtypes import MarlinSparseLayout # Your FP16 model model = model.cuda().half() -quantize_(model, Int4WeightOnlyConfig(layout=MarlinSparseLayout(), version=1)) +quantize_(model, Int4WeightOnlyConfig(int4_packing_format="marlin_sparse")) ``` Note the existing API results in an extremely high accuracy degredation and is intended to be used in concert with an already sparsified+finetuned checkpoint where possible until we develop