Skip to content

Commit

Permalink
[Core] [!4Bit] fix cannot pickle 'module' object for 8 bit (fix huggi…
Browse files Browse the repository at this point in the history
…ngface#47) (huggingface#49)

* fix cannot pickle 'module' object for 8 bit

* remove unused import

* remove print

* check with tuple

* revert to len check

* add test for 8bit

* set same QuantizeConfig

* check if it's 4 bit

* fix grammar

* remove params

* it's not a list

* set gptqmodel_cuda back

* check is tuple

* format

* set desc_act=True

* set desc_act=True

* format

* format

* Refractor fix

* desc_act=True

---------

Co-authored-by: Qubitium <Qubitium@modelcloud.ai>
  • Loading branch information
CSY-ModelCloud and Qubitium authored Jun 24, 2024
1 parent 9a485ba commit f44d984
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 32 deletions.
41 changes: 28 additions & 13 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ def hf_device_map(self):
return getattr(self.model, "hf_device_map", None)

def _prepare_dataset_for_quantization(
self,
calibration_dataset: List[Dict[str, Union[List[int], torch.LongTensor]]],
batch_size: int = 1,
self,
calibration_dataset: List[Dict[str, Union[List[int], torch.LongTensor]]],
batch_size: int = 1,
):
def _convert_tensor_to_list(tensor):
if isinstance(tensor, torch.Tensor):
Expand Down Expand Up @@ -139,7 +139,7 @@ def _convert_tensor_to_list(tensor):
pad_token_id = self.config.eos_token_id

new_calibration_dataset = [
collate_data(new_calibration_dataset[start : start + batch_size], pad_token_id)
collate_data(new_calibration_dataset[start: start + batch_size], pad_token_id)
for start in range(0, len(new_calibration_dataset), batch_size)
]
for new_example in new_calibration_dataset:
Expand Down Expand Up @@ -184,7 +184,7 @@ def quantize(

if len(calibration_dataset) < MIN_CALIBRATION_DATASET_SIZE:
logger.warning(f"Calibration dataset size should be greater than {MIN_CALIBRATION_DATASET_SIZE}. "
f"Current size: {len(calibration_dataset)}.")
f"Current size: {len(calibration_dataset)}.")

# Calculate the average length of the average input_ids
total_input_ids_length = 0
Expand All @@ -195,7 +195,7 @@ def quantize(

if avg < MIN_CALIBRATION_DATASET_INPUT_IDS_AVG_LENGTH:
logger.warning(f"The average length of input_ids of calibration_dataset should be greater than "
f"{MIN_CALIBRATION_DATASET_INPUT_IDS_AVG_LENGTH}! Current AVG is {avg}.")
f"{MIN_CALIBRATION_DATASET_INPUT_IDS_AVG_LENGTH}! Current AVG is {avg}.")

device_map = self.hf_device_map
if device_map:
Expand Down Expand Up @@ -240,10 +240,7 @@ def store_input_hook(_, args, kwargs):
if pos_ids is not None:
position_ids.append(move_to(pos_ids, data_device))
one_kwargs = {}
for (
k,
v,
) in kwargs.items(): # make sure other arguments also be captured
for (k, v) in kwargs.items(): # make sure other arguments also be captured
if k not in ["hidden_states", "attention_mask", "position_ids"]:
one_kwargs[k] = nested_move_to(v, data_device)
layer_input_kwargs.append(one_kwargs)
Expand Down Expand Up @@ -498,8 +495,8 @@ def save_quantized(

if model_base_name is None:
model_base_name = (
self.quantize_config.model_file_base_name or
f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g"
self.quantize_config.model_file_base_name or
f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g"
)

state_dict = self.model.state_dict()
Expand Down Expand Up @@ -543,7 +540,25 @@ def save_quantized(
if format is None and quantize_config.format == FORMAT.GPTQ:
# Model qzeros may be edited in place.
# TODO: avoid inplace modification of the weights
model = copy.deepcopy(self.model)
# fix ModelCloud/GPTQModel/issues/47
# fix gptqmodel_cuda cannot be serialized
# no need to set it back, no calculation below
if quantize_config.bits != 4:
cuda_name_modules = {}
from gptqmodel.nn_modules.qlinear.qlinear_cuda import BaseCudaQuantLinear
for name, module in model.named_modules():
if isinstance(module, BaseCudaQuantLinear):
cuda_name_modules[name] = module.gptqmodel_cuda
module.gptqmodel_cuda = None
model = copy.deepcopy(self.model)

for name, modules in model.named_modules():
if isinstance(module, BaseCudaQuantLinear) and name in cuda_name_modules:
module.gptqmodel_cuda = cuda_name_modules[name]

del cuda_name_modules
else:
model = copy.deepcopy(self.model)
model = convert_gptq_v2_to_v1_format(
model, quantize_config=quantize_config, qlinear_kernel=self.qlinear_kernel
)
Expand Down
5 changes: 5 additions & 0 deletions gptqmodel/nn_modules/qlinear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,8 @@
class BaseQuantLinear(nn.Module):
# override me
QUANT_TYPE = "base"


class BaseCudaQuantLinear(BaseQuantLinear):
# override me
QUANT_TYPE = "base-cuda"
4 changes: 2 additions & 2 deletions gptqmodel/nn_modules/qlinear/qlinear_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
import torch
import torch.nn as nn
import transformers
from gptqmodel.nn_modules.qlinear import BaseQuantLinear
from gptqmodel.nn_modules.qlinear import BaseCudaQuantLinear

logger = getLogger(__name__)


class QuantLinear(BaseQuantLinear):
class QuantLinear(BaseCudaQuantLinear):
QUANT_TYPE = "cuda"

def __init__(
Expand Down
4 changes: 2 additions & 2 deletions gptqmodel/nn_modules/qlinear/qlinear_cuda_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
import torch
import torch.nn as nn
import transformers
from gptqmodel.nn_modules.qlinear import BaseQuantLinear
from gptqmodel.nn_modules.qlinear import BaseCudaQuantLinear

logger = getLogger(__name__)


class QuantLinear(BaseQuantLinear):
class QuantLinear(BaseCudaQuantLinear):
QUANT_TYPE = "cuda-old"

def __init__(
Expand Down
54 changes: 39 additions & 15 deletions tests/test_quant_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,18 @@


class TestQuantization(unittest.TestCase):

def setUp(self):
self.pretrained_model_dir = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"

self.tokenizer = AutoTokenizer.from_pretrained(self.pretrained_model_dir, use_fast=True)
self.calibration_dataset = [
self.tokenizer(
"auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."
),
self.tokenizer("Today I am in Paris and it is a wonderful day."),
]

@parameterized.expand(
[
(False, True, FORMAT.GPTQ_V2),
Expand All @@ -21,16 +33,6 @@ class TestQuantization(unittest.TestCase):
]
)
def test_quantize(self, use_marlin: bool, sym: bool, format: FORMAT):
pretrained_model_dir = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
calibration_dataset = [
tokenizer(
"auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."
),
tokenizer("Today I am in Paris and it is a wonderful day."),
]

quantize_config = QuantizeConfig(
bits=4,
group_size=128,
Expand All @@ -40,17 +42,15 @@ def test_quantize(self, use_marlin: bool, sym: bool, format: FORMAT):
)

model = GPTQModel.from_pretrained(
pretrained_model_dir,
self.pretrained_model_dir,
quantize_config=quantize_config,
use_flash_attention_2=False,
)

model.quantize(calibration_dataset)
model.quantize(self.calibration_dataset)

with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(
tmpdirname,
)
model.save_quantized(tmpdirname)

logging.info(f"Saved config mem: {model.quantize_config}")

Expand Down Expand Up @@ -117,3 +117,27 @@ def test_quantize(self, use_marlin: bool, sym: bool, format: FORMAT):
format=format,
)
assert isinstance(model.quantize_config, QuantizeConfig)

def test_gptq_8bit(self):
quantize_config = QuantizeConfig(
bits=8,
group_size=128,
format=FORMAT.GPTQ,
desc_act=True
)

model = GPTQModel.from_pretrained(
self.pretrained_model_dir,
quantize_config=quantize_config,
)

model.quantize(self.calibration_dataset)

with tempfile.TemporaryDirectory() as tmpdirname:
err = None
try:
model.save_quantized(tmpdirname)
except Exception as e:
print(e)
err = e
self.assertTrue(err is None)

0 comments on commit f44d984

Please sign in to comment.