Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPTQ Activation Ordering #94

Merged
merged 64 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
012138a
actorder
horheynm Jul 2, 2024
f88c84e
g_idx fix
horheynm Jul 10, 2024
3211fe1
fix
horheynm Jul 10, 2024
bbbf564
lint
horheynm Jul 10, 2024
8d29f0d
propagagte g_idx with perm
horheynm Jul 11, 2024
89224e9
scratch
horheynm Jul 12, 2024
cb8446d
GPTQ - move calibration of quantiztion params to after hessian calibr…
Jul 18, 2024
d7029a0
no recompute
horheynm Jul 22, 2024
eeff533
clean up
horheynm Jul 22, 2024
842b150
remvoe unwanted code
horheynm Jul 22, 2024
240c39d
draft
horheynm Jul 27, 2024
820d08a
draft
horheynm Jul 31, 2024
564845e
draft
horheynm Aug 1, 2024
6f54737
mimic gptq
horheynm Aug 6, 2024
2cc99bb
permutation seems to be working
kylesayrs Aug 9, 2024
6fe537d
WIP: fails on non-square weights
kylesayrs Aug 9, 2024
6611073
pass perm into quant params calculation
kylesayrs Aug 9, 2024
9077969
works on vllm and loading with identity permutation
kylesayrs Aug 12, 2024
6a1565e
WIP: working pytorch with actorder
kylesayrs Aug 12, 2024
1940df4
able to inference with script and reload, needed to set
kylesayrs Aug 12, 2024
11beac1
remove testing comments
kylesayrs Aug 13, 2024
9456698
remove scripts
kylesayrs Aug 13, 2024
0c773e6
remove dregs
kylesayrs Aug 13, 2024
b6bebc2
merge actorder and group cases
kylesayrs Aug 13, 2024
3bde194
code structuring and cleanup
kylesayrs Aug 13, 2024
758c495
use `refresh_layer_weight_quant_params`
kylesayrs Aug 13, 2024
85fb1ff
update_layer_weight_quant_params reuse
kylesayrs Aug 14, 2024
5b52e9d
deep copy H to allow for future reuse
kylesayrs Aug 14, 2024
9e2cef9
hoist group_size
kylesayrs Aug 16, 2024
e725cc7
remove footer note
kylesayrs Aug 16, 2024
2392b83
apply style
kylesayrs Aug 16, 2024
a5a30e1
fix rebase dreggs
kylesayrs Aug 16, 2024
ca6fc6e
remove extra line
kylesayrs Aug 16, 2024
6f99634
move lines for better grouping
kylesayrs Aug 16, 2024
b726bd6
move for better diff
kylesayrs Aug 16, 2024
2002761
remove extra lines
kylesayrs Aug 16, 2024
0ef0c5b
use getattr to avoid pr dep
kylesayrs Aug 17, 2024
476aed0
Revert "use getattr to avoid pr dep"
kylesayrs Aug 17, 2024
ffb809c
add actorder to docstring
kylesayrs Aug 21, 2024
edc02d4
Merge remote-tracking branch 'origin' into kylesayrs/activation-ordering
kylesayrs Aug 22, 2024
bc49946
do not clone hessian
kylesayrs Aug 22, 2024
99f2286
apply style
kylesayrs Aug 22, 2024
48b36c2
avoid unset g_idx parameter by observing directly
kylesayrs Aug 22, 2024
9550f14
use update_layer_weight_quant_params
kylesayrs Aug 22, 2024
d22ff2e
Merge remote-tracking branch 'origin/main' into kylesayrs/activation-…
kylesayrs Aug 23, 2024
72d919f
Merge branch 'main' into kylesayrs/activation-ordering
kylesayrs Aug 25, 2024
e4d37a6
indent for when quantization_scheme is missing
kylesayrs Aug 25, 2024
cdc8bcd
add actorder e2e test
kylesayrs Aug 25, 2024
1fe188b
do not freeze if initialized from gptq
kylesayrs Aug 27, 2024
b06a103
add get_attr_chain helper function
kylesayrs Aug 27, 2024
f293efd
cleanup and clarify logic
kylesayrs Aug 27, 2024
a99e0da
apply style
kylesayrs Aug 27, 2024
bf915d4
rename to getattr_chain, handle no default case
kylesayrs Aug 27, 2024
66ef96b
out of place type conversion
kylesayrs Aug 27, 2024
98aaf88
Merge remote-tracking branch 'origin/gptq-cleanup' into kylesayrs/act…
kylesayrs Aug 27, 2024
91c877a
account for extra case
kylesayrs Aug 27, 2024
b711e14
remove freeze_quantization argument
kylesayrs Aug 28, 2024
974dbc7
remove fake_quantization case, update debug message
kylesayrs Aug 28, 2024
094e429
remove todo
kylesayrs Aug 28, 2024
582c179
Merge remote-tracking branch 'origin/gptq-cleanup' into kylesayrs/act…
kylesayrs Aug 28, 2024
febb741
correct name
kylesayrs Aug 28, 2024
83a1d93
Merge remote-tracking branch 'origin/main' into kylesayrs/activation-…
kylesayrs Aug 28, 2024
a1646e5
Merge remote-tracking branch 'origin/main' into kylesayrs/activation-…
kylesayrs Aug 28, 2024
eef6bab
change to false in docstring
kylesayrs Aug 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/llmcompressor/modifiers/quantization/gptq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class GPTQModifier(Modifier):
| symmetric: true
| strategy: "tensor"
| group_size: 128
| actorder: True


:param sequential_update: Whether or not to update weights sequentially by layer,
Expand Down
77 changes: 53 additions & 24 deletions src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,40 @@ def compress(
W = W.t()
W = W.float()

tick = time.time()

# if activation ordering is enabled, permute the weight columns
# in order of greatest hessian values. Columns are unpermuted after
# quantization is finished
actorder = False
if hasattr(self.layer, "quantization_scheme"):
quant_scheme = self.layer.quantization_scheme
quant_weights = quant_scheme.weights
if quant_weights is not None:
actorder = quant_weights.actorder
if actorder:
# use hessian to create a permutation of weights
perm = torch.argsort(torch.diag(self.H), descending=True)

# permute weight and hessian
W = W[:, perm]
self.H = self.H[perm][:, perm]

# fetch latest correct scale and ZP relevant for any changes
from compressed_tensors.quantization import update_layer_weight_quant_params

# TODO: experiment with updating before each block
update_layer_weight_quant_params(self.layer, weight=W, reset_obs=True)
scale = self.layer.weight_scale.data
zero_point = self.layer.weight_zero_point.data
Satrat marked this conversation as resolved.
Show resolved Hide resolved

group_size = (
quant_scheme.weights.group_size
if quant_scheme.weights.group_size is not None
else W.shape[1]
)

# mask sparsity if applicable
sparsity = tensor_sparsity(W)
preserve_zeros = sparsity >= SPARSITY_THRESHOLD
W_nz_mask = (
Expand All @@ -106,25 +140,14 @@ def compress(
else None
)

tick = time.time()

if hasattr(self.layer, "quantization_scheme"):
quant_scheme = self.layer.quantization_scheme
if quant_scheme.weights is not None:
# fetch latest correct scale and ZP relevant for any changes
# such as activation reordering
from compressed_tensors.quantization import (
update_layer_weight_quant_params,
)

update_layer_weight_quant_params(self.layer)

# invalidate dead hessian values
dead = torch.diag(self.H) == 0
self.H[dead, dead] = 1
W[:, dead] = 0

Losses = torch.zeros(self.rows, device=self.dev)

# compute inverse hessian in place to save memory
damp = percdamp * torch.mean(torch.diag(self.H))
diag = torch.arange(self.columns, device=self.dev)
self.H[diag, diag] += damp
Expand Down Expand Up @@ -165,8 +188,6 @@ def compress(
elif hasattr(self.layer, "quantization_scheme"):
quant_scheme = self.layer.quantization_scheme
if quant_scheme.weights is not None:
scale = self.layer.weight_scale
zero_point = self.layer.weight_zero_point
from compressed_tensors.quantization import QuantizationStrategy
from compressed_tensors.quantization.lifecycle.forward import (
fake_quantize,
Expand All @@ -192,9 +213,7 @@ def compress(
else: # strategy == QuantizationStrategy.GROUP
# get the group index for the current column
column_idx = i1 + i
input_dim_group = (
column_idx // quant_scheme.weights.group_size
)
input_dim_group = column_idx // group_size

# Since we're only applying quantization to a slice, this
# ends up being a channelwise application
Expand Down Expand Up @@ -249,12 +268,26 @@ def compress(
f"Compressed layer size: {get_layer_size_bytes(self.layer)} MB",
)

if actorder:
kylesayrs marked this conversation as resolved.
Show resolved Hide resolved
# restore original permutation
invperm = torch.argsort(perm)
W = W[:, invperm]

# g_idx describes the group index of the permuted weight
g_idx = torch.tensor(
[i // group_size for i in range(self.columns)],
dtype=torch.int,
).to(device=invperm.device)

# invert to get the group index of the unpermuted weight
self.layer.weight_g_idx.data = g_idx[invperm]

if isinstance(self.layer, transformers.Conv1D):
W = W.t()
W = W.reshape(final_shape).to(final_dtype)

# This is a bit hacky, but FSDP updates only work if we change the weight in
# place, clone() or direct assignment won't work
# This is a bit hacky, but FSDP updates only work if we change
# the weight in place, clone() or direct assignment won't work
self.layer.weight -= self.layer.weight
self.layer.weight += W

Expand All @@ -263,10 +296,6 @@ def compress(
update_prefix_dict(self.layer, "weight", self.layer.weight.to(device))
self.layer._hf_hook.post_forward(self.layer, None)

del W
del Losses
del diag

def free(self):
"""
Free the Hessian memory after the layer is complete
Expand Down
Loading