Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove is_gpt_fast flag #172

Merged
merged 3 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 0 additions & 51 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,57 +268,6 @@ def test_8da4w_quantizer_eval(self):
f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}"
)

@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_gptq_quantizer_gpt_fast(self):
Copy link
Member

@msaroufim msaroufim Apr 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this test not useful to keep around or some new version of it? Or more generally can we no longer use gpt-fast and ao together?

Copy link
Contributor Author

@jerryzh168 jerryzh168 Apr 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh next test tests gpt fast code path: test_gptq_quantizer_int4wo, this was initially added because we were trying to merge the gpt fast code path and 8da4w code path together in the same quantizer code path and use a flag to distinguish them, but now we just duplicated the quantizer code

so since we removed the gpt-fast code path in Int8DynActInt4WeightGPTQQuantizer, we no longer need to test this

from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer, InputRecorder
# should be similar to TorchCompileDynamicQuantizer
precision = torch.bfloat16
device = "cuda"
checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
model = Transformer.from_name(checkpoint_path.parent.name)
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
model.load_state_dict(checkpoint, assign=True)
model = model.to(dtype=precision, device=device)
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path
tokenizer = SentencePieceProcessor( # pyre-ignore[28]
model_file=str(tokenizer_path)
)
blocksize = 128
percdamp = 0.01
groupsize = 128
calibration_tasks = ["wikitext"]
calibration_limit = 1
calibration_seq_length = 100
input_prep_func = prepare_inputs_for_model
pad_calibration_inputs = False

inputs = InputRecorder(
tokenizer,
calibration_seq_length,
input_prep_func,
pad_calibration_inputs,
model.config.vocab_size,
).record_inputs(
calibration_tasks,
calibration_limit,
).get_inputs()

quantizer = Int8DynActInt4WeightGPTQQuantizer(
blocksize,
percdamp,
groupsize,
_is_gpt_fast=True,
_use_cuda=True,
)

model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length)

model = quantizer.quantize(model, inputs)
compiled = torch.compile(model, mode="max-autotune")
with torch.no_grad():
compiled(inputs[0].values[0], inputs[1].values[0])

@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_gptq_quantizer_int4wo(self):
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer, InputRecorder, TransformerEvalWrapper
Expand Down
92 changes: 19 additions & 73 deletions torchao/quantization/GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,17 +1176,8 @@ def __init__(
padding_allowed: bool = False,
precision: torch.dtype = torch.float32,
scales_precision: torch.dtype = torch.float32,
inner_k_tiles: Optional[int] = None,
_is_gpt_fast: bool = False,
) -> None:
super().__init__()
if _is_gpt_fast:
assert inner_k_tiles in [2, 4, 8]
assert groupsize in [32, 64, 128, 256]
else:
assert inner_k_tiles is None
self._is_gpt_fast = _is_gpt_fast
self.inner_k_tiles = inner_k_tiles
self.groupsize: int = groupsize
self.padding_allowed: bool = padding_allowed
self.precision: torch.dtype = precision
Expand All @@ -1210,9 +1201,7 @@ def _create_quantized_state_dict(
), f"require in_features:{in_features} % self.groupsize:{self.groupsize} == 0"

weight = mod.weight.data
if not _check_linear_int4_k(
in_features, self.groupsize, self.inner_k_tiles
):
if not _check_linear_int4_k(in_features, self.groupsize):
if self.padding_allowed:
from .utils import find_multiple
import torch.nn.functional as F
Expand All @@ -1233,36 +1222,21 @@ def _create_quantized_state_dict(
self.groupsize,
self.scales_precision,
)
if self._is_gpt_fast:
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int8.to(torch.int32), self.inner_k_tiles)
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros)
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu")
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu")
else:
cur_state_dict[f"{fqn}.weight"] = weight_int8.to("cpu")
cur_state_dict[f"{fqn}.scales"] = scales.to("cpu")
cur_state_dict[f"{fqn}.zeros"] = zeros.to("cpu")
cur_state_dict[f"{fqn}.weight"] = weight_int8.to("cpu")
cur_state_dict[f"{fqn}.scales"] = scales.to("cpu")
cur_state_dict[f"{fqn}.zeros"] = zeros.to("cpu")
# TODO: support bias?

return cur_state_dict

def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module:
if self._is_gpt_fast:
# TODO: temporary path for gpt-fast, will remove later
replace_linear_int4(
model,
self.groupsize,
self.inner_k_tiles,
self.padding_allowed,
)
else:
replace_linear_8da4w(
model,
self.groupsize,
self.padding_allowed,
self.precision,
self.precision,
)
replace_linear_8da4w(
model,
self.groupsize,
self.padding_allowed,
self.precision,
self.precision,
)
return model

def quantize(
Expand All @@ -1284,9 +1258,7 @@ def __init__(
inner_k_tiles=8,
padding_allowed=True,
precision=torch.float32,
_is_gpt_fast=False,
):
self._is_gpt_fast = _is_gpt_fast
self.blocksize = blocksize
self.percdamp = percdamp
self.groupsize = groupsize
Expand Down Expand Up @@ -1327,23 +1299,6 @@ def __init__(
)

# we need to do the padding here, both for q and the qparams if necessary

# TODO: this is the gpt-fast version, merge with the main version later
def make_names_and_values_dict_func_gpt_fast(q, qparams):
k = q.shape[1]
new_k = find_multiple(k, 1024)
# how much we need to pad the weight
delta_k = new_k - q.shape[1]
q = q.to(torch.int32)
final_q = torch.ops.aten._convert_weight_to_int4pack(F.pad(q, pad=(0, delta_k)), inner_k_tiles)
scales = qparams[0].to(torch.bfloat16)
zeros = qparams[1].to(torch.bfloat16)
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros)
# how many new groups we need for padded weight
delta_groups = new_k // groupsize - scales_and_zeros.shape[0]
final_s_and_z = F.pad(scales_and_zeros, pad=(0,0,0,0,0, delta_groups), value=1)
return {"weight": final_q, "scales_and_zeros": final_s_and_z}

def make_names_and_values_dict_func(q, qparams):
k = q.shape[1]
new_k = find_multiple(k, 1 if groupsize is None else groupsize)
Expand All @@ -1354,26 +1309,17 @@ def make_names_and_values_dict_func(q, qparams):
zeros = qparams[1].to(self.precision)
return {"weight": final_q, "scales": scales, "zeros": zeros}

self.make_names_and_values_dict_func = make_names_and_values_dict_func_gpt_fast if self._is_gpt_fast else make_names_and_values_dict_func
self.make_names_and_values_dict_func = make_names_and_values_dict_func
super().__init__()

def _convert_for_runtime(self, model):
if self._is_gpt_fast:
# TODO: temporary path for gpt-fast, will remove later
replace_linear_int4(
model,
self.groupsize,
self.inner_k_tiles,
self.padding_allowed,
)
else:
replace_linear_8da4w(
model,
self.groupsize,
self.padding_allowed,
self.precision,
self.precision,
)
replace_linear_8da4w(
model,
self.groupsize,
self.padding_allowed,
self.precision,
self.precision,
)
return model

def quantize(self, model: torch.nn.Module, inputs: List[MultiInput], **kwargs: Any) -> torch.nn.Module:
Expand Down
Loading