Skip to content

Commit 29210b3

Browse files
committed
Add support for loading a module quantized with ModuleFqnToConfig using regex
Summary: att, we are adding regex support to simplify the config, and enabling the support in both transformers and vllm to make sure regex config works everywhere torchao PR that adds the functionality to quantize_ API: pytorch/ao#3084 transformer PR: Test Plan: We save the model with the regex config in transformers, in vllm we just make sure we can load the model: pytest tests/quantization/test_torchao.py test_opt_125m_module_fqn_to_config_regex_model_loading_with_params Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
1 parent 1b86bd8 commit 29210b3

File tree

2 files changed

+38
-3
lines changed

2 files changed

+38
-3
lines changed

tests/quantization/test_torchao.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,5 +216,23 @@ def test_reload_weights():
216216
# print("-" * 60)
217217

218218

219+
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
220+
@pytest.mark.skip(
221+
reason="since torchao nightly is only compatible with torch nightly"
222+
"currently https://github.com/pytorch/ao/issues/2919, we'll have to skip "
223+
"torchao tests that requires newer versions (0.14.0.dev+) for now")
224+
def test_opt_125m_module_fqn_to_config_regex_model(vllm_runner):
225+
torch._dynamo.reset()
226+
model_name = ("torchao-testing/opt-125m-ModuleFqnToConfig-v1-regex"
227+
"-0.14.0.dev")
228+
with vllm_runner(model_name=model_name,
229+
dtype="bfloat16",
230+
pt_load_map_location="cuda:0") as llm:
231+
output = llm.generate_greedy(["The capital of France is"],
232+
max_tokens=32)
233+
234+
assert output
235+
236+
219237
if __name__ == "__main__":
220238
pytest.main([__file__])

vllm/model_executor/layers/quantization/torchao.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
from typing import Any, Optional
55

6+
import regex as re
67
import torch
78
import torch.nn.functional as F
89
from torch.nn.parameter import Parameter
@@ -177,9 +178,25 @@ def get_quant_method(
177178
module_fqn = prefix
178179
if isinstance(self.torchao_config, ModuleFqnToConfig):
179180
module_fqn_to_config = self.torchao_config.module_fqn_to_config
180-
c = module_fqn_to_config.get(module_fqn) or module_fqn_to_config.get(
181-
"_default", None
182-
)
181+
c = None
182+
if module_fqn in module_fqn_to_config:
183+
assert not module_fqn.startswith("re:"), "module fqn should not start with" \
184+
"`re:`, which is used for specifying regex"
185+
c = module_fqn_to_config[module_fqn]
186+
else:
187+
for maybe_module_fqn_pattern in module_fqn_to_config:
188+
if not maybe_module_fqn_pattern.startswith("re:"):
189+
continue
190+
elif re.fullmatch(maybe_module_fqn_pattern[3:],
191+
module_fqn):
192+
# we'll apply the config for first fully matched pattern
193+
c = module_fqn_to_config[maybe_module_fqn_pattern]
194+
break
195+
else:
196+
# fallback to use default if no module specific
197+
# config is provided
198+
c = module_fqn_to_config.get("_default", None)
199+
183200
if c is not None:
184201
current_torchao_config = TorchAOConfig(
185202
c, self.skip_modules, self.is_checkpoint_torchao_serialized

0 commit comments

Comments
 (0)