Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions .github/scripts/torchao_model_releases/quantize_and_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def _untie_weights_and_save_locally(model_id):

_int4_quant_code = """
from torchao.quantization import Int4WeightOnlyConfig
quant_config = Int4WeightOnlyConfig(group_size=128, packing_format="tile_packed_to_4d", int4_choose_qparams_algorithm="hqq", version=2)
quant_config = Int4WeightOnlyConfig(group_size=128, int4_packing_format="tile_packed_to_4d", int4_choose_qparams_algorithm="hqq")
quantization_config = TorchAoConfig(quant_type=quant_config)
quantized_model = AutoModelForCausalLM.from_pretrained(model_to_quantize, device_map="auto", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_id)
Expand Down Expand Up @@ -256,7 +256,7 @@ def _untie_weights_and_save_locally(model_id):
)
tokenizer = AutoTokenizer.from_pretrained(model_id)

base_config = Int4WeightOnlyConfig(group_size=128, version=2)
base_config = Int4WeightOnlyConfig(group_size=128)
quant_config = AWQConfig(base_config, step="prepare")
quantize_(
model,
Expand Down Expand Up @@ -633,9 +633,8 @@ def quantize_and_upload(
"FP8": Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),
"INT4": Int4WeightOnlyConfig(
group_size=128,
packing_format="tile_packed_to_4d",
int4_packing_format="tile_packed_to_4d",
int4_choose_qparams_algorithm="hqq",
version=2,
),
"INT8-INT4": ModuleFqnToConfig(
{
Expand Down Expand Up @@ -669,7 +668,7 @@ def quantize_and_upload(
)
tokenizer = AutoTokenizer.from_pretrained(model_id)

base_config = Int4WeightOnlyConfig(group_size=128, version=2)
base_config = Int4WeightOnlyConfig(group_size=128)
quant_config = AWQConfig(base_config, step="prepare")
quantize_(
model,
Expand Down
1 change: 1 addition & 0 deletions test/dtypes/test_affine_quantized_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def test_tp(self, dtype):

class TestInt4woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel):
QUANT_METHOD_FN = staticmethod(int4_weight_only)
QUANT_METHOD_KWARGS = {"version": 1}
COMMON_DTYPES = [torch.bfloat16]

@common_utils.parametrize("dtype", COMMON_DTYPES)
Expand Down
111 changes: 93 additions & 18 deletions test/integration/test_load_and_run_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,22 @@

# please check model card for how to generate these models

_DEPRECATED_SINGLE_LINEAR_MODEL_NAMES = [
# high precision model, used for testing config deprecation warning
_HIGH_PRECISION_MODEL = "facebook/opt-125m"

_DEPRECATED_SINGLE_LINEAR_MODEL_INFO = [
# model card: https://huggingface.co/torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v1-0.13.dev
"torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v1-0.13.dev"
(
"torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v1-0.13.dev",
1,
"Float8DynamicActivationFloat8WeightConfig",
),
# model card: https://huggingface.co/torchao-testing/single-linear-Int4WeightOnlyConfig-v1-0.14.dev
(
"torchao-testing/single-linear-Int4WeightOnlyConfig-v1-0.14.dev",
1,
"Int4WeightOnlyConfig",
),
]

_DEPRECATED_MODEL_INFO = [
Expand All @@ -36,15 +49,33 @@
1,
"Float8DynamicActivationFloat8WeightConfig",
),
# model card: https://huggingface.co/torchao-testing/opt-125m-Int4WeightOnlyConfig-v1-0.14.dev
(
"torchao-testing/opt-125m-Int4WeightOnlyConfig-v1-0.14.dev",
1,
"Int4WeightOnlyConfig",
),
]

_SINGLE_LINEAR_MODEL_NAMES = [
_SINGLE_LINEAR_MODEL_INFO = [
# model card: https://huggingface.co/torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v2-0.13.dev
"torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v2-0.13.dev",
(
"torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v2-0.13.dev",
2,
"Float8DynamicActivationFloat8WeightConfig",
),
# model card: https://huggingface.co/torchao-testing/single-linear-Int4WeightOnlyConfig-v2-0.13.dev
"torchao-testing/single-linear-Int4WeightOnlyConfig-v2-0.13.dev",
(
"torchao-testing/single-linear-Int4WeightOnlyConfig-v2-0.13.dev",
2,
"Int4WeightOnlyConfig",
),
# model card: https://huggingface.co/torchao-testing/single-linear-Int4WeightOnlyConfig-preshuffled-v2-0.13.dev
"torchao-testing/single-linear-Int4WeightOnlyConfig-preshuffled-v2-0.13.dev",
(
"torchao-testing/single-linear-Int4WeightOnlyConfig-preshuffled-v2-0.13.dev",
2,
"Int4WeightOnlyConfig",
),
]


Expand All @@ -55,7 +86,9 @@
"Skipping the test in fbcode for now, not sure how to download from transformers",
)
class TestLoadAndRunCheckpoint(TestCase):
def _test_single_linear_helper(self, model_name):
def _test_single_linear_helper(
self, model_name, version, config_name, is_deprecated
):
from huggingface_hub import hf_hub_download

downloaded_model = hf_hub_download(model_name, filename="model.pt")
Expand All @@ -69,8 +102,20 @@ def _test_single_linear_helper(self, model_name):
model = torch.nn.Sequential(
torch.nn.Linear(32, 256, dtype=torch.bfloat16, device="cuda")
)
with open(downloaded_model, "rb") as f:

with (
open(downloaded_model, "rb") as f,
warnings.catch_warnings(record=True) as caught_warnings,
):
model.load_state_dict(torch.load(f), assign=True)
if is_deprecated:
assert any(
f"Models quantized with version {version} of {config_name} is deprecated"
in str(w.message)
for w in caught_warnings
), (
f"Didn't get expected warning message for deprecation for model: {model_name}"
)

downloaded_example_inputs = hf_hub_download(
model_name, filename="model_inputs.pt"
Expand All @@ -84,17 +129,23 @@ def _test_single_linear_helper(self, model_name):
output = model(*example_inputs)
self.assertTrue(torch.equal(output, ref_output))

@common_utils.parametrize("model_name", _DEPRECATED_SINGLE_LINEAR_MODEL_NAMES)
def test_deprecated_single_linear(self, model_name):
self._test_single_linear_helper(model_name)
@common_utils.parametrize("model_info", _DEPRECATED_SINGLE_LINEAR_MODEL_INFO)
def test_deprecated_single_linear(self, model_info):
model_name, version, config_name = model_info
self._test_single_linear_helper(
model_name, version, config_name, is_deprecated=True
)

@common_utils.parametrize("model_name", _SINGLE_LINEAR_MODEL_NAMES)
def test_single_linear(self, model_name):
@common_utils.parametrize("model_info", _SINGLE_LINEAR_MODEL_INFO)
def test_single_linear(self, model_info):
"""Test that we can load and run the quantized linear checkpoint with saved sample input
and match the saved output, to make sure there is no BC breaking changes
when we make changes to tensor subclass implementations
"""
self._test_single_linear_helper(model_name)
model_name, version, config_name = model_info
self._test_single_linear_helper(
model_name, version, config_name, is_deprecated=False
)

@common_utils.parametrize("model_info", _DEPRECATED_MODEL_INFO)
def test_deprecated_hf_models(self, model_info):
Expand All @@ -109,17 +160,23 @@ def test_deprecated_hf_models(self, model_info):
torch_dtype="bfloat16",
device_map="cuda:0",
)
# version mismatch check in config.py
assert any(
"Stored version is not the same as current default version of the config"
in str(w.message)
for w in caught_warnings
), "Didn't get expected warning message for version mismatch"
), (
f"Didn't get expected warning message for version mismatch for config {config_name}, model {model_name}"
)

# checkpoint deprecation
assert any(
f"Models quantized with version 1 of {config_name} is deprecated"
f"Models quantized with version {version} of {config_name} is deprecated"
in str(w.message)
for w in caught_warnings
), "Didn't get expected warning message for deprecation"
), (
f"Didn't get expected warning message for deprecation for model {model_name}"
)
assert isinstance(quantized_model.config.quantization_config, TorchAoConfig)
assert (
quantized_model.config.quantization_config.quant_type.version == version
Expand All @@ -139,7 +196,8 @@ def test_deprecated_hf_models(self, model_info):
return_tensors="pt",
).to("cuda")
generated_ids = quantized_model.generate(
**inputs, max_new_tokens=128, temperature=0
**inputs,
max_new_tokens=128,
)

downloaded_output = hf_hub_download(model_name, filename="model_output.pt")
Expand All @@ -153,6 +211,23 @@ def test_deprecated_hf_models(self, model_info):
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)

# make sure we throw warning for config deprecation
with warnings.catch_warnings(record=True) as caught_warnings:
_ = AutoModelForCausalLM.from_pretrained(
_HIGH_PRECISION_MODEL,
torch_dtype="bfloat16",
device_map="cuda:0",
quantization_config=quantized_model.config.quantization_config,
)
# config version deprecation in quant_api.py
assert any(
f"Config Deprecation: version {version} of {config_name} is deprecated and will no longer be supported in a future release"
in str(w.message)
for w in caught_warnings
), (
f"Didn't get expected warning message for version deprecation for config {config_name}, model {model_name}"
)


common_utils.instantiate_parametrized_tests(TestLoadAndRunCheckpoint)

Expand Down
6 changes: 3 additions & 3 deletions test/prototype/test_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_awq_functionality(self):
m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)

# baseline quantization
base_config = Int4WeightOnlyConfig(group_size=group_size, version=2)
base_config = Int4WeightOnlyConfig(group_size=group_size)
m_baseline = copy.deepcopy(m)
quantize_(m_baseline, base_config)

Expand Down Expand Up @@ -123,7 +123,7 @@ def test_awq_loading(self):
calibration_data = dataset[:n_calibration_examples]

# calibrate
base_config = Int4WeightOnlyConfig(group_size=group_size, version=2)
base_config = Int4WeightOnlyConfig(group_size=group_size)
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
quantize_(m, quant_config)

Expand Down Expand Up @@ -177,7 +177,7 @@ def test_awq_loading_vllm(self):
calibration_data = dataset[:n_calibration_examples]

# calibrate
base_config = Int4WeightOnlyConfig(group_size=group_size, version=2)
base_config = Int4WeightOnlyConfig(group_size=group_size)
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
quantize_(m, quant_config)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
BF16_ACT_CONFIG = Int4WeightOnlyConfig(
group_size=128,
int4_packing_format="marlin_sparse",
version=2,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def get_config(group_size):
return Int4WeightOnlyConfig(
group_size=group_size,
int4_packing_format="opaque",
version=2,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def get_config(group_size):
return Int4WeightOnlyConfig(
group_size=group_size,
int4_packing_format="plain_int32",
version=2,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
BF16_ACT_CONFIG = Int4WeightOnlyConfig(
group_size=128,
int4_packing_format="preshuffled",
version=2,
)

# only 128 group_size is supported
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def setUp(self):
self.config = Int4WeightOnlyConfig(
group_size=128,
int4_packing_format="plain",
version=2,
)
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,12 @@
INT4_CONFIG = Int4WeightOnlyConfig(
group_size=128,
int4_packing_format="tile_packed_to_4d",
version=2,
)

INT4_HQQ_CONFIG = Int4WeightOnlyConfig(
group_size=128,
int4_packing_format="tile_packed_to_4d",
int4_choose_qparams_algorithm="hqq",
version=2,
)


Expand Down
4 changes: 4 additions & 0 deletions torchao/dtypes/uintx/int4_cpu_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple

Expand Down Expand Up @@ -78,6 +79,9 @@ def __init__(
transposed: bool,
_layout: Layout,
):
warnings.warn(
"Models quantized with version 1 of Int4WeightOnlyConfig is deprecated and will no longer be supported in a future release, please upgrade torchao and quantize again, or download a newer torchao checkpoint, see https://github.com/pytorch/ao/issues/2948 for more details"
)
self.packed_weight = packed_weight
self.scale_and_zero = scale_and_zero
self.transposed = False
Expand Down
4 changes: 4 additions & 0 deletions torchao/dtypes/uintx/int4_xpu_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import warnings
from dataclasses import dataclass
from typing import Optional, Tuple

Expand Down Expand Up @@ -207,6 +208,9 @@ def __init__(
scale: torch.Tensor = None,
zero: torch.Tensor = None,
):
warnings.warn(
"Models quantized with version 1 of Int4WeightOnlyConfig is deprecated and will no longer be supported in a future release, please upgrade torchao and quantize again, or download a newer torchao checkpoint, see https://github.com/pytorch/ao/issues/2948 for more details"
)
self.packed_weight = packed_weight
self.scale_and_zero = scale_and_zero
self.transposed = False
Expand Down
4 changes: 4 additions & 0 deletions torchao/dtypes/uintx/marlin_sparse_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from dataclasses import dataclass

import torch
Expand Down Expand Up @@ -158,6 +159,9 @@ def __init__(
group_size: int,
num_bits: int,
):
warnings.warn(
"Models quantized with version 1 of Int4WeightOnlyConfig is deprecated and will no longer be supported in a future release, please upgrade torchao and quantize again, or download a newer torchao checkpoint, see https://github.com/pytorch/ao/issues/2948 for more details"
)
self.int_data = int_data
self.scale_and_zero = None
self.scale = scale
Expand Down
4 changes: 4 additions & 0 deletions torchao/dtypes/uintx/tensor_core_tiled_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import logging
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple

Expand Down Expand Up @@ -237,6 +238,9 @@ def __init__(
transposed: bool,
_layout: Layout,
):
warnings.warn(
"Models quantized with version 1 of Int4WeightOnlyConfig is deprecated and will no longer be supported in a future release, please upgrade torchao and quantize again, or download a newer torchao checkpoint, see https://github.com/pytorch/ao/issues/2948 for more details"
)
self.packed_weight = packed_weight
self.scale_and_zero = scale_and_zero
self.transposed = False
Expand Down
2 changes: 1 addition & 1 deletion torchao/prototype/awq/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def quantize_and_eval(
# TODO: this is temporary, we'll be using Int4WeightOnlyConfig soon
from torchao.quantization import Int4WeightOnlyConfig

base_config = Int4WeightOnlyConfig(group_size=group_size, version=2)
base_config = Int4WeightOnlyConfig(group_size=group_size)
print(f"running {quant} prepare and calibrate")
t0 = time.time()
quant_config = AWQConfig(base_config, step="prepare")
Expand Down
Loading
Loading