Skip to content

Commit 03c4c4a

Browse files
authored
Support using Int4PreshuffledTensor after loading (#26066)
Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
1 parent 2ec401b commit 03c4c4a

File tree

2 files changed

+208
-4
lines changed

2 files changed

+208
-4
lines changed

tests/quantization/test_torchao.py

Lines changed: 144 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def test_opt_125m_awq_int4wo_model_loading_with_params(vllm_runner):
9999

100100

101101
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
102-
def test_on_the_fly_quant_config_dict_json(vllm_runner):
102+
def test_online_quant_config_dict_json(vllm_runner):
103103
"""Testing on the fly quantization, load_weights integration point,
104104
with config dict serialized to json string
105105
"""
@@ -133,7 +133,7 @@ def test_on_the_fly_quant_config_dict_json(vllm_runner):
133133

134134

135135
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
136-
def test_on_the_fly_quant_config_file(vllm_runner):
136+
def test_online_quant_config_file(vllm_runner):
137137
"""Testing on the fly quantization, load_weights integration point,
138138
with config file
139139
"""
@@ -252,6 +252,148 @@ def test_opt_125m_module_fqn_to_config_regex_model(vllm_runner):
252252
) as llm:
253253
output = llm.generate_greedy(["The capital of France is"], max_tokens=4)
254254

255+
assert output
256+
257+
258+
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
259+
@pytest.mark.skip(
260+
reason="since torchao nightly is only compatible with torch nightly"
261+
"currently https://github.com/pytorch/ao/issues/2919, we'll have to skip "
262+
"torchao tests that requires newer versions (0.14.0.dev+) for now"
263+
)
264+
def test_opt_125m_int4wo_model_running_preshuffled_kernel(vllm_runner, monkeypatch):
265+
"""We load a model with Int4Tensor (plain format) linear weights
266+
and verify that the weight is updated to Int4PreshuffledTensor
267+
after loading in vllm
268+
"""
269+
from torchao.quantization import Int4PreshuffledTensor
270+
from torchao.utils import _is_fbgemm_gpu_genai_available, is_sm_at_least_90
271+
272+
torch._dynamo.reset()
273+
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
274+
model_name = "torchao-testing/opt-125m-Int4WeightOnlyConfig-v2-0.14.0.dev"
275+
# Note: using enforce_eager=True because the `bf16i4bf16_shuffled` doesn't
276+
# have meta kernel implemented yet, can remove this flag after that is implemented
277+
with vllm_runner(
278+
model_name=model_name,
279+
quantization="torchao",
280+
dtype="bfloat16",
281+
pt_load_map_location="cuda:0",
282+
enforce_eager=True,
283+
) as llm:
284+
285+
def has_int4_preshuffled_tensor_weight(model):
286+
return isinstance(
287+
model.model.decoder.layers[0].self_attn.qkv_proj.weight,
288+
Int4PreshuffledTensor,
289+
)
290+
291+
def get_weight_attrs(model):
292+
weight = model.model.decoder.layers[0].self_attn.qkv_proj.weight
293+
return [
294+
weight.requires_grad,
295+
weight.input_dim,
296+
weight.output_dim,
297+
hasattr(weight, "weight_loader"),
298+
]
299+
300+
llm_engine = llm.get_llm().llm_engine
301+
has_int4_preshuffled_tensor = any(
302+
llm_engine.apply_model(has_int4_preshuffled_tensor_weight)
303+
)
304+
weight_attrs = llm_engine.apply_model(get_weight_attrs)[0]
305+
306+
# making sure we are using Int4PreshuffledTensor on H100 GPU, when
307+
# fbgemm_gpu_genai
308+
# library is installed, otherwise it should be using Int4Tensor
309+
if _is_fbgemm_gpu_genai_available() and is_sm_at_least_90():
310+
assert has_int4_preshuffled_tensor
311+
else:
312+
assert not has_int4_preshuffled_tensor
313+
314+
assert weight_attrs == [False, 1, 0, True]
315+
output = llm.generate_greedy(["The capital of France is"], max_tokens=32)
316+
317+
assert output
318+
319+
320+
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
321+
@pytest.mark.skip(
322+
reason="since torchao nightly is only compatible with torch nightly"
323+
"currently https://github.com/pytorch/ao/issues/2919, we'll have to skip "
324+
"torchao tests that requires newer versions (0.14.0.dev+) for now"
325+
)
326+
def test_opt_125m_int4wo_model_running_preshuffled_kernel_online_quant(
327+
vllm_runner, monkeypatch
328+
):
329+
"""We load a bf16 model and online quantize the model to int4, then verify that
330+
the weights are updated to Int4PreshuffledTensor after online quantization
331+
"""
332+
from torchao.quantization import Int4PreshuffledTensor
333+
from torchao.utils import _is_fbgemm_gpu_genai_available, is_sm_at_least_90
334+
335+
torch._dynamo.reset()
336+
model_name = "facebook/opt-125m"
337+
338+
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
339+
340+
import json
341+
342+
from torchao.core.config import config_to_dict
343+
from torchao.quantization import Int4WeightOnlyConfig
344+
345+
torchao_quant_config = Int4WeightOnlyConfig(
346+
group_size=128, int4_packing_format="plain"
347+
)
348+
hf_overrides = {
349+
"quantization_config_dict_json": json.dumps(
350+
config_to_dict(torchao_quant_config)
351+
)
352+
}
353+
354+
# Note: using enforce_eager=True because the `bf16i4bf16_shuffled` doesn't
355+
# have meta kernel implemented yet, can remove this flag after that is implemented
356+
with vllm_runner(
357+
model_name=model_name,
358+
quantization="torchao",
359+
dtype="bfloat16",
360+
pt_load_map_location="cuda:0",
361+
hf_overrides=hf_overrides,
362+
enforce_eager=True,
363+
) as llm:
364+
365+
def has_int4_preshuffled_tensor_weight(model):
366+
return isinstance(
367+
model.model.decoder.layers[0].self_attn.qkv_proj.weight,
368+
Int4PreshuffledTensor,
369+
)
370+
371+
def get_weight_attrs(model):
372+
weight = model.model.decoder.layers[0].self_attn.qkv_proj.weight
373+
return [
374+
weight.requires_grad,
375+
weight.input_dim,
376+
weight.output_dim,
377+
hasattr(weight, "weight_loader"),
378+
]
379+
380+
llm_engine = llm.get_llm().llm_engine
381+
has_int4_preshuffled_tensor = any(
382+
llm_engine.apply_model(has_int4_preshuffled_tensor_weight)
383+
)
384+
weight_attrs = llm_engine.apply_model(get_weight_attrs)[0]
385+
386+
# making sure we are using Int4PreshuffledTensor on H100 GPU, when
387+
# fbgemm_gpu_genai
388+
# library is installed, otherwise it should be using Int4Tensor
389+
if _is_fbgemm_gpu_genai_available() and is_sm_at_least_90():
390+
assert has_int4_preshuffled_tensor
391+
else:
392+
assert not has_int4_preshuffled_tensor
393+
394+
assert weight_attrs == [False, 1, 0, True]
395+
output = llm.generate_greedy(["The capital of France is"], max_tokens=32)
396+
255397
assert output
256398

257399

vllm/model_executor/layers/quantization/torchao.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import importlib
44
import json
5+
import types
56
from importlib.util import find_spec
67
from typing import Any, Optional
78

@@ -27,6 +28,39 @@
2728
logger = init_logger(__name__)
2829

2930

31+
def _bond_method_to_cls(func, obj):
32+
if hasattr(func, "__self__") or not callable(func):
33+
# If the function is already bound to an instance, return it as is
34+
return func
35+
else:
36+
return types.MethodType(func, obj)
37+
38+
39+
def _get_weight_attrs(param):
40+
# record attributes attached to the weight, so we can
41+
# recover later
42+
recorded_weight_attr = {}
43+
for key in param.__dict__:
44+
if hasattr(param, key):
45+
attr = getattr(param, key)
46+
if not callable(attr):
47+
recorded_weight_attr[key] = attr
48+
elif hasattr(attr, "__self__") and param is attr.__self__:
49+
# if attr is a bonded method for an instance, and
50+
# attr.__self__ points to the instance (param)
51+
# we'll record the underlying function object
52+
recorded_weight_attr[key] = attr.__func__
53+
else:
54+
recorded_weight_attr[key] = attr
55+
return recorded_weight_attr
56+
57+
58+
def _restore_weight_attrs(param, recorded_weight_attr):
59+
for attr_name, attr in recorded_weight_attr.items():
60+
if not hasattr(param, attr_name):
61+
setattr(param, attr_name, _bond_method_to_cls(attr, param))
62+
63+
3064
def torchao_version_at_least(torchao_version: str) -> bool:
3165
if find_spec("torchao"):
3266
try:
@@ -57,6 +91,14 @@ def should_skip(prefix: str, skip_modules: list[str]) -> bool:
5791
return False
5892

5993

94+
if torchao_version_at_least("0.15.0"):
95+
from torchao.prototype.tensor_conversion.api import (
96+
convert_to_packed_tensor_based_on_current_hardware,
97+
)
98+
else:
99+
convert_to_packed_tensor_based_on_current_hardware = lambda t: t
100+
101+
60102
class TorchAOConfig(QuantizationConfig):
61103
"""Config class for torchao."""
62104

@@ -307,12 +349,32 @@ def apply(
307349

308350
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
309351
if self.quant_config.is_checkpoint_torchao_serialized:
352+
if not hasattr(layer, "weight"):
353+
return
354+
355+
# record attributes attached to the weight, so we can
356+
# recover later
357+
recorded_weight_attr = _get_weight_attrs(layer.weight)
358+
359+
layer.weight = Parameter(
360+
convert_to_packed_tensor_based_on_current_hardware(layer.weight),
361+
requires_grad=layer.weight.requires_grad,
362+
)
363+
364+
_restore_weight_attrs(layer.weight, recorded_weight_attr)
310365
return
311366

312-
# quantize the weight on the fly if the checkpoint is not already
367+
# online quantize the weight if the checkpoint is not already
313368
# quantized by torchao
369+
recorded_weight_attr = _get_weight_attrs(layer.weight)
370+
314371
weight = torchao_quantize_param_data(
315372
layer.weight, self.quant_config.torchao_config
316373
)
317-
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
374+
weight = torch.nn.Parameter(
375+
convert_to_packed_tensor_based_on_current_hardware(weight),
376+
weight.requires_grad,
377+
)
378+
379+
_restore_weight_attrs(weight, recorded_weight_attr)
318380
layer.register_parameter("weight", weight)

0 commit comments

Comments
 (0)