Skip to content

Commit 8f77f17

Browse files
committed
Enhance woq model loading & support hf woq model loading
Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
1 parent 855b988 commit 8f77f17

File tree

8 files changed

+529
-31
lines changed

8 files changed

+529
-31
lines changed

neural_compressor/torch/algorithms/weight_only/save_load.py

Lines changed: 469 additions & 7 deletions
Large diffs are not rendered by default.

neural_compressor/torch/quantization/load_entry.py

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -31,28 +31,47 @@
3131
}
3232

3333

34-
def load(output_dir="./saved_results", model=None):
35-
from neural_compressor.common.base_config import ConfigRegistry
34+
def load(model_name_or_path="./saved_results", model=None, format="default", *hf_model_args, **hf_model_kwargs):
35+
"""Load quantized model.
3636
37-
qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), "qconfig.json")
38-
with open(qconfig_file_path, "r") as f:
39-
per_op_qconfig = json.load(f)
37+
Args:
38+
model_name_or_path (str, optional): local path where quantized weights or model are saved
39+
or huggingface model id. Defaults to "./saved_results".
40+
model (torch.nn.Module, optional): original model. Require to pass when loading INC WOQ quantized model
41+
or loading FP8 model. Defaults to None.
42+
format (str, optional): 'defult' for loading INC quantized model.
43+
'huggingface' now only for loading huggingface WOQ causal language model. Defaults to "default".
4044
41-
if " " in per_op_qconfig.keys(): # ipex qconfig format: {' ': {'q_op_infos': {'0': {'op_type': ...
42-
from neural_compressor.torch.algorithms.static_quant import load
45+
Returns:
46+
torch.nn.Module: quantized model
47+
"""
48+
if format == "default":
49+
from neural_compressor.common.base_config import ConfigRegistry
50+
from neural_compressor.torch.algorithms.static_quant import load as static_quant_load
51+
from neural_compressor.torch.algorithms.weight_only.save_load import load as woq_load
52+
from neural_compressor.torch.algorithms.habana_fp8 import load as habana_fp8_load
4353

44-
return load(output_dir)
45-
else:
46-
config_mapping = load_config_mapping(qconfig_file_path, ConfigRegistry.get_all_configs()["torch"])
47-
# select load function
48-
config_object = config_mapping[next(iter(config_mapping))]
49-
if isinstance(config_object, (RTNConfig, GPTQConfig, AWQConfig, TEQConfig, AutoRoundConfig)): # WOQ
50-
from neural_compressor.torch.algorithms.weight_only.save_load import load
54+
qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(model_name_or_path)), "qconfig.json")
55+
with open(qconfig_file_path, "r") as f:
56+
per_op_qconfig = json.load(f)
57+
58+
if " " in per_op_qconfig.keys(): # ipex qconfig format: {' ': {'q_op_infos': {'0': {'op_type': ...
59+
return static_quant_load(model_name_or_path)
60+
else:
61+
config_mapping = load_config_mapping(qconfig_file_path, ConfigRegistry.get_all_configs()["torch"])
62+
# select load function
63+
config_object = config_mapping[next(iter(config_mapping))]
5164

52-
return load(output_dir)
65+
if isinstance(config_object, (RTNConfig, GPTQConfig, AWQConfig, TEQConfig, AutoRoundConfig)): # WOQ
66+
return woq_load(model_name_or_path, model=model, format=format)
5367

54-
model.qconfig = config_mapping
55-
if isinstance(config_object, FP8Config): # FP8
56-
from neural_compressor.torch.algorithms.habana_fp8 import load
68+
model.qconfig = config_mapping
69+
if isinstance(config_object, FP8Config): # FP8
70+
return habana_fp8_load(model, model_name_or_path)
71+
elif format == "huggingface":
72+
# now only support load huggingface WOQ causal language model
73+
from neural_compressor.torch.algorithms.weight_only.save_load import load as woq_load
5774

58-
return load(model, output_dir) # pylint: disable=E1121
75+
return woq_load(model_name_or_path, format=format, *hf_model_args, **hf_model_kwargs)
76+
else:
77+
raise ValueError("`format` in load function can only be 'huggingface' or 'default', but get {}".format(format))

test/3x/torch/quantization/weight_only/test_autoround.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def test_save_and_load(self):
148148
from neural_compressor.torch.quantization import load
149149

150150
# loading compressed model
151-
loaded_model = load("saved_results")
151+
loaded_model = load("saved_results", model=copy.deepcopy(self.gptj))
152152
loaded_out = loaded_model(self.inp)[0]
153153
assert torch.allclose(inc_out, loaded_out), "Unexpected result. Please double check."
154154

test/3x/torch/quantization/weight_only/test_awq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def calib_func(model):
131131
from neural_compressor.torch.quantization import load
132132

133133
# loading compressed model
134-
loaded_model = load("saved_results")
134+
loaded_model = load("saved_results", model=copy.deepcopy(self.tiny_gptj))
135135
loaded_out = loaded_model(self.example_inputs)[0]
136136
assert torch.allclose(inc_out, loaded_out), "Unexpected result. Please double check."
137137
assert isinstance(loaded_model.lm_head, WeightOnlyLinear), "loading compressed model failed."

test/3x/torch/quantization/weight_only/test_gptq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def test_save_and_load(self):
254254
from neural_compressor.torch.quantization import load
255255

256256
# loading compressed model
257-
loaded_model = load("saved_results")
257+
loaded_model = load("saved_results", model=copy.deepcopy(self.tiny_gptj))
258258
loaded_out = loaded_model(self.example_inputs)[0]
259259
assert torch.allclose(inc_out, loaded_out), "Unexpected result. Please double check."
260260
assert isinstance(
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import torch
2+
from transformers import AutoTokenizer
3+
from neural_compressor.torch.utils import accelerator
4+
5+
device = accelerator.current_device_name()
6+
7+
class TestHFModelLoad:
8+
def setup_class(self):
9+
self.model_name = "TheBloke/TinyLlama-1.1B-python-v0.1-GPTQ"
10+
self.example_inputs = torch.tensor([[10, 20, 30, 40, 50, 60]], dtype=torch.long).to(device)
11+
12+
def test_load_hf_woq_model(self):
13+
from neural_compressor.torch.quantization import load
14+
15+
qmodel = load(self.model_name, format="huggingface")
16+
output = qmodel(self.example_inputs)[0]
17+
assert len(output) > 0, "Not loading the model correctly"

test/3x/torch/quantization/weight_only/test_rtn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def test_save_and_load(self):
290290
from neural_compressor.torch.quantization import load
291291

292292
# loading compressed model
293-
loaded_model = load("saved_results")
293+
loaded_model = load("saved_results", model=copy.deepcopy(self.tiny_gptj))
294294
loaded_out = loaded_model(self.example_inputs)[0]
295295
assert torch.allclose(inc_out, loaded_out), "Unexpected result. Please double check."
296296
assert isinstance(loaded_model.lm_head, WeightOnlyLinear), "loading compressed model failed."

test/3x/torch/quantization/weight_only/test_teq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def test_save_and_load(self):
141141
from neural_compressor.torch.quantization import load
142142

143143
# loading compressed model
144-
loaded_model = load("saved_results")
144+
loaded_model = load("saved_results", model=copy.deepcopy(self.gptj))
145145
loaded_out = loaded_model(self.example_inputs)[0]
146146
assert torch.allclose(inc_out, loaded_out), "Unexpected result. Please double check."
147147

0 commit comments

Comments
 (0)