Skip to content

Commit

Permalink
Add gptq chunk support (#269)
Browse files Browse the repository at this point in the history
Co-authored-by: gushiqiao <gushiqiao@sensetime.com>
  • Loading branch information
gushiqiao and gushiqiao authored Dec 17, 2024
1 parent 408ba31 commit e8a7345
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ quant:
static_groups: True
percdamp: 0.01
blocksize: 128
chunk_num: 4
true_sequential: True
online_rotate: True
fp32_had: True
Expand Down
11 changes: 9 additions & 2 deletions llmc/compression/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def add_quant_config(self):
self.blocksize = special_config['blocksize']

self.owq = special_config.get('owq', False)
self.chunk_num = special_config.get('chunk_num', 1)

if self.owq:
self.n_outs = special_config['n_outs']
Expand Down Expand Up @@ -275,12 +276,18 @@ def add_batch(self, layer, name, inp, out):
inp = inp.permute([1, 0, 2])
inp = inp.flatten(1)

assert inp.shape[1] % self.chunk_num == 0, \
f'Error: inp.shape[1] ({inp.shape[1]}) cannot be evenly divided by chunk_num.'
chunks = torch.chunk(inp, self.chunk_num, dim=1)

self.layers_cache[name]['H'] *= self.layers_cache[name]['nsamples'] / (
self.layers_cache[name]['nsamples'] + tmp
)
self.layers_cache[name]['nsamples'] += tmp
inp = math.sqrt(2 / self.layers_cache[name]['nsamples']) * inp.float()
self.layers_cache[name]['H'] += inp.matmul(inp.t())

for chunk in chunks:
chunk = math.sqrt(2 / self.layers_cache[name]['nsamples']) * chunk.float()
self.layers_cache[name]['H'] += chunk.matmul(chunk.t())

dist.all_reduce(self.layers_cache[name]['H'], op=dist.ReduceOp.SUM)
dist.all_reduce(torch.tensor(self.layers_cache[name]['nsamples']).cuda(),
Expand Down

0 comments on commit e8a7345

Please sign in to comment.