Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix GPTQ compatibility with AutoGPTQ #1574

Merged
merged 2 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion optimum/gptq/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@
"model.layers",
]

GPTQ_CONFIG = "quantization_config.json"
GPTQ_CONFIG = "quantize_config.json"
32 changes: 18 additions & 14 deletions optimum/gptq/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import json
import os
from enum import Enum
Expand All @@ -35,7 +34,6 @@

if is_accelerate_available():
from accelerate import (
Accelerator,
cpu_offload_with_hook,
load_checkpoint_and_dispatch,
)
Expand Down Expand Up @@ -146,6 +144,17 @@ def __init__(
self.quant_method = QuantizationMethod.GPTQ
self.cache_block_outputs = cache_block_outputs

self.serialization_keys = [
"bits",
"dataset",
"group_size",
"damp_percent",
"desc_act",
"sym",
"true_sequential",
"quant_method",
]

if self.bits not in [2, 3, 4, 8]:
raise ValueError("only support quantize to [2,3,4,8] bits.")
if self.group_size != -1 and self.group_size <= 0:
Expand All @@ -169,7 +178,10 @@ def to_dict(self):
"""
Returns the args in dict format.
"""
return copy.deepcopy(self.__dict__)
gptq_dict = {}
for key in self.serialization_keys:
gptq_dict[key] = getattr(self, key)
return gptq_dict

@classmethod
def from_dict(cls, config_dict: Dict[str, Any]):
Expand Down Expand Up @@ -600,7 +612,7 @@ def pack_model(

logger.info("Model packed.")

def save(self, model: nn.Module, save_dir: str, max_shard_size: str = "10GB", safe_serialization: bool = False):
def save(self, model: nn.Module, save_dir: str, max_shard_size: str = "10GB", safe_serialization: bool = True):
"""
Save model state dict and configs

Expand All @@ -618,20 +630,12 @@ def save(self, model: nn.Module, save_dir: str, max_shard_size: str = "10GB", sa
which will be bigger than `max_shard_size`.

</Tip>
safe_serialization (`bool`, defaults to `False`):
safe_serialization (`bool`, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).

"""

if not is_accelerate_available():
raise RuntimeError(
"You need to install accelerate in order to save a quantized model. You can do it with `pip install accelerate`"
)

os.makedirs(save_dir, exist_ok=True)
# save model and config
accelerator = Accelerator()
accelerator.save_model(model, save_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization)
model.save_pretrained(save_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization)
with open(os.path.join(save_dir, GPTQ_CONFIG), "w", encoding="utf-8") as f:
json.dump(self.to_dict(), f, indent=2)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
"invisible-watermark",
]

QUALITY_REQUIRE = ["black~=23.1", "ruff>=0.0.241,<=0.0.259"]
QUALITY_REQUIRE = ["black~=23.1", "ruff==0.1.5"]

BENCHMARK_REQUIRE = ["optuna", "tqdm", "scikit-learn", "seqeval", "torchvision", "evaluate>=0.2.0"]

Expand Down
24 changes: 24 additions & 0 deletions tests/gptq/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,14 @@

from optimum.gptq import GPTQQuantizer, load_quantized_model
from optimum.gptq.data import get_dataset
from optimum.utils.import_utils import is_auto_gptq_available
from optimum.utils.testing_utils import require_accelerate, require_auto_gptq, require_torch_gpu


if is_auto_gptq_available():
from auto_gptq import AutoGPTQForCausalLM


@slow
@require_auto_gptq
@require_torch_gpu
Expand Down Expand Up @@ -125,7 +130,9 @@ def check_inference_correctness(self, model):
def test_generate_quality(self):
self.check_inference_correctness(self.quantized_model)

@require_torch_gpu
@require_accelerate
@slow
def test_serialization(self):
"""
Test the serialization of the model and the loading of the quantized weights
Expand All @@ -148,6 +155,11 @@ def test_serialization(self):
exllama_config=self.exllama_config,
)
self.check_quantized_layers_type(quantized_model_from_saved, "cuda-old")

with torch.device("cuda"):
_ = AutoModelForCausalLM.from_pretrained(tmpdirname)
_ = AutoGPTQForCausalLM.from_quantized(tmpdirname)

self.check_inference_correctness(quantized_model_from_saved)


Expand Down Expand Up @@ -177,6 +189,7 @@ def test_serialization(self):
# act_order don't work with qlinear_cuda kernel
pass

@require_torch_gpu
def test_exllama_serialization(self):
"""
Test the serialization of the model and the loading of the quantized weights with exllama kernel
Expand All @@ -195,6 +208,11 @@ def test_exllama_serialization(self):
empty_model, save_folder=tmpdirname, device_map={"": 0}, exllama_config={"version": 1}
)
self.check_quantized_layers_type(quantized_model_from_saved, "exllama")

with torch.device("cuda"):
_ = AutoModelForCausalLM.from_pretrained(tmpdirname)
_ = AutoGPTQForCausalLM.from_quantized(tmpdirname)

self.check_inference_correctness(quantized_model_from_saved)

def test_exllama_max_input_length(self):
Expand Down Expand Up @@ -245,6 +263,7 @@ def test_serialization(self):
# don't need to test
pass

@require_torch_gpu
def test_exllama_serialization(self):
"""
Test the serialization of the model and the loading of the quantized weights with exllamav2 kernel
Expand All @@ -265,6 +284,11 @@ def test_exllama_serialization(self):
device_map={"": 0},
)
self.check_quantized_layers_type(quantized_model_from_saved, "exllamav2")

with torch.device("cuda"):
_ = AutoModelForCausalLM.from_pretrained(tmpdirname)
_ = AutoGPTQForCausalLM.from_quantized(tmpdirname)

self.check_inference_correctness(quantized_model_from_saved)


Expand Down
Loading